diff --git a/CMakeLists.txt b/CMakeLists.txt index 7153f8424..a56ecc46e 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -113,7 +113,6 @@ set(DUCKDB_SRC_FILES src/duckdb/ub_src_common_row_operations.cpp src/duckdb/ub_src_common_serializer.cpp src/duckdb/ub_src_common_sort.cpp - src/duckdb/ub_src_common_sorting.cpp src/duckdb/ub_src_common_tree_renderer.cpp src/duckdb/ub_src_common_types.cpp src/duckdb/ub_src_common_types_column.cpp @@ -164,6 +163,7 @@ set(DUCKDB_SRC_FILES src/duckdb/ub_src_function_scalar.cpp src/duckdb/ub_src_function_scalar_date.cpp src/duckdb/ub_src_function_scalar_generic.cpp + src/duckdb/ub_src_function_scalar_geometry.cpp src/duckdb/ub_src_function_scalar_list.cpp src/duckdb/ub_src_function_scalar_map.cpp src/duckdb/ub_src_function_scalar_operator.cpp @@ -377,9 +377,11 @@ set(DUCKDB_SRC_FILES src/duckdb/extension/parquet/parquet_timestamp.cpp src/duckdb/extension/parquet/parquet_float16.cpp src/duckdb/extension/parquet/parquet_statistics.cpp + src/duckdb/extension/parquet/parquet_shredding.cpp + src/duckdb/extension/parquet/parquet_geometry.cpp src/duckdb/extension/parquet/parquet_multi_file_info.cpp src/duckdb/extension/parquet/column_reader.cpp - src/duckdb/extension/parquet/geo_parquet.cpp + src/duckdb/extension/parquet/parquet_field_id.cpp src/duckdb/extension/parquet/parquet_extension.cpp src/duckdb/extension/parquet/column_writer.cpp src/duckdb/extension/parquet/parquet_file_metadata_cache.cpp @@ -389,6 +391,7 @@ set(DUCKDB_SRC_FILES src/duckdb/ub_extension_parquet_reader.cpp src/duckdb/ub_extension_parquet_reader_variant.cpp src/duckdb/ub_extension_parquet_writer.cpp + src/duckdb/ub_extension_parquet_writer_variant.cpp src/duckdb/third_party/parquet/parquet_types.cpp src/duckdb/third_party/thrift/thrift/protocol/TProtocol.cpp src/duckdb/third_party/thrift/thrift/transport/TTransportException.cpp diff --git a/src/duckdb/extension/core_functions/aggregate/algebraic/avg.cpp b/src/duckdb/extension/core_functions/aggregate/algebraic/avg.cpp index c2cfd61f8..6e55010e2 100644 --- a/src/duckdb/extension/core_functions/aggregate/algebraic/avg.cpp +++ b/src/duckdb/extension/core_functions/aggregate/algebraic/avg.cpp @@ -272,7 +272,7 @@ unique_ptr BindDecimalAvg(ClientContext &context, AggregateFunctio function = GetAverageAggregate(decimal_type.InternalType()); function.name = "avg"; function.arguments[0] = decimal_type; - function.return_type = LogicalType::DOUBLE; + function.SetReturnType(LogicalType::DOUBLE); return make_uniq( Hugeint::Cast(Hugeint::POWERS_OF_TEN[DecimalType::GetScale(decimal_type)])); } diff --git a/src/duckdb/extension/core_functions/aggregate/distributive/approx_count.cpp b/src/duckdb/extension/core_functions/aggregate/distributive/approx_count.cpp index 40b426390..9e478dedd 100644 --- a/src/duckdb/extension/core_functions/aggregate/distributive/approx_count.cpp +++ b/src/duckdb/extension/core_functions/aggregate/distributive/approx_count.cpp @@ -90,7 +90,7 @@ AggregateFunction GetApproxCountDistinctFunction(const LogicalType &input_type) AggregateFunction::StateCombine, AggregateFunction::StateFinalize, ApproxCountDistinctSimpleUpdateFunction); - fun.null_handling = FunctionNullHandling::SPECIAL_HANDLING; + fun.SetNullHandling(FunctionNullHandling::SPECIAL_HANDLING); return fun; } diff --git a/src/duckdb/extension/core_functions/aggregate/distributive/arg_min_max.cpp b/src/duckdb/extension/core_functions/aggregate/distributive/arg_min_max.cpp index d2bdfbe54..15d560321 100644 --- a/src/duckdb/extension/core_functions/aggregate/distributive/arg_min_max.cpp +++ b/src/duckdb/extension/core_functions/aggregate/distributive/arg_min_max.cpp @@ -15,7 +15,7 @@ namespace duckdb { namespace { struct ArgMinMaxStateBase { - ArgMinMaxStateBase() : is_initialized(false), arg_null(false) { + ArgMinMaxStateBase() : is_initialized(false), arg_null(false), val_null(false) { } template @@ -34,6 +34,7 @@ struct ArgMinMaxStateBase { bool is_initialized; bool arg_null; + bool val_null; }; // Out-of-line specialisations @@ -81,7 +82,7 @@ struct ArgMinMaxState : public ArgMinMaxStateBase { } }; -template +template struct ArgMinMaxBase { template static void Initialize(STATE &state) { @@ -94,25 +95,48 @@ struct ArgMinMaxBase { } template - static void Assign(STATE &state, const A_TYPE &x, const B_TYPE &y, const bool x_null, + static void Assign(STATE &state, const A_TYPE &x, const B_TYPE &y, const bool x_null, const bool y_null, AggregateInputData &aggregate_input_data) { - if (IGNORE_NULL) { + D_ASSERT(aggregate_input_data.bind_data); + const auto &bind_data = aggregate_input_data.bind_data->Cast(); + + if (bind_data.null_handling == ArgMinMaxNullHandling::IGNORE_ANY_NULL) { STATE::template AssignValue(state.arg, x, aggregate_input_data); STATE::template AssignValue(state.value, y, aggregate_input_data); } else { state.arg_null = x_null; + state.val_null = y_null; if (!state.arg_null) { STATE::template AssignValue(state.arg, x, aggregate_input_data); } - STATE::template AssignValue(state.value, y, aggregate_input_data); + if (!state.val_null) { + STATE::template AssignValue(state.value, y, aggregate_input_data); + } } } template static void Operation(STATE &state, const A_TYPE &x, const B_TYPE &y, AggregateBinaryInput &binary) { + D_ASSERT(binary.input.bind_data); + const auto &bind_data = binary.input.bind_data->Cast(); if (!state.is_initialized) { - if (IGNORE_NULL || binary.right_mask.RowIsValid(binary.ridx)) { - Assign(state, x, y, !binary.left_mask.RowIsValid(binary.lidx), binary.input); + if (bind_data.null_handling == ArgMinMaxNullHandling::IGNORE_ANY_NULL && + binary.left_mask.RowIsValid(binary.lidx) && binary.right_mask.RowIsValid(binary.ridx)) { + Assign(state, x, y, !binary.left_mask.RowIsValid(binary.lidx), + !binary.right_mask.RowIsValid(binary.ridx), binary.input); + state.is_initialized = true; + return; + } + if (bind_data.null_handling == ArgMinMaxNullHandling::HANDLE_ARG_NULL && + binary.right_mask.RowIsValid(binary.ridx)) { + Assign(state, x, y, !binary.left_mask.RowIsValid(binary.lidx), + !binary.right_mask.RowIsValid(binary.ridx), binary.input); + state.is_initialized = true; + return; + } + if (bind_data.null_handling == ArgMinMaxNullHandling::HANDLE_ANY_NULL) { + Assign(state, x, y, !binary.left_mask.RowIsValid(binary.lidx), + !binary.right_mask.RowIsValid(binary.ridx), binary.input); state.is_initialized = true; } } else { @@ -122,8 +146,15 @@ struct ArgMinMaxBase { template static void Execute(STATE &state, A_TYPE x_data, B_TYPE y_data, AggregateBinaryInput &binary) { - if ((IGNORE_NULL || binary.right_mask.RowIsValid(binary.ridx)) && COMPARATOR::Operation(y_data, state.value)) { - Assign(state, x_data, y_data, !binary.left_mask.RowIsValid(binary.lidx), binary.input); + D_ASSERT(binary.input.bind_data); + const auto &bind_data = binary.input.bind_data->Cast(); + + if (binary.right_mask.RowIsValid(binary.ridx) && + (state.val_null || COMPARATOR::Operation(y_data, state.value))) { + if (bind_data.null_handling != ArgMinMaxNullHandling::IGNORE_ANY_NULL || + binary.left_mask.RowIsValid(binary.lidx)) { + Assign(state, x_data, y_data, !binary.left_mask.RowIsValid(binary.lidx), false, binary.input); + } } } @@ -132,8 +163,10 @@ struct ArgMinMaxBase { if (!source.is_initialized) { return; } - if (!target.is_initialized || COMPARATOR::Operation(source.value, target.value)) { - Assign(target, source.arg, source.value, source.arg_null, aggregate_input_data); + + if (!target.is_initialized || target.val_null || + (!source.val_null && COMPARATOR::Operation(source.value, target.value))) { + Assign(target, source.arg, source.value, source.arg_null, false, aggregate_input_data); target.is_initialized = true; } } @@ -148,17 +181,20 @@ struct ArgMinMaxBase { } static bool IgnoreNull() { - return IGNORE_NULL; + return false; } + template static unique_ptr Bind(ClientContext &context, AggregateFunction &function, vector> &arguments) { if (arguments[1]->return_type.InternalType() == PhysicalType::VARCHAR) { ExpressionBinder::PushCollation(context, arguments[1], arguments[1]->return_type); } function.arguments[0] = arguments[0]->return_type; - function.return_type = arguments[0]->return_type; - return nullptr; + function.SetReturnType(arguments[0]->return_type); + + auto function_data = make_uniq(NULL_HANDLING); + return unique_ptr(std::move(function_data)); } }; @@ -186,12 +222,14 @@ struct GenericArgMinMaxState { } }; -template -struct VectorArgMinMaxBase : ArgMinMaxBase { +template +struct VectorArgMinMaxBase : ArgMinMaxBase { template static void Update(Vector inputs[], AggregateInputData &aggregate_input_data, idx_t input_count, Vector &state_vector, idx_t count) { + D_ASSERT(aggregate_input_data.bind_data); + const auto &bind_data = aggregate_input_data.bind_data->Cast(); + auto &arg = inputs[0]; UnifiedVectorFormat adata; arg.ToUnifiedFormat(count, adata); @@ -213,21 +251,36 @@ struct VectorArgMinMaxBase : ArgMinMaxBase { auto states = UnifiedVectorFormat::GetData(sdata); for (idx_t i = 0; i < count; i++) { - const auto bidx = bdata.sel->get_index(i); - if (!bdata.validity.RowIsValid(bidx)) { - continue; - } - const auto bval = bys[bidx]; + const auto sidx = sdata.sel->get_index(i); + auto &state = *states[sidx]; const auto aidx = adata.sel->get_index(i); const auto arg_null = !adata.validity.RowIsValid(aidx); - if (IGNORE_NULL && arg_null) { + + if (bind_data.null_handling == ArgMinMaxNullHandling::IGNORE_ANY_NULL && arg_null) { continue; } - const auto sidx = sdata.sel->get_index(i); - auto &state = *states[sidx]; - if (!state.is_initialized || COMPARATOR::template Operation(bval, state.value)) { + const auto bidx = bdata.sel->get_index(i); + + if (!bdata.validity.RowIsValid(bidx)) { + if (bind_data.null_handling == ArgMinMaxNullHandling::HANDLE_ANY_NULL && !state.is_initialized) { + state.is_initialized = true; + state.val_null = true; + if (!arg_null) { + if (&state == last_state) { + assign_count--; + } + assign_sel[assign_count++] = UnsafeNumericCast(i); + last_state = &state; + } + } + continue; + } + + const auto bval = bys[bidx]; + + if (!state.is_initialized || state.val_null || COMPARATOR::template Operation(bval, state.value)) { STATE::template AssignValue(state.value, bval, aggregate_input_data); state.arg_null = arg_null; // micro-adaptivity: it is common we overwrite the same state repeatedly @@ -270,8 +323,12 @@ struct VectorArgMinMaxBase : ArgMinMaxBase { if (!source.is_initialized) { return; } - if (!target.is_initialized || COMPARATOR::Operation(source.value, target.value)) { - STATE::template AssignValue(target.value, source.value, aggregate_input_data); + if (!target.is_initialized || target.val_null || + (!source.val_null && COMPARATOR::Operation(source.value, target.value))) { + target.val_null = source.val_null; + if (!target.val_null) { + STATE::template AssignValue(target.value, source.value, aggregate_input_data); + } target.arg_null = source.arg_null; if (!target.arg_null) { STATE::template AssignValue(target.arg, source.arg, aggregate_input_data); @@ -290,38 +347,56 @@ struct VectorArgMinMaxBase : ArgMinMaxBase { } } + template static unique_ptr Bind(ClientContext &context, AggregateFunction &function, vector> &arguments) { if (arguments[1]->return_type.InternalType() == PhysicalType::VARCHAR) { ExpressionBinder::PushCollation(context, arguments[1], arguments[1]->return_type); } function.arguments[0] = arguments[0]->return_type; - function.return_type = arguments[0]->return_type; - return nullptr; + function.SetReturnType(arguments[0]->return_type); + + auto function_data = make_uniq(NULL_HANDLING); + return unique_ptr(std::move(function_data)); } }; template -AggregateFunction GetGenericArgMinMaxFunction() { +bind_aggregate_function_t GetBindFunction(const ArgMinMaxNullHandling null_handling) { + switch (null_handling) { + case ArgMinMaxNullHandling::HANDLE_ARG_NULL: + return OP::template Bind; + case ArgMinMaxNullHandling::HANDLE_ANY_NULL: + return OP::template Bind; + default: + return OP::template Bind; + } +} + +template +AggregateFunction GetGenericArgMinMaxFunction(const ArgMinMaxNullHandling null_handling) { using STATE = ArgMinMaxState; + auto bind = GetBindFunction(null_handling); return AggregateFunction( {LogicalType::ANY, LogicalType::ANY}, LogicalType::ANY, AggregateFunction::StateSize, AggregateFunction::StateInitialize, OP::template Update, - AggregateFunction::StateCombine, AggregateFunction::StateVoidFinalize, nullptr, OP::Bind, + AggregateFunction::StateCombine, AggregateFunction::StateVoidFinalize, nullptr, bind, AggregateFunction::StateDestroy); } template -AggregateFunction GetVectorArgMinMaxFunctionInternal(const LogicalType &by_type, const LogicalType &type) { +AggregateFunction GetVectorArgMinMaxFunctionInternal(const LogicalType &by_type, const LogicalType &type, + const ArgMinMaxNullHandling null_handling) { #ifndef DUCKDB_SMALLER_BINARY using STATE = ArgMinMaxState; + auto bind = GetBindFunction(null_handling); return AggregateFunction({type, by_type}, type, AggregateFunction::StateSize, AggregateFunction::StateInitialize, OP::template Update, AggregateFunction::StateCombine, - AggregateFunction::StateVoidFinalize, nullptr, OP::Bind, + AggregateFunction::StateVoidFinalize, nullptr, bind, AggregateFunction::StateDestroy); #else - auto function = GetGenericArgMinMaxFunction(); + auto function = GetGenericArgMinMaxFunction(null_handling); function.arguments = {type, by_type}; function.return_type = type; return function; @@ -330,18 +405,19 @@ AggregateFunction GetVectorArgMinMaxFunctionInternal(const LogicalType &by_type, #ifndef DUCKDB_SMALLER_BINARY template -AggregateFunction GetVectorArgMinMaxFunctionBy(const LogicalType &by_type, const LogicalType &type) { +AggregateFunction GetVectorArgMinMaxFunctionBy(const LogicalType &by_type, const LogicalType &type, + const ArgMinMaxNullHandling null_handling) { switch (by_type.InternalType()) { case PhysicalType::INT32: - return GetVectorArgMinMaxFunctionInternal(by_type, type); + return GetVectorArgMinMaxFunctionInternal(by_type, type, null_handling); case PhysicalType::INT64: - return GetVectorArgMinMaxFunctionInternal(by_type, type); + return GetVectorArgMinMaxFunctionInternal(by_type, type, null_handling); case PhysicalType::INT128: - return GetVectorArgMinMaxFunctionInternal(by_type, type); + return GetVectorArgMinMaxFunctionInternal(by_type, type, null_handling); case PhysicalType::DOUBLE: - return GetVectorArgMinMaxFunctionInternal(by_type, type); + return GetVectorArgMinMaxFunctionInternal(by_type, type, null_handling); case PhysicalType::VARCHAR: - return GetVectorArgMinMaxFunctionInternal(by_type, type); + return GetVectorArgMinMaxFunctionInternal(by_type, type, null_handling); default: throw InternalException("Unimplemented arg_min/arg_max aggregate"); } @@ -356,19 +432,21 @@ const vector ArgMaxByTypes() { } template -void AddVectorArgMinMaxFunctionBy(AggregateFunctionSet &fun, const LogicalType &type) { +void AddVectorArgMinMaxFunctionBy(AggregateFunctionSet &fun, const LogicalType &type, + const ArgMinMaxNullHandling null_handling) { auto by_types = ArgMaxByTypes(); for (const auto &by_type : by_types) { #ifndef DUCKDB_SMALLER_BINARY - fun.AddFunction(GetVectorArgMinMaxFunctionBy(by_type, type)); + fun.AddFunction(GetVectorArgMinMaxFunctionBy(by_type, type, null_handling)); #else - fun.AddFunction(GetVectorArgMinMaxFunctionInternal(by_type, type)); + fun.AddFunction(GetVectorArgMinMaxFunctionInternal(by_type, type, null_handling)); #endif } } template -AggregateFunction GetArgMinMaxFunctionInternal(const LogicalType &by_type, const LogicalType &type) { +AggregateFunction GetArgMinMaxFunctionInternal(const LogicalType &by_type, const LogicalType &type, + const ArgMinMaxNullHandling null_handling) { #ifndef DUCKDB_SMALLER_BINARY using STATE = ArgMinMaxState; auto function = @@ -377,9 +455,9 @@ AggregateFunction GetArgMinMaxFunctionInternal(const LogicalType &by_type, const if (type.InternalType() == PhysicalType::VARCHAR || by_type.InternalType() == PhysicalType::VARCHAR) { function.destructor = AggregateFunction::StateDestroy; } - function.bind = OP::Bind; + function.bind = GetBindFunction(null_handling); #else - auto function = GetGenericArgMinMaxFunction(); + auto function = GetGenericArgMinMaxFunction(null_handling); function.arguments = {type, by_type}; function.return_type = type; #endif @@ -388,18 +466,19 @@ AggregateFunction GetArgMinMaxFunctionInternal(const LogicalType &by_type, const #ifndef DUCKDB_SMALLER_BINARY template -AggregateFunction GetArgMinMaxFunctionBy(const LogicalType &by_type, const LogicalType &type) { +AggregateFunction GetArgMinMaxFunctionBy(const LogicalType &by_type, const LogicalType &type, + const ArgMinMaxNullHandling null_handling) { switch (by_type.InternalType()) { case PhysicalType::INT32: - return GetArgMinMaxFunctionInternal(by_type, type); + return GetArgMinMaxFunctionInternal(by_type, type, null_handling); case PhysicalType::INT64: - return GetArgMinMaxFunctionInternal(by_type, type); + return GetArgMinMaxFunctionInternal(by_type, type, null_handling); case PhysicalType::INT128: - return GetArgMinMaxFunctionInternal(by_type, type); + return GetArgMinMaxFunctionInternal(by_type, type, null_handling); case PhysicalType::DOUBLE: - return GetArgMinMaxFunctionInternal(by_type, type); + return GetArgMinMaxFunctionInternal(by_type, type, null_handling); case PhysicalType::VARCHAR: - return GetArgMinMaxFunctionInternal(by_type, type); + return GetArgMinMaxFunctionInternal(by_type, type, null_handling); default: throw InternalException("Unimplemented arg_min/arg_max by aggregate"); } @@ -407,37 +486,38 @@ AggregateFunction GetArgMinMaxFunctionBy(const LogicalType &by_type, const Logic #endif template -void AddArgMinMaxFunctionBy(AggregateFunctionSet &fun, const LogicalType &type) { +void AddArgMinMaxFunctionBy(AggregateFunctionSet &fun, const LogicalType &type, ArgMinMaxNullHandling null_handling) { auto by_types = ArgMaxByTypes(); for (const auto &by_type : by_types) { #ifndef DUCKDB_SMALLER_BINARY - fun.AddFunction(GetArgMinMaxFunctionBy(by_type, type)); + fun.AddFunction(GetArgMinMaxFunctionBy(by_type, type, null_handling)); #else - fun.AddFunction(GetArgMinMaxFunctionInternal(by_type, type)); + fun.AddFunction(GetArgMinMaxFunctionInternal(by_type, type, null_handling)); #endif } } template -AggregateFunction GetDecimalArgMinMaxFunction(const LogicalType &by_type, const LogicalType &type) { +AggregateFunction GetDecimalArgMinMaxFunction(const LogicalType &by_type, const LogicalType &type, + ArgMinMaxNullHandling null_handling) { D_ASSERT(type.id() == LogicalTypeId::DECIMAL); #ifndef DUCKDB_SMALLER_BINARY switch (type.InternalType()) { case PhysicalType::INT16: - return GetArgMinMaxFunctionBy(by_type, type); + return GetArgMinMaxFunctionBy(by_type, type, null_handling); case PhysicalType::INT32: - return GetArgMinMaxFunctionBy(by_type, type); + return GetArgMinMaxFunctionBy(by_type, type, null_handling); case PhysicalType::INT64: - return GetArgMinMaxFunctionBy(by_type, type); + return GetArgMinMaxFunctionBy(by_type, type, null_handling); default: - return GetArgMinMaxFunctionBy(by_type, type); + return GetArgMinMaxFunctionBy(by_type, type, null_handling); } #else - return GetArgMinMaxFunctionInternal(by_type, type); + return GetArgMinMaxFunctionInternal(by_type, type, null_handling); #endif } -template +template unique_ptr BindDecimalArgMinMax(ClientContext &context, AggregateFunction &function, vector> &arguments) { auto decimal_type = arguments[0]->return_type; @@ -469,51 +549,69 @@ unique_ptr BindDecimalArgMinMax(ClientContext &context, AggregateF } auto name = std::move(function.name); - function = GetDecimalArgMinMaxFunction(by_type, decimal_type); + function = GetDecimalArgMinMaxFunction(by_type, decimal_type, NULL_HANDLING); function.name = std::move(name); - function.return_type = decimal_type; - return nullptr; + function.SetReturnType(decimal_type); + + auto function_data = make_uniq(NULL_HANDLING); + return unique_ptr(std::move(function_data)); } template -void AddDecimalArgMinMaxFunctionBy(AggregateFunctionSet &fun, const LogicalType &by_type) { - fun.AddFunction(AggregateFunction({LogicalTypeId::DECIMAL, by_type}, LogicalTypeId::DECIMAL, nullptr, nullptr, - nullptr, nullptr, nullptr, nullptr, BindDecimalArgMinMax)); +void AddDecimalArgMinMaxFunctionBy(AggregateFunctionSet &fun, const LogicalType &by_type, + const ArgMinMaxNullHandling null_handling) { + switch (null_handling) { + case ArgMinMaxNullHandling::IGNORE_ANY_NULL: + fun.AddFunction(AggregateFunction({LogicalTypeId::DECIMAL, by_type}, LogicalTypeId::DECIMAL, nullptr, nullptr, + nullptr, nullptr, nullptr, nullptr, + BindDecimalArgMinMax)); + break; + case ArgMinMaxNullHandling::HANDLE_ARG_NULL: + fun.AddFunction(AggregateFunction({LogicalTypeId::DECIMAL, by_type}, LogicalTypeId::DECIMAL, nullptr, nullptr, + nullptr, nullptr, nullptr, nullptr, + BindDecimalArgMinMax)); + break; + case ArgMinMaxNullHandling::HANDLE_ANY_NULL: + fun.AddFunction(AggregateFunction({LogicalTypeId::DECIMAL, by_type}, LogicalTypeId::DECIMAL, nullptr, nullptr, + nullptr, nullptr, nullptr, nullptr, + BindDecimalArgMinMax)); + break; + } } template -void AddGenericArgMinMaxFunction(AggregateFunctionSet &fun) { - fun.AddFunction(GetGenericArgMinMaxFunction()); +void AddGenericArgMinMaxFunction(AggregateFunctionSet &fun, const ArgMinMaxNullHandling null_handling) { + fun.AddFunction(GetGenericArgMinMaxFunction(null_handling)); } -template -void AddArgMinMaxFunctions(AggregateFunctionSet &fun) { - using GENERIC_VECTOR_OP = VectorArgMinMaxBase>; +template +void AddArgMinMaxFunctions(AggregateFunctionSet &fun, const ArgMinMaxNullHandling null_handling) { + using GENERIC_VECTOR_OP = VectorArgMinMaxBase>; #ifndef DUCKDB_SMALLER_BINARY - using OP = ArgMinMaxBase; - using VECTOR_OP = VectorArgMinMaxBase; + using OP = ArgMinMaxBase; + using VECTOR_OP = VectorArgMinMaxBase; #else using OP = GENERIC_VECTOR_OP; using VECTOR_OP = GENERIC_VECTOR_OP; #endif - AddArgMinMaxFunctionBy(fun, LogicalType::INTEGER); - AddArgMinMaxFunctionBy(fun, LogicalType::BIGINT); - AddArgMinMaxFunctionBy(fun, LogicalType::DOUBLE); - AddArgMinMaxFunctionBy(fun, LogicalType::VARCHAR); - AddArgMinMaxFunctionBy(fun, LogicalType::DATE); - AddArgMinMaxFunctionBy(fun, LogicalType::TIMESTAMP); - AddArgMinMaxFunctionBy(fun, LogicalType::TIMESTAMP_TZ); - AddArgMinMaxFunctionBy(fun, LogicalType::BLOB); + AddArgMinMaxFunctionBy(fun, LogicalType::INTEGER, null_handling); + AddArgMinMaxFunctionBy(fun, LogicalType::BIGINT, null_handling); + AddArgMinMaxFunctionBy(fun, LogicalType::DOUBLE, null_handling); + AddArgMinMaxFunctionBy(fun, LogicalType::VARCHAR, null_handling); + AddArgMinMaxFunctionBy(fun, LogicalType::DATE, null_handling); + AddArgMinMaxFunctionBy(fun, LogicalType::TIMESTAMP, null_handling); + AddArgMinMaxFunctionBy(fun, LogicalType::TIMESTAMP_TZ, null_handling); + AddArgMinMaxFunctionBy(fun, LogicalType::BLOB, null_handling); auto by_types = ArgMaxByTypes(); for (const auto &by_type : by_types) { - AddDecimalArgMinMaxFunctionBy(fun, by_type); + AddDecimalArgMinMaxFunctionBy(fun, by_type, null_handling); } - AddVectorArgMinMaxFunctionBy(fun, LogicalType::ANY); + AddVectorArgMinMaxFunctionBy(fun, LogicalType::ANY, null_handling); // we always use LessThan when using sort keys because the ORDER_TYPE takes care of selecting the lowest or highest - AddGenericArgMinMaxFunction(fun); + AddGenericArgMinMaxFunction(fun, null_handling); } //------------------------------------------------------------------------------ @@ -547,6 +645,8 @@ class ArgMinMaxNState { template void ArgMinMaxNUpdate(Vector inputs[], AggregateInputData &aggr_input, idx_t input_count, Vector &state_vector, idx_t count) { + D_ASSERT(aggr_input.bind_data); + const auto &bind_data = aggr_input.bind_data->Cast(); auto &val_vector = inputs[0]; auto &arg_vector = inputs[1]; @@ -560,8 +660,8 @@ void ArgMinMaxNUpdate(Vector inputs[], AggregateInputData &aggr_input, idx_t inp auto val_extra_state = STATE::VAL_TYPE::CreateExtraState(val_vector, count); auto arg_extra_state = STATE::ARG_TYPE::CreateExtraState(arg_vector, count); - STATE::VAL_TYPE::PrepareData(val_vector, count, val_extra_state, val_format); - STATE::ARG_TYPE::PrepareData(arg_vector, count, arg_extra_state, arg_format); + STATE::VAL_TYPE::PrepareData(val_vector, count, val_extra_state, val_format, bind_data.nulls_last); + STATE::ARG_TYPE::PrepareData(arg_vector, count, arg_extra_state, arg_format, bind_data.nulls_last); n_vector.ToUnifiedFormat(count, n_format); state_vector.ToUnifiedFormat(count, state_format); @@ -571,9 +671,16 @@ void ArgMinMaxNUpdate(Vector inputs[], AggregateInputData &aggr_input, idx_t inp for (idx_t i = 0; i < count; i++) { const auto arg_idx = arg_format.sel->get_index(i); const auto val_idx = val_format.sel->get_index(i); - if (!arg_format.validity.RowIsValid(arg_idx) || !val_format.validity.RowIsValid(val_idx)) { + + if (bind_data.null_handling == ArgMinMaxNullHandling::IGNORE_ANY_NULL && + (!arg_format.validity.RowIsValid(arg_idx) || !val_format.validity.RowIsValid(val_idx))) { + continue; + } + if (bind_data.null_handling == ArgMinMaxNullHandling::HANDLE_ARG_NULL && + !val_format.validity.RowIsValid(val_idx)) { continue; } + const auto state_idx = state_format.sel->get_index(i); auto &state = *states[state_idx]; @@ -671,7 +778,77 @@ void SpecializeArgMinMaxNFunction(PhysicalType val_type, PhysicalType arg_type, } } -template +template +void SpecializeArgMinMaxNullNFunction(AggregateFunction &function) { + using STATE = ArgMinMaxNState; + using OP = MinMaxNOperation; + + function.state_size = AggregateFunction::StateSize; + function.initialize = AggregateFunction::StateInitialize; + function.combine = AggregateFunction::StateCombine; + function.destructor = AggregateFunction::StateDestroy; + + function.finalize = MinMaxNOperation::Finalize; + function.update = ArgMinMaxNUpdate; +} + +template +void SpecializeArgMinMaxNullNFunction(PhysicalType arg_type, AggregateFunction &function) { + switch (arg_type) { +#ifndef DUCKDB_SMALLER_BINARY + case PhysicalType::VARCHAR: + SpecializeArgMinMaxNullNFunction(function); + break; + case PhysicalType::INT32: + SpecializeArgMinMaxNullNFunction, COMPARATOR>(function); + break; + case PhysicalType::INT64: + SpecializeArgMinMaxNullNFunction, COMPARATOR>(function); + break; + case PhysicalType::FLOAT: + SpecializeArgMinMaxNullNFunction, COMPARATOR>(function); + break; + case PhysicalType::DOUBLE: + SpecializeArgMinMaxNullNFunction, COMPARATOR>(function); + break; +#endif + default: + SpecializeArgMinMaxNullNFunction(function); + break; + } +} + +template +void SpecializeArgMinMaxNullNFunction(PhysicalType val_type, PhysicalType arg_type, AggregateFunction &function) { + switch (val_type) { +#ifndef DUCKDB_SMALLER_BINARY + case PhysicalType::VARCHAR: + SpecializeArgMinMaxNullNFunction(arg_type, function); + break; + case PhysicalType::INT32: + SpecializeArgMinMaxNullNFunction, NULLS_LAST, COMPARATOR>(arg_type, + function); + break; + case PhysicalType::INT64: + SpecializeArgMinMaxNullNFunction, NULLS_LAST, COMPARATOR>(arg_type, + function); + break; + case PhysicalType::FLOAT: + SpecializeArgMinMaxNullNFunction, NULLS_LAST, COMPARATOR>(arg_type, + function); + break; + case PhysicalType::DOUBLE: + SpecializeArgMinMaxNullNFunction, NULLS_LAST, COMPARATOR>(arg_type, + function); + break; +#endif + default: + SpecializeArgMinMaxNullNFunction(arg_type, function); + break; + } +} + +template unique_ptr ArgMinMaxNBind(ClientContext &context, AggregateFunction &function, vector> &arguments) { for (auto &arg : arguments) { @@ -682,19 +859,24 @@ unique_ptr ArgMinMaxNBind(ClientContext &context, AggregateFunctio const auto val_type = arguments[0]->return_type.InternalType(); const auto arg_type = arguments[1]->return_type.InternalType(); + function.SetReturnType(LogicalType::LIST(arguments[0]->return_type)); // Specialize the function based on the input types - SpecializeArgMinMaxNFunction(val_type, arg_type, function); + auto function_data = make_uniq(NULL_HANDLING, NULLS_LAST); + if (NULL_HANDLING != ArgMinMaxNullHandling::IGNORE_ANY_NULL) { + SpecializeArgMinMaxNullNFunction(val_type, arg_type, function); + } else { + SpecializeArgMinMaxNFunction(val_type, arg_type, function); + } - function.return_type = LogicalType::LIST(arguments[0]->return_type); - return nullptr; + return unique_ptr(std::move(function_data)); } -template +template void AddArgMinMaxNFunction(AggregateFunctionSet &set) { AggregateFunction function({LogicalTypeId::ANY, LogicalTypeId::ANY, LogicalType::BIGINT}, LogicalType::LIST(LogicalType::ANY), nullptr, nullptr, nullptr, nullptr, nullptr, - nullptr, ArgMinMaxNBind); + nullptr, ArgMinMaxNBind); return set.AddFunction(function); } @@ -707,27 +889,41 @@ void AddArgMinMaxNFunction(AggregateFunctionSet &set) { AggregateFunctionSet ArgMinFun::GetFunctions() { AggregateFunctionSet fun; - AddArgMinMaxFunctions(fun); - AddArgMinMaxNFunction(fun); + AddArgMinMaxFunctions(fun, ArgMinMaxNullHandling::IGNORE_ANY_NULL); + AddArgMinMaxNFunction(fun); return fun; } AggregateFunctionSet ArgMaxFun::GetFunctions() { AggregateFunctionSet fun; - AddArgMinMaxFunctions(fun); - AddArgMinMaxNFunction(fun); + AddArgMinMaxFunctions(fun, ArgMinMaxNullHandling::IGNORE_ANY_NULL); + AddArgMinMaxNFunction(fun); return fun; } AggregateFunctionSet ArgMinNullFun::GetFunctions() { AggregateFunctionSet fun; - AddArgMinMaxFunctions(fun); + AddArgMinMaxFunctions(fun, ArgMinMaxNullHandling::HANDLE_ARG_NULL); return fun; } AggregateFunctionSet ArgMaxNullFun::GetFunctions() { AggregateFunctionSet fun; - AddArgMinMaxFunctions(fun); + AddArgMinMaxFunctions(fun, ArgMinMaxNullHandling::HANDLE_ARG_NULL); + return fun; +} + +AggregateFunctionSet ArgMinNullsLastFun::GetFunctions() { + AggregateFunctionSet fun; + AddArgMinMaxFunctions(fun, ArgMinMaxNullHandling::HANDLE_ANY_NULL); + AddArgMinMaxNFunction(fun); + return fun; +} + +AggregateFunctionSet ArgMaxNullsLastFun::GetFunctions() { + AggregateFunctionSet fun; + AddArgMinMaxFunctions(fun, ArgMinMaxNullHandling::HANDLE_ANY_NULL); + AddArgMinMaxNFunction(fun); return fun; } diff --git a/src/duckdb/extension/core_functions/aggregate/distributive/bitagg.cpp b/src/duckdb/extension/core_functions/aggregate/distributive/bitagg.cpp index 168d3a539..fccfd0ac8 100644 --- a/src/duckdb/extension/core_functions/aggregate/distributive/bitagg.cpp +++ b/src/duckdb/extension/core_functions/aggregate/distributive/bitagg.cpp @@ -166,7 +166,6 @@ struct BitStringBitwiseOperation : public BitwiseOperation { }; struct BitStringAndOperation : public BitStringBitwiseOperation { - template static void Execute(STATE &state, INPUT_TYPE input) { Bit::BitwiseAnd(input, state.value, state.value); @@ -174,7 +173,6 @@ struct BitStringAndOperation : public BitStringBitwiseOperation { }; struct BitStringOrOperation : public BitStringBitwiseOperation { - template static void Execute(STATE &state, INPUT_TYPE input) { Bit::BitwiseOr(input, state.value, state.value); diff --git a/src/duckdb/extension/core_functions/aggregate/distributive/bitstring_agg.cpp b/src/duckdb/extension/core_functions/aggregate/distributive/bitstring_agg.cpp index fad7550d8..f3bfb115b 100644 --- a/src/duckdb/extension/core_functions/aggregate/distributive/bitstring_agg.cpp +++ b/src/duckdb/extension/core_functions/aggregate/distributive/bitstring_agg.cpp @@ -235,7 +235,6 @@ idx_t BitStringAggOperation::GetRange(uhugeint_t min, uhugeint_t max) { unique_ptr BitstringPropagateStats(ClientContext &context, BoundAggregateExpression &expr, AggregateStatisticsInput &input) { - if (NumericStats::HasMinMax(input.child_stats[0])) { auto &bind_agg_data = input.bind_data->Cast(); bind_agg_data.min = NumericStats::Min(input.child_stats[0]); diff --git a/src/duckdb/extension/core_functions/aggregate/distributive/kurtosis.cpp b/src/duckdb/extension/core_functions/aggregate/distributive/kurtosis.cpp index aa551eca5..d1ca6b694 100644 --- a/src/duckdb/extension/core_functions/aggregate/distributive/kurtosis.cpp +++ b/src/duckdb/extension/core_functions/aggregate/distributive/kurtosis.cpp @@ -106,7 +106,7 @@ AggregateFunction KurtosisFun::GetFunction() { auto result = AggregateFunction::UnaryAggregate>( LogicalType::DOUBLE, LogicalType::DOUBLE); - result.errors = FunctionErrors::CAN_THROW_RUNTIME_ERROR; + result.SetFallible(); return result; } @@ -114,7 +114,7 @@ AggregateFunction KurtosisPopFun::GetFunction() { auto result = AggregateFunction::UnaryAggregate>( LogicalType::DOUBLE, LogicalType::DOUBLE); - result.errors = FunctionErrors::CAN_THROW_RUNTIME_ERROR; + result.SetFallible(); return result; } diff --git a/src/duckdb/extension/core_functions/aggregate/distributive/string_agg.cpp b/src/duckdb/extension/core_functions/aggregate/distributive/string_agg.cpp index ddbecbf28..7c5aa8764 100644 --- a/src/duckdb/extension/core_functions/aggregate/distributive/string_agg.cpp +++ b/src/duckdb/extension/core_functions/aggregate/distributive/string_agg.cpp @@ -120,6 +120,12 @@ unique_ptr StringAggBind(ClientContext &context, AggregateFunction return make_uniq(","); } D_ASSERT(arguments.size() == 2); + // Check if any argument is of UNKNOWN type (parameter not yet bound) + for (auto &arg : arguments) { + if (arg->return_type.id() == LogicalTypeId::UNKNOWN) { + throw ParameterNotResolvedException(); + } + } if (arguments[1]->HasParameter()) { throw ParameterNotResolvedException(); } diff --git a/src/duckdb/extension/core_functions/aggregate/distributive/sum.cpp b/src/duckdb/extension/core_functions/aggregate/distributive/sum.cpp index bfea19644..5b821d44c 100644 --- a/src/duckdb/extension/core_functions/aggregate/distributive/sum.cpp +++ b/src/duckdb/extension/core_functions/aggregate/distributive/sum.cpp @@ -84,7 +84,7 @@ void SumNoOverflowSerialize(Serializer &serializer, const optional_ptr SumNoOverflowDeserialize(Deserializer &deserializer, AggregateFunction &function) { - function.return_type = deserializer.Get(); + function.SetReturnType(deserializer.Get()); return nullptr; } @@ -207,7 +207,7 @@ unique_ptr BindDecimalSum(ClientContext &context, AggregateFunctio function = GetSumAggregate(decimal_type.InternalType()); function.name = "sum"; function.arguments[0] = decimal_type; - function.return_type = LogicalType::DECIMAL(Decimal::MAX_WIDTH_DECIMAL, DecimalType::GetScale(decimal_type)); + function.SetReturnType(LogicalType::DECIMAL(Decimal::MAX_WIDTH_DECIMAL, DecimalType::GetScale(decimal_type))); function.order_dependent = AggregateOrderDependent::NOT_ORDER_DEPENDENT; return nullptr; } diff --git a/src/duckdb/extension/core_functions/aggregate/holistic/approx_top_k.cpp b/src/duckdb/extension/core_functions/aggregate/holistic/approx_top_k.cpp index 641a5010e..5199b514b 100644 --- a/src/duckdb/extension/core_functions/aggregate/holistic/approx_top_k.cpp +++ b/src/duckdb/extension/core_functions/aggregate/holistic/approx_top_k.cpp @@ -398,7 +398,7 @@ unique_ptr ApproxTopKBind(ClientContext &context, AggregateFunctio function.update = ApproxTopKUpdate; function.finalize = ApproxTopKFinalize; } - function.return_type = LogicalType::LIST(arguments[0]->return_type); + function.SetReturnType(LogicalType::LIST(arguments[0]->return_type)); return nullptr; } diff --git a/src/duckdb/extension/core_functions/aggregate/holistic/approximate_quantile.cpp b/src/duckdb/extension/core_functions/aggregate/holistic/approximate_quantile.cpp index 35336383b..49c87ddbb 100644 --- a/src/duckdb/extension/core_functions/aggregate/holistic/approximate_quantile.cpp +++ b/src/duckdb/extension/core_functions/aggregate/holistic/approximate_quantile.cpp @@ -294,7 +294,6 @@ AggregateFunction GetApproximateQuantileAggregate(const LogicalType &type) { template struct ApproxQuantileListOperation : public ApproxQuantileOperation { - template static void Finalize(STATE &state, RESULT_TYPE &target, AggregateFinalizeData &finalize_data) { if (state.pos == 0) { diff --git a/src/duckdb/extension/core_functions/aggregate/holistic/mad.cpp b/src/duckdb/extension/core_functions/aggregate/holistic/mad.cpp index 9835e44b3..41e965e64 100644 --- a/src/duckdb/extension/core_functions/aggregate/holistic/mad.cpp +++ b/src/duckdb/extension/core_functions/aggregate/holistic/mad.cpp @@ -57,7 +57,6 @@ struct QuantileReuseUpdater { }; void ReuseIndexes(idx_t *index, const SubFrames &currs, const SubFrames &prevs) { - // Copy overlapping indices by scanning the previous set and copying down into holes. // We copy instead of leaving gaps in case there are fewer values in the current frame. FrameSet prev_set(prevs); @@ -317,7 +316,7 @@ AggregateFunction GetMedianAbsoluteDeviationAggregateFunctionInternal(const Logi AggregateFunction GetMedianAbsoluteDeviationAggregateFunction(const LogicalType &type) { auto result = GetMedianAbsoluteDeviationAggregateFunctionInternal(type); - result.errors = FunctionErrors::CAN_THROW_RUNTIME_ERROR; + result.SetFallible(); return result; } diff --git a/src/duckdb/extension/core_functions/aggregate/holistic/mode.cpp b/src/duckdb/extension/core_functions/aggregate/holistic/mode.cpp index dc09dd32f..45a8a6f60 100644 --- a/src/duckdb/extension/core_functions/aggregate/holistic/mode.cpp +++ b/src/duckdb/extension/core_functions/aggregate/holistic/mode.cpp @@ -518,7 +518,7 @@ AggregateFunction GetTypedEntropyFunction(const LogicalType &type) { auto func = AggregateFunction::UnaryAggregateDestructor( type, LogicalType::DOUBLE); - func.null_handling = FunctionNullHandling::SPECIAL_HANDLING; + func.SetNullHandling(FunctionNullHandling::SPECIAL_HANDLING); return func; } @@ -530,7 +530,7 @@ AggregateFunction GetFallbackEntropyFunction(const LogicalType &type) { AggregateSortKeyHelpers::UnaryUpdate, AggregateFunction::StateCombine, AggregateFunction::StateFinalize, nullptr); func.destructor = AggregateFunction::StateDestroy; - func.null_handling = FunctionNullHandling::SPECIAL_HANDLING; + func.SetNullHandling(FunctionNullHandling::SPECIAL_HANDLING); return func; } diff --git a/src/duckdb/extension/core_functions/aggregate/nested/histogram.cpp b/src/duckdb/extension/core_functions/aggregate/nested/histogram.cpp index e1af92578..790c60d17 100644 --- a/src/duckdb/extension/core_functions/aggregate/nested/histogram.cpp +++ b/src/duckdb/extension/core_functions/aggregate/nested/histogram.cpp @@ -61,7 +61,6 @@ struct StringMapType { template void HistogramUpdateFunction(Vector inputs[], AggregateInputData &aggr_input, idx_t input_count, Vector &state_vector, idx_t count) { - D_ASSERT(input_count == 1); auto &input = inputs[0]; @@ -209,14 +208,13 @@ AggregateFunction GetHistogramFunction(const LogicalType &type) { template unique_ptr HistogramBindFunction(ClientContext &context, AggregateFunction &function, vector> &arguments) { - D_ASSERT(arguments.size() == 1); if (arguments[0]->return_type.id() == LogicalTypeId::UNKNOWN) { throw ParameterNotResolvedException(); } function = GetHistogramFunction(arguments[0]->return_type); - return make_uniq(function.return_type); + return make_uniq(function.GetReturnType()); } } // namespace diff --git a/src/duckdb/extension/core_functions/aggregate/nested/list.cpp b/src/duckdb/extension/core_functions/aggregate/nested/list.cpp index 5771e14eb..92916c71b 100644 --- a/src/duckdb/extension/core_functions/aggregate/nested/list.cpp +++ b/src/duckdb/extension/core_functions/aggregate/nested/list.cpp @@ -47,7 +47,6 @@ struct ListFunction { void ListUpdateFunction(Vector inputs[], AggregateInputData &aggr_input_data, idx_t input_count, Vector &state_vector, idx_t count) { - D_ASSERT(input_count == 1); auto &input = inputs[0]; RecursiveUnifiedVectorFormat input_data; @@ -75,7 +74,6 @@ void ListAbsorbFunction(Vector &states_vector, Vector &combined, AggregateInputD auto combined_ptr = FlatVector::GetData(combined); for (idx_t i = 0; i < count; i++) { - auto &state = *states_ptr[states_data.sel->get_index(i)]; if (state.linked_list.total_capacity == 0) { // NULL, no need to append @@ -98,7 +96,6 @@ void ListAbsorbFunction(Vector &states_vector, Vector &combined, AggregateInputD void ListFinalize(Vector &states_vector, AggregateInputData &aggr_input_data, Vector &result, idx_t count, idx_t offset) { - UnifiedVectorFormat states_data; states_vector.ToUnifiedFormat(count, states_data); auto states = UnifiedVectorFormat::GetData(states_data); @@ -132,7 +129,6 @@ void ListFinalize(Vector &states_vector, AggregateInputData &aggr_input_data, Ve ListVector::Reserve(result, total_len); auto &result_child = ListVector::GetEntry(result); for (idx_t i = 0; i < count; i++) { - auto &state = *states[states_data.sel->get_index(i)]; const auto rid = i + offset; if (state.linked_list.total_capacity == 0) { @@ -147,7 +143,6 @@ void ListFinalize(Vector &states_vector, AggregateInputData &aggr_input_data, Ve } void ListCombineFunction(Vector &states_vector, Vector &combined, AggregateInputData &aggr_input_data, idx_t count) { - // Can we use destructive combining? if (aggr_input_data.combine_type == AggregateCombineType::ALLOW_DESTRUCTIVE) { ListAbsorbFunction(states_vector, combined, aggr_input_data, count); @@ -182,9 +177,8 @@ void ListCombineFunction(Vector &states_vector, Vector &combined, AggregateInput unique_ptr ListBindFunction(ClientContext &context, AggregateFunction &function, vector> &arguments) { - - function.return_type = LogicalType::LIST(arguments[0]->return_type); - return make_uniq(function.return_type); + function.SetReturnType(LogicalType::LIST(arguments[0]->return_type)); + return make_uniq(function.GetReturnType()); } } // namespace diff --git a/src/duckdb/extension/core_functions/aggregate/regression/regr_count.cpp b/src/duckdb/extension/core_functions/aggregate/regression/regr_count.cpp index 9215fcfb8..89962af8b 100644 --- a/src/duckdb/extension/core_functions/aggregate/regression/regr_count.cpp +++ b/src/duckdb/extension/core_functions/aggregate/regression/regr_count.cpp @@ -11,7 +11,7 @@ AggregateFunction RegrCountFun::GetFunction() { auto regr_count = AggregateFunction::BinaryAggregate( LogicalType::DOUBLE, LogicalType::DOUBLE, LogicalType::UINTEGER); regr_count.name = "regr_count"; - regr_count.null_handling = FunctionNullHandling::SPECIAL_HANDLING; + regr_count.SetNullHandling(FunctionNullHandling::SPECIAL_HANDLING); return regr_count; } diff --git a/src/duckdb/extension/core_functions/function_list.cpp b/src/duckdb/extension/core_functions/function_list.cpp index a8ba52658..810b020ab 100644 --- a/src/duckdb/extension/core_functions/function_list.cpp +++ b/src/duckdb/extension/core_functions/function_list.cpp @@ -73,8 +73,10 @@ static const StaticFunctionDefinition core_functions[] = { DUCKDB_AGGREGATE_FUNCTION(ApproxTopKFun), DUCKDB_AGGREGATE_FUNCTION_SET(ArgMaxFun), DUCKDB_AGGREGATE_FUNCTION_SET(ArgMaxNullFun), + DUCKDB_AGGREGATE_FUNCTION_SET(ArgMaxNullsLastFun), DUCKDB_AGGREGATE_FUNCTION_SET(ArgMinFun), DUCKDB_AGGREGATE_FUNCTION_SET(ArgMinNullFun), + DUCKDB_AGGREGATE_FUNCTION_SET(ArgMinNullsLastFun), DUCKDB_AGGREGATE_FUNCTION_SET_ALIAS(ArgmaxFun), DUCKDB_AGGREGATE_FUNCTION_SET_ALIAS(ArgminFun), DUCKDB_AGGREGATE_FUNCTION_ALIAS(ArrayAggFun), diff --git a/src/duckdb/extension/core_functions/include/core_functions/aggregate/algebraic/stddev.hpp b/src/duckdb/extension/core_functions/include/core_functions/aggregate/algebraic/stddev.hpp index b2626ee27..e626e117c 100644 --- a/src/duckdb/extension/core_functions/include/core_functions/aggregate/algebraic/stddev.hpp +++ b/src/duckdb/extension/core_functions/include/core_functions/aggregate/algebraic/stddev.hpp @@ -9,7 +9,8 @@ #pragma once #include "duckdb/function/aggregate_function.hpp" -#include +#include +#include namespace duckdb { diff --git a/src/duckdb/extension/core_functions/include/core_functions/aggregate/distributive_functions.hpp b/src/duckdb/extension/core_functions/include/core_functions/aggregate/distributive_functions.hpp index 39bc9459c..4add0a00d 100644 --- a/src/duckdb/extension/core_functions/include/core_functions/aggregate/distributive_functions.hpp +++ b/src/duckdb/extension/core_functions/include/core_functions/aggregate/distributive_functions.hpp @@ -57,6 +57,16 @@ struct ArgMinNullFun { static AggregateFunctionSet GetFunctions(); }; +struct ArgMinNullsLastFun { + static constexpr const char *Name = "arg_min_nulls_last"; + static constexpr const char *Parameters = "arg,val,N"; + static constexpr const char *Description = "Finds the rows with N minimum vals, including nulls. Calculates the arg expression at that row."; + static constexpr const char *Example = "arg_min_null_val(A, B, N)"; + static constexpr const char *Categories = ""; + + static AggregateFunctionSet GetFunctions(); +}; + struct ArgMaxFun { static constexpr const char *Name = "arg_max"; static constexpr const char *Parameters = "arg,val"; @@ -89,6 +99,16 @@ struct ArgMaxNullFun { static AggregateFunctionSet GetFunctions(); }; +struct ArgMaxNullsLastFun { + static constexpr const char *Name = "arg_max_nulls_last"; + static constexpr const char *Parameters = "arg,val,N"; + static constexpr const char *Description = "Finds the rows with N maximum vals, including nulls. Calculates the arg expression at that row."; + static constexpr const char *Example = "arg_min_null_val(A, B, N)"; + static constexpr const char *Categories = ""; + + static AggregateFunctionSet GetFunctions(); +}; + struct BitAndFun { static constexpr const char *Name = "bit_and"; static constexpr const char *Parameters = "arg"; diff --git a/src/duckdb/extension/core_functions/include/core_functions/aggregate/quantile_sort_tree.hpp b/src/duckdb/extension/core_functions/include/core_functions/aggregate/quantile_sort_tree.hpp index 2c796b2e1..c36d0d4fc 100644 --- a/src/duckdb/extension/core_functions/include/core_functions/aggregate/quantile_sort_tree.hpp +++ b/src/duckdb/extension/core_functions/include/core_functions/aggregate/quantile_sort_tree.hpp @@ -300,7 +300,6 @@ struct QuantileIncluded { }; struct QuantileSortTree { - unique_ptr index_tree; QuantileSortTree(AggregateInputData &aggr_input_data, const WindowPartitionInput &partition) { diff --git a/src/duckdb/extension/core_functions/include/core_functions/array_kernels.hpp b/src/duckdb/extension/core_functions/include/core_functions/array_kernels.hpp index dd6e29153..ddcdb92b6 100644 --- a/src/duckdb/extension/core_functions/include/core_functions/array_kernels.hpp +++ b/src/duckdb/extension/core_functions/include/core_functions/array_kernels.hpp @@ -13,7 +13,6 @@ struct InnerProductOp { template static TYPE Operation(const TYPE *lhs_data, const TYPE *rhs_data, const idx_t count) { - TYPE result = 0; auto lhs_ptr = lhs_data; @@ -43,7 +42,6 @@ struct CosineSimilarityOp { template static TYPE Operation(const TYPE *lhs_data, const TYPE *rhs_data, const idx_t count) { - TYPE distance = 0; TYPE norm_l = 0; TYPE norm_r = 0; @@ -78,7 +76,6 @@ struct DistanceSquaredOp { template static TYPE Operation(const TYPE *lhs_data, const TYPE *rhs_data, const idx_t count) { - TYPE distance = 0; auto l_ptr = lhs_data; diff --git a/src/duckdb/extension/core_functions/lambda_functions.cpp b/src/duckdb/extension/core_functions/lambda_functions.cpp index f1aa80af7..89356921c 100644 --- a/src/duckdb/extension/core_functions/lambda_functions.cpp +++ b/src/duckdb/extension/core_functions/lambda_functions.cpp @@ -18,7 +18,6 @@ struct LambdaExecuteInfo { LambdaExecuteInfo(ClientContext &context, const Expression &lambda_expr, const DataChunk &args, const bool has_index, const Vector &child_vector) : has_index(has_index) { - expr_executor = make_uniq(context, lambda_expr); // get the input types for the input chunk @@ -103,7 +102,6 @@ struct ListFilterFunctor { //! Uses the lambda vector to filter the incoming list and to append the filtered list to the result vector static void AppendResult(Vector &result, Vector &lambda_vector, const idx_t elem_cnt, list_entry_t *result_entries, ListFilterInfo &info, LambdaExecuteInfo &execute_info) { - idx_t count = 0; SelectionVector sel(elem_cnt); UnifiedVectorFormat lambda_data; @@ -184,7 +182,6 @@ LambdaFunctions::GetMutableColumnInfo(vector &data) static void ExecuteExpression(const idx_t elem_cnt, const LambdaFunctions::ColumnInfo &column_info, const vector &column_infos, const Vector &index_vector, LambdaExecuteInfo &info) { - info.input_chunk.SetCardinality(elem_cnt); info.lambda_chunk.SetCardinality(elem_cnt); @@ -203,7 +200,6 @@ static void ExecuteExpression(const idx_t elem_cnt, const LambdaFunctions::Colum // (slice and) reference the other columns vector slices; for (idx_t i = 0; i < column_infos.size(); i++) { - if (column_infos[i].vector.get().GetVectorType() == VectorType::CONSTANT_VECTOR) { // only reference constant vectorsl info.input_chunk.data[i + slice_offset].Reference(column_infos[i].vector); @@ -273,7 +269,6 @@ LogicalType LambdaFunctions::BindBinaryChildren(const vector &funct template static void ExecuteLambda(DataChunk &args, ExpressionState &state, Vector &result) { - bool result_is_null = false; LambdaFunctions::LambdaInfo info(args, state, result, result_is_null); if (result_is_null) { @@ -302,7 +297,6 @@ static void ExecuteLambda(DataChunk &args, ExpressionState &state, Vector &resul idx_t elem_cnt = 0; idx_t offset = 0; for (idx_t row_idx = 0; row_idx < info.row_count; row_idx++) { - auto list_idx = info.list_column_format.sel->get_index(row_idx); const auto &list_entry = info.list_entries[list_idx]; @@ -322,10 +316,8 @@ static void ExecuteLambda(DataChunk &args, ExpressionState &state, Vector &resul // iterate the elements of the current list and create the corresponding selection vectors for (idx_t child_idx = 0; child_idx < list_entry.length; child_idx++) { - // reached STANDARD_VECTOR_SIZE elements if (elem_cnt == STANDARD_VECTOR_SIZE) { - execute_info.lambda_chunk.Reset(); ExecuteExpression(elem_cnt, child_info, info.column_infos, index_vector, execute_info); auto &lambda_vector = execute_info.lambda_chunk.data[0]; @@ -368,8 +360,8 @@ unique_ptr LambdaFunctions::ListLambdaPrepareBind(vectorreturn_type.id() == LogicalTypeId::SQLNULL) { bound_function.arguments[0] = LogicalType::SQLNULL; - bound_function.return_type = LogicalType::SQLNULL; - return make_uniq(bound_function.return_type, nullptr); + bound_function.SetReturnType(LogicalType::SQLNULL); + return make_uniq(bound_function.GetReturnType(), nullptr); } // prepared statements if (arguments[0]->return_type.id() == LogicalTypeId::UNKNOWN) { @@ -393,7 +385,7 @@ unique_ptr LambdaFunctions::ListLambdaBind(ClientContext &context, auto &bound_lambda_expr = arguments[1]->Cast(); auto lambda_expr = std::move(bound_lambda_expr.lambda_expr); - return make_uniq(bound_function.return_type, std::move(lambda_expr), has_index); + return make_uniq(bound_function.GetReturnType(), std::move(lambda_expr), has_index); } void LambdaFunctions::ListTransformFunction(DataChunk &args, ExpressionState &state, Vector &result) { diff --git a/src/duckdb/extension/core_functions/scalar/array/array_functions.cpp b/src/duckdb/extension/core_functions/scalar/array/array_functions.cpp index ecc1ce97f..a5d7067f1 100644 --- a/src/duckdb/extension/core_functions/scalar/array/array_functions.cpp +++ b/src/duckdb/extension/core_functions/scalar/array/array_functions.cpp @@ -6,14 +6,13 @@ namespace duckdb { static unique_ptr ArrayGenericBinaryBind(ClientContext &context, ScalarFunction &bound_function, vector> &arguments) { - const auto &lhs_type = arguments[0]->return_type; const auto &rhs_type = arguments[1]->return_type; if (lhs_type.IsUnknown() && rhs_type.IsUnknown()) { bound_function.arguments[0] = rhs_type; bound_function.arguments[1] = lhs_type; - bound_function.return_type = LogicalType::UNKNOWN; + bound_function.SetReturnType(LogicalType::UNKNOWN); return nullptr; } @@ -212,11 +211,11 @@ static void AddArrayFoldFunction(ScalarFunctionSet &set, const LogicalType &type const auto array = LogicalType::ARRAY(type, optional_idx()); if (type.id() == LogicalTypeId::FLOAT) { ScalarFunction function({array, array}, type, ArrayGenericFold, ArrayGenericBinaryBind); - BaseScalarFunction::SetReturnsError(function); + function.SetFallible(); set.AddFunction(function); } else if (type.id() == LogicalTypeId::DOUBLE) { ScalarFunction function({array, array}, type, ArrayGenericFold, ArrayGenericBinaryBind); - BaseScalarFunction::SetReturnsError(function); + function.SetFallible(); set.AddFunction(function); } else { throw NotImplementedException("Array function not implemented for type %s", type.ToString()); @@ -273,7 +272,7 @@ ScalarFunctionSet ArrayCrossProductFun::GetFunctions() { set.AddFunction( ScalarFunction({double_array, double_array}, double_array, ArrayFixedCombine)); for (auto &func : set.functions) { - BaseScalarFunction::SetReturnsError(func); + func.SetFallible(); } return set; } diff --git a/src/duckdb/extension/core_functions/scalar/array/array_value.cpp b/src/duckdb/extension/core_functions/scalar/array/array_value.cpp index ec8500a87..27f8b8969 100644 --- a/src/duckdb/extension/core_functions/scalar/array/array_value.cpp +++ b/src/duckdb/extension/core_functions/scalar/array/array_value.cpp @@ -62,8 +62,8 @@ unique_ptr ArrayValueBind(ClientContext &context, ScalarFunction & // this is more for completeness reasons bound_function.varargs = child_type; - bound_function.return_type = LogicalType::ARRAY(child_type, arguments.size()); - return make_uniq(bound_function.return_type); + bound_function.SetReturnType(LogicalType::ARRAY(child_type, arguments.size())); + return make_uniq(bound_function.GetReturnType()); } unique_ptr ArrayValueStats(ClientContext &context, FunctionStatisticsInput &input) { @@ -84,7 +84,7 @@ ScalarFunction ArrayValueFun::GetFunction() { ScalarFunction fun("array_value", {}, LogicalTypeId::ARRAY, ArrayValueFunction, ArrayValueBind, nullptr, ArrayValueStats); fun.varargs = LogicalType::ANY; - fun.null_handling = FunctionNullHandling::SPECIAL_HANDLING; + fun.SetNullHandling(FunctionNullHandling::SPECIAL_HANDLING); return fun; } diff --git a/src/duckdb/extension/core_functions/scalar/bit/bitstring.cpp b/src/duckdb/extension/core_functions/scalar/bit/bitstring.cpp index fdd499e22..b3e26a7f1 100644 --- a/src/duckdb/extension/core_functions/scalar/bit/bitstring.cpp +++ b/src/duckdb/extension/core_functions/scalar/bit/bitstring.cpp @@ -47,7 +47,7 @@ ScalarFunctionSet BitStringFun::GetFunctions() { bitstring.AddFunction( ScalarFunction({LogicalType::BIT, LogicalType::INTEGER}, LogicalType::BIT, BitStringFunction)); for (auto &func : bitstring.functions) { - BaseScalarFunction::SetReturnsError(func); + func.SetFallible(); } return bitstring; } @@ -72,7 +72,7 @@ struct GetBitOperator { ScalarFunction GetBitFun::GetFunction() { ScalarFunction func({LogicalType::BIT, LogicalType::INTEGER}, LogicalType::INTEGER, ScalarFunction::BinaryFunction); - BaseScalarFunction::SetReturnsError(func); + func.SetFallible(); return func; } @@ -100,7 +100,7 @@ static void SetBitOperation(DataChunk &args, ExpressionState &state, Vector &res ScalarFunction SetBitFun::GetFunction() { ScalarFunction function({LogicalType::BIT, LogicalType::INTEGER, LogicalType::INTEGER}, LogicalType::BIT, SetBitOperation); - BaseScalarFunction::SetReturnsError(function); + function.SetFallible(); return function; } diff --git a/src/duckdb/extension/core_functions/scalar/blob/base64.cpp b/src/duckdb/extension/core_functions/scalar/blob/base64.cpp index d2c372114..77ae51731 100644 --- a/src/duckdb/extension/core_functions/scalar/blob/base64.cpp +++ b/src/duckdb/extension/core_functions/scalar/blob/base64.cpp @@ -43,7 +43,7 @@ ScalarFunction ToBase64Fun::GetFunction() { ScalarFunction FromBase64Fun::GetFunction() { ScalarFunction function({LogicalType::VARCHAR}, LogicalType::BLOB, Base64DecodeFunction); - BaseScalarFunction::SetReturnsError(function); + function.SetFallible(); return function; } diff --git a/src/duckdb/extension/core_functions/scalar/blob/encode.cpp b/src/duckdb/extension/core_functions/scalar/blob/encode.cpp index b9bfa986a..08b594b3d 100644 --- a/src/duckdb/extension/core_functions/scalar/blob/encode.cpp +++ b/src/duckdb/extension/core_functions/scalar/blob/encode.cpp @@ -39,7 +39,7 @@ ScalarFunction EncodeFun::GetFunction() { ScalarFunction DecodeFun::GetFunction() { ScalarFunction function({LogicalType::BLOB}, LogicalType::VARCHAR, DecodeFunction); - BaseScalarFunction::SetReturnsError(function); + function.SetFallible(); return function; } diff --git a/src/duckdb/extension/core_functions/scalar/date/current.cpp b/src/duckdb/extension/core_functions/scalar/date/current.cpp index aa041f627..bf928618d 100644 --- a/src/duckdb/extension/core_functions/scalar/date/current.cpp +++ b/src/duckdb/extension/core_functions/scalar/date/current.cpp @@ -23,7 +23,7 @@ static void CurrentTimestampFunction(DataChunk &input, ExpressionState &state, V ScalarFunction GetCurrentTimestampFun::GetFunction() { ScalarFunction current_timestamp({}, LogicalType::TIMESTAMP_TZ, CurrentTimestampFunction); - current_timestamp.stability = FunctionStability::CONSISTENT_WITHIN_QUERY; + current_timestamp.SetStability(FunctionStability::CONSISTENT_WITHIN_QUERY); return current_timestamp; } diff --git a/src/duckdb/extension/core_functions/scalar/date/date_part.cpp b/src/duckdb/extension/core_functions/scalar/date/date_part.cpp index 7ced59dcb..5f0325938 100644 --- a/src/duckdb/extension/core_functions/scalar/date/date_part.cpp +++ b/src/duckdb/extension/core_functions/scalar/date/date_part.cpp @@ -1772,7 +1772,7 @@ unique_ptr DatePartBind(ClientContext &context, ScalarFunction &bo arguments.erase(arguments.begin()); bound_function.arguments.erase(bound_function.arguments.begin()); bound_function.name = "julian"; - bound_function.return_type = LogicalType::DOUBLE; + bound_function.SetReturnType(LogicalType::DOUBLE); switch (arguments[0]->return_type.id()) { case LogicalType::TIMESTAMP: case LogicalType::TIMESTAMP_S: @@ -1793,7 +1793,7 @@ unique_ptr DatePartBind(ClientContext &context, ScalarFunction &bo arguments.erase(arguments.begin()); bound_function.arguments.erase(bound_function.arguments.begin()); bound_function.name = "epoch"; - bound_function.return_type = LogicalType::DOUBLE; + bound_function.SetReturnType(LogicalType::DOUBLE); switch (arguments[0]->return_type.id()) { case LogicalType::TIMESTAMP: case LogicalType::TIMESTAMP_S: @@ -1844,7 +1844,7 @@ ScalarFunctionSet GetGenericDatePartFunction(scalar_function_t date_func, scalar nullptr, ts_stats, DATE_CACHE)); operator_set.AddFunction(ScalarFunction({LogicalType::INTERVAL}, LogicalType::BIGINT, std::move(interval_func))); for (auto &func : operator_set.functions) { - BaseScalarFunction::SetReturnsError(func); + func.SetFallible(); } return operator_set; } @@ -1974,8 +1974,8 @@ struct StructDatePart { } Function::EraseArgument(bound_function, arguments, 0); - bound_function.return_type = LogicalType::STRUCT(struct_children); - return make_uniq(bound_function.return_type, part_codes); + bound_function.SetReturnType(LogicalType::STRUCT(struct_children)); + return make_uniq(bound_function.GetReturnType(), part_codes); } template @@ -2168,7 +2168,7 @@ ScalarFunctionSet QuarterFun::GetFunctions() { ScalarFunctionSet DayOfWeekFun::GetFunctions() { auto set = GetDatePartFunction(); for (auto &func : set.functions) { - BaseScalarFunction::SetReturnsError(func); + func.SetFallible(); } return set; } @@ -2203,7 +2203,7 @@ ScalarFunctionSet TimezoneFun::GetFunctions() { operator_set.AddFunction(function); for (auto &func : operator_set.functions) { - BaseScalarFunction::SetReturnsError(func); + func.SetFallible(); } return operator_set; @@ -2408,7 +2408,7 @@ ScalarFunctionSet DatePartFun::GetFunctions() { date_part.AddFunction(StructDatePart::GetFunction(LogicalType::TIME_TZ)); for (auto &func : date_part.functions) { - BaseScalarFunction::SetReturnsError(func); + func.SetFallible(); } return date_part; diff --git a/src/duckdb/extension/core_functions/scalar/date/date_sub.cpp b/src/duckdb/extension/core_functions/scalar/date/date_sub.cpp index dab1c8231..6e3d450d7 100644 --- a/src/duckdb/extension/core_functions/scalar/date/date_sub.cpp +++ b/src/duckdb/extension/core_functions/scalar/date/date_sub.cpp @@ -37,7 +37,6 @@ struct DateSub { struct MonthOperator { template static inline TR Operation(TA start_ts, TB end_ts) { - if (start_ts > end_ts) { return -MonthOperator::Operation(end_ts, start_ts); } diff --git a/src/duckdb/extension/core_functions/scalar/date/date_trunc.cpp b/src/duckdb/extension/core_functions/scalar/date/date_trunc.cpp index 819efbac4..ef6c9ff26 100644 --- a/src/duckdb/extension/core_functions/scalar/date/date_trunc.cpp +++ b/src/duckdb/extension/core_functions/scalar/date/date_trunc.cpp @@ -703,7 +703,7 @@ unique_ptr DateTruncBind(ClientContext &context, ScalarFunction &b default: throw NotImplementedException("Temporal argument type for DATETRUNC"); } - bound_function.return_type = LogicalType::DATE; + bound_function.SetReturnType(LogicalType::DATE); break; default: switch (bound_function.arguments[1].id()) { @@ -733,7 +733,7 @@ ScalarFunctionSet DateTruncFun::GetFunctions() { date_trunc.AddFunction(ScalarFunction({LogicalType::VARCHAR, LogicalType::INTERVAL}, LogicalType::INTERVAL, DateTruncFunction)); for (auto &func : date_trunc.functions) { - BaseScalarFunction::SetReturnsError(func); + func.SetFallible(); } return date_trunc; } diff --git a/src/duckdb/extension/core_functions/scalar/date/make_date.cpp b/src/duckdb/extension/core_functions/scalar/date/make_date.cpp index 189d2a229..d7f1eaf99 100644 --- a/src/duckdb/extension/core_functions/scalar/date/make_date.cpp +++ b/src/duckdb/extension/core_functions/scalar/date/make_date.cpp @@ -65,7 +65,6 @@ void ExecuteStructMakeDate(DataChunk &input, ExpressionState &state, Vector &res struct MakeTimeOperator { template static RESULT_TYPE Operation(HH hh, MM mm, SS ss) { - auto hh_32 = Cast::Operation(hh); auto mm_32 = Cast::Operation(mm); // Have to check this separately because safe casting of DOUBLE => INT32 can round. @@ -149,7 +148,7 @@ ScalarFunctionSet MakeDateFun::GetFunctions() { make_date.AddFunction( ScalarFunction({LogicalType::STRUCT(make_date_children)}, LogicalType::DATE, ExecuteStructMakeDate)); for (auto &func : make_date.functions) { - BaseScalarFunction::SetReturnsError(func); + func.SetFallible(); } return make_date; } @@ -157,7 +156,7 @@ ScalarFunctionSet MakeDateFun::GetFunctions() { ScalarFunction MakeTimeFun::GetFunction() { ScalarFunction function({LogicalType::BIGINT, LogicalType::BIGINT, LogicalType::DOUBLE}, LogicalType::TIME, ExecuteMakeTime); - BaseScalarFunction::SetReturnsError(function); + function.SetFallible(); return function; } @@ -170,7 +169,7 @@ ScalarFunctionSet MakeTimestampFun::GetFunctions() { ScalarFunction({LogicalType::BIGINT}, LogicalType::TIMESTAMP, ExecuteMakeTimestamp)); for (auto &func : operator_set.functions) { - BaseScalarFunction::SetReturnsError(func); + func.SetFallible(); } return operator_set; } diff --git a/src/duckdb/extension/core_functions/scalar/date/time_bucket.cpp b/src/duckdb/extension/core_functions/scalar/date/time_bucket.cpp index 6427a55f5..e767282d3 100644 --- a/src/duckdb/extension/core_functions/scalar/date/time_bucket.cpp +++ b/src/duckdb/extension/core_functions/scalar/date/time_bucket.cpp @@ -16,7 +16,6 @@ namespace duckdb { namespace { struct TimeBucket { - // Use 2000-01-03 00:00:00 (Monday) as origin when bucket_width is days, hours, ... for TimescaleDB compatibility // There are 10959 days between 1970-01-01 and 2000-01-03 constexpr static const int64_t DEFAULT_ORIGIN_MICROS = 10959 * Interval::MICROS_PER_DAY; @@ -369,7 +368,7 @@ ScalarFunctionSet TimeBucketFun::GetFunctions() { time_bucket.AddFunction(ScalarFunction({LogicalType::INTERVAL, LogicalType::TIMESTAMP, LogicalType::TIMESTAMP}, LogicalType::TIMESTAMP, TimeBucketOriginFunction)); for (auto &func : time_bucket.functions) { - BaseScalarFunction::SetReturnsError(func); + func.SetFallible(); } return time_bucket; } diff --git a/src/duckdb/extension/core_functions/scalar/date/to_interval.cpp b/src/duckdb/extension/core_functions/scalar/date/to_interval.cpp index d8c0f58e0..8ad21e543 100644 --- a/src/duckdb/extension/core_functions/scalar/date/to_interval.cpp +++ b/src/duckdb/extension/core_functions/scalar/date/to_interval.cpp @@ -183,7 +183,7 @@ ScalarFunctionSet GetIntegerIntervalFunctions() { function_set.AddFunction(ScalarFunction({LogicalType::BIGINT}, LogicalType::INTERVAL, ScalarFunction::UnaryFunction)); for (auto &func : function_set.functions) { - BaseScalarFunction::SetReturnsError(func); + func.SetFallible(); } return function_set; } @@ -225,35 +225,35 @@ ScalarFunctionSet ToDaysFun::GetFunctions() { ScalarFunction ToHoursFun::GetFunction() { ScalarFunction function({LogicalType::BIGINT}, LogicalType::INTERVAL, ScalarFunction::UnaryFunction); - BaseScalarFunction::SetReturnsError(function); + function.SetFallible(); return function; } ScalarFunction ToMinutesFun::GetFunction() { ScalarFunction function({LogicalType::BIGINT}, LogicalType::INTERVAL, ScalarFunction::UnaryFunction); - BaseScalarFunction::SetReturnsError(function); + function.SetFallible(); return function; } ScalarFunction ToSecondsFun::GetFunction() { ScalarFunction function({LogicalType::DOUBLE}, LogicalType::INTERVAL, ScalarFunction::UnaryFunction); - BaseScalarFunction::SetReturnsError(function); + function.SetFallible(); return function; } ScalarFunction ToMillisecondsFun::GetFunction() { ScalarFunction function({LogicalType::DOUBLE}, LogicalType::INTERVAL, ScalarFunction::UnaryFunction); - BaseScalarFunction::SetReturnsError(function); + function.SetFallible(); return function; } ScalarFunction ToMicrosecondsFun::GetFunction() { ScalarFunction function({LogicalType::BIGINT}, LogicalType::INTERVAL, ScalarFunction::UnaryFunction); - BaseScalarFunction::SetReturnsError(function); + function.SetFallible(); return function; } diff --git a/src/duckdb/extension/core_functions/scalar/debug/vector_type.cpp b/src/duckdb/extension/core_functions/scalar/debug/vector_type.cpp index 627d7ac28..73545544f 100644 --- a/src/duckdb/extension/core_functions/scalar/debug/vector_type.cpp +++ b/src/duckdb/extension/core_functions/scalar/debug/vector_type.cpp @@ -17,7 +17,7 @@ ScalarFunction VectorTypeFun::GetFunction() { {LogicalType::ANY}, // argument list LogicalType::VARCHAR, // return type VectorTypeFunction); - vector_type_fun.null_handling = FunctionNullHandling::SPECIAL_HANDLING; + vector_type_fun.SetNullHandling(FunctionNullHandling::SPECIAL_HANDLING); return vector_type_fun; } diff --git a/src/duckdb/extension/core_functions/scalar/enum/enum_functions.cpp b/src/duckdb/extension/core_functions/scalar/enum/enum_functions.cpp index be3c5c03b..8de43097e 100644 --- a/src/duckdb/extension/core_functions/scalar/enum/enum_functions.cpp +++ b/src/duckdb/extension/core_functions/scalar/enum/enum_functions.cpp @@ -90,16 +90,16 @@ static unique_ptr BindEnumCodeFunction(ClientContext &context, Sca auto phy_type = EnumType::GetPhysicalType(arguments[0]->return_type); switch (phy_type) { case PhysicalType::UINT8: - bound_function.return_type = LogicalType(LogicalTypeId::UTINYINT); + bound_function.SetReturnType(LogicalType(LogicalTypeId::UTINYINT)); break; case PhysicalType::UINT16: - bound_function.return_type = LogicalType(LogicalTypeId::USMALLINT); + bound_function.SetReturnType(LogicalType(LogicalTypeId::USMALLINT)); break; case PhysicalType::UINT32: - bound_function.return_type = LogicalType(LogicalTypeId::UINTEGER); + bound_function.SetReturnType(LogicalType(LogicalTypeId::UINTEGER)); break; case PhysicalType::UINT64: - bound_function.return_type = LogicalType(LogicalTypeId::UBIGINT); + bound_function.SetReturnType(LogicalType(LogicalTypeId::UBIGINT)); break; default: throw InternalException("Unsupported Enum Internal Type"); @@ -131,33 +131,33 @@ static unique_ptr BindEnumRangeBoundaryFunction(ClientContext &con ScalarFunction EnumFirstFun::GetFunction() { auto fun = ScalarFunction({LogicalType::ANY}, LogicalType::VARCHAR, EnumFirstFunction, BindEnumFunction); - fun.null_handling = FunctionNullHandling::SPECIAL_HANDLING; + fun.SetNullHandling(FunctionNullHandling::SPECIAL_HANDLING); return fun; } ScalarFunction EnumLastFun::GetFunction() { auto fun = ScalarFunction({LogicalType::ANY}, LogicalType::VARCHAR, EnumLastFunction, BindEnumFunction); - fun.null_handling = FunctionNullHandling::SPECIAL_HANDLING; + fun.SetNullHandling(FunctionNullHandling::SPECIAL_HANDLING); return fun; } ScalarFunction EnumCodeFun::GetFunction() { auto fun = ScalarFunction({LogicalType::ANY}, LogicalType::ANY, EnumCodeFunction, BindEnumCodeFunction); - fun.null_handling = FunctionNullHandling::SPECIAL_HANDLING; + fun.SetNullHandling(FunctionNullHandling::SPECIAL_HANDLING); return fun; } ScalarFunction EnumRangeFun::GetFunction() { auto fun = ScalarFunction({LogicalType::ANY}, LogicalType::LIST(LogicalType::VARCHAR), EnumRangeFunction, BindEnumFunction); - fun.null_handling = FunctionNullHandling::SPECIAL_HANDLING; + fun.SetNullHandling(FunctionNullHandling::SPECIAL_HANDLING); return fun; } ScalarFunction EnumRangeBoundaryFun::GetFunction() { auto fun = ScalarFunction({LogicalType::ANY, LogicalType::ANY}, LogicalType::LIST(LogicalType::VARCHAR), EnumRangeBoundaryFunction, BindEnumRangeBoundaryFunction); - fun.null_handling = FunctionNullHandling::SPECIAL_HANDLING; + fun.SetNullHandling(FunctionNullHandling::SPECIAL_HANDLING); return fun; } diff --git a/src/duckdb/extension/core_functions/scalar/generic/alias.cpp b/src/duckdb/extension/core_functions/scalar/generic/alias.cpp index 4edadcaaf..222510cb8 100644 --- a/src/duckdb/extension/core_functions/scalar/generic/alias.cpp +++ b/src/duckdb/extension/core_functions/scalar/generic/alias.cpp @@ -11,7 +11,7 @@ static void AliasFunction(DataChunk &args, ExpressionState &state, Vector &resul ScalarFunction AliasFun::GetFunction() { auto fun = ScalarFunction({LogicalType::ANY}, LogicalType::VARCHAR, AliasFunction); - fun.null_handling = FunctionNullHandling::SPECIAL_HANDLING; + fun.SetNullHandling(FunctionNullHandling::SPECIAL_HANDLING); return fun; } diff --git a/src/duckdb/extension/core_functions/scalar/generic/binning.cpp b/src/duckdb/extension/core_functions/scalar/generic/binning.cpp index ffaceaf3d..a799bd1f1 100644 --- a/src/duckdb/extension/core_functions/scalar/generic/binning.cpp +++ b/src/duckdb/extension/core_functions/scalar/generic/binning.cpp @@ -422,7 +422,7 @@ unique_ptr BindEquiWidthFunction(ClientContext &, ScalarFunction & child_type = arguments[1]->return_type; break; } - bound_function.return_type = LogicalType::LIST(child_type); + bound_function.SetReturnType(LogicalType::LIST(child_type)); return nullptr; } @@ -478,7 +478,7 @@ void EquiWidthBinSerialize(Serializer &, const optional_ptr, const } unique_ptr EquiWidthBinDeserialize(Deserializer &deserializer, ScalarFunction &function) { - function.return_type = deserializer.Get(); + function.SetReturnType(deserializer.Get()); return nullptr; } @@ -504,7 +504,7 @@ ScalarFunctionSet EquiWidthBinsFun::GetFunctions() { for (auto &function : functions.functions) { function.serialize = EquiWidthBinSerialize; function.deserialize = EquiWidthBinDeserialize; - BaseScalarFunction::SetReturnsError(function); + function.SetFallible(); } return functions; } diff --git a/src/duckdb/extension/core_functions/scalar/generic/can_implicitly_cast.cpp b/src/duckdb/extension/core_functions/scalar/generic/can_implicitly_cast.cpp index 1f28c8da8..a1f46f3d3 100644 --- a/src/duckdb/extension/core_functions/scalar/generic/can_implicitly_cast.cpp +++ b/src/duckdb/extension/core_functions/scalar/generic/can_implicitly_cast.cpp @@ -36,7 +36,7 @@ unique_ptr BindCanCastImplicitlyExpression(FunctionBindExpressionInp ScalarFunction CanCastImplicitlyFun::GetFunction() { auto fun = ScalarFunction({LogicalType::ANY, LogicalType::ANY}, LogicalType::BOOLEAN, CanCastImplicitlyFunction); - fun.null_handling = FunctionNullHandling::SPECIAL_HANDLING; + fun.SetNullHandling(FunctionNullHandling::SPECIAL_HANDLING); fun.bind_expression = BindCanCastImplicitlyExpression; return fun; } diff --git a/src/duckdb/extension/core_functions/scalar/generic/cast_to_type.cpp b/src/duckdb/extension/core_functions/scalar/generic/cast_to_type.cpp index 4b87705d7..0cbf56344 100644 --- a/src/duckdb/extension/core_functions/scalar/generic/cast_to_type.cpp +++ b/src/duckdb/extension/core_functions/scalar/generic/cast_to_type.cpp @@ -24,7 +24,7 @@ unique_ptr BindCastToTypeFunction(FunctionBindExpressionInput &input } // namespace ScalarFunction CastToTypeFun::GetFunction() { auto fun = ScalarFunction({LogicalType::ANY, LogicalType::ANY}, LogicalType::ANY, CastToTypeFunction); - fun.null_handling = FunctionNullHandling::SPECIAL_HANDLING; + fun.SetNullHandling(FunctionNullHandling::SPECIAL_HANDLING); fun.bind_expression = BindCastToTypeFunction; return fun; } diff --git a/src/duckdb/extension/core_functions/scalar/generic/current_setting.cpp b/src/duckdb/extension/core_functions/scalar/generic/current_setting.cpp index 4464f0544..46e045c40 100644 --- a/src/duckdb/extension/core_functions/scalar/generic/current_setting.cpp +++ b/src/duckdb/extension/core_functions/scalar/generic/current_setting.cpp @@ -33,7 +33,6 @@ void CurrentSettingFunction(DataChunk &args, ExpressionState &state, Vector &res unique_ptr CurrentSettingBind(ClientContext &context, ScalarFunction &bound_function, vector> &arguments) { - auto &key_child = arguments[0]; if (key_child->return_type.id() == LogicalTypeId::UNKNOWN) { throw ParameterNotResolvedException(); @@ -53,13 +52,10 @@ unique_ptr CurrentSettingBind(ClientContext &context, ScalarFuncti if (!context.TryGetCurrentSetting(key, val)) { auto extension_name = Catalog::AutoloadExtensionByConfigName(context, key); // If autoloader didn't throw, the config is now available - if (!context.TryGetCurrentSetting(key, val)) { - throw InternalException("Extension %s did not provide the '%s' config setting", - extension_name.ToStdString(), key); - } + context.TryGetCurrentSetting(key, val); } - bound_function.return_type = val.type(); + bound_function.SetReturnType(val.type()); return make_uniq(val); } @@ -67,7 +63,7 @@ unique_ptr CurrentSettingBind(ClientContext &context, ScalarFuncti ScalarFunction CurrentSettingFun::GetFunction() { auto fun = ScalarFunction({LogicalType::VARCHAR}, LogicalType::ANY, CurrentSettingFunction, CurrentSettingBind); - fun.null_handling = FunctionNullHandling::SPECIAL_HANDLING; + fun.SetNullHandling(FunctionNullHandling::SPECIAL_HANDLING); return fun; } diff --git a/src/duckdb/extension/core_functions/scalar/generic/hash.cpp b/src/duckdb/extension/core_functions/scalar/generic/hash.cpp index 184919447..a829d67fe 100644 --- a/src/duckdb/extension/core_functions/scalar/generic/hash.cpp +++ b/src/duckdb/extension/core_functions/scalar/generic/hash.cpp @@ -12,7 +12,7 @@ static void HashFunction(DataChunk &args, ExpressionState &state, Vector &result ScalarFunction HashFun::GetFunction() { auto hash_fun = ScalarFunction({LogicalType::ANY}, LogicalType::HASH, HashFunction); hash_fun.varargs = LogicalType::ANY; - hash_fun.null_handling = FunctionNullHandling::SPECIAL_HANDLING; + hash_fun.SetNullHandling(FunctionNullHandling::SPECIAL_HANDLING); return hash_fun; } diff --git a/src/duckdb/extension/core_functions/scalar/generic/least.cpp b/src/duckdb/extension/core_functions/scalar/generic/least.cpp index 519350c1b..d922f1a17 100644 --- a/src/duckdb/extension/core_functions/scalar/generic/least.cpp +++ b/src/duckdb/extension/core_functions/scalar/generic/least.cpp @@ -232,7 +232,7 @@ unique_ptr BindLeastGreatest(ClientContext &context, ScalarFunctio } bound_function.arguments[0] = child_type; bound_function.varargs = child_type; - bound_function.return_type = child_type; + bound_function.SetReturnType(child_type); return nullptr; } diff --git a/src/duckdb/extension/core_functions/scalar/generic/replace_type.cpp b/src/duckdb/extension/core_functions/scalar/generic/replace_type.cpp index b6c823a33..5a25237c8 100644 --- a/src/duckdb/extension/core_functions/scalar/generic/replace_type.cpp +++ b/src/duckdb/extension/core_functions/scalar/generic/replace_type.cpp @@ -26,7 +26,7 @@ unique_ptr BindReplaceTypeFunction(FunctionBindExpressionInput &inpu ScalarFunction ReplaceTypeFun::GetFunction() { auto fun = ScalarFunction({LogicalType::ANY, LogicalType::ANY, LogicalType::ANY}, LogicalType::ANY, ReplaceTypeFunction); - fun.null_handling = FunctionNullHandling::SPECIAL_HANDLING; + fun.SetNullHandling(FunctionNullHandling::SPECIAL_HANDLING); fun.bind_expression = BindReplaceTypeFunction; return fun; } diff --git a/src/duckdb/extension/core_functions/scalar/generic/stats.cpp b/src/duckdb/extension/core_functions/scalar/generic/stats.cpp index d6b5f5e13..3bd18ae01 100644 --- a/src/duckdb/extension/core_functions/scalar/generic/stats.cpp +++ b/src/duckdb/extension/core_functions/scalar/generic/stats.cpp @@ -49,8 +49,8 @@ unique_ptr StatsPropagateStats(ClientContext &context, FunctionS ScalarFunction StatsFun::GetFunction() { ScalarFunction stats({LogicalType::ANY}, LogicalType::VARCHAR, StatsFunction, StatsBind, nullptr, StatsPropagateStats); - stats.null_handling = FunctionNullHandling::SPECIAL_HANDLING; - stats.stability = FunctionStability::VOLATILE; + stats.SetNullHandling(FunctionNullHandling::SPECIAL_HANDLING); + stats.SetStability(FunctionStability::VOLATILE); return stats; } diff --git a/src/duckdb/extension/core_functions/scalar/generic/system_functions.cpp b/src/duckdb/extension/core_functions/scalar/generic/system_functions.cpp index 5a0b25a6d..ea35972bd 100644 --- a/src/duckdb/extension/core_functions/scalar/generic/system_functions.cpp +++ b/src/duckdb/extension/core_functions/scalar/generic/system_functions.cpp @@ -108,19 +108,19 @@ void VersionFunction(DataChunk &input, ExpressionState &state, Vector &result) { ScalarFunction CurrentQueryFun::GetFunction() { ScalarFunction current_query({}, LogicalType::VARCHAR, CurrentQueryFunction); - current_query.stability = FunctionStability::VOLATILE; + current_query.SetStability(FunctionStability::VOLATILE); return current_query; } ScalarFunction CurrentSchemaFun::GetFunction() { ScalarFunction current_schema({}, LogicalType::VARCHAR, CurrentSchemaFunction); - current_schema.stability = FunctionStability::CONSISTENT_WITHIN_QUERY; + current_schema.SetStability(FunctionStability::CONSISTENT_WITHIN_QUERY); return current_schema; } ScalarFunction CurrentDatabaseFun::GetFunction() { ScalarFunction current_database({}, LogicalType::VARCHAR, CurrentDatabaseFunction); - current_database.stability = FunctionStability::CONSISTENT_WITHIN_QUERY; + current_database.SetStability(FunctionStability::CONSISTENT_WITHIN_QUERY); return current_database; } @@ -128,20 +128,20 @@ ScalarFunction CurrentSchemasFun::GetFunction() { auto varchar_list_type = LogicalType::LIST(LogicalType::VARCHAR); ScalarFunction current_schemas({LogicalType::BOOLEAN}, varchar_list_type, CurrentSchemasFunction, CurrentSchemasBind); - current_schemas.stability = FunctionStability::CONSISTENT_WITHIN_QUERY; + current_schemas.SetStability(FunctionStability::CONSISTENT_WITHIN_QUERY); return current_schemas; } ScalarFunction InSearchPathFun::GetFunction() { ScalarFunction in_search_path({LogicalType::VARCHAR, LogicalType::VARCHAR}, LogicalType::BOOLEAN, InSearchPathFunction); - in_search_path.stability = FunctionStability::CONSISTENT_WITHIN_QUERY; + in_search_path.SetStability(FunctionStability::CONSISTENT_WITHIN_QUERY); return in_search_path; } ScalarFunction CurrentTransactionIdFun::GetFunction() { ScalarFunction txid_current({}, LogicalType::UBIGINT, TransactionIdCurrent); - txid_current.stability = FunctionStability::CONSISTENT_WITHIN_QUERY; + txid_current.SetStability(FunctionStability::CONSISTENT_WITHIN_QUERY); return txid_current; } diff --git a/src/duckdb/extension/core_functions/scalar/generic/typeof.cpp b/src/duckdb/extension/core_functions/scalar/generic/typeof.cpp index a5d26ad8c..3b5ddf4b2 100644 --- a/src/duckdb/extension/core_functions/scalar/generic/typeof.cpp +++ b/src/duckdb/extension/core_functions/scalar/generic/typeof.cpp @@ -25,7 +25,7 @@ unique_ptr BindTypeOfFunctionExpression(FunctionBindExpressionInput ScalarFunction TypeOfFun::GetFunction() { auto fun = ScalarFunction({LogicalType::ANY}, LogicalType::VARCHAR, TypeOfFunction); - fun.null_handling = FunctionNullHandling::SPECIAL_HANDLING; + fun.SetNullHandling(FunctionNullHandling::SPECIAL_HANDLING); fun.bind_expression = BindTypeOfFunctionExpression; return fun; } diff --git a/src/duckdb/extension/core_functions/scalar/list/array_slice.cpp b/src/duckdb/extension/core_functions/scalar/list/array_slice.cpp index 98cbef28a..05124ea9b 100644 --- a/src/duckdb/extension/core_functions/scalar/list/array_slice.cpp +++ b/src/duckdb/extension/core_functions/scalar/list/array_slice.cpp @@ -161,7 +161,6 @@ template void ExecuteConstantSlice(Vector &result, Vector &str_vector, Vector &begin_vector, Vector &end_vector, optional_ptr step_vector, const idx_t count, SelectionVector &sel, idx_t &sel_idx, optional_ptr result_child_vector, bool begin_is_empty, bool end_is_empty) { - // check all this nullness early auto str_valid = !ConstantVector::IsNull(str_vector); auto begin_valid = !ConstantVector::IsNull(begin_vector); @@ -404,11 +403,11 @@ unique_ptr ArraySliceBind(ClientContext &context, ScalarFunction & auto child_type = ArrayType::GetChildType(arguments[0]->return_type); auto target_type = LogicalType::LIST(child_type); arguments[0] = BoundCastExpression::AddCastToType(context, std::move(arguments[0]), target_type); - bound_function.return_type = arguments[0]->return_type; + bound_function.SetReturnType(arguments[0]->return_type); } break; case LogicalTypeId::LIST: // The result is the same type - bound_function.return_type = arguments[0]->return_type; + bound_function.SetReturnType(arguments[0]->return_type); break; case LogicalTypeId::BLOB: case LogicalTypeId::VARCHAR: @@ -421,9 +420,9 @@ unique_ptr ArraySliceBind(ClientContext &context, ScalarFunction & if (arguments[0]->return_type.IsJSONType()) { // This is needed to avoid producing invalid JSON bound_function.arguments[0] = LogicalType::VARCHAR; - bound_function.return_type = LogicalType::VARCHAR; + bound_function.SetReturnType(LogicalType::VARCHAR); } else { - bound_function.return_type = arguments[0]->return_type; + bound_function.SetReturnType(arguments[0]->return_type); } for (idx_t i = 1; i < 3; i++) { if (arguments[i]->return_type.id() != LogicalTypeId::LIST) { @@ -434,7 +433,7 @@ unique_ptr ArraySliceBind(ClientContext &context, ScalarFunction & case LogicalTypeId::SQLNULL: case LogicalTypeId::UNKNOWN: bound_function.arguments[0] = LogicalTypeId::UNKNOWN; - bound_function.return_type = LogicalType::SQLNULL; + bound_function.SetReturnType(LogicalType::SQLNULL); break; default: throw BinderException("ARRAY_SLICE can only operate on LISTs and VARCHARs"); @@ -449,7 +448,7 @@ unique_ptr ArraySliceBind(ClientContext &context, ScalarFunction & bound_function.arguments[2] = LogicalType::BIGINT; } - return make_uniq(bound_function.return_type, begin_is_empty, end_is_empty); + return make_uniq(bound_function.GetReturnType(), begin_is_empty, end_is_empty); } } // namespace @@ -457,8 +456,8 @@ ScalarFunctionSet ListSliceFun::GetFunctions() { // the arguments and return types are actually set in the binder function ScalarFunction fun({LogicalType::ANY, LogicalType::ANY, LogicalType::ANY}, LogicalType::ANY, ArraySliceFunction, ArraySliceBind); - fun.null_handling = FunctionNullHandling::SPECIAL_HANDLING; - BaseScalarFunction::SetReturnsError(fun); + fun.SetNullHandling(FunctionNullHandling::SPECIAL_HANDLING); + fun.SetFallible(); ScalarFunctionSet set; set.AddFunction(fun); fun.arguments.push_back(LogicalType::BIGINT); diff --git a/src/duckdb/extension/core_functions/scalar/list/flatten.cpp b/src/duckdb/extension/core_functions/scalar/list/flatten.cpp index 97b3d625f..23cbf8660 100644 --- a/src/duckdb/extension/core_functions/scalar/list/flatten.cpp +++ b/src/duckdb/extension/core_functions/scalar/list/flatten.cpp @@ -10,7 +10,6 @@ namespace duckdb { namespace { void ListFlattenFunction(DataChunk &args, ExpressionState &, Vector &result) { - const auto flat_list_data = FlatVector::GetData(result); auto &flat_list_mask = FlatVector::Validity(result); diff --git a/src/duckdb/extension/core_functions/scalar/list/list_aggregates.cpp b/src/duckdb/extension/core_functions/scalar/list/list_aggregates.cpp index 8be7134ab..0016e806d 100644 --- a/src/duckdb/extension/core_functions/scalar/list/list_aggregates.cpp +++ b/src/duckdb/extension/core_functions/scalar/list/list_aggregates.cpp @@ -32,7 +32,7 @@ unique_ptr ListAggregatesInitLocalState(ExpressionState &sta unique_ptr ListAggregatesBindFailure(ScalarFunction &bound_function) { bound_function.arguments[0] = LogicalType::SQLNULL; - bound_function.return_type = LogicalType::SQLNULL; + bound_function.SetReturnType(LogicalType::SQLNULL); return make_uniq(LogicalType::SQLNULL); } @@ -187,7 +187,6 @@ struct UniqueFunctor { auto result_data = FlatVector::GetData(result); for (idx_t i = 0; i < count; i++) { - auto state = states[sdata.sel->get_index(i)]; if (!state->hist) { @@ -253,7 +252,6 @@ void ListAggregatesFunction(DataChunk &args, ExpressionState &state, Vector &res idx_t states_idx = 0; for (idx_t i = 0; i < count; i++) { - // initialize the state for this list auto state_ptr = state_buffer.get() + size * i; states[i] = state_ptr; @@ -390,7 +388,6 @@ template unique_ptr ListAggregatesBindFunction(ClientContext &context, ScalarFunction &bound_function, const LogicalType &list_child_type, AggregateFunction &aggr_function, vector> &arguments) { - // create the child expression and its type vector> children; auto expr = make_uniq(Value(list_child_type)); @@ -408,7 +405,7 @@ ListAggregatesBindFunction(ClientContext &context, ScalarFunction &bound_functio bound_function.arguments[0] = LogicalType::LIST(bound_aggr_function->function.arguments[0]); if (IS_AGGR) { - bound_function.return_type = bound_aggr_function->function.return_type; + bound_function.SetReturnType(bound_aggr_function->function.GetReturnType()); } // check if the aggregate function consumed all the extra input arguments if (bound_aggr_function->children.size() > 1) { @@ -417,13 +414,12 @@ ListAggregatesBindFunction(ClientContext &context, ScalarFunction &bound_functio bound_aggr_function->ToString()); } - return make_uniq(bound_function.return_type, std::move(bound_aggr_function)); + return make_uniq(bound_function.GetReturnType(), std::move(bound_aggr_function)); } template unique_ptr ListAggregatesBind(ClientContext &context, ScalarFunction &bound_function, vector> &arguments) { - arguments[0] = BoundCastExpression::AddArrayCastToList(context, std::move(arguments[0])); if (arguments[0]->return_type.id() == LogicalTypeId::SQLNULL) { @@ -459,7 +455,7 @@ unique_ptr ListAggregatesBind(ClientContext &context, ScalarFuncti if (is_parameter) { bound_function.arguments[0] = LogicalTypeId::UNKNOWN; - bound_function.return_type = LogicalType::SQLNULL; + bound_function.SetReturnType(LogicalType::SQLNULL); return nullptr; } @@ -481,7 +477,7 @@ unique_ptr ListAggregatesBind(ClientContext &context, ScalarFuncti // found a matching function, bind it as an aggregate auto best_function = func.functions.GetFunctionByOffset(best_function_idx.GetIndex()); if (IS_AGGR) { - bound_function.errors = best_function.errors; + bound_function.SetErrorMode(best_function.GetErrorMode()); return ListAggregatesBindFunction(context, bound_function, child_type, best_function, arguments); } @@ -493,7 +489,6 @@ unique_ptr ListAggregatesBind(ClientContext &context, ScalarFuncti unique_ptr ListAggregateBind(ClientContext &context, ScalarFunction &bound_function, vector> &arguments) { - // the list column and the name of the aggregate function D_ASSERT(bound_function.arguments.size() >= 2); D_ASSERT(arguments.size() >= 2); @@ -507,8 +502,8 @@ ScalarFunction ListAggregateFun::GetFunction() { auto result = ScalarFunction({LogicalType::LIST(LogicalType::ANY), LogicalType::VARCHAR}, LogicalType::ANY, ListAggregateFunction, ListAggregateBind, nullptr, nullptr, ListAggregatesInitLocalState); - BaseScalarFunction::SetReturnsError(result); - result.null_handling = FunctionNullHandling::SPECIAL_HANDLING; + result.SetFallible(); + result.SetNullHandling(FunctionNullHandling::SPECIAL_HANDLING); result.varargs = LogicalType::ANY; result.serialize = ListAggregatesBindData::SerializeFunction; result.deserialize = ListAggregatesBindData::DeserializeFunction; diff --git a/src/duckdb/extension/core_functions/scalar/list/list_distance.cpp b/src/duckdb/extension/core_functions/scalar/list/list_distance.cpp index 5c3513b2a..ad0c488a0 100644 --- a/src/duckdb/extension/core_functions/scalar/list/list_distance.cpp +++ b/src/duckdb/extension/core_functions/scalar/list/list_distance.cpp @@ -88,7 +88,7 @@ ScalarFunctionSet ListDistanceFun::GetFunctions() { AddListFoldFunction(set, type); } for (auto &func : set.functions) { - BaseScalarFunction::SetReturnsError(func); + func.SetFallible(); } return set; } @@ -115,7 +115,7 @@ ScalarFunctionSet ListCosineSimilarityFun::GetFunctions() { AddListFoldFunction(set, type); } for (auto &func : set.functions) { - BaseScalarFunction::SetReturnsError(func); + func.SetFallible(); } return set; } diff --git a/src/duckdb/extension/core_functions/scalar/list/list_filter.cpp b/src/duckdb/extension/core_functions/scalar/list/list_filter.cpp index 4224fad24..a183d498d 100644 --- a/src/duckdb/extension/core_functions/scalar/list/list_filter.cpp +++ b/src/duckdb/extension/core_functions/scalar/list/list_filter.cpp @@ -7,7 +7,6 @@ namespace duckdb { static unique_ptr ListFilterBind(ClientContext &context, ScalarFunction &bound_function, vector> &arguments) { - // the list column and the bound lambda expression D_ASSERT(arguments.size() == 2); if (arguments[1]->GetExpressionClass() != ExpressionClass::BOUND_LAMBDA) { @@ -25,7 +24,7 @@ static unique_ptr ListFilterBind(ClientContext &context, ScalarFun arguments[0] = BoundCastExpression::AddArrayCastToList(context, std::move(arguments[0])); - bound_function.return_type = arguments[0]->return_type; + bound_function.SetReturnType(arguments[0]->return_type); auto has_index = bound_lambda_expr.parameter_count == 2; return LambdaFunctions::ListLambdaBind(context, bound_function, arguments, has_index); } @@ -39,7 +38,7 @@ ScalarFunction ListFilterFun::GetFunction() { ScalarFunction fun({LogicalType::LIST(LogicalType::ANY), LogicalType::LAMBDA}, LogicalType::LIST(LogicalType::ANY), LambdaFunctions::ListFilterFunction, ListFilterBind, nullptr, nullptr); - fun.null_handling = FunctionNullHandling::SPECIAL_HANDLING; + fun.SetNullHandling(FunctionNullHandling::SPECIAL_HANDLING); fun.serialize = ListLambdaBindData::Serialize; fun.deserialize = ListLambdaBindData::Deserialize; fun.bind_lambda = ListFilterBindLambda; diff --git a/src/duckdb/extension/core_functions/scalar/list/list_has_any_or_all.cpp b/src/duckdb/extension/core_functions/scalar/list/list_has_any_or_all.cpp index 51b4980cd..ff2fd5354 100644 --- a/src/duckdb/extension/core_functions/scalar/list/list_has_any_or_all.cpp +++ b/src/duckdb/extension/core_functions/scalar/list/list_has_any_or_all.cpp @@ -7,7 +7,6 @@ namespace duckdb { static void ListHasAnyFunction(DataChunk &args, ExpressionState &, Vector &result) { - auto &l_vec = args.data[0]; auto &r_vec = args.data[1]; @@ -63,7 +62,6 @@ static void ListHasAnyFunction(DataChunk &args, ExpressionState &, Vector &resul // Use the smaller list to build the set if (r_list.length < l_list.length) { - build_list = r_list; probe_list = l_list; @@ -96,7 +94,6 @@ static void ListHasAnyFunction(DataChunk &args, ExpressionState &, Vector &resul } static void ListHasAllFunction(DataChunk &args, ExpressionState &state, Vector &result) { - const auto &func_expr = state.expr.Cast(); const auto swap = func_expr.function.name == "<@"; diff --git a/src/duckdb/extension/core_functions/scalar/list/list_reduce.cpp b/src/duckdb/extension/core_functions/scalar/list/list_reduce.cpp index 08f64b54e..caedeba21 100644 --- a/src/duckdb/extension/core_functions/scalar/list/list_reduce.cpp +++ b/src/duckdb/extension/core_functions/scalar/list/list_reduce.cpp @@ -175,7 +175,6 @@ bool ExecuteReduce(const idx_t loops, ReduceExecuteInfo &execute_info, LambdaFun unique_ptr ListReduceBind(ClientContext &context, ScalarFunction &bound_function, vector> &arguments) { - // the list column and the bound lambda expression D_ASSERT(arguments.size() == 2 || arguments.size() == 3); if (arguments[1]->GetExpressionClass() != ExpressionClass::BOUND_LAMBDA) { @@ -223,8 +222,8 @@ unique_ptr ListReduceBind(ClientContext &context, ScalarFunction & if (!cast_lambda_expr) { throw BinderException("Could not cast lambda expression to list child type"); } - bound_function.return_type = cast_lambda_expr->return_type; - return make_uniq(bound_function.return_type, std::move(cast_lambda_expr), has_index, + bound_function.SetReturnType(cast_lambda_expr->return_type); + return make_uniq(bound_function.GetReturnType(), std::move(cast_lambda_expr), has_index, has_initial); } @@ -311,7 +310,7 @@ ScalarFunctionSet ListReduceFun::GetFunctions() { ScalarFunction fun({LogicalType::LIST(LogicalType::ANY), LogicalType::LAMBDA}, LogicalType::ANY, LambdaFunctions::ListReduceFunction, ListReduceBind, nullptr, nullptr); - fun.null_handling = FunctionNullHandling::SPECIAL_HANDLING; + fun.SetNullHandling(FunctionNullHandling::SPECIAL_HANDLING); fun.serialize = ListLambdaBindData::Serialize; fun.deserialize = ListLambdaBindData::Deserialize; fun.bind_lambda = ListReduceBindLambda; diff --git a/src/duckdb/extension/core_functions/scalar/list/list_sort.cpp b/src/duckdb/extension/core_functions/scalar/list/list_sort.cpp index 1263500c9..61fab0938 100644 --- a/src/duckdb/extension/core_functions/scalar/list/list_sort.cpp +++ b/src/duckdb/extension/core_functions/scalar/list/list_sort.cpp @@ -37,7 +37,6 @@ ListSortBindData::ListSortBindData(OrderType order_type_p, OrderByNullType null_ ClientContext &context_p) : order_type(order_type_p), null_order(null_order_p), return_type(return_type_p), child_type(child_type_p), is_grade_up(is_grade_up_p), context(context_p) { - // get the vector types types.emplace_back(LogicalType::USMALLINT); types.emplace_back(child_type); @@ -71,7 +70,6 @@ static void SinkDataChunk(const Sort &sort, ExecutionContext &context, OperatorS Vector *child_vector, SelectionVector &sel, idx_t offset_lists_indices, vector &types, Vector &payload_vector, bool &data_to_sort, Vector &lists_indices) { - // slice the child vector Vector slice(*child_vector, sel, offset_lists_indices); @@ -256,22 +254,22 @@ static void ListSortFunction(DataChunk &args, ExpressionState &state, Vector &re static unique_ptr ListSortBind(ClientContext &context, ScalarFunction &bound_function, vector> &arguments, OrderType &order, OrderByNullType &null_order) { - LogicalType child_type; if (arguments[0]->return_type == LogicalTypeId::UNKNOWN) { bound_function.arguments[0] = LogicalTypeId::UNKNOWN; - bound_function.return_type = LogicalType::SQLNULL; - child_type = bound_function.return_type; - return make_uniq(order, null_order, false, bound_function.return_type, child_type, context); + bound_function.SetReturnType(LogicalType::SQLNULL); + child_type = bound_function.GetReturnType(); + return make_uniq(order, null_order, false, bound_function.GetReturnType(), child_type, + context); } arguments[0] = BoundCastExpression::AddArrayCastToList(context, std::move(arguments[0])); child_type = ListType::GetChildType(arguments[0]->return_type); bound_function.arguments[0] = arguments[0]->return_type; - bound_function.return_type = arguments[0]->return_type; + bound_function.SetReturnType(arguments[0]->return_type); - return make_uniq(order, null_order, false, bound_function.return_type, child_type, context); + return make_uniq(order, null_order, false, bound_function.GetReturnType(), child_type, context); } template @@ -286,7 +284,6 @@ static T GetOrder(ClientContext &context, Expression &expr) { static unique_ptr ListGradeUpBind(ClientContext &context, ScalarFunction &bound_function, vector> &arguments) { - D_ASSERT(!arguments.empty() && arguments.size() <= 3); auto order = OrderType::ORDER_DEFAULT; auto null_order = OrderByNullType::ORDER_DEFAULT; @@ -306,9 +303,9 @@ static unique_ptr ListGradeUpBind(ClientContext &context, ScalarFu arguments[0] = BoundCastExpression::AddArrayCastToList(context, std::move(arguments[0])); bound_function.arguments[0] = arguments[0]->return_type; - bound_function.return_type = LogicalType::LIST(LogicalTypeId::BIGINT); + bound_function.SetReturnType(LogicalType::LIST(LogicalTypeId::BIGINT)); auto child_type = ListType::GetChildType(arguments[0]->return_type); - return make_uniq(order, null_order, true, bound_function.return_type, child_type, context); + return make_uniq(order, null_order, true, bound_function.GetReturnType(), child_type, context); } static unique_ptr ListNormalSortBind(ClientContext &context, ScalarFunction &bound_function, diff --git a/src/duckdb/extension/core_functions/scalar/list/list_transform.cpp b/src/duckdb/extension/core_functions/scalar/list/list_transform.cpp index 97e8be006..4e2400f71 100644 --- a/src/duckdb/extension/core_functions/scalar/list/list_transform.cpp +++ b/src/duckdb/extension/core_functions/scalar/list/list_transform.cpp @@ -7,7 +7,6 @@ namespace duckdb { static unique_ptr ListTransformBind(ClientContext &context, ScalarFunction &bound_function, vector> &arguments) { - // the list column and the bound lambda expression D_ASSERT(arguments.size() == 2); if (arguments[1]->GetExpressionClass() != ExpressionClass::BOUND_LAMBDA) { @@ -17,7 +16,7 @@ static unique_ptr ListTransformBind(ClientContext &context, Scalar arguments[0] = BoundCastExpression::AddArrayCastToList(context, std::move(arguments[0])); auto &bound_lambda_expr = arguments[1]->Cast(); - bound_function.return_type = LogicalType::LIST(bound_lambda_expr.lambda_expr->return_type); + bound_function.SetReturnType(LogicalType::LIST(bound_lambda_expr.lambda_expr->return_type)); auto has_index = bound_lambda_expr.parameter_count == 2; return LambdaFunctions::ListLambdaBind(context, bound_function, arguments, has_index); } @@ -31,7 +30,7 @@ ScalarFunction ListTransformFun::GetFunction() { ScalarFunction fun({LogicalType::LIST(LogicalType::ANY), LogicalType::LAMBDA}, LogicalType::LIST(LogicalType::ANY), LambdaFunctions::ListTransformFunction, ListTransformBind, nullptr, nullptr); - fun.null_handling = FunctionNullHandling::SPECIAL_HANDLING; + fun.SetNullHandling(FunctionNullHandling::SPECIAL_HANDLING); fun.serialize = ListLambdaBindData::Serialize; fun.deserialize = ListLambdaBindData::Deserialize; fun.bind_lambda = ListTransformBindLambda; diff --git a/src/duckdb/extension/core_functions/scalar/list/list_value.cpp b/src/duckdb/extension/core_functions/scalar/list/list_value.cpp index cec76fe89..4a0ddbf6b 100644 --- a/src/duckdb/extension/core_functions/scalar/list/list_value.cpp +++ b/src/duckdb/extension/core_functions/scalar/list/list_value.cpp @@ -291,8 +291,8 @@ unique_ptr UnpivotBind(ClientContext &context, ScalarFunction &bou // this is more for completeness reasons bound_function.varargs = child_type; - bound_function.return_type = LogicalType::LIST(child_type); - return make_uniq(bound_function.return_type); + bound_function.SetReturnType(LogicalType::LIST(child_type)); + return make_uniq(bound_function.GetReturnType()); } unique_ptr ListValueStats(ClientContext &context, FunctionStatisticsInput &input) { @@ -309,7 +309,6 @@ unique_ptr ListValueStats(ClientContext &context, FunctionStatis } // namespace ScalarFunctionSet ListValueFun::GetFunctions() { - ScalarFunctionSet set("list_value"); // Overload for 0 arguments, which returns an empty list. @@ -322,7 +321,7 @@ ScalarFunctionSet ListValueFun::GetFunctions() { ScalarFunction value_fun({element_type}, LogicalType::LIST(element_type), ListValueFunction, nullptr, nullptr, ListValueStats); value_fun.varargs = element_type; - value_fun.null_handling = FunctionNullHandling::SPECIAL_HANDLING; + value_fun.SetNullHandling(FunctionNullHandling::SPECIAL_HANDLING); set.AddFunction(value_fun); return set; @@ -332,7 +331,7 @@ ScalarFunction UnpivotListFun::GetFunction() { ScalarFunction fun("unpivot_list", {}, LogicalTypeId::LIST, ListValueFunction, UnpivotBind, nullptr, ListValueStats); fun.varargs = LogicalTypeId::ANY; - fun.null_handling = FunctionNullHandling::SPECIAL_HANDLING; + fun.SetNullHandling(FunctionNullHandling::SPECIAL_HANDLING); return fun; } diff --git a/src/duckdb/extension/core_functions/scalar/list/range.cpp b/src/duckdb/extension/core_functions/scalar/list/range.cpp index 494039d41..13281c09a 100644 --- a/src/duckdb/extension/core_functions/scalar/list/range.cpp +++ b/src/duckdb/extension/core_functions/scalar/list/range.cpp @@ -258,7 +258,7 @@ ScalarFunctionSet ListRangeFun::GetFunctions() { LogicalType::LIST(LogicalType::TIMESTAMP), ListRangeFunction)); for (auto &func : range_set.functions) { - BaseScalarFunction::SetReturnsError(func); + func.SetFallible(); } return range_set; } @@ -277,7 +277,7 @@ ScalarFunctionSet GenerateSeriesFun::GetFunctions() { LogicalType::LIST(LogicalType::TIMESTAMP), ListRangeFunction)); for (auto &func : generate_series.functions) { - BaseScalarFunction::SetReturnsError(func); + func.SetFallible(); } return generate_series; } diff --git a/src/duckdb/extension/core_functions/scalar/map/cardinality.cpp b/src/duckdb/extension/core_functions/scalar/map/cardinality.cpp index 9c81223e7..9806b5d76 100644 --- a/src/duckdb/extension/core_functions/scalar/map/cardinality.cpp +++ b/src/duckdb/extension/core_functions/scalar/map/cardinality.cpp @@ -36,14 +36,14 @@ static unique_ptr CardinalityBind(ClientContext &context, ScalarFu throw BinderException("Cardinality can only operate on MAPs"); } - bound_function.return_type = LogicalType::UBIGINT; - return make_uniq(bound_function.return_type); + bound_function.SetReturnType(LogicalType::UBIGINT); + return make_uniq(bound_function.GetReturnType()); } ScalarFunction CardinalityFun::GetFunction() { ScalarFunction fun({LogicalType::ANY}, LogicalType::UBIGINT, CardinalityFunction, CardinalityBind); fun.varargs = LogicalType::ANY; - fun.null_handling = FunctionNullHandling::DEFAULT_NULL_HANDLING; + fun.SetNullHandling(FunctionNullHandling::DEFAULT_NULL_HANDLING); return fun; } diff --git a/src/duckdb/extension/core_functions/scalar/map/map.cpp b/src/duckdb/extension/core_functions/scalar/map/map.cpp index ab9bea1bb..8b1e86a13 100644 --- a/src/duckdb/extension/core_functions/scalar/map/map.cpp +++ b/src/duckdb/extension/core_functions/scalar/map/map.cpp @@ -38,7 +38,6 @@ static bool MapIsNull(DataChunk &chunk) { } static void MapFunction(DataChunk &args, ExpressionState &, Vector &result) { - // internal MAP representation // - LIST-vector that contains STRUCTs as child entries // - STRUCTs have exactly two fields, a key-field, and a value-field @@ -107,7 +106,6 @@ static void MapFunction(DataChunk &args, ExpressionState &, Vector &result) { idx_t offset = 0; for (idx_t row_idx = 0; row_idx < row_count; row_idx++) { - auto keys_idx = keys_data.sel->get_index(row_idx); auto values_idx = values_data.sel->get_index(row_idx); auto result_idx = result_data.sel->get_index(row_idx); @@ -128,7 +126,6 @@ static void MapFunction(DataChunk &args, ExpressionState &, Vector &result) { // set the selection vectors and perform a duplicate key check value_set_t unique_keys; for (idx_t child_idx = 0; child_idx < keys_entry.length; child_idx++) { - auto key_idx = keys_child_data.sel->get_index(keys_entry.offset + child_idx); auto value_idx = values_child_data.sel->get_index(values_entry.offset + child_idx); @@ -173,16 +170,15 @@ static void MapFunction(DataChunk &args, ExpressionState &, Vector &result) { } ScalarFunctionSet MapFun::GetFunctions() { - ScalarFunction empty_func({}, LogicalType::MAP(LogicalType::SQLNULL, LogicalType::SQLNULL), MapFunction); - BaseScalarFunction::SetReturnsError(empty_func); + empty_func.SetFallible(); auto key_type = LogicalType::TEMPLATE("K"); auto val_type = LogicalType::TEMPLATE("V"); ScalarFunction value_func({LogicalType::LIST(key_type), LogicalType::LIST(val_type)}, LogicalType::MAP(key_type, val_type), MapFunction); - BaseScalarFunction::SetReturnsError(value_func); - value_func.null_handling = FunctionNullHandling::SPECIAL_HANDLING; + value_func.SetFallible(); + value_func.SetNullHandling(FunctionNullHandling::SPECIAL_HANDLING); ScalarFunctionSet set; diff --git a/src/duckdb/extension/core_functions/scalar/map/map_concat.cpp b/src/duckdb/extension/core_functions/scalar/map/map_concat.cpp index 4c733d56f..33fac37ec 100644 --- a/src/duckdb/extension/core_functions/scalar/map/map_concat.cpp +++ b/src/duckdb/extension/core_functions/scalar/map/map_concat.cpp @@ -132,7 +132,6 @@ bool IsEmptyMap(const LogicalType &map) { unique_ptr MapConcatBind(ClientContext &context, ScalarFunction &bound_function, vector> &arguments) { - auto arg_count = arguments.size(); if (arg_count < 2) { throw InvalidInputException("The provided amount of arguments is incorrect, please provide 2 or more maps"); @@ -141,7 +140,7 @@ unique_ptr MapConcatBind(ClientContext &context, ScalarFunction &b if (arguments[0]->return_type.id() == LogicalTypeId::UNKNOWN) { // Prepared statement bound_function.arguments.emplace_back(LogicalTypeId::UNKNOWN); - bound_function.return_type = LogicalType(LogicalTypeId::SQLNULL); + bound_function.SetReturnType(LogicalTypeId::SQLNULL); return nullptr; } @@ -155,7 +154,7 @@ unique_ptr MapConcatBind(ClientContext &context, ScalarFunction &b if (map.id() == LogicalTypeId::UNKNOWN) { // Prepared statement bound_function.arguments.emplace_back(LogicalTypeId::UNKNOWN); - bound_function.return_type = LogicalType(LogicalTypeId::SQLNULL); + bound_function.SetReturnType(LogicalTypeId::SQLNULL); return nullptr; } if (map.id() == LogicalTypeId::SQLNULL) { @@ -183,8 +182,8 @@ unique_ptr MapConcatBind(ClientContext &context, ScalarFunction &b if (expected.id() == LogicalTypeId::SQLNULL && is_null == false) { expected = LogicalType::MAP(LogicalType::SQLNULL, LogicalType::SQLNULL); } - bound_function.return_type = expected; - return make_uniq(bound_function.return_type); + bound_function.SetReturnType(expected); + return make_uniq(bound_function.GetReturnType()); } } // namespace @@ -192,7 +191,7 @@ unique_ptr MapConcatBind(ClientContext &context, ScalarFunction &b ScalarFunction MapConcatFun::GetFunction() { //! the arguments and return types are actually set in the binder function ScalarFunction fun("map_concat", {}, LogicalTypeId::LIST, MapConcatFunction, MapConcatBind); - fun.null_handling = FunctionNullHandling::SPECIAL_HANDLING; + fun.SetNullHandling(FunctionNullHandling::SPECIAL_HANDLING); fun.varargs = LogicalType::ANY; return fun; } diff --git a/src/duckdb/extension/core_functions/scalar/map/map_entries.cpp b/src/duckdb/extension/core_functions/scalar/map/map_entries.cpp index 06af34e66..0d9372903 100644 --- a/src/duckdb/extension/core_functions/scalar/map/map_entries.cpp +++ b/src/duckdb/extension/core_functions/scalar/map/map_entries.cpp @@ -29,14 +29,13 @@ static void MapEntriesFunction(DataChunk &args, ExpressionState &state, Vector & } ScalarFunction MapEntriesFun::GetFunction() { - auto key_type = LogicalType::TEMPLATE("K"); auto val_type = LogicalType::TEMPLATE("V"); auto map_type = LogicalType::MAP(key_type, val_type); auto row_type = LogicalType::STRUCT({{"key", key_type}, {"value", val_type}}); ScalarFunction fun({map_type}, LogicalType::LIST(row_type), MapEntriesFunction); - fun.null_handling = FunctionNullHandling::SPECIAL_HANDLING; + fun.SetNullHandling(FunctionNullHandling::SPECIAL_HANDLING); return fun; } diff --git a/src/duckdb/extension/core_functions/scalar/map/map_extract.cpp b/src/duckdb/extension/core_functions/scalar/map/map_extract.cpp index fcea0b133..b7b8a3091 100644 --- a/src/duckdb/extension/core_functions/scalar/map/map_extract.cpp +++ b/src/duckdb/extension/core_functions/scalar/map/map_extract.cpp @@ -118,7 +118,7 @@ ScalarFunction MapExtractValueFun::GetFunction() { auto val_type = LogicalType::TEMPLATE("V"); ScalarFunction fun({LogicalType::MAP(key_type, val_type), key_type}, val_type, MapExtractValueFunc); - fun.null_handling = FunctionNullHandling::SPECIAL_HANDLING; + fun.SetNullHandling(FunctionNullHandling::SPECIAL_HANDLING); return fun; } @@ -128,7 +128,7 @@ ScalarFunction MapExtractFun::GetFunction() { ScalarFunction fun({LogicalType::MAP(key_type, val_type), key_type}, LogicalType::LIST(val_type), MapExtractListFunc); - fun.null_handling = FunctionNullHandling::SPECIAL_HANDLING; + fun.SetNullHandling(FunctionNullHandling::SPECIAL_HANDLING); return fun; } diff --git a/src/duckdb/extension/core_functions/scalar/map/map_from_entries.cpp b/src/duckdb/extension/core_functions/scalar/map/map_from_entries.cpp index 2344b9a6e..169b9177c 100644 --- a/src/duckdb/extension/core_functions/scalar/map/map_from_entries.cpp +++ b/src/duckdb/extension/core_functions/scalar/map/map_from_entries.cpp @@ -26,9 +26,9 @@ ScalarFunction MapFromEntriesFun::GetFunction() { auto row_type = LogicalType::STRUCT({{"", key_type}, {"", val_type}}); ScalarFunction fun({LogicalType::LIST(row_type)}, map_type, MapFromEntriesFunction); - fun.null_handling = FunctionNullHandling::DEFAULT_NULL_HANDLING; + fun.SetNullHandling(FunctionNullHandling::DEFAULT_NULL_HANDLING); - BaseScalarFunction::SetReturnsError(fun); + fun.SetFallible(); return fun; } diff --git a/src/duckdb/extension/core_functions/scalar/map/map_keys_values.cpp b/src/duckdb/extension/core_functions/scalar/map/map_keys_values.cpp index eec32a0a6..2ee626ba1 100644 --- a/src/duckdb/extension/core_functions/scalar/map/map_keys_values.cpp +++ b/src/duckdb/extension/core_functions/scalar/map/map_keys_values.cpp @@ -57,9 +57,9 @@ ScalarFunction MapKeysFun::GetFunction() { auto val_type = LogicalType::TEMPLATE("V"); ScalarFunction function({LogicalType::MAP(key_type, val_type)}, LogicalType::LIST(key_type), MapKeysFunction); - function.null_handling = FunctionNullHandling::SPECIAL_HANDLING; + function.SetNullHandling(FunctionNullHandling::SPECIAL_HANDLING); - BaseScalarFunction::SetReturnsError(function); + function.SetFallible(); return function; } @@ -68,9 +68,9 @@ ScalarFunction MapValuesFun::GetFunction() { auto val_type = LogicalType::TEMPLATE("V"); ScalarFunction function({LogicalType::MAP(key_type, val_type)}, LogicalType::LIST(val_type), MapValuesFunction); - function.null_handling = FunctionNullHandling::SPECIAL_HANDLING; + function.SetNullHandling(FunctionNullHandling::SPECIAL_HANDLING); - BaseScalarFunction::SetReturnsError(function); + function.SetFallible(); return function; } diff --git a/src/duckdb/extension/core_functions/scalar/math/numeric.cpp b/src/duckdb/extension/core_functions/scalar/math/numeric.cpp index d6ae71bc0..0897daa98 100644 --- a/src/duckdb/extension/core_functions/scalar/math/numeric.cpp +++ b/src/duckdb/extension/core_functions/scalar/math/numeric.cpp @@ -159,7 +159,7 @@ unique_ptr DecimalUnaryOpBind(ClientContext &context, ScalarFuncti break; } bound_function.arguments[0] = decimal_type; - bound_function.return_type = decimal_type; + bound_function.SetReturnType(decimal_type); return nullptr; } @@ -192,7 +192,7 @@ ScalarFunctionSet AbsOperatorFun::GetFunctions() { } } for (auto &func : abs.functions) { - BaseScalarFunction::SetReturnsError(func); + func.SetFallible(); } return abs; } @@ -356,7 +356,7 @@ static unique_ptr BindGenericRoundFunctionDecimal(ClientContext &c } } bound_function.arguments[0] = decimal_type; - bound_function.return_type = LogicalType::DECIMAL(width, 0); + bound_function.SetReturnType(LogicalType::DECIMAL(width, 0)); return nullptr; } @@ -550,7 +550,7 @@ static unique_ptr BindDecimalRoundPrecision(ClientContext &context } } bound_function.arguments[0] = decimal_type; - bound_function.return_type = LogicalType::DECIMAL(width, target_scale); + bound_function.SetReturnType(LogicalType::DECIMAL(width, target_scale)); return make_uniq(round_value); } @@ -972,7 +972,7 @@ struct SqrtOperator { ScalarFunction SqrtFun::GetFunction() { ScalarFunction function({LogicalType::DOUBLE}, LogicalType::DOUBLE, ScalarFunction::UnaryFunction); - BaseScalarFunction::SetReturnsError(function); + function.SetFallible(); return function; } @@ -1017,7 +1017,7 @@ struct LnOperator { ScalarFunction LnFun::GetFunction() { ScalarFunction function({LogicalType::DOUBLE}, LogicalType::DOUBLE, ScalarFunction::UnaryFunction); - BaseScalarFunction::SetReturnsError(function); + function.SetFallible(); return function; } @@ -1044,7 +1044,7 @@ struct Log10Operator { ScalarFunction Log10Fun::GetFunction() { ScalarFunction function({LogicalType::DOUBLE}, LogicalType::DOUBLE, ScalarFunction::UnaryFunction); - BaseScalarFunction::SetReturnsError(function); + function.SetFallible(); return function; } @@ -1073,7 +1073,7 @@ ScalarFunctionSet LogFun::GetFunctions() { funcs.AddFunction(ScalarFunction({LogicalType::DOUBLE, LogicalType::DOUBLE}, LogicalType::DOUBLE, ScalarFunction::BinaryFunction)); for (auto &function : funcs.functions) { - BaseScalarFunction::SetReturnsError(function); + function.SetFallible(); } return funcs; } @@ -1099,7 +1099,7 @@ struct Log2Operator { ScalarFunction Log2Fun::GetFunction() { ScalarFunction function({LogicalType::DOUBLE}, LogicalType::DOUBLE, ScalarFunction::UnaryFunction); - BaseScalarFunction::SetReturnsError(function); + function.SetFallible(); return function; } @@ -1289,7 +1289,7 @@ struct SinOperator { ScalarFunction SinFun::GetFunction() { ScalarFunction function({LogicalType::DOUBLE}, LogicalType::DOUBLE, ScalarFunction::UnaryFunction>); - BaseScalarFunction::SetReturnsError(function); + function.SetFallible(); return function; } @@ -1308,7 +1308,7 @@ struct CosOperator { ScalarFunction CosFun::GetFunction() { ScalarFunction function({LogicalType::DOUBLE}, LogicalType::DOUBLE, ScalarFunction::UnaryFunction>); - BaseScalarFunction::SetReturnsError(function); + function.SetFallible(); return function; } @@ -1327,7 +1327,7 @@ struct TanOperator { ScalarFunction TanFun::GetFunction() { ScalarFunction function({LogicalType::DOUBLE}, LogicalType::DOUBLE, ScalarFunction::UnaryFunction>); - BaseScalarFunction::SetReturnsError(function); + function.SetFallible(); return function; } @@ -1349,7 +1349,7 @@ struct ASinOperator { ScalarFunction AsinFun::GetFunction() { ScalarFunction function({LogicalType::DOUBLE}, LogicalType::DOUBLE, ScalarFunction::UnaryFunction>); - BaseScalarFunction::SetReturnsError(function); + function.SetFallible(); return function; } @@ -1405,7 +1405,7 @@ struct ACos { ScalarFunction AcosFun::GetFunction() { ScalarFunction function({LogicalType::DOUBLE}, LogicalType::DOUBLE, ScalarFunction::UnaryFunction>); - BaseScalarFunction::SetReturnsError(function); + function.SetFallible(); return function; } @@ -1515,7 +1515,7 @@ struct AtanhOperator { ScalarFunction AtanhFun::GetFunction() { ScalarFunction function({LogicalType::DOUBLE}, LogicalType::DOUBLE, ScalarFunction::UnaryFunction); - BaseScalarFunction::SetReturnsError(function); + function.SetFallible(); return function; } @@ -1550,7 +1550,7 @@ struct CotOperator { ScalarFunction CotFun::GetFunction() { ScalarFunction function({LogicalType::DOUBLE}, LogicalType::DOUBLE, ScalarFunction::UnaryFunction>); - BaseScalarFunction::SetReturnsError(function); + function.SetFallible(); return function; } @@ -1572,7 +1572,7 @@ struct GammaOperator { ScalarFunction GammaFun::GetFunction() { auto func = ScalarFunction({LogicalType::DOUBLE}, LogicalType::DOUBLE, ScalarFunction::UnaryFunction); - BaseScalarFunction::SetReturnsError(func); + func.SetFallible(); return func; } @@ -1594,7 +1594,7 @@ struct LogGammaOperator { ScalarFunction LogGammaFun::GetFunction() { ScalarFunction function({LogicalType::DOUBLE}, LogicalType::DOUBLE, ScalarFunction::UnaryFunction); - BaseScalarFunction::SetReturnsError(function); + function.SetFallible(); return function; } @@ -1619,7 +1619,7 @@ struct FactorialOperator { ScalarFunction FactorialOperatorFun::GetFunction() { ScalarFunction function({LogicalType::INTEGER}, LogicalType::HUGEINT, ScalarFunction::UnaryFunction); - BaseScalarFunction::SetReturnsError(function); + function.SetFallible(); return function; } @@ -1735,7 +1735,7 @@ ScalarFunctionSet LeastCommonMultipleFun::GetFunctions() { ScalarFunction({LogicalType::HUGEINT, LogicalType::HUGEINT}, LogicalType::HUGEINT, ScalarFunction::BinaryFunction)); for (auto &function : funcs.functions) { - BaseScalarFunction::SetReturnsError(function); + function.SetFallible(); } return funcs; } diff --git a/src/duckdb/extension/core_functions/scalar/operators/bitwise.cpp b/src/duckdb/extension/core_functions/scalar/operators/bitwise.cpp index 56844b0f1..9e65138c2 100644 --- a/src/duckdb/extension/core_functions/scalar/operators/bitwise.cpp +++ b/src/duckdb/extension/core_functions/scalar/operators/bitwise.cpp @@ -116,7 +116,7 @@ ScalarFunctionSet BitwiseAndFun::GetFunctions() { } functions.AddFunction(ScalarFunction({LogicalType::BIT, LogicalType::BIT}, LogicalType::BIT, BitwiseANDOperation)); for (auto &function : functions.functions) { - BaseScalarFunction::SetReturnsError(function); + function.SetFallible(); } return functions; } @@ -153,7 +153,7 @@ ScalarFunctionSet BitwiseOrFun::GetFunctions() { } functions.AddFunction(ScalarFunction({LogicalType::BIT, LogicalType::BIT}, LogicalType::BIT, BitwiseOROperation)); for (auto &function : functions.functions) { - BaseScalarFunction::SetReturnsError(function); + function.SetFallible(); } return functions; } @@ -190,7 +190,7 @@ ScalarFunctionSet BitwiseXorFun::GetFunctions() { } functions.AddFunction(ScalarFunction({LogicalType::BIT, LogicalType::BIT}, LogicalType::BIT, BitwiseXOROperation)); for (auto &function : functions.functions) { - BaseScalarFunction::SetReturnsError(function); + function.SetFallible(); } return functions; } @@ -225,7 +225,7 @@ ScalarFunctionSet BitwiseNotFun::GetFunctions() { } functions.AddFunction(ScalarFunction({LogicalType::BIT}, LogicalType::BIT, BitwiseNOTOperation)); for (auto &function : functions.functions) { - BaseScalarFunction::SetReturnsError(function); + function.SetFallible(); } return functions; } @@ -294,7 +294,7 @@ ScalarFunctionSet LeftShiftFun::GetFunctions() { functions.AddFunction( ScalarFunction({LogicalType::BIT, LogicalType::INTEGER}, LogicalType::BIT, BitwiseShiftLeftOperation)); for (auto &function : functions.functions) { - BaseScalarFunction::SetReturnsError(function); + function.SetFallible(); } return functions; } @@ -344,7 +344,7 @@ ScalarFunctionSet RightShiftFun::GetFunctions() { functions.AddFunction( ScalarFunction({LogicalType::BIT, LogicalType::INTEGER}, LogicalType::BIT, BitwiseShiftRightOperation)); for (auto &function : functions.functions) { - BaseScalarFunction::SetReturnsError(function); + function.SetFallible(); } return functions; } diff --git a/src/duckdb/extension/core_functions/scalar/random/random.cpp b/src/duckdb/extension/core_functions/scalar/random/random.cpp index 589e264b4..738556161 100644 --- a/src/duckdb/extension/core_functions/scalar/random/random.cpp +++ b/src/duckdb/extension/core_functions/scalar/random/random.cpp @@ -114,7 +114,7 @@ void GenerateUUIDv7Function(DataChunk &args, ExpressionState &state, Vector &res ScalarFunction RandomFun::GetFunction() { ScalarFunction random("random", {}, LogicalType::DOUBLE, RandomFunction, nullptr, nullptr, nullptr, RandomInitLocalState); - random.stability = FunctionStability::VOLATILE; + random.SetStability(FunctionStability::VOLATILE); return random; } @@ -126,7 +126,7 @@ ScalarFunction UUIDv4Fun::GetFunction() { ScalarFunction uuid_v4_function({}, LogicalType::UUID, GenerateUUIDv4Function, nullptr, nullptr, nullptr, RandomInitLocalState); // generate a random uuid v4 - uuid_v4_function.stability = FunctionStability::VOLATILE; + uuid_v4_function.SetStability(FunctionStability::VOLATILE); return uuid_v4_function; } @@ -134,7 +134,7 @@ ScalarFunction UUIDv7Fun::GetFunction() { ScalarFunction uuid_v7_function({}, LogicalType::UUID, GenerateUUIDv7Function, nullptr, nullptr, nullptr, RandomInitLocalState); // generate a random uuid v7 - uuid_v7_function.stability = FunctionStability::VOLATILE; + uuid_v7_function.SetStability(FunctionStability::VOLATILE); return uuid_v7_function; } diff --git a/src/duckdb/extension/core_functions/scalar/random/setseed.cpp b/src/duckdb/extension/core_functions/scalar/random/setseed.cpp index 29072de56..1364b7ddf 100644 --- a/src/duckdb/extension/core_functions/scalar/random/setseed.cpp +++ b/src/duckdb/extension/core_functions/scalar/random/setseed.cpp @@ -58,8 +58,8 @@ unique_ptr SetSeedBind(ClientContext &context, ScalarFunction &bou ScalarFunction SetseedFun::GetFunction() { ScalarFunction setseed("setseed", {LogicalType::DOUBLE}, LogicalType::SQLNULL, SetSeedFunction, SetSeedBind); - setseed.stability = FunctionStability::VOLATILE; - BaseScalarFunction::SetReturnsError(setseed); + setseed.SetVolatile(); + setseed.SetFallible(); return setseed; } diff --git a/src/duckdb/extension/core_functions/scalar/string/hex.cpp b/src/duckdb/extension/core_functions/scalar/string/hex.cpp index d3d6eee7b..6ce5db5ce 100644 --- a/src/duckdb/extension/core_functions/scalar/string/hex.cpp +++ b/src/duckdb/extension/core_functions/scalar/string/hex.cpp @@ -89,7 +89,6 @@ struct HexStrOperator { struct HexIntegralOperator { template static RESULT_TYPE Operation(INPUT_TYPE input, Vector &result) { - auto num_leading_zero = CountZeros::Leading(static_cast(input)); idx_t num_bits_to_check = 64 - num_leading_zero; D_ASSERT(num_bits_to_check <= sizeof(INPUT_TYPE) * 8); @@ -119,7 +118,6 @@ struct HexIntegralOperator { struct HexHugeIntOperator { template static RESULT_TYPE Operation(INPUT_TYPE input, Vector &result) { - idx_t num_leading_zero = CountZeros::Leading(UnsafeNumericCast(input)); idx_t buffer_size = sizeof(INPUT_TYPE) * 2 - (num_leading_zero / 4); @@ -146,7 +144,6 @@ struct HexHugeIntOperator { struct HexUhugeIntOperator { template static RESULT_TYPE Operation(INPUT_TYPE input, Vector &result) { - idx_t num_leading_zero = CountZeros::Leading(UnsafeNumericCast(input)); idx_t buffer_size = sizeof(INPUT_TYPE) * 2 - (num_leading_zero / 4); @@ -204,7 +201,6 @@ struct BinaryStrOperator { struct BinaryIntegralOperator { template static RESULT_TYPE Operation(INPUT_TYPE input, Vector &result) { - auto num_leading_zero = CountZeros::Leading(static_cast(input)); idx_t num_bits_to_check = 64 - num_leading_zero; D_ASSERT(num_bits_to_check <= sizeof(INPUT_TYPE) * 8); @@ -409,7 +405,7 @@ ScalarFunctionSet HexFun::GetFunctions() { ScalarFunction UnhexFun::GetFunction() { ScalarFunction function({LogicalType::VARCHAR}, LogicalType::BLOB, FromHexFunction); - BaseScalarFunction::SetReturnsError(function); + function.SetFallible(); return function; } @@ -433,7 +429,7 @@ ScalarFunctionSet BinFun::GetFunctions() { ScalarFunction UnbinFun::GetFunction() { ScalarFunction function({LogicalType::VARCHAR}, LogicalType::BLOB, FromBinaryFunction); - BaseScalarFunction::SetReturnsError(function); + function.SetFallible(); return function; } diff --git a/src/duckdb/extension/core_functions/scalar/string/instr.cpp b/src/duckdb/extension/core_functions/scalar/string/instr.cpp index cc0fde9f1..2a1411084 100644 --- a/src/duckdb/extension/core_functions/scalar/string/instr.cpp +++ b/src/duckdb/extension/core_functions/scalar/string/instr.cpp @@ -53,7 +53,7 @@ ScalarFunction InstrFun::GetFunction() { auto function = ScalarFunction({LogicalType::VARCHAR, LogicalType::VARCHAR}, LogicalType::BIGINT, ScalarFunction::BinaryFunction, nullptr, nullptr, InStrPropagateStats); - function.collation_handling = FunctionCollationHandling::PUSH_COMBINABLE_COLLATIONS; + function.SetCollationHandling(FunctionCollationHandling::PUSH_COMBINABLE_COLLATIONS); return function; } diff --git a/src/duckdb/extension/core_functions/scalar/string/pad.cpp b/src/duckdb/extension/core_functions/scalar/string/pad.cpp index 586e1605a..44fb8a763 100644 --- a/src/duckdb/extension/core_functions/scalar/string/pad.cpp +++ b/src/duckdb/extension/core_functions/scalar/string/pad.cpp @@ -133,14 +133,14 @@ static void PadFunction(DataChunk &args, ExpressionState &state, Vector &result) ScalarFunction LpadFun::GetFunction() { ScalarFunction func({LogicalType::VARCHAR, LogicalType::INTEGER, LogicalType::VARCHAR}, LogicalType::VARCHAR, PadFunction); - BaseScalarFunction::SetReturnsError(func); + func.SetFallible(); return func; } ScalarFunction RpadFun::GetFunction() { ScalarFunction func({LogicalType::VARCHAR, LogicalType::INTEGER, LogicalType::VARCHAR}, LogicalType::VARCHAR, PadFunction); - BaseScalarFunction::SetReturnsError(func); + func.SetFallible(); return func; } diff --git a/src/duckdb/extension/core_functions/scalar/string/printf.cpp b/src/duckdb/extension/core_functions/scalar/string/printf.cpp index 1ec8ae2cd..b98512d51 100644 --- a/src/duckdb/extension/core_functions/scalar/string/printf.cpp +++ b/src/duckdb/extension/core_functions/scalar/string/printf.cpp @@ -189,7 +189,7 @@ ScalarFunction PrintfFun::GetFunction() { ScalarFunction printf_fun({LogicalType::VARCHAR}, LogicalType::VARCHAR, PrintfFunction, BindPrintfFunction); printf_fun.varargs = LogicalType::ANY; - BaseScalarFunction::SetReturnsError(printf_fun); + printf_fun.SetFallible(); return printf_fun; } @@ -198,7 +198,7 @@ ScalarFunction FormatFun::GetFunction() { ScalarFunction format_fun({LogicalType::VARCHAR}, LogicalType::VARCHAR, PrintfFunction, BindPrintfFunction); format_fun.varargs = LogicalType::ANY; - BaseScalarFunction::SetReturnsError(format_fun); + format_fun.SetFallible(); return format_fun; } diff --git a/src/duckdb/extension/core_functions/scalar/string/repeat.cpp b/src/duckdb/extension/core_functions/scalar/string/repeat.cpp index 2bfceae03..c93bbfa5e 100644 --- a/src/duckdb/extension/core_functions/scalar/string/repeat.cpp +++ b/src/duckdb/extension/core_functions/scalar/string/repeat.cpp @@ -67,7 +67,7 @@ ScalarFunctionSet RepeatFun::GetFunctions() { repeat.AddFunction(ScalarFunction({LogicalType::LIST(LogicalType::TEMPLATE("T")), LogicalType::BIGINT}, LogicalType::LIST(LogicalType::TEMPLATE("T")), RepeatListFunction)); for (auto &func : repeat.functions) { - BaseScalarFunction::SetReturnsError(func); + func.SetFallible(); } return repeat; } diff --git a/src/duckdb/extension/core_functions/scalar/string/starts_with.cpp b/src/duckdb/extension/core_functions/scalar/string/starts_with.cpp index 7ef277292..dd918c8a3 100644 --- a/src/duckdb/extension/core_functions/scalar/string/starts_with.cpp +++ b/src/duckdb/extension/core_functions/scalar/string/starts_with.cpp @@ -17,7 +17,6 @@ static bool StartsWith(const unsigned char *haystack, idx_t haystack_size, const } static bool StartsWith(const string_t &haystack_s, const string_t &needle_s) { - auto haystack = const_uchar_ptr_cast(haystack_s.GetData()); auto haystack_size = haystack_s.GetSize(); auto needle = const_uchar_ptr_cast(needle_s.GetData()); @@ -39,7 +38,7 @@ struct StartsWithOperator { ScalarFunction StartsWithOperatorFun::GetFunction() { ScalarFunction starts_with({LogicalType::VARCHAR, LogicalType::VARCHAR}, LogicalType::BOOLEAN, ScalarFunction::BinaryFunction); - starts_with.collation_handling = FunctionCollationHandling::PUSH_COMBINABLE_COLLATIONS; + starts_with.SetCollationHandling(FunctionCollationHandling::PUSH_COMBINABLE_COLLATIONS); return starts_with; } diff --git a/src/duckdb/extension/core_functions/scalar/struct/struct_insert.cpp b/src/duckdb/extension/core_functions/scalar/struct/struct_insert.cpp index cc4fd6f01..2fa3e2efb 100644 --- a/src/duckdb/extension/core_functions/scalar/struct/struct_insert.cpp +++ b/src/duckdb/extension/core_functions/scalar/struct/struct_insert.cpp @@ -68,8 +68,8 @@ static unique_ptr StructInsertBind(ClientContext &context, ScalarF new_children.push_back(make_pair(child->GetAlias(), arguments[i]->return_type)); } - bound_function.return_type = LogicalType::STRUCT(new_children); - return make_uniq(bound_function.return_type); + bound_function.SetReturnType(LogicalType::STRUCT(new_children)); + return make_uniq(bound_function.GetReturnType()); } static unique_ptr StructInsertStats(ClientContext &context, FunctionStatisticsInput &input) { @@ -93,7 +93,7 @@ static unique_ptr StructInsertStats(ClientContext &context, Func ScalarFunction StructInsertFun::GetFunction() { ScalarFunction fun({}, LogicalTypeId::STRUCT, StructInsertFunction, StructInsertBind, nullptr, StructInsertStats); - fun.null_handling = FunctionNullHandling::SPECIAL_HANDLING; + fun.SetNullHandling(FunctionNullHandling::SPECIAL_HANDLING); fun.varargs = LogicalType::ANY; fun.serialize = VariableReturnBindData::Serialize; fun.deserialize = VariableReturnBindData::Deserialize; diff --git a/src/duckdb/extension/core_functions/scalar/struct/struct_update.cpp b/src/duckdb/extension/core_functions/scalar/struct/struct_update.cpp index e83c9b884..e60366950 100644 --- a/src/duckdb/extension/core_functions/scalar/struct/struct_update.cpp +++ b/src/duckdb/extension/core_functions/scalar/struct/struct_update.cpp @@ -108,8 +108,8 @@ static unique_ptr StructUpdateBind(ClientContext &context, ScalarF } } - bound_function.return_type = LogicalType::STRUCT(new_children); - return make_uniq(bound_function.return_type); + bound_function.SetReturnType(LogicalType::STRUCT(new_children)); + return make_uniq(bound_function.GetReturnType()); } unique_ptr StructUpdateStats(ClientContext &context, FunctionStatisticsInput &input) { @@ -151,7 +151,7 @@ unique_ptr StructUpdateStats(ClientContext &context, FunctionSta ScalarFunction StructUpdateFun::GetFunction() { ScalarFunction fun({}, LogicalTypeId::STRUCT, StructUpdateFunction, StructUpdateBind, nullptr, StructUpdateStats); - fun.null_handling = FunctionNullHandling::SPECIAL_HANDLING; + fun.SetNullHandling(FunctionNullHandling::SPECIAL_HANDLING); fun.varargs = LogicalType::ANY; fun.serialize = VariableReturnBindData::Serialize; fun.deserialize = VariableReturnBindData::Deserialize; diff --git a/src/duckdb/extension/core_functions/scalar/union/union_extract.cpp b/src/duckdb/extension/core_functions/scalar/union/union_extract.cpp index b322f18ea..3feed5a2a 100644 --- a/src/duckdb/extension/core_functions/scalar/union/union_extract.cpp +++ b/src/duckdb/extension/core_functions/scalar/union/union_extract.cpp @@ -97,7 +97,7 @@ unique_ptr UnionExtractBind(ClientContext &context, ScalarFunction throw BinderException("Could not find key \"%s\" in union\n%s", key, message); } - bound_function.return_type = return_type; + bound_function.SetReturnType(return_type); return make_uniq(key, key_index, return_type); } diff --git a/src/duckdb/extension/core_functions/scalar/union/union_tag.cpp b/src/duckdb/extension/core_functions/scalar/union/union_tag.cpp index 95f63590a..98f210fa3 100644 --- a/src/duckdb/extension/core_functions/scalar/union/union_tag.cpp +++ b/src/duckdb/extension/core_functions/scalar/union/union_tag.cpp @@ -10,7 +10,6 @@ namespace { unique_ptr UnionTagBind(ClientContext &context, ScalarFunction &bound_function, vector> &arguments) { - if (arguments.empty()) { throw BinderException("Missing required arguments for union_tag function."); } @@ -42,7 +41,7 @@ unique_ptr UnionTagBind(ClientContext &context, ScalarFunction &bo str.IsInlined() ? str : StringVector::AddString(varchar_vector, str); } auto enum_type = LogicalType::ENUM(varchar_vector, member_count); - bound_function.return_type = enum_type; + bound_function.SetReturnType(enum_type); return nullptr; } diff --git a/src/duckdb/extension/core_functions/scalar/union/union_value.cpp b/src/duckdb/extension/core_functions/scalar/union/union_value.cpp index 44274b3fd..12ecb18e7 100644 --- a/src/duckdb/extension/core_functions/scalar/union/union_value.cpp +++ b/src/duckdb/extension/core_functions/scalar/union/union_value.cpp @@ -40,7 +40,6 @@ void UnionValueFunction(DataChunk &args, ExpressionState &state, Vector &result) unique_ptr UnionValueBind(ClientContext &context, ScalarFunction &bound_function, vector> &arguments) { - if (arguments.size() != 1) { throw BinderException("union_value takes exactly one argument"); } @@ -54,8 +53,8 @@ unique_ptr UnionValueBind(ClientContext &context, ScalarFunction & union_members.push_back(make_pair(child->GetAlias(), child->return_type)); - bound_function.return_type = LogicalType::UNION(std::move(union_members)); - return make_uniq(bound_function.return_type); + bound_function.SetReturnType(LogicalType::UNION(std::move(union_members))); + return make_uniq(bound_function.GetReturnType()); } } // namespace @@ -63,7 +62,7 @@ unique_ptr UnionValueBind(ClientContext &context, ScalarFunction & ScalarFunction UnionValueFun::GetFunction() { ScalarFunction fun("union_value", {}, LogicalTypeId::UNION, UnionValueFunction, UnionValueBind, nullptr, nullptr); fun.varargs = LogicalType::ANY; - fun.null_handling = FunctionNullHandling::SPECIAL_HANDLING; + fun.SetNullHandling(FunctionNullHandling::SPECIAL_HANDLING); fun.serialize = VariableReturnBindData::Serialize; fun.deserialize = VariableReturnBindData::Deserialize; return fun; diff --git a/src/duckdb/extension/icu/icu-current.cpp b/src/duckdb/extension/icu/icu-current.cpp index 65bf29c54..76a7ae0f3 100644 --- a/src/duckdb/extension/icu/icu-current.cpp +++ b/src/duckdb/extension/icu/icu-current.cpp @@ -36,13 +36,13 @@ static void CurrentDateFunction(DataChunk &input, ExpressionState &state, Vector ScalarFunction GetCurrentTimeFun() { ScalarFunction current_time({}, LogicalType::TIME_TZ, CurrentTimeFunction); - current_time.stability = FunctionStability::CONSISTENT_WITHIN_QUERY; + current_time.SetStability(FunctionStability::CONSISTENT_WITHIN_QUERY); return current_time; } ScalarFunction GetCurrentDateFun() { ScalarFunction current_date({}, LogicalType::DATE, CurrentDateFunction); - current_date.stability = FunctionStability::CONSISTENT_WITHIN_QUERY; + current_date.SetStability(FunctionStability::CONSISTENT_WITHIN_QUERY); return current_date; } diff --git a/src/duckdb/extension/icu/icu-dateadd.cpp b/src/duckdb/extension/icu/icu-dateadd.cpp index 7f979f8a5..56a025861 100644 --- a/src/duckdb/extension/icu/icu-dateadd.cpp +++ b/src/duckdb/extension/icu/icu-dateadd.cpp @@ -219,7 +219,6 @@ interval_t ICUCalendarAge::Operation(timestamp_t end_date, timestamp_t start_dat } struct ICUDateAdd : public ICUDateFunc { - template static void ExecuteUnary(DataChunk &args, ExpressionState &state, Vector &result) { D_ASSERT(args.ColumnCount() == 1); diff --git a/src/duckdb/extension/icu/icu-datefunc.cpp b/src/duckdb/extension/icu/icu-datefunc.cpp index 2d5fdce78..b0924b83a 100644 --- a/src/duckdb/extension/icu/icu-datefunc.cpp +++ b/src/duckdb/extension/icu/icu-datefunc.cpp @@ -16,7 +16,6 @@ ICUDateFunc::BindData::BindData(const BindData &other) ICUDateFunc::BindData::BindData(const string &tz_setting_p, const string &cal_setting_p) : tz_setting(tz_setting_p), cal_setting(cal_setting_p) { - InitCalendar(); } diff --git a/src/duckdb/extension/icu/icu-datepart.cpp b/src/duckdb/extension/icu/icu-datepart.cpp index 570430283..3af96b517 100644 --- a/src/duckdb/extension/icu/icu-datepart.cpp +++ b/src/duckdb/extension/icu/icu-datepart.cpp @@ -611,7 +611,7 @@ struct ICUDatePart : public ICUDateFunc { set.AddFunction(GetBinaryPartCodeFunction(LogicalType::TIMESTAMP_TZ)); set.AddFunction(GetStructFunction(LogicalType::TIMESTAMP_TZ)); for (auto &func : set.functions) { - BaseScalarFunction::SetReturnsError(func); + func.SetFallible(); } loader.RegisterFunction(set); } diff --git a/src/duckdb/extension/icu/icu-datesub.cpp b/src/duckdb/extension/icu/icu-datesub.cpp index 00e14f9e1..9c5edbcc1 100644 --- a/src/duckdb/extension/icu/icu-datesub.cpp +++ b/src/duckdb/extension/icu/icu-datesub.cpp @@ -9,7 +9,6 @@ namespace duckdb { struct ICUCalendarSub : public ICUDateFunc { - // ICU only has 32 bit precision for date parts, so it can overflow a high resolution. // Since there is no difference between ICU and the obvious calculations, // we make these using the DuckDB internal type. @@ -192,7 +191,6 @@ ICUDateFunc::part_sub_t ICUDateFunc::SubtractFactory(DatePartSpecifier type) { // MS-SQL differences can be computed using ICU by truncating both arguments // to the desired part precision and then applying ICU subtraction/difference struct ICUCalendarDiff : public ICUDateFunc { - template static int64_t DifferenceFunc(icu::Calendar *calendar, timestamp_t start_date, timestamp_t end_date, part_trunc_t trunc_func, part_sub_t sub_func) { diff --git a/src/duckdb/extension/icu/icu-list-range.cpp b/src/duckdb/extension/icu/icu-list-range.cpp index 4ee9e0b46..a1ec558e2 100644 --- a/src/duckdb/extension/icu/icu-list-range.cpp +++ b/src/duckdb/extension/icu/icu-list-range.cpp @@ -181,7 +181,6 @@ struct ICUListRange : public ICUDateFunc { } static void AddICUListRangeFunction(ExtensionLoader &loader) { - ScalarFunctionSet range("range"); range.AddFunction(ScalarFunction({LogicalType::TIMESTAMP_TZ, LogicalType::TIMESTAMP_TZ, LogicalType::INTERVAL}, LogicalType::LIST(LogicalType::TIMESTAMP_TZ), ICUListRangeFunction, diff --git a/src/duckdb/extension/icu/icu-makedate.cpp b/src/duckdb/extension/icu/icu-makedate.cpp index 7c8efb2cb..128e80d93 100644 --- a/src/duckdb/extension/icu/icu-makedate.cpp +++ b/src/duckdb/extension/icu/icu-makedate.cpp @@ -145,7 +145,7 @@ struct ICUMakeTimestampTZFunc : public ICUDateFunc { static ScalarFunction GetSenaryFunction(const LogicalTypeId &type) { ScalarFunction function({type, type, type, type, type, LogicalType::DOUBLE}, LogicalType::TIMESTAMP_TZ, Execute, Bind); - BaseScalarFunction::SetReturnsError(function); + function.SetFallible(); return function; } @@ -153,7 +153,7 @@ struct ICUMakeTimestampTZFunc : public ICUDateFunc { static ScalarFunction GetSeptenaryFunction(const LogicalTypeId &type) { ScalarFunction function({type, type, type, type, type, LogicalType::DOUBLE, LogicalType::VARCHAR}, LogicalType::TIMESTAMP_TZ, Execute, Bind); - BaseScalarFunction::SetReturnsError(function); + function.SetFallible(); return function; } @@ -162,7 +162,7 @@ struct ICUMakeTimestampTZFunc : public ICUDateFunc { set.AddFunction(GetSenaryFunction(LogicalType::BIGINT)); set.AddFunction(GetSeptenaryFunction(LogicalType::BIGINT)); ScalarFunction function({LogicalType::BIGINT}, LogicalType::TIMESTAMP_TZ, FromMicros); - BaseScalarFunction::SetReturnsError(function); + function.SetFallible(); set.AddFunction(function); loader.RegisterFunction(set); } diff --git a/src/duckdb/extension/icu/icu-timebucket.cpp b/src/duckdb/extension/icu/icu-timebucket.cpp index 1336e0189..9a4035d18 100644 --- a/src/duckdb/extension/icu/icu-timebucket.cpp +++ b/src/duckdb/extension/icu/icu-timebucket.cpp @@ -16,7 +16,6 @@ namespace duckdb { struct ICUTimeBucket : public ICUDateFunc { - // Use 2000-01-03 00:00:00 (Monday) as origin when bucket_width is days, hours, ... for TimescaleDB compatibility // There are 10959 days between 1970-01-01 and 2000-01-03 constexpr static const int64_t DEFAULT_ORIGIN_MICROS_1 = 10959 * Interval::MICROS_PER_DAY; @@ -630,7 +629,7 @@ struct ICUTimeBucket : public ICUDateFunc { set.AddFunction(ScalarFunction({LogicalType::INTERVAL, LogicalType::TIMESTAMP_TZ, LogicalType::VARCHAR}, LogicalType::TIMESTAMP_TZ, ICUTimeBucketTimeZoneFunction, Bind)); for (auto &func : set.functions) { - BaseScalarFunction::SetReturnsError(func); + func.SetFallible(); } loader.RegisterFunction(set); } diff --git a/src/duckdb/extension/icu/icu-timezone.cpp b/src/duckdb/extension/icu/icu-timezone.cpp index 86b8b6033..65993beaf 100644 --- a/src/duckdb/extension/icu/icu-timezone.cpp +++ b/src/duckdb/extension/icu/icu-timezone.cpp @@ -267,7 +267,6 @@ struct ICUToNaiveTimestamp : public ICUDateFunc { }; struct ICULocalTimestampFunc : public ICUDateFunc { - struct BindDataNow : public BindData { explicit BindDataNow(ClientContext &context) : BindData(context) { now = MetaTransaction::Get(context).start_timestamp; @@ -452,7 +451,7 @@ struct ICUTimeZoneFunc : public ICUDateFunc { set.AddFunction(ScalarFunction({LogicalType::VARCHAR, LogicalType::TIME_TZ}, LogicalType::TIME_TZ, Execute, Bind)); for (auto &func : set.functions) { - BaseScalarFunction::SetReturnsError(func); + func.SetFallible(); } loader.RegisterFunction(set); } diff --git a/src/duckdb/extension/icu/icu_extension.cpp b/src/duckdb/extension/icu/icu_extension.cpp index 006283576..59b41afae 100644 --- a/src/duckdb/extension/icu/icu_extension.cpp +++ b/src/duckdb/extension/icu/icu_extension.cpp @@ -5,11 +5,8 @@ #include "duckdb/function/scalar_function.hpp" #include "duckdb/main/config.hpp" #include "duckdb/main/connection.hpp" -#include "duckdb/main/database.hpp" #include "duckdb/main/extension/extension_loader.hpp" #include "duckdb/parser/parsed_data/create_collation_info.hpp" -#include "duckdb/parser/parsed_data/create_scalar_function_info.hpp" -#include "duckdb/parser/parsed_data/create_table_function_info.hpp" #include "duckdb/planner/expression/bound_function_expression.hpp" #include "include/icu-current.hpp" #include "include/icu-dateadd.hpp" @@ -25,8 +22,6 @@ #include "include/icu_extension.hpp" #include "unicode/calendar.h" #include "unicode/coll.h" -#include "unicode/errorcode.h" -#include "unicode/sortkey.h" #include "unicode/stringpiece.h" #include "unicode/timezone.h" #include "unicode/ucol.h" @@ -209,7 +204,7 @@ static ScalarFunction GetICUCollateFunction(const string &collation, const strin return result; } -unique_ptr GetTimeZoneInternal(string &tz_str, vector &candidates) { +unique_ptr GetKnownTimeZone(const string &tz_str) { icu::StringPiece tz_name_utf8(tz_str); const auto uid = icu::UnicodeString::fromUTF8(tz_name_utf8); duckdb::unique_ptr tz(icu::TimeZone::createTimeZone(uid)); @@ -217,6 +212,66 @@ unique_ptr GetTimeZoneInternal(string &tz_str, vector &ca return tz; } + return nullptr; +} + +static string NormalizeTimeZone(const string &tz_str) { + if (GetKnownTimeZone(tz_str)) { + return tz_str; + } + + // Map UTC±NN00 to Etc/UTC±N + do { + if (tz_str.size() <= 4) { + break; + } + if (tz_str.compare(0, 3, "UTC")) { + break; + } + + idx_t pos = 3; + const auto sign = tz_str[pos++]; + if (sign != '+' && sign != '-') { + break; + } + + string mapped = "Etc/GMT"; + mapped += sign; + const auto base_len = mapped.size(); + for (; pos < tz_str.size(); ++pos) { + const auto digit = tz_str[pos]; + // We could get fancy here and count colons and their locations, but I doubt anyone cares. + if (digit == '0' || digit == ':') { + continue; + } + if (!StringUtil::CharacterIsDigit(digit)) { + break; + } + mapped += digit; + } + if (pos < tz_str.size()) { + break; + } + // If we didn't add anything, then make it +0 + if (mapped.size() == base_len) { + mapped.back() = '+'; + mapped += '0'; + } + // Final sanity check + if (GetKnownTimeZone(mapped)) { + return mapped; + } + } while (false); + + return tz_str; +} + +unique_ptr GetTimeZoneInternal(string &tz_str, vector &candidates) { + auto tz = GetKnownTimeZone(tz_str); + if (tz) { + return tz; + } + // Try to be friendlier // Go through all the zone names and look for a case insensitive match // If we don't find one, make a suggestion @@ -269,6 +324,7 @@ unique_ptr ICUHelpers::GetTimeZone(string &tz_str, string *error_ static void SetICUTimeZone(ClientContext &context, SetScope scope, Value ¶meter) { auto tz_str = StringValue::Get(parameter); + tz_str = NormalizeTimeZone(tz_str); ICUHelpers::GetTimeZone(tz_str); parameter = Value(tz_str); } @@ -362,7 +418,6 @@ static void SetICUCalendar(ClientContext &context, SetScope scope, Value ¶me } static void LoadInternal(ExtensionLoader &loader) { - // iterate over all the collations int32_t count; auto locales = icu::Collator::getAvailableLocales(count); @@ -405,6 +460,11 @@ static void LoadInternal(ExtensionLoader &loader) { icu::UnicodeString tz_id; std::string tz_string; tz->getID(tz_id).toUTF8String(tz_string); + // If the environment TZ is invalid, look for some alternatives + tz_string = NormalizeTimeZone(tz_string); + if (!GetKnownTimeZone(tz_string)) { + tz_string = "UTC"; + } config.AddExtensionOption("TimeZone", "The current time zone", LogicalType::VARCHAR, Value(tz_string), SetICUTimeZone); diff --git a/src/duckdb/extension/icu/third_party/icu/common/putil.cpp b/src/duckdb/extension/icu/third_party/icu/common/putil.cpp index c79811499..0c3fd9376 100644 --- a/src/duckdb/extension/icu/third_party/icu/common/putil.cpp +++ b/src/duckdb/extension/icu/third_party/icu/common/putil.cpp @@ -1090,9 +1090,15 @@ uprv_tzname(int n) if (tzid[0] == ':') { tzid++; } - /* This might be a good Olson ID. */ - skipZoneIDPrefix(&tzid); - return tzid; +#if defined(TZDEFAULT) + if (uprv_strcmp(tzid, TZDEFAULT) != 0) { +#endif + /* This might be a good Olson ID. */ + skipZoneIDPrefix(&tzid); + return tzid; +#if defined(TZDEFAULT) + } +#endif } /* else U_TZNAME will give a better result. */ #endif diff --git a/src/duckdb/extension/icu/third_party/icu/common/unicode/ucnv.h b/src/duckdb/extension/icu/third_party/icu/common/unicode/ucnv.h index e69de29bb..c1f295577 100644 --- a/src/duckdb/extension/icu/third_party/icu/common/unicode/ucnv.h +++ b/src/duckdb/extension/icu/third_party/icu/common/unicode/ucnv.h @@ -0,0 +1,11 @@ +/** + * Converter option for EBCDIC SBCS or mixed-SBCS/DBCS (stateful) codepages. + * Swaps Unicode mappings for EBCDIC LF and NL codes, as used on + * S/390 (z/OS) Unix System Services (Open Edition). + * For example, ucnv_open("ibm-1047,swaplfnl", &errorCode); + * See convrtrs.txt. + * + * @see ucnv_open + * @stable ICU 2.4 + */ +#define UCNV_SWAP_LFNL_OPTION_STRING ",swaplfnl" diff --git a/src/duckdb/extension/json/include/json_common.hpp b/src/duckdb/extension/json/include/json_common.hpp index f6dd78f05..81bbd6868 100644 --- a/src/duckdb/extension/json/include/json_common.hpp +++ b/src/duckdb/extension/json/include/json_common.hpp @@ -13,6 +13,7 @@ #include "duckdb/common/operator/string_cast.hpp" #include "duckdb/planner/expression/bound_function_expression.hpp" #include "yyjson.hpp" +#include "duckdb/common/types/blob.hpp" using namespace duckdb_yyjson; // NOLINT @@ -228,11 +229,8 @@ struct JSONCommon { static string FormatParseError(const char *data, idx_t length, yyjson_read_err &error, const string &extra = "") { D_ASSERT(error.code != YYJSON_READ_SUCCESS); - // Go to blob so we can have a better error message for weird strings - auto blob = Value::BLOB(string(data, length)); // Truncate, so we don't print megabytes worth of JSON - string input = blob.ToString(); - input = input.length() > 50 ? string(input.c_str(), 47) + "..." : input; + auto input = length > 50 ? string(data, 47) + "..." : string(data, length); // Have to replace \r, otherwise output is unreadable input = StringUtil::Replace(input, "\r", "\\r"); return StringUtil::Format("Malformed JSON at byte %lld of input: %s. %s Input: \"%s\"", error.pos, error.msg, diff --git a/src/duckdb/extension/json/include/json_reader.hpp b/src/duckdb/extension/json/include/json_reader.hpp index de75af996..b78da3e31 100644 --- a/src/duckdb/extension/json/include/json_reader.hpp +++ b/src/duckdb/extension/json/include/json_reader.hpp @@ -210,8 +210,8 @@ class JSONReader : public BaseFileReader { void PrepareReader(ClientContext &context, GlobalTableFunctionState &) override; bool TryInitializeScan(ClientContext &context, GlobalTableFunctionState &gstate, LocalTableFunctionState &lstate) override; - void Scan(ClientContext &context, GlobalTableFunctionState &global_state, LocalTableFunctionState &local_state, - DataChunk &chunk) override; + AsyncResult Scan(ClientContext &context, GlobalTableFunctionState &global_state, + LocalTableFunctionState &local_state, DataChunk &chunk) override; void FinishFile(ClientContext &context, GlobalTableFunctionState &gstate_p) override; double GetProgressInFile(ClientContext &context) override; diff --git a/src/duckdb/extension/json/include/json_serializer.hpp b/src/duckdb/extension/json/include/json_serializer.hpp index aa17f3ffd..e856bff79 100644 --- a/src/duckdb/extension/json/include/json_serializer.hpp +++ b/src/duckdb/extension/json/include/json_serializer.hpp @@ -39,6 +39,18 @@ struct JsonSerializer : Serializer { return serializer.GetRootObject(); } + template + static string SerializeToString(T &value) { + auto doc = yyjson_mut_doc_new(nullptr); + JsonSerializer serializer(doc, false, false, false); + value.Serialize(serializer); + auto result_obj = serializer.GetRootObject(); + idx_t len = 0; + auto data = yyjson_mut_val_write_opts(result_obj, JSONCommon::WRITE_PRETTY_FLAG, nullptr, + reinterpret_cast(&len), nullptr); + return string(data, len); + } + yyjson_mut_val *GetRootObject() { D_ASSERT(stack.size() == 1); // or we forgot to pop somewhere return stack.front(); diff --git a/src/duckdb/extension/json/json_functions.cpp b/src/duckdb/extension/json/json_functions.cpp index 2d09828c3..2d0ef11f5 100644 --- a/src/duckdb/extension/json/json_functions.cpp +++ b/src/duckdb/extension/json/json_functions.cpp @@ -394,7 +394,11 @@ void JSONFunctions::RegisterSimpleCastFunctions(ExtensionLoader &loader) { loader.RegisterCastFunction(LogicalType::LIST(LogicalType::JSON()), LogicalTypeId::VARCHAR, CastJSONListToVarchar, json_list_to_varchar_cost); - // VARCHAR to JSON[] (also needs a special case otherwise get a VARCHAR -> VARCHAR[] cast first) + // JSON[] to JSON is allowed implicitly + loader.RegisterCastFunction(LogicalType::LIST(LogicalType::JSON()), LogicalType::JSON(), CastJSONListToVarchar, + 100); + + // VARCHAR to JSON[] (also needs a special case otherwise we get a VARCHAR -> VARCHAR[] cast first) const auto varchar_to_json_list_cost = CastFunctionSet::ImplicitCastCost(db, LogicalType::VARCHAR, LogicalType::LIST(LogicalType::JSON())) - 1; BoundCastInfo varchar_to_json_list_info(CastVarcharToJSONList, nullptr, JSONFunctionLocalState::InitCastLocalState); diff --git a/src/duckdb/extension/json/json_functions/json_create.cpp b/src/duckdb/extension/json/json_functions/json_create.cpp index 4cd00249c..d1c8a8afb 100644 --- a/src/duckdb/extension/json/json_functions/json_create.cpp +++ b/src/duckdb/extension/json/json_functions/json_create.cpp @@ -111,11 +111,11 @@ static unique_ptr JSONCreateBindParams(ScalarFunction &bound_funct auto &type = arguments[i]->return_type; if (arguments[i]->HasParameter()) { throw ParameterNotResolvedException(); - } else if (type == LogicalTypeId::SQLNULL) { - // This is needed for macro's - bound_function.arguments.push_back(type); } else if (object && i % 2 == 0) { - // Key, must be varchar + if (type != LogicalType::VARCHAR) { + throw BinderException("json_object() keys must be VARCHAR, add an explicit cast to argument \"%s\"", + arguments[i]->GetName()); + } bound_function.arguments.push_back(LogicalType::VARCHAR); } else { // Value, cast to types that we can put in JSON @@ -128,7 +128,7 @@ static unique_ptr JSONCreateBindParams(ScalarFunction &bound_funct static unique_ptr JSONObjectBind(ClientContext &context, ScalarFunction &bound_function, vector> &arguments) { if (arguments.size() % 2 != 0) { - throw InvalidInputException("json_object() requires an even number of arguments"); + throw BinderException("json_object() requires an even number of arguments"); } return JSONCreateBindParams(bound_function, arguments, true); } @@ -141,7 +141,7 @@ static unique_ptr JSONArrayBind(ClientContext &context, ScalarFunc static unique_ptr ToJSONBind(ClientContext &context, ScalarFunction &bound_function, vector> &arguments) { if (arguments.size() != 1) { - throw InvalidInputException("to_json() takes exactly one argument"); + throw BinderException("to_json() takes exactly one argument"); } return JSONCreateBindParams(bound_function, arguments, false); } @@ -149,14 +149,14 @@ static unique_ptr ToJSONBind(ClientContext &context, ScalarFunctio static unique_ptr ArrayToJSONBind(ClientContext &context, ScalarFunction &bound_function, vector> &arguments) { if (arguments.size() != 1) { - throw InvalidInputException("array_to_json() takes exactly one argument"); + throw BinderException("array_to_json() takes exactly one argument"); } auto arg_id = arguments[0]->return_type.id(); if (arguments[0]->HasParameter()) { throw ParameterNotResolvedException(); } if (arg_id != LogicalTypeId::LIST && arg_id != LogicalTypeId::SQLNULL) { - throw InvalidInputException("array_to_json() argument type must be LIST"); + throw BinderException("array_to_json() argument type must be LIST"); } return JSONCreateBindParams(bound_function, arguments, false); } @@ -164,14 +164,14 @@ static unique_ptr ArrayToJSONBind(ClientContext &context, ScalarFu static unique_ptr RowToJSONBind(ClientContext &context, ScalarFunction &bound_function, vector> &arguments) { if (arguments.size() != 1) { - throw InvalidInputException("row_to_json() takes exactly one argument"); + throw BinderException("row_to_json() takes exactly one argument"); } auto arg_id = arguments[0]->return_type.id(); if (arguments[0]->HasParameter()) { throw ParameterNotResolvedException(); } if (arguments[0]->return_type.id() != LogicalTypeId::STRUCT && arg_id != LogicalTypeId::SQLNULL) { - throw InvalidInputException("row_to_json() argument type must be STRUCT"); + throw BinderException("row_to_json() argument type must be STRUCT"); } return JSONCreateBindParams(bound_function, arguments, false); } @@ -473,7 +473,6 @@ static void CreateValuesList(const StructNames &names, yyjson_mut_doc *doc, yyjs static void CreateValuesArray(const StructNames &names, yyjson_mut_doc *doc, yyjson_mut_val *vals[], Vector &value_v, idx_t count) { - value_v.Flatten(count); // Initialize array for the nested values @@ -616,6 +615,7 @@ static void CreateValues(const StructNames &names, yyjson_mut_doc *doc, yyjson_m case LogicalTypeId::VALIDITY: case LogicalTypeId::TABLE: case LogicalTypeId::LAMBDA: + case LogicalTypeId::GEOMETRY: // TODO! Add support for GEOMETRY throw InternalException("Unsupported type arrived at JSON create function"); } } @@ -728,7 +728,7 @@ ScalarFunctionSet JSONFunctions::GetObjectFunction() { ScalarFunction fun("json_object", {}, LogicalType::JSON(), ObjectFunction, JSONObjectBind, nullptr, nullptr, JSONFunctionLocalState::Init); fun.varargs = LogicalType::ANY; - fun.null_handling = FunctionNullHandling::SPECIAL_HANDLING; + fun.SetNullHandling(FunctionNullHandling::SPECIAL_HANDLING); return ScalarFunctionSet(fun); } @@ -736,7 +736,7 @@ ScalarFunctionSet JSONFunctions::GetArrayFunction() { ScalarFunction fun("json_array", {}, LogicalType::JSON(), ArrayFunction, JSONArrayBind, nullptr, nullptr, JSONFunctionLocalState::Init); fun.varargs = LogicalType::ANY; - fun.null_handling = FunctionNullHandling::SPECIAL_HANDLING; + fun.SetNullHandling(FunctionNullHandling::SPECIAL_HANDLING); return ScalarFunctionSet(fun); } diff --git a/src/duckdb/extension/json/json_functions/json_merge_patch.cpp b/src/duckdb/extension/json/json_functions/json_merge_patch.cpp index 225228924..de1caadc2 100644 --- a/src/duckdb/extension/json/json_functions/json_merge_patch.cpp +++ b/src/duckdb/extension/json/json_functions/json_merge_patch.cpp @@ -84,7 +84,7 @@ ScalarFunctionSet JSONFunctions::GetMergePatchFunction() { ScalarFunction fun("json_merge_patch", {LogicalType::JSON(), LogicalType::JSON()}, LogicalType::JSON(), MergePatchFunction, nullptr, nullptr, nullptr, JSONFunctionLocalState::Init); fun.varargs = LogicalType::JSON(); - fun.null_handling = FunctionNullHandling::SPECIAL_HANDLING; + fun.SetNullHandling(FunctionNullHandling::SPECIAL_HANDLING); return ScalarFunctionSet(fun); } diff --git a/src/duckdb/extension/json/json_functions/json_table_in_out.cpp b/src/duckdb/extension/json/json_functions/json_table_in_out.cpp index 787e393b9..404164181 100644 --- a/src/duckdb/extension/json/json_functions/json_table_in_out.cpp +++ b/src/duckdb/extension/json/json_functions/json_table_in_out.cpp @@ -284,7 +284,6 @@ static void InitializeLocalState(JSONTableInOutLocalState &lstate, DataChunk &in template static bool JSONTableInOutHandleValue(JSONTableInOutLocalState &lstate, JSONTableInOutResult &result, idx_t &child_index, size_t &idx, yyjson_val *child_key, yyjson_val *child_val) { - if (idx < child_index) { return false; // Continue: Get back to where we left off } diff --git a/src/duckdb/extension/json/json_multi_file_info.cpp b/src/duckdb/extension/json/json_multi_file_info.cpp index 7771af489..1f131e6af 100644 --- a/src/duckdb/extension/json/json_multi_file_info.cpp +++ b/src/duckdb/extension/json/json_multi_file_info.cpp @@ -1,6 +1,7 @@ #include "json_multi_file_info.hpp" #include "json_scan.hpp" #include "duckdb/common/types/value.hpp" +#include "duckdb/parallel/async_result.hpp" namespace duckdb { @@ -530,8 +531,17 @@ void ReadJSONObjectsFunction(ClientContext &context, JSONReader &json_reader, JS output.SetCardinality(count); } -void JSONReader::Scan(ClientContext &context, GlobalTableFunctionState &global_state, - LocalTableFunctionState &local_state, DataChunk &output) { +AsyncResult JSONReader::Scan(ClientContext &context, GlobalTableFunctionState &global_state, + LocalTableFunctionState &local_state, DataChunk &output) { +#ifdef DUCKDB_DEBUG_ASYNC_SINK_SOURCE + { + vector> tasks = AsyncResult::GenerateTestTasks(); + if (!tasks.empty()) { + return AsyncResult(std::move(tasks)); + } + } +#endif + auto &gstate = global_state.Cast().state; auto &lstate = local_state.Cast().state; auto &json_data = gstate.bind_data.bind_data->Cast(); @@ -545,6 +555,7 @@ void JSONReader::Scan(ClientContext &context, GlobalTableFunctionState &global_s default: throw InternalException("Unsupported scan type for JSONMultiFileInfo::Scan"); } + return AsyncResult(output.size() ? SourceResultType::HAVE_MORE_OUTPUT : SourceResultType::FINISHED); } void JSONReader::FinishFile(ClientContext &context, GlobalTableFunctionState &global_state) { diff --git a/src/duckdb/extension/json/json_reader.cpp b/src/duckdb/extension/json/json_reader.cpp index b52026a4e..ad61da7d2 100644 --- a/src/duckdb/extension/json/json_reader.cpp +++ b/src/duckdb/extension/json/json_reader.cpp @@ -184,8 +184,7 @@ void JSONReader::OpenJSONFile() { if (!IsOpen()) { auto &fs = FileSystem::GetFileSystem(context); auto regular_file_handle = fs.OpenFile(file, FileFlags::FILE_FLAGS_READ | options.compression); - file_handle = make_uniq(QueryContext(context), std::move(regular_file_handle), - BufferAllocator::Get(context)); + file_handle = make_uniq(context, std::move(regular_file_handle), BufferAllocator::Get(context)); } Reset(); } diff --git a/src/duckdb/extension/parquet/column_reader.cpp b/src/duckdb/extension/parquet/column_reader.cpp index c13a71b6f..6f280c496 100644 --- a/src/duckdb/extension/parquet/column_reader.cpp +++ b/src/duckdb/extension/parquet/column_reader.cpp @@ -895,7 +895,6 @@ unique_ptr ColumnReader::CreateReader(ParquetReader &reader, const default: throw NotImplementedException("Unrecognized Parquet type for Decimal"); } - break; case LogicalTypeId::UUID: return make_uniq(reader, schema); case LogicalTypeId::INTERVAL: diff --git a/src/duckdb/extension/parquet/column_writer.cpp b/src/duckdb/extension/parquet/column_writer.cpp index 7cdd51bc5..4ac113f41 100644 --- a/src/duckdb/extension/parquet/column_writer.cpp +++ b/src/duckdb/extension/parquet/column_writer.cpp @@ -1,7 +1,7 @@ #include "column_writer.hpp" #include "duckdb.hpp" -#include "geo_parquet.hpp" +#include "parquet_geometry.hpp" #include "parquet_rle_bp_decoder.hpp" #include "parquet_bss_encoder.hpp" #include "parquet_statistics.hpp" @@ -13,6 +13,7 @@ #include "writer/list_column_writer.hpp" #include "writer/primitive_column_writer.hpp" #include "writer/struct_column_writer.hpp" +#include "writer/variant_column_writer.hpp" #include "writer/templated_column_writer.hpp" #include "duckdb/common/exception.hpp" #include "duckdb/common/operator/comparison_operators.hpp" @@ -96,7 +97,7 @@ bool ColumnWriterStatistics::HasGeoStats() { return false; } -optional_ptr ColumnWriterStatistics::GetGeoStats() { +optional_ptr ColumnWriterStatistics::GetGeoStats() { return nullptr; } @@ -181,8 +182,7 @@ void ColumnWriter::CompressPage(MemoryStream &temp_writer, size_t &compressed_si } } -void ColumnWriter::HandleRepeatLevels(ColumnWriterState &state, ColumnWriterState *parent, idx_t count, - idx_t max_repeat) const { +void ColumnWriter::HandleRepeatLevels(ColumnWriterState &state, ColumnWriterState *parent, idx_t count) const { if (!parent) { // no repeat levels without a parent node return; @@ -245,8 +245,9 @@ void ColumnWriter::HandleDefineLevels(ColumnWriterState &state, ColumnWriterStat //===--------------------------------------------------------------------===// ParquetColumnSchema ColumnWriter::FillParquetSchema(vector &schemas, - const LogicalType &type, const string &name, - optional_ptr field_ids, idx_t max_repeat, + const LogicalType &type, const string &name, bool allow_geometry, + optional_ptr field_ids, + optional_ptr shredding_types, idx_t max_repeat, idx_t max_define, bool can_have_nulls) { auto null_type = can_have_nulls ? FieldRepetitionType::OPTIONAL : FieldRepetitionType::REQUIRED; if (!can_have_nulls) { @@ -263,6 +264,70 @@ ParquetColumnSchema ColumnWriter::FillParquetSchema(vectorchild_field_ids; } } + optional_ptr shredding_type; + if (shredding_types) { + shredding_type = shredding_types->GetChild(name); + } + + if (type.id() == LogicalTypeId::STRUCT && type.GetAlias() == "PARQUET_VARIANT") { + // variant type + // variants are stored as follows: + // group VARIANT { + // metadata BYTE_ARRAY, + // value BYTE_ARRAY, + // [] + // } + + const bool is_shredded = shredding_type != nullptr; + + child_list_t child_types; + child_types.emplace_back("metadata", LogicalType::BLOB); + child_types.emplace_back("value", LogicalType::BLOB); + if (is_shredded) { + auto &typed_value_type = shredding_type->type; + if (typed_value_type.id() != LogicalTypeId::ANY) { + child_types.emplace_back("typed_value", + VariantColumnWriter::TransformTypedValueRecursive(typed_value_type)); + } + } + + // variant group + duckdb_parquet::SchemaElement top_element; + top_element.repetition_type = null_type; + top_element.num_children = child_types.size(); + top_element.logicalType.__isset.VARIANT = true; + top_element.logicalType.VARIANT.__isset.specification_version = true; + top_element.logicalType.VARIANT.specification_version = 1; + top_element.__isset.logicalType = true; + top_element.__isset.num_children = true; + top_element.__isset.repetition_type = true; + top_element.name = name; + schemas.push_back(std::move(top_element)); + + ParquetColumnSchema variant_column(name, type, max_define, max_repeat, schema_idx, 0); + variant_column.children.reserve(child_types.size()); + for (auto &child_type : child_types) { + auto &child_name = child_type.first; + bool is_optional; + if (child_name == "metadata") { + is_optional = false; + } else if (child_name == "value") { + if (is_shredded) { + //! When shredding the variant, the 'value' becomes optional + is_optional = true; + } else { + is_optional = false; + } + } else { + D_ASSERT(child_name == "typed_value"); + is_optional = true; + } + variant_column.children.emplace_back(FillParquetSchema(schemas, child_type.second, child_type.first, + allow_geometry, child_field_ids, shredding_type, + max_repeat, max_define + 1, is_optional)); + } + return variant_column; + } if (type.id() == LogicalTypeId::STRUCT || type.id() == LogicalTypeId::UNION) { auto &child_types = StructType::GetChildTypes(type); @@ -285,7 +350,8 @@ ParquetColumnSchema ColumnWriter::FillParquetSchema(vectorfield_id; } - ParquetWriter::SetSchemaProperties(type, schema_element); + ParquetWriter::SetSchemaProperties(type, schema_element, allow_geometry); schemas.push_back(std::move(schema_element)); return ParquetColumnSchema(name, type, max_define, max_repeat, schema_idx, 0); } @@ -400,6 +467,17 @@ ColumnWriter::CreateWriterRecursive(ClientContext &context, ParquetWriter &write auto &type = schema.type; auto can_have_nulls = parquet_schemas[schema.schema_index].repetition_type == FieldRepetitionType::OPTIONAL; path_in_schema.push_back(schema.name); + + if (type.id() == LogicalTypeId::STRUCT && type.GetAlias() == "PARQUET_VARIANT") { + vector> child_writers; + child_writers.reserve(schema.children.size()); + for (idx_t i = 0; i < schema.children.size(); i++) { + child_writers.push_back( + CreateWriterRecursive(context, writer, parquet_schemas, schema.children[i], path_in_schema)); + } + return make_uniq(writer, schema, path_in_schema, std::move(child_writers), can_have_nulls); + } + if (type.id() == LogicalTypeId::STRUCT || type.id() == LogicalTypeId::UNION) { // construct the child writers recursively vector> child_writers; @@ -439,11 +517,6 @@ ColumnWriter::CreateWriterRecursive(ClientContext &context, ParquetWriter &write return make_uniq(writer, schema, path_in_schema, std::move(struct_writer), can_have_nulls); } - if (type.id() == LogicalTypeId::BLOB && type.GetAlias() == "WKB_BLOB") { - return make_uniq>( - writer, schema, std::move(path_in_schema), can_have_nulls); - } - switch (type.id()) { case LogicalTypeId::BOOLEAN: return make_uniq(writer, schema, std::move(path_in_schema), can_have_nulls); @@ -514,6 +587,9 @@ ColumnWriter::CreateWriterRecursive(ClientContext &context, ParquetWriter &write case LogicalTypeId::BLOB: return make_uniq>( writer, schema, std::move(path_in_schema), can_have_nulls); + case LogicalTypeId::GEOMETRY: + return make_uniq>( + writer, schema, std::move(path_in_schema), can_have_nulls); case LogicalTypeId::VARCHAR: return make_uniq>( writer, schema, std::move(path_in_schema), can_have_nulls); diff --git a/src/duckdb/extension/parquet/decoder/delta_length_byte_array_decoder.cpp b/src/duckdb/extension/parquet/decoder/delta_length_byte_array_decoder.cpp index 9a0c1eac5..a2fd7abd9 100644 --- a/src/duckdb/extension/parquet/decoder/delta_length_byte_array_decoder.cpp +++ b/src/duckdb/extension/parquet/decoder/delta_length_byte_array_decoder.cpp @@ -34,13 +34,21 @@ void DeltaLengthByteArrayDecoder::InitializePage() { void DeltaLengthByteArrayDecoder::Read(shared_ptr &block_ref, uint8_t *defines, idx_t read_count, Vector &result, idx_t result_offset) { if (defines) { - ReadInternal(block_ref, defines, read_count, result, result_offset); + if (reader.Type().IsJSONType()) { + ReadInternal(block_ref, defines, read_count, result, result_offset); + } else { + ReadInternal(block_ref, defines, read_count, result, result_offset); + } } else { - ReadInternal(block_ref, defines, read_count, result, result_offset); + if (reader.Type().IsJSONType()) { + ReadInternal(block_ref, defines, read_count, result, result_offset); + } else { + ReadInternal(block_ref, defines, read_count, result, result_offset); + } } } -template +template void DeltaLengthByteArrayDecoder::ReadInternal(shared_ptr &block_ref, uint8_t *const defines, const idx_t read_count, Vector &result, const idx_t result_offset) { auto &block = *block_ref; @@ -58,6 +66,8 @@ void DeltaLengthByteArrayDecoder::ReadInternal(shared_ptr &blo } } + const auto &string_column_reader = reader.Cast(); + const auto start_ptr = block.ptr; for (idx_t row_idx = 0; row_idx < read_count; row_idx++) { const auto result_idx = result_offset + row_idx; @@ -75,11 +85,15 @@ void DeltaLengthByteArrayDecoder::ReadInternal(shared_ptr &blo } const auto &str_len = length_data[length_idx++]; result_data[result_idx] = string_t(char_ptr_cast(block.ptr), str_len); + if (VALIDATE_INDIVIDUAL_STRINGS) { + string_column_reader.VerifyString(char_ptr_cast(block.ptr), str_len); + } block.unsafe_inc(str_len); } - // Verify that the strings we read are valid UTF-8 - reader.Cast().VerifyString(char_ptr_cast(start_ptr), block.ptr - start_ptr); + if (!VALIDATE_INDIVIDUAL_STRINGS) { + string_column_reader.VerifyString(char_ptr_cast(start_ptr), NumericCast(block.ptr - start_ptr)); + } StringColumnReader::ReferenceBlock(result, block_ref); } diff --git a/src/duckdb/extension/parquet/decoder/dictionary_decoder.cpp b/src/duckdb/extension/parquet/decoder/dictionary_decoder.cpp index dfce2343b..7ec8bed74 100644 --- a/src/duckdb/extension/parquet/decoder/dictionary_decoder.cpp +++ b/src/duckdb/extension/parquet/decoder/dictionary_decoder.cpp @@ -14,39 +14,36 @@ DictionaryDecoder::DictionaryDecoder(ColumnReader &reader) void DictionaryDecoder::InitializeDictionary(idx_t new_dictionary_size, optional_ptr filter, optional_ptr filter_state, bool has_defines) { - auto old_dict_size = dictionary_size; dictionary_size = new_dictionary_size; filter_result.reset(); filter_count = 0; can_have_nulls = has_defines; - // we use the first value in the dictionary to keep a NULL - if (!dictionary) { - dictionary = make_uniq(reader.Type(), dictionary_size + 1); - } else if (dictionary_size > old_dict_size) { - dictionary->Resize(old_dict_size, dictionary_size + 1); - } - dictionary_id = - reader.reader.GetFileName() + "_" + reader.Schema().name + "_" + std::to_string(reader.chunk_read_offset); + // we use the last entry as a NULL, dictionary vectors don't have a separate validity mask - auto &dict_validity = FlatVector::Validity(*dictionary); - dict_validity.Reset(dictionary_size + 1); + const auto duckdb_dictionary_size = dictionary_size + can_have_nulls; + dictionary = DictionaryVector::CreateReusableDictionary(reader.Type(), duckdb_dictionary_size); + auto &dict_validity = FlatVector::Validity(dictionary->data); + dict_validity.Reset(duckdb_dictionary_size); if (can_have_nulls) { dict_validity.SetInvalid(dictionary_size); } - reader.Plain(reader.block, nullptr, dictionary_size, 0, *dictionary); + // now read the non-NULL values from Parquet + reader.Plain(reader.block, nullptr, dictionary_size, 0, dictionary->data); + + // immediately filter the dictionary, if applicable if (filter && CanFilter(*filter, *filter_state)) { // no filter result yet - apply filter to the dictionary // initialize the filter result - setting everything to false - filter_result = make_unsafe_uniq_array(dictionary_size); + filter_result = make_unsafe_uniq_array(duckdb_dictionary_size); // apply the filter UnifiedVectorFormat vdata; - dictionary->ToUnifiedFormat(dictionary_size, vdata); + dictionary->data.ToUnifiedFormat(duckdb_dictionary_size, vdata); SelectionVector dict_sel; - filter_count = dictionary_size; - ColumnSegment::FilterSelection(dict_sel, *dictionary, vdata, *filter, *filter_state, dictionary_size, - filter_count); + filter_count = duckdb_dictionary_size; + ColumnSegment::FilterSelection(dict_sel, dictionary->data, vdata, *filter, *filter_state, + duckdb_dictionary_size, filter_count); // now set all matching tuples to true for (idx_t i = 0; i < filter_count; i++) { @@ -97,7 +94,8 @@ idx_t DictionaryDecoder::Read(uint8_t *defines, idx_t read_count, Vector &result idx_t valid_count = GetValidValues(defines, read_count, result_offset); if (valid_count == read_count) { // all values are valid - we can directly decompress the offsets into the selection vector - dict_decoder->GetBatch(data_ptr_cast(dictionary_selection_vector.data()), valid_count); + dict_decoder->GetBatch(data_ptr_cast(dictionary_selection_vector.data()), + NumericCast(valid_count)); // we do still need to verify the offsets though uint32_t max_index = 0; for (idx_t idx = 0; idx < valid_count; idx++) { @@ -109,19 +107,18 @@ idx_t DictionaryDecoder::Read(uint8_t *defines, idx_t read_count, Vector &result } else if (valid_count > 0) { // for the valid entries - decode the offsets offset_buffer.resize(reader.reader.allocator, sizeof(uint32_t) * valid_count); - dict_decoder->GetBatch(offset_buffer.ptr, valid_count); + dict_decoder->GetBatch(offset_buffer.ptr, NumericCast(valid_count)); ConvertDictToSelVec(reinterpret_cast(offset_buffer.ptr), valid_sel, valid_count); } #ifdef DEBUG dictionary_selection_vector.Verify(read_count, dictionary_size + can_have_nulls); #endif if (result_offset == 0) { - result.Dictionary(*dictionary, dictionary_size + can_have_nulls, dictionary_selection_vector, read_count); - DictionaryVector::SetDictionaryId(result, dictionary_id); + result.Dictionary(dictionary, dictionary_selection_vector); D_ASSERT(result.GetVectorType() == VectorType::DICTIONARY_VECTOR); } else { D_ASSERT(result.GetVectorType() == VectorType::FLAT_VECTOR); - VectorOperations::Copy(*dictionary, result, dictionary_selection_vector, read_count, 0, result_offset); + VectorOperations::Copy(dictionary->data, result, dictionary_selection_vector, read_count, 0, result_offset); } return valid_count; } @@ -132,7 +129,7 @@ void DictionaryDecoder::Skip(uint8_t *defines, idx_t skip_count) { } idx_t valid_count = reader.GetValidCount(defines, skip_count); // skip past the valid offsets - dict_decoder->Skip(valid_count); + dict_decoder->Skip(NumericCast(valid_count)); } bool DictionaryDecoder::DictionarySupportsFilter(const TableFilter &filter, TableFilterState &filter_state) { diff --git a/src/duckdb/extension/parquet/include/column_writer.hpp b/src/duckdb/extension/parquet/include/column_writer.hpp index d475e903b..929e94a11 100644 --- a/src/duckdb/extension/parquet/include/column_writer.hpp +++ b/src/duckdb/extension/parquet/include/column_writer.hpp @@ -18,6 +18,7 @@ class ParquetWriter; class ColumnWriterPageState; class PrimitiveColumnWriterState; struct ChildFieldIDs; +struct ShreddingType; class ResizeableBuffer; class ParquetBloomFilter; @@ -71,11 +72,6 @@ class ColumnWriter { bool can_have_nulls); virtual ~ColumnWriter(); - ParquetWriter &writer; - const ParquetColumnSchema &column_schema; - vector schema_path; - bool can_have_nulls; - public: const LogicalType &Type() const { return column_schema.type; @@ -94,9 +90,11 @@ class ColumnWriter { } static ParquetColumnSchema FillParquetSchema(vector &schemas, - const LogicalType &type, const string &name, - optional_ptr field_ids, idx_t max_repeat = 0, - idx_t max_define = 1, bool can_have_nulls = true); + const LogicalType &type, const string &name, bool allow_geometry, + optional_ptr field_ids, + optional_ptr shredding_types, + idx_t max_repeat = 0, idx_t max_define = 1, + bool can_have_nulls = true); //! Create the column writer for a specific type recursively static unique_ptr CreateWriterRecursive(ClientContext &context, ParquetWriter &writer, const vector &parquet_schemas, @@ -129,10 +127,19 @@ class ColumnWriter { protected: void HandleDefineLevels(ColumnWriterState &state, ColumnWriterState *parent, const ValidityMask &validity, const idx_t count, const uint16_t define_value, const uint16_t null_value) const; - void HandleRepeatLevels(ColumnWriterState &state_p, ColumnWriterState *parent, idx_t count, idx_t max_repeat) const; + void HandleRepeatLevels(ColumnWriterState &state_p, ColumnWriterState *parent, idx_t count) const; void CompressPage(MemoryStream &temp_writer, size_t &compressed_size, data_ptr_t &compressed_data, AllocatedData &compressed_buf); + +public: + ParquetWriter &writer; + const ParquetColumnSchema &column_schema; + vector schema_path; + bool can_have_nulls; + +protected: + vector> child_writers; }; } // namespace duckdb diff --git a/src/duckdb/extension/parquet/include/decoder/delta_length_byte_array_decoder.hpp b/src/duckdb/extension/parquet/include/decoder/delta_length_byte_array_decoder.hpp index f8141e26e..9f304da25 100644 --- a/src/duckdb/extension/parquet/include/decoder/delta_length_byte_array_decoder.hpp +++ b/src/duckdb/extension/parquet/include/decoder/delta_length_byte_array_decoder.hpp @@ -27,7 +27,7 @@ class DeltaLengthByteArrayDecoder { void Skip(uint8_t *defines, idx_t skip_count); private: - template + template void ReadInternal(shared_ptr &block, uint8_t *defines, idx_t read_count, Vector &result, idx_t result_offset); template diff --git a/src/duckdb/extension/parquet/include/decoder/dictionary_decoder.hpp b/src/duckdb/extension/parquet/include/decoder/dictionary_decoder.hpp index c012a82dc..de75b045e 100644 --- a/src/duckdb/extension/parquet/include/decoder/dictionary_decoder.hpp +++ b/src/duckdb/extension/parquet/include/decoder/dictionary_decoder.hpp @@ -47,11 +47,10 @@ class DictionaryDecoder { SelectionVector valid_sel; SelectionVector dictionary_selection_vector; idx_t dictionary_size; - unique_ptr dictionary; + buffer_ptr dictionary; unsafe_unique_array filter_result; idx_t filter_count; bool can_have_nulls; - string dictionary_id; }; } // namespace duckdb diff --git a/src/duckdb/extension/parquet/include/geo_parquet.hpp b/src/duckdb/extension/parquet/include/geo_parquet.hpp deleted file mode 100644 index 6dc82bc8d..000000000 --- a/src/duckdb/extension/parquet/include/geo_parquet.hpp +++ /dev/null @@ -1,241 +0,0 @@ -//===----------------------------------------------------------------------===// -// DuckDB -// -// geo_parquet.hpp -// -// -//===----------------------------------------------------------------------===// - -#pragma once - -#include "column_writer.hpp" -#include "duckdb/common/string.hpp" -#include "duckdb/common/types/data_chunk.hpp" -#include "duckdb/common/unordered_map.hpp" -#include "duckdb/common/unordered_set.hpp" -#include "parquet_types.h" - -namespace duckdb { - -struct ParquetColumnSchema; - -struct GeometryKindSet { - - uint8_t bits[4] = {0, 0, 0, 0}; - - void Add(uint32_t wkb_type) { - auto kind = wkb_type % 1000; - auto dims = wkb_type / 1000; - if (kind < 1 || kind > 7 || (dims) > 3) { - return; - } - bits[dims] |= (1 << (kind - 1)); - } - - void Combine(const GeometryKindSet &other) { - for (uint32_t d = 0; d < 4; d++) { - bits[d] |= other.bits[d]; - } - } - - bool IsEmpty() const { - for (uint32_t d = 0; d < 4; d++) { - if (bits[d] != 0) { - return false; - } - } - return true; - } - - template - vector ToList() const { - vector result; - for (uint32_t d = 0; d < 4; d++) { - for (uint32_t i = 1; i <= 7; i++) { - if (bits[d] & (1 << (i - 1))) { - result.push_back(i + d * 1000); - } - } - } - return result; - } - - vector ToString(bool snake_case) const { - vector result; - for (uint32_t d = 0; d < 4; d++) { - for (uint32_t i = 1; i <= 7; i++) { - if (bits[d] & (1 << (i - 1))) { - string str; - switch (i) { - case 1: - str = snake_case ? "point" : "Point"; - break; - case 2: - str = snake_case ? "linestring" : "LineString"; - break; - case 3: - str = snake_case ? "polygon" : "Polygon"; - break; - case 4: - str = snake_case ? "multipoint" : "MultiPoint"; - break; - case 5: - str = snake_case ? "multilinestring" : "MultiLineString"; - break; - case 6: - str = snake_case ? "multipolygon" : "MultiPolygon"; - break; - case 7: - str = snake_case ? "geometrycollection" : "GeometryCollection"; - break; - default: - str = snake_case ? "unknown" : "Unknown"; - break; - } - switch (d) { - case 1: - str += snake_case ? "_z" : " Z"; - break; - case 2: - str += snake_case ? "_m" : " M"; - break; - case 3: - str += snake_case ? "_zm" : " ZM"; - break; - default: - break; - } - - result.push_back(str); - } - } - } - return result; - } -}; - -struct GeometryExtent { - - double xmin = NumericLimits::Maximum(); - double xmax = NumericLimits::Minimum(); - double ymin = NumericLimits::Maximum(); - double ymax = NumericLimits::Minimum(); - double zmin = NumericLimits::Maximum(); - double zmax = NumericLimits::Minimum(); - double mmin = NumericLimits::Maximum(); - double mmax = NumericLimits::Minimum(); - - bool IsSet() const { - return xmin != NumericLimits::Maximum() && xmax != NumericLimits::Minimum() && - ymin != NumericLimits::Maximum() && ymax != NumericLimits::Minimum(); - } - - bool HasZ() const { - return zmin != NumericLimits::Maximum() && zmax != NumericLimits::Minimum(); - } - - bool HasM() const { - return mmin != NumericLimits::Maximum() && mmax != NumericLimits::Minimum(); - } - - void Combine(const GeometryExtent &other) { - xmin = std::min(xmin, other.xmin); - xmax = std::max(xmax, other.xmax); - ymin = std::min(ymin, other.ymin); - ymax = std::max(ymax, other.ymax); - zmin = std::min(zmin, other.zmin); - zmax = std::max(zmax, other.zmax); - mmin = std::min(mmin, other.mmin); - mmax = std::max(mmax, other.mmax); - } - - void Combine(const double &xmin_p, const double &xmax_p, const double &ymin_p, const double &ymax_p) { - xmin = std::min(xmin, xmin_p); - xmax = std::max(xmax, xmax_p); - ymin = std::min(ymin, ymin_p); - ymax = std::max(ymax, ymax_p); - } - - void ExtendX(const double &x) { - xmin = std::min(xmin, x); - xmax = std::max(xmax, x); - } - void ExtendY(const double &y) { - ymin = std::min(ymin, y); - ymax = std::max(ymax, y); - } - void ExtendZ(const double &z) { - zmin = std::min(zmin, z); - zmax = std::max(zmax, z); - } - void ExtendM(const double &m) { - mmin = std::min(mmin, m); - mmax = std::max(mmax, m); - } -}; - -struct GeometryStats { - GeometryKindSet types; - GeometryExtent bbox; - - void Update(const string_t &wkb); -}; - -//------------------------------------------------------------------------------ -// GeoParquetMetadata -//------------------------------------------------------------------------------ -class ParquetReader; -class ColumnReader; -class ClientContext; -class ExpressionExecutor; - -enum class GeoParquetColumnEncoding : uint8_t { - WKB = 1, - POINT, - LINESTRING, - POLYGON, - MULTIPOINT, - MULTILINESTRING, - MULTIPOLYGON, -}; - -struct GeoParquetColumnMetadata { - // The encoding of the geometry column - GeoParquetColumnEncoding geometry_encoding; - - // The statistics of the geometry column - GeometryStats stats; - - // The crs of the geometry column (if any) in PROJJSON format - string projjson; - - // Used to track the "primary" geometry column (if any) - idx_t insertion_index = 0; -}; - -class GeoParquetFileMetadata { -public: - void AddGeoParquetStats(const string &column_name, const LogicalType &type, const GeometryStats &stats); - void Write(duckdb_parquet::FileMetaData &file_meta_data); - - // Try to read GeoParquet metadata. Returns nullptr if not found, invalid or the required spatial extension is not - // available. - static unique_ptr TryRead(const duckdb_parquet::FileMetaData &file_meta_data, - const ClientContext &context); - const unordered_map &GetColumnMeta() const; - - static unique_ptr CreateColumnReader(ParquetReader &reader, const ParquetColumnSchema &schema, - ClientContext &context); - - bool IsGeometryColumn(const string &column_name) const; - - static bool IsGeoParquetConversionEnabled(const ClientContext &context); - static LogicalType GeometryType(); - -private: - mutex write_lock; - string version = "1.1.0"; - unordered_map geometry_columns; -}; - -} // namespace duckdb diff --git a/src/duckdb/extension/parquet/include/parquet_column_schema.hpp b/src/duckdb/extension/parquet/include/parquet_column_schema.hpp index d467e2a02..fd49a4ec4 100644 --- a/src/duckdb/extension/parquet/include/parquet_column_schema.hpp +++ b/src/duckdb/extension/parquet/include/parquet_column_schema.hpp @@ -15,7 +15,7 @@ namespace duckdb { using duckdb_parquet::FileMetaData; struct ParquetOptions; -enum class ParquetColumnSchemaType { COLUMN, FILE_ROW_NUMBER, GEOMETRY, EXPRESSION, VARIANT }; +enum class ParquetColumnSchemaType { COLUMN, FILE_ROW_NUMBER, EXPRESSION, VARIANT, GEOMETRY }; enum class ParquetExtraTypeInfo { NONE, @@ -35,7 +35,7 @@ struct ParquetColumnSchema { ParquetColumnSchemaType schema_type = ParquetColumnSchemaType::COLUMN); ParquetColumnSchema(string name, LogicalType type, idx_t max_define, idx_t max_repeat, idx_t schema_index, idx_t column_index, ParquetColumnSchemaType schema_type = ParquetColumnSchemaType::COLUMN); - ParquetColumnSchema(ParquetColumnSchema parent, LogicalType result_type, ParquetColumnSchemaType schema_type); + ParquetColumnSchema(ParquetColumnSchema child, LogicalType result_type, ParquetColumnSchemaType schema_type); ParquetColumnSchemaType schema_type; string name; diff --git a/src/duckdb/extension/parquet/include/parquet_dbp_decoder.hpp b/src/duckdb/extension/parquet/include/parquet_dbp_decoder.hpp index 31fb26cc9..775160215 100644 --- a/src/duckdb/extension/parquet/include/parquet_dbp_decoder.hpp +++ b/src/duckdb/extension/parquet/include/parquet_dbp_decoder.hpp @@ -18,7 +18,7 @@ class DbpDecoder { : buffer_(buffer, buffer_len), // block_size_in_values(ParquetDecodeUtils::VarintDecode(buffer_)), - number_of_miniblocks_per_block(ParquetDecodeUtils::VarintDecode(buffer_)), + number_of_miniblocks_per_block(DecodeNumberOfMiniblocksPerBlock(buffer_)), number_of_values_in_a_miniblock(block_size_in_values / number_of_miniblocks_per_block), total_value_count(ParquetDecodeUtils::VarintDecode(buffer_)), previous_value(ParquetDecodeUtils::ZigzagToInt(ParquetDecodeUtils::VarintDecode(buffer_))), @@ -31,7 +31,7 @@ class DbpDecoder { number_of_values_in_a_miniblock % BitpackingPrimitives::BITPACKING_ALGORITHM_GROUP_SIZE == 0)) { throw InvalidInputException("Parquet file has invalid block sizes for DELTA_BINARY_PACKED"); } - }; + } ByteBuffer BufferPtr() const { return buffer_; @@ -68,6 +68,15 @@ class DbpDecoder { } private: + static idx_t DecodeNumberOfMiniblocksPerBlock(ByteBuffer &buffer) { + auto res = ParquetDecodeUtils::VarintDecode(buffer); + if (res == 0) { + throw InvalidInputException( + "Parquet file has invalid number of miniblocks per block for DELTA_BINARY_PACKED"); + } + return res; + } + template void GetBatchInternal(const data_ptr_t target_values_ptr, const idx_t batch_size) { if (batch_size == 0) { diff --git a/src/duckdb/extension/parquet/include/parquet_field_id.hpp b/src/duckdb/extension/parquet/include/parquet_field_id.hpp new file mode 100644 index 000000000..9d5dd754c --- /dev/null +++ b/src/duckdb/extension/parquet/include/parquet_field_id.hpp @@ -0,0 +1,39 @@ +#pragma once + +#include "duckdb/common/serializer/buffered_file_writer.hpp" +#include "duckdb/common/case_insensitive_map.hpp" + +namespace duckdb { + +struct FieldID; +struct ChildFieldIDs { + ChildFieldIDs(); + ChildFieldIDs Copy() const; + unique_ptr> ids; + + void Serialize(Serializer &serializer) const; + static ChildFieldIDs Deserialize(Deserializer &source); +}; + +struct FieldID { +public: + static constexpr const auto DUCKDB_FIELD_ID = "__duckdb_field_id"; + FieldID(); + explicit FieldID(int32_t field_id); + FieldID Copy() const; + bool set; + int32_t field_id; + ChildFieldIDs child_field_ids; + + void Serialize(Serializer &serializer) const; + static FieldID Deserialize(Deserializer &source); + +public: + static void GenerateFieldIDs(ChildFieldIDs &field_ids, idx_t &field_id, const vector &names, + const vector &sql_types); + static void GetFieldIDs(const Value &field_ids_value, ChildFieldIDs &field_ids, + unordered_set &unique_field_ids, + const case_insensitive_map_t &name_to_type_map); +}; + +} // namespace duckdb diff --git a/src/duckdb/extension/parquet/include/parquet_file_metadata_cache.hpp b/src/duckdb/extension/parquet/include/parquet_file_metadata_cache.hpp index aa1c1c9b5..552fbff7c 100644 --- a/src/duckdb/extension/parquet/include/parquet_file_metadata_cache.hpp +++ b/src/duckdb/extension/parquet/include/parquet_file_metadata_cache.hpp @@ -9,7 +9,7 @@ #include "duckdb.hpp" #include "duckdb/storage/object_cache.hpp" -#include "geo_parquet.hpp" +#include "parquet_geometry.hpp" #include "parquet_types.h" namespace duckdb { diff --git a/src/duckdb/extension/parquet/include/parquet_geometry.hpp b/src/duckdb/extension/parquet/include/parquet_geometry.hpp new file mode 100644 index 000000000..3c367ee37 --- /dev/null +++ b/src/duckdb/extension/parquet/include/parquet_geometry.hpp @@ -0,0 +1,107 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// geo_parquet.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "column_writer.hpp" +#include "duckdb/common/string.hpp" +#include "duckdb/common/types/data_chunk.hpp" +#include "duckdb/common/unordered_map.hpp" +#include "duckdb/common/unordered_set.hpp" +#include "parquet_types.h" + +namespace duckdb { + +struct ParquetColumnSchema; +class ParquetReader; +class ColumnReader; +class ClientContext; + +struct GeometryColumnReader { + static unique_ptr Create(ParquetReader &reader, const ParquetColumnSchema &schema, + ClientContext &context); +}; + +enum class GeoParquetColumnEncoding : uint8_t { + WKB = 1, + POINT, + LINESTRING, + POLYGON, + MULTIPOINT, + MULTILINESTRING, + MULTIPOLYGON, +}; + +enum class GeoParquetVersion : uint8_t { + // Write GeoParquet 1.0 metadata + // GeoParquet 1.0 has the widest support among readers and writers + V1, + + // Write GeoParquet 2.0 + // The GeoParquet 2.0 options is identical to GeoParquet 1.0 except the underlying storage + // of spatial columns is Parquet native geometry, where the Parquet writer will include + // native statistics according to the underlying Parquet options. Compared to 'BOTH', this will + // actually write the metadata as containing GeoParquet version 2.0.0 + // However, V2 isnt standardized yet, so this option is still a bit experimental + V2, + + // Write GeoParquet 1.0 metadata, with native Parquet geometry types + // This is a bit of a hold-over option for compatibility with systems that + // reject GeoParquet 2.0 metadata, but can read Parquet native geometry types as they simply ignore the extra + // logical type. DuckDB v1.4.0 falls into this category. + BOTH, + + // Do not write GeoParquet metadata + // This option suppresses GeoParquet metadata; however, spatial types will be written as + // Parquet native Geometry/Geography. + NONE, +}; + +struct GeoParquetColumnMetadata { + // The encoding of the geometry column + GeoParquetColumnEncoding geometry_encoding; + + // The statistics of the geometry column + GeometryStatsData stats; + + // The crs of the geometry column (if any) in PROJJSON format + string projjson; + + // Used to track the "primary" geometry column (if any) + idx_t insertion_index = 0; + + GeoParquetColumnMetadata() { + geometry_encoding = GeoParquetColumnEncoding::WKB; + stats.SetEmpty(); + } +}; + +class GeoParquetFileMetadata { +public: + explicit GeoParquetFileMetadata(GeoParquetVersion geo_parquet_version) : version(geo_parquet_version) { + } + void AddGeoParquetStats(const string &column_name, const LogicalType &type, const GeometryStatsData &stats); + void Write(duckdb_parquet::FileMetaData &file_meta_data); + + // Try to read GeoParquet metadata. Returns nullptr if not found, invalid or the required spatial extension is not + // available. + static unique_ptr TryRead(const duckdb_parquet::FileMetaData &file_meta_data, + const ClientContext &context); + const unordered_map &GetColumnMeta() const; + + bool IsGeometryColumn(const string &column_name) const; + + static bool IsGeoParquetConversionEnabled(const ClientContext &context); + +private: + mutex write_lock; + unordered_map geometry_columns; + GeoParquetVersion version; +}; + +} // namespace duckdb diff --git a/src/duckdb/extension/parquet/include/parquet_metadata.hpp b/src/duckdb/extension/parquet/include/parquet_metadata.hpp index 09ecd5afa..fb0900610 100644 --- a/src/duckdb/extension/parquet/include/parquet_metadata.hpp +++ b/src/duckdb/extension/parquet/include/parquet_metadata.hpp @@ -38,4 +38,9 @@ class ParquetBloomProbeFunction : public TableFunction { ParquetBloomProbeFunction(); }; +class ParquetFullMetadataFunction : public TableFunction { +public: + ParquetFullMetadataFunction(); +}; + } // namespace duckdb diff --git a/src/duckdb/extension/parquet/include/parquet_reader.hpp b/src/duckdb/extension/parquet/include/parquet_reader.hpp index de905c70c..8c9d43da7 100644 --- a/src/duckdb/extension/parquet/include/parquet_reader.hpp +++ b/src/duckdb/extension/parquet/include/parquet_reader.hpp @@ -105,7 +105,6 @@ struct ParquetOptions { explicit ParquetOptions(ClientContext &context); bool binary_as_string = false; - bool variant_legacy_encoding = false; bool file_row_number = false; shared_ptr encryption_config; bool debug_use_openssl = true; @@ -166,14 +165,14 @@ class ParquetReader : public BaseFileReader { bool TryInitializeScan(ClientContext &context, GlobalTableFunctionState &gstate, LocalTableFunctionState &lstate) override; - void Scan(ClientContext &context, GlobalTableFunctionState &global_state, LocalTableFunctionState &local_state, - DataChunk &chunk) override; + AsyncResult Scan(ClientContext &context, GlobalTableFunctionState &global_state, + LocalTableFunctionState &local_state, DataChunk &chunk) override; void FinishFile(ClientContext &context, GlobalTableFunctionState &gstate_p) override; double GetProgressInFile(ClientContext &context) override; public: void InitializeScan(ClientContext &context, ParquetReaderScanState &state, vector groups_to_read); - void Scan(ClientContext &context, ParquetReaderScanState &state, DataChunk &output); + AsyncResult Scan(ClientContext &context, ParquetReaderScanState &state, DataChunk &output); idx_t NumRows() const; idx_t NumRowGroups() const; @@ -209,7 +208,6 @@ class ParquetReader : public BaseFileReader { shared_ptr metadata); void InitializeSchema(ClientContext &context); - bool ScanInternal(ClientContext &context, ParquetReaderScanState &state, DataChunk &output); //! Parse the schema of the file unique_ptr ParseSchema(ClientContext &context); ParquetColumnSchema ParseSchemaRecursive(idx_t depth, idx_t max_define, idx_t max_repeat, idx_t &next_schema_idx, diff --git a/src/duckdb/extension/parquet/include/parquet_shredding.hpp b/src/duckdb/extension/parquet/include/parquet_shredding.hpp new file mode 100644 index 000000000..f43cbc42c --- /dev/null +++ b/src/duckdb/extension/parquet/include/parquet_shredding.hpp @@ -0,0 +1,49 @@ +#pragma once + +#include "duckdb/common/serializer/buffered_file_writer.hpp" +#include "duckdb/common/case_insensitive_map.hpp" +#include "duckdb/common/types/variant.hpp" + +namespace duckdb { + +struct ShreddingType; + +struct ChildShreddingTypes { +public: + ChildShreddingTypes(); + +public: + ChildShreddingTypes Copy() const; + +public: + void Serialize(Serializer &serializer) const; + static ChildShreddingTypes Deserialize(Deserializer &source); + +public: + unique_ptr> types; +}; + +struct ShreddingType { +public: + ShreddingType(); + explicit ShreddingType(const LogicalType &type); + +public: + ShreddingType Copy() const; + +public: + void Serialize(Serializer &serializer) const; + static ShreddingType Deserialize(Deserializer &source); + +public: + static ShreddingType GetShreddingTypes(const Value &val); + void AddChild(const string &name, ShreddingType &&child); + optional_ptr GetChild(const string &name) const; + +public: + bool set = false; + LogicalType type; + ChildShreddingTypes children; +}; + +} // namespace duckdb diff --git a/src/duckdb/extension/parquet/include/parquet_statistics.hpp b/src/duckdb/extension/parquet/include/parquet_statistics.hpp index cb05dae3b..e138d9763 100644 --- a/src/duckdb/extension/parquet/include/parquet_statistics.hpp +++ b/src/duckdb/extension/parquet/include/parquet_statistics.hpp @@ -23,7 +23,6 @@ struct ParquetColumnSchema; class ResizeableBuffer; struct ParquetStatisticsUtils { - static unique_ptr TransformColumnStatistics(const ParquetColumnSchema &reader, const vector &columns, bool can_have_nan); diff --git a/src/duckdb/extension/parquet/include/parquet_support.hpp b/src/duckdb/extension/parquet/include/parquet_support.hpp index 91c43fcb4..0b00e6242 100644 --- a/src/duckdb/extension/parquet/include/parquet_support.hpp +++ b/src/duckdb/extension/parquet/include/parquet_support.hpp @@ -118,7 +118,6 @@ class StripeStreams { }; class ColumnReader { - public: ColumnReader(const EncodingKey &ek, StripeStreams &stripe); diff --git a/src/duckdb/extension/parquet/include/parquet_writer.hpp b/src/duckdb/extension/parquet/include/parquet_writer.hpp index a2bfc3a80..aa5874ed5 100644 --- a/src/duckdb/extension/parquet/include/parquet_writer.hpp +++ b/src/duckdb/extension/parquet/include/parquet_writer.hpp @@ -21,8 +21,10 @@ #include "parquet_statistics.hpp" #include "column_writer.hpp" +#include "parquet_field_id.hpp" +#include "parquet_shredding.hpp" #include "parquet_types.h" -#include "geo_parquet.hpp" +#include "parquet_geometry.hpp" #include "writer/parquet_write_stats.hpp" #include "thrift/protocol/TCompactProtocol.h" @@ -43,29 +45,6 @@ struct PreparedRowGroup { vector> states; }; -struct FieldID; -struct ChildFieldIDs { - ChildFieldIDs(); - ChildFieldIDs Copy() const; - unique_ptr> ids; - - void Serialize(Serializer &serializer) const; - static ChildFieldIDs Deserialize(Deserializer &source); -}; - -struct FieldID { - static constexpr const auto DUCKDB_FIELD_ID = "__duckdb_field_id"; - FieldID(); - explicit FieldID(int32_t field_id); - FieldID Copy() const; - bool set; - int32_t field_id; - ChildFieldIDs child_field_ids; - - void Serialize(Serializer &serializer) const; - static FieldID Deserialize(Deserializer &source); -}; - struct ParquetBloomFilterEntry { unique_ptr bloom_filter; idx_t row_group_idx; @@ -81,11 +60,11 @@ class ParquetWriter { public: ParquetWriter(ClientContext &context, FileSystem &fs, string file_name, vector types, vector names, duckdb_parquet::CompressionCodec::type codec, ChildFieldIDs field_ids, - const vector> &kv_metadata, + ShreddingType shredding_types, const vector> &kv_metadata, shared_ptr encryption_config, optional_idx dictionary_size_limit, idx_t string_dictionary_page_size_limit, bool enable_bloom_filters, double bloom_filter_false_positive_ratio, int64_t compression_level, bool debug_use_openssl, - ParquetVersion parquet_version); + ParquetVersion parquet_version, GeoParquetVersion geoparquet_version); ~ParquetWriter(); public: @@ -95,7 +74,8 @@ class ParquetWriter { void Finalize(); static duckdb_parquet::Type::type DuckDBTypeToParquetType(const LogicalType &duckdb_type); - static void SetSchemaProperties(const LogicalType &duckdb_type, duckdb_parquet::SchemaElement &schema_ele); + static void SetSchemaProperties(const LogicalType &duckdb_type, duckdb_parquet::SchemaElement &schema_ele, + bool allow_geometry); ClientContext &GetContext() { return context; @@ -139,6 +119,9 @@ class ParquetWriter { ParquetVersion GetParquetVersion() const { return parquet_version; } + GeoParquetVersion GetGeoParquetVersion() const { + return geoparquet_version; + } const string &GetFileName() const { return file_name; } @@ -166,6 +149,7 @@ class ParquetWriter { vector column_names; duckdb_parquet::CompressionCodec::type codec; ChildFieldIDs field_ids; + ShreddingType shredding_types; shared_ptr encryption_config; optional_idx dictionary_size_limit; idx_t string_dictionary_page_size_limit; @@ -175,6 +159,7 @@ class ParquetWriter { bool debug_use_openssl; shared_ptr encryption_util; ParquetVersion parquet_version; + GeoParquetVersion geoparquet_version; vector column_schemas; unique_ptr writer; diff --git a/src/duckdb/extension/parquet/include/reader/interval_column_reader.hpp b/src/duckdb/extension/parquet/include/reader/interval_column_reader.hpp index 1ead9cf04..0f93bf9d5 100644 --- a/src/duckdb/extension/parquet/include/reader/interval_column_reader.hpp +++ b/src/duckdb/extension/parquet/include/reader/interval_column_reader.hpp @@ -57,7 +57,6 @@ struct IntervalValueConversion { }; class IntervalColumnReader : public TemplatedColumnReader { - public: IntervalColumnReader(ParquetReader &reader, const ParquetColumnSchema &schema) : TemplatedColumnReader(reader, schema) { diff --git a/src/duckdb/extension/parquet/include/reader/string_column_reader.hpp b/src/duckdb/extension/parquet/include/reader/string_column_reader.hpp index 4bc19516a..d0d18b80c 100644 --- a/src/duckdb/extension/parquet/include/reader/string_column_reader.hpp +++ b/src/duckdb/extension/parquet/include/reader/string_column_reader.hpp @@ -14,16 +14,30 @@ namespace duckdb { class StringColumnReader : public ColumnReader { +public: + enum class StringColumnType : uint8_t { VARCHAR, JSON, OTHER }; + + static StringColumnType GetStringColumnType(const LogicalType &type) { + if (type.IsJSONType()) { + return StringColumnType::JSON; + } + if (type.id() == LogicalTypeId::VARCHAR) { + return StringColumnType::VARCHAR; + } + return StringColumnType::OTHER; + } + public: static constexpr const PhysicalType TYPE = PhysicalType::VARCHAR; public: StringColumnReader(ParquetReader &reader, const ParquetColumnSchema &schema); idx_t fixed_width_string_length; + const StringColumnType string_column_type; public: static void VerifyString(const char *str_data, uint32_t str_len, const bool isVarchar); - void VerifyString(const char *str_data, uint32_t str_len); + void VerifyString(const char *str_data, uint32_t str_len) const; static void ReferenceBlock(Vector &result, shared_ptr &block); diff --git a/src/duckdb/extension/parquet/include/reader/templated_column_reader.hpp b/src/duckdb/extension/parquet/include/reader/templated_column_reader.hpp index 3bd0e96d6..b6bd55cc7 100644 --- a/src/duckdb/extension/parquet/include/reader/templated_column_reader.hpp +++ b/src/duckdb/extension/parquet/include/reader/templated_column_reader.hpp @@ -79,7 +79,6 @@ class TemplatedColumnReader : public ColumnReader { template struct CallbackParquetValueConversion { - template static DUCKDB_PHYSICAL_TYPE PlainRead(ByteBuffer &plain_data, ColumnReader &reader) { if (CHECKED) { diff --git a/src/duckdb/extension/parquet/include/reader/uuid_column_reader.hpp b/src/duckdb/extension/parquet/include/reader/uuid_column_reader.hpp index 86193d9a6..22d468d0f 100644 --- a/src/duckdb/extension/parquet/include/reader/uuid_column_reader.hpp +++ b/src/duckdb/extension/parquet/include/reader/uuid_column_reader.hpp @@ -50,7 +50,6 @@ struct UUIDValueConversion { }; class UUIDColumnReader : public TemplatedColumnReader { - public: UUIDColumnReader(ParquetReader &reader, const ParquetColumnSchema &schema) : TemplatedColumnReader(reader, schema) { diff --git a/src/duckdb/extension/parquet/include/reader/variant/variant_binary_decoder.hpp b/src/duckdb/extension/parquet/include/reader/variant/variant_binary_decoder.hpp index a7c717709..17efcd46e 100644 --- a/src/duckdb/extension/parquet/include/reader/variant/variant_binary_decoder.hpp +++ b/src/duckdb/extension/parquet/include/reader/variant/variant_binary_decoder.hpp @@ -137,10 +137,8 @@ class VariantBinaryDecoder { static VariantValue Decode(const VariantMetadata &metadata, const_data_ptr_t data); public: - static VariantValue PrimitiveTypeDecode(const VariantMetadata &metadata, const VariantValueMetadata &value_metadata, - const_data_ptr_t data); - static VariantValue ShortStringDecode(const VariantMetadata &metadata, const VariantValueMetadata &value_metadata, - const_data_ptr_t data); + static VariantValue PrimitiveTypeDecode(const VariantValueMetadata &value_metadata, const_data_ptr_t data); + static VariantValue ShortStringDecode(const VariantValueMetadata &value_metadata, const_data_ptr_t data); static VariantValue ObjectDecode(const VariantMetadata &metadata, const VariantValueMetadata &value_metadata, const_data_ptr_t data); static VariantValue ArrayDecode(const VariantMetadata &metadata, const VariantValueMetadata &value_metadata, diff --git a/src/duckdb/extension/parquet/include/reader/variant/variant_shredded_conversion.hpp b/src/duckdb/extension/parquet/include/reader/variant/variant_shredded_conversion.hpp index 27ece7d70..bbcf71792 100644 --- a/src/duckdb/extension/parquet/include/reader/variant/variant_shredded_conversion.hpp +++ b/src/duckdb/extension/parquet/include/reader/variant/variant_shredded_conversion.hpp @@ -11,13 +11,14 @@ class VariantShreddedConversion { public: static vector Convert(Vector &metadata, Vector &group, idx_t offset, idx_t length, idx_t total_size, - bool is_field = false); + bool is_field); static vector ConvertShreddedLeaf(Vector &metadata, Vector &value, Vector &typed_value, idx_t offset, - idx_t length, idx_t total_size); + idx_t length, idx_t total_size, const bool is_field); static vector ConvertShreddedArray(Vector &metadata, Vector &value, Vector &typed_value, idx_t offset, - idx_t length, idx_t total_size); + idx_t length, idx_t total_size, const bool is_field); static vector ConvertShreddedObject(Vector &metadata, Vector &value, Vector &typed_value, - idx_t offset, idx_t length, idx_t total_size); + idx_t offset, idx_t length, idx_t total_size, + const bool is_field); }; } // namespace duckdb diff --git a/src/duckdb/extension/parquet/include/reader/variant/variant_value.hpp b/src/duckdb/extension/parquet/include/reader/variant/variant_value.hpp index a4c38ede7..9d3f502c3 100644 --- a/src/duckdb/extension/parquet/include/reader/variant/variant_value.hpp +++ b/src/duckdb/extension/parquet/include/reader/variant/variant_value.hpp @@ -42,6 +42,7 @@ struct VariantValue { public: yyjson_mut_val *ToJSON(ClientContext &context, yyjson_mut_doc *doc) const; + static void ToVARIANT(vector &input, Vector &result); public: VariantValueType value_type; diff --git a/src/duckdb/extension/parquet/include/reader/variant_column_reader.hpp b/src/duckdb/extension/parquet/include/reader/variant_column_reader.hpp index 78670b14a..69b429626 100644 --- a/src/duckdb/extension/parquet/include/reader/variant_column_reader.hpp +++ b/src/duckdb/extension/parquet/include/reader/variant_column_reader.hpp @@ -15,7 +15,7 @@ namespace duckdb { class VariantColumnReader : public ColumnReader { public: - static constexpr const PhysicalType TYPE = PhysicalType::VARCHAR; + static constexpr const PhysicalType TYPE = PhysicalType::STRUCT; public: VariantColumnReader(ClientContext &context, ParquetReader &reader, const ParquetColumnSchema &schema, diff --git a/src/duckdb/extension/parquet/include/writer/list_column_writer.hpp b/src/duckdb/extension/parquet/include/writer/list_column_writer.hpp index f1070b0f1..902d3001c 100644 --- a/src/duckdb/extension/parquet/include/writer/list_column_writer.hpp +++ b/src/duckdb/extension/parquet/include/writer/list_column_writer.hpp @@ -28,13 +28,11 @@ class ListColumnWriter : public ColumnWriter { public: ListColumnWriter(ParquetWriter &writer, const ParquetColumnSchema &column_schema, vector schema_path_p, unique_ptr child_writer_p, bool can_have_nulls) - : ColumnWriter(writer, column_schema, std::move(schema_path_p), can_have_nulls), - child_writer(std::move(child_writer_p)) { + : ColumnWriter(writer, column_schema, std::move(schema_path_p), can_have_nulls) { + child_writers.push_back(std::move(child_writer_p)); } ~ListColumnWriter() override = default; - unique_ptr child_writer; - public: unique_ptr InitializeWriteState(duckdb_parquet::RowGroup &row_group) override; bool HasAnalyze() override; @@ -46,6 +44,9 @@ class ListColumnWriter : public ColumnWriter { void BeginWrite(ColumnWriterState &state) override; void Write(ColumnWriterState &state, Vector &vector, idx_t count) override; void FinalizeWrite(ColumnWriterState &state) override; + +protected: + ColumnWriter &GetChildWriter(); }; } // namespace duckdb diff --git a/src/duckdb/extension/parquet/include/writer/parquet_write_stats.hpp b/src/duckdb/extension/parquet/include/writer/parquet_write_stats.hpp index 1016c81fe..840830e3a 100644 --- a/src/duckdb/extension/parquet/include/writer/parquet_write_stats.hpp +++ b/src/duckdb/extension/parquet/include/writer/parquet_write_stats.hpp @@ -9,7 +9,7 @@ #pragma once #include "column_writer.hpp" -#include "geo_parquet.hpp" +#include "parquet_geometry.hpp" namespace duckdb { @@ -28,7 +28,7 @@ class ColumnWriterStatistics { virtual bool MaxIsExact(); virtual bool HasGeoStats(); - virtual optional_ptr GetGeoStats(); + virtual optional_ptr GetGeoStats(); virtual void WriteGeoStats(duckdb_parquet::GeospatialStatistics &stats); public: @@ -255,10 +255,11 @@ class UUIDStatisticsState : public ColumnWriterStatistics { class GeoStatisticsState final : public ColumnWriterStatistics { public: explicit GeoStatisticsState() : has_stats(false) { + geo_stats.SetEmpty(); } bool has_stats; - GeometryStats geo_stats; + GeometryStatsData geo_stats; public: void Update(const string_t &val) { @@ -268,37 +269,36 @@ class GeoStatisticsState final : public ColumnWriterStatistics { bool HasGeoStats() override { return has_stats; } - optional_ptr GetGeoStats() override { + optional_ptr GetGeoStats() override { return geo_stats; } void WriteGeoStats(duckdb_parquet::GeospatialStatistics &stats) override { const auto &types = geo_stats.types; - const auto &bbox = geo_stats.bbox; - - if (bbox.IsSet()) { + const auto &bbox = geo_stats.extent; + if (bbox.HasXY()) { stats.__isset.bbox = true; - stats.bbox.xmin = bbox.xmin; - stats.bbox.xmax = bbox.xmax; - stats.bbox.ymin = bbox.ymin; - stats.bbox.ymax = bbox.ymax; + stats.bbox.xmin = bbox.x_min; + stats.bbox.xmax = bbox.x_max; + stats.bbox.ymin = bbox.y_min; + stats.bbox.ymax = bbox.y_max; if (bbox.HasZ()) { stats.bbox.__isset.zmin = true; stats.bbox.__isset.zmax = true; - stats.bbox.zmin = bbox.zmin; - stats.bbox.zmax = bbox.zmax; + stats.bbox.zmin = bbox.z_min; + stats.bbox.zmax = bbox.z_max; } if (bbox.HasM()) { stats.bbox.__isset.mmin = true; stats.bbox.__isset.mmax = true; - stats.bbox.mmin = bbox.mmin; - stats.bbox.mmax = bbox.mmax; + stats.bbox.mmin = bbox.m_min; + stats.bbox.mmax = bbox.m_max; } } stats.__isset.geospatial_types = true; - stats.geospatial_types = types.ToList(); + stats.geospatial_types = types.ToWKBList(); } }; diff --git a/src/duckdb/extension/parquet/include/writer/struct_column_writer.hpp b/src/duckdb/extension/parquet/include/writer/struct_column_writer.hpp index 8927c391b..bbb6cd06b 100644 --- a/src/duckdb/extension/parquet/include/writer/struct_column_writer.hpp +++ b/src/duckdb/extension/parquet/include/writer/struct_column_writer.hpp @@ -16,13 +16,11 @@ class StructColumnWriter : public ColumnWriter { public: StructColumnWriter(ParquetWriter &writer, const ParquetColumnSchema &column_schema, vector schema_path_p, vector> child_writers_p, bool can_have_nulls) - : ColumnWriter(writer, column_schema, std::move(schema_path_p), can_have_nulls), - child_writers(std::move(child_writers_p)) { + : ColumnWriter(writer, column_schema, std::move(schema_path_p), can_have_nulls) { + child_writers = std::move(child_writers_p); } ~StructColumnWriter() override = default; - vector> child_writers; - public: unique_ptr InitializeWriteState(duckdb_parquet::RowGroup &row_group) override; bool HasAnalyze() override; diff --git a/src/duckdb/extension/parquet/include/writer/templated_column_writer.hpp b/src/duckdb/extension/parquet/include/writer/templated_column_writer.hpp index c035bba43..c0dfa12ae 100644 --- a/src/duckdb/extension/parquet/include/writer/templated_column_writer.hpp +++ b/src/duckdb/extension/parquet/include/writer/templated_column_writer.hpp @@ -197,6 +197,7 @@ class StandardColumnWriter : public PrimitiveColumnWriter { const bool check_parent_empty = parent && !parent->is_empty.empty(); const idx_t parent_index = state.definition_levels.size(); + D_ASSERT(!check_parent_empty || parent_index < parent->is_empty.size()); const idx_t vcount = check_parent_empty ? parent->definition_levels.size() - state.definition_levels.size() : count; @@ -207,7 +208,7 @@ class StandardColumnWriter : public PrimitiveColumnWriter { // Fast path for (; vector_index < vcount; vector_index++) { const auto &src_value = data_ptr[vector_index]; - state.dictionary.Insert(src_value); + state.dictionary.template Insert(src_value); state.total_value_count++; state.total_string_size += DlbaEncoder::GetStringSize(src_value); } @@ -218,7 +219,7 @@ class StandardColumnWriter : public PrimitiveColumnWriter { } if (validity.RowIsValid(vector_index)) { const auto &src_value = data_ptr[vector_index]; - state.dictionary.Insert(src_value); + state.dictionary.template Insert(src_value); state.total_value_count++; state.total_string_size += DlbaEncoder::GetStringSize(src_value); } diff --git a/src/duckdb/extension/parquet/include/writer/variant_column_writer.hpp b/src/duckdb/extension/parquet/include/writer/variant_column_writer.hpp new file mode 100644 index 000000000..74fdda608 --- /dev/null +++ b/src/duckdb/extension/parquet/include/writer/variant_column_writer.hpp @@ -0,0 +1,30 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// writer/variant_column_writer.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "struct_column_writer.hpp" +#include "duckdb/planner/expression/bound_function_expression.hpp" + +namespace duckdb { + +class VariantColumnWriter : public StructColumnWriter { +public: + VariantColumnWriter(ParquetWriter &writer, const ParquetColumnSchema &column_schema, vector schema_path_p, + vector> child_writers_p, bool can_have_nulls) + : StructColumnWriter(writer, column_schema, std::move(schema_path_p), std::move(child_writers_p), + can_have_nulls) { + } + ~VariantColumnWriter() override = default; + +public: + static ScalarFunction GetTransformFunction(); + static LogicalType TransformTypedValueRecursive(const LogicalType &type); +}; + +} // namespace duckdb diff --git a/src/duckdb/extension/parquet/parquet_crypto.cpp b/src/duckdb/extension/parquet/parquet_crypto.cpp index b60c01155..959dbee27 100644 --- a/src/duckdb/extension/parquet/parquet_crypto.cpp +++ b/src/duckdb/extension/parquet/parquet_crypto.cpp @@ -198,7 +198,6 @@ class DecryptionTransport : public TTransport { } uint32_t Finalize() { - if (read_buffer_offset != read_buffer_size) { throw InternalException("DecryptionTransport::Finalize was called with bytes remaining in read buffer: \n" "read buffer offset: %d, read buffer size: %d", diff --git a/src/duckdb/extension/parquet/parquet_extension.cpp b/src/duckdb/extension/parquet/parquet_extension.cpp index 37e6cd0b7..7171a8ece 100644 --- a/src/duckdb/extension/parquet/parquet_extension.cpp +++ b/src/duckdb/extension/parquet/parquet_extension.cpp @@ -7,14 +7,16 @@ #include "duckdb/parser/tableref/subqueryref.hpp" #include "duckdb/planner/operator/logical_projection.hpp" #include "duckdb/planner/query_node/bound_select_node.hpp" -#include "geo_parquet.hpp" +#include "parquet_geometry.hpp" #include "parquet_crypto.hpp" #include "parquet_metadata.hpp" #include "parquet_reader.hpp" #include "parquet_writer.hpp" +#include "parquet_shredding.hpp" #include "reader/struct_column_reader.hpp" #include "zstd_file_system.hpp" #include "writer/primitive_column_writer.hpp" +#include "writer/variant_column_writer.hpp" #include #include @@ -43,6 +45,9 @@ #include "duckdb/parser/parsed_data/create_table_function_info.hpp" #include "duckdb/parser/tableref/table_function_ref.hpp" #include "duckdb/planner/expression/bound_cast_expression.hpp" +#include "duckdb/planner/expression/bound_function_expression.hpp" +#include "duckdb/planner/expression/bound_reference_expression.hpp" +#include "duckdb/planner/expression/bound_constant_expression.hpp" #include "duckdb/planner/operator/logical_get.hpp" #include "duckdb/storage/statistics/base_statistics.hpp" #include "duckdb/storage/table/row_group.hpp" @@ -54,156 +59,6 @@ namespace duckdb { -static case_insensitive_map_t GetChildNameToTypeMap(const LogicalType &type) { - case_insensitive_map_t name_to_type_map; - switch (type.id()) { - case LogicalTypeId::LIST: - name_to_type_map.emplace("element", ListType::GetChildType(type)); - break; - case LogicalTypeId::MAP: - name_to_type_map.emplace("key", MapType::KeyType(type)); - name_to_type_map.emplace("value", MapType::ValueType(type)); - break; - case LogicalTypeId::STRUCT: - for (auto &child_type : StructType::GetChildTypes(type)) { - if (child_type.first == FieldID::DUCKDB_FIELD_ID) { - throw BinderException("Cannot have column named \"%s\" with FIELD_IDS", FieldID::DUCKDB_FIELD_ID); - } - name_to_type_map.emplace(child_type); - } - break; - default: // LCOV_EXCL_START - throw InternalException("Unexpected type in GetChildNameToTypeMap"); - } // LCOV_EXCL_STOP - return name_to_type_map; -} - -static void GetChildNamesAndTypes(const LogicalType &type, vector &child_names, - vector &child_types) { - switch (type.id()) { - case LogicalTypeId::LIST: - child_names.emplace_back("element"); - child_types.emplace_back(ListType::GetChildType(type)); - break; - case LogicalTypeId::MAP: - child_names.emplace_back("key"); - child_names.emplace_back("value"); - child_types.emplace_back(MapType::KeyType(type)); - child_types.emplace_back(MapType::ValueType(type)); - break; - case LogicalTypeId::STRUCT: - for (auto &child_type : StructType::GetChildTypes(type)) { - child_names.emplace_back(child_type.first); - child_types.emplace_back(child_type.second); - } - break; - default: // LCOV_EXCL_START - throw InternalException("Unexpected type in GetChildNamesAndTypes"); - } // LCOV_EXCL_STOP -} - -static void GenerateFieldIDs(ChildFieldIDs &field_ids, idx_t &field_id, const vector &names, - const vector &sql_types) { - D_ASSERT(names.size() == sql_types.size()); - for (idx_t col_idx = 0; col_idx < names.size(); col_idx++) { - const auto &col_name = names[col_idx]; - auto inserted = field_ids.ids->insert(make_pair(col_name, FieldID(UnsafeNumericCast(field_id++)))); - D_ASSERT(inserted.second); - - const auto &col_type = sql_types[col_idx]; - if (col_type.id() != LogicalTypeId::LIST && col_type.id() != LogicalTypeId::MAP && - col_type.id() != LogicalTypeId::STRUCT) { - continue; - } - - // Cannot use GetChildNameToTypeMap here because we lose order, and we want to generate depth-first - vector child_names; - vector child_types; - GetChildNamesAndTypes(col_type, child_names, child_types); - - GenerateFieldIDs(inserted.first->second.child_field_ids, field_id, child_names, child_types); - } -} - -static void GetFieldIDs(const Value &field_ids_value, ChildFieldIDs &field_ids, - unordered_set &unique_field_ids, - const case_insensitive_map_t &name_to_type_map) { - const auto &struct_type = field_ids_value.type(); - if (struct_type.id() != LogicalTypeId::STRUCT) { - throw BinderException( - "Expected FIELD_IDS to be a STRUCT, e.g., {col1: 42, col2: {%s: 43, nested_col: 44}, col3: 44}", - FieldID::DUCKDB_FIELD_ID); - } - const auto &struct_children = StructValue::GetChildren(field_ids_value); - D_ASSERT(StructType::GetChildTypes(struct_type).size() == struct_children.size()); - for (idx_t i = 0; i < struct_children.size(); i++) { - const auto &col_name = StringUtil::Lower(StructType::GetChildName(struct_type, i)); - if (col_name == FieldID::DUCKDB_FIELD_ID) { - continue; - } - - auto it = name_to_type_map.find(col_name); - if (it == name_to_type_map.end()) { - string names; - for (const auto &name : name_to_type_map) { - if (!names.empty()) { - names += ", "; - } - names += name.first; - } - throw BinderException( - "Column name \"%s\" specified in FIELD_IDS not found. Consider using WRITE_PARTITION_COLUMNS if this " - "column is a partition column. Available column names: [%s]", - col_name, names); - } - D_ASSERT(field_ids.ids->find(col_name) == field_ids.ids->end()); // Caught by STRUCT - deduplicates keys - - const auto &child_value = struct_children[i]; - const auto &child_type = child_value.type(); - optional_ptr field_id_value; - optional_ptr child_field_ids_value; - - if (child_type.id() == LogicalTypeId::STRUCT) { - const auto &nested_children = StructValue::GetChildren(child_value); - D_ASSERT(StructType::GetChildTypes(child_type).size() == nested_children.size()); - for (idx_t nested_i = 0; nested_i < nested_children.size(); nested_i++) { - const auto &field_id_or_nested_col = StructType::GetChildName(child_type, nested_i); - if (field_id_or_nested_col == FieldID::DUCKDB_FIELD_ID) { - field_id_value = &nested_children[nested_i]; - } else { - child_field_ids_value = &child_value; - } - } - } else { - field_id_value = &child_value; - } - - FieldID field_id; - if (field_id_value) { - Value field_id_integer_value = field_id_value->DefaultCastAs(LogicalType::INTEGER); - const uint32_t field_id_int = IntegerValue::Get(field_id_integer_value); - if (!unique_field_ids.insert(field_id_int).second) { - throw BinderException("Duplicate field_id %s found in FIELD_IDS", field_id_integer_value.ToString()); - } - field_id = FieldID(UnsafeNumericCast(field_id_int)); - } - auto inserted = field_ids.ids->insert(make_pair(col_name, std::move(field_id))); - D_ASSERT(inserted.second); - - if (child_field_ids_value) { - const auto &col_type = it->second; - if (col_type.id() != LogicalTypeId::LIST && col_type.id() != LogicalTypeId::MAP && - col_type.id() != LogicalTypeId::STRUCT) { - throw BinderException("Column \"%s\" with type \"%s\" cannot have a nested FIELD_IDS specification", - col_name, LogicalTypeIdToString(col_type.id())); - } - - GetFieldIDs(*child_field_ids_value, inserted.first->second.child_field_ids, unique_field_ids, - GetChildNameToTypeMap(col_type)); - } - } -} - struct ParquetWriteBindData : public TableFunctionData { vector sql_types; vector column_names; @@ -233,11 +88,15 @@ struct ParquetWriteBindData : public TableFunctionData { optional_idx row_groups_per_file; ChildFieldIDs field_ids; + ShreddingType shredding_types; //! The compression level, higher value is more int64_t compression_level = ZStdFileSystem::DefaultCompressionLevel(); //! Which encodings to include when writing ParquetVersion parquet_version = ParquetVersion::V1; + + //! Which geo-parquet version to use when writing + GeoParquetVersion geoparquet_version = GeoParquetVersion::V1; }; struct ParquetWriteGlobalState : public GlobalFunctionData { @@ -291,6 +150,8 @@ static void ParquetListCopyOptions(ClientContext &context, CopyOptionsInput &inp copy_options["binary_as_string"] = CopyOption(LogicalType::BOOLEAN, CopyOptionMode::READ_ONLY); copy_options["file_row_number"] = CopyOption(LogicalType::BOOLEAN, CopyOptionMode::READ_ONLY); copy_options["can_have_nan"] = CopyOption(LogicalType::BOOLEAN, CopyOptionMode::READ_ONLY); + copy_options["geoparquet_version"] = CopyOption(LogicalType::VARCHAR, CopyOptionMode::WRITE_ONLY); + copy_options["shredding"] = CopyOption(LogicalType::ANY, CopyOptionMode::WRITE_ONLY); } static unique_ptr ParquetWriteBind(ClientContext &context, CopyFunctionBindInput &input, @@ -342,7 +203,7 @@ static unique_ptr ParquetWriteBind(ClientContext &context, CopyFun if (option.second[0].type().id() == LogicalTypeId::VARCHAR && StringUtil::Lower(StringValue::Get(option.second[0])) == "auto") { idx_t field_id = 0; - GenerateFieldIDs(bind_data->field_ids, field_id, names, sql_types); + FieldID::GenerateFieldIDs(bind_data->field_ids, field_id, names, sql_types); } else { unordered_set unique_field_ids; case_insensitive_map_t name_to_type_map; @@ -353,7 +214,57 @@ static unique_ptr ParquetWriteBind(ClientContext &context, CopyFun } name_to_type_map.emplace(names[col_idx], sql_types[col_idx]); } - GetFieldIDs(option.second[0], bind_data->field_ids, unique_field_ids, name_to_type_map); + FieldID::GetFieldIDs(option.second[0], bind_data->field_ids, unique_field_ids, name_to_type_map); + } + } else if (loption == "shredding") { + if (option.second[0].type().id() == LogicalTypeId::VARCHAR && + StringUtil::Lower(StringValue::Get(option.second[0])) == "auto") { + throw NotImplementedException("The 'auto' option is not yet implemented for 'shredding'"); + } else { + case_insensitive_set_t variant_names; + for (idx_t col_idx = 0; col_idx < names.size(); col_idx++) { + if (sql_types[col_idx].id() != LogicalTypeId::STRUCT) { + continue; + } + if (sql_types[col_idx].GetAlias() != "PARQUET_VARIANT") { + continue; + } + variant_names.emplace(names[col_idx]); + } + auto &shredding_types_value = option.second[0]; + if (shredding_types_value.type().id() != LogicalTypeId::STRUCT) { + BinderException("SHREDDING value should be a STRUCT of column names to types, i.e: {col1: " + "'INTEGER[]', col2: 'BOOLEAN'}"); + } + const auto &struct_type = shredding_types_value.type(); + const auto &struct_children = StructValue::GetChildren(shredding_types_value); + D_ASSERT(StructType::GetChildTypes(struct_type).size() == struct_children.size()); + for (idx_t i = 0; i < struct_children.size(); i++) { + const auto &col_name = StringUtil::Lower(StructType::GetChildName(struct_type, i)); + auto it = variant_names.find(col_name); + if (it == variant_names.end()) { + string names; + for (const auto &entry : variant_names) { + if (!names.empty()) { + names += ", "; + } + names += entry; + } + if (names.empty()) { + throw BinderException("VARIANT by name \"%s\" specified in SHREDDING not found. There are " + "no VARIANT columns present.", + col_name); + } else { + throw BinderException( + "VARIANT by name \"%s\" specified in SHREDDING not found. Consider using " + "WRITE_PARTITION_COLUMNS if this " + "column is a partition column. Available names of VARIANT columns: [%s]", + col_name, names); + } + } + const auto &child_value = struct_children[i]; + bind_data->shredding_types.AddChild(col_name, ShreddingType::GetShreddingTypes(child_value)); + } } } else if (loption == "kv_metadata") { auto &kv_struct = option.second[0]; @@ -426,6 +337,19 @@ static unique_ptr ParquetWriteBind(ClientContext &context, CopyFun } else { throw BinderException("Expected parquet_version 'V1' or 'V2'"); } + } else if (loption == "geoparquet_version") { + const auto roption = StringUtil::Upper(option.second[0].ToString()); + if (roption == "NONE") { + bind_data->geoparquet_version = GeoParquetVersion::NONE; + } else if (roption == "V1") { + bind_data->geoparquet_version = GeoParquetVersion::V1; + } else if (roption == "V2") { + bind_data->geoparquet_version = GeoParquetVersion::V2; + } else if (roption == "BOTH") { + bind_data->geoparquet_version = GeoParquetVersion::BOTH; + } else { + throw BinderException("Expected geoparquet_version 'NONE', 'V1' or 'BOTH'"); + } } else { throw InternalException("Unrecognized option for PARQUET: %s", option.first.c_str()); } @@ -454,10 +378,11 @@ static unique_ptr ParquetWriteInitializeGlobal(ClientContext auto &fs = FileSystem::GetFileSystem(context); global_state->writer = make_uniq( context, fs, file_path, parquet_bind.sql_types, parquet_bind.column_names, parquet_bind.codec, - parquet_bind.field_ids.Copy(), parquet_bind.kv_metadata, parquet_bind.encryption_config, - parquet_bind.dictionary_size_limit, parquet_bind.string_dictionary_page_size_limit, - parquet_bind.enable_bloom_filters, parquet_bind.bloom_filter_false_positive_ratio, - parquet_bind.compression_level, parquet_bind.debug_use_openssl, parquet_bind.parquet_version); + parquet_bind.field_ids.Copy(), parquet_bind.shredding_types.Copy(), parquet_bind.kv_metadata, + parquet_bind.encryption_config, parquet_bind.dictionary_size_limit, + parquet_bind.string_dictionary_page_size_limit, parquet_bind.enable_bloom_filters, + parquet_bind.bloom_filter_false_positive_ratio, parquet_bind.compression_level, parquet_bind.debug_use_openssl, + parquet_bind.parquet_version, parquet_bind.geoparquet_version); return std::move(global_state); } @@ -626,6 +551,39 @@ ParquetVersion EnumUtil::FromString(const char *value) { throw NotImplementedException(StringUtil::Format("Enum value: '%s' not implemented", value)); } +template <> +const char *EnumUtil::ToChars(GeoParquetVersion value) { + switch (value) { + case GeoParquetVersion::NONE: + return "NONE"; + case GeoParquetVersion::V1: + return "V1"; + case GeoParquetVersion::V2: + return "V2"; + case GeoParquetVersion::BOTH: + return "BOTH"; + default: + throw NotImplementedException(StringUtil::Format("Enum value: '%s' not implemented", value)); + } +} + +template <> +GeoParquetVersion EnumUtil::FromString(const char *value) { + if (StringUtil::Equals(value, "NONE")) { + return GeoParquetVersion::NONE; + } + if (StringUtil::Equals(value, "V1")) { + return GeoParquetVersion::V1; + } + if (StringUtil::Equals(value, "V2")) { + return GeoParquetVersion::V2; + } + if (StringUtil::Equals(value, "BOTH")) { + return GeoParquetVersion::BOTH; + } + throw NotImplementedException(StringUtil::Format("Enum value: '%s' not implemented", value)); +} + static optional_idx SerializeCompressionLevel(const int64_t compression_level) { return compression_level < 0 ? NumericLimits::Maximum() - NumericCast(AbsValue(compression_level)) : NumericCast(compression_level); @@ -679,6 +637,9 @@ static void ParquetCopySerialize(Serializer &serializer, const FunctionData &bin serializer.WritePropertyWithDefault(115, "string_dictionary_page_size_limit", bind_data.string_dictionary_page_size_limit, default_value.string_dictionary_page_size_limit); + serializer.WritePropertyWithDefault(116, "geoparquet_version", bind_data.geoparquet_version, + default_value.geoparquet_version); + serializer.WriteProperty(117, "shredding_types", bind_data.shredding_types); } static unique_ptr ParquetCopyDeserialize(Deserializer &deserializer, CopyFunction &function) { @@ -711,6 +672,9 @@ static unique_ptr ParquetCopyDeserialize(Deserializer &deserialize deserializer.ReadPropertyWithExplicitDefault(114, "parquet_version", default_value.parquet_version); data->string_dictionary_page_size_limit = deserializer.ReadPropertyWithExplicitDefault( 115, "string_dictionary_page_size_limit", default_value.string_dictionary_page_size_limit); + data->geoparquet_version = + deserializer.ReadPropertyWithExplicitDefault(116, "geoparquet_version", default_value.geoparquet_version); + data->shredding_types = deserializer.ReadProperty(117, "shredding_types"); return std::move(data); } @@ -828,8 +792,52 @@ static bool IsTypeLossy(const LogicalType &type) { return type.id() == LogicalTypeId::HUGEINT || type.id() == LogicalTypeId::UHUGEINT; } -static vector> ParquetWriteSelect(CopyToSelectInput &input) { +static bool IsExtensionGeometryType(const LogicalType &type, ClientContext &context) { + if (type.id() != LogicalTypeId::BLOB) { + return false; + } + if (!type.HasAlias()) { + return false; + } + if (type.GetAlias() != "GEOMETRY") { + return false; + } + return GeoParquetFileMetadata::IsGeoParquetConversionEnabled(context); +} + +static string GetShredding(case_insensitive_map_t> &options, const string &col_name) { + //! At this point, the options haven't been parsed yet, so we have to parse them ourselves. + auto it = options.find("shredding"); + if (it == options.end()) { + return string(); + } + auto &shredding = it->second; + if (shredding.empty()) { + return string(); + } + + auto &shredding_val = shredding[0]; + if (shredding_val.type().id() != LogicalTypeId::STRUCT) { + return string(); + } + + auto &shredded_variants = StructType::GetChildTypes(shredding_val.type()); + auto &values = StructValue::GetChildren(shredding_val); + for (idx_t i = 0; i < shredded_variants.size(); i++) { + auto &shredded_variant = shredded_variants[i]; + if (shredded_variant.first != col_name) { + continue; + } + auto &shredded_val = values[i]; + if (shredded_val.type().id() != LogicalTypeId::VARCHAR) { + return string(); + } + return shredded_val.GetValue(); + } + return string(); +} +static vector> ParquetWriteSelect(CopyToSelectInput &input) { auto &context = input.context; vector> result; @@ -837,22 +845,35 @@ static vector> ParquetWriteSelect(CopyToSelectInput &inpu bool any_change = false; for (auto &expr : input.select_list) { - const auto &type = expr->return_type; const auto &name = expr->GetAlias(); // Spatial types need to be encoded into WKB when writing GeoParquet. // But dont perform this conversion if this is a EXPORT DATABASE statement - if (input.copy_to_type == CopyToType::COPY_TO_FILE && type.id() == LogicalTypeId::BLOB && type.HasAlias() && - type.GetAlias() == "GEOMETRY" && GeoParquetFileMetadata::IsGeoParquetConversionEnabled(context)) { - - LogicalType wkb_blob_type(LogicalTypeId::BLOB); - wkb_blob_type.SetAlias("WKB_BLOB"); - - auto cast_expr = BoundCastExpression::AddCastToType(context, std::move(expr), wkb_blob_type, false); + if (input.copy_to_type == CopyToType::COPY_TO_FILE && IsExtensionGeometryType(type, context)) { + // Cast the column to GEOMETRY + auto cast_expr = + BoundCastExpression::AddCastToType(context, std::move(expr), LogicalType::GEOMETRY(), false); cast_expr->SetAlias(name); result.push_back(std::move(cast_expr)); any_change = true; + } else if (input.copy_to_type == CopyToType::COPY_TO_FILE && type.id() == LogicalTypeId::VARIANT) { + vector> arguments; + arguments.push_back(std::move(expr)); + + auto shredded_type_str = GetShredding(input.options, name); + if (!shredded_type_str.empty()) { + arguments.push_back(make_uniq(Value(shredded_type_str))); + } + + auto transform_func = VariantColumnWriter::GetTransformFunction(); + transform_func.bind(context, transform_func, arguments); + + auto func_expr = make_uniq(transform_func.GetReturnType(), transform_func, + std::move(arguments), nullptr, false); + func_expr->SetAlias(name); + result.push_back(std::move(func_expr)); + any_change = true; } // If this is an EXPORT DATABASE statement, we dont want to write "lossy" types, instead cast them to VARCHAR else if (input.copy_to_type == CopyToType::EXPORT_DATABASE && TypeVisitor::Contains(type, IsTypeLossy)) { @@ -924,6 +945,13 @@ static void LoadInternal(ExtensionLoader &loader) { ParquetBloomProbeFunction bloom_probe_fun; loader.RegisterFunction(MultiFileReader::CreateFunctionSet(bloom_probe_fun)); + // parquet_full_metadata + ParquetFullMetadataFunction full_meta_fun; + loader.RegisterFunction(MultiFileReader::CreateFunctionSet(full_meta_fun)); + + // variant_to_parquet_variant + loader.RegisterFunction(VariantColumnWriter::GetTransformFunction()); + CopyFunction function("parquet"); function.copy_to_select = ParquetWriteSelect; function.copy_to_bind = ParquetWriteBind; @@ -970,9 +998,6 @@ static void LoadInternal(ExtensionLoader &loader) { "enable_geoparquet_conversion", "Attempt to decode/encode geometry data in/as GeoParquet files if the spatial extension is present.", LogicalType::BOOLEAN, Value::BOOLEAN(true)); - config.AddExtensionOption("variant_legacy_encoding", - "Enables the Parquet reader to identify a Variant structurally.", LogicalType::BOOLEAN, - Value::BOOLEAN(false)); } void ParquetExtension::Load(ExtensionLoader &loader) { diff --git a/src/duckdb/extension/parquet/parquet_field_id.cpp b/src/duckdb/extension/parquet/parquet_field_id.cpp new file mode 100644 index 000000000..642fc26c7 --- /dev/null +++ b/src/duckdb/extension/parquet/parquet_field_id.cpp @@ -0,0 +1,180 @@ +#include "parquet_field_id.hpp" +#include "duckdb/common/exception/binder_exception.hpp" + +namespace duckdb { + +constexpr const char *FieldID::DUCKDB_FIELD_ID; + +ChildFieldIDs::ChildFieldIDs() : ids(make_uniq>()) { +} + +ChildFieldIDs ChildFieldIDs::Copy() const { + ChildFieldIDs result; + for (const auto &id : *ids) { + result.ids->emplace(id.first, id.second.Copy()); + } + return result; +} + +FieldID::FieldID() : set(false) { +} + +FieldID::FieldID(int32_t field_id_p) : set(true), field_id(field_id_p) { +} + +FieldID FieldID::Copy() const { + auto result = set ? FieldID(field_id) : FieldID(); + result.child_field_ids = child_field_ids.Copy(); + return result; +} + +static case_insensitive_map_t GetChildNameToTypeMap(const LogicalType &type) { + case_insensitive_map_t name_to_type_map; + switch (type.id()) { + case LogicalTypeId::LIST: + name_to_type_map.emplace("element", ListType::GetChildType(type)); + break; + case LogicalTypeId::MAP: + name_to_type_map.emplace("key", MapType::KeyType(type)); + name_to_type_map.emplace("value", MapType::ValueType(type)); + break; + case LogicalTypeId::STRUCT: + for (auto &child_type : StructType::GetChildTypes(type)) { + if (child_type.first == FieldID::DUCKDB_FIELD_ID) { + throw BinderException("Cannot have column named \"%s\" with FIELD_IDS", FieldID::DUCKDB_FIELD_ID); + } + name_to_type_map.emplace(child_type); + } + break; + default: // LCOV_EXCL_START + throw InternalException("Unexpected type in GetChildNameToTypeMap"); + } // LCOV_EXCL_STOP + return name_to_type_map; +} + +static void GetChildNamesAndTypes(const LogicalType &type, vector &child_names, + vector &child_types) { + switch (type.id()) { + case LogicalTypeId::LIST: + child_names.emplace_back("element"); + child_types.emplace_back(ListType::GetChildType(type)); + break; + case LogicalTypeId::MAP: + child_names.emplace_back("key"); + child_names.emplace_back("value"); + child_types.emplace_back(MapType::KeyType(type)); + child_types.emplace_back(MapType::ValueType(type)); + break; + case LogicalTypeId::STRUCT: + for (auto &child_type : StructType::GetChildTypes(type)) { + child_names.emplace_back(child_type.first); + child_types.emplace_back(child_type.second); + } + break; + default: // LCOV_EXCL_START + throw InternalException("Unexpected type in GetChildNamesAndTypes"); + } // LCOV_EXCL_STOP +} + +void FieldID::GenerateFieldIDs(ChildFieldIDs &field_ids, idx_t &field_id, const vector &names, + const vector &sql_types) { + D_ASSERT(names.size() == sql_types.size()); + for (idx_t col_idx = 0; col_idx < names.size(); col_idx++) { + const auto &col_name = names[col_idx]; + auto inserted = field_ids.ids->insert(make_pair(col_name, FieldID(UnsafeNumericCast(field_id++)))); + D_ASSERT(inserted.second); + + const auto &col_type = sql_types[col_idx]; + if (col_type.id() != LogicalTypeId::LIST && col_type.id() != LogicalTypeId::MAP && + col_type.id() != LogicalTypeId::STRUCT) { + continue; + } + + // Cannot use GetChildNameToTypeMap here because we lose order, and we want to generate depth-first + vector child_names; + vector child_types; + GetChildNamesAndTypes(col_type, child_names, child_types); + GenerateFieldIDs(inserted.first->second.child_field_ids, field_id, child_names, child_types); + } +} + +void FieldID::GetFieldIDs(const Value &field_ids_value, ChildFieldIDs &field_ids, + unordered_set &unique_field_ids, + const case_insensitive_map_t &name_to_type_map) { + const auto &struct_type = field_ids_value.type(); + if (struct_type.id() != LogicalTypeId::STRUCT) { + throw BinderException( + "Expected FIELD_IDS to be a STRUCT, e.g., {col1: 42, col2: {%s: 43, nested_col: 44}, col3: 44}", + FieldID::DUCKDB_FIELD_ID); + } + const auto &struct_children = StructValue::GetChildren(field_ids_value); + D_ASSERT(StructType::GetChildTypes(struct_type).size() == struct_children.size()); + for (idx_t i = 0; i < struct_children.size(); i++) { + const auto &col_name = StringUtil::Lower(StructType::GetChildName(struct_type, i)); + if (col_name == FieldID::DUCKDB_FIELD_ID) { + continue; + } + + auto it = name_to_type_map.find(col_name); + if (it == name_to_type_map.end()) { + string names; + for (const auto &name : name_to_type_map) { + if (!names.empty()) { + names += ", "; + } + names += name.first; + } + throw BinderException( + "Column name \"%s\" specified in FIELD_IDS not found. Consider using WRITE_PARTITION_COLUMNS if this " + "column is a partition column. Available column names: [%s]", + col_name, names); + } + D_ASSERT(field_ids.ids->find(col_name) == field_ids.ids->end()); // Caught by STRUCT - deduplicates keys + + const auto &child_value = struct_children[i]; + const auto &child_type = child_value.type(); + optional_ptr field_id_value; + optional_ptr child_field_ids_value; + + if (child_type.id() == LogicalTypeId::STRUCT) { + const auto &nested_children = StructValue::GetChildren(child_value); + D_ASSERT(StructType::GetChildTypes(child_type).size() == nested_children.size()); + for (idx_t nested_i = 0; nested_i < nested_children.size(); nested_i++) { + const auto &field_id_or_nested_col = StructType::GetChildName(child_type, nested_i); + if (field_id_or_nested_col == FieldID::DUCKDB_FIELD_ID) { + field_id_value = &nested_children[nested_i]; + } else { + child_field_ids_value = &child_value; + } + } + } else { + field_id_value = &child_value; + } + + FieldID field_id; + if (field_id_value) { + Value field_id_integer_value = field_id_value->DefaultCastAs(LogicalType::INTEGER); + const uint32_t field_id_int = IntegerValue::Get(field_id_integer_value); + if (!unique_field_ids.insert(field_id_int).second) { + throw BinderException("Duplicate field_id %s found in FIELD_IDS", field_id_integer_value.ToString()); + } + field_id = FieldID(UnsafeNumericCast(field_id_int)); + } + auto inserted = field_ids.ids->insert(make_pair(col_name, std::move(field_id))); + D_ASSERT(inserted.second); + + if (child_field_ids_value) { + const auto &col_type = it->second; + if (col_type.id() != LogicalTypeId::LIST && col_type.id() != LogicalTypeId::MAP && + col_type.id() != LogicalTypeId::STRUCT) { + throw BinderException("Column \"%s\" with type \"%s\" cannot have a nested FIELD_IDS specification", + col_name, LogicalTypeIdToString(col_type.id())); + } + + GetFieldIDs(*child_field_ids_value, inserted.first->second.child_field_ids, unique_field_ids, + GetChildNameToTypeMap(col_type)); + } + } +} + +} // namespace duckdb diff --git a/src/duckdb/extension/parquet/geo_parquet.cpp b/src/duckdb/extension/parquet/parquet_geometry.cpp similarity index 54% rename from src/duckdb/extension/parquet/geo_parquet.cpp rename to src/duckdb/extension/parquet/parquet_geometry.cpp index bddc36b43..7ab81cc2a 100644 --- a/src/duckdb/extension/parquet/geo_parquet.cpp +++ b/src/duckdb/extension/parquet/parquet_geometry.cpp @@ -1,193 +1,29 @@ -#include "geo_parquet.hpp" +#include "parquet_geometry.hpp" #include "column_reader.hpp" #include "duckdb/catalog/catalog_entry/scalar_function_catalog_entry.hpp" #include "duckdb/execution/expression_executor.hpp" #include "duckdb/function/scalar_function.hpp" +#include "duckdb/function/scalar/geometry_functions.hpp" #include "duckdb/planner/expression/bound_function_expression.hpp" #include "duckdb/planner/expression/bound_reference_expression.hpp" #include "duckdb/main/extension_helper.hpp" #include "reader/expression_column_reader.hpp" #include "parquet_reader.hpp" #include "yyjson.hpp" +#include "reader/string_column_reader.hpp" namespace duckdb { using namespace duckdb_yyjson; // NOLINT -//------------------------------------------------------------------------------ -// WKB stats -//------------------------------------------------------------------------------ -namespace { - -class BinaryReader { -public: - const char *beg; - const char *end; - const char *ptr; - - BinaryReader(const char *beg, uint32_t len) : beg(beg), end(beg + len), ptr(beg) { - } - - template - T Read() { - if (ptr + sizeof(T) > end) { - throw InvalidInputException("Unexpected end of WKB data"); - } - T val; - memcpy(&val, ptr, sizeof(T)); - ptr += sizeof(T); - return val; - } - - void Skip(idx_t len) { - if (ptr + len > end) { - throw InvalidInputException("Unexpected end of WKB data"); - } - ptr += len; - } - - const char *Reserve(idx_t len) { - if (ptr + len > end) { - throw InvalidInputException("Unexpected end of WKB data"); - } - auto ret = ptr; - ptr += len; - return ret; - } - - bool IsAtEnd() const { - return ptr >= end; - } -}; - -} // namespace - -static void UpdateBoundsFromVertexArray(GeometryExtent &bbox, uint32_t flag, const char *vert_array, - uint32_t vert_count) { - switch (flag) { - case 0: { // XY - constexpr auto vert_width = sizeof(double) * 2; - for (uint32_t vert_idx = 0; vert_idx < vert_count; vert_idx++) { - double vert[2]; - memcpy(vert, vert_array + vert_idx * vert_width, vert_width); - bbox.ExtendX(vert[0]); - bbox.ExtendY(vert[1]); - } - } break; - case 1: { // XYZ - constexpr auto vert_width = sizeof(double) * 3; - for (uint32_t vert_idx = 0; vert_idx < vert_count; vert_idx++) { - double vert[3]; - memcpy(vert, vert_array + vert_idx * vert_width, vert_width); - bbox.ExtendX(vert[0]); - bbox.ExtendY(vert[1]); - bbox.ExtendZ(vert[2]); - } - } break; - case 2: { // XYM - constexpr auto vert_width = sizeof(double) * 3; - for (uint32_t vert_idx = 0; vert_idx < vert_count; vert_idx++) { - double vert[3]; - memcpy(vert, vert_array + vert_idx * vert_width, vert_width); - bbox.ExtendX(vert[0]); - bbox.ExtendY(vert[1]); - bbox.ExtendM(vert[2]); - } - } break; - case 3: { // XYZM - constexpr auto vert_width = sizeof(double) * 4; - for (uint32_t vert_idx = 0; vert_idx < vert_count; vert_idx++) { - double vert[4]; - memcpy(vert, vert_array + vert_idx * vert_width, vert_width); - bbox.ExtendX(vert[0]); - bbox.ExtendY(vert[1]); - bbox.ExtendZ(vert[2]); - bbox.ExtendM(vert[3]); - } - } break; - default: - break; - } -} - -void GeometryStats::Update(const string_t &wkb) { - BinaryReader reader(wkb.GetData(), wkb.GetSize()); - - bool first_geom = true; - while (!reader.IsAtEnd()) { - reader.Read(); // byte order - auto type = reader.Read(); - auto kind = type % 1000; - auto flag = type / 1000; - const auto hasz = (flag & 0x01) != 0; - const auto hasm = (flag & 0x02) != 0; - - if (first_geom) { - // Only add the top-level geometry type - types.Add(type); - first_geom = false; - } - - const auto vert_width = sizeof(double) * (2 + (hasz ? 1 : 0) + (hasm ? 1 : 0)); - - switch (kind) { - case 1: { // POINT - - // Point are special in that they are considered "empty" if they are all-nan - const auto vert_array = reader.Reserve(vert_width); - const auto dims_count = 2 + (hasz ? 1 : 0) + (hasm ? 1 : 0); - double vert_point[4] = {0, 0, 0, 0}; - - memcpy(vert_point, vert_array, vert_width); - - for (auto dim_idx = 0; dim_idx < dims_count; dim_idx++) { - if (!std::isnan(vert_point[dim_idx])) { - bbox.ExtendX(vert_point[0]); - bbox.ExtendY(vert_point[1]); - if (hasz && hasm) { - bbox.ExtendZ(vert_point[2]); - bbox.ExtendM(vert_point[3]); - } else if (hasz) { - bbox.ExtendZ(vert_point[2]); - } else if (hasm) { - bbox.ExtendM(vert_point[2]); - } - break; - } - } - } break; - case 2: { // LINESTRING - const auto vert_count = reader.Read(); - const auto vert_array = reader.Reserve(vert_count * vert_width); - UpdateBoundsFromVertexArray(bbox, flag, vert_array, vert_count); - } break; - case 3: { // POLYGON - const auto ring_count = reader.Read(); - for (uint32_t ring_idx = 0; ring_idx < ring_count; ring_idx++) { - const auto vert_count = reader.Read(); - const auto vert_array = reader.Reserve(vert_count * vert_width); - UpdateBoundsFromVertexArray(bbox, flag, vert_array, vert_count); - } - } break; - case 4: // MULTIPOINT - case 5: // MULTILINESTRING - case 6: // MULTIPOLYGON - case 7: { // GEOMETRYCOLLECTION - reader.Skip(sizeof(uint32_t)); - } break; - } - } -} - //------------------------------------------------------------------------------ // GeoParquetFileMetadata //------------------------------------------------------------------------------ unique_ptr GeoParquetFileMetadata::TryRead(const duckdb_parquet::FileMetaData &file_meta_data, const ClientContext &context) { - // Conversion not enabled, or spatial is not loaded! if (!IsGeoParquetConversionEnabled(context)) { return nullptr; @@ -208,17 +44,19 @@ unique_ptr GeoParquetFileMetadata::TryRead(const duckdb_ throw InvalidInputException("Geoparquet metadata is not an object"); } - auto result = make_uniq(); + // We dont actually care about the version for now, as we only support V1+native + auto result = make_uniq(GeoParquetVersion::BOTH); // Check and parse the version const auto version_val = yyjson_obj_get(root, "version"); if (!yyjson_is_str(version_val)) { throw InvalidInputException("Geoparquet metadata does not have a version"); } - result->version = yyjson_get_str(version_val); - if (StringUtil::StartsWith(result->version, "2")) { - // Guard against a breaking future 2.0 version - throw InvalidInputException("Geoparquet version %s is not supported", result->version); + + auto version = yyjson_get_str(version_val); + if (StringUtil::StartsWith(version, "3")) { + // Guard against a breaking future 3.0 version + throw InvalidInputException("Geoparquet version %s is not supported", version); } // Check and parse the geometry columns @@ -292,8 +130,7 @@ unique_ptr GeoParquetFileMetadata::TryRead(const duckdb_ } void GeoParquetFileMetadata::AddGeoParquetStats(const string &column_name, const LogicalType &type, - const GeometryStats &stats) { - + const GeometryStatsData &stats) { // Lock the metadata lock_guard glock(write_lock); @@ -301,21 +138,18 @@ void GeoParquetFileMetadata::AddGeoParquetStats(const string &column_name, const if (it == geometry_columns.end()) { auto &column = geometry_columns[column_name]; - column.stats.types.Combine(stats.types); - column.stats.bbox.Combine(stats.bbox); + column.stats.Merge(stats); column.insertion_index = geometry_columns.size() - 1; } else { - it->second.stats.types.Combine(stats.types); - it->second.stats.bbox.Combine(stats.bbox); + it->second.stats.Merge(stats); } } void GeoParquetFileMetadata::Write(duckdb_parquet::FileMetaData &file_meta_data) { - // GeoParquet does not support M or ZM coordinates. So remove any columns that have them. unordered_set invalid_columns; for (auto &column : geometry_columns) { - if (column.second.stats.bbox.HasM()) { + if (column.second.stats.extent.HasM()) { invalid_columns.insert(column.first); } } @@ -344,7 +178,20 @@ void GeoParquetFileMetadata::Write(duckdb_parquet::FileMetaData &file_meta_data) yyjson_mut_doc_set_root(doc, root); // Add the version - yyjson_mut_obj_add_strncpy(doc, root, "version", version.c_str(), version.size()); + switch (version) { + case GeoParquetVersion::V1: + case GeoParquetVersion::BOTH: + yyjson_mut_obj_add_strcpy(doc, root, "version", "1.0.0"); + break; + case GeoParquetVersion::V2: + yyjson_mut_obj_add_strcpy(doc, root, "version", "2.0.0"); + break; + case GeoParquetVersion::NONE: + default: + // Should never happen, we should not be writing anything + yyjson_mut_doc_free(doc); + throw InternalException("GeoParquetVersion::NONE should not write metadata"); + } // Add the primary column yyjson_mut_obj_add_strncpy(doc, root, "primary_column", primary_geometry_column.c_str(), @@ -354,32 +201,31 @@ void GeoParquetFileMetadata::Write(duckdb_parquet::FileMetaData &file_meta_data) const auto json_columns = yyjson_mut_obj_add_obj(doc, root, "columns"); for (auto &column : geometry_columns) { - const auto column_json = yyjson_mut_obj_add_obj(doc, json_columns, column.first.c_str()); yyjson_mut_obj_add_str(doc, column_json, "encoding", "WKB"); const auto geometry_types = yyjson_mut_obj_add_arr(doc, column_json, "geometry_types"); + for (auto &type_name : column.second.stats.types.ToString(false)) { yyjson_mut_arr_add_strcpy(doc, geometry_types, type_name.c_str()); } - const auto &bbox = column.second.stats.bbox; - - if (bbox.IsSet()) { + const auto &bbox = column.second.stats.extent; + if (bbox.HasXY()) { const auto bbox_arr = yyjson_mut_obj_add_arr(doc, column_json, "bbox"); - if (!column.second.stats.bbox.HasZ()) { - yyjson_mut_arr_add_real(doc, bbox_arr, bbox.xmin); - yyjson_mut_arr_add_real(doc, bbox_arr, bbox.ymin); - yyjson_mut_arr_add_real(doc, bbox_arr, bbox.xmax); - yyjson_mut_arr_add_real(doc, bbox_arr, bbox.ymax); + if (!column.second.stats.extent.HasZ()) { + yyjson_mut_arr_add_real(doc, bbox_arr, bbox.x_min); + yyjson_mut_arr_add_real(doc, bbox_arr, bbox.y_min); + yyjson_mut_arr_add_real(doc, bbox_arr, bbox.x_max); + yyjson_mut_arr_add_real(doc, bbox_arr, bbox.y_max); } else { - yyjson_mut_arr_add_real(doc, bbox_arr, bbox.xmin); - yyjson_mut_arr_add_real(doc, bbox_arr, bbox.ymin); - yyjson_mut_arr_add_real(doc, bbox_arr, bbox.zmin); - yyjson_mut_arr_add_real(doc, bbox_arr, bbox.xmax); - yyjson_mut_arr_add_real(doc, bbox_arr, bbox.ymax); - yyjson_mut_arr_add_real(doc, bbox_arr, bbox.zmax); + yyjson_mut_arr_add_real(doc, bbox_arr, bbox.x_min); + yyjson_mut_arr_add_real(doc, bbox_arr, bbox.y_min); + yyjson_mut_arr_add_real(doc, bbox_arr, bbox.z_min); + yyjson_mut_arr_add_real(doc, bbox_arr, bbox.x_max); + yyjson_mut_arr_add_real(doc, bbox_arr, bbox.y_max); + yyjson_mut_arr_add_real(doc, bbox_arr, bbox.z_max); } } @@ -432,52 +278,31 @@ bool GeoParquetFileMetadata::IsGeoParquetConversionEnabled(const ClientContext & // Disabled by setting return false; } - if (!context.db->ExtensionIsLoaded("spatial")) { - // Spatial extension is not loaded, we cant convert anyway - return false; - } return true; } -LogicalType GeoParquetFileMetadata::GeometryType() { - auto blob_type = LogicalType(LogicalTypeId::BLOB); - blob_type.SetAlias("GEOMETRY"); - return blob_type; -} - const unordered_map &GeoParquetFileMetadata::GetColumnMeta() const { return geometry_columns; } -unique_ptr GeoParquetFileMetadata::CreateColumnReader(ParquetReader &reader, - const ParquetColumnSchema &schema, - ClientContext &context) { - - // Get the catalog - auto &catalog = Catalog::GetSystemCatalog(context); +unique_ptr GeometryColumnReader::Create(ParquetReader &reader, const ParquetColumnSchema &schema, + ClientContext &context) { + D_ASSERT(schema.type.id() == LogicalTypeId::GEOMETRY); + D_ASSERT(schema.children.size() == 1 && schema.children[0].type.id() == LogicalTypeId::BLOB); - // WKB encoding - if (schema.children[0].type.id() == LogicalTypeId::BLOB) { - // Look for a conversion function in the catalog - auto &conversion_func_set = - catalog.GetEntry(context, DEFAULT_SCHEMA, "st_geomfromwkb"); - auto conversion_func = conversion_func_set.functions.GetFunctionByArguments(context, {LogicalType::BLOB}); + // Make a string reader for the underlying WKB data + auto string_reader = make_uniq(reader, schema.children[0]); - // Create a bound function call expression - auto args = vector>(); - args.push_back(std::move(make_uniq(LogicalType::BLOB, 0))); - auto expr = - make_uniq(conversion_func.return_type, conversion_func, std::move(args), nullptr); - - // Create a child reader - auto child_reader = ColumnReader::CreateReader(reader, schema.children[0]); - - // Create an expression reader that applies the conversion function to the child reader - return make_uniq(context, std::move(child_reader), std::move(expr), schema); - } + // Wrap the string reader in a geometry reader + auto args = vector>(); + auto ref = make_uniq_base(LogicalTypeId::BLOB, 0); + args.push_back(std::move(ref)); - // Otherwise, unrecognized encoding - throw NotImplementedException("Unsupported geometry encoding"); + // TODO: Pass the actual target type here so we get the CRS information too + auto func = StGeomfromwkbFun::GetFunction(); + func.name = "ST_GeomFromWKB"; + auto expr = make_uniq_base(schema.type, func, std::move(args), nullptr); + return make_uniq(context, std::move(string_reader), std::move(expr), schema); } } // namespace duckdb diff --git a/src/duckdb/extension/parquet/parquet_metadata.cpp b/src/duckdb/extension/parquet/parquet_metadata.cpp index 2f34efae2..9fe14688f 100644 --- a/src/duckdb/extension/parquet/parquet_metadata.cpp +++ b/src/duckdb/extension/parquet/parquet_metadata.cpp @@ -46,23 +46,23 @@ enum class ParquetMetadataOperatorType : uint8_t { SCHEMA, KEY_VALUE_META_DATA, FILE_META_DATA, - BLOOM_PROBE + BLOOM_PROBE, + FULL_METADATA }; class ParquetMetadataFileProcessor { public: ParquetMetadataFileProcessor() = default; virtual ~ParquetMetadataFileProcessor() = default; - void Initialize(ClientContext &context, OpenFileInfo &file_info) { - ParquetOptions parquet_options(context); - reader = make_uniq(context, file_info, parquet_options); + void Initialize(ClientContext &context, ParquetReader &reader) { + InitializeInternal(context, reader); + } + virtual void InitializeInternal(ClientContext &context, ParquetReader &reader) {}; + virtual idx_t TotalRowCount(ParquetReader &reader) = 0; + virtual void ReadRow(vector> &output, idx_t output_idx, idx_t row_idx, ParquetReader &reader) = 0; + virtual bool ForceFlush() { + return false; } - virtual void InitializeInternal(ClientContext &context) {}; - virtual idx_t TotalRowCount() = 0; - virtual void ReadRow(DataChunk &output, idx_t output_idx, idx_t row_idx) = 0; - -protected: - unique_ptr reader; }; struct ParquetMetaDataBindData; @@ -115,10 +115,20 @@ struct ParquetMetadataGlobalState : public GlobalTableFunctionState { }; struct ParquetMetadataLocalState : public LocalTableFunctionState { + unique_ptr reader; unique_ptr processor; bool file_exhausted = true; idx_t row_idx = 0; idx_t total_rows = 0; + + void Initialize(ClientContext &context, OpenFileInfo &file_info) { + ParquetOptions parquet_options(context); + reader = make_uniq(context, file_info, parquet_options); + processor->Initialize(context, *reader); + total_rows = processor->TotalRowCount(*reader); + row_idx = 0; + file_exhausted = false; + } }; template @@ -179,9 +189,9 @@ static Value ParquetElementBoolean(bool value, bool is_iset) { class ParquetRowGroupMetadataProcessor : public ParquetMetadataFileProcessor { public: - void InitializeInternal(ClientContext &context) override; - idx_t TotalRowCount() override; - void ReadRow(DataChunk &output, idx_t output_idx, idx_t row_idx) override; + void InitializeInternal(ClientContext &context, ParquetReader &reader) override; + idx_t TotalRowCount(ParquetReader &reader) override; + void ReadRow(vector> &output, idx_t output_idx, idx_t row_idx, ParquetReader &reader) override; private: vector column_schemas; @@ -334,18 +344,27 @@ static Value ConvertParquetGeoStatsTypes(const duckdb_parquet::GeospatialStatist vector types; types.reserve(stats.geospatial_types.size()); - GeometryKindSet kind_set; + GeometryTypeSet type_set = GeometryTypeSet::Empty(); for (auto &type : stats.geospatial_types) { - kind_set.Add(type); + const auto geom_type = (type % 1000); + const auto vert_type = (type / 1000); + if (geom_type < 1 || geom_type > 7) { + throw InvalidInputException("Unsupported geometry type in Parquet geo metadata"); + } + if (vert_type < 0 || vert_type > 3) { + throw InvalidInputException("Unsupported geometry vertex type in Parquet geo metadata"); + } + type_set.Add(static_cast(geom_type), static_cast(vert_type)); } - for (auto &type_name : kind_set.ToString(true)) { + + for (auto &type_name : type_set.ToString(true)) { types.push_back(Value(type_name)); } return Value::LIST(LogicalType::VARCHAR, types); } -void ParquetRowGroupMetadataProcessor::InitializeInternal(ClientContext &context) { - auto meta_data = reader->GetFileMetadata(); +void ParquetRowGroupMetadataProcessor::InitializeInternal(ClientContext &context, ParquetReader &reader) { + auto meta_data = reader.GetFileMetadata(); column_schemas.clear(); for (idx_t schema_idx = 0; schema_idx < meta_data->schema.size(); schema_idx++) { auto &schema_element = meta_data->schema[schema_idx]; @@ -353,18 +372,19 @@ void ParquetRowGroupMetadataProcessor::InitializeInternal(ClientContext &context continue; } ParquetColumnSchema column_schema; - column_schema.type = reader->DeriveLogicalType(schema_element, column_schema); + column_schema.type = reader.DeriveLogicalType(schema_element, column_schema); column_schemas.push_back(std::move(column_schema)); } } -idx_t ParquetRowGroupMetadataProcessor::TotalRowCount() { - auto meta_data = reader->GetFileMetadata(); +idx_t ParquetRowGroupMetadataProcessor::TotalRowCount(ParquetReader &reader) { + auto meta_data = reader.GetFileMetadata(); return meta_data->row_groups.size() * column_schemas.size(); } -void ParquetRowGroupMetadataProcessor::ReadRow(DataChunk &output, idx_t output_idx, idx_t row_idx) { - auto meta_data = reader->GetFileMetadata(); +void ParquetRowGroupMetadataProcessor::ReadRow(vector> &output, idx_t output_idx, idx_t row_idx, + ParquetReader &reader) { + auto meta_data = reader.GetFileMetadata(); idx_t col_idx = row_idx % column_schemas.size(); idx_t row_group_idx = row_idx / column_schemas.size(); @@ -377,86 +397,90 @@ void ParquetRowGroupMetadataProcessor::ReadRow(DataChunk &output, idx_t output_i auto &column_type = column_schema.type; // file_name - output.SetValue(0, output_idx, reader->file.path); + output[0].get().SetValue(output_idx, reader.file.path); // row_group_id - output.SetValue(1, output_idx, Value::BIGINT(UnsafeNumericCast(row_group_idx))); + output[1].get().SetValue(output_idx, Value::BIGINT(UnsafeNumericCast(row_group_idx))); // row_group_num_rows - output.SetValue(2, output_idx, Value::BIGINT(row_group.num_rows)); + output[2].get().SetValue(output_idx, Value::BIGINT(row_group.num_rows)); // row_group_num_columns - output.SetValue(3, output_idx, Value::BIGINT(UnsafeNumericCast(row_group.columns.size()))); + output[3].get().SetValue(output_idx, Value::BIGINT(UnsafeNumericCast(row_group.columns.size()))); // row_group_bytes - output.SetValue(4, output_idx, Value::BIGINT(row_group.total_byte_size)); + output[4].get().SetValue(output_idx, Value::BIGINT(row_group.total_byte_size)); // column_id - output.SetValue(5, output_idx, Value::BIGINT(UnsafeNumericCast(col_idx))); + output[5].get().SetValue(output_idx, Value::BIGINT(UnsafeNumericCast(col_idx))); // file_offset - output.SetValue(6, output_idx, ParquetElementBigint(column.file_offset, row_group.__isset.file_offset)); + output[6].get().SetValue(output_idx, ParquetElementBigint(column.file_offset, row_group.__isset.file_offset)); // num_values - output.SetValue(7, output_idx, Value::BIGINT(col_meta.num_values)); + output[7].get().SetValue(output_idx, Value::BIGINT(col_meta.num_values)); // path_in_schema - output.SetValue(8, output_idx, StringUtil::Join(col_meta.path_in_schema, ", ")); + output[8].get().SetValue(output_idx, StringUtil::Join(col_meta.path_in_schema, ", ")); // type - output.SetValue(9, output_idx, ConvertParquetElementToString(col_meta.type)); + output[9].get().SetValue(output_idx, ConvertParquetElementToString(col_meta.type)); // stats_min - output.SetValue(10, output_idx, ConvertParquetStats(column_type, column_schema, stats.__isset.min, stats.min)); + output[10].get().SetValue(output_idx, + ConvertParquetStats(column_type, column_schema, stats.__isset.min, stats.min)); // stats_max - output.SetValue(11, output_idx, ConvertParquetStats(column_type, column_schema, stats.__isset.max, stats.max)); + output[11].get().SetValue(output_idx, + ConvertParquetStats(column_type, column_schema, stats.__isset.max, stats.max)); // stats_null_count - output.SetValue(12, output_idx, ParquetElementBigint(stats.null_count, stats.__isset.null_count)); + output[12].get().SetValue(output_idx, ParquetElementBigint(stats.null_count, stats.__isset.null_count)); // stats_distinct_count - output.SetValue(13, output_idx, ParquetElementBigint(stats.distinct_count, stats.__isset.distinct_count)); + output[13].get().SetValue(output_idx, ParquetElementBigint(stats.distinct_count, stats.__isset.distinct_count)); // stats_min_value - output.SetValue(14, output_idx, - ConvertParquetStats(column_type, column_schema, stats.__isset.min_value, stats.min_value)); + output[14].get().SetValue( + output_idx, ConvertParquetStats(column_type, column_schema, stats.__isset.min_value, stats.min_value)); // stats_max_value - output.SetValue(15, output_idx, - ConvertParquetStats(column_type, column_schema, stats.__isset.max_value, stats.max_value)); + output[15].get().SetValue( + output_idx, ConvertParquetStats(column_type, column_schema, stats.__isset.max_value, stats.max_value)); // compression - output.SetValue(16, output_idx, ConvertParquetElementToString(col_meta.codec)); + output[16].get().SetValue(output_idx, ConvertParquetElementToString(col_meta.codec)); // encodings vector encoding_string; encoding_string.reserve(col_meta.encodings.size()); for (auto &encoding : col_meta.encodings) { encoding_string.push_back(ConvertParquetElementToString(encoding)); } - output.SetValue(17, output_idx, Value(StringUtil::Join(encoding_string, ", "))); + output[17].get().SetValue(output_idx, Value(StringUtil::Join(encoding_string, ", "))); // index_page_offset - output.SetValue(18, output_idx, - ParquetElementBigint(col_meta.index_page_offset, col_meta.__isset.index_page_offset)); + output[18].get().SetValue(output_idx, + ParquetElementBigint(col_meta.index_page_offset, col_meta.__isset.index_page_offset)); // dictionary_page_offset - output.SetValue(19, output_idx, - ParquetElementBigint(col_meta.dictionary_page_offset, col_meta.__isset.dictionary_page_offset)); + output[19].get().SetValue( + output_idx, ParquetElementBigint(col_meta.dictionary_page_offset, col_meta.__isset.dictionary_page_offset)); // data_page_offset - output.SetValue(20, output_idx, Value::BIGINT(col_meta.data_page_offset)); + output[20].get().SetValue(output_idx, Value::BIGINT(col_meta.data_page_offset)); // total_compressed_size - output.SetValue(21, output_idx, Value::BIGINT(col_meta.total_compressed_size)); + output[21].get().SetValue(output_idx, Value::BIGINT(col_meta.total_compressed_size)); // total_uncompressed_size - output.SetValue(22, output_idx, Value::BIGINT(col_meta.total_uncompressed_size)); + output[22].get().SetValue(output_idx, Value::BIGINT(col_meta.total_uncompressed_size)); // key_value_metadata vector map_keys, map_values; for (auto &entry : col_meta.key_value_metadata) { map_keys.push_back(Value::BLOB_RAW(entry.key)); map_values.push_back(Value::BLOB_RAW(entry.value)); } - output.SetValue(23, output_idx, - Value::MAP(LogicalType::BLOB, LogicalType::BLOB, std::move(map_keys), std::move(map_values))); + output[23].get().SetValue( + output_idx, Value::MAP(LogicalType::BLOB, LogicalType::BLOB, std::move(map_keys), std::move(map_values))); // bloom_filter_offset - output.SetValue(24, output_idx, - ParquetElementBigint(col_meta.bloom_filter_offset, col_meta.__isset.bloom_filter_offset)); + output[24].get().SetValue(output_idx, + ParquetElementBigint(col_meta.bloom_filter_offset, col_meta.__isset.bloom_filter_offset)); // bloom_filter_length - output.SetValue(25, output_idx, - ParquetElementBigint(col_meta.bloom_filter_length, col_meta.__isset.bloom_filter_length)); + output[25].get().SetValue(output_idx, + ParquetElementBigint(col_meta.bloom_filter_length, col_meta.__isset.bloom_filter_length)); // min_is_exact - output.SetValue(26, output_idx, ParquetElementBoolean(stats.is_min_value_exact, stats.__isset.is_min_value_exact)); + output[26].get().SetValue(output_idx, + ParquetElementBoolean(stats.is_min_value_exact, stats.__isset.is_min_value_exact)); // max_is_exact - output.SetValue(27, output_idx, ParquetElementBoolean(stats.is_max_value_exact, stats.__isset.is_max_value_exact)); + output[27].get().SetValue(output_idx, + ParquetElementBoolean(stats.is_max_value_exact, stats.__isset.is_max_value_exact)); // row_group_compressed_bytes - output.SetValue(28, output_idx, - ParquetElementBigint(row_group.total_compressed_size, row_group.__isset.total_compressed_size)); + output[28].get().SetValue( + output_idx, ParquetElementBigint(row_group.total_compressed_size, row_group.__isset.total_compressed_size)); // geo_stats_bbox, LogicalType::STRUCT(...) - output.SetValue(29, output_idx, ConvertParquetGeoStatsBBOX(col_meta.geospatial_statistics)); + output[29].get().SetValue(output_idx, ConvertParquetGeoStatsBBOX(col_meta.geospatial_statistics)); // geo_stats_types, LogicalType::LIST(LogicalType::VARCHAR) - output.SetValue(30, output_idx, ConvertParquetGeoStatsTypes(col_meta.geospatial_statistics)); + output[30].get().SetValue(output_idx, ConvertParquetGeoStatsTypes(col_meta.geospatial_statistics)); } //===--------------------------------------------------------------------===// @@ -465,8 +489,8 @@ void ParquetRowGroupMetadataProcessor::ReadRow(DataChunk &output, idx_t output_i class ParquetSchemaProcessor : public ParquetMetadataFileProcessor { public: - idx_t TotalRowCount() override; - void ReadRow(DataChunk &output, idx_t output_idx, idx_t row_idx) override; + idx_t TotalRowCount(ParquetReader &reader) override; + void ReadRow(vector> &output, idx_t output_idx, idx_t row_idx, ParquetReader &reader) override; }; template <> @@ -567,45 +591,46 @@ static Value ParquetLogicalTypeToString(const duckdb_parquet::LogicalType &type, return Value(); } -idx_t ParquetSchemaProcessor::TotalRowCount() { - return reader->GetFileMetadata()->schema.size(); +idx_t ParquetSchemaProcessor::TotalRowCount(ParquetReader &reader) { + return reader.GetFileMetadata()->schema.size(); } -void ParquetSchemaProcessor::ReadRow(DataChunk &output, idx_t output_idx, idx_t row_idx) { - auto meta_data = reader->GetFileMetadata(); +void ParquetSchemaProcessor::ReadRow(vector> &output, idx_t output_idx, idx_t row_idx, + ParquetReader &reader) { + auto meta_data = reader.GetFileMetadata(); const auto &column = meta_data->schema[row_idx]; // file_name - output.SetValue(0, output_idx, reader->file.path); + output[0].get().SetValue(output_idx, reader.file.path); // name - output.SetValue(1, output_idx, column.name); + output[1].get().SetValue(output_idx, column.name); // type - output.SetValue(2, output_idx, ParquetElementString(column.type, column.__isset.type)); + output[2].get().SetValue(output_idx, ParquetElementString(column.type, column.__isset.type)); // type_length - output.SetValue(3, output_idx, ParquetElementInteger(column.type_length, column.__isset.type_length)); + output[3].get().SetValue(output_idx, ParquetElementInteger(column.type_length, column.__isset.type_length)); // repetition_type - output.SetValue(4, output_idx, ParquetElementString(column.repetition_type, column.__isset.repetition_type)); + output[4].get().SetValue(output_idx, ParquetElementString(column.repetition_type, column.__isset.repetition_type)); // num_children - output.SetValue(5, output_idx, ParquetElementBigint(column.num_children, column.__isset.num_children)); + output[5].get().SetValue(output_idx, ParquetElementBigint(column.num_children, column.__isset.num_children)); // converted_type - output.SetValue(6, output_idx, ParquetElementString(column.converted_type, column.__isset.converted_type)); + output[6].get().SetValue(output_idx, ParquetElementString(column.converted_type, column.__isset.converted_type)); // scale - output.SetValue(7, output_idx, ParquetElementBigint(column.scale, column.__isset.scale)); + output[7].get().SetValue(output_idx, ParquetElementBigint(column.scale, column.__isset.scale)); // precision - output.SetValue(8, output_idx, ParquetElementBigint(column.precision, column.__isset.precision)); + output[8].get().SetValue(output_idx, ParquetElementBigint(column.precision, column.__isset.precision)); // field_id - output.SetValue(9, output_idx, ParquetElementBigint(column.field_id, column.__isset.field_id)); + output[9].get().SetValue(output_idx, ParquetElementBigint(column.field_id, column.__isset.field_id)); // logical_type - output.SetValue(10, output_idx, ParquetLogicalTypeToString(column.logicalType, column.__isset.logicalType)); + output[10].get().SetValue(output_idx, ParquetLogicalTypeToString(column.logicalType, column.__isset.logicalType)); // duckdb_type ParquetColumnSchema column_schema; Value duckdb_type; if (column.__isset.type) { - duckdb_type = reader->DeriveLogicalType(column, column_schema).ToString(); + duckdb_type = reader.DeriveLogicalType(column, column_schema).ToString(); } - output.SetValue(11, output_idx, duckdb_type); + output[11].get().SetValue(output_idx, duckdb_type); // column_id - output.SetValue(12, output_idx, Value::BIGINT(UnsafeNumericCast(row_idx))); + output[12].get().SetValue(output_idx, Value::BIGINT(UnsafeNumericCast(row_idx))); } //===--------------------------------------------------------------------===// @@ -614,8 +639,8 @@ void ParquetSchemaProcessor::ReadRow(DataChunk &output, idx_t output_idx, idx_t class ParquetKeyValueMetadataProcessor : public ParquetMetadataFileProcessor { public: - idx_t TotalRowCount() override; - void ReadRow(DataChunk &output, idx_t output_idx, idx_t row_idx) override; + idx_t TotalRowCount(ParquetReader &reader) override; + void ReadRow(vector> &output, idx_t output_idx, idx_t row_idx, ParquetReader &reader) override; }; template <> @@ -631,17 +656,18 @@ void ParquetMetaDataOperator::BindSchemaGetFileMetadata()->key_value_metadata.size(); +idx_t ParquetKeyValueMetadataProcessor::TotalRowCount(ParquetReader &reader) { + return reader.GetFileMetadata()->key_value_metadata.size(); } -void ParquetKeyValueMetadataProcessor::ReadRow(DataChunk &output, idx_t output_idx, idx_t row_idx) { - auto meta_data = reader->GetFileMetadata(); +void ParquetKeyValueMetadataProcessor::ReadRow(vector> &output, idx_t output_idx, idx_t row_idx, + ParquetReader &reader) { + auto meta_data = reader.GetFileMetadata(); auto &entry = meta_data->key_value_metadata[row_idx]; - output.SetValue(0, output_idx, Value(reader->file.path)); - output.SetValue(1, output_idx, Value::BLOB_RAW(entry.key)); - output.SetValue(2, output_idx, Value::BLOB_RAW(entry.value)); + output[0].get().SetValue(output_idx, Value(reader.file.path)); + output[1].get().SetValue(output_idx, Value::BLOB_RAW(entry.key)); + output[2].get().SetValue(output_idx, Value::BLOB_RAW(entry.value)); } //===--------------------------------------------------------------------===// @@ -650,8 +676,8 @@ void ParquetKeyValueMetadataProcessor::ReadRow(DataChunk &output, idx_t output_i class ParquetFileMetadataProcessor : public ParquetMetadataFileProcessor { public: - idx_t TotalRowCount() override; - void ReadRow(DataChunk &output, idx_t output_idx, idx_t row_idx) override; + idx_t TotalRowCount(ParquetReader &reader) override; + void ReadRow(vector> &output, idx_t output_idx, idx_t row_idx, ParquetReader &reader) override; }; template <> @@ -685,34 +711,34 @@ void ParquetMetaDataOperator::BindSchemaGetFileMetadata(); +void ParquetFileMetadataProcessor::ReadRow(vector> &output, idx_t output_idx, idx_t row_idx, + ParquetReader &reader) { + auto meta_data = reader.GetFileMetadata(); // file_name - output.SetValue(0, output_idx, Value(reader->file.path)); + output[0].get().SetValue(output_idx, Value(reader.file.path)); // created_by - output.SetValue(1, output_idx, ParquetElementStringVal(meta_data->created_by, meta_data->__isset.created_by)); + output[1].get().SetValue(output_idx, ParquetElementStringVal(meta_data->created_by, meta_data->__isset.created_by)); // num_rows - output.SetValue(2, output_idx, Value::BIGINT(meta_data->num_rows)); + output[2].get().SetValue(output_idx, Value::BIGINT(meta_data->num_rows)); // num_row_groups - output.SetValue(3, output_idx, Value::BIGINT(UnsafeNumericCast(meta_data->row_groups.size()))); + output[3].get().SetValue(output_idx, Value::BIGINT(UnsafeNumericCast(meta_data->row_groups.size()))); // format_version - output.SetValue(4, output_idx, Value::BIGINT(meta_data->version)); + output[4].get().SetValue(output_idx, Value::BIGINT(meta_data->version)); // encryption_algorithm - output.SetValue(5, output_idx, - ParquetElementString(meta_data->encryption_algorithm, meta_data->__isset.encryption_algorithm)); + output[5].get().SetValue( + output_idx, ParquetElementString(meta_data->encryption_algorithm, meta_data->__isset.encryption_algorithm)); // footer_signing_key_metadata - output.SetValue(6, output_idx, - ParquetElementStringVal(meta_data->footer_signing_key_metadata, - meta_data->__isset.footer_signing_key_metadata)); + output[6].get().SetValue(output_idx, ParquetElementStringVal(meta_data->footer_signing_key_metadata, + meta_data->__isset.footer_signing_key_metadata)); // file_size_bytes - output.SetValue(7, output_idx, Value::UBIGINT(reader->GetHandle().GetFileSize())); + output[7].get().SetValue(output_idx, Value::UBIGINT(reader.GetHandle().GetFileSize())); // footer_size - output.SetValue(8, output_idx, Value::UBIGINT(reader->metadata->footer_size)); + output[8].get().SetValue(output_idx, Value::UBIGINT(reader.metadata->footer_size)); } //===--------------------------------------------------------------------===// @@ -723,9 +749,9 @@ class ParquetBloomProbeProcessor : public ParquetMetadataFileProcessor { public: ParquetBloomProbeProcessor(const string &probe_column, const Value &probe_value); - void InitializeInternal(ClientContext &context) override; - idx_t TotalRowCount() override; - void ReadRow(DataChunk &output, idx_t output_idx, idx_t row_idx) override; + void InitializeInternal(ClientContext &context, ParquetReader &reader) override; + idx_t TotalRowCount(ParquetReader &reader) override; + void ReadRow(vector> &output, idx_t output_idx, idx_t row_idx, ParquetReader &reader) override; private: string probe_column_name; @@ -754,34 +780,35 @@ ParquetBloomProbeProcessor::ParquetBloomProbeProcessor(const string &probe_colum : probe_column_name(probe_column), probe_constant(probe_value) { } -void ParquetBloomProbeProcessor::InitializeInternal(ClientContext &context) { +void ParquetBloomProbeProcessor::InitializeInternal(ClientContext &context, ParquetReader &reader) { probe_column_idx = optional_idx::Invalid(); - for (idx_t column_idx = 0; column_idx < reader->columns.size(); column_idx++) { - if (reader->columns[column_idx].name == probe_column_name) { + for (idx_t column_idx = 0; column_idx < reader.columns.size(); column_idx++) { + if (reader.columns[column_idx].name == probe_column_name) { probe_column_idx = column_idx; break; } } if (!probe_column_idx.IsValid()) { - throw InvalidInputException("Column %s not found in %s", probe_column_name, reader->file.path); + throw InvalidInputException("Column %s not found in %s", probe_column_name, reader.file.path); } - auto transport = duckdb_base_std::make_shared(reader->GetHandle(), false); + auto transport = duckdb_base_std::make_shared(reader.GetHandle(), false); protocol = make_uniq>(std::move(transport)); allocator = &BufferAllocator::Get(context); filter = make_uniq( ExpressionType::COMPARE_EQUAL, - probe_constant.CastAs(context, reader->GetColumns()[probe_column_idx.GetIndex()].type)); + probe_constant.CastAs(context, reader.GetColumns()[probe_column_idx.GetIndex()].type)); } -idx_t ParquetBloomProbeProcessor::TotalRowCount() { - return reader->GetFileMetadata()->row_groups.size(); +idx_t ParquetBloomProbeProcessor::TotalRowCount(ParquetReader &reader) { + return reader.GetFileMetadata()->row_groups.size(); } -void ParquetBloomProbeProcessor::ReadRow(DataChunk &output, idx_t output_idx, idx_t row_idx) { - auto meta_data = reader->GetFileMetadata(); +void ParquetBloomProbeProcessor::ReadRow(vector> &output, idx_t output_idx, idx_t row_idx, + ParquetReader &reader) { + auto meta_data = reader.GetFileMetadata(); auto &row_group = meta_data->row_groups[row_idx]; auto &column = row_group.columns[probe_column_idx.GetIndex()]; @@ -789,9 +816,124 @@ void ParquetBloomProbeProcessor::ReadRow(DataChunk &output, idx_t output_idx, id auto bloom_excludes = ParquetStatisticsUtils::BloomFilterExcludes(*filter, column.meta_data, *protocol, *allocator); - output.SetValue(0, output_idx, Value(reader->file.path)); - output.SetValue(1, output_idx, Value::BIGINT(NumericCast(row_idx))); - output.SetValue(2, output_idx, Value::BOOLEAN(bloom_excludes)); + output[0].get().SetValue(output_idx, Value(reader.file.path)); + output[1].get().SetValue(output_idx, Value::BIGINT(NumericCast(row_idx))); + output[2].get().SetValue(output_idx, Value::BOOLEAN(bloom_excludes)); +} + +//===--------------------------------------------------------------------===// +// Full Metadata +//===--------------------------------------------------------------------===// + +class FullMetadataProcessor : public ParquetMetadataFileProcessor { +public: + FullMetadataProcessor() = default; + + void InitializeInternal(ClientContext &context, ParquetReader &reader) override; + idx_t TotalRowCount(ParquetReader &reader) override; + void ReadRow(vector> &output, idx_t output_idx, idx_t row_idx, ParquetReader &reader) override; + bool ForceFlush() override { + return true; + } + +private: + void PopulateMetadata(ParquetMetadataFileProcessor &processor, Vector &output, idx_t output_idx, + ParquetReader &reader); + + ParquetFileMetadataProcessor file_processor; + ParquetRowGroupMetadataProcessor row_group_processor; + ParquetSchemaProcessor schema_processor; + ParquetKeyValueMetadataProcessor kv_processor; +}; + +void FullMetadataProcessor::PopulateMetadata(ParquetMetadataFileProcessor &processor, Vector &output, idx_t output_idx, + ParquetReader &reader) { + auto count = processor.TotalRowCount(reader); + auto *result_data = FlatVector::GetData(output); + auto &result_struct = ListVector::GetEntry(output); + auto &result_struct_entries = StructVector::GetEntries(result_struct); + + ListVector::SetListSize(output, count); + ListVector::Reserve(output, count); + + result_data[output_idx].offset = 0; + result_data[output_idx].length = count; + + FlatVector::Validity(output).SetValid(output_idx); + + vector> vectors; + for (auto &entry : result_struct_entries) { + vectors.push_back(std::ref(*entry.get())); + entry->SetVectorType(VectorType::FLAT_VECTOR); + auto &validity = FlatVector::Validity(*entry); + validity.Initialize(count); + } + for (idx_t i = 0; i < count; i++) { + processor.ReadRow(vectors, i, i, reader); + } +} + +template <> +void ParquetMetaDataOperator::BindSchema(vector &return_types, + vector &names) { + names.emplace_back("parquet_file_metadata"); + vector file_meta_types; + vector file_meta_names; + ParquetMetaDataOperator::BindSchema(file_meta_types, file_meta_names); + child_list_t file_meta_children; + for (idx_t i = 0; i < file_meta_types.size(); i++) { + file_meta_children.push_back(make_pair(file_meta_names[i], file_meta_types[i])); + } + return_types.emplace_back(LogicalType::LIST(LogicalType::STRUCT(std::move(file_meta_children)))); + + names.emplace_back("parquet_metadata"); + vector row_group_types; + vector row_group_names; + ParquetMetaDataOperator::BindSchema(row_group_types, row_group_names); + child_list_t row_group_children; + for (idx_t i = 0; i < row_group_types.size(); i++) { + row_group_children.push_back(make_pair(row_group_names[i], row_group_types[i])); + } + return_types.emplace_back(LogicalType::LIST(LogicalType::STRUCT(std::move(row_group_children)))); + + names.emplace_back("parquet_schema"); + vector schema_types; + vector schema_names; + ParquetMetaDataOperator::BindSchema(schema_types, schema_names); + child_list_t schema_children; + for (idx_t i = 0; i < schema_types.size(); i++) { + schema_children.push_back(make_pair(schema_names[i], schema_types[i])); + } + return_types.emplace_back(LogicalType::LIST(LogicalType::STRUCT(std::move(schema_children)))); + + names.emplace_back("parquet_kv_metadata"); + vector kv_types; + vector kv_names; + ParquetMetaDataOperator::BindSchema(kv_types, kv_names); + child_list_t kv_children; + for (idx_t i = 0; i < kv_types.size(); i++) { + kv_children.push_back(make_pair(kv_names[i], kv_types[i])); + } + return_types.emplace_back(LogicalType::LIST(LogicalType::STRUCT(std::move(kv_children)))); +} + +void FullMetadataProcessor::InitializeInternal(ClientContext &context, ParquetReader &reader) { + file_processor.Initialize(context, reader); + row_group_processor.Initialize(context, reader); + schema_processor.Initialize(context, reader); + kv_processor.Initialize(context, reader); +} + +idx_t FullMetadataProcessor::TotalRowCount(ParquetReader &reader) { + return 1; +} + +void FullMetadataProcessor::ReadRow(vector> &output, idx_t output_idx, idx_t row_idx, + ParquetReader &reader) { + PopulateMetadata(file_processor, output[0].get(), output_idx, reader); + PopulateMetadata(row_group_processor, output[1].get(), output_idx, reader); + PopulateMetadata(schema_processor, output[2].get(), output_idx, reader); + PopulateMetadata(kv_processor, output[3].get(), output_idx, reader); } //===--------------------------------------------------------------------===// @@ -859,6 +1001,10 @@ unique_ptr ParquetMetaDataOperator::InitLocal(Execution make_uniq(probe_bind_data.probe_column_name, probe_bind_data.probe_constant); break; } + case ParquetMetadataOperatorType::FULL_METADATA: { + res->processor = make_uniq(); + break; + } default: throw InternalException("Unsupported ParquetMetadataOperatorType"); } @@ -872,6 +1018,11 @@ void ParquetMetaDataOperator::Function(ClientContext &context, TableFunctionInpu idx_t output_count = 0; + vector> output_vectors; + for (idx_t i = 0; i < output.ColumnCount(); i++) { + output_vectors.push_back(std::ref(output.data[i])); + } + while (output_count < STANDARD_VECTOR_SIZE) { // Check if we need a new file if (local_state.file_exhausted) { @@ -880,11 +1031,7 @@ void ParquetMetaDataOperator::Function(ClientContext &context, TableFunctionInpu break; // No more files to process } - local_state.processor->Initialize(context, next_file); - local_state.processor->InitializeInternal(context); - local_state.file_exhausted = false; - local_state.row_idx = 0; - local_state.total_rows = local_state.processor->TotalRowCount(); + local_state.Initialize(context, next_file); } idx_t left_in_vector = STANDARD_VECTOR_SIZE - output_count; @@ -897,14 +1044,19 @@ void ParquetMetaDataOperator::Function(ClientContext &context, TableFunctionInpu rows_to_output = left_in_vector; } + output.SetCardinality(output_count + rows_to_output); + for (idx_t i = 0; i < rows_to_output; ++i) { - local_state.processor->ReadRow(output, output_count + i, local_state.row_idx + i); + local_state.processor->ReadRow(output_vectors, output_count + i, local_state.row_idx + i, + *local_state.reader); } output_count += rows_to_output; local_state.row_idx += rows_to_output; - } - output.SetCardinality(output_count); + if (local_state.processor->ForceFlush()) { + break; + } + } } double ParquetMetaDataOperator::Progress(ClientContext &context, const FunctionData *bind_data_p, @@ -957,4 +1109,13 @@ ParquetBloomProbeFunction::ParquetBloomProbeFunction() ParquetMetaDataOperator::InitLocal) { table_scan_progress = ParquetMetaDataOperator::Progress; } + +ParquetFullMetadataFunction::ParquetFullMetadataFunction() + : TableFunction("parquet_full_metadata", {LogicalType::VARCHAR}, + ParquetMetaDataOperator::Function, + ParquetMetaDataOperator::Bind, + ParquetMetaDataOperator::InitGlobal, + ParquetMetaDataOperator::InitLocal) { + table_scan_progress = ParquetMetaDataOperator::Progress; +} } // namespace duckdb diff --git a/src/duckdb/extension/parquet/parquet_multi_file_info.cpp b/src/duckdb/extension/parquet/parquet_multi_file_info.cpp index 9617f0c83..160211b69 100644 --- a/src/duckdb/extension/parquet/parquet_multi_file_info.cpp +++ b/src/duckdb/extension/parquet/parquet_multi_file_info.cpp @@ -397,10 +397,6 @@ bool ParquetMultiFileInfo::ParseOption(ClientContext &context, const string &ori options.binary_as_string = BooleanValue::Get(val); return true; } - if (key == "variant_legacy_encoding") { - options.variant_legacy_encoding = BooleanValue::Get(val); - return true; - } if (key == "file_row_number") { options.file_row_number = BooleanValue::Get(val); return true; @@ -575,12 +571,21 @@ void ParquetReader::FinishFile(ClientContext &context, GlobalTableFunctionState gstate.row_group_index = 0; } -void ParquetReader::Scan(ClientContext &context, GlobalTableFunctionState &gstate_p, - LocalTableFunctionState &local_state_p, DataChunk &chunk) { +AsyncResult ParquetReader::Scan(ClientContext &context, GlobalTableFunctionState &gstate_p, + LocalTableFunctionState &local_state_p, DataChunk &chunk) { +#ifdef DUCKDB_DEBUG_ASYNC_SINK_SOURCE + { + vector> tasks = AsyncResult::GenerateTestTasks(); + if (!tasks.empty()) { + return AsyncResult(std::move(tasks)); + } + } +#endif + auto &gstate = gstate_p.Cast(); auto &local_state = local_state_p.Cast(); local_state.scan_state.op = gstate.op; - Scan(context, local_state.scan_state, chunk); + return Scan(context, local_state.scan_state, chunk); } unique_ptr ParquetMultiFileInfo::Copy() { diff --git a/src/duckdb/extension/parquet/parquet_reader.cpp b/src/duckdb/extension/parquet/parquet_reader.cpp index cad5f3a9b..b806beb1a 100644 --- a/src/duckdb/extension/parquet/parquet_reader.cpp +++ b/src/duckdb/extension/parquet/parquet_reader.cpp @@ -5,7 +5,7 @@ #include "column_reader.hpp" #include "duckdb.hpp" #include "reader/expression_column_reader.hpp" -#include "geo_parquet.hpp" +#include "parquet_geometry.hpp" #include "reader/list_column_reader.hpp" #include "parquet_crypto.hpp" #include "parquet_file_metadata_cache.hpp" @@ -92,7 +92,7 @@ static shared_ptr LoadMetadata(ClientContext &context, Allocator &allocator, CachingFileHandle &file_handle, const shared_ptr &encryption_config, const EncryptionUtil &encryption_util, optional_idx footer_size) { - auto file_proto = CreateThriftFileProtocol(QueryContext(context), file_handle, false); + auto file_proto = CreateThriftFileProtocol(context, file_handle, false); auto &transport = reinterpret_cast(*file_proto->getTransport()); auto file_size = transport.GetSize(); if (file_size < 12) { @@ -225,10 +225,6 @@ LogicalType ParquetReader::DeriveLogicalType(const SchemaElement &s_ele, Parquet return LogicalType::TIME_TZ; } return LogicalType::TIME; - } else if (s_ele.logicalType.__isset.GEOMETRY) { - return LogicalType::BLOB; - } else if (s_ele.logicalType.__isset.GEOGRAPHY) { - return LogicalType::BLOB; } } if (s_ele.__isset.converted_type) { @@ -406,10 +402,11 @@ unique_ptr ParquetReader::CreateReaderRecursive(ClientContext &con const vector &indexes, const ParquetColumnSchema &schema) { switch (schema.schema_type) { - case ParquetColumnSchemaType::GEOMETRY: - return GeoParquetFileMetadata::CreateColumnReader(*this, schema, context); case ParquetColumnSchemaType::FILE_ROW_NUMBER: return make_uniq(*this, schema); + case ParquetColumnSchemaType::GEOMETRY: { + return GeometryColumnReader::Create(*this, schema, context); + } case ParquetColumnSchemaType::COLUMN: { if (schema.children.empty()) { // leaf reader @@ -487,11 +484,11 @@ ParquetColumnSchema::ParquetColumnSchema(string name_p, LogicalType type_p, idx_ max_repeat(max_repeat), schema_index(schema_index), column_index(column_index) { } -ParquetColumnSchema::ParquetColumnSchema(ParquetColumnSchema parent, LogicalType result_type, +ParquetColumnSchema::ParquetColumnSchema(ParquetColumnSchema child, LogicalType result_type, ParquetColumnSchemaType schema_type) - : schema_type(schema_type), name(parent.name), type(std::move(result_type)), max_define(parent.max_define), - max_repeat(parent.max_repeat), schema_index(parent.schema_index), column_index(parent.column_index) { - children.push_back(std::move(parent)); + : schema_type(schema_type), name(child.name), type(std::move(result_type)), max_define(child.max_define), + max_repeat(child.max_repeat), schema_index(child.schema_index), column_index(child.column_index) { + children.push_back(std::move(child)); } unique_ptr ParquetColumnSchema::Stats(const FileMetaData &file_meta_data, @@ -518,59 +515,41 @@ unique_ptr ParquetColumnSchema::Stats(const FileMetaData &file_m return ParquetStatisticsUtils::TransformColumnStatistics(*this, columns, parquet_options.can_have_nan); } -static bool IsVariantType(const SchemaElement &root, const vector &children) { - if (children.size() < 2) { +static bool IsGeometryType(const SchemaElement &s_ele, const ParquetFileMetadataCache &metadata, idx_t depth) { + const auto is_blob = s_ele.__isset.type && s_ele.type == Type::BYTE_ARRAY; + if (!is_blob) { return false; } - auto &child0 = children[0]; - auto &child1 = children[1]; - ParquetColumnSchema const *metadata; - ParquetColumnSchema const *value; - - if (child0.name == "metadata" && child1.name == "value") { - metadata = &child0; - value = &child1; - } else if (child1.name == "metadata" && child0.name == "value") { - metadata = &child1; - value = &child0; - } else { - return false; + // TODO: Handle CRS in the future + const auto is_native_geom = s_ele.__isset.logicalType && s_ele.logicalType.__isset.GEOMETRY; + const auto is_native_geog = s_ele.__isset.logicalType && s_ele.logicalType.__isset.GEOGRAPHY; + if (is_native_geom || is_native_geog) { + return true; } - //! Verify names - if (metadata->name != "metadata") { - return false; - } - if (value->name != "value") { - return false; - } + // geoparquet types have to be at the root of the schema, and have to be present in the kv metadata. + const auto is_at_root = depth == 1; + const auto is_in_gpq_metadata = metadata.geo_metadata && metadata.geo_metadata->IsGeometryColumn(s_ele.name); + const auto is_leaf = s_ele.num_children == 0; + const auto is_geoparquet_geom = is_at_root && is_in_gpq_metadata && is_leaf; - //! Verify types - if (metadata->parquet_type != duckdb_parquet::Type::BYTE_ARRAY) { - return false; - } - if (value->parquet_type != duckdb_parquet::Type::BYTE_ARRAY) { - return false; - } - if (children.size() == 3) { - auto &typed_value = children[2]; - if (typed_value.name != "typed_value") { - return false; - } - } else if (children.size() != 2) { - return false; + if (is_geoparquet_geom) { + return true; } - return true; + + return false; } ParquetColumnSchema ParquetReader::ParseSchemaRecursive(idx_t depth, idx_t max_define, idx_t max_repeat, idx_t &next_schema_idx, idx_t &next_file_idx, ClientContext &context) { - auto file_meta_data = GetFileMetadata(); D_ASSERT(file_meta_data); - D_ASSERT(next_schema_idx < file_meta_data->schema.size()); + if (next_schema_idx >= file_meta_data->schema.size()) { + throw InvalidInputException("Malformed Parquet schema in file \"%s\": invalid schema index %d", file.path, + next_schema_idx); + } auto &s_ele = file_meta_data->schema[next_schema_idx]; auto this_idx = next_schema_idx; @@ -585,15 +564,26 @@ ParquetColumnSchema ParquetReader::ParseSchemaRecursive(idx_t depth, idx_t max_d max_repeat++; } - // Check for geoparquet spatial types - if (depth == 1) { - // geoparquet types have to be at the root of the schema, and have to be present in the kv metadata. - // geoarrow types, although geometry columns, are structs and have children and are handled below. - if (metadata->geo_metadata && metadata->geo_metadata->IsGeometryColumn(s_ele.name) && s_ele.num_children == 0) { - auto root_schema = ParseColumnSchema(s_ele, max_define, max_repeat, this_idx, next_file_idx++); - return ParquetColumnSchema(std::move(root_schema), GeoParquetFileMetadata::GeometryType(), - ParquetColumnSchemaType::GEOMETRY); - } + // Check for geometry type + if (IsGeometryType(s_ele, *metadata, depth)) { + // Geometries in both GeoParquet and native parquet are stored as a WKB-encoded BLOB. + // Because we don't just want to validate that the WKB encoding is correct, but also transform it into + // little-endian if necessary, we cant just make use of the StringColumnReader without heavily modifying it. + // Therefore, we create a dedicated GEOMETRY parquet column schema type, which wraps the underlying BLOB column. + // This schema type gets instantiated as a ExpressionColumnReader on top of the standard Blob/String reader, + // which performs the WKB validation/transformation using the `ST_GeomFromWKB` function of DuckDB. + // This enables us to also support other geometry encodings (such as GeoArrow geometries) easier in the future. + + // Inner BLOB schema + ParquetColumnSchema blob_schema(max_define, max_repeat, this_idx, next_file_idx++, + ParquetColumnSchemaType::COLUMN); + blob_schema.name = s_ele.name; + blob_schema.type = LogicalType::BLOB; + + // Wrap in geometry schema + ParquetColumnSchema geom_schema(std::move(blob_schema), LogicalType::GEOMETRY(), + ParquetColumnSchemaType::GEOMETRY); + return geom_schema; } if (s_ele.__isset.num_children && s_ele.num_children > 0) { // inner node @@ -627,9 +617,6 @@ ParquetColumnSchema ParquetReader::ParseSchemaRecursive(idx_t depth, idx_t max_d const bool is_map = s_ele.__isset.converted_type && s_ele.converted_type == ConvertedType::MAP; bool is_map_kv = s_ele.__isset.converted_type && s_ele.converted_type == ConvertedType::MAP_KEY_VALUE; bool is_variant = s_ele.__isset.logicalType && s_ele.logicalType.__isset.VARIANT == true; - if (!is_variant) { - is_variant = parquet_options.variant_legacy_encoding && IsVariantType(s_ele, child_schemas); - } if (!is_map_kv && this_idx > 0) { // check if the parent node of this is a map @@ -665,7 +652,7 @@ ParquetColumnSchema ParquetReader::ParseSchemaRecursive(idx_t depth, idx_t max_d LogicalType result_type; if (is_variant) { - result_type = LogicalType::JSON(); + result_type = LogicalType::VARIANT(); } else { result_type = LogicalType::STRUCT(std::move(struct_types)); } @@ -705,13 +692,6 @@ ParquetColumnSchema ParquetReader::ParseSchemaRecursive(idx_t depth, idx_t max_d return list_schema; } - // Convert to geometry type if possible - if (s_ele.__isset.logicalType && (s_ele.logicalType.__isset.GEOMETRY || s_ele.logicalType.__isset.GEOGRAPHY) && - GeoParquetFileMetadata::IsGeoParquetConversionEnabled(context)) { - return ParquetColumnSchema(std::move(result), GeoParquetFileMetadata::GeometryType(), - ParquetColumnSchemaType::GEOMETRY); - } - return result; } } @@ -727,23 +707,28 @@ unique_ptr ParquetReader::ParseSchema(ClientContext &contex idx_t next_file_idx = 0; if (file_meta_data->schema.empty()) { - throw IOException("Parquet reader: no schema elements found"); + throw IOException("Failed to read Parquet file \"%s\": no schema elements found", file.path); } if (file_meta_data->schema[0].num_children == 0) { - throw IOException("Parquet reader: root schema element has no children"); + throw IOException("Failed to read Parquet file \"%s\": root schema element has no children", file.path); } auto root = ParseSchemaRecursive(0, 0, 0, next_schema_idx, next_file_idx, context); if (root.type.id() != LogicalTypeId::STRUCT) { - throw InvalidInputException("Root element of Parquet file must be a struct"); + throw InvalidInputException("Failed to read Parquet file \"%s\": Root element of Parquet file must be a struct", + file.path); } D_ASSERT(next_schema_idx == file_meta_data->schema.size() - 1); - D_ASSERT(file_meta_data->row_groups.empty() || next_file_idx == file_meta_data->row_groups[0].columns.size()); + if (!file_meta_data->row_groups.empty() && next_file_idx != file_meta_data->row_groups[0].columns.size()) { + throw InvalidInputException("Failed to read Parquet file \"%s\": row group does not have enough columns", + file.path); + } if (parquet_options.file_row_number) { for (auto &column : root.children) { auto &name = column.name; if (StringUtil::CIEquals(name, "file_row_number")) { - throw BinderException( - "Using file_row_number option on file with column named file_row_number is not supported"); + throw BinderException("Failed to read Parquet file \"%s\": Using file_row_number option on file with " + "column named file_row_number is not supported", + file.path); } } root.children.push_back(FileRowNumberSchema()); @@ -808,9 +793,6 @@ ParquetOptions::ParquetOptions(ClientContext &context) { if (context.TryGetCurrentSetting("binary_as_string", lookup_value)) { binary_as_string = lookup_value.GetValue(); } - if (context.TryGetCurrentSetting("variant_legacy_encoding", lookup_value)) { - variant_legacy_encoding = lookup_value.GetValue(); - } } ParquetColumnDefinition ParquetColumnDefinition::FromSchemaValue(ClientContext &context, const Value &column_value) { @@ -837,7 +819,7 @@ ParquetReader::ParquetReader(ClientContext &context_p, OpenFileInfo file_p, Parq shared_ptr metadata_p) : BaseFileReader(std::move(file_p)), fs(CachingFileSystem::Get(context_p)), allocator(BufferAllocator::Get(context_p)), parquet_options(std::move(parquet_options_p)) { - file_handle = fs.OpenFile(QueryContext(context_p), file, FileFlags::FILE_FLAGS_READ); + file_handle = fs.OpenFile(context_p, file, FileFlags::FILE_FLAGS_READ); if (!file_handle->CanSeek()) { throw NotImplementedException( "Reading parquet files from a FIFO stream is not supported and cannot be efficiently supported since " @@ -1046,7 +1028,6 @@ uint64_t ParquetReader::GetGroupSpan(ParquetReaderScanState &state) { idx_t max_offset = NumericLimits::Minimum(); for (auto &column_chunk : group.columns) { - // Set the min offset idx_t current_min_offset = NumericLimits::Maximum(); if (column_chunk.meta_data.__isset.dictionary_page_offset) { @@ -1236,7 +1217,7 @@ void ParquetReader::InitializeScan(ClientContext &context, ParquetReaderScanStat state.prefetch_mode = false; } - state.file_handle = fs.OpenFile(QueryContext(context), file, flags); + state.file_handle = fs.OpenFile(context, file, flags); } state.adaptive_filter.reset(); state.scan_filters.clear(); @@ -1247,21 +1228,12 @@ void ParquetReader::InitializeScan(ClientContext &context, ParquetReaderScanStat } } - state.thrift_file_proto = CreateThriftFileProtocol(QueryContext(context), *state.file_handle, state.prefetch_mode); + state.thrift_file_proto = CreateThriftFileProtocol(context, *state.file_handle, state.prefetch_mode); state.root_reader = CreateReader(context); state.define_buf.resize(allocator, STANDARD_VECTOR_SIZE); state.repeat_buf.resize(allocator, STANDARD_VECTOR_SIZE); } -void ParquetReader::Scan(ClientContext &context, ParquetReaderScanState &state, DataChunk &result) { - while (ScanInternal(context, state, result)) { - if (result.size() > 0) { - break; - } - result.Reset(); - } -} - void ParquetReader::GetPartitionStats(vector &result) { GetPartitionStats(*GetFileMetadata(), result); } @@ -1279,9 +1251,10 @@ void ParquetReader::GetPartitionStats(const duckdb_parquet::FileMetaData &metada } } -bool ParquetReader::ScanInternal(ClientContext &context, ParquetReaderScanState &state, DataChunk &result) { +AsyncResult ParquetReader::Scan(ClientContext &context, ParquetReaderScanState &state, DataChunk &result) { + result.Reset(); if (state.finished) { - return false; + return SourceResultType::FINISHED; } // see if we have to switch to the next row group in the parquet file @@ -1295,7 +1268,7 @@ bool ParquetReader::ScanInternal(ClientContext &context, ParquetReaderScanState if ((idx_t)state.current_group == state.group_idx_list.size()) { state.finished = true; - return false; + return SourceResultType::FINISHED; } // TODO: only need this if we have a deletion vector? @@ -1367,7 +1340,8 @@ bool ParquetReader::ScanInternal(ClientContext &context, ParquetReaderScanState } } } - return true; + result.Reset(); + return SourceResultType::HAVE_MORE_OUTPUT; } auto scan_count = MinValue(STANDARD_VECTOR_SIZE, GetGroup(state).num_rows - state.offset_in_group); @@ -1375,7 +1349,8 @@ bool ParquetReader::ScanInternal(ClientContext &context, ParquetReaderScanState if (scan_count == 0) { state.finished = true; - return false; // end of last group, we are done + // end of last group, we are done + return SourceResultType::FINISHED; } auto &deletion_filter = state.root_reader->Reader().deletion_filter; @@ -1461,7 +1436,7 @@ bool ParquetReader::ScanInternal(ClientContext &context, ParquetReaderScanState rows_read += scan_count; state.offset_in_group += scan_count; - return true; + return SourceResultType::HAVE_MORE_OUTPUT; } } // namespace duckdb diff --git a/src/duckdb/extension/parquet/parquet_shredding.cpp b/src/duckdb/extension/parquet/parquet_shredding.cpp new file mode 100644 index 000000000..b7ed673a8 --- /dev/null +++ b/src/duckdb/extension/parquet/parquet_shredding.cpp @@ -0,0 +1,81 @@ +#include "parquet_shredding.hpp" +#include "duckdb/common/exception/binder_exception.hpp" +#include "duckdb/common/type_visitor.hpp" + +namespace duckdb { + +ChildShreddingTypes::ChildShreddingTypes() : types(make_uniq>()) { +} + +ChildShreddingTypes ChildShreddingTypes::Copy() const { + ChildShreddingTypes result; + for (const auto &type : *types) { + result.types->emplace(type.first, type.second.Copy()); + } + return result; +} + +ShreddingType::ShreddingType() : set(false) { +} + +ShreddingType::ShreddingType(const LogicalType &type) : set(true), type(type) { +} + +ShreddingType ShreddingType::Copy() const { + auto result = set ? ShreddingType(type) : ShreddingType(); + result.children = children.Copy(); + return result; +} + +static ShreddingType ConvertShreddingTypeRecursive(const LogicalType &type) { + if (type.id() == LogicalTypeId::VARIANT) { + return ShreddingType(LogicalType(LogicalTypeId::ANY)); + } + if (!type.IsNested()) { + return ShreddingType(type); + } + + switch (type.id()) { + case LogicalTypeId::STRUCT: { + ShreddingType res(type); + auto &children = StructType::GetChildTypes(type); + for (auto &entry : children) { + res.AddChild(entry.first, ConvertShreddingTypeRecursive(entry.second)); + } + return res; + } + case LogicalTypeId::LIST: { + ShreddingType res(type); + const auto &child = ListType::GetChildType(type); + res.AddChild("element", ConvertShreddingTypeRecursive(child)); + return res; + } + default: + break; + } + throw BinderException("VARIANT can only be shredded on LIST/STRUCT/ANY/non-nested type, not %s", type.ToString()); +} + +void ShreddingType::AddChild(const string &name, ShreddingType &&child) { + children.types->emplace(name, std::move(child)); +} + +optional_ptr ShreddingType::GetChild(const string &name) const { + auto it = children.types->find(name); + if (it == children.types->end()) { + return nullptr; + } + return it->second; +} + +ShreddingType ShreddingType::GetShreddingTypes(const Value &val) { + if (val.type().id() != LogicalTypeId::VARCHAR) { + throw BinderException("SHREDDING value should be of type VARCHAR, a stringified type to use for the column"); + } + auto type_str = val.GetValue(); + auto logical_type = TransformStringToLogicalType(type_str); + + return ConvertShreddingTypeRecursive(logical_type); +} + +} // namespace duckdb diff --git a/src/duckdb/extension/parquet/parquet_statistics.cpp b/src/duckdb/extension/parquet/parquet_statistics.cpp index 5f7d93718..27c5daacc 100644 --- a/src/duckdb/extension/parquet/parquet_statistics.cpp +++ b/src/duckdb/extension/parquet/parquet_statistics.cpp @@ -322,7 +322,6 @@ Value ParquetStatisticsUtils::ConvertValueInternal(const LogicalType &type, cons unique_ptr ParquetStatisticsUtils::TransformColumnStatistics(const ParquetColumnSchema &schema, const vector &columns, bool can_have_nan) { - // Not supported types auto &type = schema.type; if (type.id() == LogicalTypeId::ARRAY || type.id() == LogicalTypeId::MAP || type.id() == LogicalTypeId::LIST) { @@ -395,26 +394,71 @@ unique_ptr ParquetStatisticsUtils::TransformColumnStatistics(con } break; case LogicalTypeId::VARCHAR: { - auto string_stats = StringStats::CreateEmpty(type); + auto string_stats = StringStats::CreateUnknown(type); if (parquet_stats.__isset.min_value) { StringColumnReader::VerifyString(parquet_stats.min_value.c_str(), parquet_stats.min_value.size(), true); - StringStats::Update(string_stats, parquet_stats.min_value); + StringStats::SetMin(string_stats, parquet_stats.min_value); } else if (parquet_stats.__isset.min) { StringColumnReader::VerifyString(parquet_stats.min.c_str(), parquet_stats.min.size(), true); - StringStats::Update(string_stats, parquet_stats.min); + StringStats::SetMin(string_stats, parquet_stats.min); } if (parquet_stats.__isset.max_value) { StringColumnReader::VerifyString(parquet_stats.max_value.c_str(), parquet_stats.max_value.size(), true); - StringStats::Update(string_stats, parquet_stats.max_value); + StringStats::SetMax(string_stats, parquet_stats.max_value); } else if (parquet_stats.__isset.max) { StringColumnReader::VerifyString(parquet_stats.max.c_str(), parquet_stats.max.size(), true); - StringStats::Update(string_stats, parquet_stats.max); + StringStats::SetMax(string_stats, parquet_stats.max); } - StringStats::SetContainsUnicode(string_stats); - StringStats::ResetMaxStringLength(string_stats); row_group_stats = string_stats.ToUnique(); break; } + case LogicalTypeId::GEOMETRY: { + auto geo_stats = GeometryStats::CreateUnknown(type); + if (column_chunk.meta_data.__isset.geospatial_statistics) { + if (column_chunk.meta_data.geospatial_statistics.__isset.bbox) { + auto &bbox = column_chunk.meta_data.geospatial_statistics.bbox; + auto &stats_bbox = GeometryStats::GetExtent(geo_stats); + + // xmin > xmax is allowed if the geometry crosses the antimeridian, + // but we don't handle this right now + if (bbox.xmin <= bbox.xmax) { + stats_bbox.x_min = bbox.xmin; + stats_bbox.x_max = bbox.xmax; + } + + if (bbox.ymin <= bbox.ymax) { + stats_bbox.y_min = bbox.ymin; + stats_bbox.y_max = bbox.ymax; + } + + if (bbox.__isset.zmin && bbox.__isset.zmax && bbox.zmin <= bbox.zmax) { + stats_bbox.z_min = bbox.zmin; + stats_bbox.z_max = bbox.zmax; + } + + if (bbox.__isset.mmin && bbox.__isset.mmax && bbox.mmin <= bbox.mmax) { + stats_bbox.m_min = bbox.mmin; + stats_bbox.m_max = bbox.mmax; + } + } + if (column_chunk.meta_data.geospatial_statistics.__isset.geospatial_types) { + auto &types = column_chunk.meta_data.geospatial_statistics.geospatial_types; + auto &stats_types = GeometryStats::GetTypes(geo_stats); + + // if types are set but empty, that still means "any type" - so we leave stats_types as-is (unknown) + // otherwise, clear and set to the actual types + + if (!types.empty()) { + stats_types.Clear(); + for (auto &geom_type : types) { + stats_types.AddWKBType(geom_type); + } + } + } + } + row_group_stats = geo_stats.ToUnique(); + break; + } default: // no stats for you break; @@ -580,7 +624,6 @@ bool ParquetStatisticsUtils::BloomFilterExcludes(const TableFilter &duckdb_filte } ParquetBloomFilter::ParquetBloomFilter(idx_t num_entries, double bloom_filter_false_positive_ratio) { - // aim for hit ratio of 0.01% // see http://tfk.mit.edu/pdf/bloom.pdf double f = bloom_filter_false_positive_ratio; diff --git a/src/duckdb/extension/parquet/parquet_writer.cpp b/src/duckdb/extension/parquet/parquet_writer.cpp index 2021335ad..2012a4884 100644 --- a/src/duckdb/extension/parquet/parquet_writer.cpp +++ b/src/duckdb/extension/parquet/parquet_writer.cpp @@ -3,6 +3,7 @@ #include "duckdb.hpp" #include "mbedtls_wrapper.hpp" #include "parquet_crypto.hpp" +#include "parquet_shredding.hpp" #include "parquet_timestamp.hpp" #include "resizable_buffer.hpp" #include "duckdb/common/file_system.hpp" @@ -35,29 +36,6 @@ using duckdb_parquet::PageType; using ParquetRowGroup = duckdb_parquet::RowGroup; using duckdb_parquet::Type; -ChildFieldIDs::ChildFieldIDs() : ids(make_uniq>()) { -} - -ChildFieldIDs ChildFieldIDs::Copy() const { - ChildFieldIDs result; - for (const auto &id : *ids) { - result.ids->emplace(id.first, id.second.Copy()); - } - return result; -} - -FieldID::FieldID() : set(false) { -} - -FieldID::FieldID(int32_t field_id_p) : set(true), field_id(field_id_p) { -} - -FieldID FieldID::Copy() const { - auto result = set ? FieldID(field_id) : FieldID(); - result.child_field_ids = child_field_ids.Copy(); - return result; -} - class MyTransport : public TTransport { public: explicit MyTransport(WriteStream &serializer) : serializer(serializer) { @@ -109,6 +87,7 @@ bool ParquetWriter::TryGetParquetType(const LogicalType &duckdb_type, optional_p case LogicalTypeId::ENUM: case LogicalTypeId::BLOB: case LogicalTypeId::VARCHAR: + case LogicalTypeId::GEOMETRY: parquet_type = Type::BYTE_ARRAY; break; case LogicalTypeId::TIME: @@ -166,7 +145,8 @@ Type::type ParquetWriter::DuckDBTypeToParquetType(const LogicalType &duckdb_type throw NotImplementedException("Unimplemented type for Parquet \"%s\"", duckdb_type.ToString()); } -void ParquetWriter::SetSchemaProperties(const LogicalType &duckdb_type, duckdb_parquet::SchemaElement &schema_ele) { +void ParquetWriter::SetSchemaProperties(const LogicalType &duckdb_type, duckdb_parquet::SchemaElement &schema_ele, + bool allow_geometry) { if (duckdb_type.IsJSONType()) { schema_ele.converted_type = ConvertedType::JSON; schema_ele.__isset.converted_type = true; @@ -174,13 +154,6 @@ void ParquetWriter::SetSchemaProperties(const LogicalType &duckdb_type, duckdb_p schema_ele.logicalType.__set_JSON(duckdb_parquet::JsonType()); return; } - if (duckdb_type.GetAlias() == "WKB_BLOB") { - schema_ele.__isset.logicalType = true; - schema_ele.logicalType.__isset.GEOMETRY = true; - // TODO: Set CRS in the future - schema_ele.logicalType.GEOMETRY.__isset.crs = false; - return; - } switch (duckdb_type.id()) { case LogicalTypeId::TINYINT: schema_ele.converted_type = ConvertedType::INT_8; @@ -285,6 +258,13 @@ void ParquetWriter::SetSchemaProperties(const LogicalType &duckdb_type, duckdb_p schema_ele.logicalType.DECIMAL.precision = schema_ele.precision; schema_ele.logicalType.DECIMAL.scale = schema_ele.scale; break; + case LogicalTypeId::GEOMETRY: + if (allow_geometry) { // Don't set this if we write GeoParquet V1 + schema_ele.__isset.logicalType = true; + schema_ele.logicalType.__isset.GEOMETRY = true; + // TODO: Set CRS in the future + schema_ele.logicalType.GEOMETRY.__isset.crs = false; + } default: break; } @@ -336,9 +316,9 @@ struct ColumnStatsUnifier { bool can_have_nan = false; bool has_nan = false; - unique_ptr geo_stats; + unique_ptr geo_stats; - virtual void UnifyGeoStats(const GeometryStats &other) { + virtual void UnifyGeoStats(const GeometryStatsData &other) { } virtual void UnifyMinMax(const string &new_min, const string &new_max) = 0; @@ -352,19 +332,21 @@ class ParquetStatsAccumulator { ParquetWriter::ParquetWriter(ClientContext &context, FileSystem &fs, string file_name_p, vector types_p, vector names_p, CompressionCodec::type codec, ChildFieldIDs field_ids_p, - const vector> &kv_metadata, + ShreddingType shredding_types_p, const vector> &kv_metadata, shared_ptr encryption_config_p, optional_idx dictionary_size_limit_p, idx_t string_dictionary_page_size_limit_p, bool enable_bloom_filters_p, double bloom_filter_false_positive_ratio_p, - int64_t compression_level_p, bool debug_use_openssl_p, ParquetVersion parquet_version) + int64_t compression_level_p, bool debug_use_openssl_p, ParquetVersion parquet_version, + GeoParquetVersion geoparquet_version) : context(context), file_name(std::move(file_name_p)), sql_types(std::move(types_p)), column_names(std::move(names_p)), codec(codec), field_ids(std::move(field_ids_p)), - encryption_config(std::move(encryption_config_p)), dictionary_size_limit(dictionary_size_limit_p), + shredding_types(std::move(shredding_types_p)), encryption_config(std::move(encryption_config_p)), + dictionary_size_limit(dictionary_size_limit_p), string_dictionary_page_size_limit(string_dictionary_page_size_limit_p), enable_bloom_filters(enable_bloom_filters_p), bloom_filter_false_positive_ratio(bloom_filter_false_positive_ratio_p), compression_level(compression_level_p), - debug_use_openssl(debug_use_openssl_p), parquet_version(parquet_version), total_written(0), num_row_groups(0) { - + debug_use_openssl(debug_use_openssl_p), parquet_version(parquet_version), geoparquet_version(geoparquet_version), + total_written(0), num_row_groups(0) { // initialize the file writer writer = make_uniq(fs, file_name.c_str(), FileFlags::FILE_FLAGS_WRITE | FileFlags::FILE_FLAGS_FILE_CREATE_NEW); @@ -390,7 +372,7 @@ ParquetWriter::ParquetWriter(ClientContext &context, FileSystem &fs, string file protocol = tproto_factory.getProtocol(duckdb_base_std::make_shared(*writer)); file_meta_data.num_rows = 0; - file_meta_data.version = 1; + file_meta_data.version = UnsafeNumericCast(parquet_version); file_meta_data.__isset.created_by = true; file_meta_data.created_by = @@ -416,10 +398,13 @@ ParquetWriter::ParquetWriter(ClientContext &context, FileSystem &fs, string file auto &unique_names = column_names; VerifyUniqueNames(unique_names); + // V1 GeoParquet stores geometries as blobs, no logical type + auto allow_geometry = geoparquet_version != GeoParquetVersion::V1; + // construct the child schemas for (idx_t i = 0; i < sql_types.size(); i++) { - auto child_schema = - ColumnWriter::FillParquetSchema(file_meta_data.schema, sql_types[i], unique_names[i], &field_ids); + auto child_schema = ColumnWriter::FillParquetSchema(file_meta_data.schema, sql_types[i], unique_names[i], + allow_geometry, &field_ids, &shredding_types); column_schemas.push_back(std::move(child_schema)); } // now construct the writers based on the schemas @@ -459,7 +444,7 @@ void ParquetWriter::PrepareRowGroup(ColumnDataCollection &buffer, PreparedRowGro write_states.emplace_back(col_writers.back().get().InitializeWriteState(row_group)); } - for (auto &chunk : buffer.Chunks({column_ids})) { + for (auto &chunk : buffer.Chunks(column_ids)) { for (idx_t i = 0; i < next; i++) { if (col_writers[i].get().HasAnalyze()) { col_writers[i].get().Analyze(*write_states[i], nullptr, chunk.data[i], chunk.size()); @@ -556,7 +541,7 @@ void ParquetWriter::FlushRowGroup(PreparedRowGroup &prepared) { row_group.__isset.total_compressed_size = true; if (encryption_config) { - auto row_group_ordinal = num_row_groups.load(); + const auto row_group_ordinal = file_meta_data.row_groups.size(); if (row_group_ordinal > std::numeric_limits::max()) { throw InvalidInputException("RowGroup ordinal exceeds 32767 when encryption enabled"); } @@ -577,6 +562,14 @@ void ParquetWriter::Flush(ColumnDataCollection &buffer) { return; } + // "total_written" is only used for the FILE_SIZE_BYTES flag, and only when threads are writing in parallel. + // We pre-emptively increase it here to try to reduce overshooting when many threads are writing in parallel. + // However, waiting for the exact value (PrepareRowGroup) takes too long, and would cause overshoots to happen. + // So, we guess the compression ratio. We guess 3x, but this will be off depending on the data. + // "total_written" is restored to the exact number of written bytes at the end of FlushRowGroup. + // PhysicalCopyToFile should be reworked to use prepare/flush batch separately for better accuracy. + total_written += buffer.SizeInBytes() / 2; + PreparedRowGroup prepared_row_group; PrepareRowGroup(buffer, prepared_row_group); buffer.Reset(); @@ -685,15 +678,13 @@ struct BlobStatsUnifier : public BaseStringStatsUnifier { }; struct GeoStatsUnifier : public ColumnStatsUnifier { - - void UnifyGeoStats(const GeometryStats &other) override { + void UnifyGeoStats(const GeometryStatsData &other) override { if (geo_stats) { - geo_stats->bbox.Combine(other.bbox); - geo_stats->types.Combine(other.types); + geo_stats->Merge(other); } else { // Make copy - geo_stats = make_uniq(); - geo_stats->bbox = other.bbox; + geo_stats = make_uniq(); + geo_stats->extent = other.extent; geo_stats->types = other.types; } } @@ -707,17 +698,17 @@ struct GeoStatsUnifier : public ColumnStatsUnifier { return string(); } - const auto &bbox = geo_stats->bbox; + const auto &bbox = geo_stats->extent; const auto &types = geo_stats->types; - const auto bbox_value = Value::STRUCT({{"xmin", bbox.xmin}, - {"xmax", bbox.xmax}, - {"ymin", bbox.ymin}, - {"ymax", bbox.ymax}, - {"zmin", bbox.zmin}, - {"zmax", bbox.zmax}, - {"mmin", bbox.mmin}, - {"mmax", bbox.mmax}}); + const auto bbox_value = Value::STRUCT({{"xmin", bbox.x_min}, + {"xmax", bbox.x_max}, + {"ymin", bbox.y_min}, + {"ymax", bbox.y_max}, + {"zmin", bbox.z_min}, + {"zmax", bbox.z_max}, + {"mmin", bbox.m_min}, + {"mmax", bbox.m_max}}); vector type_strings; for (const auto &type : types.ToString(true)) { @@ -810,11 +801,9 @@ static unique_ptr GetBaseStatsUnifier(const LogicalType &typ } } case LogicalTypeId::BLOB: - if (type.GetAlias() == "WKB_BLOB") { - return make_uniq(); - } else { - return make_uniq(); - } + return make_uniq(); + case LogicalTypeId::GEOMETRY: + return make_uniq(); case LogicalTypeId::VARCHAR: return make_uniq(); case LogicalTypeId::UUID: @@ -903,22 +892,24 @@ void ParquetWriter::GatherWrittenStatistics() { column_stats["has_nan"] = Value::BOOLEAN(stats_unifier->has_nan); } if (stats_unifier->geo_stats) { - const auto &bbox = stats_unifier->geo_stats->bbox; + const auto &bbox = stats_unifier->geo_stats->extent; const auto &types = stats_unifier->geo_stats->types; - column_stats["bbox_xmin"] = Value::DOUBLE(bbox.xmin); - column_stats["bbox_xmax"] = Value::DOUBLE(bbox.xmax); - column_stats["bbox_ymin"] = Value::DOUBLE(bbox.ymin); - column_stats["bbox_ymax"] = Value::DOUBLE(bbox.ymax); + if (bbox.HasXY()) { + column_stats["bbox_xmin"] = Value::DOUBLE(bbox.x_min); + column_stats["bbox_xmax"] = Value::DOUBLE(bbox.x_max); + column_stats["bbox_ymin"] = Value::DOUBLE(bbox.y_min); + column_stats["bbox_ymax"] = Value::DOUBLE(bbox.y_max); - if (bbox.HasZ()) { - column_stats["bbox_zmin"] = Value::DOUBLE(bbox.zmin); - column_stats["bbox_zmax"] = Value::DOUBLE(bbox.zmax); - } + if (bbox.HasZ()) { + column_stats["bbox_zmin"] = Value::DOUBLE(bbox.z_min); + column_stats["bbox_zmax"] = Value::DOUBLE(bbox.z_max); + } - if (bbox.HasM()) { - column_stats["bbox_mmin"] = Value::DOUBLE(bbox.mmin); - column_stats["bbox_mmax"] = Value::DOUBLE(bbox.mmax); + if (bbox.HasM()) { + column_stats["bbox_mmin"] = Value::DOUBLE(bbox.m_min); + column_stats["bbox_mmax"] = Value::DOUBLE(bbox.m_max); + } } if (!types.IsEmpty()) { @@ -934,7 +925,6 @@ void ParquetWriter::GatherWrittenStatistics() { } void ParquetWriter::Finalize() { - // dump the bloom filters right before footer, not if stuff is encrypted for (auto &bloom_filter_entry : bloom_filters) { @@ -975,7 +965,8 @@ void ParquetWriter::Finalize() { } // Add geoparquet metadata to the file metadata - if (geoparquet_data && GeoParquetFileMetadata::IsGeoParquetConversionEnabled(context)) { + if (geoparquet_data && GeoParquetFileMetadata::IsGeoParquetConversionEnabled(context) && + geoparquet_version != GeoParquetVersion::NONE) { geoparquet_data->Write(file_meta_data); } @@ -1005,7 +996,7 @@ void ParquetWriter::Finalize() { GeoParquetFileMetadata &ParquetWriter::GetGeoParquetData() { if (!geoparquet_data) { - geoparquet_data = make_uniq(); + geoparquet_data = make_uniq(geoparquet_version); } return *geoparquet_data; } diff --git a/src/duckdb/extension/parquet/reader/list_column_reader.cpp b/src/duckdb/extension/parquet/reader/list_column_reader.cpp index 0ff1be271..b291e1019 100644 --- a/src/duckdb/extension/parquet/reader/list_column_reader.cpp +++ b/src/duckdb/extension/parquet/reader/list_column_reader.cpp @@ -175,7 +175,6 @@ ListColumnReader::ListColumnReader(ParquetReader &reader, const ParquetColumnSch unique_ptr child_column_reader_p) : ColumnReader(reader, schema), child_column_reader(std::move(child_column_reader_p)), read_cache(reader.allocator, ListType::GetChildType(Type())), read_vector(read_cache), overflow_child_count(0) { - child_defines.resize(reader.allocator, STANDARD_VECTOR_SIZE); child_repeats.resize(reader.allocator, STANDARD_VECTOR_SIZE); child_defines_ptr = (uint8_t *)child_defines.ptr; diff --git a/src/duckdb/extension/parquet/reader/string_column_reader.cpp b/src/duckdb/extension/parquet/reader/string_column_reader.cpp index 6b2a3db6d..019abd71a 100644 --- a/src/duckdb/extension/parquet/reader/string_column_reader.cpp +++ b/src/duckdb/extension/parquet/reader/string_column_reader.cpp @@ -9,7 +9,7 @@ namespace duckdb { // String Column Reader //===--------------------------------------------------------------------===// StringColumnReader::StringColumnReader(ParquetReader &reader, const ParquetColumnSchema &schema) - : ColumnReader(reader, schema) { + : ColumnReader(reader, schema), string_column_type(GetStringColumnType(Type())) { fixed_width_string_length = 0; if (schema.parquet_type == Type::FIXED_LEN_BYTE_ARRAY) { fixed_width_string_length = schema.type_length; @@ -26,13 +26,26 @@ void StringColumnReader::VerifyString(const char *str_data, uint32_t str_len, co size_t pos; auto utf_type = Utf8Proc::Analyze(str_data, str_len, &reason, &pos); if (utf_type == UnicodeType::INVALID) { - throw InvalidInputException("Invalid string encoding found in Parquet file: value \"" + - Blob::ToString(string_t(str_data, str_len)) + "\" is not valid UTF8!"); + throw InvalidInputException("Invalid string encoding found in Parquet file: value \"%s\" is not valid UTF8!", + Blob::ToString(string_t(str_data, str_len))); } } -void StringColumnReader::VerifyString(const char *str_data, uint32_t str_len) { - VerifyString(str_data, str_len, Type().id() == LogicalTypeId::VARCHAR); +void StringColumnReader::VerifyString(const char *str_data, uint32_t str_len) const { + switch (string_column_type) { + case StringColumnType::VARCHAR: + VerifyString(str_data, str_len, true); + break; + case StringColumnType::JSON: { + const auto error = StringUtil::ValidateJSON(str_data, str_len); + if (!error.empty()) { + throw InvalidInputException("Invalid JSON found in Parquet file: %s", error); + } + break; + } + default: + break; + } } class ParquetStringVectorBuffer : public VectorBuffer { diff --git a/src/duckdb/extension/parquet/reader/variant/variant_binary_decoder.cpp b/src/duckdb/extension/parquet/reader/variant/variant_binary_decoder.cpp index eacff5501..0388da0b3 100644 --- a/src/duckdb/extension/parquet/reader/variant/variant_binary_decoder.cpp +++ b/src/duckdb/extension/parquet/reader/variant/variant_binary_decoder.cpp @@ -15,7 +15,7 @@ static constexpr uint8_t VERSION_MASK = 0xF; static constexpr uint8_t SORTED_STRINGS_MASK = 0x1; static constexpr uint8_t SORTED_STRINGS_SHIFT = 4; static constexpr uint8_t OFFSET_SIZE_MINUS_ONE_MASK = 0x3; -static constexpr uint8_t OFFSET_SIZE_MINUS_ONE_SHIFT = 5; +static constexpr uint8_t OFFSET_SIZE_MINUS_ONE_SHIFT = 6; static constexpr uint8_t BASIC_TYPE_MASK = 0x3; static constexpr uint8_t VALUE_HEADER_SHIFT = 2; @@ -74,8 +74,8 @@ VariantMetadata::VariantMetadata(const string_t &metadata) : metadata(metadata) const_data_ptr_t ptr = reinterpret_cast(metadata_data + sizeof(uint8_t)); idx_t dictionary_size = ReadVariableLengthLittleEndian(header.offset_size, ptr); - offsets = ptr; - bytes = offsets + ((dictionary_size + 1) * header.offset_size); + auto offsets = ptr; + auto bytes = offsets + ((dictionary_size + 1) * header.offset_size); idx_t last_offset = ReadVariableLengthLittleEndian(header.offset_size, ptr); for (idx_t i = 0; i < dictionary_size; i++) { auto next_offset = ReadVariableLengthLittleEndian(header.offset_size, ptr); @@ -140,8 +140,7 @@ hugeint_t DecodeDecimal(const_data_ptr_t data, uint8_t &scale, uint8_t &width) { return result; } -VariantValue VariantBinaryDecoder::PrimitiveTypeDecode(const VariantMetadata &metadata, - const VariantValueMetadata &value_metadata, +VariantValue VariantBinaryDecoder::PrimitiveTypeDecode(const VariantValueMetadata &value_metadata, const_data_ptr_t data) { switch (value_metadata.primitive_type) { case VariantPrimitiveType::NULL_TYPE: { @@ -267,8 +266,7 @@ VariantValue VariantBinaryDecoder::PrimitiveTypeDecode(const VariantMetadata &me } } -VariantValue VariantBinaryDecoder::ShortStringDecode(const VariantMetadata &metadata, - const VariantValueMetadata &value_metadata, +VariantValue VariantBinaryDecoder::ShortStringDecode(const VariantValueMetadata &value_metadata, const_data_ptr_t data) { D_ASSERT(value_metadata.string_size < 64); auto string_data = reinterpret_cast(data); @@ -348,10 +346,10 @@ VariantValue VariantBinaryDecoder::Decode(const VariantMetadata &variant_metadat data++; switch (value_metadata.basic_type) { case VariantBasicType::PRIMITIVE: { - return PrimitiveTypeDecode(variant_metadata, value_metadata, data); + return PrimitiveTypeDecode(value_metadata, data); } case VariantBasicType::SHORT_STRING: { - return ShortStringDecode(variant_metadata, value_metadata, data); + return ShortStringDecode(value_metadata, data); } case VariantBasicType::OBJECT: { return ObjectDecode(variant_metadata, value_metadata, data); diff --git a/src/duckdb/extension/parquet/reader/variant/variant_shredded_conversion.cpp b/src/duckdb/extension/parquet/reader/variant/variant_shredded_conversion.cpp index 916e6e2cd..b96304d98 100644 --- a/src/duckdb/extension/parquet/reader/variant/variant_shredded_conversion.cpp +++ b/src/duckdb/extension/parquet/reader/variant/variant_shredded_conversion.cpp @@ -124,7 +124,7 @@ VariantValue ConvertShreddedValue::Convert(hugeint_t val) { template vector ConvertTypedValues(Vector &vec, Vector &metadata, Vector &blob, idx_t offset, idx_t length, - idx_t total_size) { + idx_t total_size, const bool is_field) { UnifiedVectorFormat metadata_format; metadata.ToUnifiedFormat(length, metadata_format); auto metadata_data = metadata_format.GetData(metadata_format); @@ -174,7 +174,12 @@ vector ConvertTypedValues(Vector &vec, Vector &metadata, Vector &b } else { ret[i] = OP::Convert(data[typed_index]); } - } else if (value_validity.RowIsValid(value_index)) { + } else { + if (is_field && !value_validity.RowIsValid(value_index)) { + //! Value is missing for this field + continue; + } + D_ASSERT(value_validity.RowIsValid(value_index)); auto metadata_value = metadata_data[metadata_format.sel->get_index(i)]; VariantMetadata variant_metadata(metadata_value); ret[i] = VariantBinaryDecoder::Decode(variant_metadata, @@ -187,7 +192,7 @@ vector ConvertTypedValues(Vector &vec, Vector &metadata, Vector &b vector VariantShreddedConversion::ConvertShreddedLeaf(Vector &metadata, Vector &value, Vector &typed_value, idx_t offset, idx_t length, - idx_t total_size) { + idx_t total_size, const bool is_field) { D_ASSERT(!typed_value.GetType().IsNested()); vector result; @@ -196,37 +201,37 @@ vector VariantShreddedConversion::ConvertShreddedLeaf(Vector &meta //! boolean case LogicalTypeId::BOOLEAN: { return ConvertTypedValues, LogicalTypeId::BOOLEAN>( - typed_value, metadata, value, offset, length, total_size); + typed_value, metadata, value, offset, length, total_size, is_field); } //! int8 case LogicalTypeId::TINYINT: { return ConvertTypedValues, LogicalTypeId::TINYINT>( - typed_value, metadata, value, offset, length, total_size); + typed_value, metadata, value, offset, length, total_size, is_field); } //! int16 case LogicalTypeId::SMALLINT: { return ConvertTypedValues, LogicalTypeId::SMALLINT>( - typed_value, metadata, value, offset, length, total_size); + typed_value, metadata, value, offset, length, total_size, is_field); } //! int32 case LogicalTypeId::INTEGER: { return ConvertTypedValues, LogicalTypeId::INTEGER>( - typed_value, metadata, value, offset, length, total_size); + typed_value, metadata, value, offset, length, total_size, is_field); } //! int64 case LogicalTypeId::BIGINT: { return ConvertTypedValues, LogicalTypeId::BIGINT>( - typed_value, metadata, value, offset, length, total_size); + typed_value, metadata, value, offset, length, total_size, is_field); } //! float case LogicalTypeId::FLOAT: { return ConvertTypedValues, LogicalTypeId::FLOAT>( - typed_value, metadata, value, offset, length, total_size); + typed_value, metadata, value, offset, length, total_size, is_field); } //! double case LogicalTypeId::DOUBLE: { return ConvertTypedValues, LogicalTypeId::DOUBLE>( - typed_value, metadata, value, offset, length, total_size); + typed_value, metadata, value, offset, length, total_size, is_field); } //! decimal4/decimal8/decimal16 case LogicalTypeId::DECIMAL: { @@ -234,15 +239,15 @@ vector VariantShreddedConversion::ConvertShreddedLeaf(Vector &meta switch (physical_type) { case PhysicalType::INT32: { return ConvertTypedValues, LogicalTypeId::DECIMAL>( - typed_value, metadata, value, offset, length, total_size); + typed_value, metadata, value, offset, length, total_size, is_field); } case PhysicalType::INT64: { return ConvertTypedValues, LogicalTypeId::DECIMAL>( - typed_value, metadata, value, offset, length, total_size); + typed_value, metadata, value, offset, length, total_size, is_field); } case PhysicalType::INT128: { return ConvertTypedValues, LogicalTypeId::DECIMAL>( - typed_value, metadata, value, offset, length, total_size); + typed_value, metadata, value, offset, length, total_size, is_field); } default: throw NotImplementedException("Decimal with PhysicalType (%s) not implemented for shredded Variant", @@ -252,42 +257,42 @@ vector VariantShreddedConversion::ConvertShreddedLeaf(Vector &meta //! date case LogicalTypeId::DATE: { return ConvertTypedValues, LogicalTypeId::DATE>( - typed_value, metadata, value, offset, length, total_size); + typed_value, metadata, value, offset, length, total_size, is_field); } //! time case LogicalTypeId::TIME: { return ConvertTypedValues, LogicalTypeId::TIME>( - typed_value, metadata, value, offset, length, total_size); + typed_value, metadata, value, offset, length, total_size, is_field); } //! timestamptz(6) (timestamptz(9) not implemented in DuckDB) case LogicalTypeId::TIMESTAMP_TZ: { return ConvertTypedValues, LogicalTypeId::TIMESTAMP_TZ>( - typed_value, metadata, value, offset, length, total_size); + typed_value, metadata, value, offset, length, total_size, is_field); } //! timestampntz(6) case LogicalTypeId::TIMESTAMP: { return ConvertTypedValues, LogicalTypeId::TIMESTAMP>( - typed_value, metadata, value, offset, length, total_size); + typed_value, metadata, value, offset, length, total_size, is_field); } //! timestampntz(9) case LogicalTypeId::TIMESTAMP_NS: { return ConvertTypedValues, LogicalTypeId::TIMESTAMP_NS>( - typed_value, metadata, value, offset, length, total_size); + typed_value, metadata, value, offset, length, total_size, is_field); } //! binary case LogicalTypeId::BLOB: { return ConvertTypedValues, LogicalTypeId::BLOB>( - typed_value, metadata, value, offset, length, total_size); + typed_value, metadata, value, offset, length, total_size, is_field); } //! string case LogicalTypeId::VARCHAR: { return ConvertTypedValues, LogicalTypeId::VARCHAR>( - typed_value, metadata, value, offset, length, total_size); + typed_value, metadata, value, offset, length, total_size, is_field); } //! uuid case LogicalTypeId::UUID: { return ConvertTypedValues, LogicalTypeId::UUID>( - typed_value, metadata, value, offset, length, total_size); + typed_value, metadata, value, offset, length, total_size, is_field); } default: throw NotImplementedException("Variant shredding on type: '%s' is not implemented", type.ToString()); @@ -395,7 +400,7 @@ static VariantValue ConvertPartiallyShreddedObject(vector vector VariantShreddedConversion::ConvertShreddedObject(Vector &metadata, Vector &value, Vector &typed_value, idx_t offset, idx_t length, - idx_t total_size) { + idx_t total_size, const bool is_field) { auto &type = typed_value.GetType(); D_ASSERT(type.id() == LogicalTypeId::STRUCT); auto &fields = StructType::GetChildTypes(type); @@ -445,7 +450,10 @@ vector VariantShreddedConversion::ConvertShreddedObject(Vector &me if (typed_validity.RowIsValid(typed_index)) { ret[i] = ConvertPartiallyShreddedObject(shredded_fields, metadata_format, value_format, i, offset); } else { - //! The value on this row is not an object, and guaranteed to be present + if (is_field && !validity.RowIsValid(value_index)) { + //! This object is a field in the parent object, the value is missing, skip it + continue; + } D_ASSERT(validity.RowIsValid(value_index)); auto &metadata_value = metadata_data[metadata_format.sel->get_index(i)]; VariantMetadata variant_metadata(metadata_value); @@ -463,7 +471,7 @@ vector VariantShreddedConversion::ConvertShreddedObject(Vector &me vector VariantShreddedConversion::ConvertShreddedArray(Vector &metadata, Vector &value, Vector &typed_value, idx_t offset, idx_t length, - idx_t total_size) { + idx_t total_size, const bool is_field) { auto &child = ListVector::GetEntry(typed_value); auto list_size = ListVector::GetListSize(typed_value); @@ -489,23 +497,26 @@ vector VariantShreddedConversion::ConvertShreddedArray(Vector &met //! We can be sure that none of the values are binary encoded for (idx_t i = 0; i < length; i++) { auto typed_index = list_format.sel->get_index(i + offset); - //! FIXME: next 4 lines duplicated below auto entry = list_data[typed_index]; Vector child_metadata(metadata.GetValue(i)); ret[i] = VariantValue(VariantValueType::ARRAY); - ret[i].array_items = Convert(child_metadata, child, entry.offset, entry.length, list_size); + ret[i].array_items = Convert(child_metadata, child, entry.offset, entry.length, list_size, false); } } else { for (idx_t i = 0; i < length; i++) { auto typed_index = list_format.sel->get_index(i + offset); auto value_index = value_format.sel->get_index(i + offset); if (validity.RowIsValid(typed_index)) { - //! FIXME: next 4 lines duplicate auto entry = list_data[typed_index]; Vector child_metadata(metadata.GetValue(i)); ret[i] = VariantValue(VariantValueType::ARRAY); - ret[i].array_items = Convert(child_metadata, child, entry.offset, entry.length, list_size); - } else if (value_validity.RowIsValid(value_index)) { + ret[i].array_items = Convert(child_metadata, child, entry.offset, entry.length, list_size, false); + } else { + if (is_field && !value_validity.RowIsValid(value_index)) { + //! Value is missing for this field + continue; + } + D_ASSERT(value_validity.RowIsValid(value_index)); auto metadata_value = metadata_data[metadata_format.sel->get_index(i)]; VariantMetadata variant_metadata(metadata_value); ret[i] = VariantBinaryDecoder::Decode(variant_metadata, @@ -547,11 +558,11 @@ vector VariantShreddedConversion::Convert(Vector &metadata, Vector auto &type = typed_value->GetType(); vector ret; if (type.id() == LogicalTypeId::STRUCT) { - return ConvertShreddedObject(metadata, *value, *typed_value, offset, length, total_size); + return ConvertShreddedObject(metadata, *value, *typed_value, offset, length, total_size, is_field); } else if (type.id() == LogicalTypeId::LIST) { - return ConvertShreddedArray(metadata, *value, *typed_value, offset, length, total_size); + return ConvertShreddedArray(metadata, *value, *typed_value, offset, length, total_size, is_field); } else { - return ConvertShreddedLeaf(metadata, *value, *typed_value, offset, length, total_size); + return ConvertShreddedLeaf(metadata, *value, *typed_value, offset, length, total_size, is_field); } } else { if (is_field) { diff --git a/src/duckdb/extension/parquet/reader/variant/variant_value.cpp b/src/duckdb/extension/parquet/reader/variant/variant_value.cpp index 0ac213469..6b3d290f4 100644 --- a/src/duckdb/extension/parquet/reader/variant/variant_value.cpp +++ b/src/duckdb/extension/parquet/reader/variant/variant_value.cpp @@ -1,4 +1,18 @@ #include "reader/variant/variant_value.hpp" +#include "duckdb/common/serializer/varint.hpp" +#include "duckdb/common/enum_util.hpp" +#include "duckdb/common/types/time.hpp" +#include "duckdb/common/types/datetime.hpp" +#include "duckdb/common/types/timestamp.hpp" +#include "duckdb/common/types/date.hpp" +#include "duckdb/common/types/interval.hpp" +#include "duckdb/common/types/decimal.hpp" +#include "duckdb/common/types/variant.hpp" +#include "duckdb/common/hugeint.hpp" +#include "duckdb/function/scalar/variant_utils.hpp" +#include "duckdb/common/string_map_set.hpp" +#include "duckdb/common/helper.hpp" +#include "duckdb/function/cast/variant/to_variant_fwd.hpp" namespace duckdb { @@ -12,6 +26,560 @@ void VariantValue::AddItem(VariantValue &&val) { array_items.push_back(std::move(val)); } +static void InitializeOffsets(DataChunk &offsets, idx_t count) { + auto keys = variant::OffsetData::GetKeys(offsets); + auto children = variant::OffsetData::GetChildren(offsets); + auto values = variant::OffsetData::GetValues(offsets); + auto blob = variant::OffsetData::GetBlob(offsets); + for (idx_t i = 0; i < count; i++) { + keys[i] = 0; + children[i] = 0; + values[i] = 0; + blob[i] = 0; + } +} + +static void AnalyzeValue(const VariantValue &value, idx_t row, DataChunk &offsets) { + auto &keys_offset = variant::OffsetData::GetKeys(offsets)[row]; + auto &children_offset = variant::OffsetData::GetChildren(offsets)[row]; + auto &values_offset = variant::OffsetData::GetValues(offsets)[row]; + auto &data_offset = variant::OffsetData::GetBlob(offsets)[row]; + + values_offset++; + switch (value.value_type) { + case VariantValueType::OBJECT: { + //! Write the count of the children + auto &children = value.object_children; + data_offset += GetVarintSize(children.size()); + if (!children.empty()) { + //! Write the children offset + data_offset += GetVarintSize(children_offset); + children_offset += children.size(); + keys_offset += children.size(); + for (auto &child : children) { + auto &child_value = child.second; + AnalyzeValue(child_value, row, offsets); + } + } + break; + } + case VariantValueType::ARRAY: { + //! Write the count of the children + auto &children = value.array_items; + data_offset += GetVarintSize(children.size()); + if (!children.empty()) { + //! Write the children offset + data_offset += GetVarintSize(children_offset); + children_offset += children.size(); + for (auto &child : children) { + AnalyzeValue(child, row, offsets); + } + } + break; + } + case VariantValueType::PRIMITIVE: { + auto &primitive = value.primitive_value; + auto type_id = primitive.type().id(); + switch (type_id) { + case LogicalTypeId::BOOLEAN: + case LogicalTypeId::SQLNULL: { + break; + } + case LogicalTypeId::TINYINT: { + data_offset += sizeof(int8_t); + break; + } + case LogicalTypeId::SMALLINT: { + data_offset += sizeof(int16_t); + break; + } + case LogicalTypeId::INTEGER: { + data_offset += sizeof(int32_t); + break; + } + case LogicalTypeId::BIGINT: { + data_offset += sizeof(int64_t); + break; + } + case LogicalTypeId::DOUBLE: { + data_offset += sizeof(double); + break; + } + case LogicalTypeId::FLOAT: { + data_offset += sizeof(float); + break; + } + case LogicalTypeId::DATE: { + data_offset += sizeof(date_t); + break; + } + case LogicalTypeId::TIMESTAMP_TZ: { + data_offset += sizeof(timestamp_tz_t); + break; + } + case LogicalTypeId::TIMESTAMP: { + data_offset += sizeof(timestamp_t); + break; + } + case LogicalTypeId::TIME: { + data_offset += sizeof(dtime_t); + break; + } + case LogicalTypeId::TIMESTAMP_NS: { + data_offset += sizeof(timestamp_ns_t); + break; + } + case LogicalTypeId::UUID: { + data_offset += sizeof(hugeint_t); + break; + } + case LogicalTypeId::DECIMAL: { + auto &type = primitive.type(); + uint8_t width; + uint8_t scale; + type.GetDecimalProperties(width, scale); + + auto physical_type = type.InternalType(); + data_offset += GetVarintSize(width); + data_offset += GetVarintSize(scale); + switch (physical_type) { + case PhysicalType::INT32: { + data_offset += sizeof(int32_t); + break; + } + case PhysicalType::INT64: { + data_offset += sizeof(int64_t); + break; + } + case PhysicalType::INT128: { + data_offset += sizeof(hugeint_t); + break; + } + default: + throw InternalException("Unexpected physical type for Decimal value: %s", + EnumUtil::ToString(physical_type)); + } + break; + } + case LogicalTypeId::BLOB: + case LogicalTypeId::VARCHAR: { + auto string_data = primitive.GetValueUnsafe(); + data_offset += GetVarintSize(string_data.GetSize()); + data_offset += string_data.GetSize(); + break; + } + default: + throw InternalException("Encountered unrecognized LogicalType in VariantValue::AnalyzeValue: %s", + primitive.type().ToString()); + } + break; + } + default: + throw InternalException("VariantValueType not handled"); + } +} + +uint32_t GetOrCreateIndex(OrderedOwningStringMap &dictionary, const string_t &key) { + auto unsorted_idx = dictionary.size(); + //! This will later be remapped to the sorted idx (see FinalizeVariantKeys in 'to_variant.cpp') + return dictionary.emplace(std::make_pair(key, unsorted_idx)).first->second; +} + +static void ConvertValue(const VariantValue &value, VariantVectorData &result, idx_t row, DataChunk &offsets, + SelectionVector &keys_selvec, OrderedOwningStringMap &dictionary) { + auto blob_data = data_ptr_cast(result.blob_data[row].GetDataWriteable()); + auto keys_list_offset = result.keys_data[row].offset; + auto children_list_offset = result.children_data[row].offset; + auto values_list_offset = result.values_data[row].offset; + + auto &keys_offset = variant::OffsetData::GetKeys(offsets)[row]; + auto &children_offset = variant::OffsetData::GetChildren(offsets)[row]; + auto &values_offset = variant::OffsetData::GetValues(offsets)[row]; + auto &data_offset = variant::OffsetData::GetBlob(offsets)[row]; + + switch (value.value_type) { + case VariantValueType::OBJECT: { + //! Write the count of the children + auto &children = value.object_children; + + //! values + result.type_ids_data[values_list_offset + values_offset] = static_cast(VariantLogicalType::OBJECT); + result.byte_offset_data[values_list_offset + values_offset] = data_offset; + values_offset++; + + //! data + VarintEncode(children.size(), blob_data + data_offset); + data_offset += GetVarintSize(children.size()); + + if (!children.empty()) { + //! Write the children offset + VarintEncode(children_offset, blob_data + data_offset); + data_offset += GetVarintSize(children_offset); + + auto start_of_children = children_offset; + children_offset += children.size(); + + auto it = children.begin(); + for (idx_t i = 0; i < children.size(); i++) { + //! children + result.keys_index_data[children_list_offset + start_of_children + i] = keys_offset; + result.values_index_data[children_list_offset + start_of_children + i] = values_offset; + + auto &child = *it; + //! keys + auto &child_key = child.first; + auto dictionary_index = GetOrCreateIndex(dictionary, child_key); + keys_selvec.set_index(keys_list_offset + keys_offset, dictionary_index); + keys_offset++; + + auto &child_value = child.second; + ConvertValue(child_value, result, row, offsets, keys_selvec, dictionary); + it++; + } + } + break; + } + case VariantValueType::ARRAY: { + //! Write the count of the children + auto &children = value.array_items; + + //! values + result.type_ids_data[values_list_offset + values_offset] = static_cast(VariantLogicalType::ARRAY); + result.byte_offset_data[values_list_offset + values_offset] = data_offset; + values_offset++; + + //! data + VarintEncode(children.size(), blob_data + data_offset); + data_offset += GetVarintSize(children.size()); + + if (!children.empty()) { + //! Write the children offset + VarintEncode(children_offset, blob_data + data_offset); + data_offset += GetVarintSize(children_offset); + + auto start_of_children = children_offset; + children_offset += children.size(); + + for (idx_t i = 0; i < children.size(); i++) { + //! children + result.keys_index_validity.SetInvalid(children_list_offset + start_of_children + i); + result.values_index_data[children_list_offset + start_of_children + i] = values_offset; + + auto &child_value = children[i]; + ConvertValue(child_value, result, row, offsets, keys_selvec, dictionary); + } + } + break; + } + case VariantValueType::PRIMITIVE: { + auto &primitive = value.primitive_value; + auto type_id = primitive.type().id(); + result.byte_offset_data[values_list_offset + values_offset] = data_offset; + switch (type_id) { + case LogicalTypeId::BOOLEAN: { + if (primitive.GetValue()) { + result.type_ids_data[values_list_offset + values_offset] = + static_cast(VariantLogicalType::BOOL_TRUE); + } else { + result.type_ids_data[values_list_offset + values_offset] = + static_cast(VariantLogicalType::BOOL_FALSE); + } + break; + } + case LogicalTypeId::SQLNULL: { + result.type_ids_data[values_list_offset + values_offset] = + static_cast(VariantLogicalType::VARIANT_NULL); + break; + } + case LogicalTypeId::TINYINT: { + result.type_ids_data[values_list_offset + values_offset] = static_cast(VariantLogicalType::INT8); + Store(primitive.GetValueUnsafe(), blob_data + data_offset); + data_offset += sizeof(int8_t); + break; + } + case LogicalTypeId::SMALLINT: { + result.type_ids_data[values_list_offset + values_offset] = static_cast(VariantLogicalType::INT16); + Store(primitive.GetValueUnsafe(), blob_data + data_offset); + data_offset += sizeof(int16_t); + break; + } + case LogicalTypeId::INTEGER: { + result.type_ids_data[values_list_offset + values_offset] = static_cast(VariantLogicalType::INT32); + Store(primitive.GetValueUnsafe(), blob_data + data_offset); + data_offset += sizeof(int32_t); + break; + } + case LogicalTypeId::BIGINT: { + result.type_ids_data[values_list_offset + values_offset] = static_cast(VariantLogicalType::INT64); + Store(primitive.GetValueUnsafe(), blob_data + data_offset); + data_offset += sizeof(int64_t); + break; + } + case LogicalTypeId::DOUBLE: { + result.type_ids_data[values_list_offset + values_offset] = static_cast(VariantLogicalType::DOUBLE); + Store(primitive.GetValueUnsafe(), blob_data + data_offset); + data_offset += sizeof(double); + break; + } + case LogicalTypeId::FLOAT: { + result.type_ids_data[values_list_offset + values_offset] = static_cast(VariantLogicalType::FLOAT); + Store(primitive.GetValueUnsafe(), blob_data + data_offset); + data_offset += sizeof(float); + break; + } + case LogicalTypeId::DATE: { + result.type_ids_data[values_list_offset + values_offset] = static_cast(VariantLogicalType::DATE); + Store(primitive.GetValueUnsafe(), blob_data + data_offset); + data_offset += sizeof(date_t); + break; + } + case LogicalTypeId::TIMESTAMP_TZ: { + result.type_ids_data[values_list_offset + values_offset] = + static_cast(VariantLogicalType::TIMESTAMP_MICROS_TZ); + Store(primitive.GetValueUnsafe(), blob_data + data_offset); + data_offset += sizeof(timestamp_tz_t); + break; + } + case LogicalTypeId::TIMESTAMP: { + result.type_ids_data[values_list_offset + values_offset] = + static_cast(VariantLogicalType::TIMESTAMP_MICROS); + Store(primitive.GetValueUnsafe(), blob_data + data_offset); + data_offset += sizeof(timestamp_t); + break; + } + case LogicalTypeId::TIME: { + result.type_ids_data[values_list_offset + values_offset] = + static_cast(VariantLogicalType::TIME_MICROS); + Store(primitive.GetValueUnsafe(), blob_data + data_offset); + data_offset += sizeof(dtime_t); + break; + } + case LogicalTypeId::TIMESTAMP_NS: { + result.type_ids_data[values_list_offset + values_offset] = + static_cast(VariantLogicalType::TIMESTAMP_NANOS); + Store(primitive.GetValueUnsafe(), blob_data + data_offset); + data_offset += sizeof(timestamp_ns_t); + break; + } + case LogicalTypeId::UUID: { + result.type_ids_data[values_list_offset + values_offset] = static_cast(VariantLogicalType::UUID); + Store(primitive.GetValueUnsafe(), blob_data + data_offset); + data_offset += sizeof(hugeint_t); + break; + } + case LogicalTypeId::DECIMAL: { + result.type_ids_data[values_list_offset + values_offset] = + static_cast(VariantLogicalType::DECIMAL); + auto &type = primitive.type(); + uint8_t width; + uint8_t scale; + type.GetDecimalProperties(width, scale); + + auto physical_type = type.InternalType(); + VarintEncode(width, blob_data + data_offset); + data_offset += GetVarintSize(width); + VarintEncode(scale, blob_data + data_offset); + data_offset += GetVarintSize(scale); + switch (physical_type) { + case PhysicalType::INT32: { + Store(primitive.GetValueUnsafe(), blob_data + data_offset); + data_offset += sizeof(int32_t); + break; + } + case PhysicalType::INT64: { + Store(primitive.GetValueUnsafe(), blob_data + data_offset); + data_offset += sizeof(int64_t); + break; + } + case PhysicalType::INT128: { + Store(primitive.GetValueUnsafe(), blob_data + data_offset); + data_offset += sizeof(hugeint_t); + break; + } + default: + throw InternalException("Unexpected physical type for Decimal value: %s", + EnumUtil::ToString(physical_type)); + } + break; + } + case LogicalTypeId::BLOB: + case LogicalTypeId::VARCHAR: { + if (type_id == LogicalTypeId::BLOB) { + result.type_ids_data[values_list_offset + values_offset] = + static_cast(VariantLogicalType::BLOB); + } else { + result.type_ids_data[values_list_offset + values_offset] = + static_cast(VariantLogicalType::VARCHAR); + } + auto string_data = primitive.GetValueUnsafe(); + auto string_size = string_data.GetSize(); + VarintEncode(string_size, blob_data + data_offset); + data_offset += GetVarintSize(string_size); + memcpy(blob_data + data_offset, string_data.GetData(), string_size); + data_offset += string_size; + break; + } + default: + throw InternalException("Encountered unrecognized LogicalType in VariantValue::AnalyzeValue: %s", + primitive.type().ToString()); + } + values_offset++; + break; + } + default: + throw InternalException("VariantValueType not handled"); + } +} + +//! Copied and modified from 'to_variant.cpp' +static void InitializeVariants(DataChunk &offsets, Vector &result, SelectionVector &keys_selvec, idx_t &selvec_size) { + auto &keys = VariantVector::GetKeys(result); + auto keys_data = ListVector::GetData(keys); + + auto &children = VariantVector::GetChildren(result); + auto children_data = ListVector::GetData(children); + + auto &values = VariantVector::GetValues(result); + auto values_data = ListVector::GetData(values); + + auto &blob = VariantVector::GetData(result); + auto blob_data = FlatVector::GetData(blob); + + idx_t children_offset = 0; + idx_t values_offset = 0; + idx_t keys_offset = 0; + + auto keys_sizes = variant::OffsetData::GetKeys(offsets); + auto children_sizes = variant::OffsetData::GetChildren(offsets); + auto values_sizes = variant::OffsetData::GetValues(offsets); + auto blob_sizes = variant::OffsetData::GetBlob(offsets); + + auto count = offsets.size(); + for (idx_t i = 0; i < count; i++) { + auto &keys_entry = keys_data[i]; + auto &children_entry = children_data[i]; + auto &values_entry = values_data[i]; + + //! keys + keys_entry.length = keys_sizes[i]; + keys_entry.offset = keys_offset; + keys_offset += keys_entry.length; + + //! children + children_entry.length = children_sizes[i]; + children_entry.offset = children_offset; + children_offset += children_entry.length; + + //! values + values_entry.length = values_sizes[i]; + values_entry.offset = values_offset; + values_offset += values_entry.length; + + //! value + blob_data[i] = StringVector::EmptyString(blob, blob_sizes[i]); + } + + //! Reserve for the children of the lists + ListVector::Reserve(keys, keys_offset); + ListVector::Reserve(children, children_offset); + ListVector::Reserve(values, values_offset); + + //! Set list sizes + ListVector::SetListSize(keys, keys_offset); + ListVector::SetListSize(children, children_offset); + ListVector::SetListSize(values, values_offset); + + keys_selvec.Initialize(keys_offset); + selvec_size = keys_offset; +} + +void VariantValue::ToVARIANT(vector &input, Vector &result) { + auto count = input.size(); + if (input.empty()) { + return; + } + + //! Keep track of all the offsets for each row. + DataChunk analyze_offsets; + analyze_offsets.Initialize( + Allocator::DefaultAllocator(), + {LogicalType::UINTEGER, LogicalType::UINTEGER, LogicalType::UINTEGER, LogicalType::UINTEGER}, count); + analyze_offsets.SetCardinality(count); + InitializeOffsets(analyze_offsets, count); + + for (idx_t i = 0; i < count; i++) { + auto &value = input[i]; + if (value.IsNull()) { + continue; + } + AnalyzeValue(value, i, analyze_offsets); + } + + SelectionVector keys_selvec; + idx_t keys_selvec_size; + InitializeVariants(analyze_offsets, result, keys_selvec, keys_selvec_size); + + auto &keys = VariantVector::GetKeys(result); + auto &keys_entry = ListVector::GetEntry(keys); + OrderedOwningStringMap dictionary(StringVector::GetStringBuffer(keys_entry).GetStringAllocator()); + + DataChunk conversion_offsets; + conversion_offsets.Initialize( + Allocator::DefaultAllocator(), + {LogicalType::UINTEGER, LogicalType::UINTEGER, LogicalType::UINTEGER, LogicalType::UINTEGER}, count); + conversion_offsets.SetCardinality(count); + InitializeOffsets(conversion_offsets, count); + + VariantVectorData variant_data(result); + for (idx_t i = 0; i < count; i++) { + auto &value = input[i]; + if (value.IsNull()) { + FlatVector::SetNull(result, i, true); + continue; + } + ConvertValue(value, variant_data, i, conversion_offsets, keys_selvec, dictionary); + } + +#ifdef DEBUG + { + auto conversion_keys_offset = variant::OffsetData::GetKeys(conversion_offsets); + auto conversion_children_offset = variant::OffsetData::GetChildren(conversion_offsets); + auto conversion_values_offset = variant::OffsetData::GetValues(conversion_offsets); + auto conversion_data_offset = variant::OffsetData::GetBlob(conversion_offsets); + + auto analyze_keys_offset = variant::OffsetData::GetKeys(analyze_offsets); + auto analyze_children_offset = variant::OffsetData::GetChildren(analyze_offsets); + auto analyze_values_offset = variant::OffsetData::GetValues(analyze_offsets); + auto analyze_data_offset = variant::OffsetData::GetBlob(analyze_offsets); + + for (idx_t i = 0; i < count; i++) { + D_ASSERT(conversion_keys_offset[i] == analyze_keys_offset[i]); + D_ASSERT(conversion_children_offset[i] == analyze_children_offset[i]); + D_ASSERT(conversion_values_offset[i] == analyze_values_offset[i]); + D_ASSERT(conversion_data_offset[i] == analyze_data_offset[i]); + } + } + +#endif + + //! Finalize the 'data' column of the VARIANT + auto conversion_data_offsets = variant::OffsetData::GetBlob(conversion_offsets); + for (idx_t i = 0; i < count; i++) { + auto &data = variant_data.blob_data[i]; + data.SetSizeAndFinalize(conversion_data_offsets[i], conversion_data_offsets[i]); + } + + VariantUtils::FinalizeVariantKeys(result, dictionary, keys_selvec, keys_selvec_size); + + keys_entry.Slice(keys_selvec, keys_selvec_size); + keys_entry.Flatten(keys_selvec_size); + + if (input.size() == 1) { + result.SetVectorType(VectorType::CONSTANT_VECTOR); + } + result.Verify(count); +} + yyjson_mut_val *VariantValue::ToJSON(ClientContext &context, yyjson_mut_doc *doc) const { switch (value_type) { case VariantValueType::PRIMITIVE: { diff --git a/src/duckdb/extension/parquet/reader/variant_column_reader.cpp b/src/duckdb/extension/parquet/reader/variant_column_reader.cpp index 402bcbb07..635bfbbb5 100644 --- a/src/duckdb/extension/parquet/reader/variant_column_reader.cpp +++ b/src/duckdb/extension/parquet/reader/variant_column_reader.cpp @@ -11,7 +11,7 @@ VariantColumnReader::VariantColumnReader(ClientContext &context, ParquetReader & const ParquetColumnSchema &schema, vector> child_readers_p) : ColumnReader(reader, schema), context(context), child_readers(std::move(child_readers_p)) { - D_ASSERT(Type().InternalType() == PhysicalType::VARCHAR); + D_ASSERT(Type().InternalType() == PhysicalType::STRUCT); if (child_readers[0]->Schema().name == "metadata" && child_readers[1]->Schema().name == "value") { metadata_reader_idx = 0; @@ -80,10 +80,7 @@ idx_t VariantColumnReader::Read(uint64_t num_values, data_ptr_t define_out, data "The Variant column did not contain the same amount of values for 'metadata' and 'value'"); } - auto result_data = FlatVector::GetData(result); - auto &result_validity = FlatVector::Validity(result); - - vector conversion_result; + vector intermediate; if (typed_value_reader) { auto typed_values = typed_value_reader->Read(num_values, define_out, repeat_out, *group_entries[1]); if (typed_values != value_values) { @@ -91,29 +88,9 @@ idx_t VariantColumnReader::Read(uint64_t num_values, data_ptr_t define_out, data "The shredded Variant column did not contain the same amount of values for 'typed_value' and 'value'"); } } - conversion_result = - VariantShreddedConversion::Convert(metadata_intermediate, intermediate_group, 0, num_values, num_values); - - for (idx_t i = 0; i < conversion_result.size(); i++) { - auto &variant = conversion_result[i]; - if (variant.IsNull()) { - result_validity.SetInvalid(i); - continue; - } - - //! Write the result to a string - VariantDecodeResult decode_result; - decode_result.doc = yyjson_mut_doc_new(nullptr); - auto json_val = variant.ToJSON(context, decode_result.doc); - - size_t len; - decode_result.data = - yyjson_mut_val_write_opts(json_val, YYJSON_WRITE_ALLOW_INF_AND_NAN, nullptr, &len, nullptr); - if (!decode_result.data) { - throw InvalidInputException("Could not serialize the JSON to string, yyjson failed"); - } - result_data[i] = StringVector::AddString(result, decode_result.data, static_cast(len)); - } + intermediate = + VariantShreddedConversion::Convert(metadata_intermediate, intermediate_group, 0, num_values, num_values, false); + VariantValue::ToVARIANT(intermediate, result); read_count = value_values; return read_count.GetIndex(); diff --git a/src/duckdb/extension/parquet/serialize_parquet.cpp b/src/duckdb/extension/parquet/serialize_parquet.cpp index aa5632077..6f12d5d89 100644 --- a/src/duckdb/extension/parquet/serialize_parquet.cpp +++ b/src/duckdb/extension/parquet/serialize_parquet.cpp @@ -7,7 +7,8 @@ #include "duckdb/common/serializer/deserializer.hpp" #include "parquet_reader.hpp" #include "parquet_crypto.hpp" -#include "parquet_writer.hpp" +#include "parquet_field_id.hpp" +#include "parquet_shredding.hpp" namespace duckdb { @@ -21,6 +22,16 @@ ChildFieldIDs ChildFieldIDs::Deserialize(Deserializer &deserializer) { return result; } +void ChildShreddingTypes::Serialize(Serializer &serializer) const { + serializer.WritePropertyWithDefault>(100, "types", types.operator*()); +} + +ChildShreddingTypes ChildShreddingTypes::Deserialize(Deserializer &deserializer) { + ChildShreddingTypes result; + deserializer.ReadPropertyWithDefault>(100, "types", result.types.operator*()); + return result; +} + void FieldID::Serialize(Serializer &serializer) const { serializer.WritePropertyWithDefault(100, "set", set); serializer.WritePropertyWithDefault(101, "field_id", field_id); @@ -89,4 +100,18 @@ ParquetOptionsSerialization ParquetOptionsSerialization::Deserialize(Deserialize return result; } +void ShreddingType::Serialize(Serializer &serializer) const { + serializer.WritePropertyWithDefault(100, "set", set); + serializer.WriteProperty(101, "type", type); + serializer.WriteProperty(102, "children", children); +} + +ShreddingType ShreddingType::Deserialize(Deserializer &deserializer) { + ShreddingType result; + deserializer.ReadPropertyWithDefault(100, "set", result.set); + deserializer.ReadProperty(101, "type", result.type); + deserializer.ReadProperty(102, "children", result.children); + return result; +} + } // namespace duckdb diff --git a/src/duckdb/extension/parquet/writer/array_column_writer.cpp b/src/duckdb/extension/parquet/writer/array_column_writer.cpp index 60284ff28..2a9c9a9d5 100644 --- a/src/duckdb/extension/parquet/writer/array_column_writer.cpp +++ b/src/duckdb/extension/parquet/writer/array_column_writer.cpp @@ -6,7 +6,7 @@ void ArrayColumnWriter::Analyze(ColumnWriterState &state_p, ColumnWriterState *p auto &state = state_p.Cast(); auto &array_child = ArrayVector::GetEntry(vector); auto array_size = ArrayType::GetSize(vector.GetType()); - child_writer->Analyze(*state.child_state, &state_p, array_child, array_size * count); + GetChildWriter().Analyze(*state.child_state, &state_p, array_child, array_size * count); } void ArrayColumnWriter::WriteArrayState(ListColumnWriterState &state, idx_t array_size, uint16_t first_repeat_level, @@ -35,10 +35,9 @@ void ArrayColumnWriter::Prepare(ColumnWriterState &state_p, ColumnWriterState *p // write definition levels and repeats // the main difference between this and ListColumnWriter::Prepare is that we need to make sure to write out // repetition levels and definitions for the child elements of the array even if the array itself is NULL. - idx_t start = 0; idx_t vcount = parent ? parent->definition_levels.size() - state.parent_index : count; idx_t vector_index = 0; - for (idx_t i = start; i < vcount; i++) { + for (idx_t i = 0; i < vcount; i++) { idx_t parent_index = state.parent_index + i; if (parent && !parent->is_empty.empty() && parent->is_empty[parent_index]) { WriteArrayState(state, array_size, parent->repetition_levels[parent_index], @@ -63,14 +62,14 @@ void ArrayColumnWriter::Prepare(ColumnWriterState &state_p, ColumnWriterState *p auto &array_child = ArrayVector::GetEntry(vector); // The elements of a single array should not span multiple Parquet pages // So, we force the entire vector to fit on a single page by setting "vector_can_span_multiple_pages=false" - child_writer->Prepare(*state.child_state, &state_p, array_child, count * array_size, false); + GetChildWriter().Prepare(*state.child_state, &state_p, array_child, count * array_size, false); } void ArrayColumnWriter::Write(ColumnWriterState &state_p, Vector &vector, idx_t count) { auto &state = state_p.Cast(); auto array_size = ArrayType::GetSize(vector.GetType()); auto &array_child = ArrayVector::GetEntry(vector); - child_writer->Write(*state.child_state, array_child, count * array_size); + GetChildWriter().Write(*state.child_state, array_child, count * array_size); } } // namespace duckdb diff --git a/src/duckdb/extension/parquet/writer/list_column_writer.cpp b/src/duckdb/extension/parquet/writer/list_column_writer.cpp index 8fba00c23..b043a94bc 100644 --- a/src/duckdb/extension/parquet/writer/list_column_writer.cpp +++ b/src/duckdb/extension/parquet/writer/list_column_writer.cpp @@ -4,23 +4,23 @@ namespace duckdb { unique_ptr ListColumnWriter::InitializeWriteState(duckdb_parquet::RowGroup &row_group) { auto result = make_uniq(row_group, row_group.columns.size()); - result->child_state = child_writer->InitializeWriteState(row_group); + result->child_state = GetChildWriter().InitializeWriteState(row_group); return std::move(result); } bool ListColumnWriter::HasAnalyze() { - return child_writer->HasAnalyze(); + return GetChildWriter().HasAnalyze(); } void ListColumnWriter::Analyze(ColumnWriterState &state_p, ColumnWriterState *parent, Vector &vector, idx_t count) { auto &state = state_p.Cast(); auto &list_child = ListVector::GetEntry(vector); auto list_count = ListVector::GetListSize(vector); - child_writer->Analyze(*state.child_state, &state_p, list_child, list_count); + GetChildWriter().Analyze(*state.child_state, &state_p, list_child, list_count); } void ListColumnWriter::FinalizeAnalyze(ColumnWriterState &state_p) { auto &state = state_p.Cast(); - child_writer->FinalizeAnalyze(*state.child_state); + GetChildWriter().FinalizeAnalyze(*state.child_state); } static idx_t GetConsecutiveChildList(Vector &list, Vector &result, idx_t offset, idx_t count) { @@ -114,12 +114,12 @@ void ListColumnWriter::Prepare(ColumnWriterState &state_p, ColumnWriterState *pa auto child_length = GetConsecutiveChildList(vector, child_list, 0, count); // The elements of a single list should not span multiple Parquet pages // So, we force the entire vector to fit on a single page by setting "vector_can_span_multiple_pages=false" - child_writer->Prepare(*state.child_state, &state_p, child_list, child_length, false); + GetChildWriter().Prepare(*state.child_state, &state_p, child_list, child_length, false); } void ListColumnWriter::BeginWrite(ColumnWriterState &state_p) { auto &state = state_p.Cast(); - child_writer->BeginWrite(*state.child_state); + GetChildWriter().BeginWrite(*state.child_state); } void ListColumnWriter::Write(ColumnWriterState &state_p, Vector &vector, idx_t count) { @@ -128,12 +128,17 @@ void ListColumnWriter::Write(ColumnWriterState &state_p, Vector &vector, idx_t c auto &list_child = ListVector::GetEntry(vector); Vector child_list(list_child); auto child_length = GetConsecutiveChildList(vector, child_list, 0, count); - child_writer->Write(*state.child_state, child_list, child_length); + GetChildWriter().Write(*state.child_state, child_list, child_length); } void ListColumnWriter::FinalizeWrite(ColumnWriterState &state_p) { auto &state = state_p.Cast(); - child_writer->FinalizeWrite(*state.child_state); + GetChildWriter().FinalizeWrite(*state.child_state); +} + +ColumnWriter &ListColumnWriter::GetChildWriter() { + D_ASSERT(child_writers.size() == 1); + return *child_writers[0]; } } // namespace duckdb diff --git a/src/duckdb/extension/parquet/writer/primitive_column_writer.cpp b/src/duckdb/extension/parquet/writer/primitive_column_writer.cpp index d3ebd7dfc..7c7050be2 100644 --- a/src/duckdb/extension/parquet/writer/primitive_column_writer.cpp +++ b/src/duckdb/extension/parquet/writer/primitive_column_writer.cpp @@ -7,6 +7,9 @@ namespace duckdb { using duckdb_parquet::Encoding; using duckdb_parquet::PageType; +constexpr const idx_t PrimitiveColumnWriter::MAX_UNCOMPRESSED_PAGE_SIZE; +constexpr const idx_t PrimitiveColumnWriter::MAX_UNCOMPRESSED_DICT_PAGE_SIZE; + PrimitiveColumnWriter::PrimitiveColumnWriter(ParquetWriter &writer, const ParquetColumnSchema &column_schema, vector schema_path, bool can_have_nulls) : ColumnWriter(writer, column_schema, std::move(schema_path), can_have_nulls) { @@ -44,7 +47,7 @@ void PrimitiveColumnWriter::Prepare(ColumnWriterState &state_p, ColumnWriterStat idx_t vcount = parent ? parent->definition_levels.size() - state.definition_levels.size() : count; idx_t parent_index = state.definition_levels.size(); auto &validity = FlatVector::Validity(vector); - HandleRepeatLevels(state, parent, count, MaxRepeat()); + HandleRepeatLevels(state, parent, count); HandleDefineLevels(state, parent, validity, count, MaxDefine(), MaxDefine() - 1); idx_t vector_index = 0; @@ -111,7 +114,7 @@ void PrimitiveColumnWriter::BeginWrite(ColumnWriterState &state_p) { hdr.type = PageType::DATA_PAGE; hdr.__isset.data_page_header = true; - hdr.data_page_header.num_values = UnsafeNumericCast(page_info.row_count); + hdr.data_page_header.num_values = NumericCast(page_info.row_count); hdr.data_page_header.encoding = GetEncoding(state); hdr.data_page_header.definition_level_encoding = Encoding::RLE; hdr.data_page_header.repetition_level_encoding = Encoding::RLE; @@ -304,12 +307,23 @@ void PrimitiveColumnWriter::SetParquetStatistics(PrimitiveColumnWriterState &sta } if (state.stats_state->HasGeoStats()) { - column_chunk.meta_data.__isset.geospatial_statistics = true; - state.stats_state->WriteGeoStats(column_chunk.meta_data.geospatial_statistics); + auto gpq_version = writer.GetGeoParquetVersion(); + + const auto has_real_stats = gpq_version == GeoParquetVersion::NONE || gpq_version == GeoParquetVersion::BOTH || + gpq_version == GeoParquetVersion::V2; + const auto has_json_stats = gpq_version == GeoParquetVersion::V1 || gpq_version == GeoParquetVersion::BOTH || + gpq_version == GeoParquetVersion::V2; - // Add the geospatial statistics to the extra GeoParquet metadata - writer.GetGeoParquetData().AddGeoParquetStats(column_schema.name, column_schema.type, - *state.stats_state->GetGeoStats()); + if (has_real_stats) { + // Write the parquet native geospatial statistics + column_chunk.meta_data.__isset.geospatial_statistics = true; + state.stats_state->WriteGeoStats(column_chunk.meta_data.geospatial_statistics); + } + if (has_json_stats) { + // Add the geospatial statistics to the extra GeoParquet metadata + writer.GetGeoParquetData().AddGeoParquetStats(column_schema.name, column_schema.type, + *state.stats_state->GetGeoStats()); + } } for (const auto &write_info : state.write_info) { diff --git a/src/duckdb/extension/parquet/writer/struct_column_writer.cpp b/src/duckdb/extension/parquet/writer/struct_column_writer.cpp index e65515ad5..c9b6bcf9d 100644 --- a/src/duckdb/extension/parquet/writer/struct_column_writer.cpp +++ b/src/duckdb/extension/parquet/writer/struct_column_writer.cpp @@ -67,7 +67,7 @@ void StructColumnWriter::Prepare(ColumnWriterState &state_p, ColumnWriterState * parent->is_empty.end()); } } - HandleRepeatLevels(state_p, parent, count, MaxRepeat()); + HandleRepeatLevels(state_p, parent, count); HandleDefineLevels(state_p, parent, validity, count, PARQUET_DEFINE_VALID, MaxDefine() - 1); auto &child_vectors = StructVector::GetEntries(vector); for (idx_t child_idx = 0; child_idx < child_writers.size(); child_idx++) { diff --git a/src/duckdb/extension/parquet/writer/variant/convert_variant.cpp b/src/duckdb/extension/parquet/writer/variant/convert_variant.cpp new file mode 100644 index 000000000..8d5c755ed --- /dev/null +++ b/src/duckdb/extension/parquet/writer/variant/convert_variant.cpp @@ -0,0 +1,1206 @@ +#include "writer/variant_column_writer.hpp" +#include "duckdb/common/types/variant.hpp" +#include "duckdb/planner/expression/bound_function_expression.hpp" +#include "duckdb/function/scalar/variant_utils.hpp" +#include "reader/variant/variant_binary_decoder.hpp" +#include "parquet_shredding.hpp" +#include "duckdb/common/types/decimal.hpp" +#include "duckdb/common/types/uuid.hpp" + +namespace duckdb { + +static idx_t CalculateByteLength(idx_t value) { + if (value == 0) { + return 1; + } + auto value_data = reinterpret_cast(&value); + idx_t irrelevant_bytes = 0; + //! Check how many of the most significant bytes are 0 + for (idx_t i = sizeof(idx_t); i > 0 && value_data[i - 1] == 0; i--) { + irrelevant_bytes++; + } + return sizeof(idx_t) - irrelevant_bytes; +} + +static uint8_t EncodeMetadataHeader(idx_t byte_length) { + D_ASSERT(byte_length <= 4); + + uint8_t header_byte = 0; + //! Set 'version' to 1 + header_byte |= static_cast(1); + //! Set 'sorted_strings' to 1 + header_byte |= static_cast(1) << 4; + //! Set 'offset_size_minus_one' to byte_length-1 + header_byte |= (static_cast(byte_length) - 1) << 6; + +#ifdef DEBUG + auto decoded_header = VariantMetadataHeader::FromHeaderByte(header_byte); + D_ASSERT(decoded_header.offset_size == byte_length); +#endif + + return header_byte; +} + +static void CreateMetadata(UnifiedVariantVectorData &variant, Vector &metadata, idx_t count) { + auto &keys = variant.keys; + auto keys_data = variant.keys_data; + + //! NOTE: the parquet variant is limited to a max dictionary size of NumericLimits::Maximum() + //! Whereas we can have NumericLimits::Maximum() *per* string in DuckDB + auto metadata_data = FlatVector::GetData(metadata); + for (idx_t row = 0; row < count; row++) { + uint64_t dictionary_count = 0; + if (variant.RowIsValid(row)) { + auto list_entry = keys_data[keys.sel->get_index(row)]; + dictionary_count = list_entry.length; + } + idx_t dictionary_size = 0; + for (idx_t i = 0; i < dictionary_count; i++) { + auto &key = variant.GetKey(row, i); + dictionary_size += key.GetSize(); + } + if (dictionary_size >= NumericLimits::Maximum()) { + throw InvalidInputException("The total length of the dictionary exceeds a 4 byte value (uint32_t), failed " + "to export VARIANT to Parquet"); + } + + auto byte_length = CalculateByteLength(dictionary_size); + auto total_length = 1 + (byte_length * (dictionary_count + 2)) + dictionary_size; + + metadata_data[row] = StringVector::EmptyString(metadata, total_length); + auto &metadata_blob = metadata_data[row]; + auto metadata_blob_data = metadata_blob.GetDataWriteable(); + + metadata_blob_data[0] = EncodeMetadataHeader(byte_length); + memcpy(metadata_blob_data + 1, reinterpret_cast(&dictionary_count), byte_length); + + auto offset_ptr = metadata_blob_data + 1 + byte_length; + auto string_ptr = metadata_blob_data + 1 + byte_length + ((dictionary_count + 1) * byte_length); + idx_t total_offset = 0; + for (idx_t i = 0; i < dictionary_count; i++) { + memcpy(offset_ptr + (i * byte_length), reinterpret_cast(&total_offset), byte_length); + auto &key = variant.GetKey(row, i); + + memcpy(string_ptr + total_offset, key.GetData(), key.GetSize()); + total_offset += key.GetSize(); + } + memcpy(offset_ptr + (dictionary_count * byte_length), reinterpret_cast(&total_offset), byte_length); + D_ASSERT(offset_ptr + ((dictionary_count + 1) * byte_length) == string_ptr); + D_ASSERT(string_ptr + total_offset == metadata_blob_data + total_length); + metadata_blob.SetSizeAndFinalize(total_length, total_length); + +#ifdef DEBUG + auto decoded_metadata = VariantMetadata(metadata_blob); + D_ASSERT(decoded_metadata.strings.size() == dictionary_count); + for (idx_t i = 0; i < dictionary_count; i++) { + D_ASSERT(decoded_metadata.strings[i] == variant.GetKey(row, i).GetString()); + } +#endif + } +} + +namespace { + +static unordered_set GetVariantType(const LogicalType &type) { + if (type.id() == LogicalTypeId::ANY) { + return {}; + } + switch (type.id()) { + case LogicalTypeId::STRUCT: + return {VariantLogicalType::OBJECT}; + case LogicalTypeId::LIST: + return {VariantLogicalType::ARRAY}; + case LogicalTypeId::BOOLEAN: + return {VariantLogicalType::BOOL_TRUE, VariantLogicalType::BOOL_FALSE}; + case LogicalTypeId::TINYINT: + return {VariantLogicalType::INT8}; + case LogicalTypeId::SMALLINT: + return {VariantLogicalType::INT16}; + case LogicalTypeId::INTEGER: + return {VariantLogicalType::INT32}; + case LogicalTypeId::BIGINT: + return {VariantLogicalType::INT64}; + case LogicalTypeId::FLOAT: + return {VariantLogicalType::FLOAT}; + case LogicalTypeId::DOUBLE: + return {VariantLogicalType::DOUBLE}; + case LogicalTypeId::DECIMAL: + return {VariantLogicalType::DECIMAL}; + case LogicalTypeId::DATE: + return {VariantLogicalType::DATE}; + case LogicalTypeId::TIME: + return {VariantLogicalType::TIME_MICROS}; + case LogicalTypeId::TIMESTAMP_TZ: + return {VariantLogicalType::TIMESTAMP_MICROS_TZ}; + case LogicalTypeId::TIMESTAMP: + return {VariantLogicalType::TIMESTAMP_MICROS}; + case LogicalTypeId::TIMESTAMP_NS: + return {VariantLogicalType::TIMESTAMP_NANOS}; + case LogicalTypeId::BLOB: + return {VariantLogicalType::BLOB}; + case LogicalTypeId::VARCHAR: + return {VariantLogicalType::VARCHAR}; + case LogicalTypeId::UUID: + return {VariantLogicalType::UUID}; + default: + throw BinderException("Type '%s' can't be translated to a VARIANT type", type.ToString()); + } +} + +struct ShreddingState { +public: + explicit ShreddingState(const LogicalType &type, idx_t total_count) + : type(type), shredded_sel(total_count), values_index_sel(total_count), result_sel(total_count) { + variant_types = GetVariantType(type); + } + +public: + bool ValueIsShredded(UnifiedVariantVectorData &variant, idx_t row, idx_t values_index) { + auto type_id = variant.GetTypeId(row, values_index); + if (!variant_types.count(type_id)) { + return false; + } + if (type_id == VariantLogicalType::DECIMAL) { + auto physical_type = type.InternalType(); + auto decimal_data = VariantUtils::DecodeDecimalData(variant, row, values_index); + auto decimal_physical_type = decimal_data.GetPhysicalType(); + return physical_type == decimal_physical_type; + } + return true; + } + void SetShredded(idx_t row, idx_t values_index, idx_t result_idx) { + shredded_sel[count] = row; + values_index_sel[count] = values_index; + result_sel[count] = result_idx; + count++; + } + case_insensitive_string_set_t ObjectFields() { + D_ASSERT(type.id() == LogicalTypeId::STRUCT); + case_insensitive_string_set_t res; + auto &child_types = StructType::GetChildTypes(type); + for (auto &entry : child_types) { + auto &type = entry.first; + res.emplace(string_t(type.c_str(), type.size())); + } + return res; + } + +public: + //! The type the field is shredded on + const LogicalType &type; + unordered_set variant_types; + //! row that is shredded + SelectionVector shredded_sel; + //! 'values_index' of the shredded value + SelectionVector values_index_sel; + //! result row of the shredded value + SelectionVector result_sel; + //! The amount of rows that are shredded on + idx_t count = 0; +}; + +} // namespace + +vector GetChildIndices(const UnifiedVariantVectorData &variant, idx_t row, const VariantNestedData &nested_data, + optional_ptr shredding_state) { + vector child_indices; + if (!shredding_state || shredding_state->type.id() != LogicalTypeId::STRUCT) { + for (idx_t i = 0; i < nested_data.child_count; i++) { + child_indices.push_back(i); + } + return child_indices; + } + //! FIXME: The variant spec says that field names should be case-sensitive, not insensitive + case_insensitive_string_set_t shredded_fields = shredding_state->ObjectFields(); + + for (idx_t i = 0; i < nested_data.child_count; i++) { + auto keys_index = variant.GetKeysIndex(row, i + nested_data.children_idx); + auto &key = variant.GetKey(row, keys_index); + + if (shredded_fields.count(key)) { + //! This field is shredded on, omit it from the value + continue; + } + child_indices.push_back(i); + } + return child_indices; +} + +static idx_t AnalyzeValueData(const UnifiedVariantVectorData &variant, idx_t row, uint32_t values_index, + vector &offsets, optional_ptr shredding_state) { + idx_t total_size = 0; + //! Every value has at least a value header + total_size++; + + idx_t offset_size = offsets.size(); + VariantLogicalType type_id = VariantLogicalType::VARIANT_NULL; + if (variant.RowIsValid(row)) { + type_id = variant.GetTypeId(row, values_index); + } + switch (type_id) { + case VariantLogicalType::OBJECT: { + auto nested_data = VariantUtils::DecodeNestedData(variant, row, values_index); + + //! Calculate value and key offsets for all children + idx_t total_offset = 0; + uint32_t highest_keys_index = 0; + + auto child_indices = GetChildIndices(variant, row, nested_data, shredding_state); + if (nested_data.child_count && child_indices.empty()) { + //! All fields of the object are shredded, omit the object entirely + return 0; + } + + auto num_elements = child_indices.size(); + offsets.resize(offset_size + num_elements + 1); + + for (idx_t entry = 0; entry < child_indices.size(); entry++) { + auto i = child_indices[entry]; + auto keys_index = variant.GetKeysIndex(row, i + nested_data.children_idx); + auto values_index = variant.GetValuesIndex(row, i + nested_data.children_idx); + offsets[offset_size + entry] = total_offset; + + total_offset += AnalyzeValueData(variant, row, values_index, offsets, nullptr); + highest_keys_index = MaxValue(highest_keys_index, keys_index); + } + offsets[offset_size + num_elements] = total_offset; + + //! Calculate the sizes for the objects value data + auto field_id_size = CalculateByteLength(highest_keys_index); + auto field_offset_size = CalculateByteLength(total_offset); + const bool is_large = num_elements > NumericLimits::Maximum(); + + //! Now add the sizes for the objects value data + if (is_large) { + total_size += sizeof(uint32_t); + } else { + total_size += sizeof(uint8_t); + } + total_size += num_elements * field_id_size; + total_size += (num_elements + 1) * field_offset_size; + total_size += total_offset; + break; + } + case VariantLogicalType::ARRAY: { + auto nested_data = VariantUtils::DecodeNestedData(variant, row, values_index); + + idx_t total_offset = 0; + offsets.resize(offset_size + nested_data.child_count + 1); + for (idx_t i = 0; i < nested_data.child_count; i++) { + auto values_index = variant.GetValuesIndex(row, i + nested_data.children_idx); + offsets[offset_size + i] = total_offset; + + total_offset += AnalyzeValueData(variant, row, values_index, offsets, nullptr); + } + offsets[offset_size + nested_data.child_count] = total_offset; + + auto field_offset_size = CalculateByteLength(total_offset); + auto num_elements = nested_data.child_count; + const bool is_large = num_elements > NumericLimits::Maximum(); + + if (is_large) { + total_size += sizeof(uint32_t); + } else { + total_size += sizeof(uint8_t); + } + total_size += (num_elements + 1) * field_offset_size; + total_size += total_offset; + break; + } + case VariantLogicalType::BLOB: + case VariantLogicalType::VARCHAR: { + auto string_value = VariantUtils::DecodeStringData(variant, row, values_index); + total_size += string_value.GetSize(); + if (type_id == VariantLogicalType::BLOB || string_value.GetSize() > 64) { + //! Save as regular string value + total_size += sizeof(uint32_t); + } + break; + } + case VariantLogicalType::VARIANT_NULL: + case VariantLogicalType::BOOL_TRUE: + case VariantLogicalType::BOOL_FALSE: + break; + case VariantLogicalType::INT8: + total_size += sizeof(uint8_t); + break; + case VariantLogicalType::INT16: + total_size += sizeof(uint16_t); + break; + case VariantLogicalType::INT32: + total_size += sizeof(uint32_t); + break; + case VariantLogicalType::INT64: + total_size += sizeof(uint64_t); + break; + case VariantLogicalType::FLOAT: + total_size += sizeof(float); + break; + case VariantLogicalType::DOUBLE: + total_size += sizeof(double); + break; + case VariantLogicalType::DECIMAL: { + auto decimal_data = VariantUtils::DecodeDecimalData(variant, row, values_index); + total_size += 1; + if (decimal_data.width <= 9) { + total_size += sizeof(int32_t); + } else if (decimal_data.width <= 18) { + total_size += sizeof(int64_t); + } else if (decimal_data.width <= 38) { + total_size += sizeof(uhugeint_t); + } else { + throw InvalidInputException("Can't convert VARIANT DECIMAL(%d, %d) to Parquet VARIANT", decimal_data.width, + decimal_data.scale); + } + break; + } + case VariantLogicalType::UUID: + total_size += sizeof(uhugeint_t); + break; + case VariantLogicalType::DATE: + total_size += sizeof(uint32_t); + break; + case VariantLogicalType::TIME_MICROS: + case VariantLogicalType::TIMESTAMP_MICROS: + case VariantLogicalType::TIMESTAMP_NANOS: + case VariantLogicalType::TIMESTAMP_MICROS_TZ: + total_size += sizeof(uint64_t); + break; + case VariantLogicalType::INTERVAL: + case VariantLogicalType::BIGNUM: + case VariantLogicalType::BITSTRING: + case VariantLogicalType::TIMESTAMP_MILIS: + case VariantLogicalType::TIMESTAMP_SEC: + case VariantLogicalType::TIME_MICROS_TZ: + case VariantLogicalType::TIME_NANOS: + case VariantLogicalType::UINT8: + case VariantLogicalType::UINT16: + case VariantLogicalType::UINT32: + case VariantLogicalType::UINT64: + case VariantLogicalType::UINT128: + case VariantLogicalType::INT128: + default: + throw InvalidInputException("Can't convert VARIANT of type '%s' to Parquet VARIANT", + EnumUtil::ToString(type_id)); + } + + return total_size; +} + +template +void WritePrimitiveTypeHeader(data_ptr_t &value_data) { + uint8_t value_header = 0; + value_header |= static_cast(VariantBasicType::PRIMITIVE); + value_header |= static_cast(TYPE_ID) << 2; + + *value_data = value_header; + value_data++; +} + +template +void CopySimplePrimitiveData(const UnifiedVariantVectorData &variant, data_ptr_t &value_data, idx_t row, + uint32_t values_index) { + auto byte_offset = variant.GetByteOffset(row, values_index); + auto data = const_data_ptr_cast(variant.GetData(row).GetData()); + auto ptr = data + byte_offset; + memcpy(value_data, ptr, sizeof(T)); + value_data += sizeof(T); +} + +void CopyUUIDData(const UnifiedVariantVectorData &variant, data_ptr_t &value_data, idx_t row, uint32_t values_index) { + auto byte_offset = variant.GetByteOffset(row, values_index); + auto data = const_data_ptr_cast(variant.GetData(row).GetData()); + auto ptr = data + byte_offset; + + auto uuid = Load(ptr); + BaseUUID::ToBlob(uuid, value_data); + value_data += sizeof(uhugeint_t); +} + +static void WritePrimitiveValueData(const UnifiedVariantVectorData &variant, idx_t row, uint32_t values_index, + data_ptr_t &value_data, const vector &offsets, idx_t &offset_index) { + VariantLogicalType type_id = VariantLogicalType::VARIANT_NULL; + if (variant.RowIsValid(row)) { + type_id = variant.GetTypeId(row, values_index); + } + + D_ASSERT(type_id != VariantLogicalType::OBJECT && type_id != VariantLogicalType::ARRAY); + switch (type_id) { + case VariantLogicalType::BLOB: + case VariantLogicalType::VARCHAR: { + auto string_value = VariantUtils::DecodeStringData(variant, row, values_index); + auto string_size = string_value.GetSize(); + if (type_id == VariantLogicalType::BLOB || string_size > 64) { + if (type_id == VariantLogicalType::BLOB) { + WritePrimitiveTypeHeader(value_data); + } else { + WritePrimitiveTypeHeader(value_data); + } + Store(string_size, value_data); + value_data += sizeof(uint32_t); + } else { + uint8_t value_header = 0; + value_header |= static_cast(VariantBasicType::SHORT_STRING); + value_header |= static_cast(string_size) << 2; + + *value_data = value_header; + value_data++; + } + memcpy(value_data, reinterpret_cast(string_value.GetData()), string_size); + value_data += string_size; + break; + } + case VariantLogicalType::VARIANT_NULL: + WritePrimitiveTypeHeader(value_data); + break; + case VariantLogicalType::BOOL_TRUE: + WritePrimitiveTypeHeader(value_data); + break; + case VariantLogicalType::BOOL_FALSE: + WritePrimitiveTypeHeader(value_data); + break; + case VariantLogicalType::INT8: + WritePrimitiveTypeHeader(value_data); + CopySimplePrimitiveData(variant, value_data, row, values_index); + break; + case VariantLogicalType::INT16: + WritePrimitiveTypeHeader(value_data); + CopySimplePrimitiveData(variant, value_data, row, values_index); + break; + case VariantLogicalType::INT32: + WritePrimitiveTypeHeader(value_data); + CopySimplePrimitiveData(variant, value_data, row, values_index); + break; + case VariantLogicalType::INT64: + WritePrimitiveTypeHeader(value_data); + CopySimplePrimitiveData(variant, value_data, row, values_index); + break; + case VariantLogicalType::FLOAT: + WritePrimitiveTypeHeader(value_data); + CopySimplePrimitiveData(variant, value_data, row, values_index); + break; + case VariantLogicalType::DOUBLE: + WritePrimitiveTypeHeader(value_data); + CopySimplePrimitiveData(variant, value_data, row, values_index); + break; + case VariantLogicalType::UUID: + WritePrimitiveTypeHeader(value_data); + CopyUUIDData(variant, value_data, row, values_index); + break; + case VariantLogicalType::DATE: + WritePrimitiveTypeHeader(value_data); + CopySimplePrimitiveData(variant, value_data, row, values_index); + break; + case VariantLogicalType::TIME_MICROS: + WritePrimitiveTypeHeader(value_data); + CopySimplePrimitiveData(variant, value_data, row, values_index); + break; + case VariantLogicalType::TIMESTAMP_MICROS: + WritePrimitiveTypeHeader(value_data); + CopySimplePrimitiveData(variant, value_data, row, values_index); + break; + case VariantLogicalType::TIMESTAMP_NANOS: + WritePrimitiveTypeHeader(value_data); + CopySimplePrimitiveData(variant, value_data, row, values_index); + break; + case VariantLogicalType::TIMESTAMP_MICROS_TZ: + WritePrimitiveTypeHeader(value_data); + CopySimplePrimitiveData(variant, value_data, row, values_index); + break; + case VariantLogicalType::DECIMAL: { + auto decimal_data = VariantUtils::DecodeDecimalData(variant, row, values_index); + + if (decimal_data.width <= 4 || decimal_data.width > 38) { + throw InvalidInputException("Can't convert VARIANT DECIMAL(%d, %d) to Parquet VARIANT", decimal_data.width, + decimal_data.scale); + } else if (decimal_data.width <= 9) { + WritePrimitiveTypeHeader(value_data); + Store(decimal_data.scale, value_data); + value_data++; + memcpy(value_data, decimal_data.value_ptr, sizeof(int32_t)); + value_data += sizeof(int32_t); + } else if (decimal_data.width <= 18) { + WritePrimitiveTypeHeader(value_data); + Store(decimal_data.scale, value_data); + value_data++; + memcpy(value_data, decimal_data.value_ptr, sizeof(int64_t)); + value_data += sizeof(int64_t); + } else if (decimal_data.width <= 38) { + WritePrimitiveTypeHeader(value_data); + Store(decimal_data.scale, value_data); + value_data++; + memcpy(value_data, decimal_data.value_ptr, sizeof(hugeint_t)); + value_data += sizeof(hugeint_t); + } else { + throw InternalException( + "Uncovered VARIANT(DECIMAL) -> Parquet VARIANT conversion for type 'DECIMAL(%d, %d)'", + decimal_data.width, decimal_data.scale); + } + break; + } + case VariantLogicalType::INTERVAL: + case VariantLogicalType::BIGNUM: + case VariantLogicalType::BITSTRING: + case VariantLogicalType::TIMESTAMP_MILIS: + case VariantLogicalType::TIMESTAMP_SEC: + case VariantLogicalType::TIME_MICROS_TZ: + case VariantLogicalType::TIME_NANOS: + case VariantLogicalType::UINT8: + case VariantLogicalType::UINT16: + case VariantLogicalType::UINT32: + case VariantLogicalType::UINT64: + case VariantLogicalType::UINT128: + case VariantLogicalType::INT128: + default: + throw InvalidInputException("Can't convert VARIANT of type '%s' to Parquet VARIANT", + EnumUtil::ToString(type_id)); + } +} + +static void WriteValueData(const UnifiedVariantVectorData &variant, idx_t row, uint32_t values_index, + data_ptr_t &value_data, const vector &offsets, idx_t &offset_index, + optional_ptr shredding_state) { + VariantLogicalType type_id = VariantLogicalType::VARIANT_NULL; + if (variant.RowIsValid(row)) { + type_id = variant.GetTypeId(row, values_index); + } + if (type_id == VariantLogicalType::OBJECT) { + auto nested_data = VariantUtils::DecodeNestedData(variant, row, values_index); + + //! -- Object value header -- + + auto child_indices = GetChildIndices(variant, row, nested_data, shredding_state); + if (nested_data.child_count && child_indices.empty()) { + throw InternalException( + "The entire should be omitted, should have been handled by the Analyze step already"); + } + auto num_elements = child_indices.size(); + + //! Determine the 'field_id_size' + uint32_t highest_keys_index = 0; + for (auto &i : child_indices) { + auto keys_index = variant.GetKeysIndex(row, i + nested_data.children_idx); + highest_keys_index = MaxValue(highest_keys_index, keys_index); + } + auto field_id_size = CalculateByteLength(highest_keys_index); + + uint32_t last_offset = 0; + if (num_elements) { + last_offset = offsets[offset_index + num_elements]; + } + offset_index += num_elements + 1; + auto field_offset_size = CalculateByteLength(last_offset); + + const bool is_large = num_elements > NumericLimits::Maximum(); + + uint8_t value_header = 0; + value_header |= static_cast(VariantBasicType::OBJECT); + value_header |= static_cast(is_large) << 6; + value_header |= (static_cast(field_id_size) - 1) << 4; + value_header |= (static_cast(field_offset_size) - 1) << 2; + +#ifdef DEBUG + auto object_value_header = VariantValueMetadata::FromHeaderByte(value_header); + D_ASSERT(object_value_header.basic_type == VariantBasicType::OBJECT); + D_ASSERT(object_value_header.is_large == is_large); + D_ASSERT(object_value_header.field_offset_size == field_offset_size); + D_ASSERT(object_value_header.field_id_size == field_id_size); +#endif + + *value_data = value_header; + value_data++; + + //! Write the 'num_elements' + if (is_large) { + Store(static_cast(num_elements), value_data); + value_data += sizeof(uint32_t); + } else { + Store(static_cast(num_elements), value_data); + value_data += sizeof(uint8_t); + } + + //! Write the 'field_id' entries + for (auto &i : child_indices) { + auto keys_index = variant.GetKeysIndex(row, i + nested_data.children_idx); + memcpy(value_data, reinterpret_cast(&keys_index), field_id_size); + value_data += field_id_size; + } + + //! Write the 'field_offset' entries and the child 'value's + auto children_ptr = value_data + ((num_elements + 1) * field_offset_size); + idx_t total_offset = 0; + for (auto &i : child_indices) { + auto values_index = variant.GetValuesIndex(row, i + nested_data.children_idx); + + memcpy(value_data, reinterpret_cast(&total_offset), field_offset_size); + value_data += field_offset_size; + auto start_ptr = children_ptr; + WriteValueData(variant, row, values_index, children_ptr, offsets, offset_index, nullptr); + total_offset += (children_ptr - start_ptr); + } + memcpy(value_data, reinterpret_cast(&total_offset), field_offset_size); + value_data += field_offset_size; + D_ASSERT(children_ptr - total_offset == value_data); + value_data = children_ptr; + } else if (type_id == VariantLogicalType::ARRAY) { + auto nested_data = VariantUtils::DecodeNestedData(variant, row, values_index); + + //! -- Array value header -- + + uint32_t last_offset = 0; + if (nested_data.child_count) { + last_offset = offsets[offset_index + nested_data.child_count]; + } + offset_index += nested_data.child_count + 1; + auto field_offset_size = CalculateByteLength(last_offset); + + auto num_elements = nested_data.child_count; + const bool is_large = num_elements > NumericLimits::Maximum(); + + uint8_t value_header = 0; + value_header |= static_cast(VariantBasicType::ARRAY); + value_header |= static_cast(is_large) << 4; + value_header |= (static_cast(field_offset_size) - 1) << 2; + +#ifdef DEBUG + auto array_value_header = VariantValueMetadata::FromHeaderByte(value_header); + D_ASSERT(array_value_header.basic_type == VariantBasicType::ARRAY); + D_ASSERT(array_value_header.is_large == is_large); + D_ASSERT(array_value_header.field_offset_size == field_offset_size); +#endif + + *value_data = value_header; + value_data++; + + //! Write the 'num_elements' + if (is_large) { + Store(static_cast(num_elements), value_data); + value_data += sizeof(uint32_t); + } else { + Store(static_cast(num_elements), value_data); + value_data += sizeof(uint8_t); + } + + //! Write the 'field_offset' entries and the child 'value's + auto children_ptr = value_data + ((num_elements + 1) * field_offset_size); + idx_t total_offset = 0; + for (idx_t i = 0; i < nested_data.child_count; i++) { + auto values_index = variant.GetValuesIndex(row, i + nested_data.children_idx); + + memcpy(value_data, reinterpret_cast(&total_offset), field_offset_size); + value_data += field_offset_size; + auto start_ptr = children_ptr; + WriteValueData(variant, row, values_index, children_ptr, offsets, offset_index, nullptr); + total_offset += (children_ptr - start_ptr); + } + memcpy(value_data, reinterpret_cast(&total_offset), field_offset_size); + value_data += field_offset_size; + D_ASSERT(children_ptr - total_offset == value_data); + value_data = children_ptr; + } else { + WritePrimitiveValueData(variant, row, values_index, value_data, offsets, offset_index); + } +} + +static void CreateValues(UnifiedVariantVectorData &variant, Vector &value, optional_ptr sel, + optional_ptr value_index_sel, + optional_ptr result_sel, optional_ptr shredding_state, + idx_t count) { + auto &validity = FlatVector::Validity(value); + auto value_data = FlatVector::GetData(value); + + for (idx_t i = 0; i < count; i++) { + idx_t value_index = 0; + if (value_index_sel) { + value_index = value_index_sel->get_index(i); + } + + idx_t row = i; + if (sel) { + row = sel->get_index(i); + } + + idx_t result_index = i; + if (result_sel) { + result_index = result_sel->get_index(i); + } + + bool is_shredded = false; + if (variant.RowIsValid(row) && shredding_state && shredding_state->ValueIsShredded(variant, row, value_index)) { + shredding_state->SetShredded(row, value_index, result_index); + is_shredded = true; + if (shredding_state->type.id() != LogicalTypeId::STRUCT) { + //! Value is shredded, directly write a NULL to the 'value' if the type is not an OBJECT + //! When the type is OBJECT, all excess fields would still need to be written to the 'value' + validity.SetInvalid(result_index); + continue; + } + } + + //! The (relative) offsets for each value, in the case of nesting + vector offsets; + //! Determine the size of this 'value' blob + idx_t blob_length = AnalyzeValueData(variant, row, value_index, offsets, shredding_state); + if (!blob_length) { + //! This is only allowed to happen for a shredded OBJECT, where there are no excess fields to write for the + //! OBJECT + (void)is_shredded; + D_ASSERT(is_shredded); + validity.SetInvalid(result_index); + continue; + } + value_data[result_index] = StringVector::EmptyString(value, blob_length); + auto &value_blob = value_data[result_index]; + auto value_blob_data = reinterpret_cast(value_blob.GetDataWriteable()); + + idx_t offset_index = 0; + WriteValueData(variant, row, value_index, value_blob_data, offsets, offset_index, shredding_state); + D_ASSERT(data_ptr_cast(value_blob.GetDataWriteable() + blob_length) == value_blob_data); + value_blob.SetSizeAndFinalize(blob_length, blob_length); + } +} + +//! fwd-declare static method +static void WriteVariantValues(UnifiedVariantVectorData &variant, Vector &result, + optional_ptr sel, + optional_ptr value_index_sel, + optional_ptr result_sel, idx_t count); + +static void WriteTypedObjectValues(UnifiedVariantVectorData &variant, Vector &result, const SelectionVector &sel, + const SelectionVector &value_index_sel, const SelectionVector &result_sel, + idx_t count) { + auto &type = result.GetType(); + D_ASSERT(type.id() == LogicalTypeId::STRUCT); + + auto &validity = FlatVector::Validity(result); + (void)validity; + + //! Collect the nested data for the objects + auto nested_data = make_unsafe_uniq_array_uninitialized(count); + for (idx_t i = 0; i < count; i++) { + auto row = sel[i]; + //! When we're shredding an object, the top-level struct of it should always be valid + D_ASSERT(validity.RowIsValid(result_sel[i])); + auto value_index = value_index_sel[i]; + D_ASSERT(variant.GetTypeId(row, value_index) == VariantLogicalType::OBJECT); + nested_data[i] = VariantUtils::DecodeNestedData(variant, row, value_index); + } + + auto &shredded_types = StructType::GetChildTypes(type); + auto &shredded_fields = StructVector::GetEntries(result); + D_ASSERT(shredded_types.size() == shredded_fields.size()); + + SelectionVector child_values_indexes; + SelectionVector child_row_sel; + SelectionVector child_result_sel; + child_values_indexes.Initialize(count); + child_row_sel.Initialize(count); + child_result_sel.Initialize(count); + + for (idx_t child_idx = 0; child_idx < shredded_types.size(); child_idx++) { + auto &child_vec = *shredded_fields[child_idx]; + D_ASSERT(child_vec.GetType() == shredded_types[child_idx].second); + + //! Prepare the path component to perform the lookup for + auto &key = shredded_types[child_idx].first; + VariantPathComponent path_component; + path_component.lookup_mode = VariantChildLookupMode::BY_KEY; + path_component.key = key; + + ValidityMask lookup_validity(count); + VariantUtils::FindChildValues(variant, path_component, sel, child_values_indexes, lookup_validity, + nested_data.get(), count); + + if (!lookup_validity.AllValid()) { + auto &child_variant_vectors = StructVector::GetEntries(child_vec); + + //! For some of the rows the field is missing, adjust the selection vector to exclude these rows. + idx_t child_count = 0; + for (idx_t i = 0; i < count; i++) { + if (!lookup_validity.RowIsValid(i)) { + //! The field is missing, set it to null + FlatVector::SetNull(*child_variant_vectors[0], result_sel[i], true); + if (child_variant_vectors.size() >= 2) { + FlatVector::SetNull(*child_variant_vectors[1], result_sel[i], true); + } + continue; + } + + child_row_sel[child_count] = sel[i]; + child_values_indexes[child_count] = child_values_indexes[i]; + child_result_sel[child_count] = result_sel[i]; + child_count++; + } + + if (child_count) { + //! If not all rows are missing this field, write the values for it + WriteVariantValues(variant, child_vec, child_row_sel, child_values_indexes, child_result_sel, + child_count); + } + } else { + WriteVariantValues(variant, child_vec, &sel, child_values_indexes, result_sel, count); + } + } +} + +static void WriteTypedArrayValues(UnifiedVariantVectorData &variant, Vector &result, const SelectionVector &sel, + const SelectionVector &value_index_sel, const SelectionVector &result_sel, + idx_t count) { + auto list_data = FlatVector::GetData(result); + + auto nested_data = make_unsafe_uniq_array_uninitialized(count); + + idx_t total_offset = 0; + for (idx_t i = 0; i < count; i++) { + auto row = sel[i]; + auto value_index = value_index_sel[i]; + auto result_row = result_sel[i]; + + D_ASSERT(variant.GetTypeId(row, value_index) == VariantLogicalType::ARRAY); + nested_data[i] = VariantUtils::DecodeNestedData(variant, row, value_index); + + list_entry_t list_entry; + list_entry.length = nested_data[i].child_count; + list_entry.offset = total_offset; + list_data[result_row] = list_entry; + + total_offset += nested_data[i].child_count; + } + ListVector::Reserve(result, total_offset); + ListVector::SetListSize(result, total_offset); + + SelectionVector child_sel; + child_sel.Initialize(total_offset); + + SelectionVector child_value_index_sel; + child_value_index_sel.Initialize(total_offset); + + SelectionVector child_result_sel; + child_result_sel.Initialize(total_offset); + + for (idx_t i = 0; i < count; i++) { + auto row = sel[i]; + auto result_row = result_sel[i]; + + auto &array_data = nested_data[i]; + auto &entry = list_data[result_row]; + for (idx_t j = 0; j < entry.length; j++) { + auto offset = entry.offset + j; + child_sel[offset] = row; + child_value_index_sel[offset] = variant.GetValuesIndex(row, array_data.children_idx + j); + child_result_sel[offset] = offset; + } + } + + auto &child_vector = ListVector::GetEntry(result); + WriteVariantValues(variant, child_vector, child_sel, child_value_index_sel, child_result_sel, total_offset); +} + +//! TODO: introduce a third selection vector, because we also need one to map to the result row to write +//! This becomes necessary when we introduce LISTs into the equation because lists are stored on the same VARIANT row, +//! but we're now going to write the flattened child vector +static void WriteShreddedPrimitive(UnifiedVariantVectorData &variant, Vector &result, const SelectionVector &sel, + const SelectionVector &value_index_sel, const SelectionVector &result_sel, + idx_t count, idx_t type_size) { + auto result_data = FlatVector::GetData(result); + for (idx_t i = 0; i < count; i++) { + auto row = sel[i]; + auto result_row = result_sel[i]; + auto value_index = value_index_sel[i]; + D_ASSERT(variant.RowIsValid(row)); + + auto byte_offset = variant.GetByteOffset(row, value_index); + auto &data = variant.GetData(row); + auto value_ptr = data.GetData(); + auto result_offset = type_size * result_row; + memcpy(result_data + result_offset, value_ptr + byte_offset, type_size); + } +} + +template +static void WriteShreddedDecimal(UnifiedVariantVectorData &variant, Vector &result, const SelectionVector &sel, + const SelectionVector &value_index_sel, const SelectionVector &result_sel, + idx_t count) { + auto result_data = FlatVector::GetData(result); + for (idx_t i = 0; i < count; i++) { + auto row = sel[i]; + auto result_row = result_sel[i]; + auto value_index = value_index_sel[i]; + D_ASSERT(variant.RowIsValid(row) && variant.GetTypeId(row, value_index) == VariantLogicalType::DECIMAL); + + auto decimal_data = VariantUtils::DecodeDecimalData(variant, row, value_index); + D_ASSERT(decimal_data.width <= DecimalWidth::max); + auto result_offset = sizeof(T) * result_row; + memcpy(result_data + result_offset, decimal_data.value_ptr, sizeof(T)); + } +} + +static void WriteShreddedString(UnifiedVariantVectorData &variant, Vector &result, const SelectionVector &sel, + const SelectionVector &value_index_sel, const SelectionVector &result_sel, + idx_t count) { + auto result_data = FlatVector::GetData(result); + for (idx_t i = 0; i < count; i++) { + auto row = sel[i]; + auto result_row = result_sel[i]; + auto value_index = value_index_sel[i]; + D_ASSERT(variant.RowIsValid(row) && (variant.GetTypeId(row, value_index) == VariantLogicalType::VARCHAR || + variant.GetTypeId(row, value_index) == VariantLogicalType::BLOB)); + + auto string_data = VariantUtils::DecodeStringData(variant, row, value_index); + result_data[result_row] = StringVector::AddStringOrBlob(result, string_data); + } +} + +static void WriteShreddedBoolean(UnifiedVariantVectorData &variant, Vector &result, const SelectionVector &sel, + const SelectionVector &value_index_sel, const SelectionVector &result_sel, + idx_t count) { + auto result_data = FlatVector::GetData(result); + for (idx_t i = 0; i < count; i++) { + auto row = sel[i]; + auto result_row = result_sel[i]; + auto value_index = value_index_sel[i]; + D_ASSERT(variant.RowIsValid(row)); + auto type_id = variant.GetTypeId(row, value_index); + D_ASSERT(type_id == VariantLogicalType::BOOL_FALSE || type_id == VariantLogicalType::BOOL_TRUE); + + result_data[result_row] = type_id == VariantLogicalType::BOOL_TRUE; + } +} + +static void WriteTypedPrimitiveValues(UnifiedVariantVectorData &variant, Vector &result, const SelectionVector &sel, + const SelectionVector &value_index_sel, const SelectionVector &result_sel, + idx_t count) { + auto &type = result.GetType(); + D_ASSERT(!type.IsNested()); + switch (type.id()) { + case LogicalTypeId::TINYINT: + case LogicalTypeId::SMALLINT: + case LogicalTypeId::INTEGER: + case LogicalTypeId::BIGINT: + case LogicalTypeId::FLOAT: + case LogicalTypeId::DOUBLE: + case LogicalTypeId::DATE: + case LogicalTypeId::TIME: + case LogicalTypeId::TIMESTAMP_TZ: + case LogicalTypeId::TIMESTAMP: + case LogicalTypeId::TIMESTAMP_NS: + case LogicalTypeId::UUID: { + const auto physical_type = type.InternalType(); + WriteShreddedPrimitive(variant, result, sel, value_index_sel, result_sel, count, GetTypeIdSize(physical_type)); + break; + } + case LogicalTypeId::DECIMAL: { + const auto physical_type = type.InternalType(); + switch (physical_type) { + //! DECIMAL4 + case PhysicalType::INT32: + WriteShreddedDecimal(variant, result, sel, value_index_sel, result_sel, count); + break; + //! DECIMAL8 + case PhysicalType::INT64: + WriteShreddedDecimal(variant, result, sel, value_index_sel, result_sel, count); + break; + //! DECIMAL16 + case PhysicalType::INT128: + WriteShreddedDecimal(variant, result, sel, value_index_sel, result_sel, count); + break; + default: + throw InvalidInputException("Can't shred on column of type '%s'", type.ToString()); + } + break; + } + case LogicalTypeId::BLOB: + case LogicalTypeId::VARCHAR: { + WriteShreddedString(variant, result, sel, value_index_sel, result_sel, count); + break; + } + case LogicalTypeId::BOOLEAN: + WriteShreddedBoolean(variant, result, sel, value_index_sel, result_sel, count); + break; + default: + throw InvalidInputException("Can't shred on type: %s", type.ToString()); + } +} + +static void WriteTypedValues(UnifiedVariantVectorData &variant, Vector &result, const SelectionVector &sel, + const SelectionVector &value_index_sel, const SelectionVector &result_sel, idx_t count) { + auto &type = result.GetType(); + + if (type.id() == LogicalTypeId::STRUCT) { + //! Shredded OBJECT + WriteTypedObjectValues(variant, result, sel, value_index_sel, result_sel, count); + } else if (type.id() == LogicalTypeId::LIST) { + //! Shredded ARRAY + WriteTypedArrayValues(variant, result, sel, value_index_sel, result_sel, count); + } else { + //! Primitive types + WriteTypedPrimitiveValues(variant, result, sel, value_index_sel, result_sel, count); + } +} + +static void WriteVariantValues(UnifiedVariantVectorData &variant, Vector &result, + optional_ptr sel, + optional_ptr value_index_sel, + optional_ptr result_sel, idx_t count) { + optional_ptr value; + optional_ptr typed_value; + + auto &result_type = result.GetType(); + D_ASSERT(result_type.id() == LogicalTypeId::STRUCT); + auto &child_types = StructType::GetChildTypes(result_type); + auto &child_vectors = StructVector::GetEntries(result); + D_ASSERT(child_types.size() == child_vectors.size()); + for (idx_t i = 0; i < child_types.size(); i++) { + auto &name = child_types[i].first; + if (name == "value") { + value = child_vectors[i].get(); + } else if (name == "typed_value") { + typed_value = child_vectors[i].get(); + } + } + + if (typed_value) { + ShreddingState shredding_state(typed_value->GetType(), count); + CreateValues(variant, *value, sel, value_index_sel, result_sel, &shredding_state, count); + + SelectionVector null_values; + if (shredding_state.count) { + WriteTypedValues(variant, *typed_value, shredding_state.shredded_sel, shredding_state.values_index_sel, + shredding_state.result_sel, shredding_state.count); + //! 'shredding_state.result_sel' will always be a subset of 'result_sel', set the rows not in the subset to + //! NULL + idx_t sel_idx = 0; + for (idx_t i = 0; i < count; i++) { + auto original_index = result_sel ? result_sel->get_index(i) : i; + if (sel_idx < shredding_state.count && shredding_state.result_sel[sel_idx] == original_index) { + sel_idx++; + continue; + } + FlatVector::SetNull(*typed_value, original_index, true); + } + } else { + //! Set all rows of the typed_value to NULL, nothing is shredded on + for (idx_t i = 0; i < count; i++) { + FlatVector::SetNull(*typed_value, result_sel ? result_sel->get_index(i) : i, true); + } + } + } else { + CreateValues(variant, *value, sel, value_index_sel, result_sel, nullptr, count); + } +} + +static void ToParquetVariant(DataChunk &input, ExpressionState &state, Vector &result) { + // DuckDB Variant: + // - keys = VARCHAR[] + // - children = STRUCT(keys_index UINTEGER, values_index UINTEGER)[] + // - values = STRUCT(type_id UTINYINT, byte_offset UINTEGER)[] + // - data = BLOB + + // Parquet VARIANT: + // - metadata = BLOB + // - value = BLOB + + auto &variant_vec = input.data[0]; + auto count = input.size(); + + RecursiveUnifiedVectorFormat recursive_format; + Vector::RecursiveToUnifiedFormat(variant_vec, count, recursive_format); + UnifiedVariantVectorData variant(recursive_format); + + auto &result_vectors = StructVector::GetEntries(result); + auto &metadata = *result_vectors[0]; + CreateMetadata(variant, metadata, count); + WriteVariantValues(variant, result, nullptr, nullptr, nullptr, count); + + if (input.AllConstant()) { + result.SetVectorType(VectorType::CONSTANT_VECTOR); + } +} + +LogicalType VariantColumnWriter::TransformTypedValueRecursive(const LogicalType &type) { + switch (type.id()) { + case LogicalTypeId::STRUCT: { + //! Wrap all fields of the struct in a struct with 'value' and 'typed_value' fields + auto &child_types = StructType::GetChildTypes(type); + child_list_t replaced_types; + for (auto &entry : child_types) { + child_list_t child_children; + child_children.emplace_back("value", LogicalType::BLOB); + if (entry.second.id() != LogicalTypeId::VARIANT) { + child_children.emplace_back("typed_value", TransformTypedValueRecursive(entry.second)); + } + replaced_types.emplace_back(entry.first, LogicalType::STRUCT(child_children)); + } + return LogicalType::STRUCT(replaced_types); + } + case LogicalTypeId::LIST: { + auto &child_type = ListType::GetChildType(type); + child_list_t replaced_types; + replaced_types.emplace_back("value", LogicalType::BLOB); + if (child_type.id() != LogicalTypeId::VARIANT) { + replaced_types.emplace_back("typed_value", TransformTypedValueRecursive(child_type)); + } + return LogicalType::LIST(LogicalType::STRUCT(replaced_types)); + } + case LogicalTypeId::UNION: + case LogicalTypeId::MAP: + case LogicalTypeId::VARIANT: + case LogicalTypeId::ARRAY: + throw BinderException("'%s' can't appear inside the a 'typed_value' shredded type!", type.ToString()); + default: + return type; + } +} + +static LogicalType GetParquetVariantType(optional_ptr shredding = nullptr) { + child_list_t children; + children.emplace_back("metadata", LogicalType::BLOB); + children.emplace_back("value", LogicalType::BLOB); + if (shredding) { + children.emplace_back("typed_value", VariantColumnWriter::TransformTypedValueRecursive(*shredding)); + } + auto res = LogicalType::STRUCT(std::move(children)); + res.SetAlias("PARQUET_VARIANT"); + return res; +} + +static unique_ptr BindTransform(ClientContext &context, ScalarFunction &bound_function, + vector> &arguments) { + if (arguments.empty()) { + return nullptr; + } + auto type = ExpressionBinder::GetExpressionReturnType(*arguments[0]); + + if (arguments.size() == 2) { + auto &shredding = *arguments[1]; + auto expr_return_type = ExpressionBinder::GetExpressionReturnType(shredding); + expr_return_type = LogicalType::NormalizeType(expr_return_type); + if (expr_return_type.id() != LogicalTypeId::VARCHAR) { + throw BinderException("Optional second argument 'shredding' has to be of type VARCHAR, i.e: " + "'STRUCT(my_field BOOLEAN)', found type: '%s' instead", + expr_return_type); + } + if (!shredding.IsFoldable()) { + throw BinderException("Optional second argument 'shredding' has to be a constant expression"); + } + Value type_str = ExpressionExecutor::EvaluateScalar(context, shredding); + if (type_str.IsNull()) { + throw BinderException("Optional second argument 'shredding' can not be NULL"); + } + auto shredded_type = TransformStringToLogicalType(type_str.GetValue()); + bound_function.SetReturnType(GetParquetVariantType(shredded_type)); + } else { + bound_function.SetReturnType(GetParquetVariantType()); + } + + return nullptr; +} + +ScalarFunction VariantColumnWriter::GetTransformFunction() { + ScalarFunction transform("variant_to_parquet_variant", {LogicalType::VARIANT()}, LogicalType::ANY, ToParquetVariant, + BindTransform); + transform.SetNullHandling(FunctionNullHandling::SPECIAL_HANDLING); + return transform; +} + +} // namespace duckdb diff --git a/src/duckdb/src/catalog/catalog_entry.cpp b/src/duckdb/src/catalog/catalog_entry.cpp index 7fdc0c3be..8fca4a954 100644 --- a/src/duckdb/src/catalog/catalog_entry.cpp +++ b/src/duckdb/src/catalog/catalog_entry.cpp @@ -48,7 +48,7 @@ unique_ptr CatalogEntry::GetInfo() const { } string CatalogEntry::ToSQL() const { - throw InternalException("Unsupported catalog type for ToSQL()"); + throw InternalException({{"catalog_type", CatalogTypeToString(type)}}, "Unsupported catalog type for ToSQL()"); } void CatalogEntry::SetChild(unique_ptr child_p) { diff --git a/src/duckdb/src/catalog/catalog_entry/copy_function_catalog_entry.cpp b/src/duckdb/src/catalog/catalog_entry/copy_function_catalog_entry.cpp index 25544a343..6d639bef6 100644 --- a/src/duckdb/src/catalog/catalog_entry/copy_function_catalog_entry.cpp +++ b/src/duckdb/src/catalog/catalog_entry/copy_function_catalog_entry.cpp @@ -3,6 +3,8 @@ namespace duckdb { +constexpr const char *CopyFunctionCatalogEntry::Name; + CopyFunctionCatalogEntry::CopyFunctionCatalogEntry(Catalog &catalog, SchemaCatalogEntry &schema, CreateCopyFunctionInfo &info) : StandardEntry(CatalogType::COPY_FUNCTION_ENTRY, schema, catalog, info.name), function(info.function) { diff --git a/src/duckdb/src/catalog/catalog_entry/duck_index_entry.cpp b/src/duckdb/src/catalog/catalog_entry/duck_index_entry.cpp index c70984e53..769a06b9f 100644 --- a/src/duckdb/src/catalog/catalog_entry/duck_index_entry.cpp +++ b/src/duckdb/src/catalog/catalog_entry/duck_index_entry.cpp @@ -22,7 +22,6 @@ void DuckIndexEntry::Rollback(CatalogEntry &) { DuckIndexEntry::DuckIndexEntry(Catalog &catalog, SchemaCatalogEntry &schema, CreateIndexInfo &create_info, TableCatalogEntry &table_p) : IndexCatalogEntry(catalog, schema, create_info), initial_index_size(0) { - auto &table = table_p.Cast(); auto &storage = table.GetStorage(); info = make_shared_ptr(storage.GetDataTableInfo(), name); diff --git a/src/duckdb/src/catalog/catalog_entry/duck_schema_entry.cpp b/src/duckdb/src/catalog/catalog_entry/duck_schema_entry.cpp index a0f40ce82..33d0db4da 100644 --- a/src/duckdb/src/catalog/catalog_entry/duck_schema_entry.cpp +++ b/src/duckdb/src/catalog/catalog_entry/duck_schema_entry.cpp @@ -391,7 +391,7 @@ CatalogSet &DuckSchemaEntry::GetCatalogSet(CatalogType type) { case CatalogType::TYPE_ENTRY: return types; default: - throw InternalException("Unsupported catalog type in schema"); + throw InternalException({{"catalog_type", CatalogTypeToString(type)}}, "Unsupported catalog type in schema"); } } diff --git a/src/duckdb/src/catalog/catalog_entry/duck_table_entry.cpp b/src/duckdb/src/catalog/catalog_entry/duck_table_entry.cpp index b80204ac0..4f1866ee4 100644 --- a/src/duckdb/src/catalog/catalog_entry/duck_table_entry.cpp +++ b/src/duckdb/src/catalog/catalog_entry/duck_table_entry.cpp @@ -29,7 +29,6 @@ namespace duckdb { IndexStorageInfo GetIndexInfo(const IndexConstraintType type, const bool v1_0_0_storage, unique_ptr &info, const idx_t id) { - auto &table_info = info->Cast(); auto constraint_name = EnumUtil::ToString(type) + "_"; auto name = constraint_name + table_info.table + "_" + to_string(id); @@ -44,7 +43,6 @@ DuckTableEntry::DuckTableEntry(Catalog &catalog, SchemaCatalogEntry &schema, Bou shared_ptr inherited_storage) : TableCatalogEntry(catalog, schema, info.Base()), storage(std::move(inherited_storage)), column_dependency_manager(std::move(info.column_dependency_manager)) { - if (storage) { if (!info.indexes.empty()) { storage->SetIndexStorageInfo(std::move(info.indexes)); @@ -68,7 +66,6 @@ DuckTableEntry::DuckTableEntry(Catalog &catalog, SchemaCatalogEntry &schema, Bou for (idx_t i = 0; i < constraints.size(); i++) { auto &constraint = constraints[i]; if (constraint->type == ConstraintType::UNIQUE) { - // UNIQUE constraint: Create a unique index. auto &unique = constraint->Cast(); IndexConstraintType constraint_type = IndexConstraintType::UNIQUE; @@ -99,7 +96,6 @@ DuckTableEntry::DuckTableEntry(Catalog &catalog, SchemaCatalogEntry &schema, Bou auto &bfk = constraint->Cast(); if (bfk.info.type == ForeignKeyType::FK_TYPE_FOREIGN_KEY_TABLE || bfk.info.type == ForeignKeyType::FK_TYPE_SELF_REFERENCE_TABLE) { - vector column_indexes; for (const auto &physical_index : bfk.info.fk_keys) { auto &col = columns.GetColumn(physical_index); @@ -595,12 +591,24 @@ void DuckTableEntry::UpdateConstraintsOnColumnDrop(const LogicalIndex &removed_i auto copy = constraint->Copy(); auto &unique = copy->Cast(); if (unique.HasIndex()) { + // Single-column UNIQUE constraint if (unique.GetIndex() == removed_index) { throw CatalogException( "Cannot drop column \"%s\" because there is a UNIQUE constraint that depends on it", info.removed_column); } unique.SetIndex(adjusted_indices[unique.GetIndex().index]); + } else { + // Multi-column UNIQUE constraint - check if any column matches the one being dropped + for (const auto &col_name : unique.GetColumnNames()) { + if (col_name == info.removed_column) { + // Build constraint string for error message: UNIQUE(col1, col2, ...) + auto constraint_str = "UNIQUE(" + StringUtil::Join(unique.GetColumnNames(), ", ") + ")"; + throw CatalogException( + "Cannot drop column \"%s\" because it is referenced in unique constraint %s", + info.removed_column, constraint_str); + } + } } create_info.constraints.push_back(std::move(copy)); break; @@ -1285,8 +1293,8 @@ TableFunction DuckTableEntry::GetScanFunction(ClientContext &context, unique_ptr return TableScanFunction::GetFunction(); } -vector DuckTableEntry::GetColumnSegmentInfo() { - return storage->GetColumnSegmentInfo(); +vector DuckTableEntry::GetColumnSegmentInfo(const QueryContext &context) { + return storage->GetColumnSegmentInfo(context); } TableStorageInfo DuckTableEntry::GetStorageInfo(ClientContext &context) { diff --git a/src/duckdb/src/catalog/catalog_entry/index_catalog_entry.cpp b/src/duckdb/src/catalog/catalog_entry/index_catalog_entry.cpp index 2c5cb9ae7..ed71e174c 100644 --- a/src/duckdb/src/catalog/catalog_entry/index_catalog_entry.cpp +++ b/src/duckdb/src/catalog/catalog_entry/index_catalog_entry.cpp @@ -5,7 +5,6 @@ namespace duckdb { IndexCatalogEntry::IndexCatalogEntry(Catalog &catalog, SchemaCatalogEntry &schema, CreateIndexInfo &info) : StandardEntry(CatalogType::INDEX_ENTRY, schema, catalog, info.index_name), sql(info.sql), options(info.options), index_type(info.index_type), index_constraint_type(info.constraint_type), column_ids(info.column_ids) { - this->temporary = info.temporary; this->dependencies = info.dependencies; this->comment = info.comment; diff --git a/src/duckdb/src/catalog/catalog_entry/pragma_function_catalog_entry.cpp b/src/duckdb/src/catalog/catalog_entry/pragma_function_catalog_entry.cpp index ff247dcb0..9d9789192 100644 --- a/src/duckdb/src/catalog/catalog_entry/pragma_function_catalog_entry.cpp +++ b/src/duckdb/src/catalog/catalog_entry/pragma_function_catalog_entry.cpp @@ -2,6 +2,7 @@ #include "duckdb/parser/parsed_data/create_pragma_function_info.hpp" namespace duckdb { +constexpr const char *PragmaFunctionCatalogEntry::Name; PragmaFunctionCatalogEntry::PragmaFunctionCatalogEntry(Catalog &catalog, SchemaCatalogEntry &schema, CreatePragmaFunctionInfo &info) diff --git a/src/duckdb/src/catalog/catalog_entry/scalar_function_catalog_entry.cpp b/src/duckdb/src/catalog/catalog_entry/scalar_function_catalog_entry.cpp index 49b20f677..e5778ad4c 100644 --- a/src/duckdb/src/catalog/catalog_entry/scalar_function_catalog_entry.cpp +++ b/src/duckdb/src/catalog/catalog_entry/scalar_function_catalog_entry.cpp @@ -5,6 +5,8 @@ namespace duckdb { +constexpr const char *ScalarFunctionCatalogEntry::Name; + ScalarFunctionCatalogEntry::ScalarFunctionCatalogEntry(Catalog &catalog, SchemaCatalogEntry &schema, CreateScalarFunctionInfo &info) : FunctionEntry(CatalogType::SCALAR_FUNCTION_ENTRY, catalog, schema, info), functions(info.functions) { diff --git a/src/duckdb/src/catalog/catalog_entry/sequence_catalog_entry.cpp b/src/duckdb/src/catalog/catalog_entry/sequence_catalog_entry.cpp index 6153a8e8a..d6a548a26 100644 --- a/src/duckdb/src/catalog/catalog_entry/sequence_catalog_entry.cpp +++ b/src/duckdb/src/catalog/catalog_entry/sequence_catalog_entry.cpp @@ -13,6 +13,8 @@ namespace duckdb { +constexpr const char *SequenceCatalogEntry::Name; + SequenceData::SequenceData(CreateSequenceInfo &info) : usage_count(info.usage_count), counter(info.start_value), last_value(info.start_value), increment(info.increment), start_value(info.start_value), min_value(info.min_value), max_value(info.max_value), cycle(info.cycle) { diff --git a/src/duckdb/src/catalog/catalog_entry/table_catalog_entry.cpp b/src/duckdb/src/catalog/catalog_entry/table_catalog_entry.cpp index 22a173fd8..8582fa93c 100644 --- a/src/duckdb/src/catalog/catalog_entry/table_catalog_entry.cpp +++ b/src/duckdb/src/catalog/catalog_entry/table_catalog_entry.cpp @@ -19,6 +19,8 @@ namespace duckdb { +constexpr const char *TableCatalogEntry::Name; + TableCatalogEntry::TableCatalogEntry(Catalog &catalog, SchemaCatalogEntry &schema, CreateTableInfo &info) : StandardEntry(CatalogType::TABLE_ENTRY, schema, catalog, info.table), columns(std::move(info.columns)), constraints(std::move(info.constraints)) { @@ -266,7 +268,7 @@ void LogicalUpdate::BindExtraColumns(TableCatalogEntry &table, LogicalGet &get, } } -vector TableCatalogEntry::GetColumnSegmentInfo() { +vector TableCatalogEntry::GetColumnSegmentInfo(const QueryContext &context) { return {}; } diff --git a/src/duckdb/src/catalog/catalog_entry/table_function_catalog_entry.cpp b/src/duckdb/src/catalog/catalog_entry/table_function_catalog_entry.cpp index a6a41ff61..f06ef164e 100644 --- a/src/duckdb/src/catalog/catalog_entry/table_function_catalog_entry.cpp +++ b/src/duckdb/src/catalog/catalog_entry/table_function_catalog_entry.cpp @@ -4,6 +4,8 @@ namespace duckdb { +constexpr const char *TableFunctionCatalogEntry::Name; + TableFunctionCatalogEntry::TableFunctionCatalogEntry(Catalog &catalog, SchemaCatalogEntry &schema, CreateTableFunctionInfo &info) : FunctionEntry(CatalogType::TABLE_FUNCTION_ENTRY, catalog, schema, info), functions(std::move(info.functions)) { diff --git a/src/duckdb/src/catalog/catalog_entry/type_catalog_entry.cpp b/src/duckdb/src/catalog/catalog_entry/type_catalog_entry.cpp index 0bb4a3f3a..324413b7c 100644 --- a/src/duckdb/src/catalog/catalog_entry/type_catalog_entry.cpp +++ b/src/duckdb/src/catalog/catalog_entry/type_catalog_entry.cpp @@ -9,6 +9,8 @@ namespace duckdb { +constexpr const char *TypeCatalogEntry::Name; + TypeCatalogEntry::TypeCatalogEntry(Catalog &catalog, SchemaCatalogEntry &schema, CreateTypeInfo &info) : StandardEntry(CatalogType::TYPE_ENTRY, schema, catalog, info.name), user_type(info.type), bind_function(info.bind_function) { diff --git a/src/duckdb/src/catalog/catalog_search_path.cpp b/src/duckdb/src/catalog/catalog_search_path.cpp index 6af56c22d..6388b9134 100644 --- a/src/duckdb/src/catalog/catalog_search_path.cpp +++ b/src/duckdb/src/catalog/catalog_search_path.cpp @@ -24,8 +24,8 @@ string CatalogSearchEntry::ToString() const { string CatalogSearchEntry::WriteOptionallyQuoted(const string &input) { for (idx_t i = 0; i < input.size(); i++) { - if (input[i] == '.' || input[i] == ',') { - return "\"" + input + "\""; + if (input[i] == '.' || input[i] == ',' || input[i] == '"') { + return "\"" + StringUtil::Replace(input, "\"", "\"\"") + "\""; } } return input; diff --git a/src/duckdb/src/catalog/catalog_set.cpp b/src/duckdb/src/catalog/catalog_set.cpp index deff8daae..d374f6999 100644 --- a/src/duckdb/src/catalog/catalog_set.cpp +++ b/src/duckdb/src/catalog/catalog_set.cpp @@ -401,8 +401,6 @@ bool CatalogSet::DropEntryInternal(CatalogTransaction transaction, const string throw CatalogException("Cannot drop entry \"%s\" because it is an internal system entry", entry->name); } - entry->OnDrop(); - // create a new tombstone entry and replace the currently stored one // set the timestamp to the timestamp of the current transaction // and point it at the tombstone node @@ -454,6 +452,7 @@ void CatalogSet::VerifyExistenceOfDependency(transaction_t commit_id, CatalogEnt void CatalogSet::CommitDrop(transaction_t commit_id, transaction_t start_time, CatalogEntry &entry) { auto &duck_catalog = GetCatalog(); + entry.OnDrop(); // Make sure that we don't see any uncommitted changes auto transaction_id = MAX_TRANSACTION_ID; // This will allow us to see all committed changes made before this COMMIT happened diff --git a/src/duckdb/src/catalog/default/default_functions.cpp b/src/duckdb/src/catalog/default/default_functions.cpp index f51038b1e..9ecc89739 100644 --- a/src/duckdb/src/catalog/default/default_functions.cpp +++ b/src/duckdb/src/catalog/default/default_functions.cpp @@ -98,9 +98,9 @@ static const DefaultMacro internal_macros[] = { {DEFAULT_SCHEMA, "array_pop_front", {"arr", nullptr}, {{nullptr, nullptr}}, "arr[2:]"}, {DEFAULT_SCHEMA, "array_push_back", {"arr", "e", nullptr}, {{nullptr, nullptr}}, "list_concat(arr, list_value(e))"}, {DEFAULT_SCHEMA, "array_push_front", {"arr", "e", nullptr}, {{nullptr, nullptr}}, "list_concat(list_value(e), arr)"}, - {DEFAULT_SCHEMA, "array_to_string", {"arr", "sep", nullptr}, {{nullptr, nullptr}}, "list_aggr(arr::varchar[], 'string_agg', sep)"}, + {DEFAULT_SCHEMA, "array_to_string", {"arr", "sep", nullptr}, {{nullptr, nullptr}}, "case len(arr::varchar[]) when 0 then '' else list_aggr(arr::varchar[], 'string_agg', sep) end"}, // Test default parameters - {DEFAULT_SCHEMA, "array_to_string_comma_default", {"arr", nullptr}, {{"sep", "','"}, {nullptr, nullptr}}, "list_aggr(arr::varchar[], 'string_agg', sep)"}, + {DEFAULT_SCHEMA, "array_to_string_comma_default", {"arr", nullptr}, {{"sep", "','"}, {nullptr, nullptr}}, "case len(arr::varchar[]) when 0 then '' else list_aggr(arr::varchar[], 'string_agg', sep) end"}, {DEFAULT_SCHEMA, "generate_subscripts", {"arr", "dim", nullptr}, {{nullptr, nullptr}}, "unnest(generate_series(1, array_length(arr, dim)))"}, {DEFAULT_SCHEMA, "fdiv", {"x", "y", nullptr}, {{nullptr, nullptr}}, "floor(x/y)"}, diff --git a/src/duckdb/src/common/adbc/adbc.cpp b/src/duckdb/src/common/adbc/adbc.cpp index 054eaaf0f..54a486dc2 100644 --- a/src/duckdb/src/common/adbc/adbc.cpp +++ b/src/duckdb/src/common/adbc/adbc.cpp @@ -196,7 +196,6 @@ AdbcStatusCode DatabaseInit(struct AdbcDatabase *database, struct AdbcError *err } AdbcStatusCode DatabaseRelease(struct AdbcDatabase *database, struct AdbcError *error) { - if (database && database->private_data) { auto wrapper = static_cast(database->private_data); @@ -537,7 +536,8 @@ static int get_schema(struct ArrowArrayStream *stream, struct ArrowSchema *out) auto count = duckdb_column_count(&result_wrapper->result); std::vector types(count); - std::vector owned_names(count); + std::vector owned_names; + owned_names.reserve(count); duckdb::vector names(count); for (idx_t i = 0; i < count; i++) { types[i] = duckdb_column_logical_type(&result_wrapper->result, i); @@ -605,7 +605,6 @@ const char *get_last_error(struct ArrowArrayStream *stream) { duckdb::unique_ptr stream_produce(uintptr_t factory_ptr, duckdb::ArrowStreamParameters ¶meters) { - // TODO this will ignore any projections or filters but since we don't expose the scan it should be sort of fine auto res = duckdb::make_uniq(); res->arrow_array_stream = *reinterpret_cast(factory_ptr); @@ -619,7 +618,6 @@ void stream_schema(ArrowArrayStream *stream, ArrowSchema &schema) { AdbcStatusCode Ingest(duckdb_connection connection, const char *table_name, const char *schema, struct ArrowArrayStream *input, struct AdbcError *error, IngestionMode ingestion_mode, bool temporary) { - if (!connection) { SetError(error, "Missing connection object"); return ADBC_STATUS_INVALID_ARGUMENT; @@ -659,12 +657,12 @@ AdbcStatusCode Ingest(duckdb_connection connection, const char *table_name, cons std::ostringstream create_table; create_table << "CREATE TABLE "; if (schema) { - create_table << schema << "."; + create_table << duckdb::KeywordHelper::WriteOptionallyQuoted(schema) << "."; } - create_table << table_name << " ("; + create_table << duckdb::KeywordHelper::WriteOptionallyQuoted(table_name) << " ("; for (idx_t i = 0; i < types.size(); i++) { - create_table << names[i] << " "; - create_table << types[i].ToString(); + create_table << duckdb::KeywordHelper::WriteOptionallyQuoted(names[i]); + create_table << " " << types[i].ToString(); if (i + 1 < types.size()) { create_table << ", "; } @@ -793,7 +791,8 @@ AdbcStatusCode StatementGetParameterSchema(struct AdbcStatement *statement, stru count = 1; } std::vector types(count); - std::vector owned_names(count); + std::vector owned_names; + owned_names.reserve(count); duckdb::vector names(count); for (idx_t i = 0; i < count; i++) { diff --git a/src/duckdb/src/common/adbc/driver_manager.cpp b/src/duckdb/src/common/adbc/driver_manager.cpp index 45fb8c24d..106c3e598 100644 --- a/src/duckdb/src/common/adbc/driver_manager.cpp +++ b/src/duckdb/src/common/adbc/driver_manager.cpp @@ -1080,7 +1080,6 @@ AdbcStatusCode AdbcConnectionGetTableTypes(struct AdbcConnection *connection, st AdbcStatusCode AdbcConnectionInit(struct AdbcConnection *connection, struct AdbcDatabase *database, struct AdbcError *error) { - if (!connection->private_data) { SetError(error, "Must call AdbcConnectionNew first"); return ADBC_STATUS_INVALID_STATE; diff --git a/src/duckdb/src/common/allocator.cpp b/src/duckdb/src/common/allocator.cpp index 977087939..ea92b2524 100644 --- a/src/duckdb/src/common/allocator.cpp +++ b/src/duckdb/src/common/allocator.cpp @@ -35,6 +35,8 @@ namespace duckdb { +constexpr const idx_t Allocator::MAXIMUM_ALLOC_SIZE; + AllocatedData::AllocatedData() : allocator(nullptr), pointer(nullptr), allocated_size(0) { } @@ -254,7 +256,7 @@ static void MallocTrim(idx_t pad) { return; // Another thread has updated LAST_TRIM_TIMESTAMP_MS since we loaded it } - // We succesfully updated LAST_TRIM_TIMESTAMP_MS, we can trim + // We successfully updated LAST_TRIM_TIMESTAMP_MS, we can trim malloc_trim(pad); #endif } diff --git a/src/duckdb/src/common/arrow/arrow_converter.cpp b/src/duckdb/src/common/arrow/arrow_converter.cpp index d5acf3698..b5429763b 100644 --- a/src/duckdb/src/common/arrow/arrow_converter.cpp +++ b/src/duckdb/src/common/arrow/arrow_converter.cpp @@ -358,7 +358,6 @@ void SetArrowFormat(DuckDBArrowSchemaHolder &root_holder, ArrowSchema &child, co } child.children = &root_holder.nested_children_ptr.back()[0]; for (size_t type_idx = 0; type_idx < child_types.size(); type_idx++) { - InitializeChild(*child.children[type_idx], root_holder); root_holder.owned_type_names.push_back(AddName(child_types[type_idx].first)); diff --git a/src/duckdb/src/common/arrow/arrow_query_result.cpp b/src/duckdb/src/common/arrow/arrow_query_result.cpp index 396a99944..608a0bd32 100644 --- a/src/duckdb/src/common/arrow/arrow_query_result.cpp +++ b/src/duckdb/src/common/arrow/arrow_query_result.cpp @@ -16,10 +16,7 @@ ArrowQueryResult::ArrowQueryResult(StatementType statement_type, StatementProper ArrowQueryResult::ArrowQueryResult(ErrorData error) : QueryResult(QueryResultType::ARROW_RESULT, std::move(error)) { } -unique_ptr ArrowQueryResult::Fetch() { - throw NotImplementedException("Can't 'Fetch' from ArrowQueryResult"); -} -unique_ptr ArrowQueryResult::FetchRaw() { +unique_ptr ArrowQueryResult::FetchInternal() { throw NotImplementedException("Can't 'FetchRaw' from ArrowQueryResult"); } diff --git a/src/duckdb/src/common/arrow/arrow_type_extension.cpp b/src/duckdb/src/common/arrow/arrow_type_extension.cpp index 93979cd36..d3dff923c 100644 --- a/src/duckdb/src/common/arrow/arrow_type_extension.cpp +++ b/src/duckdb/src/common/arrow/arrow_type_extension.cpp @@ -7,6 +7,8 @@ #include "duckdb/common/arrow/schema_metadata.hpp" #include "duckdb/common/types/vector.hpp" +#include "yyjson.hpp" + namespace duckdb { ArrowTypeExtension::ArrowTypeExtension(string extension_name, string arrow_format, @@ -365,6 +367,72 @@ struct ArrowBool8 { } }; +struct ArrowGeometry { + static unique_ptr GetType(const ArrowSchema &schema, const ArrowSchemaMetadata &schema_metadata) { + // Validate extension metadata. This metadata also contains a CRS, which we drop + // because the GEOMETRY type does not implement a CRS at the type level (yet). + const auto extension_metadata = schema_metadata.GetOption(ArrowSchemaMetadata::ARROW_METADATA_KEY); + if (!extension_metadata.empty()) { + unique_ptr doc( + duckdb_yyjson::yyjson_read(extension_metadata.data(), extension_metadata.size(), + duckdb_yyjson::YYJSON_READ_NOFLAG), + duckdb_yyjson::yyjson_doc_free); + if (!doc) { + throw SerializationException("Invalid JSON in GeoArrow metadata"); + } + + duckdb_yyjson::yyjson_val *val = yyjson_doc_get_root(doc.get()); + if (!yyjson_is_obj(val)) { + throw SerializationException("Invalid GeoArrow metadata: not a JSON object"); + } + + duckdb_yyjson::yyjson_val *edges = yyjson_obj_get(val, "edges"); + if (edges && yyjson_is_str(edges) && std::strcmp(yyjson_get_str(edges), "planar") != 0) { + throw NotImplementedException("Can't import non-planar edges"); + } + } + + const auto format = string(schema.format); + if (format == "z") { + return make_uniq(LogicalType::GEOMETRY(), + make_uniq(ArrowVariableSizeType::NORMAL)); + } + if (format == "Z") { + return make_uniq(LogicalType::GEOMETRY(), + make_uniq(ArrowVariableSizeType::SUPER_SIZE)); + } + if (format == "vz") { + return make_uniq(LogicalType::GEOMETRY(), + make_uniq(ArrowVariableSizeType::VIEW)); + } + throw InvalidInputException("Arrow extension type \"%s\" not supported for geoarrow.wkb", format.c_str()); + } + + static void PopulateSchema(DuckDBArrowSchemaHolder &root_holder, ArrowSchema &schema, const LogicalType &type, + ClientContext &context, const ArrowTypeExtension &extension) { + ArrowSchemaMetadata schema_metadata; + schema_metadata.AddOption(ArrowSchemaMetadata::ARROW_EXTENSION_NAME, "geoarrow.wkb"); + schema_metadata.AddOption(ArrowSchemaMetadata::ARROW_METADATA_KEY, "{}"); + root_holder.metadata_info.emplace_back(schema_metadata.SerializeMetadata()); + schema.metadata = root_holder.metadata_info.back().get(); + + const auto options = context.GetClientProperties(); + if (options.arrow_offset_size == ArrowOffsetSize::LARGE) { + schema.format = "Z"; + } else { + schema.format = "z"; + } + } + + static void ArrowToDuck(ClientContext &, Vector &source, Vector &result, idx_t count) { + Geometry::FromBinary(source, result, count, true); + } + + static void DuckToArrow(ClientContext &context, Vector &source, Vector &result, idx_t count) { + Geometry::ToBinary(source, result, count); + } +}; + void ArrowTypeExtensionSet::Initialize(const DBConfig &config) { // Types that are 1:1 config.RegisterArrowExtension({"arrow.uuid", "w:16", make_shared_ptr(LogicalType::UUID)}); @@ -380,6 +448,11 @@ void ArrowTypeExtensionSet::Initialize(const DBConfig &config) { config.RegisterArrowExtension( {"DuckDB", "time_tz", "w:8", make_shared_ptr(LogicalType::TIME_TZ)}); + config.RegisterArrowExtension( + {"geoarrow.wkb", ArrowGeometry::PopulateSchema, ArrowGeometry::GetType, + make_shared_ptr(LogicalType::GEOMETRY(), LogicalType::BLOB, ArrowGeometry::ArrowToDuck, + ArrowGeometry::DuckToArrow)}); + // Types that are 1:n config.RegisterArrowExtension({"arrow.json", &ArrowJson::PopulateSchema, &ArrowJson::GetType, make_shared_ptr(LogicalType::JSON())}); diff --git a/src/duckdb/src/common/arrow/physical_arrow_collector.cpp b/src/duckdb/src/common/arrow/physical_arrow_collector.cpp index 0636865be..a8b225f75 100644 --- a/src/duckdb/src/common/arrow/physical_arrow_collector.cpp +++ b/src/duckdb/src/common/arrow/physical_arrow_collector.cpp @@ -88,7 +88,7 @@ SinkCombineResultType PhysicalArrowCollector::Combine(ExecutionContext &context, return SinkCombineResultType::FINISHED; } -unique_ptr PhysicalArrowCollector::GetResult(GlobalSinkState &state_p) { +unique_ptr PhysicalArrowCollector::GetResult(GlobalSinkState &state_p) const { auto &gstate = state_p.Cast(); return std::move(gstate.result); } diff --git a/src/duckdb/src/common/bignum.cpp b/src/duckdb/src/common/bignum.cpp index fb3613e88..4414b3a5b 100644 --- a/src/duckdb/src/common/bignum.cpp +++ b/src/duckdb/src/common/bignum.cpp @@ -1,30 +1,39 @@ #include "duckdb/common/bignum.hpp" #include "duckdb/common/types/bignum.hpp" -#include +#include "duckdb/common/printer.hpp" +#include "duckdb/common/to_string.hpp" namespace duckdb { void PrintBits(const char value) { + string result; for (int i = 7; i >= 0; --i) { - std::cout << ((value >> i) & 1); + result += to_string((value >> i) & 1); } + Printer::RawPrint(OutputStream::STREAM_STDOUT, result); } void bignum_t::Print() const { auto ptr = data.GetData(); auto length = data.GetSize(); + string result; for (idx_t i = 0; i < length; ++i) { - PrintBits(ptr[i]); - std::cout << " "; + for (int j = 7; j >= 0; --j) { + result += to_string((ptr[i] >> j) & 1); + } + result += " "; } - std::cout << '\n'; + Printer::Print(OutputStream::STREAM_STDOUT, result); } void BignumIntermediate::Print() const { + string result; for (idx_t i = 0; i < size; ++i) { - PrintBits(static_cast(data[i])); - std::cout << " "; + for (int j = 7; j >= 0; --j) { + result += to_string((data[i] >> j) & 1); + } + result += " "; } - std::cout << '\n'; + Printer::Print(OutputStream::STREAM_STDOUT, result); } BignumIntermediate::BignumIntermediate(const bignum_t &value) { @@ -232,7 +241,6 @@ void BignumAddition(data_ptr_t result, int64_t result_end, bool is_target_absolu } string_t BignumIntermediate::Negate(Vector &result_vector) const { - auto target = StringVector::EmptyString(result_vector, size + Bignum::BIGNUM_HEADER_SIZE); auto ptr = target.GetDataWriteable(); diff --git a/src/duckdb/src/common/csv_writer.cpp b/src/duckdb/src/common/csv_writer.cpp index bb9ff81d2..8f8992347 100644 --- a/src/duckdb/src/common/csv_writer.cpp +++ b/src/duckdb/src/common/csv_writer.cpp @@ -16,7 +16,7 @@ CSVWriterState::CSVWriterState() } CSVWriterState::CSVWriterState(ClientContext &context, idx_t flush_size_p) - : flush_size(flush_size_p), stream(make_uniq(Allocator::Get(context))) { + : flush_size(flush_size_p), stream(make_uniq(Allocator::Get(context), flush_size)) { } CSVWriterState::CSVWriterState(DatabaseInstance &db, idx_t flush_size_p) @@ -71,7 +71,6 @@ CSVWriter::CSVWriter(CSVReaderOptions &options_p, FileSystem &fs, const string & FileFlags::FILE_FLAGS_WRITE | FileFlags::FILE_FLAGS_FILE_CREATE_NEW | FileLockType::WRITE_LOCK | compression)), write_stream(*file_writer), should_initialize(true), shared(shared) { - if (!shared) { global_write_state = make_uniq(); } @@ -198,18 +197,6 @@ void CSVWriter::ResetInternal(optional_ptr local_state) { bytes_written = 0; } -unique_ptr CSVWriter::InitializeLocalWriteState(ClientContext &context, idx_t flush_size) { - auto res = make_uniq(context, flush_size); - res->stream = make_uniq(); - return res; -} - -unique_ptr CSVWriter::InitializeLocalWriteState(DatabaseInstance &db, idx_t flush_size) { - auto res = make_uniq(db, flush_size); - res->stream = make_uniq(); - return res; -} - idx_t CSVWriter::BytesWritten() { if (shared) { lock_guard flock(lock); diff --git a/src/duckdb/src/common/encryption_key_manager.cpp b/src/duckdb/src/common/encryption_key_manager.cpp index 482c4a006..9d16a159b 100644 --- a/src/duckdb/src/common/encryption_key_manager.cpp +++ b/src/duckdb/src/common/encryption_key_manager.cpp @@ -31,6 +31,8 @@ EncryptionKey::~EncryptionKey() { void EncryptionKey::LockEncryptionKey(data_ptr_t key, idx_t key_len) { #if defined(_WIN32) VirtualLock(key, key_len); +#elif defined(__MVS__) + __mlockall(_BPX_NONSWAP); #else mlock(key, key_len); #endif @@ -40,6 +42,8 @@ void EncryptionKey::UnlockEncryptionKey(data_ptr_t key, idx_t key_len) { memset(key, 0, key_len); #if defined(_WIN32) VirtualUnlock(key, key_len); +#elif defined(__MVS__) + __mlockall(_BPX_SWAP); #else munlock(key, key_len); #endif diff --git a/src/duckdb/src/common/enum_util.cpp b/src/duckdb/src/common/enum_util.cpp index 324ba7004..c5e015978 100644 --- a/src/duckdb/src/common/enum_util.cpp +++ b/src/duckdb/src/common/enum_util.cpp @@ -82,13 +82,13 @@ #include "duckdb/common/multi_file/multi_file_options.hpp" #include "duckdb/common/operator/decimal_cast_operators.hpp" #include "duckdb/common/printer.hpp" -#include "duckdb/common/sort/partition_state.hpp" #include "duckdb/common/sorting/sort_key.hpp" #include "duckdb/common/types.hpp" #include "duckdb/common/types/column/column_data_scan_states.hpp" #include "duckdb/common/types/column/partitioned_column_data.hpp" #include "duckdb/common/types/conflict_manager.hpp" #include "duckdb/common/types/date.hpp" +#include "duckdb/common/types/geometry.hpp" #include "duckdb/common/types/hyperloglog.hpp" #include "duckdb/common/types/row/block_iterator.hpp" #include "duckdb/common/types/row/partitioned_tuple_data.hpp" @@ -101,8 +101,10 @@ #include "duckdb/execution/index/art/art_scanner.hpp" #include "duckdb/execution/index/art/node.hpp" #include "duckdb/execution/index/bound_index.hpp" +#include "duckdb/execution/index/unbound_index.hpp" #include "duckdb/execution/operator/csv_scanner/csv_option.hpp" #include "duckdb/execution/operator/csv_scanner/csv_state.hpp" +#include "duckdb/execution/physical_table_scan_enum.hpp" #include "duckdb/execution/reservoir_sample.hpp" #include "duckdb/function/aggregate_state.hpp" #include "duckdb/function/compression_function.hpp" @@ -121,15 +123,18 @@ #include "duckdb/logging/log_storage.hpp" #include "duckdb/logging/logging.hpp" #include "duckdb/main/appender.hpp" +#include "duckdb/main/attached_database.hpp" #include "duckdb/main/capi/capi_internal.hpp" #include "duckdb/main/error_manager.hpp" #include "duckdb/main/extension.hpp" #include "duckdb/main/extension_helper.hpp" #include "duckdb/main/extension_install_info.hpp" +#include "duckdb/main/query_parameters.hpp" #include "duckdb/main/query_profiler.hpp" #include "duckdb/main/query_result.hpp" #include "duckdb/main/secret/secret.hpp" #include "duckdb/main/setting_info.hpp" +#include "duckdb/parallel/async_result.hpp" #include "duckdb/parallel/interrupt.hpp" #include "duckdb/parallel/meta_pipeline.hpp" #include "duckdb/parallel/task.hpp" @@ -631,6 +636,45 @@ ArrowVariableSizeType EnumUtil::FromString(const char *va return static_cast(StringUtil::StringToEnum(GetArrowVariableSizeTypeValues(), 4, "ArrowVariableSizeType", value)); } +const StringUtil::EnumStringLiteral *GetAsyncResultTypeValues() { + static constexpr StringUtil::EnumStringLiteral values[] { + { static_cast(AsyncResultType::INVALID), "INVALID" }, + { static_cast(AsyncResultType::IMPLICIT), "IMPLICIT" }, + { static_cast(AsyncResultType::HAVE_MORE_OUTPUT), "HAVE_MORE_OUTPUT" }, + { static_cast(AsyncResultType::FINISHED), "FINISHED" }, + { static_cast(AsyncResultType::BLOCKED), "BLOCKED" } + }; + return values; +} + +template<> +const char* EnumUtil::ToChars(AsyncResultType value) { + return StringUtil::EnumToString(GetAsyncResultTypeValues(), 5, "AsyncResultType", static_cast(value)); +} + +template<> +AsyncResultType EnumUtil::FromString(const char *value) { + return static_cast(StringUtil::StringToEnum(GetAsyncResultTypeValues(), 5, "AsyncResultType", value)); +} + +const StringUtil::EnumStringLiteral *GetAsyncResultsExecutionModeValues() { + static constexpr StringUtil::EnumStringLiteral values[] { + { static_cast(AsyncResultsExecutionMode::SYNCHRONOUS), "SYNCHRONOUS" }, + { static_cast(AsyncResultsExecutionMode::TASK_EXECUTOR), "TASK_EXECUTOR" } + }; + return values; +} + +template<> +const char* EnumUtil::ToChars(AsyncResultsExecutionMode value) { + return StringUtil::EnumToString(GetAsyncResultsExecutionModeValues(), 2, "AsyncResultsExecutionMode", static_cast(value)); +} + +template<> +AsyncResultsExecutionMode EnumUtil::FromString(const char *value) { + return static_cast(StringUtil::StringToEnum(GetAsyncResultsExecutionModeValues(), 2, "AsyncResultsExecutionMode", value)); +} + const StringUtil::EnumStringLiteral *GetBinderTypeValues() { static constexpr StringUtil::EnumStringLiteral values[] { { static_cast(BinderType::REGULAR_BINDER), "REGULAR_BINDER" }, @@ -727,6 +771,24 @@ BlockState EnumUtil::FromString(const char *value) { return static_cast(StringUtil::StringToEnum(GetBlockStateValues(), 2, "BlockState", value)); } +const StringUtil::EnumStringLiteral *GetBufferedIndexReplayValues() { + static constexpr StringUtil::EnumStringLiteral values[] { + { static_cast(BufferedIndexReplay::INSERT_ENTRY), "INSERT_ENTRY" }, + { static_cast(BufferedIndexReplay::DEL_ENTRY), "DEL_ENTRY" } + }; + return values; +} + +template<> +const char* EnumUtil::ToChars(BufferedIndexReplay value) { + return StringUtil::EnumToString(GetBufferedIndexReplayValues(), 2, "BufferedIndexReplay", static_cast(value)); +} + +template<> +BufferedIndexReplay EnumUtil::FromString(const char *value) { + return static_cast(StringUtil::StringToEnum(GetBufferedIndexReplayValues(), 2, "BufferedIndexReplay", value)); +} + const StringUtil::EnumStringLiteral *GetCAPIResultSetTypeValues() { static constexpr StringUtil::EnumStringLiteral values[] { { static_cast(CAPIResultSetType::CAPI_RESULT_TYPE_NONE), "CAPI_RESULT_TYPE_NONE" }, @@ -1464,19 +1526,20 @@ const StringUtil::EnumStringLiteral *GetExplainFormatValues() { { static_cast(ExplainFormat::JSON), "JSON" }, { static_cast(ExplainFormat::HTML), "HTML" }, { static_cast(ExplainFormat::GRAPHVIZ), "GRAPHVIZ" }, - { static_cast(ExplainFormat::YAML), "YAML" } + { static_cast(ExplainFormat::YAML), "YAML" }, + { static_cast(ExplainFormat::MERMAID), "MERMAID" } }; return values; } template<> const char* EnumUtil::ToChars(ExplainFormat value) { - return StringUtil::EnumToString(GetExplainFormatValues(), 6, "ExplainFormat", static_cast(value)); + return StringUtil::EnumToString(GetExplainFormatValues(), 7, "ExplainFormat", static_cast(value)); } template<> ExplainFormat EnumUtil::FromString(const char *value) { - return static_cast(StringUtil::StringToEnum(GetExplainFormatValues(), 6, "ExplainFormat", value)); + return static_cast(StringUtil::StringToEnum(GetExplainFormatValues(), 7, "ExplainFormat", value)); } const StringUtil::EnumStringLiteral *GetExplainOutputTypeValues() { @@ -1795,19 +1858,20 @@ const StringUtil::EnumStringLiteral *GetExtraTypeInfoTypeValues() { { static_cast(ExtraTypeInfoType::ARRAY_TYPE_INFO), "ARRAY_TYPE_INFO" }, { static_cast(ExtraTypeInfoType::ANY_TYPE_INFO), "ANY_TYPE_INFO" }, { static_cast(ExtraTypeInfoType::INTEGER_LITERAL_TYPE_INFO), "INTEGER_LITERAL_TYPE_INFO" }, - { static_cast(ExtraTypeInfoType::TEMPLATE_TYPE_INFO), "TEMPLATE_TYPE_INFO" } + { static_cast(ExtraTypeInfoType::TEMPLATE_TYPE_INFO), "TEMPLATE_TYPE_INFO" }, + { static_cast(ExtraTypeInfoType::GEO_TYPE_INFO), "GEO_TYPE_INFO" } }; return values; } template<> const char* EnumUtil::ToChars(ExtraTypeInfoType value) { - return StringUtil::EnumToString(GetExtraTypeInfoTypeValues(), 13, "ExtraTypeInfoType", static_cast(value)); + return StringUtil::EnumToString(GetExtraTypeInfoTypeValues(), 14, "ExtraTypeInfoType", static_cast(value)); } template<> ExtraTypeInfoType EnumUtil::FromString(const char *value) { - return static_cast(StringUtil::StringToEnum(GetExtraTypeInfoTypeValues(), 13, "ExtraTypeInfoType", value)); + return static_cast(StringUtil::StringToEnum(GetExtraTypeInfoTypeValues(), 14, "ExtraTypeInfoType", value)); } const StringUtil::EnumStringLiteral *GetFileBufferTypeValues() { @@ -2059,6 +2123,30 @@ GateStatus EnumUtil::FromString(const char *value) { return static_cast(StringUtil::StringToEnum(GetGateStatusValues(), 2, "GateStatus", value)); } +const StringUtil::EnumStringLiteral *GetGeometryTypeValues() { + static constexpr StringUtil::EnumStringLiteral values[] { + { static_cast(GeometryType::INVALID), "INVALID" }, + { static_cast(GeometryType::POINT), "POINT" }, + { static_cast(GeometryType::LINESTRING), "LINESTRING" }, + { static_cast(GeometryType::POLYGON), "POLYGON" }, + { static_cast(GeometryType::MULTIPOINT), "MULTIPOINT" }, + { static_cast(GeometryType::MULTILINESTRING), "MULTILINESTRING" }, + { static_cast(GeometryType::MULTIPOLYGON), "MULTIPOLYGON" }, + { static_cast(GeometryType::GEOMETRYCOLLECTION), "GEOMETRYCOLLECTION" } + }; + return values; +} + +template<> +const char* EnumUtil::ToChars(GeometryType value) { + return StringUtil::EnumToString(GetGeometryTypeValues(), 8, "GeometryType", static_cast(value)); +} + +template<> +GeometryType EnumUtil::FromString(const char *value) { + return static_cast(StringUtil::StringToEnum(GetGeometryTypeValues(), 8, "GeometryType", value)); +} + const StringUtil::EnumStringLiteral *GetHLLStorageTypeValues() { static constexpr StringUtil::EnumStringLiteral values[] { { static_cast(HLLStorageType::HLL_V1), "HLL_V1" }, @@ -2599,6 +2687,7 @@ const StringUtil::EnumStringLiteral *GetLogicalTypeIdValues() { { static_cast(LogicalTypeId::POINTER), "POINTER" }, { static_cast(LogicalTypeId::VALIDITY), "VALIDITY" }, { static_cast(LogicalTypeId::UUID), "UUID" }, + { static_cast(LogicalTypeId::GEOMETRY), "GEOMETRY" }, { static_cast(LogicalTypeId::STRUCT), "STRUCT" }, { static_cast(LogicalTypeId::LIST), "LIST" }, { static_cast(LogicalTypeId::MAP), "MAP" }, @@ -2615,12 +2704,12 @@ const StringUtil::EnumStringLiteral *GetLogicalTypeIdValues() { template<> const char* EnumUtil::ToChars(LogicalTypeId value) { - return StringUtil::EnumToString(GetLogicalTypeIdValues(), 50, "LogicalTypeId", static_cast(value)); + return StringUtil::EnumToString(GetLogicalTypeIdValues(), 51, "LogicalTypeId", static_cast(value)); } template<> LogicalTypeId EnumUtil::FromString(const char *value) { - return static_cast(StringUtil::StringToEnum(GetLogicalTypeIdValues(), 50, "LogicalTypeId", value)); + return static_cast(StringUtil::StringToEnum(GetLogicalTypeIdValues(), 51, "LogicalTypeId", value)); } const StringUtil::EnumStringLiteral *GetLookupResultTypeValues() { @@ -2772,32 +2861,38 @@ MetaPipelineType EnumUtil::FromString(const char *value) { const StringUtil::EnumStringLiteral *GetMetricsTypeValues() { static constexpr StringUtil::EnumStringLiteral values[] { - { static_cast(MetricsType::QUERY_NAME), "QUERY_NAME" }, + { static_cast(MetricsType::ATTACH_LOAD_STORAGE_LATENCY), "ATTACH_LOAD_STORAGE_LATENCY" }, + { static_cast(MetricsType::ATTACH_REPLAY_WAL_LATENCY), "ATTACH_REPLAY_WAL_LATENCY" }, { static_cast(MetricsType::BLOCKED_THREAD_TIME), "BLOCKED_THREAD_TIME" }, + { static_cast(MetricsType::CHECKPOINT_LATENCY), "CHECKPOINT_LATENCY" }, + { static_cast(MetricsType::COMMIT_WRITE_WAL_LATENCY), "COMMIT_WRITE_WAL_LATENCY" }, { static_cast(MetricsType::CPU_TIME), "CPU_TIME" }, - { static_cast(MetricsType::EXTRA_INFO), "EXTRA_INFO" }, { static_cast(MetricsType::CUMULATIVE_CARDINALITY), "CUMULATIVE_CARDINALITY" }, - { static_cast(MetricsType::OPERATOR_TYPE), "OPERATOR_TYPE" }, - { static_cast(MetricsType::OPERATOR_CARDINALITY), "OPERATOR_CARDINALITY" }, { static_cast(MetricsType::CUMULATIVE_ROWS_SCANNED), "CUMULATIVE_ROWS_SCANNED" }, + { static_cast(MetricsType::EXTRA_INFO), "EXTRA_INFO" }, + { static_cast(MetricsType::LATENCY), "LATENCY" }, + { static_cast(MetricsType::OPERATOR_CARDINALITY), "OPERATOR_CARDINALITY" }, + { static_cast(MetricsType::OPERATOR_NAME), "OPERATOR_NAME" }, { static_cast(MetricsType::OPERATOR_ROWS_SCANNED), "OPERATOR_ROWS_SCANNED" }, { static_cast(MetricsType::OPERATOR_TIMING), "OPERATOR_TIMING" }, + { static_cast(MetricsType::OPERATOR_TYPE), "OPERATOR_TYPE" }, + { static_cast(MetricsType::QUERY_NAME), "QUERY_NAME" }, { static_cast(MetricsType::RESULT_SET_SIZE), "RESULT_SET_SIZE" }, - { static_cast(MetricsType::LATENCY), "LATENCY" }, { static_cast(MetricsType::ROWS_RETURNED), "ROWS_RETURNED" }, - { static_cast(MetricsType::OPERATOR_NAME), "OPERATOR_NAME" }, { static_cast(MetricsType::SYSTEM_PEAK_BUFFER_MEMORY), "SYSTEM_PEAK_BUFFER_MEMORY" }, { static_cast(MetricsType::SYSTEM_PEAK_TEMP_DIR_SIZE), "SYSTEM_PEAK_TEMP_DIR_SIZE" }, { static_cast(MetricsType::TOTAL_BYTES_READ), "TOTAL_BYTES_READ" }, { static_cast(MetricsType::TOTAL_BYTES_WRITTEN), "TOTAL_BYTES_WRITTEN" }, + { static_cast(MetricsType::WAITING_TO_ATTACH_LATENCY), "WAITING_TO_ATTACH_LATENCY" }, + { static_cast(MetricsType::WAL_REPLAY_ENTRY_COUNT), "WAL_REPLAY_ENTRY_COUNT" }, { static_cast(MetricsType::ALL_OPTIMIZERS), "ALL_OPTIMIZERS" }, { static_cast(MetricsType::CUMULATIVE_OPTIMIZER_TIMING), "CUMULATIVE_OPTIMIZER_TIMING" }, - { static_cast(MetricsType::PLANNER), "PLANNER" }, - { static_cast(MetricsType::PLANNER_BINDING), "PLANNER_BINDING" }, { static_cast(MetricsType::PHYSICAL_PLANNER), "PHYSICAL_PLANNER" }, { static_cast(MetricsType::PHYSICAL_PLANNER_COLUMN_BINDING), "PHYSICAL_PLANNER_COLUMN_BINDING" }, - { static_cast(MetricsType::PHYSICAL_PLANNER_RESOLVE_TYPES), "PHYSICAL_PLANNER_RESOLVE_TYPES" }, { static_cast(MetricsType::PHYSICAL_PLANNER_CREATE_PLAN), "PHYSICAL_PLANNER_CREATE_PLAN" }, + { static_cast(MetricsType::PHYSICAL_PLANNER_RESOLVE_TYPES), "PHYSICAL_PLANNER_RESOLVE_TYPES" }, + { static_cast(MetricsType::PLANNER), "PLANNER" }, + { static_cast(MetricsType::PLANNER_BINDING), "PLANNER_BINDING" }, { static_cast(MetricsType::OPTIMIZER_EXPRESSION_REWRITER), "OPTIMIZER_EXPRESSION_REWRITER" }, { static_cast(MetricsType::OPTIMIZER_FILTER_PULLUP), "OPTIMIZER_FILTER_PULLUP" }, { static_cast(MetricsType::OPTIMIZER_FILTER_PUSHDOWN), "OPTIMIZER_FILTER_PUSHDOWN" }, @@ -2816,6 +2911,7 @@ const StringUtil::EnumStringLiteral *GetMetricsTypeValues() { { static_cast(MetricsType::OPTIMIZER_BUILD_SIDE_PROBE_SIDE), "OPTIMIZER_BUILD_SIDE_PROBE_SIDE" }, { static_cast(MetricsType::OPTIMIZER_LIMIT_PUSHDOWN), "OPTIMIZER_LIMIT_PUSHDOWN" }, { static_cast(MetricsType::OPTIMIZER_TOP_N), "OPTIMIZER_TOP_N" }, + { static_cast(MetricsType::OPTIMIZER_TOP_N_WINDOW_ELIMINATION), "OPTIMIZER_TOP_N_WINDOW_ELIMINATION" }, { static_cast(MetricsType::OPTIMIZER_COMPRESSED_MATERIALIZATION), "OPTIMIZER_COMPRESSED_MATERIALIZATION" }, { static_cast(MetricsType::OPTIMIZER_DUPLICATE_GROUPS), "OPTIMIZER_DUPLICATE_GROUPS" }, { static_cast(MetricsType::OPTIMIZER_REORDER_FILTER), "OPTIMIZER_REORDER_FILTER" }, @@ -2825,19 +2921,20 @@ const StringUtil::EnumStringLiteral *GetMetricsTypeValues() { { static_cast(MetricsType::OPTIMIZER_MATERIALIZED_CTE), "OPTIMIZER_MATERIALIZED_CTE" }, { static_cast(MetricsType::OPTIMIZER_SUM_REWRITER), "OPTIMIZER_SUM_REWRITER" }, { static_cast(MetricsType::OPTIMIZER_LATE_MATERIALIZATION), "OPTIMIZER_LATE_MATERIALIZATION" }, - { static_cast(MetricsType::OPTIMIZER_CTE_INLINING), "OPTIMIZER_CTE_INLINING" } + { static_cast(MetricsType::OPTIMIZER_CTE_INLINING), "OPTIMIZER_CTE_INLINING" }, + { static_cast(MetricsType::OPTIMIZER_COMMON_SUBPLAN), "OPTIMIZER_COMMON_SUBPLAN" } }; return values; } template<> const char* EnumUtil::ToChars(MetricsType value) { - return StringUtil::EnumToString(GetMetricsTypeValues(), 54, "MetricsType", static_cast(value)); + return StringUtil::EnumToString(GetMetricsTypeValues(), 62, "MetricsType", static_cast(value)); } template<> MetricsType EnumUtil::FromString(const char *value) { - return static_cast(StringUtil::StringToEnum(GetMetricsTypeValues(), 54, "MetricsType", value)); + return static_cast(StringUtil::StringToEnum(GetMetricsTypeValues(), 62, "MetricsType", value)); } const StringUtil::EnumStringLiteral *GetMultiFileColumnMappingModeValues() { @@ -3060,6 +3157,7 @@ const StringUtil::EnumStringLiteral *GetOptimizerTypeValues() { { static_cast(OptimizerType::BUILD_SIDE_PROBE_SIDE), "BUILD_SIDE_PROBE_SIDE" }, { static_cast(OptimizerType::LIMIT_PUSHDOWN), "LIMIT_PUSHDOWN" }, { static_cast(OptimizerType::TOP_N), "TOP_N" }, + { static_cast(OptimizerType::TOP_N_WINDOW_ELIMINATION), "TOP_N_WINDOW_ELIMINATION" }, { static_cast(OptimizerType::COMPRESSED_MATERIALIZATION), "COMPRESSED_MATERIALIZATION" }, { static_cast(OptimizerType::DUPLICATE_GROUPS), "DUPLICATE_GROUPS" }, { static_cast(OptimizerType::REORDER_FILTER), "REORDER_FILTER" }, @@ -3069,19 +3167,20 @@ const StringUtil::EnumStringLiteral *GetOptimizerTypeValues() { { static_cast(OptimizerType::MATERIALIZED_CTE), "MATERIALIZED_CTE" }, { static_cast(OptimizerType::SUM_REWRITER), "SUM_REWRITER" }, { static_cast(OptimizerType::LATE_MATERIALIZATION), "LATE_MATERIALIZATION" }, - { static_cast(OptimizerType::CTE_INLINING), "CTE_INLINING" } + { static_cast(OptimizerType::CTE_INLINING), "CTE_INLINING" }, + { static_cast(OptimizerType::COMMON_SUBPLAN), "COMMON_SUBPLAN" } }; return values; } template<> const char* EnumUtil::ToChars(OptimizerType value) { - return StringUtil::EnumToString(GetOptimizerTypeValues(), 29, "OptimizerType", static_cast(value)); + return StringUtil::EnumToString(GetOptimizerTypeValues(), 31, "OptimizerType", static_cast(value)); } template<> OptimizerType EnumUtil::FromString(const char *value) { - return static_cast(StringUtil::StringToEnum(GetOptimizerTypeValues(), 29, "OptimizerType", value)); + return static_cast(StringUtil::StringToEnum(GetOptimizerTypeValues(), 31, "OptimizerType", value)); } const StringUtil::EnumStringLiteral *GetOrderByNullTypeValues() { @@ -3237,28 +3336,6 @@ ParserExtensionResultType EnumUtil::FromString(const return static_cast(StringUtil::StringToEnum(GetParserExtensionResultTypeValues(), 3, "ParserExtensionResultType", value)); } -const StringUtil::EnumStringLiteral *GetPartitionSortStageValues() { - static constexpr StringUtil::EnumStringLiteral values[] { - { static_cast(PartitionSortStage::INIT), "INIT" }, - { static_cast(PartitionSortStage::SCAN), "SCAN" }, - { static_cast(PartitionSortStage::PREPARE), "PREPARE" }, - { static_cast(PartitionSortStage::MERGE), "MERGE" }, - { static_cast(PartitionSortStage::SORTED), "SORTED" }, - { static_cast(PartitionSortStage::FINISHED), "FINISHED" } - }; - return values; -} - -template<> -const char* EnumUtil::ToChars(PartitionSortStage value) { - return StringUtil::EnumToString(GetPartitionSortStageValues(), 6, "PartitionSortStage", static_cast(value)); -} - -template<> -PartitionSortStage EnumUtil::FromString(const char *value) { - return static_cast(StringUtil::StringToEnum(GetPartitionSortStageValues(), 6, "PartitionSortStage", value)); -} - const StringUtil::EnumStringLiteral *GetPartitionedColumnDataTypeValues() { static constexpr StringUtil::EnumStringLiteral values[] { { static_cast(PartitionedColumnDataType::INVALID), "INVALID" }, @@ -3416,6 +3493,26 @@ PhysicalOperatorType EnumUtil::FromString(const char *valu return static_cast(StringUtil::StringToEnum(GetPhysicalOperatorTypeValues(), 82, "PhysicalOperatorType", value)); } +const StringUtil::EnumStringLiteral *GetPhysicalTableScanExecutionStrategyValues() { + static constexpr StringUtil::EnumStringLiteral values[] { + { static_cast(PhysicalTableScanExecutionStrategy::DEFAULT), "DEFAULT" }, + { static_cast(PhysicalTableScanExecutionStrategy::TASK_EXECUTOR), "TASK_EXECUTOR" }, + { static_cast(PhysicalTableScanExecutionStrategy::SYNCHRONOUS), "SYNCHRONOUS" }, + { static_cast(PhysicalTableScanExecutionStrategy::TASK_EXECUTOR_BUT_FORCE_SYNC_CHECKS), "TASK_EXECUTOR_BUT_FORCE_SYNC_CHECKS" } + }; + return values; +} + +template<> +const char* EnumUtil::ToChars(PhysicalTableScanExecutionStrategy value) { + return StringUtil::EnumToString(GetPhysicalTableScanExecutionStrategyValues(), 4, "PhysicalTableScanExecutionStrategy", static_cast(value)); +} + +template<> +PhysicalTableScanExecutionStrategy EnumUtil::FromString(const char *value) { + return static_cast(StringUtil::StringToEnum(GetPhysicalTableScanExecutionStrategyValues(), 4, "PhysicalTableScanExecutionStrategy", value)); +} + const StringUtil::EnumStringLiteral *GetPhysicalTypeValues() { static constexpr StringUtil::EnumStringLiteral values[] { { static_cast(PhysicalType::BOOL), "BOOL" }, @@ -3535,19 +3632,20 @@ const StringUtil::EnumStringLiteral *GetProfilerPrintFormatValues() { { static_cast(ProfilerPrintFormat::QUERY_TREE_OPTIMIZER), "QUERY_TREE_OPTIMIZER" }, { static_cast(ProfilerPrintFormat::NO_OUTPUT), "NO_OUTPUT" }, { static_cast(ProfilerPrintFormat::HTML), "HTML" }, - { static_cast(ProfilerPrintFormat::GRAPHVIZ), "GRAPHVIZ" } + { static_cast(ProfilerPrintFormat::GRAPHVIZ), "GRAPHVIZ" }, + { static_cast(ProfilerPrintFormat::MERMAID), "MERMAID" } }; return values; } template<> const char* EnumUtil::ToChars(ProfilerPrintFormat value) { - return StringUtil::EnumToString(GetProfilerPrintFormatValues(), 6, "ProfilerPrintFormat", static_cast(value)); + return StringUtil::EnumToString(GetProfilerPrintFormatValues(), 7, "ProfilerPrintFormat", static_cast(value)); } template<> ProfilerPrintFormat EnumUtil::FromString(const char *value) { - return static_cast(StringUtil::StringToEnum(GetProfilerPrintFormatValues(), 6, "ProfilerPrintFormat", value)); + return static_cast(StringUtil::StringToEnum(GetProfilerPrintFormatValues(), 7, "ProfilerPrintFormat", value)); } const StringUtil::EnumStringLiteral *GetProfilingCoverageValues() { @@ -3595,19 +3693,56 @@ const StringUtil::EnumStringLiteral *GetQueryNodeTypeValues() { { static_cast(QueryNodeType::SET_OPERATION_NODE), "SET_OPERATION_NODE" }, { static_cast(QueryNodeType::BOUND_SUBQUERY_NODE), "BOUND_SUBQUERY_NODE" }, { static_cast(QueryNodeType::RECURSIVE_CTE_NODE), "RECURSIVE_CTE_NODE" }, - { static_cast(QueryNodeType::CTE_NODE), "CTE_NODE" } + { static_cast(QueryNodeType::CTE_NODE), "CTE_NODE" }, + { static_cast(QueryNodeType::STATEMENT_NODE), "STATEMENT_NODE" } }; return values; } template<> const char* EnumUtil::ToChars(QueryNodeType value) { - return StringUtil::EnumToString(GetQueryNodeTypeValues(), 5, "QueryNodeType", static_cast(value)); + return StringUtil::EnumToString(GetQueryNodeTypeValues(), 6, "QueryNodeType", static_cast(value)); } template<> QueryNodeType EnumUtil::FromString(const char *value) { - return static_cast(StringUtil::StringToEnum(GetQueryNodeTypeValues(), 5, "QueryNodeType", value)); + return static_cast(StringUtil::StringToEnum(GetQueryNodeTypeValues(), 6, "QueryNodeType", value)); +} + +const StringUtil::EnumStringLiteral *GetQueryResultMemoryTypeValues() { + static constexpr StringUtil::EnumStringLiteral values[] { + { static_cast(QueryResultMemoryType::IN_MEMORY), "IN_MEMORY" }, + { static_cast(QueryResultMemoryType::BUFFER_MANAGED), "BUFFER_MANAGED" } + }; + return values; +} + +template<> +const char* EnumUtil::ToChars(QueryResultMemoryType value) { + return StringUtil::EnumToString(GetQueryResultMemoryTypeValues(), 2, "QueryResultMemoryType", static_cast(value)); +} + +template<> +QueryResultMemoryType EnumUtil::FromString(const char *value) { + return static_cast(StringUtil::StringToEnum(GetQueryResultMemoryTypeValues(), 2, "QueryResultMemoryType", value)); +} + +const StringUtil::EnumStringLiteral *GetQueryResultOutputTypeValues() { + static constexpr StringUtil::EnumStringLiteral values[] { + { static_cast(QueryResultOutputType::FORCE_MATERIALIZED), "FORCE_MATERIALIZED" }, + { static_cast(QueryResultOutputType::ALLOW_STREAMING), "ALLOW_STREAMING" } + }; + return values; +} + +template<> +const char* EnumUtil::ToChars(QueryResultOutputType value) { + return StringUtil::EnumToString(GetQueryResultOutputTypeValues(), 2, "QueryResultOutputType", static_cast(value)); +} + +template<> +QueryResultOutputType EnumUtil::FromString(const char *value) { + return static_cast(StringUtil::StringToEnum(GetQueryResultOutputTypeValues(), 2, "QueryResultOutputType", value)); } const StringUtil::EnumStringLiteral *GetQueryResultTypeValues() { @@ -3630,6 +3765,24 @@ QueryResultType EnumUtil::FromString(const char *value) { return static_cast(StringUtil::StringToEnum(GetQueryResultTypeValues(), 4, "QueryResultType", value)); } +const StringUtil::EnumStringLiteral *GetRecoveryModeValues() { + static constexpr StringUtil::EnumStringLiteral values[] { + { static_cast(RecoveryMode::DEFAULT), "DEFAULT" }, + { static_cast(RecoveryMode::NO_WAL_WRITES), "NO_WAL_WRITES" } + }; + return values; +} + +template<> +const char* EnumUtil::ToChars(RecoveryMode value) { + return StringUtil::EnumToString(GetRecoveryModeValues(), 2, "RecoveryMode", static_cast(value)); +} + +template<> +RecoveryMode EnumUtil::FromString(const char *value) { + return static_cast(StringUtil::StringToEnum(GetRecoveryModeValues(), 2, "RecoveryMode", value)); +} + const StringUtil::EnumStringLiteral *GetRelationTypeValues() { static constexpr StringUtil::EnumStringLiteral values[] { { static_cast(RelationType::INVALID_RELATION), "INVALID_RELATION" }, @@ -4220,19 +4373,20 @@ const StringUtil::EnumStringLiteral *GetStatisticsTypeValues() { { static_cast(StatisticsType::LIST_STATS), "LIST_STATS" }, { static_cast(StatisticsType::STRUCT_STATS), "STRUCT_STATS" }, { static_cast(StatisticsType::BASE_STATS), "BASE_STATS" }, - { static_cast(StatisticsType::ARRAY_STATS), "ARRAY_STATS" } + { static_cast(StatisticsType::ARRAY_STATS), "ARRAY_STATS" }, + { static_cast(StatisticsType::GEOMETRY_STATS), "GEOMETRY_STATS" } }; return values; } template<> const char* EnumUtil::ToChars(StatisticsType value) { - return StringUtil::EnumToString(GetStatisticsTypeValues(), 6, "StatisticsType", static_cast(value)); + return StringUtil::EnumToString(GetStatisticsTypeValues(), 7, "StatisticsType", static_cast(value)); } template<> StatisticsType EnumUtil::FromString(const char *value) { - return static_cast(StringUtil::StringToEnum(GetStatisticsTypeValues(), 6, "StatisticsType", value)); + return static_cast(StringUtil::StringToEnum(GetStatisticsTypeValues(), 7, "StatisticsType", value)); } const StringUtil::EnumStringLiteral *GetStatsInfoValues() { @@ -4806,6 +4960,7 @@ const StringUtil::EnumStringLiteral *GetVariantLogicalTypeValues() { { static_cast(VariantLogicalType::ARRAY), "ARRAY" }, { static_cast(VariantLogicalType::BIGNUM), "BIGNUM" }, { static_cast(VariantLogicalType::BITSTRING), "BITSTRING" }, + { static_cast(VariantLogicalType::GEOMETRY), "GEOMETRY" }, { static_cast(VariantLogicalType::ENUM_SIZE), "ENUM_SIZE" } }; return values; @@ -4813,12 +4968,12 @@ const StringUtil::EnumStringLiteral *GetVariantLogicalTypeValues() { template<> const char* EnumUtil::ToChars(VariantLogicalType value) { - return StringUtil::EnumToString(GetVariantLogicalTypeValues(), 34, "VariantLogicalType", static_cast(value)); + return StringUtil::EnumToString(GetVariantLogicalTypeValues(), 35, "VariantLogicalType", static_cast(value)); } template<> VariantLogicalType EnumUtil::FromString(const char *value) { - return static_cast(StringUtil::StringToEnum(GetVariantLogicalTypeValues(), 34, "VariantLogicalType", value)); + return static_cast(StringUtil::StringToEnum(GetVariantLogicalTypeValues(), 35, "VariantLogicalType", value)); } const StringUtil::EnumStringLiteral *GetVectorAuxiliaryDataTypeValues() { @@ -4931,6 +5086,26 @@ VerifyExistenceType EnumUtil::FromString(const char *value) return static_cast(StringUtil::StringToEnum(GetVerifyExistenceTypeValues(), 3, "VerifyExistenceType", value)); } +const StringUtil::EnumStringLiteral *GetVertexTypeValues() { + static constexpr StringUtil::EnumStringLiteral values[] { + { static_cast(VertexType::XY), "XY" }, + { static_cast(VertexType::XYZ), "XYZ" }, + { static_cast(VertexType::XYM), "XYM" }, + { static_cast(VertexType::XYZM), "XYZM" } + }; + return values; +} + +template<> +const char* EnumUtil::ToChars(VertexType value) { + return StringUtil::EnumToString(GetVertexTypeValues(), 4, "VertexType", static_cast(value)); +} + +template<> +VertexType EnumUtil::FromString(const char *value) { + return static_cast(StringUtil::StringToEnum(GetVertexTypeValues(), 4, "VertexType", value)); +} + const StringUtil::EnumStringLiteral *GetWALTypeValues() { static constexpr StringUtil::EnumStringLiteral values[] { { static_cast(WALType::INVALID), "INVALID" }, diff --git a/src/duckdb/src/common/enums/compression_type.cpp b/src/duckdb/src/common/enums/compression_type.cpp index ec551eff1..427cfbe91 100644 --- a/src/duckdb/src/common/enums/compression_type.cpp +++ b/src/duckdb/src/common/enums/compression_type.cpp @@ -17,25 +17,60 @@ vector ListCompressionTypes(void) { return compression_types; } -bool CompressionTypeIsDeprecated(CompressionType compression_type, optional_ptr storage_manager) { - vector types({CompressionType::COMPRESSION_PATAS, CompressionType::COMPRESSION_CHIMP}); - if (storage_manager) { - if (storage_manager->GetStorageVersion() >= 5) { - //! NOTE: storage_manager is an optional_ptr because it's called from ForceCompressionSetting, which doesn't - //! have guaranteed access to a StorageManager The introduction of DICT_FSST deprecates Dictionary and FSST - //! compression methods - types.emplace_back(CompressionType::COMPRESSION_DICTIONARY); - types.emplace_back(CompressionType::COMPRESSION_FSST); - } else { - types.emplace_back(CompressionType::COMPRESSION_DICT_FSST); - } +namespace { +struct CompressionMethodRequirements { + CompressionType type; + optional_idx minimum_storage_version; + optional_idx maximum_storage_version; +}; +} // namespace + +CompressionAvailabilityResult CompressionTypeIsAvailable(CompressionType compression_type, + optional_ptr storage_manager) { + //! Max storage compatibility + vector candidates({{CompressionType::COMPRESSION_PATAS, optional_idx(), 0}, + {CompressionType::COMPRESSION_CHIMP, optional_idx(), 0}, + {CompressionType::COMPRESSION_DICTIONARY, 0, 4}, + {CompressionType::COMPRESSION_FSST, 0, 4}, + {CompressionType::COMPRESSION_DICT_FSST, 5, optional_idx()}}); + + optional_idx current_storage_version; + if (storage_manager && storage_manager->HasStorageVersion()) { + current_storage_version = storage_manager->GetStorageVersion(); } - for (auto &type : types) { - if (type == compression_type) { - return true; + for (auto &candidate : candidates) { + auto &type = candidate.type; + if (type != compression_type) { + continue; + } + auto &min = candidate.minimum_storage_version; + auto &max = candidate.maximum_storage_version; + + if (!min.IsValid()) { + //! Used to signal: always deprecated + return CompressionAvailabilityResult::Deprecated(); + } + + if (!current_storage_version.IsValid()) { + //! Can't determine in this call whether it's available or not, default to available + return CompressionAvailabilityResult(); + } + + auto current_version = current_storage_version.GetIndex(); + D_ASSERT(min.IsValid()); + if (min.GetIndex() > current_version) { + //! Minimum required storage version is higher than the current storage version, this method isn't available + //! yet + return CompressionAvailabilityResult::NotAvailableYet(); + } + if (max.IsValid() && max.GetIndex() < current_version) { + //! Maximum supported storage version is lower than the current storage version, this method is no longer + //! available + return CompressionAvailabilityResult::Deprecated(); } + return CompressionAvailabilityResult(); } - return false; + return CompressionAvailabilityResult(); } CompressionType CompressionTypeFromString(const string &str) { diff --git a/src/duckdb/src/common/enums/metric_type.cpp b/src/duckdb/src/common/enums/metric_type.cpp index 866049251..84b552037 100644 --- a/src/duckdb/src/common/enums/metric_type.cpp +++ b/src/duckdb/src/common/enums/metric_type.cpp @@ -31,6 +31,7 @@ profiler_settings_t MetricsUtils::GetOptimizerMetrics() { MetricsType::OPTIMIZER_BUILD_SIDE_PROBE_SIDE, MetricsType::OPTIMIZER_LIMIT_PUSHDOWN, MetricsType::OPTIMIZER_TOP_N, + MetricsType::OPTIMIZER_TOP_N_WINDOW_ELIMINATION, MetricsType::OPTIMIZER_COMPRESSED_MATERIALIZATION, MetricsType::OPTIMIZER_DUPLICATE_GROUPS, MetricsType::OPTIMIZER_REORDER_FILTER, @@ -41,6 +42,7 @@ profiler_settings_t MetricsUtils::GetOptimizerMetrics() { MetricsType::OPTIMIZER_SUM_REWRITER, MetricsType::OPTIMIZER_LATE_MATERIALIZATION, MetricsType::OPTIMIZER_CTE_INLINING, + MetricsType::OPTIMIZER_COMMON_SUBPLAN, }; } @@ -48,12 +50,12 @@ profiler_settings_t MetricsUtils::GetPhaseTimingMetrics() { return { MetricsType::ALL_OPTIMIZERS, MetricsType::CUMULATIVE_OPTIMIZER_TIMING, - MetricsType::PLANNER, - MetricsType::PLANNER_BINDING, MetricsType::PHYSICAL_PLANNER, MetricsType::PHYSICAL_PLANNER_COLUMN_BINDING, - MetricsType::PHYSICAL_PLANNER_RESOLVE_TYPES, MetricsType::PHYSICAL_PLANNER_CREATE_PLAN, + MetricsType::PHYSICAL_PLANNER_RESOLVE_TYPES, + MetricsType::PLANNER, + MetricsType::PLANNER_BINDING, }; } @@ -95,6 +97,8 @@ MetricsType MetricsUtils::GetOptimizerMetricByType(OptimizerType type) { return MetricsType::OPTIMIZER_LIMIT_PUSHDOWN; case OptimizerType::TOP_N: return MetricsType::OPTIMIZER_TOP_N; + case OptimizerType::TOP_N_WINDOW_ELIMINATION: + return MetricsType::OPTIMIZER_TOP_N_WINDOW_ELIMINATION; case OptimizerType::COMPRESSED_MATERIALIZATION: return MetricsType::OPTIMIZER_COMPRESSED_MATERIALIZATION; case OptimizerType::DUPLICATE_GROUPS: @@ -115,6 +119,8 @@ MetricsType MetricsUtils::GetOptimizerMetricByType(OptimizerType type) { return MetricsType::OPTIMIZER_LATE_MATERIALIZATION; case OptimizerType::CTE_INLINING: return MetricsType::OPTIMIZER_CTE_INLINING; + case OptimizerType::COMMON_SUBPLAN: + return MetricsType::OPTIMIZER_COMMON_SUBPLAN; default: throw InternalException("OptimizerType %s cannot be converted to a MetricsType", EnumUtil::ToString(type)); }; @@ -158,6 +164,8 @@ OptimizerType MetricsUtils::GetOptimizerTypeByMetric(MetricsType type) { return OptimizerType::LIMIT_PUSHDOWN; case MetricsType::OPTIMIZER_TOP_N: return OptimizerType::TOP_N; + case MetricsType::OPTIMIZER_TOP_N_WINDOW_ELIMINATION: + return OptimizerType::TOP_N_WINDOW_ELIMINATION; case MetricsType::OPTIMIZER_COMPRESSED_MATERIALIZATION: return OptimizerType::COMPRESSED_MATERIALIZATION; case MetricsType::OPTIMIZER_DUPLICATE_GROUPS: @@ -178,6 +186,8 @@ OptimizerType MetricsUtils::GetOptimizerTypeByMetric(MetricsType type) { return OptimizerType::LATE_MATERIALIZATION; case MetricsType::OPTIMIZER_CTE_INLINING: return OptimizerType::CTE_INLINING; + case MetricsType::OPTIMIZER_COMMON_SUBPLAN: + return OptimizerType::COMMON_SUBPLAN; default: return OptimizerType::INVALID; }; @@ -203,6 +213,7 @@ bool MetricsUtils::IsOptimizerMetric(MetricsType type) { case MetricsType::OPTIMIZER_BUILD_SIDE_PROBE_SIDE: case MetricsType::OPTIMIZER_LIMIT_PUSHDOWN: case MetricsType::OPTIMIZER_TOP_N: + case MetricsType::OPTIMIZER_TOP_N_WINDOW_ELIMINATION: case MetricsType::OPTIMIZER_COMPRESSED_MATERIALIZATION: case MetricsType::OPTIMIZER_DUPLICATE_GROUPS: case MetricsType::OPTIMIZER_REORDER_FILTER: @@ -213,6 +224,7 @@ bool MetricsUtils::IsOptimizerMetric(MetricsType type) { case MetricsType::OPTIMIZER_SUM_REWRITER: case MetricsType::OPTIMIZER_LATE_MATERIALIZATION: case MetricsType::OPTIMIZER_CTE_INLINING: + case MetricsType::OPTIMIZER_COMMON_SUBPLAN: return true; default: return false; @@ -223,12 +235,12 @@ bool MetricsUtils::IsPhaseTimingMetric(MetricsType type) { switch(type) { case MetricsType::ALL_OPTIMIZERS: case MetricsType::CUMULATIVE_OPTIMIZER_TIMING: - case MetricsType::PLANNER: - case MetricsType::PLANNER_BINDING: case MetricsType::PHYSICAL_PLANNER: case MetricsType::PHYSICAL_PLANNER_COLUMN_BINDING: - case MetricsType::PHYSICAL_PLANNER_RESOLVE_TYPES: case MetricsType::PHYSICAL_PLANNER_CREATE_PLAN: + case MetricsType::PHYSICAL_PLANNER_RESOLVE_TYPES: + case MetricsType::PLANNER: + case MetricsType::PLANNER_BINDING: return true; default: return false; @@ -237,9 +249,13 @@ bool MetricsUtils::IsPhaseTimingMetric(MetricsType type) { bool MetricsUtils::IsQueryGlobalMetric(MetricsType type) { switch(type) { + case MetricsType::ATTACH_LOAD_STORAGE_LATENCY: + case MetricsType::ATTACH_REPLAY_WAL_LATENCY: case MetricsType::BLOCKED_THREAD_TIME: + case MetricsType::CHECKPOINT_LATENCY: case MetricsType::SYSTEM_PEAK_BUFFER_MEMORY: case MetricsType::SYSTEM_PEAK_TEMP_DIR_SIZE: + case MetricsType::WAITING_TO_ATTACH_LATENCY: return true; default: return false; diff --git a/src/duckdb/src/common/enums/optimizer_type.cpp b/src/duckdb/src/common/enums/optimizer_type.cpp index b0d669500..c7441a0fa 100644 --- a/src/duckdb/src/common/enums/optimizer_type.cpp +++ b/src/duckdb/src/common/enums/optimizer_type.cpp @@ -3,6 +3,7 @@ #include "duckdb/common/exception.hpp" #include "duckdb/common/exception/parser_exception.hpp" #include "duckdb/common/string_util.hpp" +#include "duckdb/optimizer/optimizer.hpp" namespace duckdb { @@ -29,6 +30,7 @@ static const DefaultOptimizerType internal_optimizer_types[] = { {"column_lifetime", OptimizerType::COLUMN_LIFETIME}, {"limit_pushdown", OptimizerType::LIMIT_PUSHDOWN}, {"top_n", OptimizerType::TOP_N}, + {"top_n_window_elimination", OptimizerType::TOP_N_WINDOW_ELIMINATION}, {"build_side_probe_side", OptimizerType::BUILD_SIDE_PROBE_SIDE}, {"compressed_materialization", OptimizerType::COMPRESSED_MATERIALIZATION}, {"duplicate_groups", OptimizerType::DUPLICATE_GROUPS}, @@ -40,6 +42,7 @@ static const DefaultOptimizerType internal_optimizer_types[] = { {"sum_rewriter", OptimizerType::SUM_REWRITER}, {"late_materialization", OptimizerType::LATE_MATERIALIZATION}, {"cte_inlining", OptimizerType::CTE_INLINING}, + {"common_subplan", OptimizerType::COMMON_SUBPLAN}, {nullptr, OptimizerType::INVALID}}; string OptimizerTypeToString(OptimizerType type) { diff --git a/src/duckdb/src/common/error_data.cpp b/src/duckdb/src/common/error_data.cpp index 2ddf94af6..f70620d33 100644 --- a/src/duckdb/src/common/error_data.cpp +++ b/src/duckdb/src/common/error_data.cpp @@ -24,7 +24,6 @@ ErrorData::ErrorData(ExceptionType type, const string &message) ErrorData::ErrorData(const string &message) : initialized(true), type(ExceptionType::INVALID), raw_message(string()), final_message(string()) { - // parse the constructed JSON if (message.empty() || message[0] != '{') { // not JSON! Use the message as a raw Exception message and leave type as uninitialized @@ -80,9 +79,9 @@ void ErrorData::Throw(const string &prepended_message) const { D_ASSERT(initialized); if (!prepended_message.empty()) { string new_message = prepended_message + raw_message; - throw Exception(type, new_message, extra_info); + throw Exception(extra_info, type, new_message); } else { - throw Exception(type, raw_message, extra_info); + throw Exception(extra_info, type, raw_message); } } diff --git a/src/duckdb/src/common/exception.cpp b/src/duckdb/src/common/exception.cpp index 2012c1fcc..d9bd31049 100644 --- a/src/duckdb/src/common/exception.cpp +++ b/src/duckdb/src/common/exception.cpp @@ -19,17 +19,17 @@ Exception::Exception(ExceptionType exception_type, const string &message) : std::runtime_error(ToJSON(exception_type, message)) { } -Exception::Exception(ExceptionType exception_type, const string &message, - const unordered_map &extra_info) - : std::runtime_error(ToJSON(exception_type, message, extra_info)) { +Exception::Exception(const unordered_map &extra_info, ExceptionType exception_type, + const string &message) + : std::runtime_error(ToJSON(extra_info, exception_type, message)) { } string Exception::ToJSON(ExceptionType type, const string &message) { unordered_map extra_info; - return ToJSON(type, message, extra_info); + return ToJSON(extra_info, type, message); } -string Exception::ToJSON(ExceptionType type, const string &message, const unordered_map &extra_info) { +string Exception::ToJSON(const unordered_map &extra_info, ExceptionType type, const string &message) { #ifndef DUCKDB_DEBUG_STACKTRACE // by default we only enable stack traces for internal exceptions if (type == ExceptionType::INTERNAL || type == ExceptionType::FATAL) @@ -240,9 +240,8 @@ TypeMismatchException::TypeMismatchException(const LogicalType &type_1, const Lo TypeMismatchException::TypeMismatchException(optional_idx error_location, const LogicalType &type_1, const LogicalType &type_2, const string &msg) - : Exception(ExceptionType::MISMATCH_TYPE, - "Type " + type_1.ToString() + " does not match with " + type_2.ToString() + ". " + msg, - Exception::InitializeExtraInfo(error_location)) { + : Exception(Exception::InitializeExtraInfo(error_location), ExceptionType::MISMATCH_TYPE, + "Type " + type_1.ToString() + " does not match with " + type_2.ToString() + ". " + msg) { } TypeMismatchException::TypeMismatchException(const string &msg) : Exception(ExceptionType::MISMATCH_TYPE, msg) { @@ -306,8 +305,12 @@ DependencyException::DependencyException(const string &msg) : Exception(Exceptio IOException::IOException(const string &msg) : Exception(ExceptionType::IO, msg) { } -IOException::IOException(const string &msg, const unordered_map &extra_info) - : Exception(ExceptionType::IO, msg, extra_info) { +IOException::IOException(const unordered_map &extra_info, const string &msg) + : Exception(extra_info, ExceptionType::IO, msg) { +} + +NotImplementedException::NotImplementedException(const unordered_map &extra_info, const string &msg) + : Exception(extra_info, ExceptionType::NOT_IMPLEMENTED, msg) { } MissingExtensionException::MissingExtensionException(const string &msg) @@ -339,20 +342,24 @@ InternalException::InternalException(const string &msg) : Exception(ExceptionTyp #endif } +InternalException::InternalException(const unordered_map &extra_info, const string &msg) + : Exception(extra_info, ExceptionType::INTERNAL, msg) { +} + InvalidInputException::InvalidInputException(const string &msg) : Exception(ExceptionType::INVALID_INPUT, msg) { } -InvalidInputException::InvalidInputException(const string &msg, const unordered_map &extra_info) - : Exception(ExceptionType::INVALID_INPUT, msg, extra_info) { +InvalidInputException::InvalidInputException(const unordered_map &extra_info, const string &msg) + : Exception(extra_info, ExceptionType::INVALID_INPUT, msg) { } InvalidConfigurationException::InvalidConfigurationException(const string &msg) : Exception(ExceptionType::INVALID_CONFIGURATION, msg) { } -InvalidConfigurationException::InvalidConfigurationException(const string &msg, - const unordered_map &extra_info) - : Exception(ExceptionType::INVALID_CONFIGURATION, msg, extra_info) { +InvalidConfigurationException::InvalidConfigurationException(const unordered_map &extra_info, + const string &msg) + : Exception(extra_info, ExceptionType::INVALID_CONFIGURATION, msg) { } OutOfMemoryException::OutOfMemoryException(const string &msg) diff --git a/src/duckdb/src/common/exception/binder_exception.cpp b/src/duckdb/src/common/exception/binder_exception.cpp index 62dca06fb..aa9a9459e 100644 --- a/src/duckdb/src/common/exception/binder_exception.cpp +++ b/src/duckdb/src/common/exception/binder_exception.cpp @@ -7,8 +7,8 @@ namespace duckdb { BinderException::BinderException(const string &msg) : Exception(ExceptionType::BINDER, msg) { } -BinderException::BinderException(const string &msg, const unordered_map &extra_info) - : Exception(ExceptionType::BINDER, msg, extra_info) { +BinderException::BinderException(const unordered_map &extra_info, const string &msg) + : Exception(extra_info, ExceptionType::BINDER, msg) { } BinderException BinderException::ColumnNotFound(const string &name, const vector &similar_bindings, @@ -18,9 +18,13 @@ BinderException BinderException::ColumnNotFound(const string &name, const vector extra_info["name"] = name; if (!similar_bindings.empty()) { extra_info["candidates"] = StringUtil::Join(similar_bindings, ","); + return BinderException(extra_info, StringUtil::Format("Referenced column \"%s\" not found in FROM clause!%s", + name, candidate_str)); + } else { + return BinderException( + extra_info, + StringUtil::Format("Referenced column \"%s\" was not found because the FROM clause is missing", name)); } - return BinderException( - StringUtil::Format("Referenced column \"%s\" not found in FROM clause!%s", name, candidate_str), extra_info); } BinderException BinderException::NoMatchingFunction(const string &catalog_name, const string &schema_name, @@ -45,15 +49,14 @@ BinderException BinderException::NoMatchingFunction(const string &catalog_name, extra_info["candidates"] = StringUtil::Join(candidates, ","); } return BinderException( + extra_info, StringUtil::Format("No function matches the given name and argument types '%s'. You might need to add " "explicit type casts.\n\tCandidate functions:\n%s", - call_str, candidate_str), - extra_info); + call_str, candidate_str)); } BinderException BinderException::Unsupported(ParsedExpression &expr, const string &message) { auto extra_info = Exception::InitializeExtraInfo("UNSUPPORTED", expr.GetQueryLocation()); - return BinderException(message, extra_info); + return BinderException(extra_info, message); } - } // namespace duckdb diff --git a/src/duckdb/src/common/exception/catalog_exception.cpp b/src/duckdb/src/common/exception/catalog_exception.cpp index 5d890f1cd..b1cd4caf7 100644 --- a/src/duckdb/src/common/exception/catalog_exception.cpp +++ b/src/duckdb/src/common/exception/catalog_exception.cpp @@ -9,8 +9,8 @@ namespace duckdb { CatalogException::CatalogException(const string &msg) : Exception(ExceptionType::CATALOG, msg) { } -CatalogException::CatalogException(const string &msg, const unordered_map &extra_info) - : Exception(ExceptionType::CATALOG, msg, extra_info) { +CatalogException::CatalogException(const unordered_map &extra_info, const string &msg) + : Exception(extra_info, ExceptionType::CATALOG, msg) { } CatalogException CatalogException::MissingEntry(const EntryLookupInfo &lookup_info, const string &suggestion) { @@ -35,9 +35,9 @@ CatalogException CatalogException::MissingEntry(const EntryLookupInfo &lookup_in if (!suggestion.empty()) { extra_info["candidates"] = suggestion; } - return CatalogException(StringUtil::Format("%s with name %s does not exist%s!%s", CatalogTypeToString(type), name, - version_info, did_you_mean), - extra_info); + return CatalogException(extra_info, + StringUtil::Format("%s with name %s does not exist%s!%s", CatalogTypeToString(type), name, + version_info, did_you_mean)); } CatalogException CatalogException::MissingEntry(CatalogType type, const string &name, const string &suggestion, @@ -55,17 +55,17 @@ CatalogException CatalogException::MissingEntry(const string &type, const string if (!suggestions.empty()) { extra_info["candidates"] = StringUtil::Join(suggestions, ", "); } - return CatalogException(StringUtil::Format("unrecognized %s \"%s\"\n%s", type, name, - StringUtil::CandidatesErrorMessage(suggestions, name, "Did you mean")), - extra_info); + return CatalogException(extra_info, + StringUtil::Format("unrecognized %s \"%s\"\n%s", type, name, + StringUtil::CandidatesErrorMessage(suggestions, name, "Did you mean"))); } CatalogException CatalogException::EntryAlreadyExists(CatalogType type, const string &name, QueryErrorContext context) { auto extra_info = Exception::InitializeExtraInfo("ENTRY_ALREADY_EXISTS", optional_idx()); extra_info["name"] = name; extra_info["type"] = CatalogTypeToString(type); - return CatalogException(StringUtil::Format("%s with name \"%s\" already exists!", CatalogTypeToString(type), name), - extra_info); + return CatalogException(extra_info, + StringUtil::Format("%s with name \"%s\" already exists!", CatalogTypeToString(type), name)); } } // namespace duckdb diff --git a/src/duckdb/src/common/exception/conversion_exception.cpp b/src/duckdb/src/common/exception/conversion_exception.cpp index 013dbdb9e..bf021b4eb 100644 --- a/src/duckdb/src/common/exception/conversion_exception.cpp +++ b/src/duckdb/src/common/exception/conversion_exception.cpp @@ -17,7 +17,7 @@ ConversionException::ConversionException(const string &msg) : Exception(Exceptio } ConversionException::ConversionException(optional_idx error_location, const string &msg) - : Exception(ExceptionType::CONVERSION, msg, Exception::InitializeExtraInfo(error_location)) { + : Exception(Exception::InitializeExtraInfo(error_location), ExceptionType::CONVERSION, msg) { } } // namespace duckdb diff --git a/src/duckdb/src/common/exception/parser_exception.cpp b/src/duckdb/src/common/exception/parser_exception.cpp index f3875da38..3afb2ea3d 100644 --- a/src/duckdb/src/common/exception/parser_exception.cpp +++ b/src/duckdb/src/common/exception/parser_exception.cpp @@ -7,13 +7,12 @@ namespace duckdb { ParserException::ParserException(const string &msg) : Exception(ExceptionType::PARSER, msg) { } -ParserException::ParserException(const string &msg, const unordered_map &extra_info) - : Exception(ExceptionType::PARSER, msg, extra_info) { +ParserException::ParserException(const unordered_map &extra_info, const string &msg) + : Exception(extra_info, ExceptionType::PARSER, msg) { } ParserException ParserException::SyntaxError(const string &query, const string &error_message, optional_idx error_location) { - return ParserException(error_message, Exception::InitializeExtraInfo("SYNTAX_ERROR", error_location)); + return ParserException(Exception::InitializeExtraInfo("SYNTAX_ERROR", error_location), error_message); } - } // namespace duckdb diff --git a/src/duckdb/src/common/exception_format_value.cpp b/src/duckdb/src/common/exception_format_value.cpp index 51e34ec0e..27b4eb465 100644 --- a/src/duckdb/src/common/exception_format_value.cpp +++ b/src/duckdb/src/common/exception_format_value.cpp @@ -28,65 +28,61 @@ ExceptionFormatValue::ExceptionFormatValue(uhugeint_t uhuge_val) ExceptionFormatValue::ExceptionFormatValue(string str_val) : type(ExceptionFormatValueType::FORMAT_VALUE_TYPE_STRING), str_val(std::move(str_val)) { } -ExceptionFormatValue::ExceptionFormatValue(String str_val) - : type(ExceptionFormatValueType::FORMAT_VALUE_TYPE_STRING), str_val(str_val.ToStdString()) { +ExceptionFormatValue::ExceptionFormatValue(const String &str_val) : ExceptionFormatValue(str_val.ToStdString()) { } template <> -ExceptionFormatValue ExceptionFormatValue::CreateFormatValue(PhysicalType value) { +ExceptionFormatValue ExceptionFormatValue::CreateFormatValue(const PhysicalType &value) { return ExceptionFormatValue(TypeIdToString(value)); } template <> -ExceptionFormatValue -ExceptionFormatValue::CreateFormatValue(LogicalType value) { // NOLINT: templating requires us to copy value here +ExceptionFormatValue ExceptionFormatValue::CreateFormatValue(const LogicalType &value) { return ExceptionFormatValue(value.ToString()); } template <> -ExceptionFormatValue ExceptionFormatValue::CreateFormatValue(float value) { - return ExceptionFormatValue(double(value)); +ExceptionFormatValue ExceptionFormatValue::CreateFormatValue(const float &value) { + return ExceptionFormatValue(static_cast(value)); } template <> -ExceptionFormatValue ExceptionFormatValue::CreateFormatValue(double value) { - return ExceptionFormatValue(double(value)); +ExceptionFormatValue ExceptionFormatValue::CreateFormatValue(const double &value) { + return ExceptionFormatValue(value); } template <> -ExceptionFormatValue ExceptionFormatValue::CreateFormatValue(string value) { - return ExceptionFormatValue(std::move(value)); +ExceptionFormatValue ExceptionFormatValue::CreateFormatValue(const string &value) { + return ExceptionFormatValue(value); } template <> -ExceptionFormatValue ExceptionFormatValue::CreateFormatValue(String value) { - return ExceptionFormatValue(std::move(value)); +ExceptionFormatValue ExceptionFormatValue::CreateFormatValue(const String &value) { + return ExceptionFormatValue(value); } template <> -ExceptionFormatValue -ExceptionFormatValue::CreateFormatValue(SQLString value) { // NOLINT: templating requires us to copy value here +ExceptionFormatValue ExceptionFormatValue::CreateFormatValue(const SQLString &value) { return KeywordHelper::WriteQuoted(value.raw_string, '\''); } template <> -ExceptionFormatValue -ExceptionFormatValue::CreateFormatValue(SQLIdentifier value) { // NOLINT: templating requires us to copy value here +ExceptionFormatValue ExceptionFormatValue::CreateFormatValue(const SQLIdentifier &value) { return KeywordHelper::WriteOptionallyQuoted(value.raw_string, '"'); } template <> -ExceptionFormatValue ExceptionFormatValue::CreateFormatValue(const char *value) { +ExceptionFormatValue ExceptionFormatValue::CreateFormatValue(const char *const &value) { return ExceptionFormatValue(string(value)); } template <> -ExceptionFormatValue ExceptionFormatValue::CreateFormatValue(char *value) { +ExceptionFormatValue ExceptionFormatValue::CreateFormatValue(char *const &value) { return ExceptionFormatValue(string(value)); } template <> -ExceptionFormatValue ExceptionFormatValue::CreateFormatValue(idx_t value) { +ExceptionFormatValue ExceptionFormatValue::CreateFormatValue(const idx_t &value) { return ExceptionFormatValue(value); } template <> -ExceptionFormatValue ExceptionFormatValue::CreateFormatValue(hugeint_t value) { +ExceptionFormatValue ExceptionFormatValue::CreateFormatValue(const hugeint_t &value) { return ExceptionFormatValue(value); } template <> -ExceptionFormatValue ExceptionFormatValue::CreateFormatValue(uhugeint_t value) { +ExceptionFormatValue ExceptionFormatValue::CreateFormatValue(const uhugeint_t &value) { return ExceptionFormatValue(value); } diff --git a/src/duckdb/src/common/extra_type_info.cpp b/src/duckdb/src/common/extra_type_info.cpp index 1d3160814..6218f3e7b 100644 --- a/src/duckdb/src/common/extra_type_info.cpp +++ b/src/duckdb/src/common/extra_type_info.cpp @@ -507,4 +507,19 @@ shared_ptr TemplateTypeInfo::Copy() const { return make_shared_ptr(*this); } +//===--------------------------------------------------------------------===// +// Geo Type Info +//===--------------------------------------------------------------------===// +GeoTypeInfo::GeoTypeInfo() : ExtraTypeInfo(ExtraTypeInfoType::GEO_TYPE_INFO) { +} + +bool GeoTypeInfo::EqualsInternal(ExtraTypeInfo *other_p) const { + // No additional info to compare + return true; +} + +shared_ptr GeoTypeInfo::Copy() const { + return make_shared_ptr(*this); +} + } // namespace duckdb diff --git a/src/duckdb/src/common/file_system.cpp b/src/duckdb/src/common/file_system.cpp index 926cfb6a0..5a712aff0 100644 --- a/src/duckdb/src/common/file_system.cpp +++ b/src/duckdb/src/common/file_system.cpp @@ -30,10 +30,7 @@ #include #ifdef __MVS__ -#define _XOPEN_SOURCE_EXTENDED 1 #include -// enjoy - https://reviews.llvm.org/D92110 -#define PATH_MAX _XOPEN_PATH_MAX #endif #else @@ -628,39 +625,9 @@ bool FileSystem::CanHandleFile(const string &fpath) { throw NotImplementedException("%s: CanHandleFile is not implemented!", GetName()); } -static string LookupExtensionForPattern(const string &pattern) { - for (const auto &entry : EXTENSION_FILE_PREFIXES) { - if (StringUtil::StartsWith(pattern, entry.name)) { - return entry.extension; - } - } - return ""; -} - vector FileSystem::GlobFiles(const string &pattern, ClientContext &context, const FileGlobInput &input) { auto result = Glob(pattern); if (result.empty()) { - string required_extension = LookupExtensionForPattern(pattern); - if (!required_extension.empty() && !context.db->ExtensionIsLoaded(required_extension)) { - auto &dbconfig = DBConfig::GetConfig(context); - if (!ExtensionHelper::CanAutoloadExtension(required_extension) || - !dbconfig.options.autoload_known_extensions) { - auto error_message = - "File " + pattern + " requires the extension " + required_extension + " to be loaded"; - error_message = - ExtensionHelper::AddExtensionInstallHintToErrorMsg(context, error_message, required_extension); - throw MissingExtensionException(error_message); - } - // an extension is required to read this file, but it is not loaded - try to load it - ExtensionHelper::AutoLoadExtension(context, required_extension); - // success! glob again - // check the extension is loaded just in case to prevent an infinite loop here - if (!context.db->ExtensionIsLoaded(required_extension)) { - throw InternalException("Extension load \"%s\" did not throw but somehow the extension was not loaded", - required_extension); - } - return GlobFiles(pattern, context, input); - } if (input.behavior == FileGlobOptions::FALLBACK_GLOB && !HasGlob(pattern)) { // if we have no glob in the pattern and we have an extension, we try to glob if (!HasGlob(pattern)) { @@ -724,7 +691,7 @@ int64_t FileHandle::Read(void *buffer, idx_t nr_bytes) { int64_t FileHandle::Read(QueryContext context, void *buffer, idx_t nr_bytes) { if (context.GetClientContext() != nullptr) { - context.GetClientContext()->client_data->profiler->AddBytesRead(nr_bytes); + context.GetClientContext()->client_data->profiler->AddToCounter(MetricsType::TOTAL_BYTES_READ, nr_bytes); } return file_system.Read(*this, buffer, UnsafeNumericCast(nr_bytes)); @@ -744,7 +711,7 @@ void FileHandle::Read(void *buffer, idx_t nr_bytes, idx_t location) { void FileHandle::Read(QueryContext context, void *buffer, idx_t nr_bytes, idx_t location) { if (context.GetClientContext() != nullptr) { - context.GetClientContext()->client_data->profiler->AddBytesRead(nr_bytes); + context.GetClientContext()->client_data->profiler->AddToCounter(MetricsType::TOTAL_BYTES_READ, nr_bytes); } file_system.Read(*this, buffer, UnsafeNumericCast(nr_bytes), location); @@ -752,7 +719,7 @@ void FileHandle::Read(QueryContext context, void *buffer, idx_t nr_bytes, idx_t void FileHandle::Write(QueryContext context, void *buffer, idx_t nr_bytes, idx_t location) { if (context.GetClientContext() != nullptr) { - context.GetClientContext()->client_data->profiler->AddBytesWritten(nr_bytes); + context.GetClientContext()->client_data->profiler->AddToCounter(MetricsType::TOTAL_BYTES_WRITTEN, nr_bytes); } file_system.Write(*this, buffer, UnsafeNumericCast(nr_bytes), location); diff --git a/src/duckdb/src/common/hive_partitioning.cpp b/src/duckdb/src/common/hive_partitioning.cpp index 932943b8f..78f3b40e8 100644 --- a/src/duckdb/src/common/hive_partitioning.cpp +++ b/src/duckdb/src/common/hive_partitioning.cpp @@ -153,7 +153,6 @@ void HivePartitioning::ApplyFiltersToFileList(ClientContext &context, vector> &filters, const HivePartitioningFilterInfo &filter_info, MultiFilePushdownInfo &info) { - vector pruned_files; vector have_preserved_filter(filters.size(), false); vector> pruned_filters; diff --git a/src/duckdb/src/common/local_file_system.cpp b/src/duckdb/src/common/local_file_system.cpp index 8733e0162..7433fe4cb 100644 --- a/src/duckdb/src/common/local_file_system.cpp +++ b/src/duckdb/src/common/local_file_system.cpp @@ -369,7 +369,7 @@ unique_ptr LocalFileSystem::OpenFile(const string &path_p, FileOpenF if (flags.ReturnNullIfExists() && errno == EEXIST) { return nullptr; } - throw IOException("Cannot open file \"%s\": %s", {{"errno", std::to_string(errno)}}, path, strerror(errno)); + throw IOException({{"errno", std::to_string(errno)}}, "Cannot open file \"%s\": %s", path, strerror(errno)); } #if defined(__DARWIN__) || defined(__APPLE__) @@ -436,7 +436,7 @@ unique_ptr LocalFileSystem::OpenFile(const string &path_p, FileOpenF extended_error += ". Also, failed closing file"; } extended_error += ". See also https://duckdb.org/docs/stable/connect/concurrency"; - throw IOException("Could not set lock on file \"%s\": %s", {{"errno", std::to_string(retained_errno)}}, + throw IOException({{"errno", std::to_string(retained_errno)}}, "Could not set lock on file \"%s\": %s", path, extended_error); } } @@ -454,7 +454,7 @@ void LocalFileSystem::SetFilePointer(FileHandle &handle, idx_t location) { int fd = handle.Cast().fd; off_t offset = lseek(fd, UnsafeNumericCast(location), SEEK_SET); if (offset == (off_t)-1) { - throw IOException("Could not seek to location %lld for file \"%s\": %s", {{"errno", std::to_string(errno)}}, + throw IOException({{"errno", std::to_string(errno)}}, "Could not seek to location %lld for file \"%s\": %s", location, handle.path, strerror(errno)); } } @@ -463,7 +463,7 @@ idx_t LocalFileSystem::GetFilePointer(FileHandle &handle) { int fd = handle.Cast().fd; off_t position = lseek(fd, 0, SEEK_CUR); if (position == (off_t)-1) { - throw IOException("Could not get file position file \"%s\": %s", {{"errno", std::to_string(errno)}}, + throw IOException({{"errno", std::to_string(errno)}}, "Could not get file position file \"%s\": %s", handle.path, strerror(errno)); } return UnsafeNumericCast(position); @@ -477,7 +477,7 @@ void LocalFileSystem::Read(FileHandle &handle, void *buffer, int64_t nr_bytes, i int64_t bytes_read = pread(fd, read_buffer, UnsafeNumericCast(nr_bytes), UnsafeNumericCast(location)); if (bytes_read == -1) { - throw IOException("Could not read from file \"%s\": %s", {{"errno", std::to_string(errno)}}, handle.path, + throw IOException({{"errno", std::to_string(errno)}}, "Could not read from file \"%s\": %s", handle.path, strerror(errno)); } if (bytes_read == 0) { @@ -498,7 +498,7 @@ int64_t LocalFileSystem::Read(FileHandle &handle, void *buffer, int64_t nr_bytes int fd = unix_handle.fd; int64_t bytes_read = read(fd, buffer, UnsafeNumericCast(nr_bytes)); if (bytes_read == -1) { - throw IOException("Could not read from file \"%s\": %s", {{"errno", std::to_string(errno)}}, handle.path, + throw IOException({{"errno", std::to_string(errno)}}, "Could not read from file \"%s\": %s", handle.path, strerror(errno)); } @@ -519,12 +519,13 @@ void LocalFileSystem::Write(FileHandle &handle, void *buffer, int64_t nr_bytes, int64_t bytes_written = pwrite(fd, write_buffer, UnsafeNumericCast(bytes_to_write), UnsafeNumericCast(current_location)); if (bytes_written < 0) { - throw IOException("Could not write file \"%s\": %s", {{"errno", std::to_string(errno)}}, handle.path, + throw IOException({{"errno", std::to_string(errno)}}, "Could not write file \"%s\": %s", handle.path, strerror(errno)); } if (bytes_written == 0) { - throw IOException("Could not write to file \"%s\" - attempted to write 0 bytes: %s", - {{"errno", std::to_string(errno)}}, handle.path, strerror(errno)); + throw IOException({{"errno", std::to_string(errno)}}, + "Could not write to file \"%s\" - attempted to write 0 bytes: %s", handle.path, + strerror(errno)); } write_buffer += bytes_written; bytes_to_write -= bytes_written; @@ -544,7 +545,7 @@ int64_t LocalFileSystem::Write(FileHandle &handle, void *buffer, int64_t nr_byte MinValue(idx_t(NumericLimits::Maximum()), idx_t(bytes_to_write)); int64_t current_bytes_written = write(fd, buffer, bytes_to_write_this_call); if (current_bytes_written <= 0) { - throw IOException("Could not write file \"%s\": %s", {{"errno", std::to_string(errno)}}, handle.path, + throw IOException({{"errno", std::to_string(errno)}}, "Could not write file \"%s\": %s", handle.path, strerror(errno)); } buffer = (void *)(data_ptr_cast(buffer) + current_bytes_written); @@ -577,7 +578,7 @@ int64_t LocalFileSystem::GetFileSize(FileHandle &handle) { int fd = handle.Cast().fd; struct stat s; if (fstat(fd, &s) == -1) { - throw IOException("Failed to get file size for file \"%s\": %s", {{"errno", std::to_string(errno)}}, + throw IOException({{"errno", std::to_string(errno)}}, "Failed to get file size for file \"%s\": %s", handle.path, strerror(errno)); } return s.st_size; @@ -587,7 +588,7 @@ timestamp_t LocalFileSystem::GetLastModifiedTime(FileHandle &handle) { int fd = handle.Cast().fd; struct stat s; if (fstat(fd, &s) == -1) { - throw IOException("Failed to get last modified time for file \"%s\": %s", {{"errno", std::to_string(errno)}}, + throw IOException({{"errno", std::to_string(errno)}}, "Failed to get last modified time for file \"%s\": %s", handle.path, strerror(errno)); } return Timestamp::FromEpochSeconds(s.st_mtime); @@ -601,7 +602,7 @@ FileType LocalFileSystem::GetFileType(FileHandle &handle) { void LocalFileSystem::Truncate(FileHandle &handle, int64_t new_size) { int fd = handle.Cast().fd; if (ftruncate(fd, new_size) != 0) { - throw IOException("Could not truncate file \"%s\": %s", {{"errno", std::to_string(errno)}}, handle.path, + throw IOException({{"errno", std::to_string(errno)}}, "Could not truncate file \"%s\": %s", handle.path, strerror(errno)); } } @@ -612,7 +613,7 @@ bool LocalFileSystem::DirectoryExists(const string &directory, optional_ptr opener) { auto normalized_file = NormalizeLocalPath(filename); if (std::remove(normalized_file) != 0) { - throw IOException("Could not remove file \"%s\": %s", {{"errno", std::to_string(errno)}}, filename, + throw IOException({{"errno", std::to_string(errno)}}, "Could not remove file \"%s\": %s", filename, strerror(errno)); } } @@ -718,7 +719,7 @@ bool LocalFileSystem::ListFilesExtended(const string &directory, if (res != 0) { continue; } - if (!(status.st_mode & S_IFREG) && !(status.st_mode & S_IFDIR)) { + if (!S_ISREG(status.st_mode) && !S_ISDIR(status.st_mode)) { // not a file or directory: skip continue; } @@ -726,7 +727,7 @@ bool LocalFileSystem::ListFilesExtended(const string &directory, info.extended_info = make_shared_ptr(); auto &options = info.extended_info->options; // file type - Value file_type(status.st_mode & S_IFDIR ? "directory" : "file"); + Value file_type(S_ISDIR(status.st_mode) ? "directory" : "file"); options.emplace("type", std::move(file_type)); // file size options.emplace("file_size", Value::BIGINT(UnsafeNumericCast(status.st_size))); @@ -767,8 +768,7 @@ void LocalFileSystem::FileSync(FileHandle &handle) { } // For other types of errors, throw normal IO exception. - throw IOException("Could not fsync file \"%s\": %s", {{"errno", std::to_string(errno)}}, handle.GetPath(), - strerror(errno)); + throw IOException("Could not fsync file \"%s\": %s", handle.GetPath(), strerror(errno)); } void LocalFileSystem::MoveFile(const string &source, const string &target, optional_ptr opener) { @@ -776,7 +776,7 @@ void LocalFileSystem::MoveFile(const string &source, const string &target, optio auto normalized_target = NormalizeLocalPath(target); //! FIXME: rename does not guarantee atomicity or overwriting target file if it exists if (rename(normalized_source, normalized_target) != 0) { - throw IOException("Could not rename file!", {{"errno", std::to_string(errno)}}); + throw IOException({{"errno", std::to_string(errno)}}, "Could not rename file!"); } } @@ -1052,7 +1052,7 @@ static int64_t FSWrite(FileHandle &handle, HANDLE hFile, void *buffer, int64_t n auto bytes_to_write = MinValue(idx_t(NumericLimits::Maximum()), idx_t(nr_bytes)); DWORD current_bytes_written = FSInternalWrite(handle, hFile, buffer, bytes_to_write, location); if (current_bytes_written <= 0) { - throw IOException("Could not write file \"%s\": %s", {{"errno", std::to_string(errno)}}, handle.path, + throw IOException({{"errno", std::to_string(errno)}}, "Could not write file \"%s\": %s", handle.path, strerror(errno)); } bytes_written += current_bytes_written; @@ -1319,7 +1319,6 @@ static bool IsSymbolicLink(const string &path) { static void RecursiveGlobDirectories(FileSystem &fs, const string &path, vector &result, bool match_directory, bool join_path) { - fs.ListFiles(path, [&](OpenFileInfo &info) { if (join_path) { info.path = fs.JoinPath(path, info.path); diff --git a/src/duckdb/src/common/multi_file/multi_file_reader.cpp b/src/duckdb/src/common/multi_file/multi_file_reader.cpp index 21413261e..56c8d3b93 100644 --- a/src/duckdb/src/common/multi_file/multi_file_reader.cpp +++ b/src/duckdb/src/common/multi_file/multi_file_reader.cpp @@ -299,7 +299,6 @@ void MultiFileReader::FinalizeBind(MultiFileReaderData &reader_data, const Multi const vector &global_columns, const vector &global_column_ids, ClientContext &context, optional_ptr global_state) { - // create a map of name -> column index auto &local_columns = reader_data.reader->GetColumns(); auto &filename = reader_data.reader->GetFileName(); diff --git a/src/duckdb/src/common/operator/cast_operators.cpp b/src/duckdb/src/common/operator/cast_operators.cpp index f26c16131..5998d7787 100644 --- a/src/duckdb/src/common/operator/cast_operators.cpp +++ b/src/duckdb/src/common/operator/cast_operators.cpp @@ -19,6 +19,7 @@ #include "duckdb/common/types/time.hpp" #include "duckdb/common/types/timestamp.hpp" #include "duckdb/common/types/vector.hpp" +#include "duckdb/common/types/geometry.hpp" #include "duckdb/common/types.hpp" #include "fast_float/fast_float.h" #include "duckdb/common/types/bit.hpp" @@ -1406,7 +1407,6 @@ string_t CastFromBlobToBit::Operation(string_t input, Vector &vector) { //===--------------------------------------------------------------------===// template <> string_t CastFromBitToString::Operation(string_t input, Vector &vector) { - idx_t result_size = Bit::BitLength(input); string_t result = StringVector::EmptyString(vector, result_size); Bit::ToString(input, result.GetDataWriteable()); @@ -1560,6 +1560,14 @@ bool TryCastBlobToUUID::Operation(string_t input, hugeint_t &result, bool strict return true; } +//===--------------------------------------------------------------------===// +// Cast To Geometry +//===--------------------------------------------------------------------===// +template <> +bool TryCastToGeometry::Operation(string_t input, string_t &result, Vector &result_vector, CastParameters ¶meters) { + return Geometry::FromString(input, result, result_vector, parameters.strict); +} + //===--------------------------------------------------------------------===// // Cast To Date //===--------------------------------------------------------------------===// diff --git a/src/duckdb/src/common/progress_bar/terminal_progress_bar_display.cpp b/src/duckdb/src/common/progress_bar/terminal_progress_bar_display.cpp index 417ac609a..50b44bf97 100644 --- a/src/duckdb/src/common/progress_bar/terminal_progress_bar_display.cpp +++ b/src/duckdb/src/common/progress_bar/terminal_progress_bar_display.cpp @@ -14,7 +14,7 @@ int32_t TerminalProgressBarDisplay::NormalizePercentage(double percentage) { return int32_t(percentage); } -static string FormatETA(double seconds, bool elapsed = false) { +string TerminalProgressBarDisplay::FormatETA(double seconds, bool elapsed) { // for terminal rendering purposes, we need to make sure the length is always the same // we pad the end with spaces if that is not the case // the maximum length here is "(~10.35 minutes remaining)" (26 bytes) @@ -68,14 +68,38 @@ static string FormatETA(double seconds, bool elapsed = false) { return result; } -void TerminalProgressBarDisplay::PrintProgressInternal(int32_t percentage, double seconds, bool finished) { - string result; +string TerminalProgressBarDisplay::FormatProgressBar(const ProgressBarDisplayInfo &display, int32_t percentage) { // we divide the number of blocks by the percentage // 0% = 0 // 100% = PROGRESS_BAR_WIDTH // the percentage determines how many blocks we need to draw - double blocks_to_draw = PROGRESS_BAR_WIDTH * (percentage / 100.0); + double blocks_to_draw = static_cast(display.width) * (percentage / 100.0); // because of the power of unicode, we can also draw partial blocks + string result; + result += display.progress_start; + idx_t i; + for (i = 0; i < idx_t(blocks_to_draw); i++) { + result += display.progress_block; + } + if (i < display.width) { + // print a partial block based on the percentage of the progress bar remaining + idx_t index = idx_t((blocks_to_draw - static_cast(idx_t(blocks_to_draw))) * + static_cast(display.partial_block_count)); + if (index >= display.partial_block_count) { + index = display.partial_block_count - 1; + } + result += display.progress_partial[index]; + i++; + } + for (; i < display.width; i++) { + result += display.progress_empty; + } + result += display.progress_end; + return result; +} + +void TerminalProgressBarDisplay::PrintProgressInternal(int32_t percentage, double seconds, bool finished) { + string result; // render the percentage with some padding to ensure everything stays nicely aligned result = "\r"; @@ -87,24 +111,7 @@ void TerminalProgressBarDisplay::PrintProgressInternal(int32_t percentage, doubl } result += to_string(percentage) + "%"; result += " "; - result += PROGRESS_START; - idx_t i; - for (i = 0; i < idx_t(blocks_to_draw); i++) { - result += PROGRESS_BLOCK; - } - if (i < PROGRESS_BAR_WIDTH) { - // print a partial block based on the percentage of the progress bar remaining - idx_t index = idx_t((blocks_to_draw - static_cast(idx_t(blocks_to_draw))) * PARTIAL_BLOCK_COUNT); - if (index >= PARTIAL_BLOCK_COUNT) { - index = PARTIAL_BLOCK_COUNT - 1; - } - result += PROGRESS_PARTIAL[index]; - i++; - } - for (; i < PROGRESS_BAR_WIDTH; i++) { - result += PROGRESS_EMPTY; - } - result += PROGRESS_END; + result += FormatProgressBar(display_info, percentage); result += " "; result += FormatETA(seconds, finished); diff --git a/src/duckdb/src/common/progress_bar/unscented_kalman_filter.cpp b/src/duckdb/src/common/progress_bar/unscented_kalman_filter.cpp index 551cdbb0f..158d59fd7 100644 --- a/src/duckdb/src/common/progress_bar/unscented_kalman_filter.cpp +++ b/src/duckdb/src/common/progress_bar/unscented_kalman_filter.cpp @@ -6,7 +6,6 @@ UnscentedKalmanFilter::UnscentedKalmanFilter() : x(STATE_DIM, 0.0), P(STATE_DIM, vector(STATE_DIM, 0.0)), Q(STATE_DIM, vector(STATE_DIM, 0.0)), R(OBS_DIM, vector(OBS_DIM, 0.0)), last_time(0.0), initialized(false), last_progress(-1.0), scale_factor(1.0) { - // Calculate UKF parameters lambda = ALPHA * ALPHA * (STATE_DIM + KAPPA) - STATE_DIM; @@ -254,11 +253,11 @@ void UnscentedKalmanFilter::UpdateInternal(double measured_progress) { } // Ensure progress stays in bounds - x[0] = std::max(0.0, std::min(1.0, x[0])); + x[0] = std::max(0.0, std::min(scale_factor, x[0])); } double UnscentedKalmanFilter::GetProgress() const { - return x[0]; + return x[0] / scale_factor; } double UnscentedKalmanFilter::GetVelocity() const { diff --git a/src/duckdb/src/common/radix_partitioning.cpp b/src/duckdb/src/common/radix_partitioning.cpp index 487e106af..7ad92eec5 100644 --- a/src/duckdb/src/common/radix_partitioning.cpp +++ b/src/duckdb/src/common/radix_partitioning.cpp @@ -98,6 +98,7 @@ struct ComputePartitionIndicesFunctor { const auto source_data = UnifiedVectorFormat::GetData(format); const auto &source_sel = *format.sel; + partition_indices.SetVectorType(VectorType::FLAT_VECTOR); const auto target = FlatVector::GetData(partition_indices); if (source_sel.IsSet()) { @@ -178,7 +179,7 @@ RadixPartitionedTupleData::RadixPartitionedTupleData(BufferManager &buffer_manag Initialize(); } -RadixPartitionedTupleData::RadixPartitionedTupleData(const RadixPartitionedTupleData &other) +RadixPartitionedTupleData::RadixPartitionedTupleData(RadixPartitionedTupleData &other) : PartitionedTupleData(other), radix_bits(other.radix_bits), hash_col_idx(other.hash_col_idx) { Initialize(); } @@ -189,7 +190,7 @@ RadixPartitionedTupleData::~RadixPartitionedTupleData() { void RadixPartitionedTupleData::Initialize() { const auto num_partitions = RadixPartitioning::NumberOfPartitions(radix_bits); for (idx_t i = 0; i < num_partitions; i++) { - partitions.emplace_back(CreatePartitionCollection(i)); + partitions.emplace_back(CreatePartitionCollection()); partitions.back()->SetPartitionIndex(i); } } diff --git a/src/duckdb/src/common/render_tree.cpp b/src/duckdb/src/common/render_tree.cpp index 582d5e1ad..ee9621814 100644 --- a/src/duckdb/src/common/render_tree.cpp +++ b/src/duckdb/src/common/render_tree.cpp @@ -103,7 +103,7 @@ static unique_ptr CreateNode(const ProfilingNode &op) { auto &info = op.GetProfilingInfo(); InsertionOrderPreservingMap extra_info; if (info.Enabled(info.settings, MetricsType::EXTRA_INFO)) { - extra_info = op.GetProfilingInfo().extra_info; + extra_info = op.GetProfilingInfo().GetMetricValue>(MetricsType::EXTRA_INFO); } string node_name = "QUERY"; diff --git a/src/duckdb/src/common/row_operations/row_external.cpp b/src/duckdb/src/common/row_operations/row_external.cpp deleted file mode 100644 index e4e3ec87d..000000000 --- a/src/duckdb/src/common/row_operations/row_external.cpp +++ /dev/null @@ -1,157 +0,0 @@ -#include "duckdb/common/row_operations/row_operations.hpp" -#include "duckdb/common/types/row/row_layout.hpp" - -namespace duckdb { - -using ValidityBytes = RowLayout::ValidityBytes; - -void RowOperations::SwizzleColumns(const RowLayout &layout, const data_ptr_t base_row_ptr, const idx_t count) { - const idx_t row_width = layout.GetRowWidth(); - data_ptr_t heap_row_ptrs[STANDARD_VECTOR_SIZE]; - idx_t done = 0; - while (done != count) { - const idx_t next = MinValue(count - done, STANDARD_VECTOR_SIZE); - const data_ptr_t row_ptr = base_row_ptr + done * row_width; - // Load heap row pointers - data_ptr_t heap_ptr_ptr = row_ptr + layout.GetHeapOffset(); - for (idx_t i = 0; i < next; i++) { - heap_row_ptrs[i] = Load(heap_ptr_ptr); - heap_ptr_ptr += row_width; - } - // Loop through the blob columns - for (idx_t col_idx = 0; col_idx < layout.ColumnCount(); col_idx++) { - auto physical_type = layout.GetTypes()[col_idx].InternalType(); - if (TypeIsConstantSize(physical_type)) { - continue; - } - data_ptr_t col_ptr = row_ptr + layout.GetOffsets()[col_idx]; - if (physical_type == PhysicalType::VARCHAR) { - data_ptr_t string_ptr = col_ptr + string_t::HEADER_SIZE; - for (idx_t i = 0; i < next; i++) { - if (Load(col_ptr) > string_t::INLINE_LENGTH) { - // Overwrite the string pointer with the within-row offset (if not inlined) - Store(UnsafeNumericCast(Load(string_ptr) - heap_row_ptrs[i]), - string_ptr); - } - col_ptr += row_width; - string_ptr += row_width; - } - } else { - // Non-varchar blob columns - for (idx_t i = 0; i < next; i++) { - // Overwrite the column data pointer with the within-row offset - Store(UnsafeNumericCast(Load(col_ptr) - heap_row_ptrs[i]), col_ptr); - col_ptr += row_width; - } - } - } - done += next; - } -} - -void RowOperations::SwizzleHeapPointer(const RowLayout &layout, data_ptr_t row_ptr, const data_ptr_t heap_base_ptr, - const idx_t count, const idx_t base_offset) { - const idx_t row_width = layout.GetRowWidth(); - row_ptr += layout.GetHeapOffset(); - idx_t cumulative_offset = 0; - for (idx_t i = 0; i < count; i++) { - Store(base_offset + cumulative_offset, row_ptr); - cumulative_offset += Load(heap_base_ptr + cumulative_offset); - row_ptr += row_width; - } -} - -void RowOperations::CopyHeapAndSwizzle(const RowLayout &layout, data_ptr_t row_ptr, const data_ptr_t heap_base_ptr, - data_ptr_t heap_ptr, const idx_t count) { - const auto row_width = layout.GetRowWidth(); - const auto heap_offset = layout.GetHeapOffset(); - for (idx_t i = 0; i < count; i++) { - // Figure out source and size - const auto source_heap_ptr = Load(row_ptr + heap_offset); - const auto size = Load(source_heap_ptr); - D_ASSERT(size >= sizeof(uint32_t)); - - // Copy and swizzle - memcpy(heap_ptr, source_heap_ptr, size); - Store(UnsafeNumericCast(heap_ptr - heap_base_ptr), row_ptr + heap_offset); - - // Increment for next iteration - row_ptr += row_width; - heap_ptr += size; - } -} - -void RowOperations::UnswizzleHeapPointer(const RowLayout &layout, const data_ptr_t base_row_ptr, - const data_ptr_t base_heap_ptr, const idx_t count) { - const auto row_width = layout.GetRowWidth(); - data_ptr_t heap_ptr_ptr = base_row_ptr + layout.GetHeapOffset(); - for (idx_t i = 0; i < count; i++) { - Store(base_heap_ptr + Load(heap_ptr_ptr), heap_ptr_ptr); - heap_ptr_ptr += row_width; - } -} - -static inline void VerifyUnswizzledString(const RowLayout &layout, const idx_t &col_idx, const data_ptr_t &row_ptr) { -#ifdef DEBUG - if (layout.GetTypes()[col_idx].id() != LogicalTypeId::VARCHAR) { - return; - } - idx_t entry_idx; - idx_t idx_in_entry; - ValidityBytes::GetEntryIndex(col_idx, entry_idx, idx_in_entry); - - ValidityBytes row_mask(row_ptr, layout.ColumnCount()); - if (row_mask.RowIsValid(row_mask.GetValidityEntry(entry_idx), idx_in_entry)) { - auto str = Load(row_ptr + layout.GetOffsets()[col_idx]); - str.Verify(); - } -#endif -} - -void RowOperations::UnswizzlePointers(const RowLayout &layout, const data_ptr_t base_row_ptr, - const data_ptr_t base_heap_ptr, const idx_t count) { - const idx_t row_width = layout.GetRowWidth(); - data_ptr_t heap_row_ptrs[STANDARD_VECTOR_SIZE]; - idx_t done = 0; - while (done != count) { - const idx_t next = MinValue(count - done, STANDARD_VECTOR_SIZE); - const data_ptr_t row_ptr = base_row_ptr + done * row_width; - // Restore heap row pointers - data_ptr_t heap_ptr_ptr = row_ptr + layout.GetHeapOffset(); - for (idx_t i = 0; i < next; i++) { - heap_row_ptrs[i] = base_heap_ptr + Load(heap_ptr_ptr); - Store(heap_row_ptrs[i], heap_ptr_ptr); - heap_ptr_ptr += row_width; - } - // Loop through the blob columns - for (idx_t col_idx = 0; col_idx < layout.ColumnCount(); col_idx++) { - auto physical_type = layout.GetTypes()[col_idx].InternalType(); - if (TypeIsConstantSize(physical_type)) { - continue; - } - data_ptr_t col_ptr = row_ptr + layout.GetOffsets()[col_idx]; - if (physical_type == PhysicalType::VARCHAR) { - data_ptr_t string_ptr = col_ptr + string_t::HEADER_SIZE; - for (idx_t i = 0; i < next; i++) { - if (Load(col_ptr) > string_t::INLINE_LENGTH) { - // Overwrite the string offset with the pointer (if not inlined) - Store(heap_row_ptrs[i] + Load(string_ptr), string_ptr); - VerifyUnswizzledString(layout, col_idx, row_ptr + i * row_width); - } - col_ptr += row_width; - string_ptr += row_width; - } - } else { - // Non-varchar blob columns - for (idx_t i = 0; i < next; i++) { - // Overwrite the column data offset with the pointer - Store(heap_row_ptrs[i] + Load(col_ptr), col_ptr); - col_ptr += row_width; - } - } - } - done += next; - } -} - -} // namespace duckdb diff --git a/src/duckdb/src/common/row_operations/row_gather.cpp b/src/duckdb/src/common/row_operations/row_gather.cpp deleted file mode 100644 index 8e5ed315b..000000000 --- a/src/duckdb/src/common/row_operations/row_gather.cpp +++ /dev/null @@ -1,176 +0,0 @@ -#include "duckdb/common/exception.hpp" -#include "duckdb/common/operator/constant_operators.hpp" -#include "duckdb/common/row_operations/row_operations.hpp" -#include "duckdb/common/types/row/row_data_collection.hpp" -#include "duckdb/common/types/row/row_layout.hpp" -#include "duckdb/common/types/row/tuple_data_layout.hpp" -#include "duckdb/common/uhugeint.hpp" - -namespace duckdb { - -using ValidityBytes = RowLayout::ValidityBytes; - -template -static void TemplatedGatherLoop(Vector &rows, const SelectionVector &row_sel, Vector &col, - const SelectionVector &col_sel, idx_t count, const RowLayout &layout, idx_t col_no, - idx_t build_size) { - // Precompute mask indexes - const auto &offsets = layout.GetOffsets(); - const auto col_offset = offsets[col_no]; - idx_t entry_idx; - idx_t idx_in_entry; - ValidityBytes::GetEntryIndex(col_no, entry_idx, idx_in_entry); - - auto ptrs = FlatVector::GetData(rows); - auto data = FlatVector::GetData(col); - auto &col_mask = FlatVector::Validity(col); - - for (idx_t i = 0; i < count; i++) { - auto row_idx = row_sel.get_index(i); - auto row = ptrs[row_idx]; - auto col_idx = col_sel.get_index(i); - data[col_idx] = Load(row + col_offset); - ValidityBytes row_mask(row, layout.ColumnCount()); - if (!row_mask.RowIsValid(row_mask.GetValidityEntry(entry_idx), idx_in_entry)) { - if (build_size > STANDARD_VECTOR_SIZE && col_mask.AllValid()) { - //! We need to initialize the mask with the vector size. - col_mask.Initialize(build_size); - } - col_mask.SetInvalid(col_idx); - } - } -} - -static void GatherVarchar(Vector &rows, const SelectionVector &row_sel, Vector &col, const SelectionVector &col_sel, - idx_t count, const RowLayout &layout, idx_t col_no, idx_t build_size, - data_ptr_t base_heap_ptr) { - // Precompute mask indexes - const auto &offsets = layout.GetOffsets(); - const auto col_offset = offsets[col_no]; - const auto heap_offset = layout.GetHeapOffset(); - idx_t entry_idx; - idx_t idx_in_entry; - ValidityBytes::GetEntryIndex(col_no, entry_idx, idx_in_entry); - - auto ptrs = FlatVector::GetData(rows); - auto data = FlatVector::GetData(col); - auto &col_mask = FlatVector::Validity(col); - - for (idx_t i = 0; i < count; i++) { - auto row_idx = row_sel.get_index(i); - auto row = ptrs[row_idx]; - auto col_idx = col_sel.get_index(i); - auto col_ptr = row + col_offset; - data[col_idx] = Load(col_ptr); - ValidityBytes row_mask(row, layout.ColumnCount()); - if (!row_mask.RowIsValid(row_mask.GetValidityEntry(entry_idx), idx_in_entry)) { - if (build_size > STANDARD_VECTOR_SIZE && col_mask.AllValid()) { - //! We need to initialize the mask with the vector size. - col_mask.Initialize(build_size); - } - col_mask.SetInvalid(col_idx); - } else if (base_heap_ptr && Load(col_ptr) > string_t::INLINE_LENGTH) { - // Not inline, so unswizzle the copied pointer the pointer - auto heap_ptr_ptr = row + heap_offset; - auto heap_row_ptr = base_heap_ptr + Load(heap_ptr_ptr); - auto string_ptr = data_ptr_t(data + col_idx) + string_t::HEADER_SIZE; - Store(heap_row_ptr + Load(string_ptr), string_ptr); -#ifdef DEBUG - data[col_idx].Verify(); -#endif - } - } -} - -static void GatherNestedVector(Vector &rows, const SelectionVector &row_sel, Vector &col, - const SelectionVector &col_sel, idx_t count, const RowLayout &layout, idx_t col_no, - data_ptr_t base_heap_ptr) { - const auto &offsets = layout.GetOffsets(); - const auto col_offset = offsets[col_no]; - const auto heap_offset = layout.GetHeapOffset(); - auto ptrs = FlatVector::GetData(rows); - - // Build the gather locations - auto data_locations = make_unsafe_uniq_array_uninitialized(count); - auto mask_locations = make_unsafe_uniq_array_uninitialized(count); - for (idx_t i = 0; i < count; i++) { - auto row_idx = row_sel.get_index(i); - auto row = ptrs[row_idx]; - mask_locations[i] = row; - auto col_ptr = ptrs[row_idx] + col_offset; - if (base_heap_ptr) { - auto heap_ptr_ptr = row + heap_offset; - auto heap_row_ptr = base_heap_ptr + Load(heap_ptr_ptr); - data_locations[i] = heap_row_ptr + Load(col_ptr); - } else { - data_locations[i] = Load(col_ptr); - } - } - - // Deserialise into the selected locations - NestedValidity parent_validity(mask_locations.get(), col_no); - RowOperations::HeapGather(col, count, col_sel, data_locations.get(), &parent_validity); -} - -void RowOperations::Gather(Vector &rows, const SelectionVector &row_sel, Vector &col, const SelectionVector &col_sel, - const idx_t count, const RowLayout &layout, const idx_t col_no, const idx_t build_size, - data_ptr_t heap_ptr) { - D_ASSERT(rows.GetVectorType() == VectorType::FLAT_VECTOR); - D_ASSERT(rows.GetType().id() == LogicalTypeId::POINTER); // "Cannot gather from non-pointer type!" - - col.SetVectorType(VectorType::FLAT_VECTOR); - switch (col.GetType().InternalType()) { - case PhysicalType::UINT8: - TemplatedGatherLoop(rows, row_sel, col, col_sel, count, layout, col_no, build_size); - break; - case PhysicalType::UINT16: - TemplatedGatherLoop(rows, row_sel, col, col_sel, count, layout, col_no, build_size); - break; - case PhysicalType::UINT32: - TemplatedGatherLoop(rows, row_sel, col, col_sel, count, layout, col_no, build_size); - break; - case PhysicalType::UINT64: - TemplatedGatherLoop(rows, row_sel, col, col_sel, count, layout, col_no, build_size); - break; - case PhysicalType::UINT128: - TemplatedGatherLoop(rows, row_sel, col, col_sel, count, layout, col_no, build_size); - break; - case PhysicalType::BOOL: - case PhysicalType::INT8: - TemplatedGatherLoop(rows, row_sel, col, col_sel, count, layout, col_no, build_size); - break; - case PhysicalType::INT16: - TemplatedGatherLoop(rows, row_sel, col, col_sel, count, layout, col_no, build_size); - break; - case PhysicalType::INT32: - TemplatedGatherLoop(rows, row_sel, col, col_sel, count, layout, col_no, build_size); - break; - case PhysicalType::INT64: - TemplatedGatherLoop(rows, row_sel, col, col_sel, count, layout, col_no, build_size); - break; - case PhysicalType::INT128: - TemplatedGatherLoop(rows, row_sel, col, col_sel, count, layout, col_no, build_size); - break; - case PhysicalType::FLOAT: - TemplatedGatherLoop(rows, row_sel, col, col_sel, count, layout, col_no, build_size); - break; - case PhysicalType::DOUBLE: - TemplatedGatherLoop(rows, row_sel, col, col_sel, count, layout, col_no, build_size); - break; - case PhysicalType::INTERVAL: - TemplatedGatherLoop(rows, row_sel, col, col_sel, count, layout, col_no, build_size); - break; - case PhysicalType::VARCHAR: - GatherVarchar(rows, row_sel, col, col_sel, count, layout, col_no, build_size, heap_ptr); - break; - case PhysicalType::LIST: - case PhysicalType::STRUCT: - case PhysicalType::ARRAY: - GatherNestedVector(rows, row_sel, col, col_sel, count, layout, col_no, heap_ptr); - break; - default: - throw InternalException("Unimplemented type for RowOperations::Gather"); - } -} - -} // namespace duckdb diff --git a/src/duckdb/src/common/row_operations/row_heap_gather.cpp b/src/duckdb/src/common/row_operations/row_heap_gather.cpp deleted file mode 100644 index fa433c64e..000000000 --- a/src/duckdb/src/common/row_operations/row_heap_gather.cpp +++ /dev/null @@ -1,276 +0,0 @@ -#include "duckdb/common/helper.hpp" -#include "duckdb/common/row_operations/row_operations.hpp" -#include "duckdb/common/types/vector.hpp" - -namespace duckdb { - -using ValidityBytes = TemplatedValidityMask; - -template -static void TemplatedHeapGather(Vector &v, const idx_t count, const SelectionVector &sel, data_ptr_t *key_locations) { - auto target = FlatVector::GetData(v); - - for (idx_t i = 0; i < count; ++i) { - const auto col_idx = sel.get_index(i); - target[col_idx] = Load(key_locations[i]); - key_locations[i] += sizeof(T); - } -} - -static void HeapGatherStringVector(Vector &v, const idx_t vcount, const SelectionVector &sel, - data_ptr_t *key_locations) { - const auto &validity = FlatVector::Validity(v); - auto target = FlatVector::GetData(v); - - for (idx_t i = 0; i < vcount; i++) { - const auto col_idx = sel.get_index(i); - if (!validity.RowIsValid(col_idx)) { - continue; - } - auto len = Load(key_locations[i]); - key_locations[i] += sizeof(uint32_t); - target[col_idx] = StringVector::AddStringOrBlob(v, string_t(const_char_ptr_cast(key_locations[i]), len)); - key_locations[i] += len; - } -} - -static void HeapGatherStructVector(Vector &v, const idx_t vcount, const SelectionVector &sel, - data_ptr_t *key_locations) { - // struct must have a validitymask for its fields - auto &child_types = StructType::GetChildTypes(v.GetType()); - const idx_t struct_validitymask_size = (child_types.size() + 7) / 8; - data_ptr_t struct_validitymask_locations[STANDARD_VECTOR_SIZE]; - for (idx_t i = 0; i < vcount; i++) { - // use key_locations as the validitymask, and create struct_key_locations - struct_validitymask_locations[i] = key_locations[i]; - key_locations[i] += struct_validitymask_size; - } - - // now deserialize into the struct vectors - auto &children = StructVector::GetEntries(v); - for (idx_t i = 0; i < child_types.size(); i++) { - NestedValidity parent_validity(struct_validitymask_locations, i); - RowOperations::HeapGather(*children[i], vcount, sel, key_locations, &parent_validity); - } -} - -static void HeapGatherListVector(Vector &v, const idx_t vcount, const SelectionVector &sel, data_ptr_t *key_locations) { - const auto &validity = FlatVector::Validity(v); - - auto child_type = ListType::GetChildType(v.GetType()); - auto list_data = ListVector::GetData(v); - data_ptr_t list_entry_locations[STANDARD_VECTOR_SIZE]; - - uint64_t entry_offset = ListVector::GetListSize(v); - for (idx_t i = 0; i < vcount; i++) { - const auto col_idx = sel.get_index(i); - if (!validity.RowIsValid(col_idx)) { - continue; - } - // read list length - auto entry_remaining = Load(key_locations[i]); - key_locations[i] += sizeof(uint64_t); - // set list entry attributes - list_data[col_idx].length = entry_remaining; - list_data[col_idx].offset = entry_offset; - // skip over the validity mask - data_ptr_t validitymask_location = key_locations[i]; - idx_t offset_in_byte = 0; - key_locations[i] += (entry_remaining + 7) / 8; - // entry sizes - data_ptr_t var_entry_size_ptr = nullptr; - if (!TypeIsConstantSize(child_type.InternalType())) { - var_entry_size_ptr = key_locations[i]; - key_locations[i] += entry_remaining * sizeof(idx_t); - } - - // now read the list data - while (entry_remaining > 0) { - auto next = MinValue(entry_remaining, (idx_t)STANDARD_VECTOR_SIZE); - - // initialize a new vector to append - Vector append_vector(v.GetType()); - append_vector.SetVectorType(v.GetVectorType()); - - auto &list_vec_to_append = ListVector::GetEntry(append_vector); - - // set validity - //! Since we are constructing the vector, this will always be a flat vector. - auto &append_validity = FlatVector::Validity(list_vec_to_append); - for (idx_t entry_idx = 0; entry_idx < next; entry_idx++) { - append_validity.Set(entry_idx, *(validitymask_location) & (1 << offset_in_byte)); - if (++offset_in_byte == 8) { - validitymask_location++; - offset_in_byte = 0; - } - } - - // compute entry sizes and set locations where the list entries are - if (TypeIsConstantSize(child_type.InternalType())) { - // constant size list entries - const idx_t type_size = GetTypeIdSize(child_type.InternalType()); - for (idx_t entry_idx = 0; entry_idx < next; entry_idx++) { - list_entry_locations[entry_idx] = key_locations[i]; - key_locations[i] += type_size; - } - } else { - // variable size list entries - for (idx_t entry_idx = 0; entry_idx < next; entry_idx++) { - list_entry_locations[entry_idx] = key_locations[i]; - key_locations[i] += Load(var_entry_size_ptr); - var_entry_size_ptr += sizeof(idx_t); - } - } - - // now deserialize and add to listvector - RowOperations::HeapGather(list_vec_to_append, next, *FlatVector::IncrementalSelectionVector(), - list_entry_locations, nullptr); - ListVector::Append(v, list_vec_to_append, next); - - // update for next iteration - entry_remaining -= next; - entry_offset += next; - } - } -} - -static void HeapGatherArrayVector(Vector &v, const idx_t vcount, const SelectionVector &sel, - data_ptr_t *key_locations) { - // Setup - auto &child_type = ArrayType::GetChildType(v.GetType()); - auto array_size = ArrayType::GetSize(v.GetType()); - auto &child_vector = ArrayVector::GetEntry(v); - auto child_type_size = GetTypeIdSize(child_type.InternalType()); - auto child_type_is_var_size = !TypeIsConstantSize(child_type.InternalType()); - - data_ptr_t array_entry_locations[STANDARD_VECTOR_SIZE]; - - // array must have a validitymask for its elements - auto array_validitymask_size = (array_size + 7) / 8; - - for (idx_t i = 0; i < vcount; i++) { - // Setup validity mask - data_ptr_t array_validitymask_location = key_locations[i]; - key_locations[i] += array_validitymask_size; - - NestedValidity parent_validity(array_validitymask_location); - - // The size of each variable size entry is stored after the validity mask - // (if the child type is variable size) - data_ptr_t var_entry_size_ptr = nullptr; - if (child_type_is_var_size) { - var_entry_size_ptr = key_locations[i]; - key_locations[i] += array_size * sizeof(idx_t); - } - - // row idx - const auto row_idx = sel.get_index(i); - - idx_t array_start = row_idx * array_size; - idx_t elem_remaining = array_size; - - while (elem_remaining > 0) { - auto chunk_size = MinValue(static_cast(STANDARD_VECTOR_SIZE), elem_remaining); - - SelectionVector array_sel(STANDARD_VECTOR_SIZE); - - if (child_type_is_var_size) { - // variable size list entries - for (idx_t elem_idx = 0; elem_idx < chunk_size; elem_idx++) { - array_entry_locations[elem_idx] = key_locations[i]; - key_locations[i] += Load(var_entry_size_ptr); - var_entry_size_ptr += sizeof(idx_t); - array_sel.set_index(elem_idx, array_start + elem_idx); - } - } else { - // constant size list entries - for (idx_t elem_idx = 0; elem_idx < chunk_size; elem_idx++) { - array_entry_locations[elem_idx] = key_locations[i]; - key_locations[i] += child_type_size; - array_sel.set_index(elem_idx, array_start + elem_idx); - } - } - - // Pass on this array's validity mask to the child vector - RowOperations::HeapGather(child_vector, chunk_size, array_sel, array_entry_locations, &parent_validity); - - elem_remaining -= chunk_size; - array_start += chunk_size; - parent_validity.OffsetListBy(chunk_size); - } - } -} - -void RowOperations::HeapGather(Vector &v, const idx_t &vcount, const SelectionVector &sel, data_ptr_t *key_locations, - optional_ptr parent_validity) { - v.SetVectorType(VectorType::FLAT_VECTOR); - - auto &validity = FlatVector::Validity(v); - if (parent_validity) { - for (idx_t i = 0; i < vcount; i++) { - const auto valid = parent_validity->IsValid(i); - const auto col_idx = sel.get_index(i); - validity.Set(col_idx, valid); - } - } - - auto type = v.GetType().InternalType(); - switch (type) { - case PhysicalType::BOOL: - case PhysicalType::INT8: - TemplatedHeapGather(v, vcount, sel, key_locations); - break; - case PhysicalType::INT16: - TemplatedHeapGather(v, vcount, sel, key_locations); - break; - case PhysicalType::INT32: - TemplatedHeapGather(v, vcount, sel, key_locations); - break; - case PhysicalType::INT64: - TemplatedHeapGather(v, vcount, sel, key_locations); - break; - case PhysicalType::UINT8: - TemplatedHeapGather(v, vcount, sel, key_locations); - break; - case PhysicalType::UINT16: - TemplatedHeapGather(v, vcount, sel, key_locations); - break; - case PhysicalType::UINT32: - TemplatedHeapGather(v, vcount, sel, key_locations); - break; - case PhysicalType::UINT64: - TemplatedHeapGather(v, vcount, sel, key_locations); - break; - case PhysicalType::INT128: - TemplatedHeapGather(v, vcount, sel, key_locations); - break; - case PhysicalType::UINT128: - TemplatedHeapGather(v, vcount, sel, key_locations); - break; - case PhysicalType::FLOAT: - TemplatedHeapGather(v, vcount, sel, key_locations); - break; - case PhysicalType::DOUBLE: - TemplatedHeapGather(v, vcount, sel, key_locations); - break; - case PhysicalType::INTERVAL: - TemplatedHeapGather(v, vcount, sel, key_locations); - break; - case PhysicalType::VARCHAR: - HeapGatherStringVector(v, vcount, sel, key_locations); - break; - case PhysicalType::STRUCT: - HeapGatherStructVector(v, vcount, sel, key_locations); - break; - case PhysicalType::LIST: - HeapGatherListVector(v, vcount, sel, key_locations); - break; - case PhysicalType::ARRAY: - HeapGatherArrayVector(v, vcount, sel, key_locations); - break; - default: - throw NotImplementedException("Unimplemented deserialize from row-format"); - } -} - -} // namespace duckdb diff --git a/src/duckdb/src/common/row_operations/row_heap_scatter.cpp b/src/duckdb/src/common/row_operations/row_heap_scatter.cpp deleted file mode 100644 index 01cf7b589..000000000 --- a/src/duckdb/src/common/row_operations/row_heap_scatter.cpp +++ /dev/null @@ -1,581 +0,0 @@ -#include "duckdb/common/helper.hpp" -#include "duckdb/common/row_operations/row_operations.hpp" -#include "duckdb/common/types/vector.hpp" -#include "duckdb/common/uhugeint.hpp" - -namespace duckdb { - -using ValidityBytes = TemplatedValidityMask; - -NestedValidity::NestedValidity(data_ptr_t validitymask_location) - : list_validity_location(validitymask_location), struct_validity_locations(nullptr), entry_idx(0), idx_in_entry(0), - list_validity_offset(0) { -} - -NestedValidity::NestedValidity(data_ptr_t *validitymask_locations, idx_t child_vector_index) - : list_validity_location(nullptr), struct_validity_locations(validitymask_locations), entry_idx(0), idx_in_entry(0), - list_validity_offset(0) { - ValidityBytes::GetEntryIndex(child_vector_index, entry_idx, idx_in_entry); -} - -void NestedValidity::SetInvalid(idx_t idx) { - if (list_validity_location) { - // Is List - - idx = idx + list_validity_offset; - - idx_t list_entry_idx; - idx_t list_idx_in_entry; - ValidityBytes::GetEntryIndex(idx, list_entry_idx, list_idx_in_entry); - const auto bit = ~(1UL << list_idx_in_entry); - list_validity_location[list_entry_idx] &= bit; - } else { - // Is Struct - const auto bit = ~(1UL << idx_in_entry); - *(struct_validity_locations[idx] + entry_idx) &= bit; - } -} - -void NestedValidity::OffsetListBy(idx_t offset) { - list_validity_offset += offset; -} - -bool NestedValidity::IsValid(idx_t idx) { - if (list_validity_location) { - // Is List - - idx = idx + list_validity_offset; - - idx_t list_entry_idx; - idx_t list_idx_in_entry; - ValidityBytes::GetEntryIndex(idx, list_entry_idx, list_idx_in_entry); - const auto bit = (1UL << list_idx_in_entry); - return list_validity_location[list_entry_idx] & bit; - } else { - // Is Struct - const auto bit = (1UL << idx_in_entry); - return *(struct_validity_locations[idx] + entry_idx) & bit; - } -} - -static void ComputeStringEntrySizes(UnifiedVectorFormat &vdata, idx_t entry_sizes[], const idx_t ser_count, - const SelectionVector &sel, const idx_t offset) { - auto strings = UnifiedVectorFormat::GetData(vdata); - for (idx_t i = 0; i < ser_count; i++) { - auto idx = sel.get_index(i); - auto str_idx = vdata.sel->get_index(idx + offset); - if (vdata.validity.RowIsValid(str_idx)) { - entry_sizes[i] += sizeof(uint32_t) + strings[str_idx].GetSize(); - } - } -} - -static void ComputeStructEntrySizes(Vector &v, idx_t entry_sizes[], idx_t vcount, idx_t ser_count, - const SelectionVector &sel, idx_t offset) { - // obtain child vectors - idx_t num_children; - auto &children = StructVector::GetEntries(v); - num_children = children.size(); - // add struct validitymask size - const idx_t struct_validitymask_size = (num_children + 7) / 8; - for (idx_t i = 0; i < ser_count; i++) { - entry_sizes[i] += struct_validitymask_size; - } - // compute size of child vectors - for (auto &struct_vector : children) { - RowOperations::ComputeEntrySizes(*struct_vector, entry_sizes, vcount, ser_count, sel, offset); - } -} - -static void ComputeListEntrySizes(Vector &v, UnifiedVectorFormat &vdata, idx_t entry_sizes[], idx_t ser_count, - const SelectionVector &sel, idx_t offset) { - auto list_data = ListVector::GetData(v); - auto &child_vector = ListVector::GetEntry(v); - idx_t list_entry_sizes[STANDARD_VECTOR_SIZE]; - for (idx_t i = 0; i < ser_count; i++) { - auto idx = sel.get_index(i); - auto source_idx = vdata.sel->get_index(idx + offset); - if (vdata.validity.RowIsValid(source_idx)) { - auto list_entry = list_data[source_idx]; - - // make room for list length, list validitymask - entry_sizes[i] += sizeof(list_entry.length); - entry_sizes[i] += (list_entry.length + 7) / 8; - - // serialize size of each entry (if non-constant size) - if (!TypeIsConstantSize(ListType::GetChildType(v.GetType()).InternalType())) { - entry_sizes[i] += list_entry.length * sizeof(list_entry.length); - } - - // compute size of each the elements in list_entry and sum them - auto entry_remaining = list_entry.length; - auto entry_offset = list_entry.offset; - while (entry_remaining > 0) { - // the list entry can span multiple vectors - auto next = MinValue((idx_t)STANDARD_VECTOR_SIZE, entry_remaining); - - // compute and add to the total - std::fill_n(list_entry_sizes, next, 0); - RowOperations::ComputeEntrySizes(child_vector, list_entry_sizes, next, next, - *FlatVector::IncrementalSelectionVector(), entry_offset); - for (idx_t list_idx = 0; list_idx < next; list_idx++) { - entry_sizes[i] += list_entry_sizes[list_idx]; - } - - // update for next iteration - entry_remaining -= next; - entry_offset += next; - } - } - } -} - -static void ComputeArrayEntrySizes(Vector &v, UnifiedVectorFormat &vdata, idx_t entry_sizes[], idx_t ser_count, - const SelectionVector &sel, idx_t offset) { - - auto array_size = ArrayType::GetSize(v.GetType()); - auto child_vector = ArrayVector::GetEntry(v); - - idx_t array_entry_sizes[STANDARD_VECTOR_SIZE]; - const idx_t array_validitymask_size = (array_size + 7) / 8; - - for (idx_t i = 0; i < ser_count; i++) { - - // Validity for the array elements - entry_sizes[i] += array_validitymask_size; - - // serialize size of each entry (if non-constant size) - if (!TypeIsConstantSize(ArrayType::GetChildType(v.GetType()).InternalType())) { - entry_sizes[i] += array_size * sizeof(idx_t); - } - - auto elem_idx = sel.get_index(i); - auto source_idx = vdata.sel->get_index(elem_idx + offset); - - auto array_start = source_idx * array_size; - auto elem_remaining = array_size; - - // the array could span multiple vectors, so we divide it into chunks - while (elem_remaining > 0) { - auto chunk_size = MinValue(static_cast(STANDARD_VECTOR_SIZE), elem_remaining); - - // compute and add to the total - std::fill_n(array_entry_sizes, chunk_size, 0); - RowOperations::ComputeEntrySizes(child_vector, array_entry_sizes, chunk_size, chunk_size, - *FlatVector::IncrementalSelectionVector(), array_start); - for (idx_t arr_elem_idx = 0; arr_elem_idx < chunk_size; arr_elem_idx++) { - entry_sizes[i] += array_entry_sizes[arr_elem_idx]; - } - // update for next iteration - elem_remaining -= chunk_size; - array_start += chunk_size; - } - } -} - -void RowOperations::ComputeEntrySizes(Vector &v, UnifiedVectorFormat &vdata, idx_t entry_sizes[], idx_t vcount, - idx_t ser_count, const SelectionVector &sel, idx_t offset) { - const auto physical_type = v.GetType().InternalType(); - if (TypeIsConstantSize(physical_type)) { - const auto type_size = GetTypeIdSize(physical_type); - for (idx_t i = 0; i < ser_count; i++) { - entry_sizes[i] += type_size; - } - } else { - switch (physical_type) { - case PhysicalType::VARCHAR: - ComputeStringEntrySizes(vdata, entry_sizes, ser_count, sel, offset); - break; - case PhysicalType::STRUCT: - ComputeStructEntrySizes(v, entry_sizes, vcount, ser_count, sel, offset); - break; - case PhysicalType::LIST: - ComputeListEntrySizes(v, vdata, entry_sizes, ser_count, sel, offset); - break; - case PhysicalType::ARRAY: - ComputeArrayEntrySizes(v, vdata, entry_sizes, ser_count, sel, offset); - break; - default: - // LCOV_EXCL_START - throw NotImplementedException("Column with variable size type %s cannot be serialized to row-format", - v.GetType().ToString()); - // LCOV_EXCL_STOP - } - } -} - -void RowOperations::ComputeEntrySizes(Vector &v, idx_t entry_sizes[], idx_t vcount, idx_t ser_count, - const SelectionVector &sel, idx_t offset) { - UnifiedVectorFormat vdata; - v.ToUnifiedFormat(vcount, vdata); - ComputeEntrySizes(v, vdata, entry_sizes, vcount, ser_count, sel, offset); -} - -template -static void TemplatedHeapScatter(UnifiedVectorFormat &vdata, const SelectionVector &sel, idx_t count, - data_ptr_t *key_locations, optional_ptr parent_validity, - idx_t offset) { - auto source = UnifiedVectorFormat::GetData(vdata); - if (!parent_validity) { - for (idx_t i = 0; i < count; i++) { - auto idx = sel.get_index(i); - auto source_idx = vdata.sel->get_index(idx + offset); - - auto target = (T *)key_locations[i]; - Store(source[source_idx], data_ptr_cast(target)); - key_locations[i] += sizeof(T); - } - } else { - for (idx_t i = 0; i < count; i++) { - auto idx = sel.get_index(i); - auto source_idx = vdata.sel->get_index(idx + offset); - - auto target = (T *)key_locations[i]; - Store(source[source_idx], data_ptr_cast(target)); - key_locations[i] += sizeof(T); - - // set the validitymask - if (!vdata.validity.RowIsValid(source_idx)) { - parent_validity->SetInvalid(i); - } - } - } -} - -static void HeapScatterStringVector(Vector &v, idx_t vcount, const SelectionVector &sel, idx_t ser_count, - data_ptr_t *key_locations, optional_ptr parent_validity, - idx_t offset) { - UnifiedVectorFormat vdata; - v.ToUnifiedFormat(vcount, vdata); - - auto strings = UnifiedVectorFormat::GetData(vdata); - if (!parent_validity) { - for (idx_t i = 0; i < ser_count; i++) { - auto idx = sel.get_index(i); - auto source_idx = vdata.sel->get_index(idx + offset); - if (vdata.validity.RowIsValid(source_idx)) { - auto &string_entry = strings[source_idx]; - // store string size - Store(NumericCast(string_entry.GetSize()), key_locations[i]); - key_locations[i] += sizeof(uint32_t); - // store the string - memcpy(key_locations[i], string_entry.GetData(), string_entry.GetSize()); - key_locations[i] += string_entry.GetSize(); - } - } - } else { - for (idx_t i = 0; i < ser_count; i++) { - auto idx = sel.get_index(i); - auto source_idx = vdata.sel->get_index(idx + offset); - if (vdata.validity.RowIsValid(source_idx)) { - auto &string_entry = strings[source_idx]; - // store string size - Store(NumericCast(string_entry.GetSize()), key_locations[i]); - key_locations[i] += sizeof(uint32_t); - // store the string - memcpy(key_locations[i], string_entry.GetData(), string_entry.GetSize()); - key_locations[i] += string_entry.GetSize(); - } else { - // set the validitymask - parent_validity->SetInvalid(i); - } - } - } -} - -static void HeapScatterStructVector(Vector &v, idx_t vcount, const SelectionVector &sel, idx_t ser_count, - data_ptr_t *key_locations, optional_ptr parent_validity, - idx_t offset) { - UnifiedVectorFormat vdata; - v.ToUnifiedFormat(vcount, vdata); - - auto &children = StructVector::GetEntries(v); - idx_t num_children = children.size(); - - // struct must have a validitymask for its fields - const idx_t struct_validitymask_size = (num_children + 7) / 8; - data_ptr_t struct_validitymask_locations[STANDARD_VECTOR_SIZE]; - for (idx_t i = 0; i < ser_count; i++) { - // initialize the struct validity mask - struct_validitymask_locations[i] = key_locations[i]; - memset(struct_validitymask_locations[i], -1, struct_validitymask_size); - key_locations[i] += struct_validitymask_size; - - // set whether the whole struct is null - auto idx = sel.get_index(i); - auto source_idx = vdata.sel->get_index(idx) + offset; - if (parent_validity && !vdata.validity.RowIsValid(source_idx)) { - parent_validity->SetInvalid(i); - } - } - - // now serialize the struct vectors - for (idx_t i = 0; i < children.size(); i++) { - auto &struct_vector = *children[i]; - NestedValidity struct_validity(struct_validitymask_locations, i); - RowOperations::HeapScatter(struct_vector, vcount, sel, ser_count, key_locations, &struct_validity, offset); - } -} - -static void HeapScatterListVector(Vector &v, idx_t vcount, const SelectionVector &sel, idx_t ser_count, - data_ptr_t *key_locations, optional_ptr parent_validity, - idx_t offset) { - UnifiedVectorFormat vdata; - v.ToUnifiedFormat(vcount, vdata); - - auto list_data = ListVector::GetData(v); - auto &child_vector = ListVector::GetEntry(v); - - UnifiedVectorFormat list_vdata; - child_vector.ToUnifiedFormat(ListVector::GetListSize(v), list_vdata); - auto child_type = ListType::GetChildType(v.GetType()).InternalType(); - - idx_t list_entry_sizes[STANDARD_VECTOR_SIZE]; - data_ptr_t list_entry_locations[STANDARD_VECTOR_SIZE]; - - for (idx_t i = 0; i < ser_count; i++) { - auto idx = sel.get_index(i); - auto source_idx = vdata.sel->get_index(idx + offset); - if (!vdata.validity.RowIsValid(source_idx)) { - if (parent_validity) { - // set the row validitymask for this column to invalid - parent_validity->SetInvalid(i); - } - continue; - } - auto list_entry = list_data[source_idx]; - - // store list length - Store(list_entry.length, key_locations[i]); - key_locations[i] += sizeof(list_entry.length); - - // make room for the validitymask - data_ptr_t list_validitymask_location = key_locations[i]; - idx_t entry_offset_in_byte = 0; - idx_t validitymask_size = (list_entry.length + 7) / 8; - memset(list_validitymask_location, -1, validitymask_size); - key_locations[i] += validitymask_size; - - // serialize size of each entry (if non-constant size) - data_ptr_t var_entry_size_ptr = nullptr; - if (!TypeIsConstantSize(child_type)) { - var_entry_size_ptr = key_locations[i]; - key_locations[i] += list_entry.length * sizeof(idx_t); - } - - auto entry_remaining = list_entry.length; - auto entry_offset = list_entry.offset; - while (entry_remaining > 0) { - // the list entry can span multiple vectors - auto next = MinValue((idx_t)STANDARD_VECTOR_SIZE, entry_remaining); - - // serialize list validity - for (idx_t entry_idx = 0; entry_idx < next; entry_idx++) { - auto list_idx = list_vdata.sel->get_index(entry_idx + entry_offset); - if (!list_vdata.validity.RowIsValid(list_idx)) { - *(list_validitymask_location) &= ~(1UL << entry_offset_in_byte); - } - if (++entry_offset_in_byte == 8) { - list_validitymask_location++; - entry_offset_in_byte = 0; - } - } - - if (TypeIsConstantSize(child_type)) { - // constant size list entries: set list entry locations - const idx_t type_size = GetTypeIdSize(child_type); - for (idx_t entry_idx = 0; entry_idx < next; entry_idx++) { - list_entry_locations[entry_idx] = key_locations[i]; - key_locations[i] += type_size; - } - } else { - // variable size list entries: compute entry sizes and set list entry locations - std::fill_n(list_entry_sizes, next, 0); - RowOperations::ComputeEntrySizes(child_vector, list_entry_sizes, next, next, - *FlatVector::IncrementalSelectionVector(), entry_offset); - for (idx_t entry_idx = 0; entry_idx < next; entry_idx++) { - list_entry_locations[entry_idx] = key_locations[i]; - key_locations[i] += list_entry_sizes[entry_idx]; - Store(list_entry_sizes[entry_idx], var_entry_size_ptr); - var_entry_size_ptr += sizeof(idx_t); - } - } - - // now serialize to the locations - RowOperations::HeapScatter(child_vector, ListVector::GetListSize(v), - *FlatVector::IncrementalSelectionVector(), next, list_entry_locations, nullptr, - entry_offset); - - // update for next iteration - entry_remaining -= next; - entry_offset += next; - } - } -} - -static void HeapScatterArrayVector(Vector &v, idx_t vcount, const SelectionVector &sel, idx_t ser_count, - data_ptr_t *key_locations, optional_ptr parent_validity, - idx_t offset) { - - auto &child_vector = ArrayVector::GetEntry(v); - auto array_size = ArrayType::GetSize(v.GetType()); - auto child_type = ArrayType::GetChildType(v.GetType()); - auto child_type_size = GetTypeIdSize(child_type.InternalType()); - auto child_type_is_var_size = !TypeIsConstantSize(child_type.InternalType()); - - UnifiedVectorFormat vdata; - v.ToUnifiedFormat(vcount, vdata); - - UnifiedVectorFormat child_vdata; - child_vector.ToUnifiedFormat(ArrayVector::GetTotalSize(v), child_vdata); - - data_ptr_t array_entry_locations[STANDARD_VECTOR_SIZE]; - idx_t array_entry_sizes[STANDARD_VECTOR_SIZE]; - - // array must have a validitymask for its elements - auto array_validitymask_size = (array_size + 7) / 8; - - for (idx_t i = 0; i < ser_count; i++) { - // Set if the whole array itself is null in the parent entry - auto source_idx = vdata.sel->get_index(sel.get_index(i) + offset); - if (parent_validity && !vdata.validity.RowIsValid(source_idx)) { - parent_validity->SetInvalid(i); - } - - // Now we can serialize the array itself - // Every array starts with a validity mask for the children - data_ptr_t array_validitymask_location = key_locations[i]; - memset(array_validitymask_location, -1, array_validitymask_size); - key_locations[i] += array_validitymask_size; - - NestedValidity array_parent_validity(array_validitymask_location); - - // If the array contains variable size entries, we reserve spaces for them here - data_ptr_t var_entry_size_ptr = nullptr; - if (child_type_is_var_size) { - var_entry_size_ptr = key_locations[i]; - key_locations[i] += array_size * sizeof(idx_t); - } - - // Then comes the elements - auto array_start = source_idx * array_size; - auto elem_remaining = array_size; - - while (elem_remaining > 0) { - // the array elements can span multiple vectors, so we divide it into chunks - auto chunk_size = MinValue(static_cast(STANDARD_VECTOR_SIZE), elem_remaining); - - // Setup the locations for the elements - if (child_type_is_var_size) { - // The elements are variable sized - std::fill_n(array_entry_sizes, chunk_size, 0); - RowOperations::ComputeEntrySizes(child_vector, array_entry_sizes, chunk_size, chunk_size, - *FlatVector::IncrementalSelectionVector(), array_start); - for (idx_t elem_idx = 0; elem_idx < chunk_size; elem_idx++) { - array_entry_locations[elem_idx] = key_locations[i]; - key_locations[i] += array_entry_sizes[elem_idx]; - - // Now store the size of the entry - Store(array_entry_sizes[elem_idx], var_entry_size_ptr); - var_entry_size_ptr += sizeof(idx_t); - } - } else { - // The elements are constant sized - for (idx_t elem_idx = 0; elem_idx < chunk_size; elem_idx++) { - array_entry_locations[elem_idx] = key_locations[i]; - key_locations[i] += child_type_size; - } - } - - RowOperations::HeapScatter(child_vector, ArrayVector::GetTotalSize(v), - *FlatVector::IncrementalSelectionVector(), chunk_size, array_entry_locations, - &array_parent_validity, array_start); - - // update for next iteration - elem_remaining -= chunk_size; - array_start += chunk_size; - array_parent_validity.OffsetListBy(chunk_size); - } - } -} - -void RowOperations::HeapScatter(Vector &v, idx_t vcount, const SelectionVector &sel, idx_t ser_count, - data_ptr_t *key_locations, optional_ptr parent_validity, idx_t offset) { - if (TypeIsConstantSize(v.GetType().InternalType())) { - UnifiedVectorFormat vdata; - v.ToUnifiedFormat(vcount, vdata); - RowOperations::HeapScatterVData(vdata, v.GetType().InternalType(), sel, ser_count, key_locations, - parent_validity, offset); - } else { - switch (v.GetType().InternalType()) { - case PhysicalType::VARCHAR: - HeapScatterStringVector(v, vcount, sel, ser_count, key_locations, parent_validity, offset); - break; - case PhysicalType::STRUCT: - HeapScatterStructVector(v, vcount, sel, ser_count, key_locations, parent_validity, offset); - break; - case PhysicalType::LIST: - HeapScatterListVector(v, vcount, sel, ser_count, key_locations, parent_validity, offset); - break; - case PhysicalType::ARRAY: - HeapScatterArrayVector(v, vcount, sel, ser_count, key_locations, parent_validity, offset); - break; - default: - // LCOV_EXCL_START - throw NotImplementedException("Serialization of variable length vector with type %s", - v.GetType().ToString()); - // LCOV_EXCL_STOP - } - } -} - -void RowOperations::HeapScatterVData(UnifiedVectorFormat &vdata, PhysicalType type, const SelectionVector &sel, - idx_t ser_count, data_ptr_t *key_locations, - optional_ptr parent_validity, idx_t offset) { - switch (type) { - case PhysicalType::BOOL: - case PhysicalType::INT8: - TemplatedHeapScatter(vdata, sel, ser_count, key_locations, parent_validity, offset); - break; - case PhysicalType::INT16: - TemplatedHeapScatter(vdata, sel, ser_count, key_locations, parent_validity, offset); - break; - case PhysicalType::INT32: - TemplatedHeapScatter(vdata, sel, ser_count, key_locations, parent_validity, offset); - break; - case PhysicalType::INT64: - TemplatedHeapScatter(vdata, sel, ser_count, key_locations, parent_validity, offset); - break; - case PhysicalType::UINT8: - TemplatedHeapScatter(vdata, sel, ser_count, key_locations, parent_validity, offset); - break; - case PhysicalType::UINT16: - TemplatedHeapScatter(vdata, sel, ser_count, key_locations, parent_validity, offset); - break; - case PhysicalType::UINT32: - TemplatedHeapScatter(vdata, sel, ser_count, key_locations, parent_validity, offset); - break; - case PhysicalType::UINT64: - TemplatedHeapScatter(vdata, sel, ser_count, key_locations, parent_validity, offset); - break; - case PhysicalType::INT128: - TemplatedHeapScatter(vdata, sel, ser_count, key_locations, parent_validity, offset); - break; - case PhysicalType::UINT128: - TemplatedHeapScatter(vdata, sel, ser_count, key_locations, parent_validity, offset); - break; - case PhysicalType::FLOAT: - TemplatedHeapScatter(vdata, sel, ser_count, key_locations, parent_validity, offset); - break; - case PhysicalType::DOUBLE: - TemplatedHeapScatter(vdata, sel, ser_count, key_locations, parent_validity, offset); - break; - case PhysicalType::INTERVAL: - TemplatedHeapScatter(vdata, sel, ser_count, key_locations, parent_validity, offset); - break; - default: - throw NotImplementedException("FIXME: Serialize to of constant type column to row-format"); - } -} - -} // namespace duckdb diff --git a/src/duckdb/src/common/row_operations/row_radix_scatter.cpp b/src/duckdb/src/common/row_operations/row_radix_scatter.cpp deleted file mode 100644 index a85a71997..000000000 --- a/src/duckdb/src/common/row_operations/row_radix_scatter.cpp +++ /dev/null @@ -1,360 +0,0 @@ -#include "duckdb/common/helper.hpp" -#include "duckdb/common/radix.hpp" -#include "duckdb/common/row_operations/row_operations.hpp" -#include "duckdb/common/types/vector.hpp" -#include "duckdb/common/uhugeint.hpp" - -namespace duckdb { - -template -void TemplatedRadixScatter(UnifiedVectorFormat &vdata, const SelectionVector &sel, idx_t add_count, - data_ptr_t *key_locations, const bool desc, const bool has_null, const bool nulls_first, - const idx_t offset) { - auto source = UnifiedVectorFormat::GetData(vdata); - if (has_null) { - auto &validity = vdata.validity; - const data_t valid = nulls_first ? 1 : 0; - const data_t invalid = 1 - valid; - - for (idx_t i = 0; i < add_count; i++) { - auto idx = sel.get_index(i); - auto source_idx = vdata.sel->get_index(idx) + offset; - // write validity and according value - if (validity.RowIsValid(source_idx)) { - key_locations[i][0] = valid; - Radix::EncodeData(key_locations[i] + 1, source[source_idx]); - // invert bits if desc - if (desc) { - for (idx_t s = 1; s < sizeof(T) + 1; s++) { - *(key_locations[i] + s) = ~*(key_locations[i] + s); - } - } - } else { - key_locations[i][0] = invalid; - memset(key_locations[i] + 1, '\0', sizeof(T)); - } - key_locations[i] += sizeof(T) + 1; - } - } else { - for (idx_t i = 0; i < add_count; i++) { - auto idx = sel.get_index(i); - auto source_idx = vdata.sel->get_index(idx) + offset; - // write value - Radix::EncodeData(key_locations[i], source[source_idx]); - // invert bits if desc - if (desc) { - for (idx_t s = 0; s < sizeof(T); s++) { - *(key_locations[i] + s) = ~*(key_locations[i] + s); - } - } - key_locations[i] += sizeof(T); - } - } -} - -void RadixScatterStringVector(UnifiedVectorFormat &vdata, const SelectionVector &sel, idx_t add_count, - data_ptr_t *key_locations, const bool desc, const bool has_null, const bool nulls_first, - const idx_t prefix_len, idx_t offset) { - auto source = UnifiedVectorFormat::GetData(vdata); - if (has_null) { - auto &validity = vdata.validity; - const data_t valid = nulls_first ? 1 : 0; - const data_t invalid = 1 - valid; - - for (idx_t i = 0; i < add_count; i++) { - auto idx = sel.get_index(i); - auto source_idx = vdata.sel->get_index(idx) + offset; - // write validity and according value - if (validity.RowIsValid(source_idx)) { - key_locations[i][0] = valid; - Radix::EncodeStringDataPrefix(key_locations[i] + 1, source[source_idx], prefix_len); - // invert bits if desc - if (desc) { - for (idx_t s = 1; s < prefix_len + 1; s++) { - *(key_locations[i] + s) = ~*(key_locations[i] + s); - } - } - } else { - key_locations[i][0] = invalid; - memset(key_locations[i] + 1, '\0', prefix_len); - } - key_locations[i] += prefix_len + 1; - } - } else { - for (idx_t i = 0; i < add_count; i++) { - auto idx = sel.get_index(i); - auto source_idx = vdata.sel->get_index(idx) + offset; - // write value - Radix::EncodeStringDataPrefix(key_locations[i], source[source_idx], prefix_len); - // invert bits if desc - if (desc) { - for (idx_t s = 0; s < prefix_len; s++) { - *(key_locations[i] + s) = ~*(key_locations[i] + s); - } - } - key_locations[i] += prefix_len; - } - } -} - -void RadixScatterListVector(Vector &v, UnifiedVectorFormat &vdata, const SelectionVector &sel, idx_t add_count, - data_ptr_t *key_locations, const bool desc, const bool has_null, const bool nulls_first, - const idx_t prefix_len, const idx_t width, const idx_t offset) { - auto list_data = ListVector::GetData(v); - auto &child_vector = ListVector::GetEntry(v); - auto list_size = ListVector::GetListSize(v); - child_vector.Flatten(list_size); - - // serialize null values - if (has_null) { - auto &validity = vdata.validity; - const data_t valid = nulls_first ? 1 : 0; - const data_t invalid = 1 - valid; - - for (idx_t i = 0; i < add_count; i++) { - auto idx = sel.get_index(i); - auto source_idx = vdata.sel->get_index(idx) + offset; - data_ptr_t &key_location = key_locations[i]; - const data_ptr_t key_location_start = key_location; - // write validity and according value - if (validity.RowIsValid(source_idx)) { - *key_location++ = valid; - auto &list_entry = list_data[source_idx]; - if (list_entry.length > 0) { - // denote that the list is not empty with a 1 - *key_location++ = 1; - RowOperations::RadixScatter(child_vector, list_size, *FlatVector::IncrementalSelectionVector(), 1, - key_locations + i, false, true, false, prefix_len, width - 2, - list_entry.offset); - } else { - // denote that the list is empty with a 0 - *key_location++ = 0; - // mark rest of bits as empty - memset(key_location, '\0', width - 2); - key_location += width - 2; - } - // invert bits if desc - if (desc) { - // skip over validity byte, handled by nulls first/last - for (key_location = key_location_start + 1; key_location < key_location_start + width; - key_location++) { - *key_location = ~*key_location; - } - } - } else { - *key_location++ = invalid; - memset(key_location, '\0', width - 1); - key_location += width - 1; - } - D_ASSERT(key_location == key_location_start + width); - } - } else { - for (idx_t i = 0; i < add_count; i++) { - auto idx = sel.get_index(i); - auto source_idx = vdata.sel->get_index(idx) + offset; - auto &list_entry = list_data[source_idx]; - data_ptr_t &key_location = key_locations[i]; - const data_ptr_t key_location_start = key_location; - if (list_entry.length > 0) { - // denote that the list is not empty with a 1 - *key_location++ = 1; - RowOperations::RadixScatter(child_vector, list_size, *FlatVector::IncrementalSelectionVector(), 1, - key_locations + i, false, true, false, prefix_len, width - 1, - list_entry.offset); - } else { - // denote that the list is empty with a 0 - *key_location++ = 0; - // mark rest of bits as empty - memset(key_location, '\0', width - 1); - key_location += width - 1; - } - // invert bits if desc - if (desc) { - for (key_location = key_location_start; key_location < key_location_start + width; key_location++) { - *key_location = ~*key_location; - } - } - D_ASSERT(key_location == key_location_start + width); - } - } -} - -void RadixScatterArrayVector(Vector &v, UnifiedVectorFormat &vdata, idx_t vcount, const SelectionVector &sel, - idx_t add_count, data_ptr_t *key_locations, const bool desc, const bool has_null, - const bool nulls_first, const idx_t prefix_len, idx_t width, const idx_t offset) { - auto &child_vector = ArrayVector::GetEntry(v); - auto array_size = ArrayType::GetSize(v.GetType()); - - if (has_null) { - auto &validity = vdata.validity; - const data_t valid = nulls_first ? 1 : 0; - const data_t invalid = 1 - valid; - - for (idx_t i = 0; i < add_count; i++) { - auto idx = sel.get_index(i); - auto source_idx = vdata.sel->get_index(idx) + offset; - data_ptr_t &key_location = key_locations[i]; - const data_ptr_t key_location_start = key_location; - - if (validity.RowIsValid(source_idx)) { - *key_location++ = valid; - - auto array_offset = source_idx * array_size; - RowOperations::RadixScatter(child_vector, array_size, *FlatVector::IncrementalSelectionVector(), 1, - key_locations + i, false, true, false, prefix_len, width - 1, array_offset); - - // invert bits if desc - if (desc) { - // skip over validity byte, handled by nulls first/last - for (key_location = key_location_start + 1; key_location < key_location_start + width; - key_location++) { - *key_location = ~*key_location; - } - } - } else { - *key_location++ = invalid; - memset(key_location, '\0', width - 1); - key_location += width - 1; - } - D_ASSERT(key_location == key_location_start + width); - } - } else { - for (idx_t i = 0; i < add_count; i++) { - auto idx = sel.get_index(i); - auto source_idx = vdata.sel->get_index(idx) + offset; - data_ptr_t &key_location = key_locations[i]; - const data_ptr_t key_location_start = key_location; - - auto array_offset = source_idx * array_size; - RowOperations::RadixScatter(child_vector, array_size, *FlatVector::IncrementalSelectionVector(), 1, - key_locations + i, false, true, false, prefix_len, width, array_offset); - // invert bits if desc - if (desc) { - for (key_location = key_location_start; key_location < key_location_start + width; key_location++) { - *key_location = ~*key_location; - } - } - D_ASSERT(key_location == key_location_start + width); - } - } -} - -void RadixScatterStructVector(Vector &v, UnifiedVectorFormat &vdata, idx_t vcount, const SelectionVector &sel, - idx_t add_count, data_ptr_t *key_locations, const bool desc, const bool has_null, - const bool nulls_first, const idx_t prefix_len, idx_t width, const idx_t offset) { - // serialize null values - if (has_null) { - auto &validity = vdata.validity; - const data_t valid = nulls_first ? 1 : 0; - const data_t invalid = 1 - valid; - - for (idx_t i = 0; i < add_count; i++) { - auto idx = sel.get_index(i); - auto source_idx = vdata.sel->get_index(idx) + offset; - - // write validity and according value - if (validity.RowIsValid(source_idx)) { - key_locations[i][0] = valid; - } else { - key_locations[i][0] = invalid; - memset(key_locations[i] + 1, '\0', width - 1); - } - key_locations[i]++; - } - width--; - } - // serialize the struct - auto &child_vector = *StructVector::GetEntries(v)[0]; - RowOperations::RadixScatter(child_vector, vcount, *FlatVector::IncrementalSelectionVector(), add_count, - key_locations, false, true, false, prefix_len, width, offset); - // invert bits if desc - if (desc) { - for (idx_t i = 0; i < add_count; i++) { - for (idx_t s = 0; s < width; s++) { - *(key_locations[i] - width + s) = ~*(key_locations[i] - width + s); - } - } - } -} - -void RowOperations::RadixScatter(Vector &v, idx_t vcount, const SelectionVector &sel, idx_t ser_count, - data_ptr_t *key_locations, bool desc, bool has_null, bool nulls_first, - idx_t prefix_len, idx_t width, idx_t offset) { -#ifdef DEBUG - // initialize to verify written width later - auto key_locations_copy = make_uniq_array(ser_count); - for (idx_t i = 0; i < ser_count; i++) { - key_locations_copy[i] = key_locations[i]; - } -#endif - - UnifiedVectorFormat vdata; - v.ToUnifiedFormat(vcount, vdata); - switch (v.GetType().InternalType()) { - case PhysicalType::BOOL: - case PhysicalType::INT8: - TemplatedRadixScatter(vdata, sel, ser_count, key_locations, desc, has_null, nulls_first, offset); - break; - case PhysicalType::INT16: - TemplatedRadixScatter(vdata, sel, ser_count, key_locations, desc, has_null, nulls_first, offset); - break; - case PhysicalType::INT32: - TemplatedRadixScatter(vdata, sel, ser_count, key_locations, desc, has_null, nulls_first, offset); - break; - case PhysicalType::INT64: - TemplatedRadixScatter(vdata, sel, ser_count, key_locations, desc, has_null, nulls_first, offset); - break; - case PhysicalType::UINT8: - TemplatedRadixScatter(vdata, sel, ser_count, key_locations, desc, has_null, nulls_first, offset); - break; - case PhysicalType::UINT16: - TemplatedRadixScatter(vdata, sel, ser_count, key_locations, desc, has_null, nulls_first, offset); - break; - case PhysicalType::UINT32: - TemplatedRadixScatter(vdata, sel, ser_count, key_locations, desc, has_null, nulls_first, offset); - break; - case PhysicalType::UINT64: - TemplatedRadixScatter(vdata, sel, ser_count, key_locations, desc, has_null, nulls_first, offset); - break; - case PhysicalType::INT128: - TemplatedRadixScatter(vdata, sel, ser_count, key_locations, desc, has_null, nulls_first, offset); - break; - case PhysicalType::UINT128: - TemplatedRadixScatter(vdata, sel, ser_count, key_locations, desc, has_null, nulls_first, offset); - break; - case PhysicalType::FLOAT: - TemplatedRadixScatter(vdata, sel, ser_count, key_locations, desc, has_null, nulls_first, offset); - break; - case PhysicalType::DOUBLE: - TemplatedRadixScatter(vdata, sel, ser_count, key_locations, desc, has_null, nulls_first, offset); - break; - case PhysicalType::INTERVAL: - TemplatedRadixScatter(vdata, sel, ser_count, key_locations, desc, has_null, nulls_first, offset); - break; - case PhysicalType::VARCHAR: - RadixScatterStringVector(vdata, sel, ser_count, key_locations, desc, has_null, nulls_first, prefix_len, offset); - break; - case PhysicalType::LIST: - RadixScatterListVector(v, vdata, sel, ser_count, key_locations, desc, has_null, nulls_first, prefix_len, width, - offset); - break; - case PhysicalType::STRUCT: - RadixScatterStructVector(v, vdata, vcount, sel, ser_count, key_locations, desc, has_null, nulls_first, - prefix_len, width, offset); - break; - case PhysicalType::ARRAY: - RadixScatterArrayVector(v, vdata, vcount, sel, ser_count, key_locations, desc, has_null, nulls_first, - prefix_len, width, offset); - break; - default: - throw NotImplementedException("Cannot ORDER BY column with type %s", v.GetType().ToString()); - } - -#ifdef DEBUG - for (idx_t i = 0; i < ser_count; i++) { - D_ASSERT(key_locations[i] == key_locations_copy[i] + width); - } -#endif -} - -} // namespace duckdb diff --git a/src/duckdb/src/common/row_operations/row_scatter.cpp b/src/duckdb/src/common/row_operations/row_scatter.cpp deleted file mode 100644 index 1912d2484..000000000 --- a/src/duckdb/src/common/row_operations/row_scatter.cpp +++ /dev/null @@ -1,230 +0,0 @@ -#include "duckdb/common/exception.hpp" -#include "duckdb/common/helper.hpp" -#include "duckdb/common/row_operations/row_operations.hpp" -#include "duckdb/common/types/null_value.hpp" -#include "duckdb/common/types/row/row_data_collection.hpp" -#include "duckdb/common/types/row/row_layout.hpp" -#include "duckdb/common/types/selection_vector.hpp" -#include "duckdb/common/types/vector.hpp" -#include "duckdb/common/uhugeint.hpp" - -namespace duckdb { - -using ValidityBytes = RowLayout::ValidityBytes; - -template -static void TemplatedScatter(UnifiedVectorFormat &col, Vector &rows, const SelectionVector &sel, const idx_t count, - const idx_t col_offset, const idx_t col_no, const idx_t col_count) { - auto data = UnifiedVectorFormat::GetData(col); - auto ptrs = FlatVector::GetData(rows); - - if (!col.validity.AllValid()) { - for (idx_t i = 0; i < count; i++) { - auto idx = sel.get_index(i); - auto col_idx = col.sel->get_index(idx); - auto row = ptrs[idx]; - - auto isnull = !col.validity.RowIsValid(col_idx); - T store_value = isnull ? NullValue() : data[col_idx]; - Store(store_value, row + col_offset); - if (isnull) { - ValidityBytes col_mask(ptrs[idx], col_count); - col_mask.SetInvalidUnsafe(col_no); - } - } - } else { - for (idx_t i = 0; i < count; i++) { - auto idx = sel.get_index(i); - auto col_idx = col.sel->get_index(idx); - auto row = ptrs[idx]; - - Store(data[col_idx], row + col_offset); - } - } -} - -static void ComputeStringEntrySizes(const UnifiedVectorFormat &col, idx_t entry_sizes[], const SelectionVector &sel, - const idx_t count, const idx_t offset = 0) { - auto data = UnifiedVectorFormat::GetData(col); - for (idx_t i = 0; i < count; i++) { - auto idx = sel.get_index(i); - auto col_idx = col.sel->get_index(idx) + offset; - const auto &str = data[col_idx]; - if (col.validity.RowIsValid(col_idx) && !str.IsInlined()) { - entry_sizes[i] += str.GetSize(); - } - } -} - -static void ScatterStringVector(UnifiedVectorFormat &col, Vector &rows, data_ptr_t str_locations[], - const SelectionVector &sel, const idx_t count, const idx_t col_offset, - const idx_t col_no, const idx_t col_count) { - auto string_data = UnifiedVectorFormat::GetData(col); - auto ptrs = FlatVector::GetData(rows); - - // Write out zero length to avoid swizzling problems. - const string_t null(nullptr, 0); - for (idx_t i = 0; i < count; i++) { - auto idx = sel.get_index(i); - auto col_idx = col.sel->get_index(idx); - auto row = ptrs[idx]; - if (!col.validity.RowIsValid(col_idx)) { - ValidityBytes col_mask(row, col_count); - col_mask.SetInvalidUnsafe(col_no); - Store(null, row + col_offset); - } else if (string_data[col_idx].IsInlined()) { - Store(string_data[col_idx], row + col_offset); - } else { - const auto &str = string_data[col_idx]; - string_t inserted(const_char_ptr_cast(str_locations[i]), UnsafeNumericCast(str.GetSize())); - memcpy(inserted.GetDataWriteable(), str.GetData(), str.GetSize()); - str_locations[i] += str.GetSize(); - inserted.Finalize(); - Store(inserted, row + col_offset); - } - } -} - -static void ScatterNestedVector(Vector &vec, UnifiedVectorFormat &col, Vector &rows, data_ptr_t data_locations[], - const SelectionVector &sel, const idx_t count, const idx_t col_offset, - const idx_t col_no, const idx_t vcount) { - // Store pointers to the data in the row - // Do this first because SerializeVector destroys the locations - auto ptrs = FlatVector::GetData(rows); - data_ptr_t validitymask_locations[STANDARD_VECTOR_SIZE]; - for (idx_t i = 0; i < count; i++) { - auto idx = sel.get_index(i); - auto row = ptrs[idx]; - validitymask_locations[i] = row; - - Store(data_locations[i], row + col_offset); - } - - // Serialise the data - NestedValidity parent_validity(validitymask_locations, col_no); - RowOperations::HeapScatter(vec, vcount, sel, count, data_locations, &parent_validity); -} - -void RowOperations::Scatter(DataChunk &columns, UnifiedVectorFormat col_data[], const RowLayout &layout, Vector &rows, - RowDataCollection &string_heap, const SelectionVector &sel, idx_t count) { - if (count == 0) { - return; - } - - // Set the validity mask for each row before inserting data - idx_t column_count = layout.ColumnCount(); - auto ptrs = FlatVector::GetData(rows); - for (idx_t i = 0; i < count; ++i) { - auto row_idx = sel.get_index(i); - auto row = ptrs[row_idx]; - ValidityBytes(row, column_count).SetAllValid(layout.ColumnCount()); - } - - const auto vcount = columns.size(); - auto &offsets = layout.GetOffsets(); - auto &types = layout.GetTypes(); - - // Compute the entry size of the variable size columns - vector handles; - data_ptr_t data_locations[STANDARD_VECTOR_SIZE]; - if (!layout.AllConstant()) { - idx_t entry_sizes[STANDARD_VECTOR_SIZE]; - std::fill_n(entry_sizes, count, sizeof(uint32_t)); - for (idx_t col_no = 0; col_no < types.size(); col_no++) { - if (TypeIsConstantSize(types[col_no].InternalType())) { - continue; - } - - auto &vec = columns.data[col_no]; - auto &col = col_data[col_no]; - switch (types[col_no].InternalType()) { - case PhysicalType::VARCHAR: - ComputeStringEntrySizes(col, entry_sizes, sel, count); - break; - case PhysicalType::LIST: - case PhysicalType::STRUCT: - case PhysicalType::ARRAY: - RowOperations::ComputeEntrySizes(vec, col, entry_sizes, vcount, count, sel); - break; - default: - throw InternalException("Unsupported type for RowOperations::Scatter"); - } - } - - // Build out the buffer space - handles = string_heap.Build(count, data_locations, entry_sizes); - - // Serialize information that is needed for swizzling if the computation goes out-of-core - const idx_t heap_pointer_offset = layout.GetHeapOffset(); - for (idx_t i = 0; i < count; i++) { - auto row_idx = sel.get_index(i); - auto row = ptrs[row_idx]; - // Pointer to this row in the heap block - Store(data_locations[i], row + heap_pointer_offset); - // Row size is stored in the heap in front of each row - Store(NumericCast(entry_sizes[i]), data_locations[i]); - data_locations[i] += sizeof(uint32_t); - } - } - - for (idx_t col_no = 0; col_no < types.size(); col_no++) { - auto &vec = columns.data[col_no]; - auto &col = col_data[col_no]; - auto col_offset = offsets[col_no]; - - switch (types[col_no].InternalType()) { - case PhysicalType::BOOL: - case PhysicalType::INT8: - TemplatedScatter(col, rows, sel, count, col_offset, col_no, column_count); - break; - case PhysicalType::INT16: - TemplatedScatter(col, rows, sel, count, col_offset, col_no, column_count); - break; - case PhysicalType::INT32: - TemplatedScatter(col, rows, sel, count, col_offset, col_no, column_count); - break; - case PhysicalType::INT64: - TemplatedScatter(col, rows, sel, count, col_offset, col_no, column_count); - break; - case PhysicalType::UINT8: - TemplatedScatter(col, rows, sel, count, col_offset, col_no, column_count); - break; - case PhysicalType::UINT16: - TemplatedScatter(col, rows, sel, count, col_offset, col_no, column_count); - break; - case PhysicalType::UINT32: - TemplatedScatter(col, rows, sel, count, col_offset, col_no, column_count); - break; - case PhysicalType::UINT64: - TemplatedScatter(col, rows, sel, count, col_offset, col_no, column_count); - break; - case PhysicalType::INT128: - TemplatedScatter(col, rows, sel, count, col_offset, col_no, column_count); - break; - case PhysicalType::UINT128: - TemplatedScatter(col, rows, sel, count, col_offset, col_no, column_count); - break; - case PhysicalType::FLOAT: - TemplatedScatter(col, rows, sel, count, col_offset, col_no, column_count); - break; - case PhysicalType::DOUBLE: - TemplatedScatter(col, rows, sel, count, col_offset, col_no, column_count); - break; - case PhysicalType::INTERVAL: - TemplatedScatter(col, rows, sel, count, col_offset, col_no, column_count); - break; - case PhysicalType::VARCHAR: - ScatterStringVector(col, rows, data_locations, sel, count, col_offset, col_no, column_count); - break; - case PhysicalType::LIST: - case PhysicalType::STRUCT: - case PhysicalType::ARRAY: - ScatterNestedVector(vec, col, rows, data_locations, sel, count, col_offset, col_no, vcount); - break; - default: - throw InternalException("Unsupported type for RowOperations::Scatter"); - } - } -} - -} // namespace duckdb diff --git a/src/duckdb/src/common/sort/comparators.cpp b/src/duckdb/src/common/sort/comparators.cpp deleted file mode 100644 index 4df4cccc4..000000000 --- a/src/duckdb/src/common/sort/comparators.cpp +++ /dev/null @@ -1,507 +0,0 @@ -#include "duckdb/common/sort/comparators.hpp" - -#include "duckdb/common/fast_mem.hpp" -#include "duckdb/common/sort/sort.hpp" -#include "duckdb/common/uhugeint.hpp" - -namespace duckdb { - -bool Comparators::TieIsBreakable(const idx_t &tie_col, const data_ptr_t &row_ptr, const SortLayout &sort_layout) { - const auto &col_idx = sort_layout.sorting_to_blob_col.at(tie_col); - // Check if the blob is NULL - ValidityBytes row_mask(row_ptr, sort_layout.column_count); - idx_t entry_idx; - idx_t idx_in_entry; - ValidityBytes::GetEntryIndex(col_idx, entry_idx, idx_in_entry); - if (!row_mask.RowIsValid(row_mask.GetValidityEntry(entry_idx), idx_in_entry)) { - // Can't break a NULL tie - return false; - } - auto &row_layout = sort_layout.blob_layout; - if (row_layout.GetTypes()[col_idx].InternalType() != PhysicalType::VARCHAR) { - // Nested type, must be broken - return true; - } - const auto &tie_col_offset = row_layout.GetOffsets()[col_idx]; - auto tie_string = Load(row_ptr + tie_col_offset); - if (tie_string.GetSize() < sort_layout.prefix_lengths[tie_col] && tie_string.GetSize() > 0) { - // No need to break the tie - we already compared the full string - return false; - } - return true; -} - -int Comparators::CompareTuple(const SBScanState &left, const SBScanState &right, const data_ptr_t &l_ptr, - const data_ptr_t &r_ptr, const SortLayout &sort_layout, const bool &external_sort) { - // Compare the sorting columns one by one - int comp_res = 0; - data_ptr_t l_ptr_offset = l_ptr; - data_ptr_t r_ptr_offset = r_ptr; - for (idx_t col_idx = 0; col_idx < sort_layout.column_count; col_idx++) { - comp_res = FastMemcmp(l_ptr_offset, r_ptr_offset, sort_layout.column_sizes[col_idx]); - if (comp_res == 0 && !sort_layout.constant_size[col_idx]) { - comp_res = BreakBlobTie(col_idx, left, right, sort_layout, external_sort); - } - if (comp_res != 0) { - break; - } - l_ptr_offset += sort_layout.column_sizes[col_idx]; - r_ptr_offset += sort_layout.column_sizes[col_idx]; - } - return comp_res; -} - -int Comparators::CompareVal(const data_ptr_t l_ptr, const data_ptr_t r_ptr, const LogicalType &type) { - switch (type.InternalType()) { - case PhysicalType::VARCHAR: - return TemplatedCompareVal(l_ptr, r_ptr); - case PhysicalType::LIST: - case PhysicalType::ARRAY: - case PhysicalType::STRUCT: { - auto l_nested_ptr = Load(l_ptr); - auto r_nested_ptr = Load(r_ptr); - return CompareValAndAdvance(l_nested_ptr, r_nested_ptr, type, true); - } - default: - throw NotImplementedException("Unimplemented CompareVal for type %s", type.ToString()); - } -} - -int Comparators::BreakBlobTie(const idx_t &tie_col, const SBScanState &left, const SBScanState &right, - const SortLayout &sort_layout, const bool &external) { - data_ptr_t l_data_ptr = left.DataPtr(*left.sb->blob_sorting_data); - data_ptr_t r_data_ptr = right.DataPtr(*right.sb->blob_sorting_data); - if (!TieIsBreakable(tie_col, l_data_ptr, sort_layout) && !TieIsBreakable(tie_col, r_data_ptr, sort_layout)) { - // Quick check to see if ties can be broken - return 0; - } - // Align the pointers - const idx_t &col_idx = sort_layout.sorting_to_blob_col.at(tie_col); - const auto &tie_col_offset = sort_layout.blob_layout.GetOffsets()[col_idx]; - l_data_ptr += tie_col_offset; - r_data_ptr += tie_col_offset; - // Do the comparison - const int order = sort_layout.order_types[tie_col] == OrderType::DESCENDING ? -1 : 1; - const auto &type = sort_layout.blob_layout.GetTypes()[col_idx]; - int result; - if (external) { - // Store heap pointers - data_ptr_t l_heap_ptr = left.HeapPtr(*left.sb->blob_sorting_data); - data_ptr_t r_heap_ptr = right.HeapPtr(*right.sb->blob_sorting_data); - // Unswizzle offset to pointer - UnswizzleSingleValue(l_data_ptr, l_heap_ptr, type); - UnswizzleSingleValue(r_data_ptr, r_heap_ptr, type); - // Compare - result = CompareVal(l_data_ptr, r_data_ptr, type); - // Swizzle the pointers back to offsets - SwizzleSingleValue(l_data_ptr, l_heap_ptr, type); - SwizzleSingleValue(r_data_ptr, r_heap_ptr, type); - } else { - result = CompareVal(l_data_ptr, r_data_ptr, type); - } - return order * result; -} - -template -int Comparators::TemplatedCompareVal(const data_ptr_t &left_ptr, const data_ptr_t &right_ptr) { - const auto left_val = Load(left_ptr); - const auto right_val = Load(right_ptr); - if (Equals::Operation(left_val, right_val)) { - return 0; - } else if (LessThan::Operation(left_val, right_val)) { - return -1; - } else { - return 1; - } -} - -int Comparators::CompareValAndAdvance(data_ptr_t &l_ptr, data_ptr_t &r_ptr, const LogicalType &type, bool valid) { - switch (type.InternalType()) { - case PhysicalType::BOOL: - case PhysicalType::INT8: - return TemplatedCompareAndAdvance(l_ptr, r_ptr); - case PhysicalType::INT16: - return TemplatedCompareAndAdvance(l_ptr, r_ptr); - case PhysicalType::INT32: - return TemplatedCompareAndAdvance(l_ptr, r_ptr); - case PhysicalType::INT64: - return TemplatedCompareAndAdvance(l_ptr, r_ptr); - case PhysicalType::UINT8: - return TemplatedCompareAndAdvance(l_ptr, r_ptr); - case PhysicalType::UINT16: - return TemplatedCompareAndAdvance(l_ptr, r_ptr); - case PhysicalType::UINT32: - return TemplatedCompareAndAdvance(l_ptr, r_ptr); - case PhysicalType::UINT64: - return TemplatedCompareAndAdvance(l_ptr, r_ptr); - case PhysicalType::INT128: - return TemplatedCompareAndAdvance(l_ptr, r_ptr); - case PhysicalType::UINT128: - return TemplatedCompareAndAdvance(l_ptr, r_ptr); - case PhysicalType::FLOAT: - return TemplatedCompareAndAdvance(l_ptr, r_ptr); - case PhysicalType::DOUBLE: - return TemplatedCompareAndAdvance(l_ptr, r_ptr); - case PhysicalType::INTERVAL: - return TemplatedCompareAndAdvance(l_ptr, r_ptr); - case PhysicalType::VARCHAR: - return CompareStringAndAdvance(l_ptr, r_ptr, valid); - case PhysicalType::LIST: - return CompareListAndAdvance(l_ptr, r_ptr, ListType::GetChildType(type), valid); - case PhysicalType::STRUCT: - return CompareStructAndAdvance(l_ptr, r_ptr, StructType::GetChildTypes(type), valid); - case PhysicalType::ARRAY: - return CompareArrayAndAdvance(l_ptr, r_ptr, ArrayType::GetChildType(type), valid, ArrayType::GetSize(type)); - default: - throw NotImplementedException("Unimplemented CompareValAndAdvance for type %s", type.ToString()); - } -} - -template -int Comparators::TemplatedCompareAndAdvance(data_ptr_t &left_ptr, data_ptr_t &right_ptr) { - auto result = TemplatedCompareVal(left_ptr, right_ptr); - left_ptr += sizeof(T); - right_ptr += sizeof(T); - return result; -} - -int Comparators::CompareStringAndAdvance(data_ptr_t &left_ptr, data_ptr_t &right_ptr, bool valid) { - if (!valid) { - return 0; - } - uint32_t left_string_size = Load(left_ptr); - uint32_t right_string_size = Load(right_ptr); - left_ptr += sizeof(uint32_t); - right_ptr += sizeof(uint32_t); - auto memcmp_res = memcmp(const_char_ptr_cast(left_ptr), const_char_ptr_cast(right_ptr), - std::min(left_string_size, right_string_size)); - - left_ptr += left_string_size; - right_ptr += right_string_size; - - if (memcmp_res != 0) { - return memcmp_res; - } - if (left_string_size == right_string_size) { - return 0; - } - if (left_string_size < right_string_size) { - return -1; - } - return 1; -} - -int Comparators::CompareStructAndAdvance(data_ptr_t &left_ptr, data_ptr_t &right_ptr, - const child_list_t &types, bool valid) { - idx_t count = types.size(); - // Load validity masks - ValidityBytes left_validity(left_ptr, types.size()); - ValidityBytes right_validity(right_ptr, types.size()); - left_ptr += (count + 7) / 8; - right_ptr += (count + 7) / 8; - // Initialize variables - bool left_valid; - bool right_valid; - idx_t entry_idx; - idx_t idx_in_entry; - // Compare - int comp_res = 0; - for (idx_t i = 0; i < count; i++) { - ValidityBytes::GetEntryIndex(i, entry_idx, idx_in_entry); - left_valid = left_validity.RowIsValid(left_validity.GetValidityEntry(entry_idx), idx_in_entry); - right_valid = right_validity.RowIsValid(right_validity.GetValidityEntry(entry_idx), idx_in_entry); - auto &type = types[i].second; - if ((left_valid == right_valid) || TypeIsConstantSize(type.InternalType())) { - comp_res = CompareValAndAdvance(left_ptr, right_ptr, types[i].second, left_valid && valid); - } - if (!left_valid && !right_valid) { - comp_res = 0; - } else if (!left_valid) { - comp_res = 1; - } else if (!right_valid) { - comp_res = -1; - } - if (comp_res != 0) { - break; - } - } - return comp_res; -} - -int Comparators::CompareArrayAndAdvance(data_ptr_t &left_ptr, data_ptr_t &right_ptr, const LogicalType &type, - bool valid, idx_t array_size) { - if (!valid) { - return 0; - } - - // Load array validity masks - ValidityBytes left_validity(left_ptr, array_size); - ValidityBytes right_validity(right_ptr, array_size); - left_ptr += (array_size + 7) / 8; - right_ptr += (array_size + 7) / 8; - - int comp_res = 0; - if (TypeIsConstantSize(type.InternalType())) { - // Templated code for fixed-size types - switch (type.InternalType()) { - case PhysicalType::BOOL: - case PhysicalType::INT8: - comp_res = TemplatedCompareListLoop(left_ptr, right_ptr, left_validity, right_validity, array_size); - break; - case PhysicalType::INT16: - comp_res = - TemplatedCompareListLoop(left_ptr, right_ptr, left_validity, right_validity, array_size); - break; - case PhysicalType::INT32: - comp_res = - TemplatedCompareListLoop(left_ptr, right_ptr, left_validity, right_validity, array_size); - break; - case PhysicalType::INT64: - comp_res = - TemplatedCompareListLoop(left_ptr, right_ptr, left_validity, right_validity, array_size); - break; - case PhysicalType::UINT8: - comp_res = - TemplatedCompareListLoop(left_ptr, right_ptr, left_validity, right_validity, array_size); - break; - case PhysicalType::UINT16: - comp_res = - TemplatedCompareListLoop(left_ptr, right_ptr, left_validity, right_validity, array_size); - break; - case PhysicalType::UINT32: - comp_res = - TemplatedCompareListLoop(left_ptr, right_ptr, left_validity, right_validity, array_size); - break; - case PhysicalType::UINT64: - comp_res = - TemplatedCompareListLoop(left_ptr, right_ptr, left_validity, right_validity, array_size); - break; - case PhysicalType::INT128: - comp_res = - TemplatedCompareListLoop(left_ptr, right_ptr, left_validity, right_validity, array_size); - break; - case PhysicalType::FLOAT: - comp_res = TemplatedCompareListLoop(left_ptr, right_ptr, left_validity, right_validity, array_size); - break; - case PhysicalType::DOUBLE: - comp_res = TemplatedCompareListLoop(left_ptr, right_ptr, left_validity, right_validity, array_size); - break; - case PhysicalType::INTERVAL: - comp_res = - TemplatedCompareListLoop(left_ptr, right_ptr, left_validity, right_validity, array_size); - break; - default: - throw NotImplementedException("CompareListAndAdvance for fixed-size type %s", type.ToString()); - } - } else { - // Variable-sized array entries - bool left_valid; - bool right_valid; - idx_t entry_idx; - idx_t idx_in_entry; - // Size (in bytes) of all variable-sizes entries is stored before the entries begin, - // to make deserialization easier. We need to skip over them - left_ptr += array_size * sizeof(idx_t); - right_ptr += array_size * sizeof(idx_t); - for (idx_t i = 0; i < array_size; i++) { - ValidityBytes::GetEntryIndex(i, entry_idx, idx_in_entry); - left_valid = left_validity.RowIsValid(left_validity.GetValidityEntry(entry_idx), idx_in_entry); - right_valid = right_validity.RowIsValid(right_validity.GetValidityEntry(entry_idx), idx_in_entry); - if (left_valid && right_valid) { - switch (type.InternalType()) { - case PhysicalType::LIST: - comp_res = CompareListAndAdvance(left_ptr, right_ptr, ListType::GetChildType(type), left_valid); - break; - case PhysicalType::ARRAY: - comp_res = CompareArrayAndAdvance(left_ptr, right_ptr, ArrayType::GetChildType(type), left_valid, - ArrayType::GetSize(type)); - break; - case PhysicalType::VARCHAR: - comp_res = CompareStringAndAdvance(left_ptr, right_ptr, left_valid); - break; - case PhysicalType::STRUCT: - comp_res = - CompareStructAndAdvance(left_ptr, right_ptr, StructType::GetChildTypes(type), left_valid); - break; - default: - throw NotImplementedException("CompareArrayAndAdvance for variable-size type %s", type.ToString()); - } - } else if (!left_valid && !right_valid) { - comp_res = 0; - } else if (left_valid) { - comp_res = -1; - } else { - comp_res = 1; - } - if (comp_res != 0) { - break; - } - } - } - return comp_res; -} - -int Comparators::CompareListAndAdvance(data_ptr_t &left_ptr, data_ptr_t &right_ptr, const LogicalType &type, - bool valid) { - if (!valid) { - return 0; - } - // Load list lengths - auto left_len = Load(left_ptr); - auto right_len = Load(right_ptr); - left_ptr += sizeof(idx_t); - right_ptr += sizeof(idx_t); - // Load list validity masks - ValidityBytes left_validity(left_ptr, left_len); - ValidityBytes right_validity(right_ptr, right_len); - left_ptr += (left_len + 7) / 8; - right_ptr += (right_len + 7) / 8; - // Compare - int comp_res = 0; - idx_t count = MinValue(left_len, right_len); - if (TypeIsConstantSize(type.InternalType())) { - // Templated code for fixed-size types - switch (type.InternalType()) { - case PhysicalType::BOOL: - case PhysicalType::INT8: - comp_res = TemplatedCompareListLoop(left_ptr, right_ptr, left_validity, right_validity, count); - break; - case PhysicalType::INT16: - comp_res = TemplatedCompareListLoop(left_ptr, right_ptr, left_validity, right_validity, count); - break; - case PhysicalType::INT32: - comp_res = TemplatedCompareListLoop(left_ptr, right_ptr, left_validity, right_validity, count); - break; - case PhysicalType::INT64: - comp_res = TemplatedCompareListLoop(left_ptr, right_ptr, left_validity, right_validity, count); - break; - case PhysicalType::UINT8: - comp_res = TemplatedCompareListLoop(left_ptr, right_ptr, left_validity, right_validity, count); - break; - case PhysicalType::UINT16: - comp_res = TemplatedCompareListLoop(left_ptr, right_ptr, left_validity, right_validity, count); - break; - case PhysicalType::UINT32: - comp_res = TemplatedCompareListLoop(left_ptr, right_ptr, left_validity, right_validity, count); - break; - case PhysicalType::UINT64: - comp_res = TemplatedCompareListLoop(left_ptr, right_ptr, left_validity, right_validity, count); - break; - case PhysicalType::INT128: - comp_res = TemplatedCompareListLoop(left_ptr, right_ptr, left_validity, right_validity, count); - break; - case PhysicalType::UINT128: - comp_res = TemplatedCompareListLoop(left_ptr, right_ptr, left_validity, right_validity, count); - break; - case PhysicalType::FLOAT: - comp_res = TemplatedCompareListLoop(left_ptr, right_ptr, left_validity, right_validity, count); - break; - case PhysicalType::DOUBLE: - comp_res = TemplatedCompareListLoop(left_ptr, right_ptr, left_validity, right_validity, count); - break; - case PhysicalType::INTERVAL: - comp_res = TemplatedCompareListLoop(left_ptr, right_ptr, left_validity, right_validity, count); - break; - default: - throw NotImplementedException("CompareListAndAdvance for fixed-size type %s", type.ToString()); - } - } else { - // Variable-sized list entries - bool left_valid; - bool right_valid; - idx_t entry_idx; - idx_t idx_in_entry; - // Size (in bytes) of all variable-sizes entries is stored before the entries begin, - // to make deserialization easier. We need to skip over them - left_ptr += left_len * sizeof(idx_t); - right_ptr += right_len * sizeof(idx_t); - for (idx_t i = 0; i < count; i++) { - ValidityBytes::GetEntryIndex(i, entry_idx, idx_in_entry); - left_valid = left_validity.RowIsValid(left_validity.GetValidityEntry(entry_idx), idx_in_entry); - right_valid = right_validity.RowIsValid(right_validity.GetValidityEntry(entry_idx), idx_in_entry); - if (left_valid && right_valid) { - switch (type.InternalType()) { - case PhysicalType::LIST: - comp_res = CompareListAndAdvance(left_ptr, right_ptr, ListType::GetChildType(type), left_valid); - break; - case PhysicalType::ARRAY: - comp_res = CompareArrayAndAdvance(left_ptr, right_ptr, ArrayType::GetChildType(type), left_valid, - ArrayType::GetSize(type)); - break; - case PhysicalType::VARCHAR: - comp_res = CompareStringAndAdvance(left_ptr, right_ptr, left_valid); - break; - case PhysicalType::STRUCT: - comp_res = - CompareStructAndAdvance(left_ptr, right_ptr, StructType::GetChildTypes(type), left_valid); - break; - default: - throw NotImplementedException("CompareListAndAdvance for variable-size type %s", type.ToString()); - } - } else if (!left_valid && !right_valid) { - comp_res = 0; - } else if (left_valid) { - comp_res = -1; - } else { - comp_res = 1; - } - if (comp_res != 0) { - break; - } - } - } - // All values that we looped over were equal - if (comp_res == 0 && left_len != right_len) { - // Smaller lists first - if (left_len < right_len) { - comp_res = -1; - } else { - comp_res = 1; - } - } - return comp_res; -} - -template -int Comparators::TemplatedCompareListLoop(data_ptr_t &left_ptr, data_ptr_t &right_ptr, - const ValidityBytes &left_validity, const ValidityBytes &right_validity, - const idx_t &count) { - int comp_res = 0; - bool left_valid; - bool right_valid; - idx_t entry_idx; - idx_t idx_in_entry; - for (idx_t i = 0; i < count; i++) { - ValidityBytes::GetEntryIndex(i, entry_idx, idx_in_entry); - left_valid = left_validity.RowIsValid(left_validity.GetValidityEntry(entry_idx), idx_in_entry); - right_valid = right_validity.RowIsValid(right_validity.GetValidityEntry(entry_idx), idx_in_entry); - comp_res = TemplatedCompareAndAdvance(left_ptr, right_ptr); - if (!left_valid && !right_valid) { - comp_res = 0; - } else if (!left_valid) { - comp_res = 1; - } else if (!right_valid) { - comp_res = -1; - } - if (comp_res != 0) { - break; - } - } - return comp_res; -} - -void Comparators::UnswizzleSingleValue(data_ptr_t data_ptr, const data_ptr_t &heap_ptr, const LogicalType &type) { - if (type.InternalType() == PhysicalType::VARCHAR) { - data_ptr += string_t::HEADER_SIZE; - } - Store(heap_ptr + Load(data_ptr), data_ptr); -} - -void Comparators::SwizzleSingleValue(data_ptr_t data_ptr, const data_ptr_t &heap_ptr, const LogicalType &type) { - if (type.InternalType() == PhysicalType::VARCHAR) { - data_ptr += string_t::HEADER_SIZE; - } - Store(UnsafeNumericCast(Load(data_ptr) - heap_ptr), data_ptr); -} - -} // namespace duckdb diff --git a/src/duckdb/src/common/sorting/hashed_sort.cpp b/src/duckdb/src/common/sort/hashed_sort.cpp similarity index 68% rename from src/duckdb/src/common/sorting/hashed_sort.cpp rename to src/duckdb/src/common/sort/hashed_sort.cpp index 5571e0bc3..e0244da6c 100644 --- a/src/duckdb/src/common/sorting/hashed_sort.cpp +++ b/src/duckdb/src/common/sort/hashed_sort.cpp @@ -1,6 +1,6 @@ #include "duckdb/common/sorting/hashed_sort.hpp" +#include "duckdb/common/sorting/sorted_run.hpp" #include "duckdb/common/radix_partitioning.hpp" -#include "duckdb/parallel/base_pipeline_event.hpp" #include "duckdb/parallel/thread_context.hpp" #include "duckdb/planner/expression/bound_reference_expression.hpp" @@ -17,20 +17,32 @@ class HashedSortGroup { HashedSortGroup(ClientContext &client, optional_ptr sort, idx_t group_idx); + bool Scan(TupleDataCollection &payload, TupleDataLocalScanState &local_scan, DataChunk &chunk) { + // Despite the name, TupleDataParallelScanState is not thread safe... + lock_guard guard(scan_lock); + return payload.Scan(parallel_scan, local_scan, chunk); + } + const idx_t group_idx; + atomic count; // Sink optional_ptr sort; unique_ptr sort_global; // Source + mutex scan_lock; + TupleDataParallelScanState parallel_scan; atomic tasks_completed; unique_ptr sort_source; - unique_ptr sorted; + + // Unsorted + unique_ptr columns; + atomic get_columns; }; HashedSortGroup::HashedSortGroup(ClientContext &client, optional_ptr sort, idx_t group_idx) - : group_idx(group_idx), sort(sort), tasks_completed(0) { + : group_idx(group_idx), count(0), sort(sort), tasks_completed(0), get_columns(0) { if (sort) { sort_global = sort->GetGlobalSinkState(client); } @@ -53,6 +65,7 @@ class HashedSortGlobalSinkState : public GlobalSinkState { // OVER(PARTITION BY...) (hash grouping) unique_ptr CreatePartition(idx_t new_bits) const; + void SyncPartitioning(const HashedSortGlobalSinkState &other); void UpdateLocalPartition(GroupingPartition &local_partition, GroupingAppend &partition_append); void CombineLocalPartition(GroupingPartition &local_partition, GroupingAppend &local_append); ProgressData GetSinkProgress(ClientContext &context, const ProgressData source_progress) const; @@ -69,6 +82,7 @@ class HashedSortGlobalSinkState : public GlobalSinkState { shared_ptr grouping_types_ptr; //! The number of radix bits if this partition is being synced with another idx_t fixed_bits; + vector scan_ids; // OVER(...) (sorting) vector hash_groups; @@ -85,7 +99,6 @@ class HashedSortGlobalSinkState : public GlobalSinkState { HashedSortGlobalSinkState::HashedSortGlobalSinkState(ClientContext &client, const HashedSort &hashed_sort) : hashed_sort(hashed_sort), buffer_manager(BufferManager::GetBufferManager(client)), allocator(Allocator::Get(client)), fixed_bits(0), max_bits(1), count(0) { - const auto memory_per_thread = PhysicalOperator::GetMaxThreadMemory(client); const auto thread_pages = PreviousPowerOfTwo(memory_per_thread / (4 * buffer_manager.GetBlockAllocSize())); while (max_bits < 8 && (thread_pages >> max_bits) > 1) { @@ -107,6 +120,9 @@ HashedSortGlobalSinkState::HashedSortGlobalSinkState(ClientContext &client, cons types.push_back(LogicalType::HASH); grouping_types_ptr->Initialize(types, TupleDataValidityType::CAN_HAVE_NULL_VALUES); Rehash(hashed_sort.estimated_cardinality); + for (column_t i = 0; i < payload_types.size(); ++i) { + scan_ids.emplace_back(i); + } } } } @@ -172,6 +188,15 @@ void HashedSortGlobalSinkState::UpdateLocalPartition(GroupingPartition &local_pa SyncLocalPartition(local_partition, partition_append); } +void HashedSortGlobalSinkState::SyncPartitioning(const HashedSortGlobalSinkState &other) { + fixed_bits = other.grouping_data ? other.grouping_data->GetRadixBits() : 0; + + const auto old_bits = grouping_data ? grouping_data->GetRadixBits() : 0; + if (fixed_bits != old_bits) { + grouping_data = CreatePartition(fixed_bits); + } +} + void HashedSortGlobalSinkState::CombineLocalPartition(GroupingPartition &local_partition, GroupingAppend &local_append) { if (!local_partition) { @@ -203,6 +228,9 @@ void HashedSortGlobalSinkState::CombineLocalPartition(GroupingPartition &local_p hash_group = make_uniq(hashed_sort.client, *hashed_sort.sort, group_idx); } } + + // Combine the thread data into the global data + grouping_data->Combine(*local_partition); } ProgressData HashedSortGlobalSinkState::GetSinkProgress(ClientContext &client, const ProgressData source) const { @@ -242,14 +270,36 @@ SinkFinalizeType HashedSort::Finalize(ClientContext &client, OperatorSinkFinaliz return SinkFinalizeType::READY; } - // OVER(...) + // OVER(ORDER BY...) + if (partitions.empty()) { + auto &hash_group = gsink.hash_groups[0]; + if (hash_group) { + auto &global_sink = *hash_group->sort_global; + OperatorSinkFinalizeInput hfinalize {global_sink, finalize.interrupt_state}; + sort->Finalize(client, hfinalize); + hash_group->sort_source = sort->GetGlobalSourceState(client, global_sink); + return SinkFinalizeType::READY; + } + return SinkFinalizeType::NO_OUTPUT_POSSIBLE; + } + + // OVER(PARTITION BY...) + auto &partitions = gsink.grouping_data->GetPartitions(); D_ASSERT(!gsink.hash_groups.empty()); - for (auto &hash_group : gsink.hash_groups) { + for (hash_t hash_bin = 0; hash_bin < partitions.size(); ++hash_bin) { + auto &partition = *partitions[hash_bin]; + if (!partition.Count()) { + continue; + } + + auto &hash_group = gsink.hash_groups[hash_bin]; if (!hash_group) { continue; } - OperatorSinkFinalizeInput hfinalize {*hash_group->sort_global, finalize.interrupt_state}; - sort->Finalize(client, hfinalize); + + // Prepare to scan into the sort + auto ¶llel_scan = hash_group->parallel_scan; + partition.InitializeScan(parallel_scan, gsink.scan_ids); } return SinkFinalizeType::READY; @@ -295,7 +345,6 @@ class HashedSortLocalSinkState : public LocalSinkState { // OVER(ORDER BY...) (only sorting) LocalSortStatePtr sort_local; - InterruptState interrupt; // OVER() (no sorting) unique_ptr unsorted; @@ -305,7 +354,6 @@ class HashedSortLocalSinkState : public LocalSinkState { HashedSortLocalSinkState::HashedSortLocalSinkState(ExecutionContext &context, const HashedSort &hashed_sort) : hashed_sort(hashed_sort), allocator(Allocator::Get(context.client)), hash_exec(context.client), sort_exec(context.client) { - vector group_types; for (idx_t prt_idx = 0; prt_idx < hashed_sort.partitions.size(); prt_idx++) { auto &pexpr = *hashed_sort.partitions[prt_idx].expression.get(); @@ -328,14 +376,6 @@ HashedSortLocalSinkState::HashedSortLocalSinkState(ExecutionContext &context, co payload_types.emplace_back(LogicalType::HASH); } else { // OVER(ORDER BY...) - for (idx_t ord_idx = 0; ord_idx < hashed_sort.orders.size(); ord_idx++) { - auto &pexpr = *hashed_sort.orders[ord_idx].expression.get(); - group_types.push_back(pexpr.return_type); - hash_exec.AddExpression(pexpr); - } - group_chunk.Initialize(allocator, group_types); - - // Single partition auto &sort = *hashed_sort.sort; sort_local = sort.GetLocalSinkState(context); } @@ -347,6 +387,12 @@ HashedSortLocalSinkState::HashedSortLocalSinkState(ExecutionContext &context, co } } +void HashedSort::Synchronize(const GlobalSinkState &source, GlobalSinkState &target) const { + auto &src = source.Cast(); + auto &tgt = target.Cast(); + tgt.SyncPartitioning(src); +} + void HashedSortLocalSinkState::Hash(DataChunk &input_chunk, Vector &hash_vector) { const auto count = input_chunk.size(); D_ASSERT(group_chunk.ColumnCount() > 0); @@ -393,6 +439,15 @@ SinkResultType HashedSort::Sink(ExecutionContext &context, DataChunk &input_chun payload_chunk.data[input_chunk.ColumnCount() + i].Reference(sort_chunk.data[i]); } } + + // Append a forced payload column + if (force_payload) { + auto &vec = payload_chunk.data[input_chunk.ColumnCount() + sort_chunk.ColumnCount()]; + D_ASSERT(vec.GetType().id() == LogicalTypeId::BOOLEAN); + vec.SetVectorType(VectorType::CONSTANT_VECTOR); + ConstantVector::SetNull(vec, true); + } + payload_chunk.SetCardinality(input_chunk); // OVER(ORDER BY...) @@ -401,6 +456,7 @@ SinkResultType HashedSort::Sink(ExecutionContext &context, DataChunk &input_chun auto &hash_group = *gstate.hash_groups[0]; OperatorSinkInput input {*hash_group.sort_global, *sort_local, sink.interrupt_state}; sort->Sink(context, payload_chunk, input); + hash_group.count += payload_chunk.size(); return SinkResultType::NEED_MORE_INPUT; } @@ -435,14 +491,17 @@ SinkCombineResultType HashedSort::Combine(ExecutionContext &context, OperatorSin auto &hash_groups = gstate.hash_groups; if (!hash_groups.empty()) { D_ASSERT(hash_groups.size() == 1); - auto &unsorted = *hash_groups[0]->sorted; + auto &hash_group = *hash_groups[0]; + auto &unsorted = *hash_group.columns; if (lstate.unsorted) { + hash_group.count += lstate.unsorted->Count(); unsorted.Combine(*lstate.unsorted); lstate.unsorted.reset(); } } else { auto new_group = make_uniq(context.client, sort, idx_t(0)); - new_group->sorted = std::move(lstate.unsorted); + new_group->columns = std::move(lstate.unsorted); + new_group->count += new_group->columns->Count(); hash_groups.emplace_back(std::move(new_group)); } return SinkCombineResultType::FINISHED; @@ -467,152 +526,130 @@ SinkCombineResultType HashedSort::Combine(ExecutionContext &context, OperatorSin auto &grouping_append = lstate.grouping_append; gstate.CombineLocalPartition(local_grouping, grouping_append); - // Don't scan the hash column - vector column_ids; - for (column_t i = 0; i < payload_types.size(); ++i) { - column_ids.emplace_back(i); + return SinkCombineResultType::FINISHED; +} + +void HashedSort::SortColumnData(ExecutionContext &context, hash_t hash_bin, OperatorSinkFinalizeInput &finalize) { + auto &gstate = finalize.global_state.Cast(); + + // OVER() + if (sort_col_count == 0) { + // Nothing to sort + return; + } + + // OVER(ORDER BY...) + if (partitions.empty()) { + // Already sorted in Combine + return; } // Loop over the partitions and add them to each hash group's global sort state - TupleDataScanState scan_state; - DataChunk chunk; - auto &partitions = local_grouping->GetPartitions(); - for (hash_t hash_bin = 0; hash_bin < partitions.size(); ++hash_bin) { + auto &partitions = gstate.grouping_data->GetPartitions(); + if (hash_bin < partitions.size()) { auto &partition = *partitions[hash_bin]; if (!partition.Count()) { - continue; - } - - partition.InitializeScan(scan_state, column_ids, TupleDataPinProperties::DESTROY_AFTER_DONE); - if (chunk.data.empty()) { - partition.InitializeScanChunk(scan_state, chunk); + return; } auto &hash_group = *gstate.hash_groups[hash_bin]; - lstate.sort_local = sort->GetLocalSinkState(context); - OperatorSinkInput sink {*hash_group.sort_global, *lstate.sort_local, combine.interrupt_state}; - while (partition.Scan(scan_state, chunk)) { + auto ¶llel_scan = hash_group.parallel_scan; + + DataChunk chunk; + partition.InitializeScanChunk(parallel_scan.scan_state, chunk); + TupleDataLocalScanState local_scan; + partition.InitializeScan(local_scan); + + auto sort_local = sort->GetLocalSinkState(context); + OperatorSinkInput sink {*hash_group.sort_global, *sort_local, finalize.interrupt_state}; + idx_t combined = 0; + while (hash_group.Scan(partition, local_scan, chunk)) { sort->Sink(context, chunk, sink); + combined += chunk.size(); } - OperatorSinkCombineInput lcombine {*hash_group.sort_global, *lstate.sort_local, combine.interrupt_state}; - sort->Combine(context, lcombine); - } + OperatorSinkCombineInput combine {*hash_group.sort_global, *sort_local, finalize.interrupt_state}; + sort->Combine(context, combine); + hash_group.count += combined; - return SinkCombineResultType::FINISHED; + // Whoever finishes last can Finalize + lock_guard finalize_guard(hash_group.scan_lock); + if (hash_group.count == partition.Count() && !hash_group.sort_source) { + OperatorSinkFinalizeInput lfinalize {*hash_group.sort_global, finalize.interrupt_state}; + sort->Finalize(context.client, lfinalize); + hash_group.sort_source = sort->GetGlobalSourceState(client, *hash_group.sort_global); + } + } } //===--------------------------------------------------------------------===// -// HashedSortMaterializeTask +// HashedSortGlobalSourceState //===--------------------------------------------------------------------===// -class HashedSortMaterializeTask : public ExecutorTask { +class HashedSortGlobalSourceState : public GlobalSourceState { public: - HashedSortMaterializeTask(Pipeline &pipeline, shared_ptr event, const PhysicalOperator &op, - HashedSortGroup &hash_group, idx_t tasks_scheduled); - - TaskExecutionResult ExecuteTask(TaskExecutionMode mode) override; + using HashGroupPtr = unique_ptr; + using SortedRunPtr = unique_ptr; + using ChunkRow = HashedSort::ChunkRow; + using ChunkRows = HashedSort::ChunkRows; - string TaskType() const override { - return "HashedSortMaterializeTask"; - } + HashedSortGlobalSourceState(ClientContext &client, HashedSortGlobalSinkState &gsink); -private: - Pipeline &pipeline; - HashedSortGroup &hash_group; - const idx_t tasks_scheduled; + HashedSortGlobalSinkState &gsink; + ChunkRows chunk_rows; }; -HashedSortMaterializeTask::HashedSortMaterializeTask(Pipeline &pipeline, shared_ptr event, - const PhysicalOperator &op, HashedSortGroup &hash_group, - idx_t tasks_scheduled) - : ExecutorTask(pipeline.GetClientContext(), std::move(event), op), pipeline(pipeline), hash_group(hash_group), - tasks_scheduled(tasks_scheduled) { -} - -TaskExecutionResult HashedSortMaterializeTask::ExecuteTask(TaskExecutionMode mode) { - ExecutionContext execution(pipeline.GetClientContext(), *thread_context, &pipeline); - auto &sort = *hash_group.sort; - auto &sort_global = *hash_group.sort_source; - auto sort_local = sort.GetLocalSourceState(execution, sort_global); - InterruptState interrupt((weak_ptr(shared_from_this()))); - OperatorSourceInput input {sort_global, *sort_local, interrupt}; - sort.MaterializeColumnData(execution, input); - if (++hash_group.tasks_completed == tasks_scheduled) { - hash_group.sorted = sort.GetColumnData(input); +HashedSortGlobalSourceState::HashedSortGlobalSourceState(ClientContext &client, HashedSortGlobalSinkState &gsink) + : gsink(gsink) { + if (!gsink.count) { + return; } - event->FinishTask(); - return TaskExecutionResult::TASK_FINISHED; -} + auto &hashed_sort = gsink.hashed_sort; -//===--------------------------------------------------------------------===// -// HashedSortMaterializeEvent -//===--------------------------------------------------------------------===// -// Formerly PartitionMergeEvent -class HashedSortMaterializeEvent : public BasePipelineEvent { -public: - HashedSortMaterializeEvent(HashedSortGlobalSinkState &gstate, Pipeline &pipeline, const PhysicalOperator &op); - - HashedSortGlobalSinkState &gstate; - const PhysicalOperator &op; - -public: - void Schedule() override; -}; + // OVER() + if (hashed_sort.sort_col_count == 0) { + // One unsorted group. We have the count and chunks. + ChunkRow chunk_row; -HashedSortMaterializeEvent::HashedSortMaterializeEvent(HashedSortGlobalSinkState &gstate, Pipeline &pipeline, - const PhysicalOperator &op) - : BasePipelineEvent(pipeline), gstate(gstate), op(op) { -} + auto &hash_group = gsink.hash_groups[0]; + if (hash_group) { + chunk_row.count = hash_group->count; + chunk_row.chunks = hash_group->columns->ChunkCount(); + } -void HashedSortMaterializeEvent::Schedule() { - auto &client = pipeline->GetClientContext(); + chunk_rows.emplace_back(chunk_row); + return; + } - // Schedule as many tasks per hash group as the sort will allow - auto &ts = TaskScheduler::GetScheduler(client); - const auto num_threads = NumericCast(ts.NumberOfThreads()); - auto &sort = *gstate.hashed_sort.sort; + // OVER(ORDER BY...) + if (hashed_sort.partitions.empty()) { + // One sorted group + ChunkRow chunk_row; - vector> merge_tasks; - for (auto &hash_group : gstate.hash_groups) { - if (!hash_group) { - continue; - } - auto &global_sink = *hash_group->sort_global; - hash_group->sort_source = sort.GetGlobalSourceState(client, global_sink); - const auto tasks_scheduled = MinValue(num_threads, hash_group->sort_source->MaxThreads()); - for (idx_t t = 0; t < tasks_scheduled; ++t) { - merge_tasks.emplace_back( - make_uniq(*pipeline, shared_from_this(), op, *hash_group, tasks_scheduled)); + auto &hash_group = gsink.hash_groups[0]; + if (hash_group) { + chunk_row.count = hash_group->count; + chunk_row.chunks = (chunk_row.count + STANDARD_VECTOR_SIZE - 1) / STANDARD_VECTOR_SIZE; } - } - SetTasks(std::move(merge_tasks)); -} + chunk_rows.emplace_back(chunk_row); + return; + } -//===--------------------------------------------------------------------===// -// HashedSortGlobalSourceState -//===--------------------------------------------------------------------===// -class HashedSortGlobalSourceState : public GlobalSourceState { -public: - using HashGroupPtr = unique_ptr; + // OVER(PARTITION BY...) + auto &partitions = gsink.grouping_data->GetPartitions(); + for (hash_t hash_bin = 0; hash_bin < partitions.size(); ++hash_bin) { + ChunkRow chunk_row; - HashedSortGlobalSourceState(ClientContext &client, HashedSortGlobalSinkState &gsink) { - if (!gsink.count) { - return; - } - hash_groups.resize(gsink.hash_groups.size()); - for (auto &hash_group : gsink.hash_groups) { - if (!hash_group) { - continue; - } - const auto group_idx = hash_group->group_idx; - hash_groups[group_idx] = std::move(hash_group->sorted); + auto &hash_group = gsink.hash_groups[hash_bin]; + if (hash_group) { + chunk_row.count = partitions[hash_bin]->Count(); + chunk_row.chunks = (chunk_row.count + STANDARD_VECTOR_SIZE - 1) / STANDARD_VECTOR_SIZE; } - } - vector hash_groups; -}; + chunk_rows.emplace_back(chunk_row); + } +} //===--------------------------------------------------------------------===// // HashedSort @@ -620,7 +657,6 @@ class HashedSortGlobalSourceState : public GlobalSourceState { void HashedSort::GenerateOrderings(Orders &partitions, Orders &orders, const vector> &partition_bys, const Orders &order_bys, const vector> &partition_stats) { - // we sort by both 1) partition by expression list and 2) order by expressions const auto partition_cols = partition_bys.size(); for (idx_t prt_idx = 0; prt_idx < partition_cols; prt_idx++) { @@ -642,7 +678,8 @@ void HashedSort::GenerateOrderings(Orders &partitions, Orders &orders, HashedSort::HashedSort(ClientContext &client, const vector> &partition_bys, const vector &order_bys, const Types &input_types, - const vector> &partition_stats, idx_t estimated_cardinality) + const vector> &partition_stats, idx_t estimated_cardinality, + bool require_payload) : client(client), estimated_cardinality(estimated_cardinality), payload_types(input_types) { GenerateOrderings(partitions, orders, partition_bys, order_bys, partition_stats); @@ -673,6 +710,15 @@ HashedSort::HashedSort(ClientContext &client, const vector sort_set(sort_ids.begin(), sort_ids.end()); + force_payload = (sort_set.size() >= payload_types.size()); + if (force_payload) { + payload_types.emplace_back(LogicalType::BOOLEAN); + } + } vector projection_map; sort = make_uniq(client, orders, payload_types, projection_map); } @@ -695,30 +741,90 @@ unique_ptr HashedSort::GetLocalSourceState(ExecutionContext &c return make_uniq(); } -vector &HashedSort::GetHashGroups(GlobalSourceState &gstate) const { +const HashedSort::ChunkRows &HashedSort::GetHashGroups(GlobalSourceState &gstate) const { auto &gsource = gstate.Cast(); - return gsource.hash_groups; + return gsource.chunk_rows; } -SinkFinalizeType HashedSort::MaterializeHashGroups(Pipeline &pipeline, Event &event, const PhysicalOperator &op, - OperatorSinkFinalizeInput &finalize) const { - auto &gsink = finalize.global_state.Cast(); +static SourceResultType MaterializeHashGroupData(ExecutionContext &context, idx_t hash_bin, bool build_runs, + OperatorSourceInput &source) { + auto &gsource = source.global_state.Cast(); + auto &gsink = gsource.gsink; + auto &hash_group = *gsink.hash_groups[hash_bin]; + + // OVER() + if (gsink.hashed_sort.sort_col_count == 0) { + D_ASSERT(hash_bin == 0); + // Hack: Only report finished for the first call + return hash_group.get_columns++ ? SourceResultType::HAVE_MORE_OUTPUT : SourceResultType::FINISHED; + } + + auto &sort = *hash_group.sort; + auto &sort_global = *hash_group.sort_source; + auto sort_local = sort.GetLocalSourceState(context, sort_global); + + OperatorSourceInput input {sort_global, *sort_local, source.interrupt_state}; + if (build_runs) { + return sort.MaterializeSortedRun(context, input); + } else { + return sort.MaterializeColumnData(context, input); + } +} + +SourceResultType HashedSort::MaterializeColumnData(ExecutionContext &execution, idx_t hash_bin, + OperatorSourceInput &source) const { + return MaterializeHashGroupData(execution, hash_bin, false, source); +} + +HashedSort::HashGroupPtr HashedSort::GetColumnData(idx_t hash_bin, OperatorSourceInput &source) const { + auto &gsource = source.global_state.Cast(); + auto &gsink = gsource.gsink; + auto &hash_group = *gsink.hash_groups[hash_bin]; // OVER() if (sort_col_count == 0) { - auto &hash_group = *gsink.hash_groups[0]; - auto &unsorted = *hash_group.sorted; - if (!unsorted.Count()) { - return SinkFinalizeType::NO_OUTPUT_POSSIBLE; - } - return SinkFinalizeType::READY; + D_ASSERT(hash_bin == 0); + return std::move(hash_group.columns); } - // Schedule all the sorts for maximum thread utilisation - auto sort_event = make_shared_ptr(gsink, pipeline, op); - event.InsertEvent(std::move(sort_event)); + auto &sort = *hash_group.sort; + auto &sort_global = *hash_group.sort_source; - return SinkFinalizeType::READY; + OperatorSourceInput input {sort_global, source.local_state, source.interrupt_state}; + auto result = sort.GetColumnData(input); + + // Just because MaterializeColumnData returned FINISHED doesn't mean that the same thread will + // get the result... + if (result && result->Count() == hash_group.count) { + return result; + } + + return nullptr; +} + +SourceResultType HashedSort::MaterializeSortedRun(ExecutionContext &context, idx_t hash_bin, + OperatorSourceInput &source) const { + return MaterializeHashGroupData(context, hash_bin, true, source); +} + +HashedSort::SortedRunPtr HashedSort::GetSortedRun(ClientContext &client, idx_t hash_bin, + OperatorSourceInput &source) const { + auto &gsource = source.global_state.Cast(); + auto &gsink = gsource.gsink; + auto &hash_group = *gsink.hash_groups[hash_bin]; + + D_ASSERT(gsink.hashed_sort.sort_col_count); + + auto &sort = *hash_group.sort; + auto &sort_global = *hash_group.sort_source; + + auto result = sort.GetSortedRun(sort_global); + if (!result) { + D_ASSERT(hash_group.count == 0); + result = make_uniq(client, sort, false); + } + + return result; } } // namespace duckdb diff --git a/src/duckdb/src/common/sort/merge_sorter.cpp b/src/duckdb/src/common/sort/merge_sorter.cpp deleted file mode 100644 index c670fd574..000000000 --- a/src/duckdb/src/common/sort/merge_sorter.cpp +++ /dev/null @@ -1,667 +0,0 @@ -#include "duckdb/common/fast_mem.hpp" -#include "duckdb/common/sort/comparators.hpp" -#include "duckdb/common/sort/sort.hpp" - -namespace duckdb { - -MergeSorter::MergeSorter(GlobalSortState &state, BufferManager &buffer_manager) - : state(state), buffer_manager(buffer_manager), sort_layout(state.sort_layout) { -} - -void MergeSorter::PerformInMergeRound() { - while (true) { - // Check for interrupts after merging a partition - if (state.context.interrupted) { - throw InterruptException(); - } - { - lock_guard pair_guard(state.lock); - if (state.pair_idx == state.num_pairs) { - break; - } - GetNextPartition(); - } - MergePartition(); - } -} - -void MergeSorter::MergePartition() { - auto &left_block = *left->sb; - auto &right_block = *right->sb; -#ifdef DEBUG - D_ASSERT(left_block.radix_sorting_data.size() == left_block.payload_data->data_blocks.size()); - D_ASSERT(right_block.radix_sorting_data.size() == right_block.payload_data->data_blocks.size()); - if (!state.payload_layout.AllConstant() && state.external) { - D_ASSERT(left_block.payload_data->data_blocks.size() == left_block.payload_data->heap_blocks.size()); - D_ASSERT(right_block.payload_data->data_blocks.size() == right_block.payload_data->heap_blocks.size()); - } - if (!sort_layout.all_constant) { - D_ASSERT(left_block.radix_sorting_data.size() == left_block.blob_sorting_data->data_blocks.size()); - D_ASSERT(right_block.radix_sorting_data.size() == right_block.blob_sorting_data->data_blocks.size()); - if (state.external) { - D_ASSERT(left_block.blob_sorting_data->data_blocks.size() == - left_block.blob_sorting_data->heap_blocks.size()); - D_ASSERT(right_block.blob_sorting_data->data_blocks.size() == - right_block.blob_sorting_data->heap_blocks.size()); - } - } -#endif - // Set up the write block - // Each merge task produces a SortedBlock with exactly state.block_capacity rows or less - result->InitializeWrite(); - // Initialize arrays to store merge data - bool left_smaller[STANDARD_VECTOR_SIZE]; - idx_t next_entry_sizes[STANDARD_VECTOR_SIZE]; - // Merge loop -#ifdef DEBUG - auto l_count = left->Remaining(); - auto r_count = right->Remaining(); -#endif - while (true) { - auto l_remaining = left->Remaining(); - auto r_remaining = right->Remaining(); - if (l_remaining + r_remaining == 0) { - // Done - break; - } - const idx_t next = MinValue(l_remaining + r_remaining, (idx_t)STANDARD_VECTOR_SIZE); - if (l_remaining != 0 && r_remaining != 0) { - // Compute the merge (not needed if one side is exhausted) - ComputeMerge(next, left_smaller); - } - // Actually merge the data (radix, blob, and payload) - MergeRadix(next, left_smaller); - if (!sort_layout.all_constant) { - MergeData(*result->blob_sorting_data, *left_block.blob_sorting_data, *right_block.blob_sorting_data, next, - left_smaller, next_entry_sizes, true); - D_ASSERT(result->radix_sorting_data.size() == result->blob_sorting_data->data_blocks.size()); - } - MergeData(*result->payload_data, *left_block.payload_data, *right_block.payload_data, next, left_smaller, - next_entry_sizes, false); - D_ASSERT(result->radix_sorting_data.size() == result->payload_data->data_blocks.size()); - } -#ifdef DEBUG - D_ASSERT(result->Count() == l_count + r_count); -#endif -} - -void MergeSorter::GetNextPartition() { - // Create result block - state.sorted_blocks_temp[state.pair_idx].push_back(make_uniq(buffer_manager, state)); - result = state.sorted_blocks_temp[state.pair_idx].back().get(); - // Determine which blocks must be merged - auto &left_block = *state.sorted_blocks[state.pair_idx * 2]; - auto &right_block = *state.sorted_blocks[state.pair_idx * 2 + 1]; - const idx_t l_count = left_block.Count(); - const idx_t r_count = right_block.Count(); - // Initialize left and right reader - left = make_uniq(buffer_manager, state); - right = make_uniq(buffer_manager, state); - // Compute the work that this thread must do using Merge Path - idx_t l_end; - idx_t r_end; - if (state.l_start + state.r_start + state.block_capacity < l_count + r_count) { - left->sb = state.sorted_blocks[state.pair_idx * 2].get(); - right->sb = state.sorted_blocks[state.pair_idx * 2 + 1].get(); - const idx_t intersection = state.l_start + state.r_start + state.block_capacity; - GetIntersection(intersection, l_end, r_end); - D_ASSERT(l_end <= l_count); - D_ASSERT(r_end <= r_count); - D_ASSERT(intersection == l_end + r_end); - } else { - l_end = l_count; - r_end = r_count; - } - // Create slices of the data that this thread must merge - left->SetIndices(0, 0); - right->SetIndices(0, 0); - left_input = left_block.CreateSlice(state.l_start, l_end, left->entry_idx); - right_input = right_block.CreateSlice(state.r_start, r_end, right->entry_idx); - left->sb = left_input.get(); - right->sb = right_input.get(); - state.l_start = l_end; - state.r_start = r_end; - D_ASSERT(left->Remaining() + right->Remaining() == state.block_capacity || (l_end == l_count && r_end == r_count)); - // Update global state - if (state.l_start == l_count && state.r_start == r_count) { - // Delete references to previous pair - state.sorted_blocks[state.pair_idx * 2] = nullptr; - state.sorted_blocks[state.pair_idx * 2 + 1] = nullptr; - // Advance pair - state.pair_idx++; - state.l_start = 0; - state.r_start = 0; - } -} - -int MergeSorter::CompareUsingGlobalIndex(SBScanState &l, SBScanState &r, const idx_t l_idx, const idx_t r_idx) { - D_ASSERT(l_idx < l.sb->Count()); - D_ASSERT(r_idx < r.sb->Count()); - - // Easy comparison using the previous result (intersections must increase monotonically) - if (l_idx < state.l_start) { - return -1; - } - if (r_idx < state.r_start) { - return 1; - } - - l.sb->GlobalToLocalIndex(l_idx, l.block_idx, l.entry_idx); - r.sb->GlobalToLocalIndex(r_idx, r.block_idx, r.entry_idx); - - l.PinRadix(l.block_idx); - r.PinRadix(r.block_idx); - data_ptr_t l_ptr = l.radix_handle.Ptr() + l.entry_idx * sort_layout.entry_size; - data_ptr_t r_ptr = r.radix_handle.Ptr() + r.entry_idx * sort_layout.entry_size; - - int comp_res; - if (sort_layout.all_constant) { - comp_res = FastMemcmp(l_ptr, r_ptr, sort_layout.comparison_size); - } else { - l.PinData(*l.sb->blob_sorting_data); - r.PinData(*r.sb->blob_sorting_data); - comp_res = Comparators::CompareTuple(l, r, l_ptr, r_ptr, sort_layout, state.external); - } - return comp_res; -} - -void MergeSorter::GetIntersection(const idx_t diagonal, idx_t &l_idx, idx_t &r_idx) { - const idx_t l_count = left->sb->Count(); - const idx_t r_count = right->sb->Count(); - // Cover some edge cases - // Code coverage off because these edge cases cannot happen unless other code changes - // Edge cases have been tested extensively while developing Merge Path in a script - // LCOV_EXCL_START - if (diagonal >= l_count + r_count) { - l_idx = l_count; - r_idx = r_count; - return; - } else if (diagonal == 0) { - l_idx = 0; - r_idx = 0; - return; - } else if (l_count == 0) { - l_idx = 0; - r_idx = diagonal; - return; - } else if (r_count == 0) { - r_idx = 0; - l_idx = diagonal; - return; - } - // LCOV_EXCL_STOP - // Determine offsets for the binary search - const idx_t l_offset = MinValue(l_count, diagonal); - const idx_t r_offset = diagonal > l_count ? diagonal - l_count : 0; - D_ASSERT(l_offset + r_offset == diagonal); - const idx_t search_space = diagonal > MaxValue(l_count, r_count) ? l_count + r_count - diagonal - : MinValue(diagonal, MinValue(l_count, r_count)); - // Double binary search - idx_t li = 0; - idx_t ri = search_space - 1; - idx_t middle; - int comp_res; - while (li <= ri) { - middle = (li + ri) / 2; - l_idx = l_offset - middle; - r_idx = r_offset + middle; - if (l_idx == l_count || r_idx == 0) { - comp_res = CompareUsingGlobalIndex(*left, *right, l_idx - 1, r_idx); - if (comp_res > 0) { - l_idx--; - r_idx++; - } else { - return; - } - if (l_idx == 0 || r_idx == r_count) { - // This case is incredibly difficult to cover as it is dependent on parallelism randomness - // But it has been tested extensively during development in a script - // LCOV_EXCL_START - return; - // LCOV_EXCL_STOP - } else { - break; - } - } - comp_res = CompareUsingGlobalIndex(*left, *right, l_idx, r_idx); - if (comp_res > 0) { - li = middle + 1; - } else { - ri = middle - 1; - } - } - int l_r_min1 = CompareUsingGlobalIndex(*left, *right, l_idx, r_idx - 1); - int l_min1_r = CompareUsingGlobalIndex(*left, *right, l_idx - 1, r_idx); - if (l_r_min1 > 0 && l_min1_r < 0) { - return; - } else if (l_r_min1 > 0) { - l_idx--; - r_idx++; - } else if (l_min1_r < 0) { - l_idx++; - r_idx--; - } -} - -void MergeSorter::ComputeMerge(const idx_t &count, bool left_smaller[]) { - auto &l = *left; - auto &r = *right; - auto &l_sorted_block = *l.sb; - auto &r_sorted_block = *r.sb; - // Save indices to restore afterwards - idx_t l_block_idx_before = l.block_idx; - idx_t l_entry_idx_before = l.entry_idx; - idx_t r_block_idx_before = r.block_idx; - idx_t r_entry_idx_before = r.entry_idx; - // Data pointers for both sides - data_ptr_t l_radix_ptr; - data_ptr_t r_radix_ptr; - // Compute the merge of the next 'count' tuples - idx_t compared = 0; - while (compared < count) { - // Move to the next block (if needed) - if (l.block_idx < l_sorted_block.radix_sorting_data.size() && - l.entry_idx == l_sorted_block.radix_sorting_data[l.block_idx]->count) { - l.block_idx++; - l.entry_idx = 0; - } - if (r.block_idx < r_sorted_block.radix_sorting_data.size() && - r.entry_idx == r_sorted_block.radix_sorting_data[r.block_idx]->count) { - r.block_idx++; - r.entry_idx = 0; - } - const bool l_done = l.block_idx == l_sorted_block.radix_sorting_data.size(); - const bool r_done = r.block_idx == r_sorted_block.radix_sorting_data.size(); - if (l_done || r_done) { - // One of the sides is exhausted, no need to compare - break; - } - // Pin the radix sorting data - left->PinRadix(l.block_idx); - l_radix_ptr = left->RadixPtr(); - right->PinRadix(r.block_idx); - r_radix_ptr = right->RadixPtr(); - - const idx_t l_count = l_sorted_block.radix_sorting_data[l.block_idx]->count; - const idx_t r_count = r_sorted_block.radix_sorting_data[r.block_idx]->count; - // Compute the merge - if (sort_layout.all_constant) { - // All sorting columns are constant size - for (; compared < count && l.entry_idx < l_count && r.entry_idx < r_count; compared++) { - left_smaller[compared] = FastMemcmp(l_radix_ptr, r_radix_ptr, sort_layout.comparison_size) < 0; - const bool &l_smaller = left_smaller[compared]; - const bool r_smaller = !l_smaller; - // Use comparison bool (0 or 1) to increment entries and pointers - l.entry_idx += l_smaller; - r.entry_idx += r_smaller; - l_radix_ptr += l_smaller * sort_layout.entry_size; - r_radix_ptr += r_smaller * sort_layout.entry_size; - } - } else { - // Pin the blob data - left->PinData(*l_sorted_block.blob_sorting_data); - right->PinData(*r_sorted_block.blob_sorting_data); - // Merge with variable size sorting columns - for (; compared < count && l.entry_idx < l_count && r.entry_idx < r_count; compared++) { - left_smaller[compared] = - Comparators::CompareTuple(*left, *right, l_radix_ptr, r_radix_ptr, sort_layout, state.external) < 0; - const bool &l_smaller = left_smaller[compared]; - const bool r_smaller = !l_smaller; - // Use comparison bool (0 or 1) to increment entries and pointers - l.entry_idx += l_smaller; - r.entry_idx += r_smaller; - l_radix_ptr += l_smaller * sort_layout.entry_size; - r_radix_ptr += r_smaller * sort_layout.entry_size; - } - } - } - // Reset block indices - left->SetIndices(l_block_idx_before, l_entry_idx_before); - right->SetIndices(r_block_idx_before, r_entry_idx_before); -} - -void MergeSorter::MergeRadix(const idx_t &count, const bool left_smaller[]) { - auto &l = *left; - auto &r = *right; - // Save indices to restore afterwards - idx_t l_block_idx_before = l.block_idx; - idx_t l_entry_idx_before = l.entry_idx; - idx_t r_block_idx_before = r.block_idx; - idx_t r_entry_idx_before = r.entry_idx; - - auto &l_blocks = l.sb->radix_sorting_data; - auto &r_blocks = r.sb->radix_sorting_data; - RowDataBlock *l_block = nullptr; - RowDataBlock *r_block = nullptr; - - data_ptr_t l_ptr; - data_ptr_t r_ptr; - - RowDataBlock *result_block = result->radix_sorting_data.back().get(); - auto result_handle = buffer_manager.Pin(result_block->block); - data_ptr_t result_ptr = result_handle.Ptr() + result_block->count * sort_layout.entry_size; - - idx_t copied = 0; - while (copied < count) { - // Move to the next block (if needed) - if (l.block_idx < l_blocks.size() && l.entry_idx == l_blocks[l.block_idx]->count) { - // Delete reference to previous block - l_blocks[l.block_idx]->block = nullptr; - // Advance block - l.block_idx++; - l.entry_idx = 0; - } - if (r.block_idx < r_blocks.size() && r.entry_idx == r_blocks[r.block_idx]->count) { - // Delete reference to previous block - r_blocks[r.block_idx]->block = nullptr; - // Advance block - r.block_idx++; - r.entry_idx = 0; - } - const bool l_done = l.block_idx == l_blocks.size(); - const bool r_done = r.block_idx == r_blocks.size(); - // Pin the radix sortable blocks - idx_t l_count; - if (!l_done) { - l_block = l_blocks[l.block_idx].get(); - left->PinRadix(l.block_idx); - l_ptr = l.RadixPtr(); - l_count = l_block->count; - } else { - l_count = 0; - } - idx_t r_count; - if (!r_done) { - r_block = r_blocks[r.block_idx].get(); - r.PinRadix(r.block_idx); - r_ptr = r.RadixPtr(); - r_count = r_block->count; - } else { - r_count = 0; - } - // Copy using computed merge - if (!l_done && !r_done) { - // Both sides have data - merge - MergeRows(l_ptr, l.entry_idx, l_count, r_ptr, r.entry_idx, r_count, *result_block, result_ptr, - sort_layout.entry_size, left_smaller, copied, count); - } else if (r_done) { - // Right side is exhausted - FlushRows(l_ptr, l.entry_idx, l_count, *result_block, result_ptr, sort_layout.entry_size, copied, count); - } else { - // Left side is exhausted - FlushRows(r_ptr, r.entry_idx, r_count, *result_block, result_ptr, sort_layout.entry_size, copied, count); - } - } - // Reset block indices - left->SetIndices(l_block_idx_before, l_entry_idx_before); - right->SetIndices(r_block_idx_before, r_entry_idx_before); -} - -void MergeSorter::MergeData(SortedData &result_data, SortedData &l_data, SortedData &r_data, const idx_t &count, - const bool left_smaller[], idx_t next_entry_sizes[], bool reset_indices) { - auto &l = *left; - auto &r = *right; - // Save indices to restore afterwards - idx_t l_block_idx_before = l.block_idx; - idx_t l_entry_idx_before = l.entry_idx; - idx_t r_block_idx_before = r.block_idx; - idx_t r_entry_idx_before = r.entry_idx; - - const auto &layout = result_data.layout; - const idx_t row_width = layout.GetRowWidth(); - const idx_t heap_pointer_offset = layout.GetHeapOffset(); - - // Left and right row data to merge - data_ptr_t l_ptr; - data_ptr_t r_ptr; - // Accompanying left and right heap data (if needed) - data_ptr_t l_heap_ptr; - data_ptr_t r_heap_ptr; - - // Result rows to write to - RowDataBlock *result_data_block = result_data.data_blocks.back().get(); - auto result_data_handle = buffer_manager.Pin(result_data_block->block); - data_ptr_t result_data_ptr = result_data_handle.Ptr() + result_data_block->count * row_width; - // Result heap to write to (if needed) - RowDataBlock *result_heap_block = nullptr; - BufferHandle result_heap_handle; - data_ptr_t result_heap_ptr; - if (!layout.AllConstant() && state.external) { - result_heap_block = result_data.heap_blocks.back().get(); - result_heap_handle = buffer_manager.Pin(result_heap_block->block); - result_heap_ptr = result_heap_handle.Ptr() + result_heap_block->byte_offset; - } - - idx_t copied = 0; - while (copied < count) { - // Move to new data blocks (if needed) - if (l.block_idx < l_data.data_blocks.size() && l.entry_idx == l_data.data_blocks[l.block_idx]->count) { - // Delete reference to previous block - l_data.data_blocks[l.block_idx]->block = nullptr; - if (!layout.AllConstant() && state.external) { - l_data.heap_blocks[l.block_idx]->block = nullptr; - } - // Advance block - l.block_idx++; - l.entry_idx = 0; - } - if (r.block_idx < r_data.data_blocks.size() && r.entry_idx == r_data.data_blocks[r.block_idx]->count) { - // Delete reference to previous block - r_data.data_blocks[r.block_idx]->block = nullptr; - if (!layout.AllConstant() && state.external) { - r_data.heap_blocks[r.block_idx]->block = nullptr; - } - // Advance block - r.block_idx++; - r.entry_idx = 0; - } - const bool l_done = l.block_idx == l_data.data_blocks.size(); - const bool r_done = r.block_idx == r_data.data_blocks.size(); - // Pin the row data blocks - if (!l_done) { - l.PinData(l_data); - l_ptr = l.DataPtr(l_data); - } - if (!r_done) { - r.PinData(r_data); - r_ptr = r.DataPtr(r_data); - } - const idx_t &l_count = !l_done ? l_data.data_blocks[l.block_idx]->count : 0; - const idx_t &r_count = !r_done ? r_data.data_blocks[r.block_idx]->count : 0; - // Perform the merge - if (layout.AllConstant() || !state.external) { - // If all constant size, or if we are doing an in-memory sort, we do not need to touch the heap - if (!l_done && !r_done) { - // Both sides have data - merge - MergeRows(l_ptr, l.entry_idx, l_count, r_ptr, r.entry_idx, r_count, *result_data_block, result_data_ptr, - row_width, left_smaller, copied, count); - } else if (r_done) { - // Right side is exhausted - FlushRows(l_ptr, l.entry_idx, l_count, *result_data_block, result_data_ptr, row_width, copied, count); - } else { - // Left side is exhausted - FlushRows(r_ptr, r.entry_idx, r_count, *result_data_block, result_data_ptr, row_width, copied, count); - } - } else { - // External sorting with variable size data. Pin the heap blocks too - if (!l_done) { - l_heap_ptr = l.BaseHeapPtr(l_data) + Load(l_ptr + heap_pointer_offset); - D_ASSERT(l_heap_ptr - l.BaseHeapPtr(l_data) >= 0); - D_ASSERT((idx_t)(l_heap_ptr - l.BaseHeapPtr(l_data)) < l_data.heap_blocks[l.block_idx]->byte_offset); - } - if (!r_done) { - r_heap_ptr = r.BaseHeapPtr(r_data) + Load(r_ptr + heap_pointer_offset); - D_ASSERT(r_heap_ptr - r.BaseHeapPtr(r_data) >= 0); - D_ASSERT((idx_t)(r_heap_ptr - r.BaseHeapPtr(r_data)) < r_data.heap_blocks[r.block_idx]->byte_offset); - } - // Both the row and heap data need to be dealt with - if (!l_done && !r_done) { - // Both sides have data - merge - idx_t l_idx_copy = l.entry_idx; - idx_t r_idx_copy = r.entry_idx; - data_ptr_t result_data_ptr_copy = result_data_ptr; - idx_t copied_copy = copied; - // Merge row data - MergeRows(l_ptr, l_idx_copy, l_count, r_ptr, r_idx_copy, r_count, *result_data_block, - result_data_ptr_copy, row_width, left_smaller, copied_copy, count); - const idx_t merged = copied_copy - copied; - // Compute the entry sizes and number of heap bytes that will be copied - idx_t copy_bytes = 0; - data_ptr_t l_heap_ptr_copy = l_heap_ptr; - data_ptr_t r_heap_ptr_copy = r_heap_ptr; - for (idx_t i = 0; i < merged; i++) { - // Store base heap offset in the row data - Store(result_heap_block->byte_offset + copy_bytes, result_data_ptr + heap_pointer_offset); - result_data_ptr += row_width; - // Compute entry size and add to total - const bool &l_smaller = left_smaller[copied + i]; - const bool r_smaller = !l_smaller; - auto &entry_size = next_entry_sizes[copied + i]; - entry_size = - l_smaller * Load(l_heap_ptr_copy) + r_smaller * Load(r_heap_ptr_copy); - D_ASSERT(entry_size >= sizeof(uint32_t)); - D_ASSERT(NumericCast(l_heap_ptr_copy - l.BaseHeapPtr(l_data)) + l_smaller * entry_size <= - l_data.heap_blocks[l.block_idx]->byte_offset); - D_ASSERT(NumericCast(r_heap_ptr_copy - r.BaseHeapPtr(r_data)) + r_smaller * entry_size <= - r_data.heap_blocks[r.block_idx]->byte_offset); - l_heap_ptr_copy += l_smaller * entry_size; - r_heap_ptr_copy += r_smaller * entry_size; - copy_bytes += entry_size; - } - // Reallocate result heap block size (if needed) - if (result_heap_block->byte_offset + copy_bytes > result_heap_block->capacity) { - idx_t new_capacity = result_heap_block->byte_offset + copy_bytes; - buffer_manager.ReAllocate(result_heap_block->block, new_capacity); - result_heap_block->capacity = new_capacity; - result_heap_ptr = result_heap_handle.Ptr() + result_heap_block->byte_offset; - } - D_ASSERT(result_heap_block->byte_offset + copy_bytes <= result_heap_block->capacity); - // Now copy the heap data - for (idx_t i = 0; i < merged; i++) { - const bool &l_smaller = left_smaller[copied + i]; - const bool r_smaller = !l_smaller; - const auto &entry_size = next_entry_sizes[copied + i]; - memcpy(result_heap_ptr, - reinterpret_cast(l_smaller * CastPointerToValue(l_heap_ptr) + - r_smaller * CastPointerToValue(r_heap_ptr)), - entry_size); - D_ASSERT(Load(result_heap_ptr) == entry_size); - result_heap_ptr += entry_size; - l_heap_ptr += l_smaller * entry_size; - r_heap_ptr += r_smaller * entry_size; - l.entry_idx += l_smaller; - r.entry_idx += r_smaller; - } - // Update result indices and pointers - result_heap_block->count += merged; - result_heap_block->byte_offset += copy_bytes; - copied += merged; - } else if (r_done) { - // Right side is exhausted - flush left - FlushBlobs(layout, l_count, l_ptr, l.entry_idx, l_heap_ptr, *result_data_block, result_data_ptr, - *result_heap_block, result_heap_handle, result_heap_ptr, copied, count); - } else { - // Left side is exhausted - flush right - FlushBlobs(layout, r_count, r_ptr, r.entry_idx, r_heap_ptr, *result_data_block, result_data_ptr, - *result_heap_block, result_heap_handle, result_heap_ptr, copied, count); - } - D_ASSERT(result_data_block->count == result_heap_block->count); - } - } - if (reset_indices) { - left->SetIndices(l_block_idx_before, l_entry_idx_before); - right->SetIndices(r_block_idx_before, r_entry_idx_before); - } -} - -void MergeSorter::MergeRows(data_ptr_t &l_ptr, idx_t &l_entry_idx, const idx_t &l_count, data_ptr_t &r_ptr, - idx_t &r_entry_idx, const idx_t &r_count, RowDataBlock &target_block, - data_ptr_t &target_ptr, const idx_t &entry_size, const bool left_smaller[], idx_t &copied, - const idx_t &count) { - const idx_t next = MinValue(count - copied, target_block.capacity - target_block.count); - idx_t i; - for (i = 0; i < next && l_entry_idx < l_count && r_entry_idx < r_count; i++) { - const bool &l_smaller = left_smaller[copied + i]; - const bool r_smaller = !l_smaller; - // Use comparison bool (0 or 1) to copy an entry from either side - FastMemcpy( - target_ptr, - reinterpret_cast(l_smaller * CastPointerToValue(l_ptr) + r_smaller * CastPointerToValue(r_ptr)), - entry_size); - target_ptr += entry_size; - // Use the comparison bool to increment entries and pointers - l_entry_idx += l_smaller; - r_entry_idx += r_smaller; - l_ptr += l_smaller * entry_size; - r_ptr += r_smaller * entry_size; - } - // Update counts - target_block.count += i; - copied += i; -} - -void MergeSorter::FlushRows(data_ptr_t &source_ptr, idx_t &source_entry_idx, const idx_t &source_count, - RowDataBlock &target_block, data_ptr_t &target_ptr, const idx_t &entry_size, idx_t &copied, - const idx_t &count) { - // Compute how many entries we can fit - idx_t next = MinValue(count - copied, target_block.capacity - target_block.count); - next = MinValue(next, source_count - source_entry_idx); - // Copy them all in a single memcpy - const idx_t copy_bytes = next * entry_size; - memcpy(target_ptr, source_ptr, copy_bytes); - target_ptr += copy_bytes; - source_ptr += copy_bytes; - // Update counts - source_entry_idx += next; - target_block.count += next; - copied += next; -} - -void MergeSorter::FlushBlobs(const RowLayout &layout, const idx_t &source_count, data_ptr_t &source_data_ptr, - idx_t &source_entry_idx, data_ptr_t &source_heap_ptr, RowDataBlock &target_data_block, - data_ptr_t &target_data_ptr, RowDataBlock &target_heap_block, - BufferHandle &target_heap_handle, data_ptr_t &target_heap_ptr, idx_t &copied, - const idx_t &count) { - const idx_t row_width = layout.GetRowWidth(); - const idx_t heap_pointer_offset = layout.GetHeapOffset(); - idx_t source_entry_idx_copy = source_entry_idx; - data_ptr_t target_data_ptr_copy = target_data_ptr; - idx_t copied_copy = copied; - // Flush row data - FlushRows(source_data_ptr, source_entry_idx_copy, source_count, target_data_block, target_data_ptr_copy, row_width, - copied_copy, count); - const idx_t flushed = copied_copy - copied; - // Compute the entry sizes and number of heap bytes that will be copied - idx_t copy_bytes = 0; - data_ptr_t source_heap_ptr_copy = source_heap_ptr; - for (idx_t i = 0; i < flushed; i++) { - // Store base heap offset in the row data - Store(target_heap_block.byte_offset + copy_bytes, target_data_ptr + heap_pointer_offset); - target_data_ptr += row_width; - // Compute entry size and add to total - auto entry_size = Load(source_heap_ptr_copy); - D_ASSERT(entry_size >= sizeof(uint32_t)); - source_heap_ptr_copy += entry_size; - copy_bytes += entry_size; - } - // Reallocate result heap block size (if needed) - if (target_heap_block.byte_offset + copy_bytes > target_heap_block.capacity) { - idx_t new_capacity = target_heap_block.byte_offset + copy_bytes; - buffer_manager.ReAllocate(target_heap_block.block, new_capacity); - target_heap_block.capacity = new_capacity; - target_heap_ptr = target_heap_handle.Ptr() + target_heap_block.byte_offset; - } - D_ASSERT(target_heap_block.byte_offset + copy_bytes <= target_heap_block.capacity); - // Copy the heap data in one go - memcpy(target_heap_ptr, source_heap_ptr, copy_bytes); - target_heap_ptr += copy_bytes; - source_heap_ptr += copy_bytes; - source_entry_idx += flushed; - copied += flushed; - // Update result indices and pointers - target_heap_block.count += flushed; - target_heap_block.byte_offset += copy_bytes; - D_ASSERT(target_heap_block.byte_offset <= target_heap_block.capacity); -} - -} // namespace duckdb diff --git a/src/duckdb/src/common/sort/partition_state.cpp b/src/duckdb/src/common/sort/partition_state.cpp deleted file mode 100644 index 2a0a65895..000000000 --- a/src/duckdb/src/common/sort/partition_state.cpp +++ /dev/null @@ -1,671 +0,0 @@ -#include "duckdb/common/sort/partition_state.hpp" - -#include "duckdb/common/row_operations/row_operations.hpp" -#include "duckdb/main/config.hpp" -#include "duckdb/parallel/executor_task.hpp" - -namespace duckdb { - -PartitionGlobalHashGroup::PartitionGlobalHashGroup(ClientContext &context, const Orders &partitions, - const Orders &orders, const Types &payload_types, bool external) - : count(0) { - - RowLayout payload_layout; - payload_layout.Initialize(payload_types); - global_sort = make_uniq(context, orders, payload_layout); - global_sort->external = external; - - // Set up a comparator for the partition subset - partition_layout = global_sort->sort_layout.GetPrefixComparisonLayout(partitions.size()); -} - -void PartitionGlobalHashGroup::ComputeMasks(ValidityMask &partition_mask, OrderMasks &order_masks) { - D_ASSERT(count > 0); - - SBIterator prev(*global_sort, ExpressionType::COMPARE_LESSTHAN); - SBIterator curr(*global_sort, ExpressionType::COMPARE_LESSTHAN); - - partition_mask.SetValidUnsafe(0); - unordered_map prefixes; - for (auto &order_mask : order_masks) { - order_mask.second.SetValidUnsafe(0); - D_ASSERT(order_mask.first >= partition_layout.column_count); - prefixes[order_mask.first] = global_sort->sort_layout.GetPrefixComparisonLayout(order_mask.first); - } - - for (++curr; curr.GetIndex() < count; ++curr) { - // Compare the partition subset first because if that differs, then so does the full ordering - const auto part_cmp = ComparePartitions(prev, curr); - - if (part_cmp) { - partition_mask.SetValidUnsafe(curr.GetIndex()); - for (auto &order_mask : order_masks) { - order_mask.second.SetValidUnsafe(curr.GetIndex()); - } - } else { - for (auto &order_mask : order_masks) { - if (prev.Compare(curr, prefixes[order_mask.first])) { - order_mask.second.SetValidUnsafe(curr.GetIndex()); - } - } - } - ++prev; - } -} - -void PartitionGlobalSinkState::GenerateOrderings(Orders &partitions, Orders &orders, - const vector> &partition_bys, - const Orders &order_bys, - const vector> &partition_stats) { - - // we sort by both 1) partition by expression list and 2) order by expressions - const auto partition_cols = partition_bys.size(); - for (idx_t prt_idx = 0; prt_idx < partition_cols; prt_idx++) { - auto &pexpr = partition_bys[prt_idx]; - - if (partition_stats.empty() || !partition_stats[prt_idx]) { - orders.emplace_back(OrderType::ASCENDING, OrderByNullType::NULLS_FIRST, pexpr->Copy(), nullptr); - } else { - orders.emplace_back(OrderType::ASCENDING, OrderByNullType::NULLS_FIRST, pexpr->Copy(), - partition_stats[prt_idx]->ToUnique()); - } - partitions.emplace_back(orders.back().Copy()); - } - - for (const auto &order : order_bys) { - orders.emplace_back(order.Copy()); - } -} - -PartitionGlobalSinkState::PartitionGlobalSinkState(ClientContext &context, - const vector> &partition_bys, - const vector &order_bys, - const Types &payload_types, - const vector> &partition_stats, - idx_t estimated_cardinality) - : context(context), buffer_manager(BufferManager::GetBufferManager(context)), allocator(Allocator::Get(context)), - fixed_bits(0), payload_types(payload_types), memory_per_thread(0), max_bits(1), count(0) { - - GenerateOrderings(partitions, orders, partition_bys, order_bys, partition_stats); - - memory_per_thread = PhysicalOperator::GetMaxThreadMemory(context); - external = ClientConfig::GetConfig(context).force_external; - - const auto thread_pages = PreviousPowerOfTwo(memory_per_thread / (4 * buffer_manager.GetBlockAllocSize())); - while (max_bits < 10 && (thread_pages >> max_bits) > 1) { - ++max_bits; - } - - grouping_types_ptr = make_shared_ptr(); - if (!orders.empty()) { - if (partitions.empty()) { - // Sort early into a dedicated hash group if we only sort. - grouping_types_ptr->Initialize(payload_types, TupleDataValidityType::CAN_HAVE_NULL_VALUES); - auto new_group = make_uniq(context, partitions, orders, payload_types, external); - hash_groups.emplace_back(std::move(new_group)); - } else { - auto types = payload_types; - types.push_back(LogicalType::HASH); - grouping_types_ptr->Initialize(types, TupleDataValidityType::CAN_HAVE_NULL_VALUES); - ResizeGroupingData(estimated_cardinality); - } - } -} - -bool PartitionGlobalSinkState::HasMergeTasks() const { - if (grouping_data) { - auto &groups = grouping_data->GetPartitions(); - return !groups.empty(); - } else if (!hash_groups.empty()) { - D_ASSERT(hash_groups.size() == 1); - return hash_groups[0]->count > 0; - } else { - return false; - } -} - -void PartitionGlobalSinkState::SyncPartitioning(const PartitionGlobalSinkState &other) { - fixed_bits = other.grouping_data ? other.grouping_data->GetRadixBits() : 0; - - const auto old_bits = grouping_data ? grouping_data->GetRadixBits() : 0; - if (fixed_bits != old_bits) { - const auto hash_col_idx = payload_types.size(); - grouping_data = - make_uniq(buffer_manager, grouping_types_ptr, fixed_bits, hash_col_idx); - } -} - -unique_ptr PartitionGlobalSinkState::CreatePartition(idx_t new_bits) const { - const auto hash_col_idx = payload_types.size(); - return make_uniq(buffer_manager, grouping_types_ptr, new_bits, hash_col_idx); -} - -void PartitionGlobalSinkState::ResizeGroupingData(idx_t cardinality) { - // Have we started to combine? Then just live with it. - if (fixed_bits || (grouping_data && !grouping_data->GetPartitions().empty())) { - return; - } - // Is the average partition size too large? - const idx_t partition_size = DEFAULT_ROW_GROUP_SIZE; - const auto bits = grouping_data ? grouping_data->GetRadixBits() : 0; - auto new_bits = bits ? bits : 4; - while (new_bits < max_bits && (cardinality / RadixPartitioning::NumberOfPartitions(new_bits)) > partition_size) { - ++new_bits; - } - - // Repartition the grouping data - if (new_bits != bits) { - grouping_data = CreatePartition(new_bits); - } -} - -void PartitionGlobalSinkState::SyncLocalPartition(GroupingPartition &local_partition, GroupingAppend &local_append) { - // We are done if the local_partition is right sized. - auto &local_radix = local_partition->Cast(); - const auto new_bits = grouping_data->GetRadixBits(); - if (local_radix.GetRadixBits() == new_bits) { - return; - } - - // If the local partition is now too small, flush it and reallocate - auto new_partition = CreatePartition(new_bits); - local_partition->FlushAppendState(*local_append); - local_partition->Repartition(context, *new_partition); - - local_partition = std::move(new_partition); - local_append = make_uniq(); - local_partition->InitializeAppendState(*local_append); -} - -void PartitionGlobalSinkState::UpdateLocalPartition(GroupingPartition &local_partition, GroupingAppend &local_append) { - // Make sure grouping_data doesn't change under us. - lock_guard guard(lock); - - if (!local_partition) { - local_partition = CreatePartition(grouping_data->GetRadixBits()); - local_append = make_uniq(); - local_partition->InitializeAppendState(*local_append); - return; - } - - // Grow the groups if they are too big - ResizeGroupingData(count); - - // Sync local partition to have the same bit count - SyncLocalPartition(local_partition, local_append); -} - -void PartitionGlobalSinkState::CombineLocalPartition(GroupingPartition &local_partition, GroupingAppend &local_append) { - if (!local_partition) { - return; - } - local_partition->FlushAppendState(*local_append); - - // Make sure grouping_data doesn't change under us. - // Combine has an internal mutex, so this is single-threaded anyway. - lock_guard guard(lock); - SyncLocalPartition(local_partition, local_append); - grouping_data->Combine(*local_partition); -} - -PartitionLocalMergeState::PartitionLocalMergeState(PartitionGlobalSinkState &gstate) - : merge_state(nullptr), stage(PartitionSortStage::INIT), finished(true), executor(gstate.context) { - - // Set up the sort expression computation. - vector sort_types; - for (auto &order : gstate.orders) { - auto &oexpr = order.expression; - sort_types.emplace_back(oexpr->return_type); - executor.AddExpression(*oexpr); - } - sort_chunk.Initialize(gstate.allocator, sort_types); - payload_chunk.Initialize(gstate.allocator, gstate.payload_types); -} - -void PartitionLocalMergeState::Scan() { - if (!merge_state->group_data) { - // OVER(ORDER BY...) - // Already sorted - return; - } - - auto &group_data = *merge_state->group_data; - auto &hash_group = *merge_state->hash_group; - auto &chunk_state = merge_state->chunk_state; - // Copy the data from the group into the sort code. - auto &global_sort = *hash_group.global_sort; - LocalSortState local_sort; - local_sort.Initialize(global_sort, global_sort.buffer_manager); - - TupleDataScanState local_scan; - group_data.InitializeScan(local_scan, merge_state->column_ids); - while (group_data.Scan(chunk_state, local_scan, payload_chunk)) { - sort_chunk.Reset(); - executor.Execute(payload_chunk, sort_chunk); - - local_sort.SinkChunk(sort_chunk, payload_chunk); - if (local_sort.SizeInBytes() > merge_state->memory_per_thread) { - local_sort.Sort(global_sort, true); - } - hash_group.count += payload_chunk.size(); - } - - global_sort.AddLocalState(local_sort); -} - -// Per-thread sink state -PartitionLocalSinkState::PartitionLocalSinkState(ClientContext &context, PartitionGlobalSinkState &gstate_p) - : gstate(gstate_p), allocator(Allocator::Get(context)), executor(context) { - - vector group_types; - for (idx_t prt_idx = 0; prt_idx < gstate.partitions.size(); prt_idx++) { - auto &pexpr = *gstate.partitions[prt_idx].expression.get(); - group_types.push_back(pexpr.return_type); - executor.AddExpression(pexpr); - } - sort_cols = gstate.orders.size() + group_types.size(); - - if (sort_cols) { - auto payload_types = gstate.payload_types; - if (!group_types.empty()) { - // OVER(PARTITION BY...) - group_chunk.Initialize(allocator, group_types); - payload_types.emplace_back(LogicalType::HASH); - } else { - // OVER(ORDER BY...) - for (idx_t ord_idx = 0; ord_idx < gstate.orders.size(); ord_idx++) { - auto &pexpr = *gstate.orders[ord_idx].expression.get(); - group_types.push_back(pexpr.return_type); - executor.AddExpression(pexpr); - } - group_chunk.Initialize(allocator, group_types); - - // Single partition - auto &global_sort = *gstate.hash_groups[0]->global_sort; - local_sort = make_uniq(); - local_sort->Initialize(global_sort, global_sort.buffer_manager); - } - // OVER(...) - payload_chunk.Initialize(allocator, payload_types); - } else { - // OVER() - payload_layout.Initialize(gstate.payload_types); - } -} - -void PartitionLocalSinkState::Hash(DataChunk &input_chunk, Vector &hash_vector) { - const auto count = input_chunk.size(); - D_ASSERT(group_chunk.ColumnCount() > 0); - - // OVER(PARTITION BY...) (hash grouping) - group_chunk.Reset(); - executor.Execute(input_chunk, group_chunk); - VectorOperations::Hash(group_chunk.data[0], hash_vector, count); - for (idx_t prt_idx = 1; prt_idx < group_chunk.ColumnCount(); ++prt_idx) { - VectorOperations::CombineHash(hash_vector, group_chunk.data[prt_idx], count); - } -} - -void PartitionLocalSinkState::Sink(DataChunk &input_chunk) { - gstate.count += input_chunk.size(); - - // OVER() - if (sort_cols == 0) { - // No sorts, so build paged row chunks - if (!rows) { - const auto entry_size = payload_layout.GetRowWidth(); - const auto block_size = gstate.buffer_manager.GetBlockSize(); - const auto capacity = MaxValue(STANDARD_VECTOR_SIZE, block_size / entry_size + 1); - rows = make_uniq(gstate.buffer_manager, capacity, entry_size); - strings = make_uniq(gstate.buffer_manager, block_size, 1U, true); - } - const auto row_count = input_chunk.size(); - const auto row_sel = FlatVector::IncrementalSelectionVector(); - Vector addresses(LogicalType::POINTER); - auto key_locations = FlatVector::GetData(addresses); - const auto prev_rows_blocks = rows->blocks.size(); - auto handles = rows->Build(row_count, key_locations, nullptr, row_sel); - auto input_data = input_chunk.ToUnifiedFormat(); - RowOperations::Scatter(input_chunk, input_data.get(), payload_layout, addresses, *strings, *row_sel, row_count); - // Mark that row blocks contain pointers (heap blocks are pinned) - if (!payload_layout.AllConstant()) { - D_ASSERT(strings->keep_pinned); - for (size_t i = prev_rows_blocks; i < rows->blocks.size(); ++i) { - rows->blocks[i]->block->SetSwizzling("PartitionLocalSinkState::Sink"); - } - } - return; - } - - if (local_sort) { - // OVER(ORDER BY...) - group_chunk.Reset(); - executor.Execute(input_chunk, group_chunk); - local_sort->SinkChunk(group_chunk, input_chunk); - - auto &hash_group = *gstate.hash_groups[0]; - hash_group.count += input_chunk.size(); - - if (local_sort->SizeInBytes() > gstate.memory_per_thread) { - auto &global_sort = *hash_group.global_sort; - local_sort->Sort(global_sort, true); - } - return; - } - - // OVER(...) - payload_chunk.Reset(); - auto &hash_vector = payload_chunk.data.back(); - Hash(input_chunk, hash_vector); - for (idx_t col_idx = 0; col_idx < input_chunk.ColumnCount(); ++col_idx) { - payload_chunk.data[col_idx].Reference(input_chunk.data[col_idx]); - } - payload_chunk.SetCardinality(input_chunk); - - gstate.UpdateLocalPartition(local_partition, local_append); - local_partition->Append(*local_append, payload_chunk); -} - -void PartitionLocalSinkState::Combine() { - // OVER() - if (sort_cols == 0) { - // Only one partition again, so need a global lock. - lock_guard glock(gstate.lock); - if (gstate.rows) { - if (rows) { - gstate.rows->Merge(*rows); - gstate.strings->Merge(*strings); - rows.reset(); - strings.reset(); - } - } else { - gstate.rows = std::move(rows); - gstate.strings = std::move(strings); - } - return; - } - - if (local_sort) { - // OVER(ORDER BY...) - auto &hash_group = *gstate.hash_groups[0]; - auto &global_sort = *hash_group.global_sort; - global_sort.AddLocalState(*local_sort); - local_sort.reset(); - return; - } - - // OVER(...) - gstate.CombineLocalPartition(local_partition, local_append); -} - -PartitionGlobalMergeState::PartitionGlobalMergeState(PartitionGlobalSinkState &sink, GroupDataPtr group_data_p, - hash_t hash_bin) - : sink(sink), group_data(std::move(group_data_p)), group_idx(sink.hash_groups.size()), - memory_per_thread(sink.memory_per_thread), - num_threads(NumericCast(TaskScheduler::GetScheduler(sink.context).NumberOfThreads())), - stage(PartitionSortStage::INIT), total_tasks(0), tasks_assigned(0), tasks_completed(0) { - - auto new_group = make_uniq(sink.context, sink.partitions, sink.orders, sink.payload_types, - sink.external); - sink.hash_groups.emplace_back(std::move(new_group)); - - hash_group = sink.hash_groups[group_idx].get(); - global_sort = sink.hash_groups[group_idx]->global_sort.get(); - - sink.bin_groups[hash_bin] = group_idx; - - column_ids.reserve(sink.payload_types.size()); - for (column_t i = 0; i < sink.payload_types.size(); ++i) { - column_ids.emplace_back(i); - } - group_data->InitializeScan(chunk_state, column_ids); -} - -PartitionGlobalMergeState::PartitionGlobalMergeState(PartitionGlobalSinkState &sink) - : sink(sink), group_idx(0), memory_per_thread(sink.memory_per_thread), - num_threads(NumericCast(TaskScheduler::GetScheduler(sink.context).NumberOfThreads())), - stage(PartitionSortStage::INIT), total_tasks(0), tasks_assigned(0), tasks_completed(0) { - - const hash_t hash_bin = 0; - hash_group = sink.hash_groups[group_idx].get(); - global_sort = sink.hash_groups[group_idx]->global_sort.get(); - - sink.bin_groups[hash_bin] = group_idx; -} - -void PartitionLocalMergeState::Prepare() { - merge_state->group_data.reset(); - - auto &global_sort = *merge_state->global_sort; - global_sort.PrepareMergePhase(); -} - -void PartitionLocalMergeState::Merge() { - auto &global_sort = *merge_state->global_sort; - MergeSorter merge_sorter(global_sort, global_sort.buffer_manager); - merge_sorter.PerformInMergeRound(); -} - -void PartitionLocalMergeState::Sorted() { - merge_state->sink.OnSortedPartition(merge_state->group_idx); -} - -void PartitionLocalMergeState::ExecuteTask() { - switch (stage) { - case PartitionSortStage::SCAN: - Scan(); - break; - case PartitionSortStage::PREPARE: - Prepare(); - break; - case PartitionSortStage::MERGE: - Merge(); - break; - case PartitionSortStage::SORTED: - Sorted(); - break; - default: - throw InternalException("Unexpected PartitionSortStage in ExecuteTask!"); - } - - merge_state->CompleteTask(); - finished = true; -} - -bool PartitionGlobalMergeState::AssignTask(PartitionLocalMergeState &local_state) { - lock_guard guard(lock); - - if (tasks_assigned >= total_tasks && !TryPrepareNextStage()) { - return false; - } - - local_state.merge_state = this; - local_state.stage = stage; - local_state.finished = false; - tasks_assigned++; - - return true; -} - -void PartitionGlobalMergeState::CompleteTask() { - lock_guard guard(lock); - - ++tasks_completed; -} - -bool PartitionGlobalMergeState::TryPrepareNextStage() { - if (tasks_completed < total_tasks) { - return false; - } - - tasks_assigned = tasks_completed = 0; - - switch (stage.load()) { - case PartitionSortStage::INIT: - // If the partitions are unordered, don't scan in parallel - // because it produces non-deterministic orderings. - // This can theoretically happen with ORDER BY, - // but that is something the query should be explicit about. - total_tasks = sink.orders.size() > sink.partitions.size() ? num_threads : 1; - stage = PartitionSortStage::SCAN; - return true; - - case PartitionSortStage::SCAN: - total_tasks = 1; - stage = PartitionSortStage::PREPARE; - return true; - - case PartitionSortStage::PREPARE: - if (!(global_sort->sorted_blocks.size() / 2)) { - break; - } - stage = PartitionSortStage::MERGE; - global_sort->InitializeMergeRound(); - total_tasks = num_threads; - return true; - - case PartitionSortStage::MERGE: - global_sort->CompleteMergeRound(true); - if (!(global_sort->sorted_blocks.size() / 2)) { - break; - } - global_sort->InitializeMergeRound(); - total_tasks = num_threads; - return true; - - case PartitionSortStage::SORTED: - stage = PartitionSortStage::FINISHED; - total_tasks = 0; - return false; - - case PartitionSortStage::FINISHED: - return false; - } - - stage = PartitionSortStage::SORTED; - total_tasks = 1; - - return true; -} - -PartitionGlobalMergeStates::PartitionGlobalMergeStates(PartitionGlobalSinkState &sink) { - // Schedule all the sorts for maximum thread utilisation - if (sink.grouping_data) { - auto &partitions = sink.grouping_data->GetPartitions(); - sink.bin_groups.resize(partitions.size(), partitions.size()); - for (hash_t hash_bin = 0; hash_bin < partitions.size(); ++hash_bin) { - auto &group_data = partitions[hash_bin]; - // Prepare for merge sort phase - if (group_data->Count()) { - auto state = make_uniq(sink, std::move(group_data), hash_bin); - states.emplace_back(std::move(state)); - } - } - } else { - // OVER(ORDER BY...) - // Already sunk into the single global sort, so set up single merge with no data - sink.bin_groups.resize(1, 1); - auto state = make_uniq(sink); - states.emplace_back(std::move(state)); - } - - sink.OnBeginMerge(); -} - -class PartitionMergeTask : public ExecutorTask { -public: - PartitionMergeTask(shared_ptr event_p, ClientContext &context_p, PartitionGlobalMergeStates &hash_groups_p, - PartitionGlobalSinkState &gstate, const PhysicalOperator &op) - : ExecutorTask(context_p, std::move(event_p), op), local_state(gstate), hash_groups(hash_groups_p) { - } - - TaskExecutionResult ExecuteTask(TaskExecutionMode mode) override; - - string TaskType() const override { - return "PartitionMergeTask"; - } - -private: - struct ExecutorCallback : public PartitionGlobalMergeStates::Callback { - explicit ExecutorCallback(Executor &executor) : executor(executor) { - } - - bool HasError() const override { - return executor.HasError(); - } - - Executor &executor; - }; - - PartitionLocalMergeState local_state; - PartitionGlobalMergeStates &hash_groups; -}; - -bool PartitionGlobalMergeStates::ExecuteTask(PartitionLocalMergeState &local_state, Callback &callback) { - // Loop until all hash groups are done - size_t sorted = 0; - while (sorted < states.size()) { - // First check if there is an unfinished task for this thread - if (callback.HasError()) { - return false; - } - if (!local_state.TaskFinished()) { - local_state.ExecuteTask(); - continue; - } - - // Thread is done with its assigned task, try to fetch new work - for (auto group = sorted; group < states.size(); ++group) { - auto &global_state = states[group]; - if (global_state->IsFinished()) { - // This hash group is done - // Update the high water mark of densely completed groups - if (sorted == group) { - ++sorted; - } - continue; - } - - // Try to assign work for this hash group to this thread - if (global_state->AssignTask(local_state)) { - // We assigned a task to this thread! - // Break out of this loop to re-enter the top-level loop and execute the task - break; - } - - // We were able to prepare the next merge round, - // but we were not able to assign a task for it to this thread - // The tasks were assigned to other threads while this thread waited for the lock - // Go to the next iteration to see if another hash group has a task - } - } - - return true; -} - -TaskExecutionResult PartitionMergeTask::ExecuteTask(TaskExecutionMode mode) { - ExecutorCallback callback(executor); - - if (!hash_groups.ExecuteTask(local_state, callback)) { - return TaskExecutionResult::TASK_ERROR; - } - - event->FinishTask(); - return TaskExecutionResult::TASK_FINISHED; -} - -void PartitionMergeEvent::Schedule() { - auto &context = pipeline->GetClientContext(); - - // Schedule tasks equal to the number of threads, which will each merge multiple partitions - auto &ts = TaskScheduler::GetScheduler(context); - auto num_threads = NumericCast(ts.NumberOfThreads()); - - vector> merge_tasks; - for (idx_t tnum = 0; tnum < num_threads; tnum++) { - merge_tasks.emplace_back(make_uniq(shared_from_this(), context, merge_states, gstate, op)); - } - SetTasks(std::move(merge_tasks)); -} - -} // namespace duckdb diff --git a/src/duckdb/src/common/sort/radix_sort.cpp b/src/duckdb/src/common/sort/radix_sort.cpp deleted file mode 100644 index b193cee61..000000000 --- a/src/duckdb/src/common/sort/radix_sort.cpp +++ /dev/null @@ -1,352 +0,0 @@ -#include "duckdb/common/fast_mem.hpp" -#include "duckdb/common/sort/comparators.hpp" -#include "duckdb/common/sort/duckdb_pdqsort.hpp" -#include "duckdb/common/sort/sort.hpp" - -namespace duckdb { - -//! Calls std::sort on strings that are tied by their prefix after the radix sort -static void SortTiedBlobs(BufferManager &buffer_manager, const data_ptr_t dataptr, const idx_t &start, const idx_t &end, - const idx_t &tie_col, bool *ties, const data_ptr_t blob_ptr, const SortLayout &sort_layout) { - const auto row_width = sort_layout.blob_layout.GetRowWidth(); - // Locate the first blob row in question - data_ptr_t row_ptr = dataptr + start * sort_layout.entry_size; - data_ptr_t blob_row_ptr = blob_ptr + Load(row_ptr + sort_layout.comparison_size) * row_width; - if (!Comparators::TieIsBreakable(tie_col, blob_row_ptr, sort_layout)) { - // Quick check to see if ties can be broken - return; - } - // Fill pointer array for sorting - auto ptr_block = make_unsafe_uniq_array_uninitialized(end - start); - auto entry_ptrs = (data_ptr_t *)ptr_block.get(); - for (idx_t i = start; i < end; i++) { - entry_ptrs[i - start] = row_ptr; - row_ptr += sort_layout.entry_size; - } - // Slow pointer-based sorting - const int order = sort_layout.order_types[tie_col] == OrderType::DESCENDING ? -1 : 1; - const idx_t &col_idx = sort_layout.sorting_to_blob_col.at(tie_col); - const auto &tie_col_offset = sort_layout.blob_layout.GetOffsets()[col_idx]; - auto logical_type = sort_layout.blob_layout.GetTypes()[col_idx]; - std::sort(entry_ptrs, entry_ptrs + end - start, - [&blob_ptr, &order, &sort_layout, &tie_col_offset, &row_width, &logical_type](const data_ptr_t l, - const data_ptr_t r) { - idx_t left_idx = Load(l + sort_layout.comparison_size); - idx_t right_idx = Load(r + sort_layout.comparison_size); - data_ptr_t left_ptr = blob_ptr + left_idx * row_width + tie_col_offset; - data_ptr_t right_ptr = blob_ptr + right_idx * row_width + tie_col_offset; - return order * Comparators::CompareVal(left_ptr, right_ptr, logical_type) < 0; - }); - // Re-order - auto temp_block = buffer_manager.GetBufferAllocator().Allocate((end - start) * sort_layout.entry_size); - data_ptr_t temp_ptr = temp_block.get(); - for (idx_t i = 0; i < end - start; i++) { - FastMemcpy(temp_ptr, entry_ptrs[i], sort_layout.entry_size); - temp_ptr += sort_layout.entry_size; - } - memcpy(dataptr + start * sort_layout.entry_size, temp_block.get(), (end - start) * sort_layout.entry_size); - // Determine if there are still ties (if this is not the last column) - if (tie_col < sort_layout.column_count - 1) { - data_ptr_t idx_ptr = dataptr + start * sort_layout.entry_size + sort_layout.comparison_size; - // Load current entry - data_ptr_t current_ptr = blob_ptr + Load(idx_ptr) * row_width + tie_col_offset; - for (idx_t i = 0; i < end - start - 1; i++) { - // Load next entry and compare - idx_ptr += sort_layout.entry_size; - data_ptr_t next_ptr = blob_ptr + Load(idx_ptr) * row_width + tie_col_offset; - ties[start + i] = Comparators::CompareVal(current_ptr, next_ptr, logical_type) == 0; - current_ptr = next_ptr; - } - } -} - -//! Identifies sequences of rows that are tied by the prefix of a blob column, and sorts them -static void SortTiedBlobs(BufferManager &buffer_manager, SortedBlock &sb, bool *ties, data_ptr_t dataptr, - const idx_t &count, const idx_t &tie_col, const SortLayout &sort_layout) { - D_ASSERT(!ties[count - 1]); - auto &blob_block = *sb.blob_sorting_data->data_blocks.back(); - auto blob_handle = buffer_manager.Pin(blob_block.block); - const data_ptr_t blob_ptr = blob_handle.Ptr(); - - for (idx_t i = 0; i < count; i++) { - if (!ties[i]) { - continue; - } - idx_t j; - for (j = i + 1; j < count; j++) { - if (!ties[j]) { - break; - } - } - SortTiedBlobs(buffer_manager, dataptr, i, j + 1, tie_col, ties, blob_ptr, sort_layout); - i = j; - } -} - -//! Returns whether there are any 'true' values in the ties[] array -static bool AnyTies(bool ties[], const idx_t &count) { - D_ASSERT(!ties[count - 1]); - bool any_ties = false; - for (idx_t i = 0; i < count - 1; i++) { - any_ties = any_ties || ties[i]; - } - return any_ties; -} - -//! Compares subsequent rows to check for ties -static void ComputeTies(data_ptr_t dataptr, const idx_t &count, const idx_t &col_offset, const idx_t &tie_size, - bool ties[], const SortLayout &sort_layout) { - D_ASSERT(!ties[count - 1]); - D_ASSERT(col_offset + tie_size <= sort_layout.comparison_size); - // Align dataptr - dataptr += col_offset; - for (idx_t i = 0; i < count - 1; i++) { - ties[i] = ties[i] && FastMemcmp(dataptr, dataptr + sort_layout.entry_size, tie_size) == 0; - dataptr += sort_layout.entry_size; - } -} - -//! Textbook LSD radix sort -void RadixSortLSD(BufferManager &buffer_manager, const data_ptr_t &dataptr, const idx_t &count, const idx_t &col_offset, - const idx_t &row_width, const idx_t &sorting_size) { - auto temp_block = buffer_manager.GetBufferAllocator().Allocate(count * row_width); - bool swap = false; - - idx_t counts[SortConstants::VALUES_PER_RADIX]; - for (idx_t r = 1; r <= sorting_size; r++) { - // Init counts to 0 - memset(counts, 0, sizeof(counts)); - // Const some values for convenience - const data_ptr_t source_ptr = swap ? temp_block.get() : dataptr; - const data_ptr_t target_ptr = swap ? dataptr : temp_block.get(); - const idx_t offset = col_offset + sorting_size - r; - // Collect counts - data_ptr_t offset_ptr = source_ptr + offset; - for (idx_t i = 0; i < count; i++) { - counts[*offset_ptr]++; - offset_ptr += row_width; - } - // Compute offsets from counts - idx_t max_count = counts[0]; - for (idx_t val = 1; val < SortConstants::VALUES_PER_RADIX; val++) { - max_count = MaxValue(max_count, counts[val]); - counts[val] = counts[val] + counts[val - 1]; - } - if (max_count == count) { - continue; - } - // Re-order the data in temporary array - data_ptr_t row_ptr = source_ptr + (count - 1) * row_width; - for (idx_t i = 0; i < count; i++) { - idx_t &radix_offset = --counts[*(row_ptr + offset)]; - FastMemcpy(target_ptr + radix_offset * row_width, row_ptr, row_width); - row_ptr -= row_width; - } - swap = !swap; - } - // Move data back to original buffer (if it was swapped) - if (swap) { - memcpy(dataptr, temp_block.get(), count * row_width); - } -} - -//! Insertion sort, used when count of values is low -inline void InsertionSort(const data_ptr_t orig_ptr, const data_ptr_t temp_ptr, const idx_t &count, - const idx_t &col_offset, const idx_t &row_width, const idx_t &total_comp_width, - const idx_t &offset, bool swap) { - const data_ptr_t source_ptr = swap ? temp_ptr : orig_ptr; - const data_ptr_t target_ptr = swap ? orig_ptr : temp_ptr; - if (count > 1) { - const idx_t total_offset = col_offset + offset; - auto temp_val = make_unsafe_uniq_array_uninitialized(row_width); - const data_ptr_t val = temp_val.get(); - const auto comp_width = total_comp_width - offset; - for (idx_t i = 1; i < count; i++) { - FastMemcpy(val, source_ptr + i * row_width, row_width); - idx_t j = i; - while (j > 0 && - FastMemcmp(source_ptr + (j - 1) * row_width + total_offset, val + total_offset, comp_width) > 0) { - FastMemcpy(source_ptr + j * row_width, source_ptr + (j - 1) * row_width, row_width); - j--; - } - FastMemcpy(source_ptr + j * row_width, val, row_width); - } - } - if (swap) { - memcpy(target_ptr, source_ptr, count * row_width); - } -} - -//! MSD radix sort that switches to insertion sort with low bucket sizes -void RadixSortMSD(const data_ptr_t orig_ptr, const data_ptr_t temp_ptr, const idx_t &count, const idx_t &col_offset, - const idx_t &row_width, const idx_t &comp_width, const idx_t &offset, idx_t locations[], bool swap) { - const data_ptr_t source_ptr = swap ? temp_ptr : orig_ptr; - const data_ptr_t target_ptr = swap ? orig_ptr : temp_ptr; - // Init counts to 0 - memset(locations, 0, SortConstants::MSD_RADIX_LOCATIONS * sizeof(idx_t)); - idx_t *counts = locations + 1; - // Collect counts - const idx_t total_offset = col_offset + offset; - data_ptr_t offset_ptr = source_ptr + total_offset; - for (idx_t i = 0; i < count; i++) { - counts[*offset_ptr]++; - offset_ptr += row_width; - } - // Compute locations from counts - idx_t max_count = 0; - for (idx_t radix = 0; radix < SortConstants::VALUES_PER_RADIX; radix++) { - max_count = MaxValue(max_count, counts[radix]); - counts[radix] += locations[radix]; - } - if (max_count != count) { - // Re-order the data in temporary array - data_ptr_t row_ptr = source_ptr; - for (idx_t i = 0; i < count; i++) { - const idx_t &radix_offset = locations[*(row_ptr + total_offset)]++; - FastMemcpy(target_ptr + radix_offset * row_width, row_ptr, row_width); - row_ptr += row_width; - } - swap = !swap; - } - // Check if done - if (offset == comp_width - 1) { - if (swap) { - memcpy(orig_ptr, temp_ptr, count * row_width); - } - return; - } - if (max_count == count) { - RadixSortMSD(orig_ptr, temp_ptr, count, col_offset, row_width, comp_width, offset + 1, - locations + SortConstants::MSD_RADIX_LOCATIONS, swap); - return; - } - // Recurse - idx_t radix_count = locations[0]; - for (idx_t radix = 0; radix < SortConstants::VALUES_PER_RADIX; radix++) { - const idx_t loc = (locations[radix] - radix_count) * row_width; - if (radix_count > SortConstants::INSERTION_SORT_THRESHOLD) { - RadixSortMSD(orig_ptr + loc, temp_ptr + loc, radix_count, col_offset, row_width, comp_width, offset + 1, - locations + SortConstants::MSD_RADIX_LOCATIONS, swap); - } else if (radix_count != 0) { - InsertionSort(orig_ptr + loc, temp_ptr + loc, radix_count, col_offset, row_width, comp_width, offset + 1, - swap); - } - radix_count = locations[radix + 1] - locations[radix]; - } -} - -//! Calls different sort functions, depending on the count and sorting sizes -void RadixSort(BufferManager &buffer_manager, const data_ptr_t &dataptr, const idx_t &count, const idx_t &col_offset, - const idx_t &sorting_size, const SortLayout &sort_layout, bool contains_string) { - - if (contains_string) { - auto begin = duckdb_pdqsort::PDQIterator(dataptr, sort_layout.entry_size); - auto end = begin + count; - duckdb_pdqsort::PDQConstants constants(sort_layout.entry_size, col_offset, sorting_size, *end); - return duckdb_pdqsort::pdqsort_branchless(begin, begin + count, constants); - } - - if (count <= SortConstants::INSERTION_SORT_THRESHOLD) { - return InsertionSort(dataptr, nullptr, count, col_offset, sort_layout.entry_size, sorting_size, 0, false); - } - - if (sorting_size <= SortConstants::MSD_RADIX_SORT_SIZE_THRESHOLD) { - return RadixSortLSD(buffer_manager, dataptr, count, col_offset, sort_layout.entry_size, sorting_size); - } - - const auto block_size = buffer_manager.GetBlockSize(); - auto temp_block = - buffer_manager.Allocate(MemoryTag::ORDER_BY, MaxValue(count * sort_layout.entry_size, block_size)); - auto pre_allocated_array = - make_unsafe_uniq_array_uninitialized(sorting_size * SortConstants::MSD_RADIX_LOCATIONS); - RadixSortMSD(dataptr, temp_block.Ptr(), count, col_offset, sort_layout.entry_size, sorting_size, 0, - pre_allocated_array.get(), false); -} - -//! Identifies sequences of rows that are tied, and calls radix sort on these -static void SubSortTiedTuples(BufferManager &buffer_manager, const data_ptr_t dataptr, const idx_t &count, - const idx_t &col_offset, const idx_t &sorting_size, bool ties[], - const SortLayout &sort_layout, bool contains_string) { - D_ASSERT(!ties[count - 1]); - for (idx_t i = 0; i < count; i++) { - if (!ties[i]) { - continue; - } - idx_t j; - for (j = i + 1; j < count; j++) { - if (!ties[j]) { - break; - } - } - RadixSort(buffer_manager, dataptr + i * sort_layout.entry_size, j - i + 1, col_offset, sorting_size, - sort_layout, contains_string); - i = j; - } -} - -void LocalSortState::SortInMemory() { - auto &sb = *sorted_blocks.back(); - auto &block = *sb.radix_sorting_data.back(); - const auto &count = block.count; - auto handle = buffer_manager->Pin(block.block); - const auto dataptr = handle.Ptr(); - // Assign an index to each row - data_ptr_t idx_dataptr = dataptr + sort_layout->comparison_size; - for (uint32_t i = 0; i < count; i++) { - Store(i, idx_dataptr); - idx_dataptr += sort_layout->entry_size; - } - // Radix sort and break ties until no more ties, or until all columns are sorted - idx_t sorting_size = 0; - idx_t col_offset = 0; - unsafe_unique_array ties_ptr; - bool *ties = nullptr; - bool contains_string = false; - for (idx_t i = 0; i < sort_layout->column_count; i++) { - sorting_size += sort_layout->column_sizes[i]; - contains_string = contains_string || sort_layout->logical_types[i].InternalType() == PhysicalType::VARCHAR; - if (sort_layout->constant_size[i] && i < sort_layout->column_count - 1) { - // Add columns to the sorting size until we reach a variable size column, or the last column - continue; - } - - if (!ties) { - // This is the first sort - RadixSort(*buffer_manager, dataptr, count, col_offset, sorting_size, *sort_layout, contains_string); - ties_ptr = make_unsafe_uniq_array_uninitialized(count); - ties = ties_ptr.get(); - std::fill_n(ties, count - 1, true); - ties[count - 1] = false; - } else { - // For subsequent sorts, we only have to subsort the tied tuples - SubSortTiedTuples(*buffer_manager, dataptr, count, col_offset, sorting_size, ties, *sort_layout, - contains_string); - } - - contains_string = false; - - if (sort_layout->constant_size[i] && i == sort_layout->column_count - 1) { - // All columns are sorted, no ties to break because last column is constant size - break; - } - - ComputeTies(dataptr, count, col_offset, sorting_size, ties, *sort_layout); - if (!AnyTies(ties, count)) { - // No ties, stop sorting - break; - } - - if (!sort_layout->constant_size[i]) { - SortTiedBlobs(*buffer_manager, sb, ties, dataptr, count, i, *sort_layout); - if (!AnyTies(ties, count)) { - // No more ties after tie-breaking, stop - break; - } - } - - col_offset += sorting_size; - sorting_size = 0; - } -} - -} // namespace duckdb diff --git a/src/duckdb/src/common/sorting/sort.cpp b/src/duckdb/src/common/sort/sort.cpp similarity index 94% rename from src/duckdb/src/common/sorting/sort.cpp rename to src/duckdb/src/common/sort/sort.cpp index 2159878ff..b46db0a5f 100644 --- a/src/duckdb/src/common/sorting/sort.cpp +++ b/src/duckdb/src/common/sort/sort.cpp @@ -141,7 +141,7 @@ class SortLocalSinkState : public LocalSinkState { D_ASSERT(!sorted_run); // TODO: we want to pass "sort.is_index_sort" instead of just "false" here // so that we can do an approximate sort, but that causes issues in the ART - sorted_run = make_uniq(context, sort.key_layout, sort.payload_layout, false); + sorted_run = make_uniq(context, sort, false); } public: @@ -161,7 +161,7 @@ class SortGlobalSinkState : public GlobalSinkState { public: explicit SortGlobalSinkState(ClientContext &context) : num_threads(NumericCast(TaskScheduler::GetScheduler(context).NumberOfThreads())), - temporary_memory_state(TemporaryMemoryManager::Get(context).Register(context)), + temporary_memory_state(TemporaryMemoryManager::Get(context).Register(context)), sorted_tuples(0), external(ClientConfig::GetConfig(context).force_external), any_combined(false), total_count(0), partition_size(0) { } @@ -366,8 +366,7 @@ ProgressData Sort::GetSinkProgress(ClientContext &context, GlobalSinkState &gsta class SortGlobalSourceState : public GlobalSourceState { public: SortGlobalSourceState(const Sort &sort, ClientContext &context, SortGlobalSinkState &sink_p) - : sink(sink_p), merger(*sort.decode_sort_key, sort.key_layout, std::move(sink.sorted_runs), - sort.output_projection_columns, sink.partition_size, sink.external, false), + : sink(sink_p), merger(sort, std::move(sink.sorted_runs), sink.partition_size, sink.external, false), merger_global_state(merger.total_count == 0 ? nullptr : merger.GetGlobalSourceState(context)) { // TODO: we want to pass "sort.is_index_sort" instead of just "false" here // so that we can do an approximate sort, but that causes issues in the ART @@ -378,6 +377,15 @@ class SortGlobalSourceState : public GlobalSourceState { return merger_global_state ? merger_global_state->MaxThreads() : 1; } + void Destroy() { + if (!merger_global_state) { + return; + } + auto guard = merger_global_state->Lock(); + merger.sorted_runs.clear(); + sink.temporary_memory_state.reset(); + } + public: //! The global sink state SortGlobalSinkState &sink; @@ -456,7 +464,8 @@ SourceResultType Sort::MaterializeColumnData(ExecutionContext &context, Operator chunk.Initialize(context.client, types); // Initialize local output collection - auto local_column_data = make_uniq(context.client, types, true); + auto local_column_data = + make_uniq(context.client, types, ColumnDataAllocatorType::BUFFER_MANAGER_ALLOCATOR); while (true) { // Check for interrupts since this could be a long-running task @@ -477,16 +486,26 @@ SourceResultType Sort::MaterializeColumnData(ExecutionContext &context, Operator } // Merge into global output collection - auto guard = gstate.Lock(); - if (!gstate.column_data) { - gstate.column_data = std::move(local_column_data); - } else { - gstate.column_data->Merge(*local_column_data); + { + auto guard = gstate.Lock(); + if (!gstate.column_data) { + gstate.column_data = std::move(local_column_data); + } else { + gstate.column_data->Merge(*local_column_data); + } } + // Destroy local state before returning + input.local_state.Cast().merger_local_state.reset(); + // Return type indicates whether materialization is done const auto progress_data = GetProgress(context.client, input.global_state); - return progress_data.done == progress_data.total ? SourceResultType::FINISHED : SourceResultType::HAVE_MORE_OUTPUT; + if (progress_data.done == progress_data.total) { + // Destroy global state before returning + gstate.Destroy(); + return SourceResultType::FINISHED; + } + return SourceResultType::HAVE_MORE_OUTPUT; } unique_ptr Sort::GetColumnData(OperatorSourceInput &input) const { @@ -502,12 +521,15 @@ SourceResultType Sort::MaterializeSortedRun(ExecutionContext &context, OperatorS } auto &lstate = input.local_state.Cast(); OperatorSourceInput merger_input {*gstate.merger_global_state, *lstate.merger_local_state, input.interrupt_state}; - return gstate.merger.MaterializeMerge(context, merger_input); + return gstate.merger.MaterializeSortedRun(context, merger_input); } unique_ptr Sort::GetSortedRun(GlobalSourceState &global_state) { auto &gstate = global_state.Cast(); - return gstate.merger.GetMaterialized(gstate); + if (gstate.merger.total_count == 0) { + return nullptr; + } + return gstate.merger.GetSortedRun(*gstate.merger_global_state); } } // namespace duckdb diff --git a/src/duckdb/src/common/sort/sort_state.cpp b/src/duckdb/src/common/sort/sort_state.cpp deleted file mode 100644 index 369f032f1..000000000 --- a/src/duckdb/src/common/sort/sort_state.cpp +++ /dev/null @@ -1,487 +0,0 @@ -#include "duckdb/common/fast_mem.hpp" -#include "duckdb/common/radix.hpp" -#include "duckdb/common/row_operations/row_operations.hpp" -#include "duckdb/common/sort/sort.hpp" -#include "duckdb/common/sort/sorted_block.hpp" -#include "duckdb/storage/buffer/buffer_pool.hpp" - -#include -#include - -namespace duckdb { - -idx_t GetNestedSortingColSize(idx_t &col_size, const LogicalType &type) { - auto physical_type = type.InternalType(); - if (TypeIsConstantSize(physical_type)) { - col_size += GetTypeIdSize(physical_type); - return 0; - } else { - switch (physical_type) { - case PhysicalType::VARCHAR: { - // Nested strings are between 4 and 11 chars long for alignment - auto size_before_str = col_size; - col_size += 11; - col_size -= (col_size - 12) % 8; - return col_size - size_before_str; - } - case PhysicalType::LIST: - // Lists get 2 bytes (null and empty list) - col_size += 2; - return GetNestedSortingColSize(col_size, ListType::GetChildType(type)); - case PhysicalType::STRUCT: - // Structs get 1 bytes (null) - col_size++; - return GetNestedSortingColSize(col_size, StructType::GetChildType(type, 0)); - case PhysicalType::ARRAY: - // Arrays get 1 bytes (null) - col_size++; - return GetNestedSortingColSize(col_size, ArrayType::GetChildType(type)); - default: - throw NotImplementedException("Unable to order column with type %s", type.ToString()); - } - } -} - -SortLayout::SortLayout(const vector &orders) - : column_count(orders.size()), all_constant(true), comparison_size(0), entry_size(0) { - vector blob_layout_types; - for (idx_t i = 0; i < column_count; i++) { - const auto &order = orders[i]; - - order_types.push_back(order.type); - order_by_null_types.push_back(order.null_order); - auto &expr = *order.expression; - logical_types.push_back(expr.return_type); - - auto physical_type = expr.return_type.InternalType(); - constant_size.push_back(TypeIsConstantSize(physical_type)); - - if (order.stats) { - stats.push_back(order.stats.get()); - has_null.push_back(stats.back()->CanHaveNull()); - } else { - stats.push_back(nullptr); - has_null.push_back(true); - } - - idx_t col_size = has_null.back() ? 1 : 0; - prefix_lengths.push_back(0); - if (!TypeIsConstantSize(physical_type) && physical_type != PhysicalType::VARCHAR) { - prefix_lengths.back() = GetNestedSortingColSize(col_size, expr.return_type); - } else if (physical_type == PhysicalType::VARCHAR) { - idx_t size_before = col_size; - if (stats.back() && StringStats::HasMaxStringLength(*stats.back())) { - col_size += StringStats::MaxStringLength(*stats.back()); - if (col_size > 12) { - col_size = 12; - } else { - constant_size.back() = true; - } - } else { - col_size = 12; - } - prefix_lengths.back() = col_size - size_before; - } else { - col_size += GetTypeIdSize(physical_type); - } - - comparison_size += col_size; - column_sizes.push_back(col_size); - } - entry_size = comparison_size + sizeof(uint32_t); - - // 8-byte alignment - if (entry_size % 8 != 0) { - // First assign more bytes to strings instead of aligning - idx_t bytes_to_fill = 8 - (entry_size % 8); - for (idx_t col_idx = 0; col_idx < column_count; col_idx++) { - if (bytes_to_fill == 0) { - break; - } - if (logical_types[col_idx].InternalType() == PhysicalType::VARCHAR && stats[col_idx] && - StringStats::HasMaxStringLength(*stats[col_idx])) { - idx_t diff = StringStats::MaxStringLength(*stats[col_idx]) - prefix_lengths[col_idx]; - if (diff > 0) { - // Increase all sizes accordingly - idx_t increase = MinValue(bytes_to_fill, diff); - column_sizes[col_idx] += increase; - prefix_lengths[col_idx] += increase; - constant_size[col_idx] = increase == diff; - comparison_size += increase; - entry_size += increase; - bytes_to_fill -= increase; - } - } - } - entry_size = AlignValue(entry_size); - } - - for (idx_t col_idx = 0; col_idx < column_count; col_idx++) { - all_constant = all_constant && constant_size[col_idx]; - if (!constant_size[col_idx]) { - sorting_to_blob_col[col_idx] = blob_layout_types.size(); - blob_layout_types.push_back(logical_types[col_idx]); - } - } - - blob_layout.Initialize(blob_layout_types); -} - -SortLayout SortLayout::GetPrefixComparisonLayout(idx_t num_prefix_cols) const { - SortLayout result; - result.column_count = num_prefix_cols; - result.all_constant = true; - result.comparison_size = 0; - for (idx_t col_idx = 0; col_idx < num_prefix_cols; col_idx++) { - result.order_types.push_back(order_types[col_idx]); - result.order_by_null_types.push_back(order_by_null_types[col_idx]); - result.logical_types.push_back(logical_types[col_idx]); - - result.all_constant = result.all_constant && constant_size[col_idx]; - result.constant_size.push_back(constant_size[col_idx]); - - result.comparison_size += column_sizes[col_idx]; - result.column_sizes.push_back(column_sizes[col_idx]); - - result.prefix_lengths.push_back(prefix_lengths[col_idx]); - result.stats.push_back(stats[col_idx]); - result.has_null.push_back(has_null[col_idx]); - } - result.entry_size = entry_size; - result.blob_layout = blob_layout; - result.sorting_to_blob_col = sorting_to_blob_col; - return result; -} - -LocalSortState::LocalSortState() : initialized(false) { - if (!Radix::IsLittleEndian()) { - throw NotImplementedException("Sorting is not supported on big endian architectures"); - } -} - -void LocalSortState::Initialize(GlobalSortState &global_sort_state, BufferManager &buffer_manager_p) { - sort_layout = &global_sort_state.sort_layout; - payload_layout = &global_sort_state.payload_layout; - buffer_manager = &buffer_manager_p; - const auto block_size = buffer_manager->GetBlockSize(); - - // Radix sorting data - auto entries_per_block = RowDataCollection::EntriesPerBlock(sort_layout->entry_size, block_size); - radix_sorting_data = make_uniq(*buffer_manager, entries_per_block, sort_layout->entry_size); - - // Blob sorting data - if (!sort_layout->all_constant) { - auto blob_row_width = sort_layout->blob_layout.GetRowWidth(); - entries_per_block = RowDataCollection::EntriesPerBlock(blob_row_width, block_size); - blob_sorting_data = make_uniq(*buffer_manager, entries_per_block, blob_row_width); - blob_sorting_heap = make_uniq(*buffer_manager, block_size, 1U, true); - } - - // Payload data - auto payload_row_width = payload_layout->GetRowWidth(); - entries_per_block = RowDataCollection::EntriesPerBlock(payload_row_width, block_size); - payload_data = make_uniq(*buffer_manager, entries_per_block, payload_row_width); - payload_heap = make_uniq(*buffer_manager, block_size, 1U, true); - initialized = true; -} - -void LocalSortState::SinkChunk(DataChunk &sort, DataChunk &payload) { - D_ASSERT(sort.size() == payload.size()); - // Build and serialize sorting data to radix sortable rows - auto data_pointers = FlatVector::GetData(addresses); - auto handles = radix_sorting_data->Build(sort.size(), data_pointers, nullptr); - for (idx_t sort_col = 0; sort_col < sort.ColumnCount(); sort_col++) { - bool has_null = sort_layout->has_null[sort_col]; - bool nulls_first = sort_layout->order_by_null_types[sort_col] == OrderByNullType::NULLS_FIRST; - bool desc = sort_layout->order_types[sort_col] == OrderType::DESCENDING; - RowOperations::RadixScatter(sort.data[sort_col], sort.size(), sel_ptr, sort.size(), data_pointers, desc, - has_null, nulls_first, sort_layout->prefix_lengths[sort_col], - sort_layout->column_sizes[sort_col]); - } - - // Also fully serialize blob sorting columns (to be able to break ties - if (!sort_layout->all_constant) { - DataChunk blob_chunk; - blob_chunk.SetCardinality(sort.size()); - for (idx_t sort_col = 0; sort_col < sort.ColumnCount(); sort_col++) { - if (!sort_layout->constant_size[sort_col]) { - blob_chunk.data.emplace_back(sort.data[sort_col]); - } - } - handles = blob_sorting_data->Build(blob_chunk.size(), data_pointers, nullptr); - auto blob_data = blob_chunk.ToUnifiedFormat(); - RowOperations::Scatter(blob_chunk, blob_data.get(), sort_layout->blob_layout, addresses, *blob_sorting_heap, - sel_ptr, blob_chunk.size()); - D_ASSERT(blob_sorting_heap->keep_pinned); - } - - // Finally, serialize payload data - handles = payload_data->Build(payload.size(), data_pointers, nullptr); - auto input_data = payload.ToUnifiedFormat(); - RowOperations::Scatter(payload, input_data.get(), *payload_layout, addresses, *payload_heap, sel_ptr, - payload.size()); - D_ASSERT(payload_heap->keep_pinned); -} - -idx_t LocalSortState::SizeInBytes() const { - idx_t size_in_bytes = radix_sorting_data->SizeInBytes() + payload_data->SizeInBytes(); - if (!sort_layout->all_constant) { - size_in_bytes += blob_sorting_data->SizeInBytes() + blob_sorting_heap->SizeInBytes(); - } - if (!payload_layout->AllConstant()) { - size_in_bytes += payload_heap->SizeInBytes(); - } - return size_in_bytes; -} - -void LocalSortState::Sort(GlobalSortState &global_sort_state, bool reorder_heap) { - D_ASSERT(radix_sorting_data->count == payload_data->count); - if (radix_sorting_data->count == 0) { - return; - } - // Move all data to a single SortedBlock - sorted_blocks.emplace_back(make_uniq(*buffer_manager, global_sort_state)); - auto &sb = *sorted_blocks.back(); - // Fixed-size sorting data - auto sorting_block = ConcatenateBlocks(*radix_sorting_data); - sb.radix_sorting_data.push_back(std::move(sorting_block)); - // Variable-size sorting data - if (!sort_layout->all_constant) { - auto &blob_data = *blob_sorting_data; - auto new_block = ConcatenateBlocks(blob_data); - sb.blob_sorting_data->data_blocks.push_back(std::move(new_block)); - } - // Payload data - auto payload_block = ConcatenateBlocks(*payload_data); - sb.payload_data->data_blocks.push_back(std::move(payload_block)); - // Now perform the actual sort - SortInMemory(); - // Re-order before the merge sort - ReOrder(global_sort_state, reorder_heap); -} - -unique_ptr LocalSortState::ConcatenateBlocks(RowDataCollection &row_data) { - // Don't copy and delete if there is only one block. - if (row_data.blocks.size() == 1) { - auto new_block = std::move(row_data.blocks[0]); - row_data.blocks.clear(); - row_data.count = 0; - return new_block; - } - // Create block with the correct capacity - auto &buffer_manager = row_data.buffer_manager; - const idx_t &entry_size = row_data.entry_size; - idx_t capacity = MaxValue((buffer_manager.GetBlockSize() + entry_size - 1) / entry_size, row_data.count); - auto new_block = make_uniq(MemoryTag::ORDER_BY, buffer_manager, capacity, entry_size); - new_block->count = row_data.count; - auto new_block_handle = buffer_manager.Pin(new_block->block); - data_ptr_t new_block_ptr = new_block_handle.Ptr(); - // Copy the data of the blocks into a single block - for (idx_t i = 0; i < row_data.blocks.size(); i++) { - auto &block = row_data.blocks[i]; - auto block_handle = buffer_manager.Pin(block->block); - memcpy(new_block_ptr, block_handle.Ptr(), block->count * entry_size); - new_block_ptr += block->count * entry_size; - block.reset(); - } - row_data.blocks.clear(); - row_data.count = 0; - return new_block; -} - -void LocalSortState::ReOrder(SortedData &sd, data_ptr_t sorting_ptr, RowDataCollection &heap, GlobalSortState &gstate, - bool reorder_heap) { - sd.swizzled = reorder_heap; - auto &unordered_data_block = sd.data_blocks.back(); - const idx_t count = unordered_data_block->count; - auto unordered_data_handle = buffer_manager->Pin(unordered_data_block->block); - const data_ptr_t unordered_data_ptr = unordered_data_handle.Ptr(); - // Create new block that will hold re-ordered row data - auto ordered_data_block = make_uniq(MemoryTag::ORDER_BY, *buffer_manager, - unordered_data_block->capacity, unordered_data_block->entry_size); - ordered_data_block->count = count; - auto ordered_data_handle = buffer_manager->Pin(ordered_data_block->block); - data_ptr_t ordered_data_ptr = ordered_data_handle.Ptr(); - // Re-order fixed-size row layout - const idx_t row_width = sd.layout.GetRowWidth(); - const idx_t sorting_entry_size = gstate.sort_layout.entry_size; - for (idx_t i = 0; i < count; i++) { - auto index = Load(sorting_ptr); - FastMemcpy(ordered_data_ptr, unordered_data_ptr + index * row_width, row_width); - ordered_data_ptr += row_width; - sorting_ptr += sorting_entry_size; - } - ordered_data_block->block->SetSwizzling( - sd.layout.AllConstant() || !sd.swizzled ? nullptr : "LocalSortState::ReOrder.ordered_data"); - // Replace the unordered data block with the re-ordered data block - sd.data_blocks.clear(); - sd.data_blocks.push_back(std::move(ordered_data_block)); - // Deal with the heap (if necessary) - if (!sd.layout.AllConstant() && reorder_heap) { - // Swizzle the column pointers to offsets - RowOperations::SwizzleColumns(sd.layout, ordered_data_handle.Ptr(), count); - sd.data_blocks.back()->block->SetSwizzling(nullptr); - // Create a single heap block to store the ordered heap - idx_t total_byte_offset = - std::accumulate(heap.blocks.begin(), heap.blocks.end(), (idx_t)0, - [](idx_t a, const unique_ptr &b) { return a + b->byte_offset; }); - idx_t heap_block_size = MaxValue(total_byte_offset, buffer_manager->GetBlockSize()); - auto ordered_heap_block = make_uniq(MemoryTag::ORDER_BY, *buffer_manager, heap_block_size, 1U); - ordered_heap_block->count = count; - ordered_heap_block->byte_offset = total_byte_offset; - auto ordered_heap_handle = buffer_manager->Pin(ordered_heap_block->block); - data_ptr_t ordered_heap_ptr = ordered_heap_handle.Ptr(); - // Fill the heap in order - ordered_data_ptr = ordered_data_handle.Ptr(); - const idx_t heap_pointer_offset = sd.layout.GetHeapOffset(); - for (idx_t i = 0; i < count; i++) { - auto heap_row_ptr = Load(ordered_data_ptr + heap_pointer_offset); - auto heap_row_size = Load(heap_row_ptr); - memcpy(ordered_heap_ptr, heap_row_ptr, heap_row_size); - ordered_heap_ptr += heap_row_size; - ordered_data_ptr += row_width; - } - // Swizzle the base pointer to the offset of each row in the heap - RowOperations::SwizzleHeapPointer(sd.layout, ordered_data_handle.Ptr(), ordered_heap_handle.Ptr(), count); - // Move the re-ordered heap to the SortedData, and clear the local heap - sd.heap_blocks.push_back(std::move(ordered_heap_block)); - heap.pinned_blocks.clear(); - heap.blocks.clear(); - heap.count = 0; - } -} - -void LocalSortState::ReOrder(GlobalSortState &gstate, bool reorder_heap) { - auto &sb = *sorted_blocks.back(); - auto sorting_handle = buffer_manager->Pin(sb.radix_sorting_data.back()->block); - const data_ptr_t sorting_ptr = sorting_handle.Ptr() + gstate.sort_layout.comparison_size; - // Re-order variable size sorting columns - if (!gstate.sort_layout.all_constant) { - ReOrder(*sb.blob_sorting_data, sorting_ptr, *blob_sorting_heap, gstate, reorder_heap); - } - // And the payload - ReOrder(*sb.payload_data, sorting_ptr, *payload_heap, gstate, reorder_heap); -} - -GlobalSortState::GlobalSortState(ClientContext &context_p, const vector &orders, - RowLayout &payload_layout) - : context(context_p), buffer_manager(BufferManager::GetBufferManager(context)), sort_layout(SortLayout(orders)), - payload_layout(payload_layout), block_capacity(0), external(false) { -} - -void GlobalSortState::AddLocalState(LocalSortState &local_sort_state) { - if (!local_sort_state.radix_sorting_data) { - return; - } - - // Sort accumulated data - // we only re-order the heap when the data is expected to not fit in memory - // re-ordering the heap avoids random access when reading/merging but incurs a significant cost of shuffling data - // when data fits in memory, doing random access on reads is cheaper than re-shuffling - local_sort_state.Sort(*this, external || !local_sort_state.sorted_blocks.empty()); - - // Append local state sorted data to this global state - lock_guard append_guard(lock); - for (auto &sb : local_sort_state.sorted_blocks) { - sorted_blocks.push_back(std::move(sb)); - } - auto &payload_heap = local_sort_state.payload_heap; - for (idx_t i = 0; i < payload_heap->blocks.size(); i++) { - heap_blocks.push_back(std::move(payload_heap->blocks[i])); - pinned_blocks.push_back(std::move(payload_heap->pinned_blocks[i])); - } - if (!sort_layout.all_constant) { - auto &blob_heap = local_sort_state.blob_sorting_heap; - for (idx_t i = 0; i < blob_heap->blocks.size(); i++) { - heap_blocks.push_back(std::move(blob_heap->blocks[i])); - pinned_blocks.push_back(std::move(blob_heap->pinned_blocks[i])); - } - } -} - -void GlobalSortState::PrepareMergePhase() { - // Determine if we need to use do an external sort - idx_t total_heap_size = - std::accumulate(sorted_blocks.begin(), sorted_blocks.end(), (idx_t)0, - [](idx_t a, const unique_ptr &b) { return a + b->HeapSize(); }); - if (external || (pinned_blocks.empty() && total_heap_size * 4 > buffer_manager.GetQueryMaxMemory())) { - external = true; - } - // Use the data that we have to determine which partition size to use during the merge - if (external && total_heap_size > 0) { - // If we have variable size data we need to be conservative, as there might be skew - idx_t max_block_size = 0; - for (auto &sb : sorted_blocks) { - idx_t size_in_bytes = sb->SizeInBytes(); - if (size_in_bytes > max_block_size) { - max_block_size = size_in_bytes; - block_capacity = sb->Count(); - } - } - } else { - for (auto &sb : sorted_blocks) { - block_capacity = MaxValue(block_capacity, sb->Count()); - } - } - // Unswizzle and pin heap blocks if we can fit everything in memory - if (!external) { - for (auto &sb : sorted_blocks) { - sb->blob_sorting_data->Unswizzle(); - sb->payload_data->Unswizzle(); - } - } -} - -void GlobalSortState::InitializeMergeRound() { - D_ASSERT(sorted_blocks_temp.empty()); - // If we reverse this list, the blocks that were merged last will be merged first in the next round - // These are still in memory, therefore this reduces the amount of read/write to disk! - std::reverse(sorted_blocks.begin(), sorted_blocks.end()); - // Uneven number of blocks - keep one on the side - if (sorted_blocks.size() % 2 == 1) { - odd_one_out = std::move(sorted_blocks.back()); - sorted_blocks.pop_back(); - } - // Init merge path path indices - pair_idx = 0; - num_pairs = sorted_blocks.size() / 2; - l_start = 0; - r_start = 0; - // Allocate room for merge results - for (idx_t p_idx = 0; p_idx < num_pairs; p_idx++) { - sorted_blocks_temp.emplace_back(); - } -} - -void GlobalSortState::CompleteMergeRound(bool keep_radix_data) { - sorted_blocks.clear(); - for (auto &sorted_block_vector : sorted_blocks_temp) { - sorted_blocks.push_back(make_uniq(buffer_manager, *this)); - sorted_blocks.back()->AppendSortedBlocks(sorted_block_vector); - } - sorted_blocks_temp.clear(); - if (odd_one_out) { - sorted_blocks.push_back(std::move(odd_one_out)); - odd_one_out = nullptr; - } - // Only one block left: Done! - if (sorted_blocks.size() == 1 && !keep_radix_data) { - sorted_blocks[0]->radix_sorting_data.clear(); - sorted_blocks[0]->blob_sorting_data = nullptr; - } -} -void GlobalSortState::Print() { - PayloadScanner scanner(*this, false); - DataChunk chunk; - chunk.Initialize(Allocator::DefaultAllocator(), scanner.GetPayloadTypes()); - for (;;) { - scanner.Scan(chunk); - const auto count = chunk.size(); - if (!count) { - break; - } - chunk.Print(); - } -} - -} // namespace duckdb diff --git a/src/duckdb/src/common/sort/sorted_block.cpp b/src/duckdb/src/common/sort/sorted_block.cpp deleted file mode 100644 index c4766c956..000000000 --- a/src/duckdb/src/common/sort/sorted_block.cpp +++ /dev/null @@ -1,387 +0,0 @@ -#include "duckdb/common/sort/sorted_block.hpp" - -#include "duckdb/common/constants.hpp" -#include "duckdb/common/row_operations/row_operations.hpp" -#include "duckdb/common/sort/sort.hpp" -#include "duckdb/common/types/row/row_data_collection.hpp" - -#include - -namespace duckdb { - -SortedData::SortedData(SortedDataType type, const RowLayout &layout, BufferManager &buffer_manager, - GlobalSortState &state) - : type(type), layout(layout), swizzled(state.external), buffer_manager(buffer_manager), state(state) { -} - -idx_t SortedData::Count() { - idx_t count = std::accumulate(data_blocks.begin(), data_blocks.end(), (idx_t)0, - [](idx_t a, const unique_ptr &b) { return a + b->count; }); - if (!layout.AllConstant() && state.external) { - D_ASSERT(count == std::accumulate(heap_blocks.begin(), heap_blocks.end(), (idx_t)0, - [](idx_t a, const unique_ptr &b) { return a + b->count; })); - } - return count; -} - -void SortedData::CreateBlock() { - const auto block_size = buffer_manager.GetBlockSize(); - auto capacity = MaxValue((block_size + layout.GetRowWidth() - 1) / layout.GetRowWidth(), state.block_capacity); - data_blocks.push_back(make_uniq(MemoryTag::ORDER_BY, buffer_manager, capacity, layout.GetRowWidth())); - if (!layout.AllConstant() && state.external) { - heap_blocks.push_back(make_uniq(MemoryTag::ORDER_BY, buffer_manager, block_size, 1U)); - D_ASSERT(data_blocks.size() == heap_blocks.size()); - } -} - -unique_ptr SortedData::CreateSlice(idx_t start_block_index, idx_t end_block_index, idx_t end_entry_index) { - // Add the corresponding blocks to the result - auto result = make_uniq(type, layout, buffer_manager, state); - for (idx_t i = start_block_index; i <= end_block_index; i++) { - result->data_blocks.push_back(data_blocks[i]->Copy()); - if (!layout.AllConstant() && state.external) { - result->heap_blocks.push_back(heap_blocks[i]->Copy()); - } - } - // All of the blocks that come before block with idx = start_block_idx can be reset (other references exist) - for (idx_t i = 0; i < start_block_index; i++) { - data_blocks[i]->block = nullptr; - if (!layout.AllConstant() && state.external) { - heap_blocks[i]->block = nullptr; - } - } - // Use start and end entry indices to set the boundaries - D_ASSERT(end_entry_index <= result->data_blocks.back()->count); - result->data_blocks.back()->count = end_entry_index; - if (!layout.AllConstant() && state.external) { - result->heap_blocks.back()->count = end_entry_index; - } - return result; -} - -void SortedData::Unswizzle() { - if (layout.AllConstant() || !swizzled) { - return; - } - for (idx_t i = 0; i < data_blocks.size(); i++) { - auto &data_block = data_blocks[i]; - auto &heap_block = heap_blocks[i]; - D_ASSERT(data_block->block->IsSwizzled()); - auto data_handle_p = buffer_manager.Pin(data_block->block); - auto heap_handle_p = buffer_manager.Pin(heap_block->block); - RowOperations::UnswizzlePointers(layout, data_handle_p.Ptr(), heap_handle_p.Ptr(), data_block->count); - state.heap_blocks.push_back(std::move(heap_block)); - state.pinned_blocks.push_back(std::move(heap_handle_p)); - } - swizzled = false; - heap_blocks.clear(); -} - -SortedBlock::SortedBlock(BufferManager &buffer_manager, GlobalSortState &state) - : buffer_manager(buffer_manager), state(state), sort_layout(state.sort_layout), - payload_layout(state.payload_layout) { - blob_sorting_data = make_uniq(SortedDataType::BLOB, sort_layout.blob_layout, buffer_manager, state); - payload_data = make_uniq(SortedDataType::PAYLOAD, payload_layout, buffer_manager, state); -} - -idx_t SortedBlock::Count() const { - idx_t count = std::accumulate(radix_sorting_data.begin(), radix_sorting_data.end(), (idx_t)0, - [](idx_t a, const unique_ptr &b) { return a + b->count; }); - if (!sort_layout.all_constant) { - D_ASSERT(count == blob_sorting_data->Count()); - } - D_ASSERT(count == payload_data->Count()); - return count; -} - -void SortedBlock::InitializeWrite() { - CreateBlock(); - if (!sort_layout.all_constant) { - blob_sorting_data->CreateBlock(); - } - payload_data->CreateBlock(); -} - -void SortedBlock::CreateBlock() { - const auto block_size = buffer_manager.GetBlockSize(); - auto capacity = MaxValue((block_size + sort_layout.entry_size - 1) / sort_layout.entry_size, state.block_capacity); - radix_sorting_data.push_back( - make_uniq(MemoryTag::ORDER_BY, buffer_manager, capacity, sort_layout.entry_size)); -} - -void SortedBlock::AppendSortedBlocks(vector> &sorted_blocks) { - D_ASSERT(Count() == 0); - for (auto &sb : sorted_blocks) { - for (auto &radix_block : sb->radix_sorting_data) { - radix_sorting_data.push_back(std::move(radix_block)); - } - if (!sort_layout.all_constant) { - for (auto &blob_block : sb->blob_sorting_data->data_blocks) { - blob_sorting_data->data_blocks.push_back(std::move(blob_block)); - } - for (auto &heap_block : sb->blob_sorting_data->heap_blocks) { - blob_sorting_data->heap_blocks.push_back(std::move(heap_block)); - } - } - for (auto &payload_data_block : sb->payload_data->data_blocks) { - payload_data->data_blocks.push_back(std::move(payload_data_block)); - } - if (!payload_data->layout.AllConstant()) { - for (auto &payload_heap_block : sb->payload_data->heap_blocks) { - payload_data->heap_blocks.push_back(std::move(payload_heap_block)); - } - } - } -} - -void SortedBlock::GlobalToLocalIndex(const idx_t &global_idx, idx_t &local_block_index, idx_t &local_entry_index) { - if (global_idx == Count()) { - local_block_index = radix_sorting_data.size() - 1; - local_entry_index = radix_sorting_data.back()->count; - return; - } - D_ASSERT(global_idx < Count()); - local_entry_index = global_idx; - for (local_block_index = 0; local_block_index < radix_sorting_data.size(); local_block_index++) { - const idx_t &block_count = radix_sorting_data[local_block_index]->count; - if (local_entry_index >= block_count) { - local_entry_index -= block_count; - } else { - break; - } - } - D_ASSERT(local_entry_index < radix_sorting_data[local_block_index]->count); -} - -unique_ptr SortedBlock::CreateSlice(const idx_t start, const idx_t end, idx_t &entry_idx) { - // Identify blocks/entry indices of this slice - idx_t start_block_index; - idx_t start_entry_index; - GlobalToLocalIndex(start, start_block_index, start_entry_index); - idx_t end_block_index; - idx_t end_entry_index; - GlobalToLocalIndex(end, end_block_index, end_entry_index); - // Add the corresponding blocks to the result - auto result = make_uniq(buffer_manager, state); - for (idx_t i = start_block_index; i <= end_block_index; i++) { - result->radix_sorting_data.push_back(radix_sorting_data[i]->Copy()); - } - // Reset all blocks that come before block with idx = start_block_idx (slice holds new reference) - for (idx_t i = 0; i < start_block_index; i++) { - radix_sorting_data[i]->block = nullptr; - } - // Use start and end entry indices to set the boundaries - entry_idx = start_entry_index; - D_ASSERT(end_entry_index <= result->radix_sorting_data.back()->count); - result->radix_sorting_data.back()->count = end_entry_index; - // Same for the var size sorting data - if (!sort_layout.all_constant) { - result->blob_sorting_data = blob_sorting_data->CreateSlice(start_block_index, end_block_index, end_entry_index); - } - // And the payload data - result->payload_data = payload_data->CreateSlice(start_block_index, end_block_index, end_entry_index); - return result; -} - -idx_t SortedBlock::HeapSize() const { - idx_t result = 0; - if (!sort_layout.all_constant) { - for (auto &block : blob_sorting_data->heap_blocks) { - result += block->capacity; - } - } - if (!payload_layout.AllConstant()) { - for (auto &block : payload_data->heap_blocks) { - result += block->capacity; - } - } - return result; -} - -idx_t SortedBlock::SizeInBytes() const { - idx_t bytes = 0; - for (idx_t i = 0; i < radix_sorting_data.size(); i++) { - bytes += radix_sorting_data[i]->capacity * sort_layout.entry_size; - if (!sort_layout.all_constant) { - bytes += blob_sorting_data->data_blocks[i]->capacity * sort_layout.blob_layout.GetRowWidth(); - bytes += blob_sorting_data->heap_blocks[i]->capacity; - } - bytes += payload_data->data_blocks[i]->capacity * payload_layout.GetRowWidth(); - if (!payload_layout.AllConstant()) { - bytes += payload_data->heap_blocks[i]->capacity; - } - } - return bytes; -} - -SBScanState::SBScanState(BufferManager &buffer_manager, GlobalSortState &state) - : buffer_manager(buffer_manager), sort_layout(state.sort_layout), state(state), block_idx(0), entry_idx(0) { -} - -void SBScanState::PinRadix(idx_t block_idx_to) { - auto &radix_sorting_data = sb->radix_sorting_data; - D_ASSERT(block_idx_to < radix_sorting_data.size()); - auto &block = radix_sorting_data[block_idx_to]; - if (!radix_handle.IsValid() || radix_handle.GetBlockHandle() != block->block) { - radix_handle = buffer_manager.Pin(block->block); - } -} - -void SBScanState::PinData(SortedData &sd) { - D_ASSERT(block_idx < sd.data_blocks.size()); - auto &data_handle = sd.type == SortedDataType::BLOB ? blob_sorting_data_handle : payload_data_handle; - auto &heap_handle = sd.type == SortedDataType::BLOB ? blob_sorting_heap_handle : payload_heap_handle; - - auto &data_block = sd.data_blocks[block_idx]; - if (!data_handle.IsValid() || data_handle.GetBlockHandle() != data_block->block) { - data_handle = buffer_manager.Pin(data_block->block); - } - if (sd.layout.AllConstant() || !state.external) { - return; - } - auto &heap_block = sd.heap_blocks[block_idx]; - if (!heap_handle.IsValid() || heap_handle.GetBlockHandle() != heap_block->block) { - heap_handle = buffer_manager.Pin(heap_block->block); - } -} - -data_ptr_t SBScanState::RadixPtr() const { - return radix_handle.Ptr() + entry_idx * sort_layout.entry_size; -} - -data_ptr_t SBScanState::DataPtr(SortedData &sd) const { - auto &data_handle = sd.type == SortedDataType::BLOB ? blob_sorting_data_handle : payload_data_handle; - D_ASSERT(sd.data_blocks[block_idx]->block->Readers() != 0 && - data_handle.GetBlockHandle() == sd.data_blocks[block_idx]->block); - return data_handle.Ptr() + entry_idx * sd.layout.GetRowWidth(); -} - -data_ptr_t SBScanState::HeapPtr(SortedData &sd) const { - return BaseHeapPtr(sd) + Load(DataPtr(sd) + sd.layout.GetHeapOffset()); -} - -data_ptr_t SBScanState::BaseHeapPtr(SortedData &sd) const { - auto &heap_handle = sd.type == SortedDataType::BLOB ? blob_sorting_heap_handle : payload_heap_handle; - D_ASSERT(!sd.layout.AllConstant() && state.external); - D_ASSERT(sd.heap_blocks[block_idx]->block->Readers() != 0 && - heap_handle.GetBlockHandle() == sd.heap_blocks[block_idx]->block); - return heap_handle.Ptr(); -} - -idx_t SBScanState::Remaining() const { - const auto &blocks = sb->radix_sorting_data; - idx_t remaining = 0; - if (block_idx < blocks.size()) { - remaining += blocks[block_idx]->count - entry_idx; - for (idx_t i = block_idx + 1; i < blocks.size(); i++) { - remaining += blocks[i]->count; - } - } - return remaining; -} - -void SBScanState::SetIndices(idx_t block_idx_to, idx_t entry_idx_to) { - block_idx = block_idx_to; - entry_idx = entry_idx_to; -} - -PayloadScanner::PayloadScanner(SortedData &sorted_data, GlobalSortState &global_sort_state, bool flush_p) { - auto count = sorted_data.Count(); - auto &layout = sorted_data.layout; - const auto block_size = global_sort_state.buffer_manager.GetBlockSize(); - - // Create collections to put the data into so we can use RowDataCollectionScanner - rows = make_uniq(global_sort_state.buffer_manager, block_size, 1U); - rows->count = count; - - heap = make_uniq(global_sort_state.buffer_manager, block_size, 1U); - if (!sorted_data.layout.AllConstant()) { - heap->count = count; - } - - if (flush_p) { - // If we are flushing, we can just move the data - rows->blocks = std::move(sorted_data.data_blocks); - if (!layout.AllConstant()) { - heap->blocks = std::move(sorted_data.heap_blocks); - } - } else { - // Not flushing, create references to the blocks - for (auto &block : sorted_data.data_blocks) { - rows->blocks.emplace_back(block->Copy()); - } - if (!layout.AllConstant()) { - for (auto &block : sorted_data.heap_blocks) { - heap->blocks.emplace_back(block->Copy()); - } - } - } - - scanner = make_uniq(*rows, *heap, layout, global_sort_state.external, flush_p); -} - -PayloadScanner::PayloadScanner(GlobalSortState &global_sort_state, bool flush_p) - : PayloadScanner(*global_sort_state.sorted_blocks[0]->payload_data, global_sort_state, flush_p) { -} - -PayloadScanner::PayloadScanner(GlobalSortState &global_sort_state, idx_t block_idx, bool flush_p) { - auto &sorted_data = *global_sort_state.sorted_blocks[0]->payload_data; - auto count = sorted_data.data_blocks[block_idx]->count; - auto &layout = sorted_data.layout; - const auto block_size = global_sort_state.buffer_manager.GetBlockSize(); - - // Create collections to put the data into so we can use RowDataCollectionScanner - rows = make_uniq(global_sort_state.buffer_manager, block_size, 1U); - if (flush_p) { - rows->blocks.emplace_back(std::move(sorted_data.data_blocks[block_idx])); - } else { - rows->blocks.emplace_back(sorted_data.data_blocks[block_idx]->Copy()); - } - rows->count = count; - - heap = make_uniq(global_sort_state.buffer_manager, block_size, 1U); - if (!sorted_data.layout.AllConstant() && sorted_data.swizzled) { - if (flush_p) { - heap->blocks.emplace_back(std::move(sorted_data.heap_blocks[block_idx])); - } else { - heap->blocks.emplace_back(sorted_data.heap_blocks[block_idx]->Copy()); - } - heap->count = count; - } - - scanner = make_uniq(*rows, *heap, layout, global_sort_state.external, flush_p); -} - -void PayloadScanner::Scan(DataChunk &chunk) { - scanner->Scan(chunk); -} - -int SBIterator::ComparisonValue(ExpressionType comparison) { - switch (comparison) { - case ExpressionType::COMPARE_LESSTHAN: - case ExpressionType::COMPARE_GREATERTHAN: - return -1; - case ExpressionType::COMPARE_LESSTHANOREQUALTO: - case ExpressionType::COMPARE_GREATERTHANOREQUALTO: - return 0; - default: - throw InternalException("Unimplemented comparison type for IEJoin!"); - } -} - -static idx_t GetBlockCountWithEmptyCheck(const GlobalSortState &gss) { - D_ASSERT(!gss.sorted_blocks.empty()); - return gss.sorted_blocks[0]->radix_sorting_data.size(); -} - -SBIterator::SBIterator(GlobalSortState &gss, ExpressionType comparison, idx_t entry_idx_p) - : sort_layout(gss.sort_layout), block_count(GetBlockCountWithEmptyCheck(gss)), block_capacity(gss.block_capacity), - entry_size(sort_layout.entry_size), all_constant(sort_layout.all_constant), external(gss.external), - cmp(ComparisonValue(comparison)), scan(gss.buffer_manager, gss), block_ptr(nullptr), entry_ptr(nullptr) { - - scan.sb = gss.sorted_blocks[0].get(); - scan.block_idx = block_count; - SetIndex(entry_idx_p); -} - -} // namespace duckdb diff --git a/src/duckdb/src/common/sorting/sorted_run.cpp b/src/duckdb/src/common/sort/sorted_run.cpp similarity index 67% rename from src/duckdb/src/common/sorting/sorted_run.cpp rename to src/duckdb/src/common/sort/sorted_run.cpp index 57c390d32..644351a74 100644 --- a/src/duckdb/src/common/sorting/sorted_run.cpp +++ b/src/duckdb/src/common/sort/sorted_run.cpp @@ -1,6 +1,7 @@ #include "duckdb/common/sorting/sorted_run.hpp" #include "duckdb/common/types/row/tuple_data_collection.hpp" +#include "duckdb/common/sorting/sort.hpp" #include "duckdb/common/sorting/sort_key.hpp" #include "duckdb/common/types/row/block_iterator.hpp" @@ -9,14 +10,144 @@ namespace duckdb { -SortedRun::SortedRun(ClientContext &context_p, shared_ptr key_layout, - shared_ptr payload_layout, bool is_index_sort_p) - : context(context_p), - key_data(make_uniq(BufferManager::GetBufferManager(context), std::move(key_layout))), - payload_data( - payload_layout && payload_layout->ColumnCount() != 0 - ? make_uniq(BufferManager::GetBufferManager(context), std::move(payload_layout)) - : nullptr), +//===--------------------------------------------------------------------===// +// SortedRunScanState +//===--------------------------------------------------------------------===// +SortedRunScanState::SortedRunScanState(ClientContext &context, const Sort &sort_p) + : sort(sort_p), key_executor(context, *sort.decode_sort_key) { + key.Initialize(context, {sort.key_layout->GetTypes()[0]}); + decoded_key.Initialize(context, {sort.decode_sort_key->return_type}); +} + +void SortedRunScanState::Scan(const SortedRun &sorted_run, const Vector &sort_key_pointers, const idx_t &count, + DataChunk &chunk) { + const auto sort_key_type = sort.key_layout->GetSortKeyType(); + switch (sort_key_type) { + case SortKeyType::NO_PAYLOAD_FIXED_8: + return TemplatedScan(sorted_run, sort_key_pointers, count, chunk); + case SortKeyType::NO_PAYLOAD_FIXED_16: + return TemplatedScan(sorted_run, sort_key_pointers, count, chunk); + case SortKeyType::NO_PAYLOAD_FIXED_24: + return TemplatedScan(sorted_run, sort_key_pointers, count, chunk); + case SortKeyType::NO_PAYLOAD_FIXED_32: + return TemplatedScan(sorted_run, sort_key_pointers, count, chunk); + case SortKeyType::NO_PAYLOAD_VARIABLE_32: + return TemplatedScan(sorted_run, sort_key_pointers, count, chunk); + case SortKeyType::PAYLOAD_FIXED_16: + return TemplatedScan(sorted_run, sort_key_pointers, count, chunk); + case SortKeyType::PAYLOAD_FIXED_24: + return TemplatedScan(sorted_run, sort_key_pointers, count, chunk); + case SortKeyType::PAYLOAD_FIXED_32: + return TemplatedScan(sorted_run, sort_key_pointers, count, chunk); + case SortKeyType::PAYLOAD_VARIABLE_32: + return TemplatedScan(sorted_run, sort_key_pointers, count, chunk); + default: + throw NotImplementedException("SortedRunMergerLocalState::ScanPartition for %s", + EnumUtil::ToString(sort_key_type)); + } +} + +template +void TemplatedGetKeyAndPayload(SORT_KEY *const *const sort_keys, SORT_KEY *temp_keys, const idx_t &count, + DataChunk &key, data_ptr_t *const payload_ptrs) { + const auto key_data = FlatVector::GetData(key.data[0]); + for (idx_t i = 0; i < count; i++) { + auto &sort_key = temp_keys[i]; + sort_key = *sort_keys[i]; + sort_key.Deconstruct(key_data[i]); + if (SORT_KEY::HAS_PAYLOAD) { + payload_ptrs[i] = sort_key.GetPayload(); + } + } + key.SetCardinality(count); +} + +template +void GetKeyAndPayload(SORT_KEY *const *const sort_keys, SORT_KEY *temp_keys, const idx_t &count, DataChunk &key, + data_ptr_t *const payload_ptrs) { + const auto type_id = key.data[0].GetType().id(); + switch (type_id) { + case LogicalTypeId::BLOB: + return TemplatedGetKeyAndPayload(sort_keys, temp_keys, count, key, payload_ptrs); + case LogicalTypeId::BIGINT: + return TemplatedGetKeyAndPayload(sort_keys, temp_keys, count, key, payload_ptrs); + default: + throw NotImplementedException("GetKeyAndPayload for %s", EnumUtil::ToString(type_id)); + } +} + +template +void SortedRunScanState::TemplatedScan(const SortedRun &sorted_run, const Vector &sort_key_pointers, const idx_t &count, + DataChunk &chunk) { + using SORT_KEY = SortKey; + + const auto &output_projection_columns = sort.output_projection_columns; + idx_t opc_idx = 0; + + const auto sort_keys = FlatVector::GetData(sort_key_pointers); + const auto payload_ptrs = FlatVector::GetData(payload_state.chunk_state.row_locations); + bool gathered_payload = false; + + // Decode from key + if (!output_projection_columns[0].is_payload) { + key.Reset(); + key_buffer.resize(count * sizeof(SORT_KEY)); + auto temp_keys = reinterpret_cast(key_buffer.data()); + GetKeyAndPayload(sort_keys, temp_keys, count, key, payload_ptrs); + + decoded_key.Reset(); + key_executor.Execute(key, decoded_key); + + const auto &decoded_key_entries = StructVector::GetEntries(decoded_key.data[0]); + for (; opc_idx < output_projection_columns.size(); opc_idx++) { + const auto &opc = output_projection_columns[opc_idx]; + if (opc.is_payload) { + break; + } + chunk.data[opc.output_col_idx].Reference(*decoded_key_entries[opc.layout_col_idx]); + } + + gathered_payload = true; + } + + // If there are no payload columns, we're done here + if (opc_idx != output_projection_columns.size()) { + if (!gathered_payload) { + // Gather row pointers from keys + for (idx_t i = 0; i < count; i++) { + payload_ptrs[i] = sort_keys[i]->GetPayload(); + } + } + + // Init scan state + auto &payload_data = *sorted_run.payload_data; + if (payload_state.pin_state.properties == TupleDataPinProperties::INVALID) { + payload_data.InitializeScan(payload_state, TupleDataPinProperties::ALREADY_PINNED); + } + TupleDataCollection::ResetCachedCastVectors(payload_state.chunk_state, payload_state.chunk_state.column_ids); + + // Now gather from payload + for (; opc_idx < output_projection_columns.size(); opc_idx++) { + const auto &opc = output_projection_columns[opc_idx]; + D_ASSERT(opc.is_payload); + payload_data.Gather(payload_state.chunk_state.row_locations, *FlatVector::IncrementalSelectionVector(), + count, opc.layout_col_idx, chunk.data[opc.output_col_idx], + *FlatVector::IncrementalSelectionVector(), + payload_state.chunk_state.cached_cast_vectors[opc.layout_col_idx]); + } + } + + chunk.SetCardinality(count); +} + +//===--------------------------------------------------------------------===// +// SortedRun +//===--------------------------------------------------------------------===// +SortedRun::SortedRun(ClientContext &context_p, const Sort &sort_p, bool is_index_sort_p) + : context(context_p), sort(sort_p), key_data(make_uniq(context, sort.key_layout)), + payload_data(sort.payload_layout && sort.payload_layout->ColumnCount() != 0 + ? make_uniq(context, sort.payload_layout) + : nullptr), is_index_sort(is_index_sort_p), finalized(false) { key_data->InitializeAppend(key_append_state, TupleDataPinProperties::KEEP_EVERYTHING_PINNED); if (payload_data) { @@ -25,8 +156,7 @@ SortedRun::SortedRun(ClientContext &context_p, shared_ptr key_l } unique_ptr SortedRun::CreateRunForMaterialization() const { - auto res = make_uniq(context, key_data->GetLayoutPtr(), - payload_data ? payload_data->GetLayoutPtr() : nullptr, is_index_sort); + auto res = make_uniq(context, sort, is_index_sort); res->key_append_state.pin_state.properties = TupleDataPinProperties::UNPIN_AFTER_DONE; res->payload_append_state.pin_state.properties = TupleDataPinProperties::UNPIN_AFTER_DONE; res->finalized = true; diff --git a/src/duckdb/src/common/sorting/sorted_run_merger.cpp b/src/duckdb/src/common/sort/sorted_run_merger.cpp similarity index 87% rename from src/duckdb/src/common/sorting/sorted_run_merger.cpp rename to src/duckdb/src/common/sort/sorted_run_merger.cpp index eb879edc5..ee18bc734 100644 --- a/src/duckdb/src/common/sorting/sorted_run_merger.cpp +++ b/src/duckdb/src/common/sort/sorted_run_merger.cpp @@ -1,5 +1,6 @@ #include "duckdb/common/sorting/sorted_run_merger.hpp" +#include "duckdb/common/sorting/sort.hpp" #include "duckdb/common/sorting/sorted_run.hpp" #include "duckdb/common/sorting/sort_key.hpp" #include "duckdb/common/types/row/block_iterator.hpp" @@ -100,7 +101,7 @@ class SortedRunMergerLocalState : public LocalSourceState { //! Whether this thread has finished the work it has been assigned bool TaskFinished() const; //! Do the work this thread has been assigned - void ExecuteTask(SortedRunMergerGlobalState &gstate, optional_ptr chunk); + SourceResultType ExecuteTask(SortedRunMergerGlobalState &gstate, optional_ptr chunk); private: //! Computes upper partition boundaries using K-way Merge Path @@ -154,12 +155,10 @@ class SortedRunMergerLocalState : public LocalSourceState { //! Variables for scanning idx_t merged_partition_count; idx_t merged_partition_index; - TupleDataScanState payload_state; - //! For decoding sort keys - ExpressionExecutor key_executor; - DataChunk key; - DataChunk decoded_key; + //! For scanning + Vector sort_key_pointers; + SortedRunScanState sorted_run_scan_state; }; //===--------------------------------------------------------------------===// @@ -172,7 +171,7 @@ class SortedRunMergerGlobalState : public GlobalSourceState { merger(merger_p), num_runs(merger.sorted_runs.size()), num_partitions((merger.total_count + (merger.partition_size - 1)) / merger.partition_size), iterator_state_type(GetBlockIteratorStateType(merger.external)), - sort_key_type(merger.key_layout->GetSortKeyType()), next_partition_idx(0), total_scanned(0), + sort_key_type(merger.sort.key_layout->GetSortKeyType()), next_partition_idx(0), total_scanned(0), destroy_partition_idx(0) { // Initialize partitions partitions.resize(num_partitions); @@ -263,6 +262,11 @@ class SortedRunMergerGlobalState : public GlobalSourceState { destroy_partition_idx = end_partition_idx; } +private: + static BlockIteratorStateType GetBlockIteratorStateType(const bool &external) { + return external ? BlockIteratorStateType::EXTERNAL : BlockIteratorStateType::IN_MEMORY; + } + public: ClientContext &context; const idx_t num_threads; @@ -292,7 +296,7 @@ SortedRunMergerLocalState::SortedRunMergerLocalState(SortedRunMergerGlobalState : iterator_state_type(gstate.iterator_state_type), sort_key_type(gstate.sort_key_type), task(SortedRunMergerTask::FINISHED), run_boundaries(gstate.num_runs), merged_partition_count(DConstants::INVALID_INDEX), merged_partition_index(DConstants::INVALID_INDEX), - key_executor(gstate.context, gstate.merger.decode_sort_key) { + sorted_run_scan_state(gstate.context, gstate.merger.sort), sort_key_pointers(LogicalType::POINTER) { for (const auto &run : gstate.merger.sorted_runs) { auto &key_data = *run->key_data; switch (iterator_state_type) { @@ -308,8 +312,6 @@ SortedRunMergerLocalState::SortedRunMergerLocalState(SortedRunMergerGlobalState EnumUtil::ToString(iterator_state_type)); } } - key.Initialize(gstate.context, {gstate.merger.key_layout->GetTypes()[0]}); - decoded_key.Initialize(gstate.context, {gstate.merger.decode_sort_key.return_type}); } bool SortedRunMergerLocalState::TaskFinished() const { @@ -328,7 +330,8 @@ bool SortedRunMergerLocalState::TaskFinished() const { } } -void SortedRunMergerLocalState::ExecuteTask(SortedRunMergerGlobalState &gstate, optional_ptr chunk) { +SourceResultType SortedRunMergerLocalState::ExecuteTask(SortedRunMergerGlobalState &gstate, + optional_ptr chunk) { D_ASSERT(task != SortedRunMergerTask::FINISHED); switch (task) { case SortedRunMergerTask::COMPUTE_BOUNDARIES: @@ -352,14 +355,20 @@ void SortedRunMergerLocalState::ExecuteTask(SortedRunMergerGlobalState &gstate, if (!chunk || chunk->size() == 0) { gstate.DestroyScannedData(); gstate.partitions[partition_idx.GetIndex()]->scanned = true; - gstate.total_scanned += merged_partition_count; + // fetch_add returns the _previous_ value! + const auto scan_count_before_adding = gstate.total_scanned.fetch_add(merged_partition_count); + const auto scan_count_after_adding = scan_count_before_adding + merged_partition_count; partition_idx = optional_idx::Invalid(); task = SortedRunMergerTask::FINISHED; + if (scan_count_after_adding == gstate.merger.total_count) { + return SourceResultType::FINISHED; + } } break; default: throw NotImplementedException("SortedRunMergerLocalState::ExecuteTask for task"); } + return SourceResultType::HAVE_MORE_OUTPUT; } void SortedRunMergerLocalState::ComputePartitionBoundaries(SortedRunMergerGlobalState &gstate, @@ -685,94 +694,21 @@ void SortedRunMergerLocalState::ScanPartition(SortedRunMergerGlobalState &gstate } } -template -void TemplatedGetKeyAndPayload(SORT_KEY *const merged_partition_keys, const idx_t count, DataChunk &key, - data_ptr_t *const payload_ptrs) { - const auto key_data = FlatVector::GetData(key.data[0]); - for (idx_t i = 0; i < count; i++) { - auto &merged_partition_key = merged_partition_keys[i]; - merged_partition_key.Deconstruct(key_data[i]); - if (SORT_KEY::HAS_PAYLOAD) { - payload_ptrs[i] = merged_partition_key.GetPayload(); - } - } - key.SetCardinality(count); -} - -template -void GetKeyAndPayload(SORT_KEY *const merged_partition_keys, const idx_t count, DataChunk &key, - data_ptr_t *const payload_ptrs) { - const auto type_id = key.data[0].GetType().id(); - switch (type_id) { - case LogicalTypeId::BLOB: - return TemplatedGetKeyAndPayload(merged_partition_keys, count, key, payload_ptrs); - case LogicalTypeId::BIGINT: - return TemplatedGetKeyAndPayload(merged_partition_keys, count, key, payload_ptrs); - default: - throw NotImplementedException("GetKeyAndPayload for %s", EnumUtil::ToString(type_id)); - } -} - template void SortedRunMergerLocalState::TemplatedScanPartition(SortedRunMergerGlobalState &gstate, DataChunk &chunk) { using SORT_KEY = SortKey; const auto count = MinValue(merged_partition_count - merged_partition_index, STANDARD_VECTOR_SIZE); - const auto &output_projection_columns = gstate.merger.output_projection_columns; - idx_t opc_idx = 0; - + // Grab pointers to sort keys const auto merged_partition_keys = reinterpret_cast(merged_partition.get()) + merged_partition_index; - const auto payload_ptrs = FlatVector::GetData(payload_state.chunk_state.row_locations); - bool gathered_payload = false; - - // Decode from key - if (!output_projection_columns[0].is_payload) { - key.Reset(); - GetKeyAndPayload(merged_partition_keys, count, key, payload_ptrs); - - decoded_key.Reset(); - key_executor.Execute(key, decoded_key); - - const auto &decoded_key_entries = StructVector::GetEntries(decoded_key.data[0]); - for (; opc_idx < output_projection_columns.size(); opc_idx++) { - const auto &opc = output_projection_columns[opc_idx]; - if (opc.is_payload) { - break; - } - chunk.data[opc.output_col_idx].Reference(*decoded_key_entries[opc.layout_col_idx]); - } - gathered_payload = true; - } - - // If there are no payload columns, we're done here - if (opc_idx != output_projection_columns.size()) { - if (!gathered_payload) { - // Gather row pointers from keys - for (idx_t i = 0; i < count; i++) { - payload_ptrs[i] = merged_partition_keys[i].GetPayload(); - } - } - - // Init scan state - auto &payload_data = *gstate.merger.sorted_runs.back()->payload_data; - if (payload_state.pin_state.properties == TupleDataPinProperties::INVALID) { - payload_data.InitializeScan(payload_state, TupleDataPinProperties::ALREADY_PINNED); - } - TupleDataCollection::ResetCachedCastVectors(payload_state.chunk_state, payload_state.chunk_state.column_ids); - - // Now gather from payload - for (; opc_idx < output_projection_columns.size(); opc_idx++) { - const auto &opc = output_projection_columns[opc_idx]; - D_ASSERT(opc.is_payload); - payload_data.Gather(payload_state.chunk_state.row_locations, *FlatVector::IncrementalSelectionVector(), - count, opc.layout_col_idx, chunk.data[opc.output_col_idx], - *FlatVector::IncrementalSelectionVector(), - payload_state.chunk_state.cached_cast_vectors[opc.layout_col_idx]); - } + const auto sort_keys = FlatVector::GetData(sort_key_pointers); + for (idx_t i = 0; i < count; i++) { + sort_keys[i] = &merged_partition_keys[i]; } - merged_partition_index += count; - chunk.SetCardinality(count); + + // Scan + sorted_run_scan_state.Scan(*gstate.merger.sorted_runs[0], sort_key_pointers, count, chunk); } void SortedRunMergerLocalState::MaterializePartition(SortedRunMergerGlobalState &gstate) { @@ -812,7 +748,9 @@ void SortedRunMergerLocalState::MaterializePartition(SortedRunMergerGlobalState // Add to global state lock_guard guard(gstate.materialized_partition_lock); - gstate.materialized_partitions.resize(partition_idx.GetIndex()); + if (gstate.materialized_partitions.size() < partition_idx.GetIndex() + 1) { + gstate.materialized_partitions.resize(partition_idx.GetIndex() + 1); + } gstate.materialized_partitions[partition_idx.GetIndex()] = std::move(sorted_run); } @@ -833,7 +771,7 @@ unique_ptr SortedRunMergerLocalState::TemplatedMaterializePartition(S while (merged_partition_index < merged_partition_count) { const auto count = MinValue(merged_partition_count - merged_partition_index, STANDARD_VECTOR_SIZE); - for (idx_t i = 0; i < count + count; i++) { + for (idx_t i = 0; i < count; i++) { auto &key = merged_partition_keys[merged_partition_index + i]; key_locations[i] = data_ptr_cast(&key); if (!SORT_KEY::CONSTANT_SIZE) { @@ -855,7 +793,7 @@ unique_ptr SortedRunMergerLocalState::TemplatedMaterializePartition(S if (!sorted_run->payload_data->GetLayout().AllConstant()) { sorted_run->payload_data->FindHeapPointers(payload_data_input, count); } - sorted_run->payload_append_state.chunk_state.heap_sizes.Reference(key_data_input.heap_sizes); + sorted_run->payload_append_state.chunk_state.heap_sizes.Reference(payload_data_input.heap_sizes); sorted_run->payload_data->Build(sorted_run->payload_append_state.pin_state, sorted_run->payload_append_state.chunk_state, 0, count); sorted_run->payload_data->CopyRows(sorted_run->payload_append_state.chunk_state, payload_data_input, @@ -876,18 +814,16 @@ unique_ptr SortedRunMergerLocalState::TemplatedMaterializePartition(S //===--------------------------------------------------------------------===// // Sorted Run Merger //===--------------------------------------------------------------------===// -SortedRunMerger::SortedRunMerger(const Expression &decode_sort_key_p, shared_ptr key_layout_p, - vector> &&sorted_runs_p, - const vector &output_projection_columns_p, +SortedRunMerger::SortedRunMerger(const Sort &sort_p, vector> &&sorted_runs_p, idx_t partition_size_p, bool external_p, bool is_index_sort_p) - : decode_sort_key(decode_sort_key_p), key_layout(std::move(key_layout_p)), sorted_runs(std::move(sorted_runs_p)), - output_projection_columns(output_projection_columns_p), total_count(SortedRunsTotalCount(sorted_runs)), + : sort(sort_p), sorted_runs(std::move(sorted_runs_p)), total_count(SortedRunsTotalCount(sorted_runs)), partition_size(partition_size_p), external(external_p), is_index_sort(is_index_sort_p) { } unique_ptr SortedRunMerger::GetLocalSourceState(ExecutionContext &, GlobalSourceState &gstate_p) const { auto &gstate = gstate_p.Cast(); + auto guard = gstate.Lock(); return make_uniq(gstate); } @@ -929,30 +865,28 @@ ProgressData SortedRunMerger::GetProgress(ClientContext &, GlobalSourceState &gs //===--------------------------------------------------------------------===// // Non-Standard Interface //===--------------------------------------------------------------------===// -SourceResultType SortedRunMerger::MaterializeMerge(ExecutionContext &, OperatorSourceInput &input) const { +SourceResultType SortedRunMerger::MaterializeSortedRun(ExecutionContext &, OperatorSourceInput &input) const { auto &gstate = input.global_state.Cast(); auto &lstate = input.local_state.Cast(); + SourceResultType res = SourceResultType::HAVE_MORE_OUTPUT; while (true) { if (!lstate.TaskFinished() || gstate.AssignTask(lstate)) { - lstate.ExecuteTask(gstate, nullptr); + res = lstate.ExecuteTask(gstate, nullptr); } else { break; } } - if (gstate.total_scanned == total_count) { - // This signals that the data has been fully materialized - return SourceResultType::FINISHED; - } - // This signals that no more tasks are left, but that the data has not yet been fully materialized - return SourceResultType::HAVE_MORE_OUTPUT; + // The thread that completes the materialization returns FINISHED, all other threads return HAVE_MORE_OUTPUT + return res; } -unique_ptr SortedRunMerger::GetMaterialized(GlobalSourceState &global_state) { +unique_ptr SortedRunMerger::GetSortedRun(GlobalSourceState &global_state) { auto &gstate = global_state.Cast(); + D_ASSERT(total_count != 0); + lock_guard guard(gstate.materialized_partition_lock); if (gstate.materialized_partitions.empty()) { - D_ASSERT(total_count == 0); return nullptr; } auto &target = *gstate.materialized_partitions[0]; @@ -963,7 +897,9 @@ unique_ptr SortedRunMerger::GetMaterialized(GlobalSourceState &global target.payload_data->Combine(*source.payload_data); } } - return std::move(gstate.materialized_partitions[0]); + auto res = std::move(gstate.materialized_partitions[0]); + gstate.materialized_partitions.clear(); + return res; } } // namespace duckdb diff --git a/src/duckdb/src/common/string_util.cpp b/src/duckdb/src/common/string_util.cpp index 51be7c3eb..504eecf39 100644 --- a/src/duckdb/src/common/string_util.cpp +++ b/src/duckdb/src/common/string_util.cpp @@ -287,9 +287,13 @@ bool StringUtil::IsUpper(const string &str) { // Jenkins hash function: https://en.wikipedia.org/wiki/Jenkins_hash_function uint64_t StringUtil::CIHash(const string &str) { + return StringUtil::CIHash(str.c_str(), str.size()); +} + +uint64_t StringUtil::CIHash(const char *str, idx_t size) { uint32_t hash = 0; - for (auto c : str) { - hash += static_cast(StringUtil::CharacterToLower(static_cast(c))); + for (idx_t i = 0; i < size; i++) { + hash += static_cast(StringUtil::CharacterToLower(static_cast(str[i]))); hash += hash << 10; hash ^= hash >> 6; } @@ -396,7 +400,10 @@ vector StringUtil::TopNStrings(vector> scores, idx_ return vector(); } sort(scores.begin(), scores.end(), [](const pair &a, const pair &b) -> bool { - return a.second > b.second || (a.second == b.second && a.first.size() < b.first.size()); + if (a.second != b.second) { + return a.second > b.second; + } + return StringUtil::CILessThan(a.first, b.first); }); vector result; result.push_back(scores[0].first); @@ -702,6 +709,21 @@ string StringUtil::ToComplexJSONMap(const ComplexJSON &complex_json) { return ComplexJSON::GetValueRecursive(complex_json); } +string StringUtil::ValidateJSON(const char *data, const idx_t &len) { + // Same flags as in JSON extension + static constexpr auto READ_FLAG = + YYJSON_READ_ALLOW_INF_AND_NAN | YYJSON_READ_ALLOW_TRAILING_COMMAS | YYJSON_READ_BIGNUM_AS_RAW; + yyjson_read_err error; + yyjson_doc *doc = yyjson_read_opts((char *)data, len, READ_FLAG, nullptr, &error); // NOLINT: for yyjson + if (error.code != YYJSON_READ_SUCCESS) { + return StringUtil::Format("Malformed JSON at byte %lld of input: %s. Input: \"%s\"", error.pos, error.msg, + string(data, len)); + } + + yyjson_doc_free(doc); + return string(); +} + string StringUtil::ExceptionToJSONMap(ExceptionType type, const string &message, const unordered_map &map) { D_ASSERT(map.find("exception_type") == map.end()); @@ -719,7 +741,6 @@ string StringUtil::ExceptionToJSONMap(ExceptionType type, const string &message, } string StringUtil::GetFileName(const string &file_path) { - idx_t pos = file_path.find_last_of("/\\"); if (pos == string::npos) { return file_path; diff --git a/src/duckdb/src/common/tree_renderer.cpp b/src/duckdb/src/common/tree_renderer.cpp index c8d97959b..c7a810468 100644 --- a/src/duckdb/src/common/tree_renderer.cpp +++ b/src/duckdb/src/common/tree_renderer.cpp @@ -4,6 +4,7 @@ #include "duckdb/common/tree_renderer/html_tree_renderer.hpp" #include "duckdb/common/tree_renderer/graphviz_tree_renderer.hpp" #include "duckdb/common/tree_renderer/yaml_tree_renderer.hpp" +#include "duckdb/common/tree_renderer/mermaid_tree_renderer.hpp" #include @@ -22,6 +23,8 @@ unique_ptr TreeRenderer::CreateRenderer(ExplainFormat format) { return make_uniq(); case ExplainFormat::YAML: return make_uniq(); + case ExplainFormat::MERMAID: + return make_uniq(); default: throw NotImplementedException("ExplainFormat %s not implemented", EnumUtil::ToString(format)); } diff --git a/src/duckdb/src/common/tree_renderer/mermaid_tree_renderer.cpp b/src/duckdb/src/common/tree_renderer/mermaid_tree_renderer.cpp new file mode 100644 index 000000000..9ff7b6539 --- /dev/null +++ b/src/duckdb/src/common/tree_renderer/mermaid_tree_renderer.cpp @@ -0,0 +1,133 @@ +#include "duckdb/common/tree_renderer/mermaid_tree_renderer.hpp" + +#include "duckdb/common/pair.hpp" +#include "duckdb/common/string_util.hpp" +#include "duckdb/execution/operator/aggregate/physical_hash_aggregate.hpp" +#include "duckdb/execution/operator/join/physical_delim_join.hpp" +#include "duckdb/execution/operator/scan/physical_positional_scan.hpp" +#include "duckdb/execution/physical_operator.hpp" +#include "duckdb/parallel/pipeline.hpp" +#include "duckdb/planner/logical_operator.hpp" +#include "duckdb/main/query_profiler.hpp" +#include "utf8proc_wrapper.hpp" + +#include + +namespace duckdb { + +string MermaidTreeRenderer::ToString(const LogicalOperator &op) { + duckdb::stringstream ss; + Render(op, ss); + return ss.str(); +} + +string MermaidTreeRenderer::ToString(const PhysicalOperator &op) { + duckdb::stringstream ss; + Render(op, ss); + return ss.str(); +} + +string MermaidTreeRenderer::ToString(const ProfilingNode &op) { + duckdb::stringstream ss; + Render(op, ss); + return ss.str(); +} + +string MermaidTreeRenderer::ToString(const Pipeline &op) { + duckdb::stringstream ss; + Render(op, ss); + return ss.str(); +} + +void MermaidTreeRenderer::Render(const LogicalOperator &op, std::ostream &ss) { + auto tree = RenderTree::CreateRenderTree(op); + ToStream(*tree, ss); +} + +void MermaidTreeRenderer::Render(const PhysicalOperator &op, std::ostream &ss) { + auto tree = RenderTree::CreateRenderTree(op); + ToStream(*tree, ss); +} + +void MermaidTreeRenderer::Render(const ProfilingNode &op, std::ostream &ss) { + auto tree = RenderTree::CreateRenderTree(op); + ToStream(*tree, ss); +} + +void MermaidTreeRenderer::Render(const Pipeline &op, std::ostream &ss) { + auto tree = RenderTree::CreateRenderTree(op); + ToStream(*tree, ss); +} + +static string SanitizeMermaidLabel(const string &text) { + string result; + result.reserve(text.size() * 2); // Reserve more space for potential escape sequences + for (size_t i = 0; i < text.size(); i++) { + char c = text[i]; + // Escape backticks and quotes + if (c == '`') { + result += "\\`"; + } else if (c == '"') { + result += "\\\""; + } else if (c == '\\' && i + 1 < text.size() && text[i + 1] == 'n') { + // Replace literal "\n" with actual newline for Mermaid markdown + result += "\n\t"; + i++; // Skip the 'n' + } else { + result += c; + } + } + return result; +} + +void MermaidTreeRenderer::ToStreamInternal(RenderTree &root, std::ostream &ss) { + vector nodes; + vector edges; + + const string node_format = " node_%d_%d[\"`**%s**%s`\"]"; + + for (idx_t y = 0; y < root.height; y++) { + for (idx_t x = 0; x < root.width; x++) { + auto node = root.GetNode(x, y); + if (!node) { + continue; + } + + // Build node label with markdown formatting + string extra_info; + for (auto &item : node->extra_text) { + auto &key = item.first; + auto &value_raw = item.second; + + auto value = QueryProfiler::JSONSanitize(value_raw); + // Add newline and key-value pair + extra_info += StringUtil::Format("\n\t%s: %s", key, SanitizeMermaidLabel(value)); + } + + // Create node with bold operator name and extra info (trim name to remove trailing spaces) + auto trimmed_name = node->name; + StringUtil::Trim(trimmed_name); + nodes.push_back(StringUtil::Format(node_format, x, y, SanitizeMermaidLabel(trimmed_name), extra_info)); + + // Create Edge(s) + for (auto &coord : node->child_positions) { + edges.push_back(StringUtil::Format(" node_%d_%d --> node_%d_%d", x, y, coord.x, coord.y)); + } + } + } + + // Output Mermaid flowchart + ss << "flowchart TD\n"; + + // Output nodes + for (auto &node : nodes) { + ss << node << "\n\n"; + } + + // Output edges + for (auto &edge : edges) { + ss << edge << "\n"; + } +} + +} // namespace duckdb diff --git a/src/duckdb/src/common/types.cpp b/src/duckdb/src/common/types.cpp index 40ff794e6..6542fafc2 100644 --- a/src/duckdb/src/common/types.cpp +++ b/src/duckdb/src/common/types.cpp @@ -31,6 +31,9 @@ namespace duckdb { +constexpr idx_t ArrayType::MAX_ARRAY_SIZE; +const idx_t UnionType::MAX_UNION_MEMBERS; + LogicalType::LogicalType() : LogicalType(LogicalTypeId::INVALID) { } @@ -159,6 +162,8 @@ PhysicalType LogicalType::GetInternalType() { return PhysicalType::UNKNOWN; case LogicalTypeId::AGGREGATE_STATE: return PhysicalType::VARCHAR; + case LogicalTypeId::GEOMETRY: + return PhysicalType::VARCHAR; default: throw InternalException("Invalid LogicalType %s", ToString()); } @@ -1344,6 +1349,8 @@ static idx_t GetLogicalTypeScore(const LogicalType &type) { return 102; case LogicalTypeId::BIGNUM: return 103; + case LogicalTypeId::GEOMETRY: + return 104; // nested types case LogicalTypeId::STRUCT: return 125; @@ -2014,6 +2021,15 @@ LogicalType LogicalType::VARIANT() { return LogicalType(LogicalTypeId::VARIANT, std::move(info)); } +//===--------------------------------------------------------------------===// +// Spatial Types +//===--------------------------------------------------------------------===// + +LogicalType LogicalType::GEOMETRY() { + auto info = make_shared_ptr(); + return LogicalType(LogicalTypeId::GEOMETRY, std::move(info)); +} + //===--------------------------------------------------------------------===// // Logical Type //===--------------------------------------------------------------------===// diff --git a/src/duckdb/src/common/types/batched_data_collection.cpp b/src/duckdb/src/common/types/batched_data_collection.cpp index fd25dbc1a..6f38c098c 100644 --- a/src/duckdb/src/common/types/batched_data_collection.cpp +++ b/src/duckdb/src/common/types/batched_data_collection.cpp @@ -2,18 +2,47 @@ #include "duckdb/common/optional_ptr.hpp" #include "duckdb/common/printer.hpp" +#include "duckdb/main/client_context.hpp" #include "duckdb/storage/buffer_manager.hpp" namespace duckdb { BatchedDataCollection::BatchedDataCollection(ClientContext &context_p, vector types_p, - bool buffer_managed_p) - : context(context_p), types(std::move(types_p)), buffer_managed(buffer_managed_p) { + ColumnDataAllocatorType allocator_type_p, + ColumnDataCollectionLifetime lifetime_p) + : context(context_p), types(std::move(types_p)), allocator_type(allocator_type_p), lifetime(lifetime_p) { +} + +BatchedDataCollection::BatchedDataCollection(ClientContext &context, vector types, + QueryResultMemoryType memory_type) + : BatchedDataCollection(context, std::move(types), + memory_type == QueryResultMemoryType::BUFFER_MANAGED + ? ColumnDataAllocatorType::BUFFER_MANAGER_ALLOCATOR + : ColumnDataAllocatorType::IN_MEMORY_ALLOCATOR, + memory_type == QueryResultMemoryType::BUFFER_MANAGED + ? ColumnDataCollectionLifetime::THROW_ERROR_AFTER_DATABASE_CLOSES + : ColumnDataCollectionLifetime::REGULAR) { } BatchedDataCollection::BatchedDataCollection(ClientContext &context_p, vector types_p, batch_map_t batches, - bool buffer_managed_p) - : context(context_p), types(std::move(types_p)), buffer_managed(buffer_managed_p), data(std::move(batches)) { + ColumnDataAllocatorType allocator_type_p, + ColumnDataCollectionLifetime lifetime_p) + : context(context_p), types(std::move(types_p)), allocator_type(allocator_type_p), lifetime(lifetime_p), + data(std::move(batches)) { +} + +unique_ptr BatchedDataCollection::CreateCollection() const { + if (last_collection.collection) { + return make_uniq(*last_collection.collection); + } else if (allocator_type == ColumnDataAllocatorType::BUFFER_MANAGER_ALLOCATOR) { + auto &buffer_manager = lifetime == ColumnDataCollectionLifetime::REGULAR + ? BufferManager::GetBufferManager(context) + : BufferManager::GetBufferManager(*context.db); + return make_uniq(buffer_manager, types, lifetime); + } else { + D_ASSERT(allocator_type == ColumnDataAllocatorType::IN_MEMORY_ALLOCATOR); + return make_uniq(Allocator::DefaultAllocator(), types); + } } void BatchedDataCollection::Append(DataChunk &input, idx_t batch_index) { @@ -25,14 +54,7 @@ void BatchedDataCollection::Append(DataChunk &input, idx_t batch_index) { } else { // new collection: check if there is already an entry D_ASSERT(data.find(batch_index) == data.end()); - unique_ptr new_collection; - if (last_collection.collection) { - new_collection = make_uniq(*last_collection.collection); - } else if (buffer_managed) { - new_collection = make_uniq(BufferManager::GetBufferManager(context), types); - } else { - new_collection = make_uniq(Allocator::DefaultAllocator(), types); - } + unique_ptr new_collection = CreateCollection(); last_collection.collection = new_collection.get(); last_collection.batch_index = batch_index; new_collection->InitializeAppend(last_collection.append_state); @@ -98,7 +120,7 @@ unique_ptr BatchedDataCollection::FetchCollection() { data.clear(); if (!result) { // empty result - return make_uniq(Allocator::DefaultAllocator(), types); + return CreateCollection(); } return result; } diff --git a/src/duckdb/src/common/types/column/column_data_allocator.cpp b/src/duckdb/src/common/types/column/column_data_allocator.cpp index b4f3f4d74..b0fefb32e 100644 --- a/src/duckdb/src/common/types/column/column_data_allocator.cpp +++ b/src/duckdb/src/common/types/column/column_data_allocator.cpp @@ -2,6 +2,8 @@ #include "duckdb/common/radix_partitioning.hpp" #include "duckdb/common/types/column/column_data_collection_segment.hpp" +#include "duckdb/main/database.hpp" +#include "duckdb/main/result_set_manager.hpp" #include "duckdb/storage/buffer/block_handle.hpp" #include "duckdb/storage/buffer/buffer_pool.hpp" #include "duckdb/storage/buffer_manager.hpp" @@ -12,17 +14,24 @@ ColumnDataAllocator::ColumnDataAllocator(Allocator &allocator) : type(ColumnData alloc.allocator = &allocator; } -ColumnDataAllocator::ColumnDataAllocator(BufferManager &buffer_manager) +ColumnDataAllocator::ColumnDataAllocator(BufferManager &buffer_manager, ColumnDataCollectionLifetime lifetime) : type(ColumnDataAllocatorType::BUFFER_MANAGER_ALLOCATOR) { alloc.buffer_manager = &buffer_manager; + if (lifetime == ColumnDataCollectionLifetime::THROW_ERROR_AFTER_DATABASE_CLOSES) { + managed_result_set = ResultSetManager::Get(buffer_manager.GetDatabase()).Add(*this); + } } -ColumnDataAllocator::ColumnDataAllocator(ClientContext &context, ColumnDataAllocatorType allocator_type) +ColumnDataAllocator::ColumnDataAllocator(ClientContext &context, ColumnDataAllocatorType allocator_type, + ColumnDataCollectionLifetime lifetime) : type(allocator_type) { switch (type) { case ColumnDataAllocatorType::BUFFER_MANAGER_ALLOCATOR: case ColumnDataAllocatorType::HYBRID: alloc.buffer_manager = &BufferManager::GetBufferManager(context); + if (lifetime == ColumnDataCollectionLifetime::THROW_ERROR_AFTER_DATABASE_CLOSES) { + managed_result_set = ResultSetManager::Get(context).Add(*this); + } break; case ColumnDataAllocatorType::IN_MEMORY_ALLOCATOR: alloc.allocator = &Allocator::Get(context); @@ -38,6 +47,9 @@ ColumnDataAllocator::ColumnDataAllocator(ColumnDataAllocator &other) { case ColumnDataAllocatorType::BUFFER_MANAGER_ALLOCATOR: case ColumnDataAllocatorType::HYBRID: alloc.buffer_manager = other.alloc.buffer_manager; + if (other.managed_result_set.IsValid()) { + ResultSetManager::Get(alloc.buffer_manager->GetDatabase()).Add(*this); + } break; case ColumnDataAllocatorType::IN_MEMORY_ALLOCATOR: alloc.allocator = other.alloc.allocator; @@ -51,8 +63,16 @@ ColumnDataAllocator::~ColumnDataAllocator() { if (type == ColumnDataAllocatorType::IN_MEMORY_ALLOCATOR) { return; } + if (managed_result_set.IsValid()) { + D_ASSERT(type != ColumnDataAllocatorType::IN_MEMORY_ALLOCATOR); + auto db = managed_result_set.GetDatabase(); + if (db) { + ResultSetManager::Get(*db).Remove(*this); + } + return; + } for (auto &block : blocks) { - block.handle->SetDestroyBufferUpon(DestroyBufferUpon::UNPIN); + block.GetHandle()->SetDestroyBufferUpon(DestroyBufferUpon::UNPIN); } blocks.clear(); } @@ -64,9 +84,9 @@ BufferHandle ColumnDataAllocator::Pin(uint32_t block_id) { // we only need to grab the lock when accessing the vector, because vector access is not thread-safe: // the vector can be resized by another thread while we try to access it lock_guard guard(lock); - handle = blocks[block_id].handle; + handle = blocks[block_id].GetHandle(); } else { - handle = blocks[block_id].handle; + handle = blocks[block_id].GetHandle(); } return alloc.buffer_manager->Pin(handle); } @@ -78,10 +98,10 @@ BufferHandle ColumnDataAllocator::AllocateBlock(idx_t size) { data.size = 0; data.capacity = NumericCast(max_size); auto pin = alloc.buffer_manager->Allocate(MemoryTag::COLUMN_DATA, max_size, false); - data.handle = pin.GetBlockHandle(); + data.SetHandle(managed_result_set, pin.GetBlockHandle()); blocks.push_back(std::move(data)); if (partition_index.IsValid()) { // Set the eviction queue index logarithmically using RadixBits - blocks.back().handle->SetEvictionQueueIndex(RadixPartitioning::RadixBits(partition_index.GetIndex())); + blocks.back().GetHandle()->SetEvictionQueueIndex(RadixPartitioning::RadixBits(partition_index.GetIndex())); } allocated_size += max_size; return pin; @@ -98,7 +118,6 @@ void ColumnDataAllocator::AllocateEmptyBlock(idx_t size) { BlockMetaData data; data.size = 0; data.capacity = NumericCast(allocation_amount); - data.handle = nullptr; blocks.push_back(std::move(data)); allocated_size += allocation_amount; } @@ -131,7 +150,8 @@ void ColumnDataAllocator::AllocateBuffer(idx_t size, uint32_t &block_id, uint32_ block_id = NumericCast(blocks.size() - 1); if (chunk_state && chunk_state->handles.find(block_id) == chunk_state->handles.end()) { // not guaranteed to be pinned already by this thread (if shared allocator) - chunk_state->handles[block_id] = alloc.buffer_manager->Pin(blocks[block_id].handle); + auto handle = blocks[block_id].GetHandle(); + chunk_state->handles[block_id] = alloc.buffer_manager->Pin(handle); } offset = block.size; block.size += size; @@ -235,7 +255,18 @@ void ColumnDataAllocator::UnswizzlePointers(ChunkManagementState &state, Vector } void ColumnDataAllocator::SetDestroyBufferUponUnpin(uint32_t block_id) { - blocks[block_id].handle->SetDestroyBufferUpon(DestroyBufferUpon::UNPIN); + blocks[block_id].GetHandle()->SetDestroyBufferUpon(DestroyBufferUpon::UNPIN); +} + +shared_ptr ColumnDataAllocator::GetDatabase() const { + if (!managed_result_set.IsValid()) { + return nullptr; + } + auto db = managed_result_set.GetDatabase(); + if (!db) { + throw ConnectionException("Trying to access a query result after the database instance has been closed"); + } + return db; } Allocator &ColumnDataAllocator::GetAllocator() { @@ -282,6 +313,26 @@ void ColumnDataAllocator::InitializeChunkState(ChunkManagementState &state, Chun } } +shared_ptr BlockMetaData::GetHandle() const { + if (handle) { + return handle; + } + auto res = weak_handle.lock(); + if (!res) { + throw ConnectionException("Trying to access a query result after the database instance has been closed"); + } + return res; +} + +void BlockMetaData::SetHandle(ManagedResultSet &managed_result_set, shared_ptr handle_p) { + if (managed_result_set.IsValid()) { + managed_result_set.GetHandles().emplace_back(handle_p); + weak_handle = handle_p; + } else { + handle = std::move(handle_p); + } +} + uint32_t BlockMetaData::Capacity() { D_ASSERT(size <= capacity); return capacity - size; diff --git a/src/duckdb/src/common/types/column/column_data_collection.cpp b/src/duckdb/src/common/types/column/column_data_collection.cpp index b53e07d68..6555f8a96 100644 --- a/src/duckdb/src/common/types/column/column_data_collection.cpp +++ b/src/duckdb/src/common/types/column/column_data_collection.cpp @@ -8,6 +8,7 @@ #include "duckdb/common/types/value_map.hpp" #include "duckdb/common/uhugeint.hpp" #include "duckdb/common/vector_operations/vector_operations.hpp" +#include "duckdb/main/database.hpp" #include "duckdb/storage/buffer_manager.hpp" namespace duckdb { @@ -59,9 +60,10 @@ ColumnDataCollection::ColumnDataCollection(Allocator &allocator_p, vector(allocator_p); } -ColumnDataCollection::ColumnDataCollection(BufferManager &buffer_manager, vector types_p) { +ColumnDataCollection::ColumnDataCollection(BufferManager &buffer_manager, vector types_p, + ColumnDataCollectionLifetime lifetime) { Initialize(std::move(types_p)); - allocator = make_shared_ptr(buffer_manager); + allocator = make_shared_ptr(buffer_manager, lifetime); } ColumnDataCollection::ColumnDataCollection(shared_ptr allocator_p, vector types_p) { @@ -70,8 +72,8 @@ ColumnDataCollection::ColumnDataCollection(shared_ptr alloc } ColumnDataCollection::ColumnDataCollection(ClientContext &context, vector types_p, - ColumnDataAllocatorType type) - : ColumnDataCollection(make_shared_ptr(context, type), std::move(types_p)) { + ColumnDataAllocatorType type, ColumnDataCollectionLifetime lifetime) + : ColumnDataCollection(make_shared_ptr(context, type, lifetime), std::move(types_p)) { D_ASSERT(!types.empty()); } @@ -146,16 +148,22 @@ idx_t ColumnDataRow::RowIndex() const { //===--------------------------------------------------------------------===// // ColumnDataRowCollection //===--------------------------------------------------------------------===// -ColumnDataRowCollection::ColumnDataRowCollection(const ColumnDataCollection &collection) { +ColumnDataRowCollection::ColumnDataRowCollection(const ColumnDataCollection &collection, + const ColumnDataScanProperties properties) { if (collection.Count() == 0) { return; } // read all the chunks ColumnDataScanState temp_scan_state; - collection.InitializeScan(temp_scan_state, ColumnDataScanProperties::DISALLOW_ZERO_COPY); + collection.InitializeScan(temp_scan_state, properties); while (true) { auto chunk = make_uniq(); - collection.InitializeScanChunk(*chunk); + // Use default allocator so the chunk is independently usable even after the DB allocator is destroyed + if (properties == ColumnDataScanProperties::DISALLOW_ZERO_COPY) { + collection.InitializeScanChunk(Allocator::DefaultAllocator(), *chunk); + } else { + collection.InitializeScanChunk(*chunk); + } if (!collection.Scan(temp_scan_state, *chunk)) { break; } @@ -252,12 +260,13 @@ ColumnDataRowIterationHelper::ColumnDataRowIterationHelper(const ColumnDataColle : collection(collection_p) { } -ColumnDataRowIterationHelper::ColumnDataRowIterator::ColumnDataRowIterator(const ColumnDataCollection *collection_p) +ColumnDataRowIterationHelper::ColumnDataRowIterator::ColumnDataRowIterator(const ColumnDataCollection *collection_p, + ColumnDataScanProperties properties) : collection(collection_p), scan_chunk(make_shared_ptr()), current_row(*scan_chunk, 0, 0) { if (!collection) { return; } - collection->InitializeScan(scan_state); + collection->InitializeScan(scan_state, properties); collection->InitializeScanChunk(*scan_chunk); collection->Scan(scan_state, *scan_chunk); } @@ -593,7 +602,6 @@ bool ColumnDataCopyCompressedStrings(ColumnDataMetaData &meta_data, const Vector template <> void ColumnDataCopy(ColumnDataMetaData &meta_data, const UnifiedVectorFormat &source_data, Vector &source, idx_t offset, idx_t copy_count) { - const auto &allocator_type = meta_data.segment.allocator->GetType(); if (allocator_type == ColumnDataAllocatorType::IN_MEMORY_ALLOCATOR || allocator_type == ColumnDataAllocatorType::HYBRID) { @@ -733,7 +741,6 @@ void ColumnDataCopy(ColumnDataMetaData &meta_data, const UnifiedVector template <> void ColumnDataCopy(ColumnDataMetaData &meta_data, const UnifiedVectorFormat &source_data, Vector &source, idx_t offset, idx_t copy_count) { - auto &segment = meta_data.segment; auto &child_vector = ListVector::GetEntry(source); @@ -813,7 +820,6 @@ void ColumnDataCopyStruct(ColumnDataMetaData &meta_data, const UnifiedVectorForm void ColumnDataCopyArray(ColumnDataMetaData &meta_data, const UnifiedVectorFormat &source_data, Vector &source, idx_t offset, idx_t copy_count) { - auto &segment = meta_data.segment; // copy the NULL values for the main array vector (the same as for a struct vector) @@ -842,7 +848,8 @@ void ColumnDataCopyArray(ColumnDataMetaData &meta_data, const UnifiedVectorForma child_vector.ToUnifiedFormat(copy_count * array_size, child_vector_data); // Broadcast and sync the validity of the array vector to the child vector - + // This requires creating a copy of the validity mask: we cannot modify the input validity + child_vector_data.validity = ValidityMask(child_vector_data.validity, child_vector_data.validity.Capacity()); if (source_data.validity.IsMaskSet()) { for (idx_t i = 0; i < copy_count; i++) { auto source_idx = source_data.sel->get_index(offset + i); @@ -1015,6 +1022,7 @@ void ColumnDataCollection::InitializeScan(ColumnDataScanState &state, ColumnData void ColumnDataCollection::InitializeScan(ColumnDataScanState &state, vector column_ids, ColumnDataScanProperties properties) const { + state.db = allocator->GetDatabase(); state.chunk_index = 0; state.segment_index = 0; state.current_row_index = 0; @@ -1052,7 +1060,11 @@ bool ColumnDataCollection::Scan(ColumnDataParallelScanState &state, ColumnDataLo } void ColumnDataCollection::InitializeScanChunk(DataChunk &chunk) const { - chunk.Initialize(allocator->GetAllocator(), types); + InitializeScanChunk(allocator->GetAllocator(), chunk); +} + +void ColumnDataCollection::InitializeScanChunk(Allocator &allocator, DataChunk &chunk) const { + chunk.Initialize(allocator, types); } void ColumnDataCollection::InitializeScanChunk(ColumnDataScanState &state, DataChunk &chunk) const { @@ -1354,6 +1366,11 @@ ColumnDataAllocatorType ColumnDataCollection::GetAllocatorType() const { return allocator->GetType(); } +BufferManager &ColumnDataCollection::GetBufferManager() const { + D_ASSERT(allocator->GetType() == ColumnDataAllocatorType::BUFFER_MANAGER_ALLOCATOR); + return allocator->GetBufferManager(); +} + const vector> &ColumnDataCollection::GetSegments() const { return segments; } diff --git a/src/duckdb/src/common/types/data_chunk.cpp b/src/duckdb/src/common/types/data_chunk.cpp index 59e7faba7..6b216b185 100644 --- a/src/duckdb/src/common/types/data_chunk.cpp +++ b/src/duckdb/src/common/types/data_chunk.cpp @@ -254,7 +254,6 @@ string DataChunk::ToString() const { } void DataChunk::Serialize(Serializer &serializer, bool compressed_serialization) const { - // write the count auto row_count = size(); serializer.WriteProperty(100, "rows", NumericCast(row_count)); @@ -279,7 +278,6 @@ void DataChunk::Serialize(Serializer &serializer, bool compressed_serialization) } void DataChunk::Deserialize(Deserializer &deserializer) { - // read and set the row count auto row_count = deserializer.ReadProperty(100, "rows"); diff --git a/src/duckdb/src/common/types/decimal.cpp b/src/duckdb/src/common/types/decimal.cpp index 5ecb39a0a..8fa226455 100644 --- a/src/duckdb/src/common/types/decimal.cpp +++ b/src/duckdb/src/common/types/decimal.cpp @@ -3,6 +3,7 @@ #include "duckdb/common/types/cast_helpers.hpp" namespace duckdb { +constexpr uint8_t Decimal::MAX_WIDTH_DECIMAL; template string TemplatedDecimalToString(SIGNED value, uint8_t width, uint8_t scale) { diff --git a/src/duckdb/src/common/types/geometry.cpp b/src/duckdb/src/common/types/geometry.cpp new file mode 100644 index 000000000..fc73b362c --- /dev/null +++ b/src/duckdb/src/common/types/geometry.cpp @@ -0,0 +1,1139 @@ +#include "duckdb/common/types/geometry.hpp" +#include "duckdb/common/types/string_type.hpp" +#include "duckdb/common/types/vector.hpp" +#include "duckdb/common/vector_operations/unary_executor.hpp" +#include "fast_float/fast_float.h" +#include "fmt/format.h" + +//---------------------------------------------------------------------------------------------------------------------- +// Internals +//---------------------------------------------------------------------------------------------------------------------- +namespace duckdb { + +namespace { + +class BlobWriter { +public: + template + void Write(const T &value) { + auto ptr = reinterpret_cast(&value); + buffer.insert(buffer.end(), ptr, ptr + sizeof(T)); + } + + template + struct Reserved { + size_t offset; + T value; + }; + + template + Reserved Reserve() { + auto offset = buffer.size(); + buffer.resize(buffer.size() + sizeof(T)); + return {offset, T()}; + } + + template + void Write(const Reserved &reserved) { + if (reserved.offset + sizeof(T) > buffer.size()) { + throw InternalException("Write out of bounds in BinaryWriter"); + } + auto ptr = reinterpret_cast(&reserved.value); + // We've reserved 0 bytes, so we can safely memcpy + memcpy(buffer.data() + reserved.offset, ptr, sizeof(T)); + } + + void Write(const char *data, size_t size) { + D_ASSERT(data != nullptr); + buffer.insert(buffer.end(), data, data + size); + } + + const vector &GetBuffer() const { + return buffer; + } + + void Clear() { + buffer.clear(); + } + +private: + vector buffer; +}; + +class FixedSizeBlobWriter { +public: + FixedSizeBlobWriter(char *data, uint32_t size) : beg(data), pos(data), end(data + size) { + } + + template + void Write(const T &value) { + if (pos + sizeof(T) > end) { + throw InvalidInputException("Writing beyond end of binary data at position %zu", pos - beg); + } + memcpy(pos, &value, sizeof(T)); + pos += sizeof(T); + } + + void Write(const char *data, size_t size) { + if (pos + size > end) { + throw InvalidInputException("Writing beyond end of binary data at position %zu", pos - beg); + } + memcpy(pos, data, size); + pos += size; + } + + size_t GetPosition() const { + return static_cast(pos - beg); + } + +private: + const char *beg; + char *pos; + const char *end; +}; + +class BlobReader { +public: + BlobReader(const char *data, uint32_t size) : beg(data), pos(data), end(data + size) { + } + + template + T Read(const bool le) { + if (le) { + return Read(); + } else { + return Read(); + } + } + + template + T Read() { + if (pos + sizeof(T) > end) { + throw InvalidInputException("Unexpected end of binary data at position %zu", pos - beg); + } + T value; + if (LE) { + memcpy(&value, pos, sizeof(T)); + pos += sizeof(T); + } else { + char temp[sizeof(T)]; + for (size_t i = 0; i < sizeof(T); ++i) { + temp[i] = pos[sizeof(T) - 1 - i]; + } + memcpy(&value, temp, sizeof(T)); + pos += sizeof(T); + } + return value; + } + + void Skip(size_t size) { + if (pos + size > end) { + throw InvalidInputException("Skipping beyond end of binary data at position %zu", pos - beg); + } + pos += size; + } + + const char *Reserve(size_t size) { + if (pos + size > end) { + throw InvalidInputException("Reserving beyond end of binary data at position %zu", pos - beg); + } + auto current_pos = pos; + pos += size; + return current_pos; + } + + size_t GetPosition() const { + return static_cast(pos - beg); + } + + const char *GetDataPtr() const { + return pos; + } + + bool IsAtEnd() const { + return pos >= end; + } + + void Reset() { + pos = beg; + } + +private: + const char *beg; + const char *pos; + const char *end; +}; + +class TextWriter { +public: + void Write(const char *str) { + buffer.insert(buffer.end(), str, str + strlen(str)); + } + void Write(char c) { + buffer.push_back(c); + } + void Write(double value) { + duckdb_fmt::format_to(std::back_inserter(buffer), "{}", value); + // Remove trailing zero + if (buffer.back() == '0') { + buffer.pop_back(); + if (buffer.back() == '.') { + buffer.pop_back(); + } + } + } + const vector &GetBuffer() const { + return buffer; + } + +private: + vector buffer; +}; + +class TextReader { +public: + TextReader(const char *text, const uint32_t size) : beg(text), pos(text), end(text + size) { + } + + bool TryMatch(const char *str) { + auto ptr = pos; + while (*str && pos < end && tolower(*pos) == tolower(*str)) { + pos++; + str++; + } + if (*str == '\0') { + SkipWhitespace(); // remove trailing whitespace + return true; // matched + } + pos = ptr; // reset position + return false; // not matched + } + + bool TryMatch(char c) { + if (pos < end && tolower(*pos) == tolower(c)) { + pos++; + SkipWhitespace(); // remove trailing whitespace + return true; // matched + } + return false; // not matched + } + + void Match(const char *str) { + if (!TryMatch(str)) { + throw InvalidInputException("Expected '%s' but got '%c' at position %zu", str, *pos, pos - beg); + } + } + + void Match(char c) { + if (!TryMatch(c)) { + throw InvalidInputException("Expected '%c' but got '%c' at position %zu", c, *pos, pos - beg); + } + } + + double MatchNumber() { + // Now use fast_float to parse the number + double num; + const auto res = duckdb_fast_float::from_chars(pos, end, num); + if (res.ec != std::errc()) { + throw InvalidInputException("Expected number at position %zu", pos - beg); + } + + pos = res.ptr; // update position to the end of the parsed number + + SkipWhitespace(); // remove trailing whitespace + return num; // return the parsed number + } + + idx_t GetPosition() const { + return static_cast(pos - beg); + } + + void Reset() { + pos = beg; + } + +private: + void SkipWhitespace() { + while (pos < end && isspace(*pos)) { + pos++; + } + } + + const char *beg; + const char *pos; + const char *end; +}; + +void FromStringRecursive(TextReader &reader, BlobWriter &writer, uint32_t depth, bool parent_has_z, bool parent_has_m) { + if (depth == Geometry::MAX_RECURSION_DEPTH) { + throw InvalidInputException("Geometry string exceeds maximum recursion depth of %d", + Geometry::MAX_RECURSION_DEPTH); + } + + GeometryType type; + + if (reader.TryMatch("point")) { + type = GeometryType::POINT; + } else if (reader.TryMatch("linestring")) { + type = GeometryType::LINESTRING; + } else if (reader.TryMatch("polygon")) { + type = GeometryType::POLYGON; + } else if (reader.TryMatch("multipoint")) { + type = GeometryType::MULTIPOINT; + } else if (reader.TryMatch("multilinestring")) { + type = GeometryType::MULTILINESTRING; + } else if (reader.TryMatch("multipolygon")) { + type = GeometryType::MULTIPOLYGON; + } else if (reader.TryMatch("geometrycollection")) { + type = GeometryType::GEOMETRYCOLLECTION; + } else { + throw InvalidInputException("Unknown geometry type at position %zu", reader.GetPosition()); + } + + const auto has_z = reader.TryMatch("z"); + const auto has_m = reader.TryMatch("m"); + + const auto is_empty = reader.TryMatch("empty"); + + if ((depth != 0) && ((parent_has_z != has_z) || (parent_has_m != has_m))) { + throw InvalidInputException("Geometry has inconsistent Z/M dimensions, starting at position %zu", + reader.GetPosition()); + } + + // How many dimensions does this geometry have? + const uint32_t dims = 2 + (has_z ? 1 : 0) + (has_m ? 1 : 0); + + // WKB type + const auto meta = static_cast(type) + (has_z ? 1000 : 0) + (has_m ? 2000 : 0); + // Write the geometry type and vertex type + writer.Write(1); // LE Byte Order + writer.Write(meta); + + switch (type) { + case GeometryType::POINT: { + if (is_empty) { + for (uint32_t d_idx = 0; d_idx < dims; d_idx++) { + // Write NaN for each dimension, if point is empty + writer.Write(std::numeric_limits::quiet_NaN()); + } + } else { + reader.Match('('); + for (uint32_t d_idx = 0; d_idx < dims; d_idx++) { + auto value = reader.MatchNumber(); + writer.Write(value); + } + reader.Match(')'); + } + } break; + case GeometryType::LINESTRING: { + if (is_empty) { + writer.Write(0); // No vertices in empty linestring + break; + } + auto vert_count = writer.Reserve(); + reader.Match('('); + do { + for (uint32_t d_idx = 0; d_idx < dims; d_idx++) { + auto value = reader.MatchNumber(); + writer.Write(value); + } + vert_count.value++; + } while (reader.TryMatch(',')); + reader.Match(')'); + writer.Write(vert_count); + } break; + case GeometryType::POLYGON: { + if (is_empty) { + writer.Write(0); + break; // No rings in empty polygon + } + auto ring_count = writer.Reserve(); + reader.Match('('); + do { + auto vert_count = writer.Reserve(); + reader.Match('('); + do { + for (uint32_t d_idx = 0; d_idx < dims; d_idx++) { + auto value = reader.MatchNumber(); + writer.Write(value); + } + vert_count.value++; + } while (reader.TryMatch(',')); + reader.Match(')'); + writer.Write(vert_count); + ring_count.value++; + } while (reader.TryMatch(',')); + reader.Match(')'); + writer.Write(ring_count); + } break; + case GeometryType::MULTIPOINT: { + if (is_empty) { + writer.Write(0); // No points in empty multipoint + break; + } + auto part_count = writer.Reserve(); + reader.Match('('); + do { + bool has_paren = reader.TryMatch('('); + + const auto part_meta = static_cast(GeometryType::POINT) + (has_z ? 1000 : 0) + (has_m ? 2000 : 0); + writer.Write(1); + writer.Write(part_meta); + + if (reader.TryMatch("EMPTY")) { + for (uint32_t d_idx = 0; d_idx < dims; d_idx++) { + // Write NaN for each dimension, if point is empty + writer.Write(std::numeric_limits::quiet_NaN()); + } + } else { + for (uint32_t d_idx = 0; d_idx < dims; d_idx++) { + auto value = reader.MatchNumber(); + writer.Write(value); + } + } + if (has_paren) { + reader.Match(')'); // Match the closing parenthesis if it was opened + } + part_count.value++; + } while (reader.TryMatch(',')); + writer.Write(part_count); + } break; + case GeometryType::MULTILINESTRING: { + if (is_empty) { + writer.Write(0); + return; // No linestrings in empty multilinestring + } + auto part_count = writer.Reserve(); + reader.Match('('); + do { + const auto part_meta = + static_cast(GeometryType::LINESTRING) + (has_z ? 1000 : 0) + (has_m ? 2000 : 0); + writer.Write(1); + writer.Write(part_meta); + + auto vert_count = writer.Reserve(); + reader.Match('('); + do { + for (uint32_t d_idx = 0; d_idx < dims; d_idx++) { + auto value = reader.MatchNumber(); + writer.Write(value); + } + vert_count.value++; + } while (reader.TryMatch(',')); + reader.Match(')'); + writer.Write(vert_count); + part_count.value++; + } while (reader.TryMatch(',')); + reader.Match(')'); + writer.Write(part_count); + } break; + case GeometryType::MULTIPOLYGON: { + if (is_empty) { + writer.Write(0); // No polygons in empty multipolygon + break; + } + auto part_count = writer.Reserve(); + reader.Match('('); + do { + const auto part_meta = + static_cast(GeometryType::POLYGON) + (has_z ? 1000 : 0) + (has_m ? 2000 : 0); + writer.Write(1); + writer.Write(part_meta); + + auto ring_count = writer.Reserve(); + reader.Match('('); + do { + auto vert_count = writer.Reserve(); + reader.Match('('); + do { + for (uint32_t d_idx = 0; d_idx < dims; d_idx++) { + auto value = reader.MatchNumber(); + writer.Write(value); + } + vert_count.value++; + } while (reader.TryMatch(',')); + reader.Match(')'); + writer.Write(vert_count); + ring_count.value++; + } while (reader.TryMatch(',')); + reader.Match(')'); + writer.Write(ring_count); + part_count.value++; + } while (reader.TryMatch(',')); + reader.Match(')'); + writer.Write(part_count); + } break; + case GeometryType::GEOMETRYCOLLECTION: { + if (is_empty) { + writer.Write(0); // No geometries in empty geometry collection + break; + } + auto part_count = writer.Reserve(); + reader.Match('('); + do { + // Recursively parse the geometry inside the collection + FromStringRecursive(reader, writer, depth + 1, has_z, has_m); + part_count.value++; + } while (reader.TryMatch(',')); + reader.Match(')'); + writer.Write(part_count); + } break; + default: + throw InvalidInputException("Unknown geometry type %d at position %zu", static_cast(type), + reader.GetPosition()); + } +} + +void ToStringRecursive(BlobReader &reader, TextWriter &writer, idx_t depth, bool parent_has_z, bool parent_has_m) { + if (depth == Geometry::MAX_RECURSION_DEPTH) { + throw InvalidInputException("Geometry exceeds maximum recursion depth of %d", Geometry::MAX_RECURSION_DEPTH); + } + + // Read the byte order (should always be 1 for little-endian) + auto byte_order = reader.Read(); + if (byte_order != 1) { + throw InvalidInputException("Unsupported byte order %d in WKB", byte_order); + } + + const auto meta = reader.Read(); + const auto type = static_cast((meta & 0x0000FFFF) % 1000); + const auto flag = (meta & 0x0000FFFF) / 1000; + const auto has_z = (flag & 0x01) != 0; + const auto has_m = (flag & 0x02) != 0; + + if ((depth != 0) && ((parent_has_z != has_z) || (parent_has_m != has_m))) { + throw InvalidInputException("Geometry has inconsistent Z/M dimensions, starting at position %zu", + reader.GetPosition()); + } + + const uint32_t dims = 2 + (has_z ? 1 : 0) + (has_m ? 1 : 0); + const auto flag_str = has_z ? (has_m ? " ZM " : " Z ") : (has_m ? " M " : " "); + + switch (type) { + case GeometryType::POINT: { + writer.Write("POINT"); + writer.Write(flag_str); + + double vert[4] = {0, 0, 0, 0}; + auto all_nan = true; + for (uint32_t d_idx = 0; d_idx < dims; d_idx++) { + vert[d_idx] = reader.Read(); + all_nan &= std::isnan(vert[d_idx]); + } + if (all_nan) { + writer.Write("EMPTY"); + return; + } + writer.Write('('); + for (uint32_t d_idx = 0; d_idx < dims; d_idx++) { + if (d_idx > 0) { + writer.Write(' '); + } + writer.Write(vert[d_idx]); + } + writer.Write(')'); + } break; + case GeometryType::LINESTRING: { + writer.Write("LINESTRING"); + ; + writer.Write(flag_str); + const auto vert_count = reader.Read(); + if (vert_count == 0) { + writer.Write("EMPTY"); + return; + } + writer.Write('('); + for (uint32_t vert_idx = 0; vert_idx < vert_count; vert_idx++) { + if (vert_idx > 0) { + writer.Write(", "); + } + for (uint32_t d_idx = 0; d_idx < dims; d_idx++) { + if (d_idx > 0) { + writer.Write(' '); + } + auto value = reader.Read(); + writer.Write(value); + } + } + writer.Write(')'); + } break; + case GeometryType::POLYGON: { + writer.Write("POLYGON"); + writer.Write(flag_str); + const auto ring_count = reader.Read(); + if (ring_count == 0) { + writer.Write("EMPTY"); + return; + } + writer.Write('('); + for (uint32_t ring_idx = 0; ring_idx < ring_count; ring_idx++) { + if (ring_idx > 0) { + writer.Write(", "); + } + const auto vert_count = reader.Read(); + if (vert_count == 0) { + writer.Write("EMPTY"); + continue; + } + writer.Write('('); + for (uint32_t vert_idx = 0; vert_idx < vert_count; vert_idx++) { + if (vert_idx > 0) { + writer.Write(", "); + } + for (uint32_t d_idx = 0; d_idx < dims; d_idx++) { + if (d_idx > 0) { + writer.Write(' '); + } + auto value = reader.Read(); + writer.Write(value); + } + } + writer.Write(')'); + } + writer.Write(')'); + } break; + case GeometryType::MULTIPOINT: { + writer.Write("MULTIPOINT"); + writer.Write(flag_str); + const auto part_count = reader.Read(); + if (part_count == 0) { + writer.Write("EMPTY"); + return; + } + writer.Write('('); + for (uint32_t part_idx = 0; part_idx < part_count; part_idx++) { + const auto part_byte_order = reader.Read(); + if (part_byte_order != 1) { + throw InvalidInputException("Unsupported byte order %d in WKB", part_byte_order); + } + const auto part_meta = reader.Read(); + const auto part_type = static_cast((part_meta & 0x0000FFFF) % 1000); + const auto part_flag = (part_meta & 0x0000FFFF) / 1000; + const auto part_has_z = (part_flag & 0x01) != 0; + const auto part_has_m = (part_flag & 0x02) != 0; + + if (part_type != GeometryType::POINT) { + throw InvalidInputException("Expected POINT in MULTIPOINT but got %d", static_cast(part_type)); + } + + if ((has_z != part_has_z) || (has_m != part_has_m)) { + throw InvalidInputException( + "Geometry has inconsistent Z/M dimensions in MULTIPOINT, starting at position %zu", + reader.GetPosition()); + } + if (part_idx > 0) { + writer.Write(", "); + } + double vert[4] = {0, 0, 0, 0}; + auto all_nan = true; + for (uint32_t d_idx = 0; d_idx < dims; d_idx++) { + vert[d_idx] = reader.Read(); + all_nan &= std::isnan(vert[d_idx]); + } + if (all_nan) { + writer.Write("EMPTY"); + continue; + } + // writer.Write('('); + for (uint32_t d_idx = 0; d_idx < dims; d_idx++) { + if (d_idx > 0) { + writer.Write(' '); + } + writer.Write(vert[d_idx]); + } + // writer.Write(')'); + } + writer.Write(')'); + + } break; + case GeometryType::MULTILINESTRING: { + writer.Write("MULTILINESTRING"); + writer.Write(flag_str); + const auto part_count = reader.Read(); + if (part_count == 0) { + writer.Write("EMPTY"); + return; + } + writer.Write('('); + for (uint32_t part_idx = 0; part_idx < part_count; part_idx++) { + const auto part_byte_order = reader.Read(); + if (part_byte_order != 1) { + throw InvalidInputException("Unsupported byte order %d in WKB", part_byte_order); + } + const auto part_meta = reader.Read(); + const auto part_type = static_cast((part_meta & 0x0000FFFF) % 1000); + const auto part_flag = (part_meta & 0x0000FFFF) / 1000; + const auto part_has_z = (part_flag & 0x01) != 0; + const auto part_has_m = (part_flag & 0x02) != 0; + + if (part_type != GeometryType::LINESTRING) { + throw InvalidInputException("Expected LINESTRING in MULTILINESTRING but got %d", + static_cast(part_type)); + } + if ((has_z != part_has_z) || (has_m != part_has_m)) { + throw InvalidInputException( + "Geometry has inconsistent Z/M dimensions in MULTILINESTRING, starting at position %zu", + reader.GetPosition()); + } + if (part_idx > 0) { + writer.Write(", "); + } + const auto vert_count = reader.Read(); + if (vert_count == 0) { + writer.Write("EMPTY"); + continue; + } + writer.Write('('); + for (uint32_t vert_idx = 0; vert_idx < vert_count; vert_idx++) { + if (vert_idx > 0) { + writer.Write(", "); + } + for (uint32_t d_idx = 0; d_idx < dims; d_idx++) { + if (d_idx > 0) { + writer.Write(' '); + } + auto value = reader.Read(); + writer.Write(value); + } + } + writer.Write(')'); + } + writer.Write(')'); + } break; + case GeometryType::MULTIPOLYGON: { + writer.Write("MULTIPOLYGON"); + writer.Write(flag_str); + const auto part_count = reader.Read(); + if (part_count == 0) { + writer.Write("EMPTY"); + return; + } + writer.Write('('); + for (uint32_t part_idx = 0; part_idx < part_count; part_idx++) { + if (part_idx > 0) { + writer.Write(", "); + } + + const auto part_byte_order = reader.Read(); + if (part_byte_order != 1) { + throw InvalidInputException("Unsupported byte order %d in WKB", part_byte_order); + } + const auto part_meta = reader.Read(); + const auto part_type = static_cast((part_meta & 0x0000FFFF) % 1000); + const auto part_flag = (part_meta & 0x0000FFFF) / 1000; + const auto part_has_z = (part_flag & 0x01) != 0; + const auto part_has_m = (part_flag & 0x02) != 0; + if (part_type != GeometryType::POLYGON) { + throw InvalidInputException("Expected POLYGON in MULTIPOLYGON but got %d", static_cast(part_type)); + } + if ((has_z != part_has_z) || (has_m != part_has_m)) { + throw InvalidInputException( + "Geometry has inconsistent Z/M dimensions in MULTIPOLYGON, starting at position %zu", + reader.GetPosition()); + } + + const auto ring_count = reader.Read(); + if (ring_count == 0) { + writer.Write("EMPTY"); + continue; + } + writer.Write('('); + for (uint32_t ring_idx = 0; ring_idx < ring_count; ring_idx++) { + if (ring_idx > 0) { + writer.Write(", "); + } + const auto vert_count = reader.Read(); + if (vert_count == 0) { + writer.Write("EMPTY"); + continue; + } + writer.Write('('); + for (uint32_t vert_idx = 0; vert_idx < vert_count; vert_idx++) { + if (vert_idx > 0) { + writer.Write(", "); + } + for (uint32_t d_idx = 0; d_idx < dims; d_idx++) { + if (d_idx > 0) { + writer.Write(' '); + } + auto value = reader.Read(); + writer.Write(value); + } + } + writer.Write(')'); + } + writer.Write(')'); + } + writer.Write(')'); + } break; + case GeometryType::GEOMETRYCOLLECTION: { + writer.Write("GEOMETRYCOLLECTION"); + writer.Write(flag_str); + const auto part_count = reader.Read(); + if (part_count == 0) { + writer.Write("EMPTY"); + return; + } + writer.Write('('); + for (uint32_t part_idx = 0; part_idx < part_count; part_idx++) { + if (part_idx > 0) { + writer.Write(", "); + } + // Recursively parse the geometry inside the collection + ToStringRecursive(reader, writer, depth + 1, has_z, has_m); + } + writer.Write(')'); + } break; + default: + throw InvalidInputException("Unsupported geometry type %d in WKB", static_cast(type)); + } +} + +struct WKBAnalysis { + uint32_t size = 0; + bool any_be = false; + bool any_z = false; + bool any_m = false; + bool any_unknown = false; + bool any_ewkb = false; +}; + +WKBAnalysis AnalyzeWKB(BlobReader &reader) { + WKBAnalysis result; + + while (!reader.IsAtEnd()) { + const auto le = reader.Read() == 1; + + const auto meta = reader.Read(le); + const auto type_id = (meta & 0x0000FFFF) % 1000; + const auto flag_id = (meta & 0x0000FFFF) / 1000; + + // Extended WKB detection + const auto has_extz = (meta & 0x80000000) != 0; + const auto has_extm = (meta & 0x40000000) != 0; + const auto has_srid = (meta & 0x20000000) != 0; + + const auto has_z = ((flag_id & 0x01) != 0) || has_extz; + const auto has_m = ((flag_id & 0x02) != 0) || has_extm; + + if (has_srid) { + result.any_ewkb = true; + reader.Skip(sizeof(uint32_t)); // Skip SRID + // Do not include SRID in the size + } + + if (has_extz || has_extm || has_srid) { + // EWKB flags are set + result.any_ewkb = true; + } + + const auto v_size = (2 + (has_z ? 1 : 0) + (has_m ? 1 : 0)) * sizeof(double); + + result.any_z |= has_z; + result.any_m |= has_m; + result.any_be |= !le; + + result.size += sizeof(uint8_t) + sizeof(uint32_t); // Byte order + type/meta + + switch (type_id) { + case 1: { // POINT + reader.Skip(v_size); + result.size += v_size; + } break; + case 2: { // LINESTRING + const auto vert_count = reader.Read(le); + reader.Skip(vert_count * v_size); + result.size += sizeof(uint32_t) + vert_count * v_size; + } break; + case 3: { // POLYGON + const auto ring_count = reader.Read(le); + result.size += sizeof(uint32_t); + for (uint32_t ring_idx = 0; ring_idx < ring_count; ring_idx++) { + const auto vert_count = reader.Read(le); + reader.Skip(vert_count * v_size); + result.size += sizeof(uint32_t) + vert_count * v_size; + } + } break; + case 4: // MULTIPOINT + case 5: // MULTILINESTRING + case 6: // MULTIPOLYGON + case 7: { // GEOMETRYCOLLECTION + reader.Skip(sizeof(uint32_t)); + result.size += sizeof(uint32_t); // part count + } break; + default: { + result.any_unknown = true; + return result; + } + } + } + return result; +} + +void ConvertWKB(BlobReader &reader, FixedSizeBlobWriter &writer) { + while (!reader.IsAtEnd()) { + const auto le = reader.Read() == 1; + const auto meta = reader.Read(le); + const auto type_id = (meta & 0x0000FFFF) % 1000; + const auto flag_id = (meta & 0x0000FFFF) / 1000; + + // Extended WKB detection + const auto has_extz = (meta & 0x80000000) != 0; + const auto has_extm = (meta & 0x40000000) != 0; + const auto has_srid = (meta & 0x20000000) != 0; + + const auto has_z = ((flag_id & 0x01) != 0) || has_extz; + const auto has_m = ((flag_id & 0x02) != 0) || has_extm; + + if (has_srid) { + reader.Skip(sizeof(uint32_t)); // Skip SRID + } + + const auto v_width = static_cast((2 + (has_z ? 1 : 0) + (has_m ? 1 : 0))); + + writer.Write(1); // Always write LE + writer.Write(type_id + (1000 * has_z) + (2000 * has_m)); // Write meta + + switch (type_id) { + case 1: { // POINT + for (uint32_t d_idx = 0; d_idx < v_width; d_idx++) { + auto value = reader.Read(le); + writer.Write(value); + } + } break; + case 2: { // LINESTRING + const auto vert_count = reader.Read(le); + writer.Write(vert_count); + for (uint32_t vert_idx = 0; vert_idx < vert_count; vert_idx++) { + for (uint32_t d_idx = 0; d_idx < v_width; d_idx++) { + auto value = reader.Read(le); + writer.Write(value); + } + } + } break; + case 3: { // POLYGON + const auto ring_count = reader.Read(le); + writer.Write(ring_count); + for (uint32_t ring_idx = 0; ring_idx < ring_count; ring_idx++) { + const auto vert_count = reader.Read(le); + writer.Write(vert_count); + for (uint32_t vert_idx = 0; vert_idx < vert_count; vert_idx++) { + for (uint32_t d_idx = 0; d_idx < v_width; d_idx++) { + auto value = reader.Read(le); + writer.Write(value); + } + } + } + } break; + case 4: // MULTIPOINT + case 5: // MULTILINESTRING + case 6: // MULTIPOLYGON + case 7: { // GEOMETRYCOLLECTION + const auto part_count = reader.Read(le); + writer.Write(part_count); + } break; + default: + D_ASSERT(false); + break; + } + } +} + +} // namespace + +} // namespace duckdb + +//---------------------------------------------------------------------------------------------------------------------- +// Public interface +//---------------------------------------------------------------------------------------------------------------------- +namespace duckdb { + +constexpr const idx_t Geometry::MAX_RECURSION_DEPTH; + +bool Geometry::FromBinary(const string_t &wkb, string_t &result, Vector &result_vector, bool strict) { + BlobReader reader(wkb.GetData(), static_cast(wkb.GetSize())); + + const auto analysis = AnalyzeWKB(reader); + if (analysis.any_unknown) { + if (strict) { + throw InvalidInputException("Unsupported geometry type in WKB"); + } + return false; + } + + if (analysis.any_be || analysis.any_ewkb) { + reader.Reset(); + // Make a new WKB with all LE + auto blob = StringVector::EmptyString(result_vector, analysis.size); + FixedSizeBlobWriter writer(blob.GetDataWriteable(), static_cast(blob.GetSize())); + ConvertWKB(reader, writer); + blob.Finalize(); + result = blob; + return true; + } + + // Copy the WKB as-is + result = StringVector::AddStringOrBlob(result_vector, wkb.GetData(), wkb.GetSize()); + return true; +} + +void Geometry::FromBinary(Vector &source, Vector &result, idx_t count, bool strict) { + if (strict) { + UnaryExecutor::Execute(source, result, count, [&](const string_t &wkb) { + string_t geom; + FromBinary(wkb, geom, result, strict); + return geom; + }); + } else { + UnaryExecutor::ExecuteWithNulls(source, result, count, + [&](const string_t &wkb, ValidityMask &mask, idx_t idx) { + string_t geom; + if (!FromBinary(wkb, geom, result, strict)) { + mask.SetInvalid(idx); + return string_t(); + } + return geom; + }); + } +} + +void Geometry::ToBinary(Vector &source, Vector &result, idx_t count) { + // We are currently using WKB internally, so just copy as-is! + result.Reference(source); +} + +bool Geometry::FromString(const string_t &wkt_text, string_t &result, Vector &result_vector, bool strict) { + TextReader reader(wkt_text.GetData(), static_cast(wkt_text.GetSize())); + BlobWriter writer; + + FromStringRecursive(reader, writer, 0, false, false); + + const auto &buffer = writer.GetBuffer(); + result = StringVector::AddStringOrBlob(result_vector, buffer.data(), buffer.size()); + return true; +} + +string_t Geometry::ToString(Vector &result, const string_t &geom) { + BlobReader reader(geom.GetData(), static_cast(geom.GetSize())); + TextWriter writer; + + ToStringRecursive(reader, writer, 0, false, false); + + // Convert the buffer to string_t + const auto &buffer = writer.GetBuffer(); + return StringVector::AddString(result, buffer.data(), buffer.size()); +} + +pair Geometry::GetType(const string_t &wkb) { + BlobReader reader(wkb.GetData(), static_cast(wkb.GetSize())); + + // Read the byte order (should always be 1 for little-endian) + const auto byte_order = reader.Read(); + if (byte_order != 1) { + throw InvalidInputException("Unsupported byte order %d in WKB", byte_order); + } + + const auto meta = reader.Read(); + const auto type_id = (meta & 0x0000FFFF) % 1000; + const auto flag_id = (meta & 0x0000FFFF) / 1000; + + if (type_id < 1 || type_id > 7) { + throw InvalidInputException("Unsupported geometry type %d in WKB", type_id); + } + if (flag_id > 3) { + throw InvalidInputException("Unsupported geometry flag %d in WKB", flag_id); + } + + const auto geom_type = static_cast(type_id); + const auto vert_type = static_cast(flag_id); + + return {geom_type, vert_type}; +} + +template +static uint32_t ParseVerticesInternal(BlobReader &reader, GeometryExtent &extent, uint32_t vert_count, bool check_nan) { + uint32_t count = 0; + + // Issue a single .Reserve() for all vertices, to minimize bounds checking overhead + const auto ptr = const_data_ptr_cast(reader.Reserve(vert_count * sizeof(VERTEX_TYPE))); + + for (uint32_t vert_idx = 0; vert_idx < vert_count; vert_idx++) { + VERTEX_TYPE vertex = Load(ptr + vert_idx * sizeof(VERTEX_TYPE)); + if (check_nan && vertex.AllNan()) { + continue; + } + + extent.Extend(vertex); + count++; + } + return count; +} + +static uint32_t ParseVertices(BlobReader &reader, GeometryExtent &extent, uint32_t vert_count, VertexType type, + bool check_nan) { + switch (type) { + case VertexType::XY: + return ParseVerticesInternal(reader, extent, vert_count, check_nan); + case VertexType::XYZ: + return ParseVerticesInternal(reader, extent, vert_count, check_nan); + case VertexType::XYM: + return ParseVerticesInternal(reader, extent, vert_count, check_nan); + case VertexType::XYZM: + return ParseVerticesInternal(reader, extent, vert_count, check_nan); + default: + throw InvalidInputException("Unsupported vertex type %d in WKB", static_cast(type)); + } +} + +uint32_t Geometry::GetExtent(const string_t &wkb, GeometryExtent &extent) { + BlobReader reader(wkb.GetData(), static_cast(wkb.GetSize())); + + uint32_t vertex_count = 0; + + while (!reader.IsAtEnd()) { + const auto byte_order = reader.Read(); + if (byte_order != 1) { + throw InvalidInputException("Unsupported byte order %d in WKB", byte_order); + } + const auto meta = reader.Read(); + const auto type_id = (meta & 0x0000FFFF) % 1000; + const auto flag_id = (meta & 0x0000FFFF) / 1000; + if (type_id < 1 || type_id > 7) { + throw InvalidInputException("Unsupported geometry type %d in WKB", type_id); + } + if (flag_id > 3) { + throw InvalidInputException("Unsupported geometry flag %d in WKB", flag_id); + } + const auto geom_type = static_cast(type_id); + const auto vert_type = static_cast(flag_id); + + switch (geom_type) { + case GeometryType::POINT: { + vertex_count += ParseVertices(reader, extent, 1, vert_type, true); + } break; + case GeometryType::LINESTRING: { + const auto vert_count = reader.Read(); + vertex_count += ParseVertices(reader, extent, vert_count, vert_type, false); + } break; + case GeometryType::POLYGON: { + const auto ring_count = reader.Read(); + for (uint32_t ring_idx = 0; ring_idx < ring_count; ring_idx++) { + const auto vert_count = reader.Read(); + vertex_count += ParseVertices(reader, extent, vert_count, vert_type, false); + } + } break; + case GeometryType::MULTIPOINT: + case GeometryType::MULTILINESTRING: + case GeometryType::MULTIPOLYGON: + case GeometryType::GEOMETRYCOLLECTION: { + // Skip count. We don't need it for extent calculation. + reader.Skip(sizeof(uint32_t)); + } break; + default: + throw InvalidInputException("Unsupported geometry type %d in WKB", static_cast(geom_type)); + } + } + return vertex_count; +} + +} // namespace duckdb diff --git a/src/duckdb/src/common/types/list_segment.cpp b/src/duckdb/src/common/types/list_segment.cpp index 8145cf07f..88c00dbc2 100644 --- a/src/duckdb/src/common/types/list_segment.cpp +++ b/src/duckdb/src/common/types/list_segment.cpp @@ -239,7 +239,6 @@ static ListSegment *GetSegment(const ListSegmentFunctions &functions, ArenaAlloc template static void WriteDataToPrimitiveSegment(const ListSegmentFunctions &, ArenaAllocator &, ListSegment *segment, RecursiveUnifiedVectorFormat &input_data, idx_t &entry_idx) { - auto sel_entry_idx = input_data.unified.sel->get_index(entry_idx); // write null validity @@ -258,7 +257,6 @@ static void WriteDataToPrimitiveSegment(const ListSegmentFunctions &, ArenaAlloc static void WriteDataToVarcharSegment(const ListSegmentFunctions &functions, ArenaAllocator &allocator, ListSegment *segment, RecursiveUnifiedVectorFormat &input_data, idx_t &entry_idx) { - auto sel_entry_idx = input_data.unified.sel->get_index(entry_idx); // write null validity @@ -297,7 +295,6 @@ static void WriteDataToVarcharSegment(const ListSegmentFunctions &functions, Are static void WriteDataToListSegment(const ListSegmentFunctions &functions, ArenaAllocator &allocator, ListSegment *segment, RecursiveUnifiedVectorFormat &input_data, idx_t &entry_idx) { - auto sel_entry_idx = input_data.unified.sel->get_index(entry_idx); // write null validity @@ -331,7 +328,6 @@ static void WriteDataToListSegment(const ListSegmentFunctions &functions, ArenaA static void WriteDataToStructSegment(const ListSegmentFunctions &functions, ArenaAllocator &allocator, ListSegment *segment, RecursiveUnifiedVectorFormat &input_data, idx_t &entry_idx) { - auto sel_entry_idx = input_data.unified.sel->get_index(entry_idx); // write null validity @@ -376,7 +372,6 @@ static void WriteDataToArraySegment(const ListSegmentFunctions &functions, Arena void ListSegmentFunctions::AppendRow(ArenaAllocator &allocator, LinkedList &linked_list, RecursiveUnifiedVectorFormat &input_data, idx_t &entry_idx) const { - auto &write_data_to_segment = *this; auto segment = GetSegment(write_data_to_segment, allocator, linked_list); write_data_to_segment.write_data(write_data_to_segment, allocator, segment, input_data, entry_idx); @@ -391,7 +386,6 @@ void ListSegmentFunctions::AppendRow(ArenaAllocator &allocator, LinkedList &link template static void ReadDataFromPrimitiveSegment(const ListSegmentFunctions &, const ListSegment *segment, Vector &result, idx_t &total_count) { - auto &aggr_vector_validity = FlatVector::Validity(result); // set NULLs @@ -462,7 +456,6 @@ static void ReadDataFromVarcharSegment(const ListSegmentFunctions &, const ListS static void ReadDataFromListSegment(const ListSegmentFunctions &functions, const ListSegment *segment, Vector &result, idx_t &total_count) { - auto &aggr_vector_validity = FlatVector::Validity(result); // set NULLs @@ -503,7 +496,6 @@ static void ReadDataFromListSegment(const ListSegmentFunctions &functions, const static void ReadDataFromStructSegment(const ListSegmentFunctions &functions, const ListSegment *segment, Vector &result, idx_t &total_count) { - auto &aggr_vector_validity = FlatVector::Validity(result); // set NULLs @@ -528,7 +520,6 @@ static void ReadDataFromStructSegment(const ListSegmentFunctions &functions, con static void ReadDataFromArraySegment(const ListSegmentFunctions &functions, const ListSegment *segment, Vector &result, idx_t &total_count) { - auto &aggr_vector_validity = FlatVector::Validity(result); // set NULLs @@ -570,7 +561,6 @@ void SegmentPrimitiveFunction(ListSegmentFunctions &functions) { } void GetSegmentDataFunctions(ListSegmentFunctions &functions, const LogicalType &type) { - if (type.id() == LogicalTypeId::UNKNOWN) { throw ParameterNotResolvedException(); } diff --git a/src/duckdb/src/common/types/row/block_iterator.cpp b/src/duckdb/src/common/types/row/block_iterator.cpp deleted file mode 100644 index bebba60e1..000000000 --- a/src/duckdb/src/common/types/row/block_iterator.cpp +++ /dev/null @@ -1,34 +0,0 @@ -#include "duckdb/common/types/row/block_iterator.hpp" - -namespace duckdb { - -BlockIteratorStateType GetBlockIteratorStateType(const bool &external) { - return external ? BlockIteratorStateType::EXTERNAL : BlockIteratorStateType::IN_MEMORY; -} - -InMemoryBlockIteratorState::InMemoryBlockIteratorState(const TupleDataCollection &key_data) - : block_ptrs(ConvertBlockPointers(key_data.GetRowBlockPointers())), fast_mod(key_data.TuplesPerBlock()), - tuple_count(key_data.Count()) { -} - -unsafe_vector InMemoryBlockIteratorState::ConvertBlockPointers(const vector &block_ptrs) { - unsafe_vector converted_block_ptrs; - converted_block_ptrs.reserve(block_ptrs.size()); - for (const auto &block_ptr : block_ptrs) { - converted_block_ptrs.emplace_back(block_ptr); - } - return converted_block_ptrs; -} - -ExternalBlockIteratorState::ExternalBlockIteratorState(TupleDataCollection &key_data_p, - optional_ptr payload_data_p) - : tuple_count(key_data_p.Count()), current_chunk_idx(DConstants::INVALID_INDEX), key_data(key_data_p), - key_ptrs(FlatVector::GetData(key_scan_state.chunk_state.row_locations)), payload_data(payload_data_p), - keep_pinned(false), pin_payload(false) { - key_data.InitializeScan(key_scan_state); - if (payload_data) { - payload_data->InitializeScan(payload_scan_state); - } -} - -} // namespace duckdb diff --git a/src/duckdb/src/common/types/row/partitioned_tuple_data.cpp b/src/duckdb/src/common/types/row/partitioned_tuple_data.cpp index eff1186b0..a9be45faa 100644 --- a/src/duckdb/src/common/types/row/partitioned_tuple_data.cpp +++ b/src/duckdb/src/common/types/row/partitioned_tuple_data.cpp @@ -9,13 +9,13 @@ namespace duckdb { PartitionedTupleData::PartitionedTupleData(PartitionedTupleDataType type_p, BufferManager &buffer_manager_p, shared_ptr &layout_ptr_p) - : type(type_p), buffer_manager(buffer_manager_p), layout_ptr(layout_ptr_p), layout(*layout_ptr), count(0), - data_size(0) { + : type(type_p), buffer_manager(buffer_manager_p), + stl_allocator(make_shared_ptr(buffer_manager.GetBufferAllocator())), layout_ptr(layout_ptr_p), + layout(*layout_ptr), count(0), data_size(0) { } -PartitionedTupleData::PartitionedTupleData(const PartitionedTupleData &other) - : type(other.type), buffer_manager(other.buffer_manager), layout_ptr(other.layout_ptr), layout(*layout_ptr), - count(0), data_size(0) { +PartitionedTupleData::PartitionedTupleData(PartitionedTupleData &other) + : PartitionedTupleData(other.type, other.buffer_manager, other.layout_ptr) { } PartitionedTupleData::~PartitionedTupleData() { diff --git a/src/duckdb/src/common/types/row/row_data_collection.cpp b/src/duckdb/src/common/types/row/row_data_collection.cpp deleted file mode 100644 index b178b7fb5..000000000 --- a/src/duckdb/src/common/types/row/row_data_collection.cpp +++ /dev/null @@ -1,141 +0,0 @@ -#include "duckdb/common/types/row/row_data_collection.hpp" - -namespace duckdb { - -RowDataCollection::RowDataCollection(BufferManager &buffer_manager, idx_t block_capacity, idx_t entry_size, - bool keep_pinned) - : buffer_manager(buffer_manager), count(0), block_capacity(block_capacity), entry_size(entry_size), - keep_pinned(keep_pinned) { - D_ASSERT(block_capacity * entry_size + entry_size > buffer_manager.GetBlockSize()); -} - -idx_t RowDataCollection::AppendToBlock(RowDataBlock &block, BufferHandle &handle, - vector &append_entries, idx_t remaining, idx_t entry_sizes[]) { - idx_t append_count = 0; - data_ptr_t dataptr; - if (entry_sizes) { - D_ASSERT(entry_size == 1); - // compute how many entries fit if entry size is variable - dataptr = handle.Ptr() + block.byte_offset; - for (idx_t i = 0; i < remaining; i++) { - if (block.byte_offset + entry_sizes[i] > block.capacity) { - if (block.count == 0 && append_count == 0 && entry_sizes[i] > block.capacity) { - // special case: single entry is bigger than block capacity - // resize current block to fit the entry, append it, and move to the next block - block.capacity = entry_sizes[i]; - buffer_manager.ReAllocate(block.block, block.capacity); - dataptr = handle.Ptr(); - append_count++; - block.byte_offset += entry_sizes[i]; - } - break; - } - append_count++; - block.byte_offset += entry_sizes[i]; - } - } else { - append_count = MinValue(remaining, block.capacity - block.count); - dataptr = handle.Ptr() + block.count * entry_size; - } - append_entries.emplace_back(dataptr, append_count); - block.count += append_count; - return append_count; -} - -RowDataBlock &RowDataCollection::CreateBlock() { - blocks.push_back(make_uniq(MemoryTag::ORDER_BY, buffer_manager, block_capacity, entry_size)); - return *blocks.back(); -} - -vector RowDataCollection::Build(idx_t added_count, data_ptr_t key_locations[], idx_t entry_sizes[], - const SelectionVector *sel) { - vector handles; - vector append_entries; - - // first allocate space of where to serialize the keys and payload columns - idx_t remaining = added_count; - { - // first append to the last block (if any) - lock_guard append_lock(rdc_lock); - count += added_count; - - if (!blocks.empty()) { - auto &last_block = *blocks.back(); - if (last_block.count < last_block.capacity) { - // last block has space: pin the buffer of this block - auto handle = buffer_manager.Pin(last_block.block); - // now append to the block - idx_t append_count = AppendToBlock(last_block, handle, append_entries, remaining, entry_sizes); - remaining -= append_count; - handles.push_back(std::move(handle)); - } - } - while (remaining > 0) { - // now for the remaining data, allocate new buffers to store the data and append there - auto &new_block = CreateBlock(); - auto handle = buffer_manager.Pin(new_block.block); - - // offset the entry sizes array if we have added entries already - idx_t *offset_entry_sizes = entry_sizes ? entry_sizes + added_count - remaining : nullptr; - - idx_t append_count = AppendToBlock(new_block, handle, append_entries, remaining, offset_entry_sizes); - D_ASSERT(new_block.count > 0); - remaining -= append_count; - - if (keep_pinned) { - pinned_blocks.push_back(std::move(handle)); - } else { - handles.push_back(std::move(handle)); - } - } - } - // now set up the key_locations based on the append entries - idx_t append_idx = 0; - for (auto &append_entry : append_entries) { - idx_t next = append_idx + append_entry.count; - if (entry_sizes) { - for (; append_idx < next; append_idx++) { - key_locations[append_idx] = append_entry.baseptr; - append_entry.baseptr += entry_sizes[append_idx]; - } - } else { - for (; append_idx < next; append_idx++) { - auto idx = sel->get_index(append_idx); - key_locations[idx] = append_entry.baseptr; - append_entry.baseptr += entry_size; - } - } - } - // return the unique pointers to the handles because they must stay pinned - return handles; -} - -void RowDataCollection::Merge(RowDataCollection &other) { - if (other.count == 0) { - return; - } - RowDataCollection temp(buffer_manager, buffer_manager.GetBlockSize(), 1); - { - // One lock at a time to avoid deadlocks - lock_guard read_lock(other.rdc_lock); - temp.count = other.count; - temp.block_capacity = other.block_capacity; - temp.entry_size = other.entry_size; - temp.blocks = std::move(other.blocks); - temp.pinned_blocks = std::move(other.pinned_blocks); - } - other.Clear(); - - lock_guard write_lock(rdc_lock); - count += temp.count; - block_capacity = MaxValue(block_capacity, temp.block_capacity); - entry_size = MaxValue(entry_size, temp.entry_size); - for (auto &block : temp.blocks) { - blocks.emplace_back(std::move(block)); - } - for (auto &handle : temp.pinned_blocks) { - pinned_blocks.emplace_back(std::move(handle)); - } -} - -} // namespace duckdb diff --git a/src/duckdb/src/common/types/row/row_data_collection_scanner.cpp b/src/duckdb/src/common/types/row/row_data_collection_scanner.cpp deleted file mode 100644 index 9b3a4be06..000000000 --- a/src/duckdb/src/common/types/row/row_data_collection_scanner.cpp +++ /dev/null @@ -1,330 +0,0 @@ -#include "duckdb/common/types/row/row_data_collection_scanner.hpp" - -#include "duckdb/common/row_operations/row_operations.hpp" -#include "duckdb/common/types/row/row_data_collection.hpp" -#include "duckdb/storage/buffer_manager.hpp" - -#include - -namespace duckdb { - -void RowDataCollectionScanner::AlignHeapBlocks(RowDataCollection &swizzled_block_collection, - RowDataCollection &swizzled_string_heap, - RowDataCollection &block_collection, RowDataCollection &string_heap, - const RowLayout &layout) { - if (block_collection.count == 0) { - return; - } - - if (layout.AllConstant()) { - // No heap blocks! Just merge fixed-size data - swizzled_block_collection.Merge(block_collection); - return; - } - - // We create one heap block per data block and swizzle the pointers - D_ASSERT(string_heap.keep_pinned == swizzled_string_heap.keep_pinned); - auto &buffer_manager = block_collection.buffer_manager; - auto &heap_blocks = string_heap.blocks; - idx_t heap_block_idx = 0; - idx_t heap_block_remaining = heap_blocks[heap_block_idx]->count; - for (auto &data_block : block_collection.blocks) { - if (heap_block_remaining == 0) { - heap_block_remaining = heap_blocks[++heap_block_idx]->count; - } - - // Pin the data block and swizzle the pointers within the rows - auto data_handle = buffer_manager.Pin(data_block->block); - auto data_ptr = data_handle.Ptr(); - if (!string_heap.keep_pinned) { - D_ASSERT(!data_block->block->IsSwizzled()); - RowOperations::SwizzleColumns(layout, data_ptr, data_block->count); - data_block->block->SetSwizzling(nullptr); - } - // At this point the data block is pinned and the heap pointer is valid - // so we can copy heap data as needed - - // We want to copy as little of the heap data as possible, check how the data and heap blocks line up - if (heap_block_remaining >= data_block->count) { - // Easy: current heap block contains all strings for this data block, just copy (reference) the block - swizzled_string_heap.blocks.emplace_back(heap_blocks[heap_block_idx]->Copy()); - swizzled_string_heap.blocks.back()->count = data_block->count; - - // Swizzle the heap pointer if we are not pinning the heap - auto &heap_block = swizzled_string_heap.blocks.back()->block; - auto heap_handle = buffer_manager.Pin(heap_block); - if (!swizzled_string_heap.keep_pinned) { - auto heap_ptr = Load(data_ptr + layout.GetHeapOffset()); - auto heap_offset = heap_ptr - heap_handle.Ptr(); - RowOperations::SwizzleHeapPointer(layout, data_ptr, heap_ptr, data_block->count, - NumericCast(heap_offset)); - } else { - swizzled_string_heap.pinned_blocks.emplace_back(std::move(heap_handle)); - } - - // Update counter - heap_block_remaining -= data_block->count; - } else { - // Strings for this data block are spread over the current heap block and the next (and possibly more) - if (string_heap.keep_pinned) { - // The heap is changing underneath the data block, - // so swizzle the string pointers to make them portable. - RowOperations::SwizzleColumns(layout, data_ptr, data_block->count); - } - idx_t data_block_remaining = data_block->count; - vector> ptrs_and_sizes; - idx_t total_size = 0; - const auto base_row_ptr = data_ptr; - while (data_block_remaining > 0) { - if (heap_block_remaining == 0) { - heap_block_remaining = heap_blocks[++heap_block_idx]->count; - } - auto next = MinValue(data_block_remaining, heap_block_remaining); - - // Figure out where to start copying strings, and how many bytes we need to copy - auto heap_start_ptr = Load(data_ptr + layout.GetHeapOffset()); - auto heap_end_ptr = - Load(data_ptr + layout.GetHeapOffset() + (next - 1) * layout.GetRowWidth()); - auto size = NumericCast(heap_end_ptr - heap_start_ptr + Load(heap_end_ptr)); - ptrs_and_sizes.emplace_back(heap_start_ptr, size); - D_ASSERT(size <= heap_blocks[heap_block_idx]->byte_offset); - - // Swizzle the heap pointer - RowOperations::SwizzleHeapPointer(layout, data_ptr, heap_start_ptr, next, total_size); - total_size += size; - - // Update where we are in the data and heap blocks - data_ptr += next * layout.GetRowWidth(); - data_block_remaining -= next; - heap_block_remaining -= next; - } - - // Finally, we allocate a new heap block and copy data to it - swizzled_string_heap.blocks.emplace_back(make_uniq( - MemoryTag::ORDER_BY, buffer_manager, MaxValue(total_size, buffer_manager.GetBlockSize()), 1U)); - auto new_heap_handle = buffer_manager.Pin(swizzled_string_heap.blocks.back()->block); - auto new_heap_ptr = new_heap_handle.Ptr(); - for (auto &ptr_and_size : ptrs_and_sizes) { - memcpy(new_heap_ptr, ptr_and_size.first, ptr_and_size.second); - new_heap_ptr += ptr_and_size.second; - } - new_heap_ptr = new_heap_handle.Ptr(); - if (swizzled_string_heap.keep_pinned) { - // Since the heap blocks are pinned, we can unswizzle the data again. - swizzled_string_heap.pinned_blocks.emplace_back(std::move(new_heap_handle)); - RowOperations::UnswizzlePointers(layout, base_row_ptr, new_heap_ptr, data_block->count); - RowOperations::UnswizzleHeapPointer(layout, base_row_ptr, new_heap_ptr, data_block->count); - } - } - } - - // We're done with variable-sized data, now just merge the fixed-size data - swizzled_block_collection.Merge(block_collection); - D_ASSERT(swizzled_block_collection.blocks.size() == swizzled_string_heap.blocks.size()); - - // Update counts and cleanup - swizzled_string_heap.count = string_heap.count; - string_heap.Clear(); -} - -void RowDataCollectionScanner::ScanState::PinData() { - auto &rows = scanner.rows; - D_ASSERT(block_idx < rows.blocks.size()); - auto &data_block = rows.blocks[block_idx]; - if (!data_handle.IsValid() || data_handle.GetBlockHandle() != data_block->block) { - data_handle = rows.buffer_manager.Pin(data_block->block); - } - if (scanner.layout.AllConstant() || !scanner.external) { - return; - } - - auto &heap = scanner.heap; - D_ASSERT(block_idx < heap.blocks.size()); - auto &heap_block = heap.blocks[block_idx]; - if (!heap_handle.IsValid() || heap_handle.GetBlockHandle() != heap_block->block) { - heap_handle = heap.buffer_manager.Pin(heap_block->block); - } -} - -RowDataCollectionScanner::RowDataCollectionScanner(RowDataCollection &rows_p, RowDataCollection &heap_p, - const RowLayout &layout_p, bool external_p, bool flush_p) - : rows(rows_p), heap(heap_p), layout(layout_p), read_state(*this), total_count(rows.count), total_scanned(0), - external(external_p), flush(flush_p), unswizzling(!layout.AllConstant() && external && !heap.keep_pinned) { - - if (unswizzling) { - D_ASSERT(rows.blocks.size() == heap.blocks.size()); - } - - ValidateUnscannedBlock(); -} - -RowDataCollectionScanner::RowDataCollectionScanner(RowDataCollection &rows_p, RowDataCollection &heap_p, - const RowLayout &layout_p, bool external_p, idx_t block_idx, - bool flush_p) - : rows(rows_p), heap(heap_p), layout(layout_p), read_state(*this), total_count(rows.count), total_scanned(0), - external(external_p), flush(flush_p), unswizzling(!layout.AllConstant() && external && !heap.keep_pinned) { - - if (unswizzling) { - D_ASSERT(rows.blocks.size() == heap.blocks.size()); - } - - D_ASSERT(block_idx < rows.blocks.size()); - read_state.block_idx = block_idx; - read_state.entry_idx = 0; - - // Pretend that we have scanned up to the start block - // and will stop at the end - auto begin = rows.blocks.begin(); - auto end = begin + NumericCast(block_idx); - total_scanned = - std::accumulate(begin, end, idx_t(0), [&](idx_t c, const unique_ptr &b) { return c + b->count; }); - total_count = total_scanned + (*end)->count; - - ValidateUnscannedBlock(); -} - -void RowDataCollectionScanner::SwizzleBlockInternal(RowDataBlock &data_block, RowDataBlock &heap_block) { - // Pin the data block and swizzle the pointers within the rows - D_ASSERT(!data_block.block->IsSwizzled()); - auto data_handle = rows.buffer_manager.Pin(data_block.block); - auto data_ptr = data_handle.Ptr(); - RowOperations::SwizzleColumns(layout, data_ptr, data_block.count); - data_block.block->SetSwizzling(nullptr); - - // Swizzle the heap pointers - auto heap_handle = heap.buffer_manager.Pin(heap_block.block); - auto heap_ptr = Load(data_ptr + layout.GetHeapOffset()); - auto heap_offset = heap_ptr - heap_handle.Ptr(); - RowOperations::SwizzleHeapPointer(layout, data_ptr, heap_ptr, data_block.count, NumericCast(heap_offset)); -} - -void RowDataCollectionScanner::SwizzleBlock(idx_t block_idx) { - if (rows.count == 0) { - return; - } - - if (!unswizzling) { - // No swizzled blocks! - return; - } - - auto &data_block = rows.blocks[block_idx]; - if (data_block->block && !data_block->block->IsSwizzled()) { - SwizzleBlockInternal(*data_block, *heap.blocks[block_idx]); - } -} - -void RowDataCollectionScanner::ReSwizzle() { - if (rows.count == 0) { - return; - } - - if (!unswizzling) { - // No swizzled blocks! - return; - } - - D_ASSERT(rows.blocks.size() == heap.blocks.size()); - for (idx_t i = 0; i < rows.blocks.size(); ++i) { - auto &data_block = rows.blocks[i]; - if (data_block->block && !data_block->block->IsSwizzled()) { - SwizzleBlockInternal(*data_block, *heap.blocks[i]); - } - } -} - -void RowDataCollectionScanner::ValidateUnscannedBlock() const { - if (unswizzling && read_state.block_idx < rows.blocks.size() && Remaining()) { - D_ASSERT(rows.blocks[read_state.block_idx]->block->IsSwizzled()); - } -} - -void RowDataCollectionScanner::Scan(DataChunk &chunk) { - auto count = MinValue((idx_t)STANDARD_VECTOR_SIZE, total_count - total_scanned); - if (count == 0) { - chunk.SetCardinality(count); - return; - } - - // Only flush blocks we processed. - const auto flush_block_idx = read_state.block_idx; - - const idx_t &row_width = layout.GetRowWidth(); - // Set up a batch of pointers to scan data from - idx_t scanned = 0; - auto data_pointers = FlatVector::GetData(addresses); - - // We must pin ALL blocks we are going to gather from - vector pinned_blocks; - while (scanned < count) { - read_state.PinData(); - auto &data_block = rows.blocks[read_state.block_idx]; - idx_t next = MinValue(data_block->count - read_state.entry_idx, count - scanned); - const data_ptr_t data_ptr = read_state.data_handle.Ptr() + read_state.entry_idx * row_width; - // Set up the next pointers - data_ptr_t row_ptr = data_ptr; - for (idx_t i = 0; i < next; i++) { - data_pointers[scanned + i] = row_ptr; - row_ptr += row_width; - } - // Unswizzle the offsets back to pointers (if needed) - if (unswizzling) { - RowOperations::UnswizzlePointers(layout, data_ptr, read_state.heap_handle.Ptr(), next); - rows.blocks[read_state.block_idx]->block->SetSwizzling("RowDataCollectionScanner::Scan"); - } - // Update state indices - read_state.entry_idx += next; - scanned += next; - total_scanned += next; - if (read_state.entry_idx == data_block->count) { - // Pin completed blocks so we don't lose them - pinned_blocks.emplace_back(rows.buffer_manager.Pin(data_block->block)); - if (unswizzling) { - auto &heap_block = heap.blocks[read_state.block_idx]; - pinned_blocks.emplace_back(heap.buffer_manager.Pin(heap_block->block)); - } - read_state.block_idx++; - read_state.entry_idx = 0; - ValidateUnscannedBlock(); - } - } - D_ASSERT(scanned == count); - // Deserialize the payload data - for (idx_t col_no = 0; col_no < layout.ColumnCount(); col_no++) { - RowOperations::Gather(addresses, *FlatVector::IncrementalSelectionVector(), chunk.data[col_no], - *FlatVector::IncrementalSelectionVector(), count, layout, col_no); - } - chunk.SetCardinality(count); - chunk.Verify(); - - // Switch to a new set of pinned blocks - read_state.pinned_blocks.swap(pinned_blocks); - - if (flush) { - // Release blocks we have passed. - for (idx_t i = flush_block_idx; i < read_state.block_idx; ++i) { - rows.blocks[i]->block = nullptr; - if (unswizzling) { - heap.blocks[i]->block = nullptr; - } - } - } else if (unswizzling) { - // Reswizzle blocks we have passed so they can be flushed safely. - for (idx_t i = flush_block_idx; i < read_state.block_idx; ++i) { - auto &data_block = rows.blocks[i]; - if (data_block->block && !data_block->block->IsSwizzled()) { - SwizzleBlockInternal(*data_block, *heap.blocks[i]); - } - } - } -} - -void RowDataCollectionScanner::Reset(bool flush_p) { - flush = flush_p; - total_scanned = 0; - - read_state.block_idx = 0; - read_state.entry_idx = 0; -} - -} // namespace duckdb diff --git a/src/duckdb/src/common/types/row/row_layout.cpp b/src/duckdb/src/common/types/row/row_layout.cpp deleted file mode 100644 index 3add8e425..000000000 --- a/src/duckdb/src/common/types/row/row_layout.cpp +++ /dev/null @@ -1,62 +0,0 @@ -//===----------------------------------------------------------------------===// -// DuckDB -// -// duckdb/common/types/row_layout.cpp -// -// -//===----------------------------------------------------------------------===// - -#include "duckdb/common/types/row/row_layout.hpp" - -#include "duckdb/planner/expression/bound_aggregate_expression.hpp" - -namespace duckdb { - -RowLayout::RowLayout() : flag_width(0), data_width(0), row_width(0), all_constant(true), heap_pointer_offset(0) { -} - -void RowLayout::Initialize(vector types_p, bool align) { - offsets.clear(); - types = std::move(types_p); - - // Null mask at the front - 1 bit per value. - flag_width = ValidityBytes::ValidityMaskSize(types.size()); - row_width = flag_width; - - // Whether all columns are constant size. - for (const auto &type : types) { - all_constant = all_constant && TypeIsConstantSize(type.InternalType()); - } - - // This enables pointer swizzling for out-of-core computation. - if (!all_constant) { - // When unswizzled, the pointer lives here. - // When swizzled, the pointer is replaced by an offset. - heap_pointer_offset = row_width; - // The 8 byte pointer will be replaced with an 8 byte idx_t when swizzled. - // However, this cannot be sizeof(data_ptr_t), since 32 bit builds use 4 byte pointers. - row_width += sizeof(idx_t); - } - - // Data columns. No alignment required. - for (const auto &type : types) { - offsets.push_back(row_width); - const auto internal_type = type.InternalType(); - if (TypeIsConstantSize(internal_type) || internal_type == PhysicalType::VARCHAR) { - row_width += GetTypeIdSize(type.InternalType()); - } else { - // Variable size types use pointers to the actual data (can be swizzled). - // Again, we would use sizeof(data_ptr_t), but this is not guaranteed to be equal to sizeof(idx_t). - row_width += sizeof(idx_t); - } - } - - data_width = row_width - flag_width; - - // Alignment padding for the next row - if (align) { - row_width = AlignValue(row_width); - } -} - -} // namespace duckdb diff --git a/src/duckdb/src/common/types/row/tuple_data_allocator.cpp b/src/duckdb/src/common/types/row/tuple_data_allocator.cpp index 7c5fcd32b..bec036387 100644 --- a/src/duckdb/src/common/types/row/tuple_data_allocator.cpp +++ b/src/duckdb/src/common/types/row/tuple_data_allocator.cpp @@ -30,12 +30,14 @@ TupleDataBlock &TupleDataBlock::operator=(TupleDataBlock &&other) noexcept { return *this; } -TupleDataAllocator::TupleDataAllocator(BufferManager &buffer_manager, shared_ptr &layout_ptr_p) - : buffer_manager(buffer_manager), layout_ptr(layout_ptr_p), layout(*layout_ptr) { +TupleDataAllocator::TupleDataAllocator(BufferManager &buffer_manager, shared_ptr layout_ptr_p, + shared_ptr stl_allocator_p) + : stl_allocator(std::move(stl_allocator_p)), buffer_manager(buffer_manager), layout_ptr(std::move(layout_ptr_p)), + layout(*layout_ptr), row_blocks(*stl_allocator), heap_blocks(*stl_allocator) { } TupleDataAllocator::TupleDataAllocator(TupleDataAllocator &allocator) - : buffer_manager(allocator.buffer_manager), layout_ptr(allocator.layout_ptr), layout(*layout_ptr) { + : TupleDataAllocator(allocator.buffer_manager, allocator.layout_ptr, allocator.stl_allocator) { } void TupleDataAllocator::SetDestroyBufferUponUnpin() { @@ -82,6 +84,10 @@ Allocator &TupleDataAllocator::GetAllocator() { return buffer_manager.GetBufferAllocator(); } +ArenaAllocator &TupleDataAllocator::GetStlAllocator() { + return *stl_allocator; +} + shared_ptr TupleDataAllocator::GetLayoutPtr() const { return layout_ptr; } @@ -116,12 +122,12 @@ bool TupleDataAllocator::BuildFastPath(TupleDataSegment &segment, TupleDataPinSt return false; } - auto &chunk = chunks.back(); + auto &chunk = *chunks.back(); if (chunk.count + append_count > STANDARD_VECTOR_SIZE) { return false; } - auto &part = segment.chunk_parts[chunk.part_ids.End() - 1]; + auto &part = *segment.chunk_parts[chunk.part_ids.End() - 1]; auto &row_block = row_blocks[part.row_block_index]; const auto row_width = layout.GetRowWidth(); @@ -152,23 +158,23 @@ void TupleDataAllocator::Build(TupleDataSegment &segment, TupleDataPinState &pin D_ASSERT(this == segment.allocator.get()); auto &chunks = segment.chunks; if (!chunks.empty()) { - ReleaseOrStoreHandles(pin_state, segment, chunks.back(), true); + ReleaseOrStoreHandles(pin_state, segment, *chunks.back(), true); } if (!BuildFastPath(segment, pin_state, chunk_state, append_offset, append_count)) { // Build the chunk parts for the incoming data - chunk_part_indices.clear(); + chunk_state.chunk_part_indices.clear(); idx_t offset = 0; while (offset != append_count) { - if (chunks.empty() || chunks.back().count == STANDARD_VECTOR_SIZE) { - chunks.emplace_back(); + if (chunks.empty() || chunks.back()->count == STANDARD_VECTOR_SIZE) { + chunks.push_back(stl_allocator->MakeUnsafePtr(*stl_allocator->Make())); } - auto &chunk = chunks.back(); + auto &chunk = *chunks.back(); // Build the next part auto next = MinValue(append_count - offset, STANDARD_VECTOR_SIZE - chunk.count); - auto &chunk_part = - chunk.AddPart(segment, BuildChunkPart(pin_state, chunk_state, append_offset + offset, next, chunk)); + auto &chunk_part = chunk.AddPart( + segment, BuildChunkPart(segment, pin_state, chunk_state, append_offset + offset, next, chunk)); next = chunk_part.count; segment.count += next; @@ -190,34 +196,37 @@ void TupleDataAllocator::Build(TupleDataSegment &segment, TupleDataPinState &pin } offset += next; - chunk_part_indices.emplace_back(chunks.size() - 1, chunk.part_ids.End() - 1); + chunk_state.chunk_part_indices.emplace_back(chunks.size() - 1, chunk.part_ids.End() - 1); } // Now initialize the pointers to write the data to - chunk_parts.clear(); - for (const auto &indices : chunk_part_indices) { - chunk_parts.emplace_back(segment.chunk_parts[indices.second]); + chunk_state.chunk_parts.clear(); + for (const auto &indices : chunk_state.chunk_part_indices) { + chunk_state.chunk_parts.emplace_back(*segment.chunk_parts[indices.second]); } - InitializeChunkStateInternal(pin_state, chunk_state, append_offset, false, true, false, chunk_parts); + InitializeChunkStateInternal(pin_state, chunk_state, append_offset, false, true, false, + chunk_state.chunk_parts); // To reduce metadata, we try to merge chunk parts where possible // Due to the way chunk parts are constructed, only the last part of the first chunk is eligible for merging - segment.chunks[chunk_part_indices[0].first].MergeLastChunkPart(segment); + segment.chunks[chunk_state.chunk_part_indices[0].first]->MergeLastChunkPart(segment); } segment.Verify(); } -TupleDataChunkPart TupleDataAllocator::BuildChunkPart(TupleDataPinState &pin_state, TupleDataChunkState &chunk_state, - const idx_t append_offset, const idx_t append_count, - TupleDataChunk &chunk) { +unsafe_arena_ptr +TupleDataAllocator::BuildChunkPart(TupleDataSegment &segment, TupleDataPinState &pin_state, + TupleDataChunkState &chunk_state, const idx_t append_offset, + const idx_t append_count, TupleDataChunk &chunk) { D_ASSERT(append_count != 0); - TupleDataChunkPart result(*chunk.lock); + auto result_ptr = stl_allocator->MakeUnsafePtr(chunk.lock.get()); + auto &result = *result_ptr; const auto block_size = buffer_manager.GetBlockSize(); // Allocate row block (if needed) if (row_blocks.empty() || row_blocks.back().RemainingCapacity() < layout.GetRowWidth()) { - row_blocks.emplace_back(buffer_manager, block_size); + CreateRowBlock(segment); if (partition_index.IsValid()) { // Set the eviction queue index logarithmically using RadixBits row_blocks.back().handle->SetEvictionQueueIndex(RadixPartitioning::RadixBits(partition_index.GetIndex())); } @@ -272,7 +281,7 @@ TupleDataChunkPart TupleDataAllocator::BuildChunkPart(TupleDataPinState &pin_sta // Allocate heap block (if needed) if (heap_blocks.empty() || heap_blocks.back().RemainingCapacity() < heap_sizes[append_offset]) { const auto size = MaxValue(block_size, heap_sizes[append_offset]); - heap_blocks.emplace_back(buffer_manager, size); + CreateHeapBlock(segment, size); if (partition_index.IsValid()) { // Set the eviction queue index logarithmically using RadixBits heap_blocks.back().handle->SetEvictionQueueIndex( RadixPartitioning::RadixBits(partition_index.GetIndex())); @@ -293,14 +302,14 @@ TupleDataChunkPart TupleDataAllocator::BuildChunkPart(TupleDataPinState &pin_sta // Mark this portion of the row block as filled row_block.size += result.count * layout.GetRowWidth(); - return result; + return result_ptr; } void TupleDataAllocator::InitializeChunkState(TupleDataSegment &segment, TupleDataPinState &pin_state, TupleDataChunkState &chunk_state, idx_t chunk_idx, bool init_heap) { D_ASSERT(this == segment.allocator.get()); D_ASSERT(chunk_idx < segment.ChunkCount()); - auto &chunk = segment.chunks[chunk_idx]; + auto &chunk = *segment.chunks[chunk_idx]; // Release or store any handles that are no longer required: // We can't release the heap here if the current chunk's heap_block_ids is empty, because if we are iterating with @@ -308,12 +317,12 @@ void TupleDataAllocator::InitializeChunkState(TupleDataSegment &segment, TupleDa // when chunk 0 needs heap block 0, chunk 1 does not need any heap blocks, and chunk 2 needs heap block 0 again ReleaseOrStoreHandles(pin_state, segment, chunk, !chunk.heap_block_ids.Empty()); - chunk_state.parts.clear(); + chunk_state.chunk_parts.clear(); for (auto part_id = chunk.part_ids.Start(); part_id < chunk.part_ids.End(); part_id++) { - chunk_state.parts.emplace_back(segment.chunk_parts[part_id]); + chunk_state.chunk_parts.emplace_back(*segment.chunk_parts[part_id]); } - InitializeChunkStateInternal(pin_state, chunk_state, 0, true, init_heap, init_heap, chunk_state.parts); + InitializeChunkStateInternal(pin_state, chunk_state, 0, true, init_heap, init_heap, chunk_state.chunk_parts); } static inline void InitializeHeapSizes(const data_ptr_t row_locations[], idx_t heap_sizes[], const idx_t offset, @@ -670,14 +679,15 @@ void TupleDataAllocator::ReleaseOrStoreHandles(TupleDataPinState &pin_state, Tup } void TupleDataAllocator::ReleaseOrStoreHandles(TupleDataPinState &pin_state, TupleDataSegment &segment) { - static TupleDataChunk DUMMY_CHUNK; + mutex dummy_chunk_mutex; + static TupleDataChunk DUMMY_CHUNK(dummy_chunk_mutex); ReleaseOrStoreHandles(pin_state, segment, DUMMY_CHUNK, true); } void TupleDataAllocator::ReleaseOrStoreHandlesInternal(TupleDataSegment &segment, - unsafe_vector &pinned_handles, + unsafe_arena_vector &pinned_handles, buffer_handle_map_t &handles, const ContinuousIdSet &block_ids, - unsafe_vector &blocks, + unsafe_arena_vector &blocks, TupleDataPinProperties properties) { bool found_handle; do { @@ -691,10 +701,7 @@ void TupleDataAllocator::ReleaseOrStoreHandlesInternal(TupleDataSegment &segment switch (properties) { case TupleDataPinProperties::KEEP_EVERYTHING_PINNED: { lock_guard guard(segment.pinned_handles_lock); - const auto block_count = block_id + 1; - if (block_count > pinned_handles.size()) { - pinned_handles.resize(block_count); - } + D_ASSERT(blocks.size() == pinned_handles.size()); pinned_handles[block_id] = std::move(it->second); break; } @@ -718,6 +725,16 @@ void TupleDataAllocator::ReleaseOrStoreHandlesInternal(TupleDataSegment &segment } while (found_handle); } +void TupleDataAllocator::CreateRowBlock(TupleDataSegment &segment) { + row_blocks.emplace_back(buffer_manager, buffer_manager.GetBlockSize()); + segment.pinned_row_handles.resize(row_blocks.size()); +} + +void TupleDataAllocator::CreateHeapBlock(TupleDataSegment &segment, idx_t size) { + heap_blocks.emplace_back(buffer_manager, size); + segment.pinned_heap_handles.resize(heap_blocks.size()); +} + BufferHandle &TupleDataAllocator::PinRowBlock(TupleDataPinState &pin_state, const TupleDataChunkPart &part) { const auto &row_block_index = part.row_block_index; auto it = pin_state.row_handles.find(row_block_index); diff --git a/src/duckdb/src/common/types/row/tuple_data_collection.cpp b/src/duckdb/src/common/types/row/tuple_data_collection.cpp index ffd4a2b4c..a068e0ab4 100644 --- a/src/duckdb/src/common/types/row/tuple_data_collection.cpp +++ b/src/duckdb/src/common/types/row/tuple_data_collection.cpp @@ -12,12 +12,21 @@ namespace duckdb { using ValidityBytes = TupleDataLayout::ValidityBytes; -TupleDataCollection::TupleDataCollection(BufferManager &buffer_manager, shared_ptr layout_ptr_p) - : layout_ptr(std::move(layout_ptr_p)), layout(*layout_ptr), - allocator(make_shared_ptr(buffer_manager, layout_ptr)) { +TupleDataCollection::TupleDataCollection(BufferManager &buffer_manager, shared_ptr layout_ptr_p, + shared_ptr stl_allocator_p) + : stl_allocator(stl_allocator_p ? std::move(stl_allocator_p) + : make_shared_ptr(buffer_manager.GetBufferAllocator())), + layout_ptr(std::move(layout_ptr_p)), layout(*layout_ptr), + allocator(make_shared_ptr(buffer_manager, layout_ptr, stl_allocator)), + segments(*stl_allocator), scatter_functions(*stl_allocator), gather_functions(*stl_allocator) { Initialize(); } +TupleDataCollection::TupleDataCollection(ClientContext &context, shared_ptr layout_ptr, + shared_ptr stl_allocator) + : TupleDataCollection(BufferManager::GetBufferManager(context), std::move(layout_ptr), std::move(stl_allocator)) { +} + TupleDataCollection::~TupleDataCollection() { } @@ -110,13 +119,13 @@ void TupleDataCollection::DestroyChunks(const idx_t chunk_idx_begin, const idx_t D_ASSERT(segments.size() == 1); // Assume 1 segment for now (multi-segment destroys can be implemented if needed) D_ASSERT(chunk_idx_begin <= chunk_idx_end && chunk_idx_end <= ChunkCount()); auto &segment = *segments[0]; - auto &chunk_begin = segment.chunks[chunk_idx_begin]; + auto &chunk_begin = *segment.chunks[chunk_idx_begin]; const auto row_block_begin = chunk_begin.row_block_ids.Start(); if (chunk_idx_end == ChunkCount()) { segment.allocator->DestroyRowBlocks(row_block_begin, segment.allocator->RowBlockCount()); } else { - auto &chunk_end = segment.chunks[chunk_idx_end]; + auto &chunk_end = *segment.chunks[chunk_idx_end]; const auto row_block_end = chunk_end.row_block_ids.Start(); segment.allocator->DestroyRowBlocks(row_block_begin, row_block_end); } @@ -129,7 +138,7 @@ void TupleDataCollection::DestroyChunks(const idx_t chunk_idx_begin, const idx_t if (chunk_idx_end == ChunkCount()) { segment.allocator->DestroyHeapBlocks(heap_block_begin, segment.allocator->HeapBlockCount()); } else { - auto &chunk_end = segment.chunks[chunk_idx_end]; + auto &chunk_end = *segment.chunks[chunk_idx_end]; if (chunk_end.heap_block_ids.Empty()) { return; } @@ -180,7 +189,7 @@ void TupleDataCollection::InitializeAppend(TupleDataAppendState &append_state, v void TupleDataCollection::InitializeAppend(TupleDataPinState &pin_state, TupleDataPinProperties properties) { pin_state.properties = properties; if (segments.empty()) { - segments.emplace_back(make_unsafe_uniq(allocator)); + segments.emplace_back(stl_allocator->MakeUnsafePtr(allocator)); } } @@ -469,7 +478,7 @@ void TupleDataCollection::Combine(TupleDataCollection &other) { other.Reset(); } -void TupleDataCollection::AddSegment(unsafe_unique_ptr segment) { +void TupleDataCollection::AddSegment(unsafe_arena_ptr segment) { count += segment->count; data_size += segment->data_size; segments.emplace_back(std::move(segment)); @@ -504,7 +513,7 @@ void TupleDataCollection::InitializeChunk(DataChunk &chunk, const vectorGetAllocator(), chunk_types); } -void TupleDataCollection::InitializeScanChunk(TupleDataScanState &state, DataChunk &chunk) const { +void TupleDataCollection::InitializeScanChunk(const TupleDataScanState &state, DataChunk &chunk) const { auto &column_ids = state.chunk_state.column_ids; D_ASSERT(!column_ids.empty()); vector chunk_types; @@ -562,11 +571,16 @@ void TupleDataCollection::InitializeScan(TupleDataParallelScanState &state, vect InitializeScan(state.scan_state, std::move(column_ids), properties); } -idx_t TupleDataCollection::FetchChunk(TupleDataScanState &state, const idx_t segment_idx, const idx_t chunk_idx, - const bool init_heap) { - auto &segment = *segments[segment_idx]; - allocator->InitializeChunkState(segment, state.pin_state, state.chunk_state, chunk_idx, init_heap); - return segment.chunks[chunk_idx].count; +idx_t TupleDataCollection::FetchChunk(TupleDataScanState &state, idx_t chunk_idx, bool init_heap) { + for (idx_t segment_idx = 0; segment_idx < segments.size(); segment_idx++) { + auto &segment = *segments[segment_idx]; + if (chunk_idx < segment.ChunkCount()) { + segment.allocator->InitializeChunkState(segment, state.pin_state, state.chunk_state, chunk_idx, init_heap); + return segment.chunks[chunk_idx]->count; + } + chunk_idx -= segment.ChunkCount(); + } + throw InternalException("Chunk index out of in TupleDataCollection::FetchChunk"); } bool TupleDataCollection::Scan(TupleDataScanState &state, DataChunk &result) { @@ -648,7 +662,7 @@ void TupleDataCollection::ScanAtIndex(TupleDataPinState &pin_state, TupleDataChu const vector &column_ids, idx_t segment_index, idx_t chunk_index, DataChunk &result) { auto &segment = *segments[segment_index]; - auto &chunk = segment.chunks[chunk_index]; + const auto &chunk = *segment.chunks[chunk_index]; segment.allocator->InitializeChunkState(segment, pin_state, chunk_state, chunk_index, false); result.Reset(); diff --git a/src/duckdb/src/common/types/row/tuple_data_iterator.cpp b/src/duckdb/src/common/types/row/tuple_data_iterator.cpp index 03dd5db23..5bbe7841d 100644 --- a/src/duckdb/src/common/types/row/tuple_data_iterator.cpp +++ b/src/duckdb/src/common/types/row/tuple_data_iterator.cpp @@ -74,7 +74,7 @@ void TupleDataChunkIterator::Reset() { } idx_t TupleDataChunkIterator::GetCurrentChunkCount() const { - return collection.segments[current_segment_idx]->chunks[current_chunk_idx].count; + return collection.segments[current_segment_idx]->chunks[current_chunk_idx]->count; } TupleDataChunkState &TupleDataChunkIterator::GetChunkState() { diff --git a/src/duckdb/src/common/types/row/tuple_data_scatter_gather.cpp b/src/duckdb/src/common/types/row/tuple_data_scatter_gather.cpp index fe671a46f..3c967d448 100644 --- a/src/duckdb/src/common/types/row/tuple_data_scatter_gather.cpp +++ b/src/duckdb/src/common/types/row/tuple_data_scatter_gather.cpp @@ -1794,7 +1794,6 @@ static void TupleDataCastToArrayStructGather(const TupleDataLayout &layout, Vect const SelectionVector &scan_sel, const idx_t scan_count, Vector &target, const SelectionVector &target_sel, optional_ptr cached_cast_vector, const vector &child_functions) { - if (cached_cast_vector) { // Reuse the cached cast vector TupleDataStructGather(layout, row_locations, col_idx, scan_sel, scan_count, *cached_cast_vector, target_sel, diff --git a/src/duckdb/src/common/types/row/tuple_data_segment.cpp b/src/duckdb/src/common/types/row/tuple_data_segment.cpp index 462e7e474..be6901670 100644 --- a/src/duckdb/src/common/types/row/tuple_data_segment.cpp +++ b/src/duckdb/src/common/types/row/tuple_data_segment.cpp @@ -15,7 +15,7 @@ void TupleDataChunkPart::SetHeapEmpty() { base_heap_ptr = nullptr; } -TupleDataChunk::TupleDataChunk() : count(0), lock(make_unsafe_uniq()) { +TupleDataChunk::TupleDataChunk(mutex &lock_p) : count(0), lock(lock_p) { } static inline void SwapTupleDataChunk(TupleDataChunk &a, TupleDataChunk &b) noexcept { @@ -26,7 +26,7 @@ static inline void SwapTupleDataChunk(TupleDataChunk &a, TupleDataChunk &b) noex std::swap(a.lock, b.lock); } -TupleDataChunk::TupleDataChunk(TupleDataChunk &&other) noexcept : count(0) { +TupleDataChunk::TupleDataChunk(TupleDataChunk &&other) noexcept : count(0), lock(other.lock) { SwapTupleDataChunk(*this, other); } @@ -35,23 +35,24 @@ TupleDataChunk &TupleDataChunk::operator=(TupleDataChunk &&other) noexcept { return *this; } -TupleDataChunkPart &TupleDataChunk::AddPart(TupleDataSegment &segment, TupleDataChunkPart &&part) { +TupleDataChunkPart &TupleDataChunk::AddPart(TupleDataSegment &segment, unsafe_arena_ptr part_ptr) { + auto &part = *part_ptr; count += part.count; row_block_ids.Insert(part.row_block_index); if (!segment.layout.AllConstant() && part.total_heap_size > 0) { heap_block_ids.Insert(part.heap_block_index); } - part.lock = *lock; + part.lock = lock; part_ids.Insert(UnsafeNumericCast(segment.chunk_parts.size())); - segment.chunk_parts.emplace_back(std::move(part)); - return segment.chunk_parts.back(); + segment.chunk_parts.emplace_back(std::move(part_ptr)); + return part; } void TupleDataChunk::Verify(const TupleDataSegment &segment) const { #ifdef D_ASSERT_IS_ENABLED idx_t total_count = 0; for (auto part_id = part_ids.Start(); part_id < part_ids.End(); part_id++) { - total_count += segment.chunk_parts[part_id].count; + total_count += segment.chunk_parts[part_id]->count; } D_ASSERT(this->count == total_count); D_ASSERT(this->count <= STANDARD_VECTOR_SIZE); @@ -63,8 +64,8 @@ void TupleDataChunk::MergeLastChunkPart(TupleDataSegment &segment) { return; } - auto &second_to_last = segment.chunk_parts[part_ids.End() - 2]; - auto &last = segment.chunk_parts[part_ids.End() - 1]; + auto &second_to_last = *segment.chunk_parts[part_ids.End() - 2]; + auto &last = *segment.chunk_parts[part_ids.End() - 1]; auto rows_align = last.row_block_index == second_to_last.row_block_index && @@ -98,11 +99,8 @@ void TupleDataChunk::MergeLastChunkPart(TupleDataSegment &segment) { } TupleDataSegment::TupleDataSegment(shared_ptr allocator_p) - : allocator(std::move(allocator_p)), layout(allocator->GetLayout()), count(0), data_size(0) { - // We initialize these with plenty of room so that we can avoid allocations - static constexpr idx_t CHUNK_RESERVATION = 64; - chunks.reserve(CHUNK_RESERVATION); - chunk_parts.reserve(CHUNK_RESERVATION); + : allocator(std::move(allocator_p)), layout(allocator->GetLayout()), count(0), data_size(0), + pinned_row_handles(allocator->GetStlAllocator()), pinned_heap_handles(allocator->GetStlAllocator()) { } TupleDataSegment::~TupleDataSegment() { @@ -112,7 +110,6 @@ TupleDataSegment::~TupleDataSegment() { } pinned_row_handles.clear(); pinned_heap_handles.clear(); - allocator.reset(); } idx_t TupleDataSegment::ChunkCount() const { @@ -131,18 +128,19 @@ void TupleDataSegment::Unpin() { void TupleDataSegment::Verify() const { #ifdef D_ASSERT_IS_ENABLED - const auto &layout = allocator->GetLayout(); + const auto &allocator_layout = allocator->GetLayout(); idx_t total_count = 0; idx_t total_size = 0; - for (const auto &chunk : chunks) { + for (const auto &chunk_ptr : chunks) { + const auto &chunk = *chunk_ptr; chunk.Verify(*this); total_count += chunk.count; - total_size += chunk.count * layout.GetRowWidth(); - if (!layout.AllConstant()) { + total_size += chunk.count * allocator_layout.GetRowWidth(); + if (!allocator_layout.AllConstant()) { for (auto part_id = chunk.part_ids.Start(); part_id < chunk.part_ids.End(); part_id++) { - total_size += chunk_parts[part_id].total_heap_size; + total_size += chunk_parts[part_id]->total_heap_size; } } } diff --git a/src/duckdb/src/common/types/selection_vector.cpp b/src/duckdb/src/common/types/selection_vector.cpp index 145b6bfa1..a1232340c 100644 --- a/src/duckdb/src/common/types/selection_vector.cpp +++ b/src/duckdb/src/common/types/selection_vector.cpp @@ -50,6 +50,14 @@ buffer_ptr SelectionVector::Slice(const SelectionVector &sel, idx return data; } +idx_t SelectionVector::SliceInPlace(const SelectionVector &source, idx_t count) { + for (idx_t i = 0; i < count; ++i) { + set_index(i, get_index(source.get_index(i))); + } + + return count; +} + void SelectionVector::Verify(idx_t count, idx_t vector_size) const { #ifdef DEBUG D_ASSERT(vector_size >= 1); diff --git a/src/duckdb/src/common/types/string_type.cpp b/src/duckdb/src/common/types/string_type.cpp index f5a236557..bea85327a 100644 --- a/src/duckdb/src/common/types/string_type.cpp +++ b/src/duckdb/src/common/types/string_type.cpp @@ -6,6 +6,8 @@ #include "utf8proc_wrapper.hpp" namespace duckdb { +constexpr idx_t string_t::MAX_STRING_SIZE; +constexpr idx_t string_t::INLINE_LENGTH; void string_t::Verify() const { #ifdef DEBUG diff --git a/src/duckdb/src/common/types/value.cpp b/src/duckdb/src/common/types/value.cpp index 2bef3a82d..f1352331f 100644 --- a/src/duckdb/src/common/types/value.cpp +++ b/src/duckdb/src/common/types/value.cpp @@ -919,6 +919,14 @@ Value Value::BIGNUM(const string &data) { return result; } +Value Value::GEOMETRY(const_data_ptr_t data, idx_t len) { + Value result; + result.type_ = LogicalType::GEOMETRY(); // construct type explicitly so that we get the ExtraTypeInfo + result.is_null = false; + result.value_info_ = make_shared_ptr(string(const_char_ptr_cast(data), len)); + return result; +} + Value Value::BLOB(const string &data) { Value result(LogicalType::BLOB); result.is_null = false; diff --git a/src/duckdb/src/common/types/vector.cpp b/src/duckdb/src/common/types/vector.cpp index ad27b162d..070907c72 100644 --- a/src/duckdb/src/common/types/vector.cpp +++ b/src/duckdb/src/common/types/vector.cpp @@ -32,7 +32,8 @@ namespace duckdb { UnifiedVectorFormat::UnifiedVectorFormat() : sel(nullptr), data(nullptr), physical_type(PhysicalType::INVALID) { } -UnifiedVectorFormat::UnifiedVectorFormat(UnifiedVectorFormat &&other) noexcept : sel(nullptr), data(nullptr) { +UnifiedVectorFormat::UnifiedVectorFormat(UnifiedVectorFormat &&other) noexcept + : sel(nullptr), data(nullptr), physical_type(PhysicalType::INVALID) { bool refers_to_self = other.sel == &other.owned_sel; std::swap(sel, other.sel); std::swap(data, other.data); @@ -96,8 +97,7 @@ Vector::Vector(const Value &value) : type(value.type()) { Vector::Vector(Vector &&other) noexcept : vector_type(other.vector_type), type(std::move(other.type)), data(other.data), - validity(std::move(other.validity)), buffer(std::move(other.buffer)), auxiliary(std::move(other.auxiliary)), - cached_hashes(std::move(other.cached_hashes)) { + validity(std::move(other.validity)), buffer(std::move(other.buffer)), auxiliary(std::move(other.auxiliary)) { } void Vector::Reference(const Value &value) { @@ -171,7 +171,6 @@ void Vector::Reinterpret(const Vector &other) { auxiliary = make_shared_ptr(std::move(new_vector)); } else { AssignSharedPointer(auxiliary, other.auxiliary); - AssignSharedPointer(cached_hashes, other.cached_hashes); } data = other.data; validity = other.validity; @@ -235,6 +234,9 @@ void Vector::Slice(const Vector &other, const SelectionVector &sel, idx_t count) } void Vector::Slice(const SelectionVector &sel, idx_t count) { + if (!sel.IsSet() || count == 0) { + return; // Nothing to do here + } if (GetVectorType() == VectorType::CONSTANT_VECTOR) { // dictionary on a constant is just a constant return; @@ -276,7 +278,6 @@ void Vector::Slice(const SelectionVector &sel, idx_t count) { vector_type = VectorType::DICTIONARY_VECTOR; buffer = std::move(dict_buffer); auxiliary = std::move(child_ref); - cached_hashes.reset(); } void Vector::Dictionary(idx_t dictionary_size, const SelectionVector &sel, idx_t count) { @@ -287,15 +288,25 @@ void Vector::Dictionary(idx_t dictionary_size, const SelectionVector &sel, idx_t } void Vector::Dictionary(Vector &dict, idx_t dictionary_size, const SelectionVector &sel, idx_t count) { - if (DictionaryVector::CanCacheHashes(dict.GetType()) && !dict.cached_hashes) { - // Create an empty hash vector for this dictionary, potentially to be used for caching hashes later - // This needs to happen here, as we need to add "cached_hashes" to the original input Vector "dict" - dict.cached_hashes = make_buffer(Vector(LogicalType::HASH, false, false, 0)); - } Reference(dict); Dictionary(dictionary_size, sel, count); } +void Vector::Dictionary(buffer_ptr reusable_dict, const SelectionVector &sel) { + D_ASSERT(type.InternalType() != PhysicalType::STRUCT); + D_ASSERT(type == reusable_dict->data.GetType()); + vector_type = VectorType::DICTIONARY_VECTOR; + data = reusable_dict->data.data; + validity.Reset(); + + auto dict_buffer = make_buffer(sel); + dict_buffer->SetDictionarySize(reusable_dict->size.GetIndex()); + dict_buffer->SetDictionaryId(reusable_dict->id); + buffer = std::move(dict_buffer); + + auxiliary = std::move(reusable_dict); +} + void Vector::Slice(const SelectionVector &sel, idx_t count, SelCache &cache) { if (GetVectorType() == VectorType::DICTIONARY_VECTOR && GetType().InternalType() != PhysicalType::STRUCT) { // dictionary vector: need to merge dictionaries @@ -353,7 +364,6 @@ void Vector::Initialize(bool initialize_to_zero, idx_t capacity) { } void Vector::FindResizeInfos(vector &resize_infos, const idx_t multiplier) { - ResizeInfo resize_info(*this, data, buffer.get(), multiplier); resize_infos.emplace_back(resize_info); @@ -724,6 +734,10 @@ Value Vector::GetValueInternal(const Vector &v_p, idx_t index_p) { auto str = reinterpret_cast(data)[index]; return Value::BIGNUM(const_data_ptr_cast(str.data.GetData()), str.data.GetSize()); } + case LogicalTypeId::GEOMETRY: { + auto str = reinterpret_cast(data)[index]; + return Value::GEOMETRY(const_data_ptr_cast(str.GetData()), str.GetSize()); + } case LogicalTypeId::AGGREGATE_STATE: { auto str = reinterpret_cast(data)[index]; return Value::AGGREGATE_STATE(vector->GetType(), const_data_ptr_cast(str.GetData()), str.GetSize()); @@ -802,7 +816,6 @@ Value Vector::GetValue(const Vector &v_p, idx_t index_p) { value.GetTypeMutable().CopyAuxInfo(v_p.GetType()); } if (v_p.GetType().id() != LogicalTypeId::AGGREGATE_STATE && value.type().id() != LogicalTypeId::AGGREGATE_STATE) { - D_ASSERT(v_p.GetType() == value.type()); } return value; @@ -1219,7 +1232,6 @@ void Vector::ToUnifiedFormat(idx_t count, UnifiedVectorFormat &format) { } void Vector::RecursiveToUnifiedFormat(Vector &input, idx_t count, RecursiveUnifiedVectorFormat &data) { - input.ToUnifiedFormat(count, data.unified); data.logical_type = input.GetType(); @@ -1846,7 +1858,9 @@ void Vector::DebugTransformToDictionary(Vector &vector, idx_t count) { inverted_sel.set_index(offset++, current_index); inverted_sel.set_index(offset++, current_index); } - Vector inverted_vector(vector, inverted_sel, verify_count); + auto reusable_dict = DictionaryVector::CreateReusableDictionary(vector.type, verify_count); + auto &inverted_vector = reusable_dict->data; + inverted_vector.Slice(vector, inverted_sel, verify_count); inverted_vector.Flatten(verify_count); // now insert the NULL values at every other position for (idx_t i = 0; i < count; i++) { @@ -1860,8 +1874,13 @@ void Vector::DebugTransformToDictionary(Vector &vector, idx_t count) { original_sel.set_index(offset++, verify_count - 1 - i * 2); } // now slice the inverted vector with the inverted selection vector - vector.Dictionary(inverted_vector, verify_count, original_sel, count); - DictionaryVector::SetDictionaryId(vector, UUID::ToString(UUID::GenerateRandomUUID())); + if (vector.GetType().InternalType() == PhysicalType::STRUCT) { + // Reusable dictionary API does not work for STRUCT + vector.Dictionary(inverted_vector, verify_count, original_sel, count); + vector.buffer->Cast().SetDictionaryId(reusable_dict->id); + } else { + vector.Dictionary(reusable_dict, original_sel); + } vector.Verify(count); } @@ -1922,17 +1941,27 @@ void Vector::DebugShuffleNestedVector(Vector &vector, idx_t count) { //===--------------------------------------------------------------------===// // DictionaryVector //===--------------------------------------------------------------------===// +buffer_ptr DictionaryVector::CreateReusableDictionary(const LogicalType &type, const idx_t &size) { + auto res = make_buffer(Vector(type, size)); + res->size = size; + res->id = UUID::ToString(UUID::GenerateRandomUUID()); + return res; +} + const Vector &DictionaryVector::GetCachedHashes(Vector &input) { D_ASSERT(CanCacheHashes(input)); - auto &dictionary = Child(input); - auto &dictionary_hashes = dictionary.cached_hashes->Cast().data; - if (!dictionary_hashes.data) { + + auto &child = input.auxiliary->Cast(); + lock_guard guard(child.cached_hashes_lock); + + if (!child.cached_hashes.data) { // Uninitialized: hash the dictionary - const auto dictionary_count = DictionarySize(input).GetIndex(); - dictionary_hashes.Initialize(false, dictionary_count); - VectorOperations::Hash(dictionary, dictionary_hashes, dictionary_count); + const auto dictionary_size = DictionarySize(input).GetIndex(); + D_ASSERT(!child.size.IsValid() || child.size.GetIndex() == dictionary_size); + child.cached_hashes.Initialize(false, dictionary_size); + VectorOperations::Hash(child.data, child.cached_hashes, dictionary_size); } - return dictionary.cached_hashes->Cast().data; + return child.cached_hashes; } //===--------------------------------------------------------------------===// @@ -2317,7 +2346,6 @@ const Vector &MapVector::GetValues(const Vector &vector) { } MapInvalidReason MapVector::CheckMapValidity(Vector &map, idx_t count, const SelectionVector &sel) { - D_ASSERT(map.GetType().id() == LogicalTypeId::MAP); // unify the MAP vector, which is a physical LIST vector @@ -2332,7 +2360,6 @@ MapInvalidReason MapVector::CheckMapValidity(Vector &map, idx_t count, const Sel keys.ToUnifiedFormat(maps_length, key_data); for (idx_t row_idx = 0; row_idx < count; row_idx++) { - auto mapped_row = sel.get_index(row_idx); auto map_idx = map_data.sel->get_index(mapped_row); @@ -2503,7 +2530,6 @@ void ListVector::PushBack(Vector &target, const Value &insert) { } idx_t ListVector::GetConsecutiveChildList(Vector &list, Vector &result, idx_t offset, idx_t count) { - auto info = ListVector::GetConsecutiveChildListInfo(list, offset, count); if (info.needs_slicing) { SelectionVector sel(info.child_list_info.length); @@ -2516,7 +2542,6 @@ idx_t ListVector::GetConsecutiveChildList(Vector &list, Vector &result, idx_t of } ConsecutiveChildListInfo ListVector::GetConsecutiveChildListInfo(Vector &list, idx_t offset, idx_t count) { - ConsecutiveChildListInfo info; UnifiedVectorFormat unified_list_data; list.ToUnifiedFormat(offset + count, unified_list_data); diff --git a/src/duckdb/src/common/value_operations/comparison_operations.cpp b/src/duckdb/src/common/value_operations/comparison_operations.cpp index ac4c88c01..775757c59 100644 --- a/src/duckdb/src/common/value_operations/comparison_operations.cpp +++ b/src/duckdb/src/common/value_operations/comparison_operations.cpp @@ -141,6 +141,11 @@ static bool TemplatedBooleanOperation(const Value &left, const Value &right) { auto &right_children = StructValue::GetChildren(right); // this should be enforced by the type D_ASSERT(left_children.size() == right_children.size()); + if (left_children.empty()) { + const auto const_true = Value::BOOLEAN(true); + return ValuePositionComparator::Final(const_true, const_true); + } + idx_t i = 0; for (; i < left_children.size() - 1; ++i) { if (ValuePositionComparator::Definite(left_children[i], right_children[i])) { diff --git a/src/duckdb/src/common/vector_operations/is_distinct_from.cpp b/src/duckdb/src/common/vector_operations/is_distinct_from.cpp index e57f9738d..d08a2d276 100644 --- a/src/duckdb/src/common/vector_operations/is_distinct_from.cpp +++ b/src/duckdb/src/common/vector_operations/is_distinct_from.cpp @@ -289,6 +289,7 @@ template idx_t DistinctSelect(Vector &left, Vector &right, const SelectionVector *sel, idx_t count, SelectionVector *true_sel, SelectionVector *false_sel, optional_ptr null_mask) { if (!sel) { + D_ASSERT(count <= STANDARD_VECTOR_SIZE); sel = FlatVector::IncrementalSelectionVector(); } @@ -468,7 +469,6 @@ using StructEntries = vector>; void ExtractNestedSelection(const SelectionVector &slice_sel, const idx_t count, const SelectionVector &sel, OptionalSelection &opt) { - for (idx_t i = 0; i < count;) { const auto slice_idx = slice_sel.get_index(i); const auto result_idx = sel.get_index(slice_idx); @@ -478,21 +478,21 @@ void ExtractNestedSelection(const SelectionVector &slice_sel, const idx_t count, } void ExtractNestedMask(const SelectionVector &slice_sel, const idx_t count, const SelectionVector &sel, - ValidityMask *child_mask, optional_ptr null_mask) { - - if (!child_mask) { + ValidityMask *child_mask_p, optional_ptr null_mask) { + if (!child_mask_p) { return; } + auto &child_mask = *child_mask_p; for (idx_t i = 0; i < count; ++i) { const auto slice_idx = slice_sel.get_index(i); const auto result_idx = sel.get_index(slice_idx); - if (child_mask && !child_mask->RowIsValid(slice_idx)) { + if (!child_mask.RowIsValid(slice_idx)) { null_mask->SetInvalid(result_idx); } } - child_mask->Reset(null_mask->Capacity()); + child_mask.Reset(null_mask->Capacity()); } void DensifyNestedSelection(const SelectionVector &dense_sel, const idx_t count, SelectionVector &slice_sel) { @@ -767,8 +767,6 @@ idx_t DistinctSelectArray(Vector &left, Vector &right, idx_t count, const Select return count; } - // FIXME: This function can probably be optimized since we know the array size is fixed for every entry. - D_ASSERT(ArrayType::GetSize(left.GetType()) == ArrayType::GetSize(right.GetType())); auto array_size = ArrayType::GetSize(left.GetType()); @@ -808,39 +806,13 @@ idx_t DistinctSelectArray(Vector &left, Vector &right, idx_t count, const Select } idx_t match_count = 0; - for (idx_t pos = 0; count > 0; ++pos) { + for (idx_t pos = 0; pos < array_size && count > 0; ++pos) { // Set up the cursors for the current position PositionArrayCursor(lcursor, lvdata, pos, slice_sel, count, array_size); PositionArrayCursor(rcursor, rvdata, pos, slice_sel, count, array_size); - // Tie-break the pairs where one of the LISTs is exhausted. idx_t true_count = 0; idx_t false_count = 0; - idx_t maybe_count = 0; - for (idx_t i = 0; i < count; ++i) { - const auto slice_idx = slice_sel.get_index(i); - if (array_size == pos) { - const auto idx = sel.get_index(slice_idx); - if (PositionComparator::TieBreak(array_size, array_size)) { - true_opt.Append(true_count, idx); - } else { - false_opt.Append(false_count, idx); - } - } else { - true_sel.set_index(maybe_count++, slice_idx); - } - } - true_opt.Advance(true_count); - false_opt.Advance(false_count); - match_count += true_count; - - // Redensify the list cursors - if (maybe_count < count) { - count = maybe_count; - DensifyNestedSelection(true_sel, count, slice_sel); - PositionArrayCursor(lcursor, lvdata, pos, slice_sel, count, array_size); - PositionArrayCursor(rcursor, rvdata, pos, slice_sel, count, array_size); - } // Find everything that definitely matches true_count = @@ -878,6 +850,15 @@ idx_t DistinctSelectArray(Vector &left, Vector &right, idx_t count, const Select count = true_count; } + if (count > 0) { + if (PositionComparator::TieBreak(array_size, array_size)) { + ExtractNestedSelection(slice_sel, count, sel, true_opt); + match_count += count; + } else { + ExtractNestedSelection(slice_sel, count, sel, false_opt); + } + } + return match_count; } @@ -890,6 +871,7 @@ idx_t DistinctSelectNested(Vector &left, Vector &right, optional_ptr(l_not_null, r_not_null, count, match_count, *sel, maybe_vec, true_opt, false_opt, null_mask); - switch (left.GetType().InternalType()) { + auto &left_type = left.GetType(); + switch (left_type.InternalType()) { case PhysicalType::LIST: match_count += DistinctSelectList(l_not_null, r_not_null, unknown, maybe_vec, true_opt, false_opt, null_mask); @@ -1009,7 +992,6 @@ template idx_t TemplatedDistinctSelectOperation(Vector &left, Vector &right, optional_ptr sel, idx_t count, optional_ptr true_sel, optional_ptr false_sel, optional_ptr null_mask) { - switch (left.GetType().InternalType()) { case PhysicalType::BOOL: case PhysicalType::INT8: diff --git a/src/duckdb/src/common/vector_operations/vector_copy.cpp b/src/duckdb/src/common/vector_operations/vector_copy.cpp index af75d56b9..2b333bc99 100644 --- a/src/duckdb/src/common/vector_operations/vector_copy.cpp +++ b/src/duckdb/src/common/vector_operations/vector_copy.cpp @@ -39,7 +39,6 @@ static const ValidityMask &ExtractValidityMask(const Vector &v) { void VectorOperations::Copy(const Vector &source_p, Vector &target, const SelectionVector &sel_p, idx_t source_count, idx_t source_offset, idx_t target_offset, idx_t copy_count) { - SelectionVector owned_sel; const SelectionVector *sel = &sel_p; diff --git a/src/duckdb/src/common/virtual_file_system.cpp b/src/duckdb/src/common/virtual_file_system.cpp index 7940d0120..bae620775 100644 --- a/src/duckdb/src/common/virtual_file_system.cpp +++ b/src/duckdb/src/common/virtual_file_system.cpp @@ -34,8 +34,9 @@ unique_ptr VirtualFileSystem::OpenFileExtended(const OpenFileInfo &f } } // open the base file handle in UNCOMPRESSED mode + flags.SetCompression(FileCompressionType::UNCOMPRESSED); - auto file_handle = FindFileSystem(file.path).OpenFile(file, flags, opener); + auto file_handle = FindFileSystem(file.path, opener).OpenFile(file, flags, opener); if (!file_handle) { return nullptr; } @@ -111,7 +112,7 @@ void VirtualFileSystem::RemoveDirectory(const string &directory, optional_ptr &callback, optional_ptr opener) { - return FindFileSystem(directory).ListFiles(directory, callback, opener); + return FindFileSystem(directory, opener).ListFiles(directory, callback, opener); } void VirtualFileSystem::MoveFile(const string &source, const string &target, optional_ptr opener) { @@ -119,7 +120,7 @@ void VirtualFileSystem::MoveFile(const string &source, const string &target, opt } bool VirtualFileSystem::FileExists(const string &filename, optional_ptr opener) { - return FindFileSystem(filename).FileExists(filename, opener); + return FindFileSystem(filename, opener).FileExists(filename, opener); } bool VirtualFileSystem::IsPipe(const string &filename, optional_ptr opener) { @@ -139,7 +140,7 @@ string VirtualFileSystem::PathSeparator(const string &path) { } vector VirtualFileSystem::Glob(const string &path, FileOpener *opener) { - return FindFileSystem(path).Glob(path, opener); + return FindFileSystem(path, opener).Glob(path, opener); } void VirtualFileSystem::RegisterSubSystem(unique_ptr fs) { @@ -224,16 +225,61 @@ bool VirtualFileSystem::SubSystemIsDisabled(const string &name) { return disabled_file_systems.find(name) != disabled_file_systems.end(); } +FileSystem &VirtualFileSystem::FindFileSystem(const string &path, optional_ptr opener) { + return FindFileSystem(path, FileOpener::TryGetDatabase(opener)); +} + +FileSystem &VirtualFileSystem::FindFileSystem(const string &path, optional_ptr db_instance) { + auto fs = FindFileSystemInternal(path); + + if (!fs && db_instance) { + string required_extension; + + for (const auto &entry : EXTENSION_FILE_PREFIXES) { + if (StringUtil::StartsWith(path, entry.name)) { + required_extension = entry.extension; + } + } + if (!required_extension.empty() && db_instance && !db_instance->ExtensionIsLoaded(required_extension)) { + auto &dbconfig = DBConfig::GetConfig(*db_instance); + if (!ExtensionHelper::CanAutoloadExtension(required_extension) || + !dbconfig.options.autoload_known_extensions) { + auto error_message = "File " + path + " requires the extension " + required_extension + " to be loaded"; + error_message = + ExtensionHelper::AddExtensionInstallHintToErrorMsg(*db_instance, error_message, required_extension); + throw MissingExtensionException(error_message); + } + // an extension is required to read this file, but it is not loaded - try to load it + ExtensionHelper::AutoLoadExtension(*db_instance, required_extension); + } + + // Retry after having autoloaded + fs = FindFileSystem(path); + } + + if (!fs) { + fs = default_fs; + } + if (!disabled_file_systems.empty() && disabled_file_systems.find(fs->GetName()) != disabled_file_systems.end()) { + throw PermissionException("File system %s has been disabled by configuration", fs->GetName()); + } + return *fs; +} + FileSystem &VirtualFileSystem::FindFileSystem(const string &path) { - auto &fs = FindFileSystemInternal(path); - if (!disabled_file_systems.empty() && disabled_file_systems.find(fs.GetName()) != disabled_file_systems.end()) { - throw PermissionException("File system %s has been disabled by configuration", fs.GetName()); + auto fs = FindFileSystemInternal(path); + if (!fs) { + fs = default_fs; + } + if (!disabled_file_systems.empty() && disabled_file_systems.find(fs->GetName()) != disabled_file_systems.end()) { + throw PermissionException("File system %s has been disabled by configuration", fs->GetName()); } - return fs; + return *fs; } -FileSystem &VirtualFileSystem::FindFileSystemInternal(const string &path) { +optional_ptr VirtualFileSystem::FindFileSystemInternal(const string &path) { FileSystem *fs = nullptr; + for (auto &sub_system : sub_systems) { if (sub_system->CanHandleFile(path)) { if (sub_system->IsManuallySet()) { @@ -245,7 +291,9 @@ FileSystem &VirtualFileSystem::FindFileSystemInternal(const string &path) { if (fs) { return *fs; } - return *default_fs; + + // We could use default_fs, that's on the caller + return nullptr; } } // namespace duckdb diff --git a/src/duckdb/src/execution/aggregate_hashtable.cpp b/src/duckdb/src/execution/aggregate_hashtable.cpp index 87c3bc661..af93d3caa 100644 --- a/src/duckdb/src/execution/aggregate_hashtable.cpp +++ b/src/duckdb/src/execution/aggregate_hashtable.cpp @@ -48,7 +48,6 @@ GroupedAggregateHashTable::GroupedAggregateHashTable(ClientContext &context_p, A : BaseAggregateHashTable(context_p, allocator, aggregate_objects_p, std::move(payload_types_p)), context(context_p), radix_bits(radix_bits), count(0), capacity(0), sink_count(0), skip_lookups(false), enable_hll(false), aggregate_allocator(make_shared_ptr(allocator)), state(*aggregate_allocator) { - // Append hash column to the end and initialise the row layout group_types_p.emplace_back(LogicalType::HASH); diff --git a/src/duckdb/src/execution/expression_executor.cpp b/src/duckdb/src/execution/expression_executor.cpp index ec11c1289..0278df506 100644 --- a/src/duckdb/src/execution/expression_executor.cpp +++ b/src/duckdb/src/execution/expression_executor.cpp @@ -181,6 +181,8 @@ void ExpressionExecutor::Verify(const Expression &expr, Vector &vector, idx_t co } else { VectorOperations::DefaultCast(vector, intermediate, count, true); } + intermediate.Verify(count); + //! FIXME: this is probably also where we want to test 'variant_normalize' Vector result(vector.GetType(), true, false, count); //! Then cast back into the original type @@ -190,6 +192,7 @@ void ExpressionExecutor::Verify(const Expression &expr, Vector &vector, idx_t co VectorOperations::DefaultCast(intermediate, result, count, true); } vector.Reference(result); + vector.Verify(count); } } @@ -227,7 +230,6 @@ void ExpressionExecutor::Execute(const Expression &expr, ExpressionState *state, // The result vector must be used for the first time, or must be reset. // Otherwise, the validity mask can contain previous (now incorrect) data. if (result.GetVectorType() == VectorType::FLAT_VECTOR) { - // We do not initialize vector caches for these expressions. if (expr.GetExpressionClass() != ExpressionClass::BOUND_REF && expr.GetExpressionClass() != ExpressionClass::BOUND_CONSTANT && diff --git a/src/duckdb/src/execution/expression_executor/execute_comparison.cpp b/src/duckdb/src/execution/expression_executor/execute_comparison.cpp index 6e78de49c..fafb24d49 100644 --- a/src/duckdb/src/execution/expression_executor/execute_comparison.cpp +++ b/src/duckdb/src/execution/expression_executor/execute_comparison.cpp @@ -209,7 +209,6 @@ idx_t NestedSelector::Select(Vector &left, Vector &ri static inline idx_t SelectNotNull(Vector &left, Vector &right, const idx_t count, const SelectionVector &sel, SelectionVector &maybe_vec, OptionalSelection &false_opt, optional_ptr null_mask) { - UnifiedVectorFormat lvdata, rvdata; left.ToUnifiedFormat(count, lvdata); right.ToUnifiedFormat(count, rvdata); diff --git a/src/duckdb/src/execution/expression_executor/execute_function.cpp b/src/duckdb/src/execution/expression_executor/execute_function.cpp index a7e99287b..c0b84eddb 100644 --- a/src/duckdb/src/execution/expression_executor/execute_function.cpp +++ b/src/duckdb/src/execution/expression_executor/execute_function.cpp @@ -71,7 +71,7 @@ bool ExecuteFunctionState::TryExecuteDictionaryExpression(const BoundFunctionExp return false; // Dictionary is too large, bail } - if (input_dictionary_id != current_input_dictionary_id) { + if (!output_dictionary || current_input_dictionary_id != input_dictionary_id) { // We haven't seen this dictionary before const auto chunk_fill_ratio = static_cast(args.size()) / STANDARD_VECTOR_SIZE; if (input_dictionary_size > STANDARD_VECTOR_SIZE && chunk_fill_ratio <= CHUNK_FILL_RATIO_THRESHOLD) { @@ -82,9 +82,8 @@ bool ExecuteFunctionState::TryExecuteDictionaryExpression(const BoundFunctionExp } // We can do dictionary optimization! Re-initialize + output_dictionary = DictionaryVector::CreateReusableDictionary(result.GetType(), input_dictionary_size); current_input_dictionary_id = input_dictionary_id; - output_dictionary = make_uniq(result.GetType(), input_dictionary_size); - output_dictionary_id = UUID::ToString(UUID::GenerateRandomUUID()); // Set up the input chunk DataChunk input_chunk; @@ -105,16 +104,14 @@ bool ExecuteFunctionState::TryExecuteDictionaryExpression(const BoundFunctionExp input_chunk.SetCardinality(count); // Execute, storing the result in an intermediate vector, and copying it to the output dictionary - Vector output_intermediate(output_dictionary->GetType()); + Vector output_intermediate(result.GetType()); expr.function.function(input_chunk, state, output_intermediate); - VectorOperations::Copy(output_intermediate, *output_dictionary, count, 0, offset); + VectorOperations::Copy(output_intermediate, output_dictionary->data, count, 0, offset); } } - // Create a dictionary result vector and give it an ID - const auto &input_sel_vector = DictionaryVector::SelVector(unary_input); - result.Dictionary(*output_dictionary, input_dictionary_size, input_sel_vector, args.size()); - DictionaryVector::SetDictionaryId(result, output_dictionary_id); + // Result references the dictionary + result.Dictionary(output_dictionary, DictionaryVector::SelVector(unary_input)); return true; } @@ -135,7 +132,7 @@ unique_ptr ExpressionExecutor::InitializeState(const BoundFunct static void VerifyNullHandling(const BoundFunctionExpression &expr, DataChunk &args, Vector &result) { #ifdef DEBUG - if (args.data.empty() || expr.function.null_handling != FunctionNullHandling::DEFAULT_NULL_HANDLING) { + if (args.data.empty() || expr.function.GetNullHandling() != FunctionNullHandling::DEFAULT_NULL_HANDLING) { return; } diff --git a/src/duckdb/src/execution/index/art/art.cpp b/src/duckdb/src/execution/index/art/art.cpp index 87c9cbf9b..f85538b0d 100644 --- a/src/duckdb/src/execution/index/art/art.cpp +++ b/src/duckdb/src/execution/index/art/art.cpp @@ -50,7 +50,6 @@ ART::ART(const string &name, const IndexConstraintType index_constraint_type, co const IndexStorageInfo &info) : BoundIndex(name, ART::TYPE_NAME, index_constraint_type, column_ids, table_io_manager, unbound_expressions, db), allocators(allocators_ptr), owns_data(false), verify_max_key_len(false) { - // FIXME: Use the new byte representation function to support nested types. for (idx_t i = 0; i < types.size(); i++) { switch (types[i]) { @@ -522,7 +521,9 @@ ErrorData ART::Insert(IndexLock &l, DataChunk &chunk, Vector &row_ids, IndexAppe if (keys[i].Empty()) { continue; } - D_ASSERT(ARTOperator::Lookup(*this, tree, keys[i], 0)); + auto leaf = ARTOperator::Lookup(*this, tree, keys[i], 0); + D_ASSERT(leaf); + D_ASSERT(ARTOperator::LookupInLeaf(*this, *leaf, row_id_keys[i])); } #endif return ErrorData(); @@ -602,8 +603,9 @@ void ART::Delete(IndexLock &state, DataChunk &input, Vector &row_ids) { continue; } auto leaf = ARTOperator::Lookup(*this, tree, keys[i], 0); - if (leaf && leaf->GetType() == NType::LEAF_INLINED) { - D_ASSERT(leaf->GetRowId() != row_id_keys[i].GetRowId()); + if (leaf) { + auto contains_row_id = ARTOperator::LookupInLeaf(*this, *leaf, row_id_keys[i]); + D_ASSERT(!contains_row_id); } } #endif @@ -634,7 +636,7 @@ bool ART::SearchGreater(ARTKey &key, bool equal, idx_t max_count, set &ro Iterator it(*this); // Early-out, if the maximum value in the ART is lower than the lower bound. - if (!it.LowerBound(tree, key, equal, 0)) { + if (!it.LowerBound(tree, key, equal)) { return true; } @@ -667,7 +669,7 @@ bool ART::SearchCloseRange(ARTKey &lower_bound, ARTKey &upper_bound, bool left_e Iterator it(*this); // Early-out, if the maximum value in the ART is lower than the lower bound. - if (!it.LowerBound(tree, lower_bound, left_equal, 0)) { + if (!it.LowerBound(tree, lower_bound, left_equal)) { return true; } @@ -1047,7 +1049,7 @@ idx_t ART::GetInMemorySize(IndexLock &index_lock) { return in_memory_size; } -//===--------------------------------------------------------------------===// +//===-------------------------------------------------------------------===// // Vacuum //===--------------------------------------------------------------------===// @@ -1205,17 +1207,27 @@ bool ART::MergeIndexes(IndexLock &state, BoundIndex &other_index) { // Verification //===--------------------------------------------------------------------===// -string ART::VerifyAndToString(IndexLock &l, const bool only_verify) { - return VerifyAndToStringInternal(only_verify); +string ART::ToString(IndexLock &l, bool display_ascii) { + return ToStringInternal(display_ascii); } -string ART::VerifyAndToStringInternal(const bool only_verify) { +string ART::ToStringInternal(bool display_ascii) { if (tree.HasMetadata()) { - return "ART: " + tree.VerifyAndToString(*this, only_verify); + return "\nART: \n" + tree.ToString(*this, 0, false, display_ascii); } return "[empty]"; } +void ART::Verify(IndexLock &l) { + VerifyInternal(); +} + +void ART::VerifyInternal() { + if (tree.HasMetadata()) { + tree.Verify(*this); + } +} + void ART::VerifyAllocations(IndexLock &l) { return VerifyAllocationsInternal(); } diff --git a/src/duckdb/src/execution/index/art/art_merger.cpp b/src/duckdb/src/execution/index/art/art_merger.cpp index 70781cbfb..61d2ec317 100644 --- a/src/duckdb/src/execution/index/art/art_merger.cpp +++ b/src/duckdb/src/execution/index/art/art_merger.cpp @@ -217,9 +217,6 @@ void ARTMerger::MergeNodeAndPrefix(Node &node, Node &prefix, const GateStatus pa auto child = node.GetChildMutable(art, byte); // Reduce the prefix to the bytes after pos. - // We always reduce by at least one byte, - // thus, if the prefix was a gate, it no longer is. - prefix.SetGateStatus(GateStatus::GATE_NOT_SET); Prefix::Reduce(art, prefix, pos); if (child) { diff --git a/src/duckdb/src/execution/index/art/base_leaf.cpp b/src/duckdb/src/execution/index/art/base_leaf.cpp index a694ca3b5..4a9332fc9 100644 --- a/src/duckdb/src/execution/index/art/base_leaf.cpp +++ b/src/duckdb/src/execution/index/art/base_leaf.cpp @@ -30,8 +30,10 @@ void BaseLeaf::InsertByteInternal(BaseLeaf &n, const uint8_t byt } template -BaseLeaf &BaseLeaf::DeleteByteInternal(ART &art, Node &node, const uint8_t byte) { - auto &n = Node::Ref(art, node, node.GetType()); +NodeHandle> BaseLeaf::DeleteByteInternal(ART &art, Node &node, + const uint8_t byte) { + NodeHandle> handle(art, node); + auto &n = handle.Get(); uint8_t child_pos = 0; for (; child_pos < n.count; child_pos++) { @@ -45,7 +47,7 @@ BaseLeaf &BaseLeaf::DeleteByteInternal(ART &art, for (uint8_t i = child_pos; i < n.count; i++) { n.key[i] = n.key[i + 1]; } - return n; + return handle; } //===--------------------------------------------------------------------===// @@ -53,27 +55,36 @@ BaseLeaf &BaseLeaf::DeleteByteInternal(ART &art, //===--------------------------------------------------------------------===// void Node7Leaf::InsertByte(ART &art, Node &node, const uint8_t byte) { - // The node is full. Grow to Node15. - auto &n7 = Node::Ref(art, node, NODE_7_LEAF); - if (n7.count == CAPACITY) { - auto node7 = node; - Node15Leaf::GrowNode7Leaf(art, node, node7); - Node15Leaf::InsertByte(art, node, byte); - return; - } + { + NodeHandle handle(art, node); + auto &n7 = handle.Get(); - InsertByteInternal(n7, byte); + if (n7.count != CAPACITY) { + InsertByteInternal(n7, byte); + return; + } + } + // The node is full. Grow to Node15. + auto node7 = node; + Node15Leaf::GrowNode7Leaf(art, node, node7); + Node15Leaf::InsertByte(art, node, byte); } void Node7Leaf::DeleteByte(ART &art, Node &node, Node &prefix, const uint8_t byte, const ARTKey &row_id) { - auto &n7 = DeleteByteInternal(art, node, byte); + idx_t remainder; + { + auto n7_handle = DeleteByteInternal(art, node, byte); + auto &n7 = n7_handle.Get(); + + if (n7.count != 1) { + return; + } - // Compress one-way nodes. - if (n7.count == 1) { + // Compress one-way nodes. D_ASSERT(node.GetGateStatus() == GateStatus::GATE_NOT_SET); // Get the remaining row ID. - auto remainder = UnsafeNumericCast(row_id.GetRowId()) & AND_LAST_BYTE; + remainder = UnsafeNumericCast(row_id.GetRowId()) & AND_LAST_BYTE; remainder |= UnsafeNumericCast(n7.key[0]); // Free the prefix (nodes) and inline the remainder. @@ -82,23 +93,27 @@ void Node7Leaf::DeleteByte(ART &art, Node &node, Node &prefix, const uint8_t byt Leaf::New(prefix, UnsafeNumericCast(remainder)); return; } - - // Free the Node7Leaf and inline the remainder. - Node::FreeNode(art, node); - Leaf::New(node, UnsafeNumericCast(remainder)); } + // Free the Node7Leaf and inline the remainder. + Node::FreeNode(art, node); + Leaf::New(node, UnsafeNumericCast(remainder)); } void Node7Leaf::ShrinkNode15Leaf(ART &art, Node &node7_leaf, Node &node15_leaf) { - auto &n7 = New(art, node7_leaf); - auto &n15 = Node::Ref(art, node15_leaf, NType::NODE_15_LEAF); - node7_leaf.SetGateStatus(node15_leaf.GetGateStatus()); + { + auto n7_handle = New(art, node7_leaf); + auto &n7 = n7_handle.Get(); - n7.count = n15.count; - for (uint8_t i = 0; i < n15.count; i++) { - n7.key[i] = n15.key[i]; - } + NodeHandle n15_handle(art, node15_leaf); + auto &n15 = n15_handle.Get(); + node7_leaf.SetGateStatus(node15_leaf.GetGateStatus()); + + n7.count = n15.count; + for (uint8_t i = 0; i < n15.count; i++) { + n7.key[i] = n15.key[i]; + } + } Node::FreeNode(art, node15_leaf); } @@ -107,54 +122,66 @@ void Node7Leaf::ShrinkNode15Leaf(ART &art, Node &node7_leaf, Node &node15_leaf) //===--------------------------------------------------------------------===// void Node15Leaf::InsertByte(ART &art, Node &node, const uint8_t byte) { - // The node is full. Grow to Node256Leaf. - auto &n15 = Node::Ref(art, node, NODE_15_LEAF); - if (n15.count == CAPACITY) { - auto node15 = node; - Node256Leaf::GrowNode15Leaf(art, node, node15); - Node256Leaf::InsertByte(art, node, byte); - return; + { + NodeHandle n15_handle(art, node); + auto &n15 = n15_handle.Get(); + if (n15.count != CAPACITY) { + InsertByteInternal(n15, byte); + return; + } } - - InsertByteInternal(n15, byte); + auto node15 = node; + Node256Leaf::GrowNode15Leaf(art, node, node15); + Node256Leaf::InsertByte(art, node, byte); } void Node15Leaf::DeleteByte(ART &art, Node &node, const uint8_t byte) { - auto &n15 = DeleteByteInternal(art, node, byte); - - // Shrink node to Node7. - if (n15.count < Node7Leaf::CAPACITY) { - auto node15 = node; - Node7Leaf::ShrinkNode15Leaf(art, node, node15); + { + auto n15_handle = DeleteByteInternal(art, node, byte); + auto &n15 = n15_handle.Get(); + if (n15.count >= Node7Leaf::CAPACITY) { + return; + } } + auto node15 = node; + Node7Leaf::ShrinkNode15Leaf(art, node, node15); } void Node15Leaf::GrowNode7Leaf(ART &art, Node &node15_leaf, Node &node7_leaf) { - auto &n7 = Node::Ref(art, node7_leaf, NType::NODE_7_LEAF); - auto &n15 = New(art, node15_leaf); - node15_leaf.SetGateStatus(node7_leaf.GetGateStatus()); + { + NodeHandle n7_handle(art, node7_leaf); + auto &n7 = n7_handle.Get(); - n15.count = n7.count; - for (uint8_t i = 0; i < n7.count; i++) { - n15.key[i] = n7.key[i]; - } + auto n15_handle = New(art, node15_leaf); + auto &n15 = n15_handle.Get(); + node15_leaf.SetGateStatus(node7_leaf.GetGateStatus()); + n15.count = n7.count; + for (uint8_t i = 0; i < n7.count; i++) { + n15.key[i] = n7.key[i]; + } + } Node::FreeNode(art, node7_leaf); } void Node15Leaf::ShrinkNode256Leaf(ART &art, Node &node15_leaf, Node &node256_leaf) { - auto &n15 = New(art, node15_leaf); - auto &n256 = Node::Ref(art, node256_leaf, NType::NODE_256_LEAF); - node15_leaf.SetGateStatus(node256_leaf.GetGateStatus()); - - ValidityMask mask(&n256.mask[0], Node256::CAPACITY); - for (uint16_t i = 0; i < Node256::CAPACITY; i++) { - if (mask.RowIsValid(i)) { - n15.key[n15.count] = UnsafeNumericCast(i); - n15.count++; + { + auto n15_handle = New(art, node15_leaf); + auto &n15 = n15_handle.Get(); + + NodeHandle n256_handle(art, node256_leaf); + auto &n256 = n256_handle.Get(); + + node15_leaf.SetGateStatus(node256_leaf.GetGateStatus()); + + ValidityMask mask(&n256.mask[0], Node256::CAPACITY); + for (uint16_t i = 0; i < Node256::CAPACITY; i++) { + if (mask.RowIsValid(i)) { + n15.key[n15.count] = UnsafeNumericCast(i); + n15.count++; + } } } - Node::FreeNode(art, node256_leaf); } diff --git a/src/duckdb/src/execution/index/art/iterator.cpp b/src/duckdb/src/execution/index/art/iterator.cpp index 1a88b7262..c8e2d09a9 100644 --- a/src/duckdb/src/execution/index/art/iterator.cpp +++ b/src/duckdb/src/execution/index/art/iterator.cpp @@ -95,125 +95,135 @@ bool Iterator::Scan(const ARTKey &upper_bound, const idx_t max_count, set } void Iterator::FindMinimum(const Node &node) { - D_ASSERT(node.HasMetadata()); + reference ref(node); - // Found the minimum. - if (node.IsAnyLeaf()) { - last_leaf = node; - return; - } + while (ref.get().HasMetadata()) { + // Found the minimum. + if (ref.get().IsAnyLeaf()) { + last_leaf = ref.get(); + return; + } - // We are passing a gate node. - if (node.GetGateStatus() == GateStatus::GATE_SET) { - D_ASSERT(status == GateStatus::GATE_NOT_SET); - status = GateStatus::GATE_SET; - entered_nested_leaf = true; - nested_depth = 0; - } + // We are passing a gate node. + if (ref.get().GetGateStatus() == GateStatus::GATE_SET) { + D_ASSERT(status == GateStatus::GATE_NOT_SET); + status = GateStatus::GATE_SET; + entered_nested_leaf = true; + nested_depth = 0; + } - // Traverse the prefix. - if (node.GetType() == NType::PREFIX) { - Prefix prefix(art, node); - for (idx_t i = 0; i < prefix.data[Prefix::Count(art)]; i++) { - current_key.Push(prefix.data[i]); - if (status == GateStatus::GATE_SET) { - row_id[nested_depth] = prefix.data[i]; - nested_depth++; - D_ASSERT(nested_depth < Prefix::ROW_ID_SIZE); + // Traverse the prefix. + if (ref.get().GetType() == NType::PREFIX) { + Prefix prefix(art, ref.get()); + for (idx_t i = 0; i < prefix.data[Prefix::Count(art)]; i++) { + current_key.Push(prefix.data[i]); + if (status == GateStatus::GATE_SET) { + row_id[nested_depth] = prefix.data[i]; + nested_depth++; + D_ASSERT(nested_depth < Prefix::ROW_ID_SIZE); + } } + nodes.emplace(ref.get(), 0); + ref = *prefix.ptr; + continue; } - nodes.emplace(node, 0); - return FindMinimum(*prefix.ptr); - } - // Go to the leftmost entry in the current node. - uint8_t byte = 0; - auto next = node.GetNextChild(art, byte); - D_ASSERT(next); - - // Recurse on the leftmost node. - current_key.Push(byte); - if (status == GateStatus::GATE_SET) { - row_id[nested_depth] = byte; - nested_depth++; - D_ASSERT(nested_depth < Prefix::ROW_ID_SIZE); + // Go to the leftmost entry in the current node. + uint8_t byte = 0; + auto next = ref.get().GetNextChild(art, byte); + D_ASSERT(next); + + // Move to the leftmost node. + current_key.Push(byte); + if (status == GateStatus::GATE_SET) { + row_id[nested_depth] = byte; + nested_depth++; + D_ASSERT(nested_depth < Prefix::ROW_ID_SIZE); + } + nodes.emplace(ref.get(), byte); + ref = *next; } - nodes.emplace(node, byte); - FindMinimum(*next); + // Should always have a node with metadata. + throw InternalException("ART Iterator::FindMinimum: Reached node without metadata"); } -bool Iterator::LowerBound(const Node &node, const ARTKey &key, const bool equal, idx_t depth) { - if (!node.HasMetadata()) { - return false; - } +bool Iterator::LowerBound(const Node &node, const ARTKey &key, const bool equal) { + reference ref(node); + idx_t depth = 0; + + while (ref.get().HasMetadata()) { + // We found any leaf node, or a gate. + if (ref.get().IsAnyLeaf() || ref.get().GetGateStatus() == GateStatus::GATE_SET) { + D_ASSERT(status == GateStatus::GATE_NOT_SET); + D_ASSERT(current_key.Size() == key.len); + if (!equal && current_key.Contains(key)) { + return Next(); + } - // We found any leaf node, or a gate. - if (node.IsAnyLeaf() || node.GetGateStatus() == GateStatus::GATE_SET) { - D_ASSERT(status == GateStatus::GATE_NOT_SET); - D_ASSERT(current_key.Size() == key.len); - if (!equal && current_key.Contains(key)) { - return Next(); + if (ref.get().GetGateStatus() == GateStatus::GATE_SET) { + FindMinimum(ref.get()); + } else { + last_leaf = ref.get(); + } + return true; } - if (node.GetGateStatus() == GateStatus::GATE_SET) { - FindMinimum(node); - } else { - last_leaf = node; - } - return true; - } + D_ASSERT(ref.get().GetGateStatus() == GateStatus::GATE_NOT_SET); + if (ref.get().GetType() != NType::PREFIX) { + auto next_byte = key[depth]; + auto child = ref.get().GetNextChild(art, next_byte); - D_ASSERT(node.GetGateStatus() == GateStatus::GATE_NOT_SET); - if (node.GetType() != NType::PREFIX) { - auto next_byte = key[depth]; - auto child = node.GetNextChild(art, next_byte); + // The key is greater than any key in this subtree. + if (!child) { + return Next(); + } - // The key is greater than any key in this subtree. - if (!child) { - return Next(); - } + current_key.Push(next_byte); + nodes.emplace(ref.get(), next_byte); - current_key.Push(next_byte); - nodes.emplace(node, next_byte); + // We return the minimum because all keys are greater than the lower bound. + if (next_byte > key[depth]) { + FindMinimum(*child); + return true; + } - // We return the minimum because all keys are greater than the lower bound. - if (next_byte > key[depth]) { - FindMinimum(*child); - return true; + // Move to the child and increment depth. + ref = *child; + depth++; + continue; } - // We recurse into the child. - return LowerBound(*child, key, equal, depth + 1); - } - - // Push back all prefix bytes. - Prefix prefix(art, node); - for (idx_t i = 0; i < prefix.data[Prefix::Count(art)]; i++) { - current_key.Push(prefix.data[i]); - } - nodes.emplace(node, 0); - - // We compare the prefix bytes with the key bytes. - for (idx_t i = 0; i < prefix.data[Prefix::Count(art)]; i++) { - // We found a prefix byte that is less than its corresponding key byte. - // I.e., the subsequent node is lesser than the key. Thus, the next node - // is the lower bound. - if (prefix.data[i] < key[depth + i]) { - return Next(); + // Push back all prefix bytes. + Prefix prefix(art, ref.get()); + for (idx_t i = 0; i < prefix.data[Prefix::Count(art)]; i++) { + current_key.Push(prefix.data[i]); } + nodes.emplace(ref.get(), 0); - // We found a prefix byte that is greater than its corresponding key byte. - // I.e., the subsequent node is greater than the key. Thus, the minimum is - // the lower bound. - if (prefix.data[i] > key[depth + i]) { - FindMinimum(*prefix.ptr); - return true; + // We compare the prefix bytes with the key bytes. + for (idx_t i = 0; i < prefix.data[Prefix::Count(art)]; i++) { + // We found a prefix byte that is less than its corresponding key byte. + // I.e., the subsequent node is lesser than the key. Thus, the next node + // is the lower bound. + if (prefix.data[i] < key[depth + i]) { + return Next(); + } + + // We found a prefix byte that is greater than its corresponding key byte. + // I.e., the subsequent node is greater than the key. Thus, the minimum is + // the lower bound. + if (prefix.data[i] > key[depth + i]) { + FindMinimum(*prefix.ptr); + return true; + } } - } - // The prefix matches the key. We recurse into the child. - depth += prefix.data[Prefix::Count(art)]; - return LowerBound(*prefix.ptr, key, equal, depth); + // The prefix matches the key. Move to the child and update depth. + depth += prefix.data[Prefix::Count(art)]; + ref = *prefix.ptr; + } + // Should always have a node with metadata. + throw InternalException("ART Iterator::LowerBound: Reached node without metadata"); } bool Iterator::Next() { diff --git a/src/duckdb/src/execution/index/art/leaf.cpp b/src/duckdb/src/execution/index/art/leaf.cpp index f6c7751d6..31ef52b6c 100644 --- a/src/duckdb/src/execution/index/art/leaf.cpp +++ b/src/duckdb/src/execution/index/art/leaf.cpp @@ -162,7 +162,6 @@ bool Leaf::DeprecatedGetRowIds(ART &art, const Node &node, set &row_ids, reference ref(node); while (ref.get().HasMetadata()) { - auto &leaf = Node::Ref(art, ref, LEAF); if (row_ids.size() + leaf.count > max_count) { return false; @@ -191,15 +190,12 @@ void Leaf::DeprecatedVacuum(ART &art, Node &node) { } } -string Leaf::DeprecatedVerifyAndToString(ART &art, const Node &node, const bool only_verify) { - D_ASSERT(node.GetType() == LEAF); - +string Leaf::DeprecatedToString(ART &art, const Node &node) { string str = ""; reference ref(node); while (ref.get().HasMetadata()) { auto &leaf = Node::Ref(art, ref, LEAF); - D_ASSERT(leaf.count <= LEAF_SIZE); str += "Leaf [count: " + to_string(leaf.count) + ", row IDs: "; for (uint8_t i = 0; i < leaf.count; i++) { @@ -209,7 +205,19 @@ string Leaf::DeprecatedVerifyAndToString(ART &art, const Node &node, const bool ref = leaf.ptr; } - return only_verify ? "" : str; + return str; +} + +void Leaf::DeprecatedVerify(ART &art, const Node &node) { + D_ASSERT(node.GetType() == LEAF); + + reference ref(node); + + while (ref.get().HasMetadata()) { + auto &leaf = Node::Ref(art, ref, LEAF); + D_ASSERT(leaf.count <= LEAF_SIZE); + ref = leaf.ptr; + } } void Leaf::DeprecatedVerifyAllocations(ART &art, unordered_map &node_counts) const { diff --git a/src/duckdb/src/execution/index/art/node.cpp b/src/duckdb/src/execution/index/art/node.cpp index 478f18166..1cfca7e40 100644 --- a/src/duckdb/src/execution/index/art/node.cpp +++ b/src/duckdb/src/execution/index/art/node.cpp @@ -391,44 +391,29 @@ void Node::TransformToDeprecated(ART &art, Node &node, // Verification //===--------------------------------------------------------------------===// -string Node::VerifyAndToString(ART &art, const bool only_verify) const { +void Node::Verify(ART &art) const { D_ASSERT(HasMetadata()); auto type = GetType(); switch (type) { case NType::LEAF_INLINED: - return only_verify ? "" : "Inlined Leaf [row ID: " + to_string(GetRowId()) + "]"; + return; case NType::LEAF: - return Leaf::DeprecatedVerifyAndToString(art, *this, only_verify); + Leaf::DeprecatedVerify(art, *this); + return; case NType::PREFIX: { - auto str = Prefix::VerifyAndToString(art, *this, only_verify); - if (GetGateStatus() == GateStatus::GATE_SET) { - str = "Gate [ " + str + " ]"; - } - return only_verify ? "" : "\n" + str; + Prefix::Verify(art, *this); + return; } default: break; } - string str = "Node" + to_string(GetCapacity(type)) + ": [ "; - uint8_t byte = 0; - - if (IsLeafNode()) { - str = "Leaf " + str; - auto has_byte = GetNextByte(art, byte); - while (has_byte) { - str += to_string(byte) + "-"; - if (byte == NumericLimits::Maximum()) { - break; - } - byte++; - has_byte = GetNextByte(art, byte); - } - } else { + if (!IsLeafNode()) { + uint8_t byte = 0; auto child = GetNextChild(art, byte); while (child) { - str += "(" + to_string(byte) + ", " + child->VerifyAndToString(art, only_verify) + ")"; + child->Verify(art); if (byte == NumericLimits::Maximum()) { break; } @@ -436,11 +421,6 @@ string Node::VerifyAndToString(ART &art, const bool only_verify) const { child = GetNextChild(art, byte); } } - - if (GetGateStatus() == GateStatus::GATE_SET) { - str = "Gate [ " + str + " ]"; - } - return only_verify ? "" : "\n" + str + "]"; } void Node::VerifyAllocations(ART &art, unordered_map &node_counts) const { @@ -482,4 +462,87 @@ void Node::VerifyAllocations(ART &art, unordered_map &node_count scanner.Scan(handler); } +//===--------------------------------------------------------------------===// +// Printing +//===--------------------------------------------------------------------===// + +string Node::ToString(ART &art, idx_t indent_level, bool inside_gate, bool display_ascii) const { + auto indent = [](string &str, const idx_t n) { + for (idx_t i = 0; i < n; ++i) { + str += " "; + } + }; + // if inside gate, print byte values not ascii. + auto format_byte = [&](uint8_t byte) { + if (!inside_gate && display_ascii && byte >= 32 && byte <= 126) { + return string(1, static_cast(byte)); + } + return to_string(byte); + }; + auto type = GetType(); + bool is_gate = GetGateStatus() == GateStatus::GATE_SET; + bool propagate_gate = inside_gate || is_gate; + + switch (type) { + case NType::LEAF_INLINED: { + string str = ""; + indent(str, indent_level); + return str + "Inlined Leaf [row ID: " + to_string(GetRowId()) + "]\n"; + } + case NType::LEAF: + return Leaf::DeprecatedToString(art, *this); + case NType::PREFIX: { + string str = Prefix::ToString(art, *this, indent_level, propagate_gate, display_ascii); + if (is_gate) { + string s = ""; + indent(s, indent_level); + s += "Gate\n"; + return s + str; + } + string s = ""; + return s + str; + } + default: + break; + } + string str = ""; + indent(str, indent_level); + str = str + "Node" + to_string(GetCapacity(type)) += "\n"; + uint8_t byte = 0; + + if (IsLeafNode()) { + indent(str, indent_level); + str += "Leaf |"; + auto has_byte = GetNextByte(art, byte); + while (has_byte) { + str += format_byte(byte) + "|"; + if (byte == NumericLimits::Maximum()) { + break; + } + byte++; + has_byte = GetNextByte(art, byte); + } + str += "\n"; + } else { + auto child = GetNextChild(art, byte); + while (child) { + string c = child->ToString(art, indent_level + 2, propagate_gate, display_ascii); + indent(str, indent_level); + str = str + format_byte(byte) + ",\n" + c; + if (byte == NumericLimits::Maximum()) { + break; + } + byte++; + child = GetNextChild(art, byte); + } + } + + if (is_gate) { + string s = ""; + indent(s, indent_level + 2); + str = "Gate\n" + s + str; + } + return str; +} + } // namespace duckdb diff --git a/src/duckdb/src/execution/index/art/prefix.cpp b/src/duckdb/src/execution/index/art/prefix.cpp index 00e94967a..1368f3b33 100644 --- a/src/duckdb/src/execution/index/art/prefix.cpp +++ b/src/duckdb/src/execution/index/art/prefix.cpp @@ -100,6 +100,10 @@ void Prefix::Reduce(ART &art, Node &node, const idx_t pos) { D_ASSERT(node.HasMetadata()); D_ASSERT(pos < Count(art)); + // We always reduce by at least one byte, + // thus, if the prefix was a gate, it no longer is. + node.SetGateStatus(GateStatus::GATE_NOT_SET); + Prefix prefix(art, node); if (pos == idx_t(prefix.data[Count(art)] - 1)) { auto next = *prefix.ptr; @@ -182,23 +186,41 @@ GateStatus Prefix::Split(ART &art, reference &node, Node &child, const uin return GateStatus::GATE_NOT_SET; } -string Prefix::VerifyAndToString(ART &art, const Node &node, const bool only_verify) { +string Prefix::ToString(ART &art, const Node &node, idx_t indent_level, bool inside_gate, bool display_ascii) { + auto indent = [](string &str, const idx_t n) { + for (idx_t i = 0; i < n; ++i) { + str += " "; + } + }; + auto format_byte = [&](uint8_t byte) { + if (!inside_gate && display_ascii && byte >= 32 && byte <= 126) { + return string(1, static_cast(byte)); + } + return to_string(byte); + }; string str = ""; + indent(str, indent_level); + reference ref(node); + Iterator(art, ref, true, false, [&](const Prefix &prefix) { + str += "Prefix: |"; + for (idx_t i = 0; i < prefix.data[Count(art)]; i++) { + str += format_byte(prefix.data[i]) + "|"; + } + }); + + auto child = ref.get().ToString(art, indent_level, inside_gate, display_ascii); + return str + "\n" + child; +} + +void Prefix::Verify(ART &art, const Node &node) { reference ref(node); Iterator(art, ref, true, false, [&](Prefix &prefix) { D_ASSERT(prefix.data[Count(art)] != 0); D_ASSERT(prefix.data[Count(art)] <= Count(art)); - - str += " Prefix :[ "; - for (idx_t i = 0; i < prefix.data[Count(art)]; i++) { - str += to_string(prefix.data[i]) + "-"; - } - str += " ] "; }); - auto child = ref.get().VerifyAndToString(art, only_verify); - return only_verify ? "" : str + child; + ref.get().Verify(art); } void Prefix::TransformToDeprecated(ART &art, Node &node, unsafe_unique_ptr &allocator) { diff --git a/src/duckdb/src/execution/index/bound_index.cpp b/src/duckdb/src/execution/index/bound_index.cpp index 2c0d43d91..7d215c421 100644 --- a/src/duckdb/src/execution/index/bound_index.cpp +++ b/src/duckdb/src/execution/index/bound_index.cpp @@ -18,7 +18,6 @@ BoundIndex::BoundIndex(const string &name, const string &index_type, IndexConstr const vector> &unbound_expressions_p, AttachedDatabase &db) : Index(column_ids, table_io_manager, db), name(name), index_type(index_type), index_constraint_type(index_constraint_type) { - for (auto &expr : unbound_expressions_p) { types.push_back(expr->return_type.InternalType()); logical_types.push_back(expr->return_type); @@ -79,10 +78,16 @@ bool BoundIndex::MergeIndexes(BoundIndex &other_index) { return MergeIndexes(state, other_index); } -string BoundIndex::VerifyAndToString(const bool only_verify) { +void BoundIndex::Verify() { + IndexLock l; + InitializeLock(l); + Verify(l); +} + +string BoundIndex::ToString(bool display_ascii) { IndexLock l; InitializeLock(l); - return VerifyAndToString(l, only_verify); + return ToString(l, display_ascii); } void BoundIndex::VerifyAllocations() { @@ -154,28 +159,39 @@ string BoundIndex::AppendRowError(DataChunk &input, idx_t index) { return error; } -void BoundIndex::ApplyBufferedAppends(const vector &table_types, ColumnDataCollection &buffered_appends, +void BoundIndex::ApplyBufferedReplays(const vector &table_types, + vector &buffered_replays, const vector &mapped_column_ids) { - IndexAppendInfo index_append_info(IndexAppendMode::INSERT_DUPLICATES, nullptr); - - ColumnDataScanState state; - buffered_appends.InitializeScan(state); - - DataChunk scan_chunk; - buffered_appends.InitializeScanChunk(scan_chunk); - DataChunk table_chunk; - table_chunk.InitializeEmpty(table_types); - - while (buffered_appends.Scan(state, scan_chunk)) { - for (idx_t i = 0; i < scan_chunk.ColumnCount() - 1; i++) { - auto col_id = mapped_column_ids[i].GetPrimaryIndex(); - table_chunk.data[col_id].Reference(scan_chunk.data[i]); - } - table_chunk.SetCardinality(scan_chunk.size()); - - auto error = Append(table_chunk, scan_chunk.data.back(), index_append_info); - if (error.HasError()) { - throw InternalException("error while applying buffered appends: " + error.Message()); + for (auto &replay : buffered_replays) { + ColumnDataScanState state; + auto &buffered_data = *replay.data; + buffered_data.InitializeScan(state); + + DataChunk scan_chunk; + buffered_data.InitializeScanChunk(scan_chunk); + DataChunk table_chunk; + table_chunk.InitializeEmpty(table_types); + + while (buffered_data.Scan(state, scan_chunk)) { + for (idx_t i = 0; i < scan_chunk.ColumnCount() - 1; i++) { + auto col_id = mapped_column_ids[i].GetPrimaryIndex(); + table_chunk.data[col_id].Reference(scan_chunk.data[i]); + } + table_chunk.SetCardinality(scan_chunk.size()); + + switch (replay.type) { + case BufferedIndexReplay::INSERT_ENTRY: { + IndexAppendInfo index_append_info(IndexAppendMode::INSERT_DUPLICATES, nullptr); + auto error = Append(table_chunk, scan_chunk.data.back(), index_append_info); + if (error.HasError()) { + throw InternalException("error while applying buffered appends: " + error.Message()); + } + continue; + } + case BufferedIndexReplay::DEL_ENTRY: { + Delete(table_chunk, scan_chunk.data.back()); + } + } } } } diff --git a/src/duckdb/src/execution/index/fixed_size_allocator.cpp b/src/duckdb/src/execution/index/fixed_size_allocator.cpp index dd4758bb9..3b1572d75 100644 --- a/src/duckdb/src/execution/index/fixed_size_allocator.cpp +++ b/src/duckdb/src/execution/index/fixed_size_allocator.cpp @@ -4,10 +4,9 @@ namespace duckdb { -FixedSizeAllocator::FixedSizeAllocator(const idx_t segment_size, BlockManager &block_manager) - : block_manager(block_manager), buffer_manager(block_manager.buffer_manager), segment_size(segment_size), - total_segment_count(0) { - +FixedSizeAllocator::FixedSizeAllocator(const idx_t segment_size, BlockManager &block_manager, MemoryTag memory_tag) + : block_manager(block_manager), buffer_manager(block_manager.buffer_manager), memory_tag(memory_tag), + segment_size(segment_size), total_segment_count(0) { if (segment_size > block_manager.GetBlockSize() - sizeof(validity_t)) { throw InternalException("The maximum segment size of fixed-size allocators is " + to_string(block_manager.GetBlockSize() - sizeof(validity_t))); @@ -48,7 +47,7 @@ IndexPointer FixedSizeAllocator::New() { if (!buffer_with_free_space.IsValid()) { // Add a new buffer. auto buffer_id = GetAvailableBufferId(); - buffers[buffer_id] = make_uniq(block_manager); + buffers[buffer_id] = make_uniq(block_manager, memory_tag); buffers_with_free_space.insert(buffer_id); buffer_with_free_space = buffer_id; @@ -321,7 +320,6 @@ void FixedSizeAllocator::Init(const FixedSizeAllocatorInfo &info) { total_segment_count = 0; for (idx_t i = 0; i < info.buffer_ids.size(); i++) { - // read all FixedSizeBuffer data auto buffer_id = info.buffer_ids[i]; diff --git a/src/duckdb/src/execution/index/fixed_size_buffer.cpp b/src/duckdb/src/execution/index/fixed_size_buffer.cpp index 82bbccac2..1cf36c1a5 100644 --- a/src/duckdb/src/execution/index/fixed_size_buffer.cpp +++ b/src/duckdb/src/execution/index/fixed_size_buffer.cpp @@ -35,12 +35,11 @@ void PartialBlockForIndex::Clear() { constexpr idx_t FixedSizeBuffer::BASE[]; constexpr uint8_t FixedSizeBuffer::SHIFT[]; -FixedSizeBuffer::FixedSizeBuffer(BlockManager &block_manager) +FixedSizeBuffer::FixedSizeBuffer(BlockManager &block_manager, MemoryTag memory_tag) : block_manager(block_manager), readers(0), segment_count(0), allocation_size(0), dirty(false), vacuum(false), loaded(false), block_pointer(), block_handle(nullptr) { - auto &buffer_manager = block_manager.buffer_manager; - buffer_handle = buffer_manager.Allocate(MemoryTag::ART_INDEX, &block_manager, false); + buffer_handle = buffer_manager.Allocate(memory_tag, &block_manager, false); block_handle = buffer_handle.GetBlockHandle(); // Zero-initialize the buffer as it might get serialized to storage. @@ -52,7 +51,6 @@ FixedSizeBuffer::FixedSizeBuffer(BlockManager &block_manager, const idx_t segmen const BlockPointer &block_pointer) : block_manager(block_manager), readers(0), segment_count(segment_count), allocation_size(allocation_size), dirty(false), vacuum(false), loaded(false), block_pointer(block_pointer) { - D_ASSERT(block_pointer.IsValid()); block_handle = block_manager.RegisterBlock(block_pointer.block_id); D_ASSERT(block_handle->BlockId() < MAXIMUM_BLOCK); @@ -159,7 +157,6 @@ void FixedSizeBuffer::LoadFromDisk() { } uint32_t FixedSizeBuffer::GetOffset(const idx_t bitmask_count, const idx_t available_segments) { - // Get a handle to the buffer's validity mask (offset 0). SegmentHandle handle(*this, 0); const auto bitmask_ptr = handle.GetPtr(); diff --git a/src/duckdb/src/execution/index/index_type_set.cpp b/src/duckdb/src/execution/index/index_type_set.cpp index 4fe7cda4f..0422f8a02 100644 --- a/src/duckdb/src/execution/index/index_type_set.cpp +++ b/src/duckdb/src/execution/index/index_type_set.cpp @@ -5,7 +5,6 @@ namespace duckdb { IndexTypeSet::IndexTypeSet() { - // Register the ART index type by default IndexType art_index_type; art_index_type.name = ART::TYPE_NAME; diff --git a/src/duckdb/src/execution/index/unbound_index.cpp b/src/duckdb/src/execution/index/unbound_index.cpp index 0d117ca92..f79a5d9b8 100644 --- a/src/duckdb/src/execution/index/unbound_index.cpp +++ b/src/duckdb/src/execution/index/unbound_index.cpp @@ -8,11 +8,14 @@ namespace duckdb { +BufferedIndexData::BufferedIndexData(BufferedIndexReplay replay_type, unique_ptr data_p) + : type(replay_type), data(std::move(data_p)) { +} + UnboundIndex::UnboundIndex(unique_ptr create_info, IndexStorageInfo storage_info_p, TableIOManager &table_io_manager, AttachedDatabase &db) : Index(create_info->Cast().column_ids, table_io_manager, db), create_info(std::move(create_info)), storage_info(std::move(storage_info_p)) { - // Memory safety check. for (idx_t info_idx = 0; info_idx < storage_info.allocator_infos.size(); info_idx++) { auto &info = storage_info.allocator_infos[info_idx]; @@ -35,26 +38,33 @@ void UnboundIndex::CommitDrop() { } } -void UnboundIndex::BufferChunk(DataChunk &chunk, Vector &row_ids, const vector &mapped_column_ids_p) { +void UnboundIndex::BufferChunk(DataChunk &index_column_chunk, Vector &row_ids, + const vector &mapped_column_ids_p, BufferedIndexReplay replay_type) { D_ASSERT(!column_ids.empty()); - auto types = chunk.GetTypes(); + auto types = index_column_chunk.GetTypes(); // column types types.push_back(LogicalType::ROW_TYPE); - if (!buffered_appends) { - auto &allocator = Allocator::Get(db); - buffered_appends = make_uniq(allocator, types); + auto &allocator = Allocator::Get(db); + + BufferedIndexData buffered_data(replay_type, make_uniq(allocator, types)); + + //! First time we are buffering data, canonical column_id mapping is stored. + //! This should be a sorted list of all the physical offsets of Indexed columns on this table. + if (mapped_column_ids.empty()) { mapped_column_ids = mapped_column_ids_p; } D_ASSERT(mapped_column_ids == mapped_column_ids_p); + // Combined chunk has all the indexed columns and rowids. DataChunk combined_chunk; combined_chunk.InitializeEmpty(types); - for (idx_t i = 0; i < chunk.ColumnCount(); i++) { - combined_chunk.data[i].Reference(chunk.data[i]); + for (idx_t i = 0; i < index_column_chunk.ColumnCount(); i++) { + combined_chunk.data[i].Reference(index_column_chunk.data[i]); } combined_chunk.data.back().Reference(row_ids); - combined_chunk.SetCardinality(chunk.size()); - buffered_appends->Append(combined_chunk); + combined_chunk.SetCardinality(index_column_chunk.size()); + buffered_data.data->Append(combined_chunk); + buffered_replays.emplace_back(std::move(buffered_data)); } } // namespace duckdb diff --git a/src/duckdb/src/execution/join_hashtable.cpp b/src/duckdb/src/execution/join_hashtable.cpp index f991ead7e..8327049b9 100644 --- a/src/duckdb/src/execution/join_hashtable.cpp +++ b/src/duckdb/src/execution/join_hashtable.cpp @@ -45,7 +45,6 @@ JoinHashTable::JoinHashTable(ClientContext &context_p, const PhysicalOperator &o auto type = condition.left->return_type; if (condition.comparison == ExpressionType::COMPARE_EQUAL || condition.comparison == ExpressionType::COMPARE_NOT_DISTINCT_FROM) { - // ensure that all equality conditions are at the front, // and that all other conditions are at the back D_ASSERT(equality_types.size() == condition_types.size()); @@ -82,7 +81,6 @@ JoinHashTable::JoinHashTable(ClientContext &context_p, const PhysicalOperator &o // Initialize the row matcher that are used for filtering during the probing only if there are non-equality if (!non_equality_predicates.empty()) { - row_matcher_probe = unique_ptr(new RowMatcher()); row_matcher_probe_no_match_sel = unique_ptr(new RowMatcher()); @@ -172,7 +170,6 @@ idx_t GetOptionalIndex(const SelectionVector *sel, const idx_t idx) { static void AddPointerToCompare(JoinHashTable::ProbeState &state, const ht_entry_t &entry, Vector &pointers_result_v, idx_t row_ht_offset, idx_t &keys_to_compare_count, const idx_t &row_index) { - const auto row_ptr_insert_to = FlatVector::GetData(pointers_result_v); const auto ht_offsets_and_salts = FlatVector::GetData(state.ht_offsets_and_salts_v); @@ -189,13 +186,11 @@ static void AddPointerToCompare(JoinHashTable::ProbeState &state, const ht_entry template static idx_t ProbeForPointersInternal(JoinHashTable::ProbeState &state, JoinHashTable &ht, ht_entry_t *entries, Vector &pointers_result_v, const SelectionVector *row_sel, idx_t &count) { - auto hashes_dense = FlatVector::GetData(state.hashes_dense_v); idx_t keys_to_compare_count = 0; for (idx_t i = 0; i < count; i++) { - auto row_hash = hashes_dense[i]; // hashes have been flattened before -> always access dense auto row_ht_offset = row_hash & ht.bitmask; @@ -260,7 +255,6 @@ static void GetRowPointersInternal(DataChunk &keys, TupleDataChunkState &key_sta Vector &hashes_v, const SelectionVector *row_sel, idx_t &count, JoinHashTable &ht, ht_entry_t *entries, Vector &pointers_result_v, SelectionVector &match_sel, bool has_row_sel) { - // densify hashes: If there is no sel, flatten the hashes, else densify via UnifiedVectorFormat if (has_row_sel) { UnifiedVectorFormat hashes_unified_v; @@ -339,7 +333,6 @@ inline bool JoinHashTable::UseSalt() const { void JoinHashTable::GetRowPointers(DataChunk &keys, TupleDataChunkState &key_state, ProbeState &state, Vector &hashes_v, const SelectionVector *sel, idx_t &count, Vector &pointers_result_v, SelectionVector &match_sel, const bool has_sel) { - if (UseSalt()) { GetRowPointersInternal(keys, key_state, state, hashes_v, sel, count, *this, entries, pointers_result_v, match_sel, has_sel); @@ -888,7 +881,6 @@ bool ScanStructure::PointersExhausted() const { } idx_t ScanStructure::ResolvePredicates(DataChunk &keys, SelectionVector &match_sel, SelectionVector *no_match_sel) { - // Initialize the found_match array to the current sel_vector for (idx_t i = 0; i < this->count; ++i) { match_sel.set_index(i, this->sel_vector.get_index(i)); @@ -934,7 +926,6 @@ idx_t ScanStructure::ScanInnerJoin(DataChunk &keys, SelectionVector &result_vect } void ScanStructure::AdvancePointers(const SelectionVector &sel, const idx_t sel_count) { - if (!ht.chains_longer_than_one) { this->count = 0; return; diff --git a/src/duckdb/src/execution/operator/aggregate/distinct_aggregate_data.cpp b/src/duckdb/src/execution/operator/aggregate/distinct_aggregate_data.cpp index dc37353f7..0e7910cbd 100644 --- a/src/duckdb/src/execution/operator/aggregate/distinct_aggregate_data.cpp +++ b/src/duckdb/src/execution/operator/aggregate/distinct_aggregate_data.cpp @@ -29,7 +29,6 @@ DistinctAggregateCollectionInfo::DistinctAggregateCollectionInfo(const vector(); auto &sink = input.local_state.Cast(); diff --git a/src/duckdb/src/execution/operator/aggregate/physical_ungrouped_aggregate.cpp b/src/duckdb/src/execution/operator/aggregate/physical_ungrouped_aggregate.cpp index a2a3da965..8bd152bb3 100644 --- a/src/duckdb/src/execution/operator/aggregate/physical_ungrouped_aggregate.cpp +++ b/src/duckdb/src/execution/operator/aggregate/physical_ungrouped_aggregate.cpp @@ -28,7 +28,6 @@ PhysicalUngroupedAggregate::PhysicalUngroupedAggregate(PhysicalPlan &physical_pl : PhysicalOperator(physical_plan, PhysicalOperatorType::UNGROUPED_AGGREGATE, std::move(types), estimated_cardinality), aggregates(std::move(expressions)) { - distinct_collection_info = DistinctAggregateCollectionInfo::Create(aggregates); if (!distinct_collection_info) { return; @@ -239,7 +238,6 @@ class UngroupedAggregateLocalSinkState : public LocalSinkState { public: void InitializeDistinctAggregates(const PhysicalUngroupedAggregate &op, const UngroupedAggregateGlobalSinkState &gstate, ExecutionContext &context) { - if (!op.distinct_data) { return; } @@ -628,7 +626,8 @@ void VerifyNullHandling(DataChunk &chunk, UngroupedAggregateState &state, #ifdef DEBUG for (idx_t aggr_idx = 0; aggr_idx < aggregates.size(); aggr_idx++) { auto &aggr = aggregates[aggr_idx]->Cast(); - if (state.counts[aggr_idx] == 0 && aggr.function.null_handling == FunctionNullHandling::DEFAULT_NULL_HANDLING) { + if (state.counts[aggr_idx] == 0 && + aggr.function.GetNullHandling() == FunctionNullHandling::DEFAULT_NULL_HANDLING) { // Default is when 0 values go in, NULL comes out UnifiedVectorFormat vdata; chunk.data[aggr_idx].ToUnifiedFormat(1, vdata); diff --git a/src/duckdb/src/execution/operator/aggregate/physical_window.cpp b/src/duckdb/src/execution/operator/aggregate/physical_window.cpp index 102f491f0..98801c042 100644 --- a/src/duckdb/src/execution/operator/aggregate/physical_window.cpp +++ b/src/duckdb/src/execution/operator/aggregate/physical_window.cpp @@ -17,7 +17,7 @@ namespace duckdb { // Global sink state class WindowGlobalSinkState; -enum WindowGroupStage : uint8_t { MASK, SINK, FINALIZE, GETDATA, DONE }; +enum WindowGroupStage : uint8_t { SORT, MATERIALIZE, MASK, SINK, FINALIZE, GETDATA, DONE }; struct WindowSourceTask { WindowSourceTask() { @@ -48,17 +48,28 @@ class WindowHashGroup { using Task = WindowSourceTask; using TaskPtr = optional_ptr; using ScannerPtr = unique_ptr; + using ChunkRow = HashedSort::ChunkRow; - WindowHashGroup(WindowGlobalSinkState &gsink, HashGroupPtr &sorted, const idx_t hash_bin_p); + template + static T BinValue(T n, T val) { + return ((n + (val - 1)) / val); + } + + WindowHashGroup(WindowGlobalSinkState &gsink, const ChunkRow &chunk_row, const idx_t hash_bin_p); void AllocateMasks(); void ComputeMasks(const idx_t begin_idx, const idx_t end_idx); ExecutorGlobalStates &GetGlobalStates(ClientContext &client); + //! The number of chunks in the group + inline idx_t ChunkCount() const { + return blocks; + } + // The total number of tasks we will execute per thread inline idx_t GetTaskCount() const { - return GetThreadCount() * (uint8_t(WindowGroupStage::DONE) - uint8_t(WindowGroupStage::MASK)); + return GetThreadCount() * (uint8_t(WindowGroupStage::DONE) - uint8_t(WindowGroupStage::SORT)); } // The total number of threads we will use inline idx_t GetThreadCount() const { @@ -76,9 +87,24 @@ class WindowHashGroup { return stage; } + void GetColumnData(ExecutionContext &context, const idx_t blocks, OperatorSourceInput &source) { + } + bool TryPrepareNextStage() { lock_guard prepare_guard(lock); switch (stage.load()) { + case WindowGroupStage::SORT: + if (sorted == blocks) { + stage = WindowGroupStage::MATERIALIZE; + return true; + } + return false; + case WindowGroupStage::MATERIALIZE: + if (materialized == blocks && rows.get()) { + stage = WindowGroupStage::MASK; + return true; + } + return false; case WindowGroupStage::MASK: if (masked == blocks) { stage = WindowGroupStage::SINK; @@ -118,7 +144,7 @@ class WindowHashGroup { task.thread_idx = next_task % group_threads; task.group_idx = hash_bin; task.begin_idx = task.thread_idx * per_thread; - task.max_idx = rows->ChunkCount(); + task.max_idx = ChunkCount(); task.end_idx = MinValue(task.begin_idx + per_thread, task.max_idx); ++next_task; return true; @@ -130,13 +156,11 @@ class WindowHashGroup { //! The shared global state from sinking WindowGlobalSinkState &gsink; //! The hash partition data - HashGroupPtr hash_group; + HashGroupPtr rows; //! The size of the group idx_t count = 0; //! The number of blocks in the group idx_t blocks = 0; - unique_ptr rows; - TupleDataLayout layout; //! The partition boundary mask ValidityMask partition_mask; //! The order boundary mask @@ -160,6 +184,10 @@ class WindowHashGroup { idx_t group_threads = 0; //! The next task to process idx_t next_task = 0; + //! Count of sorted run blocks + std::atomic sorted; + //! Count of materialized run blocks + std::atomic materialized; //! Count of masked blocks std::atomic masked; //! Count of sunk rows @@ -218,7 +246,6 @@ PhysicalWindow::PhysicalWindow(PhysicalPlan &physical_plan, vector PhysicalOperatorType type) : PhysicalOperator(physical_plan, type, std::move(types), estimated_cardinality), select_list(std::move(select_list_p)), order_idx(0), is_order_dependent(false) { - idx_t max_orders = 0; for (idx_t i = 0; i < select_list.size(); ++i) { auto &expr = select_list[i]; @@ -271,7 +298,6 @@ static unique_ptr WindowExecutorFactory(BoundWindowExpression &w WindowGlobalSinkState::WindowGlobalSinkState(const PhysicalWindow &op, ClientContext &client) : op(op), client(client), count(0) { - D_ASSERT(op.select_list[op.order_idx]->GetExpressionClass() == ExpressionClass::BOUND_WINDOW); auto &wexpr = op.select_list[op.order_idx]->Cast(); @@ -324,14 +350,7 @@ SinkFinalizeType PhysicalWindow::Finalize(Pipeline &pipeline, Event &event, Clie auto &hashed_sink = *gsink.hashed_sink; OperatorSinkFinalizeInput hfinalize {hashed_sink, input.interrupt_state}; - auto result = global_partition.Finalize(client, hfinalize); - - // Did we get any data? - if (result != SinkFinalizeType::READY) { - return result; - } - - return global_partition.MaterializeHashGroups(pipeline, event, *this, hfinalize); + return global_partition.Finalize(client, hfinalize); } ProgressData PhysicalWindow::GetSinkProgress(ClientContext &context, GlobalSinkState &gstate, @@ -367,6 +386,8 @@ class WindowGlobalSourceState : public GlobalSourceState { ClientContext &client; //! All the sunk data WindowGlobalSinkState &gsink; + //! The hashed sort global source state for delayed sorting + unique_ptr hashed_source; //! The sorted hash groups vector window_hash_groups; //! The total number of blocks to process; @@ -404,20 +425,18 @@ class WindowGlobalSourceState : public GlobalSourceState { WindowGlobalSourceState::WindowGlobalSourceState(ClientContext &client, WindowGlobalSinkState &gsink_p) : client(client), gsink(gsink_p), next_group(0), locals(0), started(0), finished(0), stopped(false), completed(0) { - auto &global_partition = *gsink.global_partition; - auto hashed_source = global_partition.GetGlobalSourceState(client, *gsink.hashed_sink); + hashed_source = global_partition.GetGlobalSourceState(client, *gsink.hashed_sink); auto &hash_groups = global_partition.GetHashGroups(*hashed_source); window_hash_groups.resize(hash_groups.size()); for (idx_t group_idx = 0; group_idx < hash_groups.size(); ++group_idx) { - auto rows = std::move(hash_groups[group_idx]); - if (!rows) { + const auto block_count = hash_groups[group_idx].chunks; + if (!block_count) { continue; } - auto window_hash_group = make_uniq(gsink, rows, group_idx); - const auto block_count = window_hash_group->rows->ChunkCount(); + auto window_hash_group = make_uniq(gsink, hash_groups[group_idx], group_idx); window_hash_group->batch_base = total_blocks; total_blocks += block_count; @@ -438,7 +457,7 @@ void WindowGlobalSourceState::CreateTaskList() { if (!window_hash_group) { continue; } - partition_blocks.emplace_back(window_hash_group->rows->ChunkCount(), group_idx); + partition_blocks.emplace_back(window_hash_group->blocks, group_idx); } std::sort(partition_blocks.begin(), partition_blocks.end(), std::greater()); @@ -453,8 +472,8 @@ void WindowGlobalSourceState::CreateTaskList() { // STANDARD_VECTOR_SIZE >> ValidityMask::BITS_PER_VALUE, but if STANDARD_VECTOR_SIZE is say 2, // we need to align the chunk count to the mask width. const auto aligned_scale = MaxValue(ValidityMask::BITS_PER_VALUE / STANDARD_VECTOR_SIZE, 1); - const auto aligned_count = (max_block.first + aligned_scale - 1) / aligned_scale; - const auto per_thread = aligned_scale * ((aligned_count + threads - 1) / threads); + const auto aligned_count = WindowHashGroup::BinValue(max_block.first, aligned_scale); + const auto per_thread = aligned_scale * WindowHashGroup::BinValue(aligned_count, threads); if (!per_thread) { throw InternalException("No blocks per thread! %ld threads, %ld groups, %ld blocks, %ld hash group", threads, partition_blocks.size(), max_block.first, max_block.second); @@ -465,23 +484,14 @@ void WindowGlobalSourceState::CreateTaskList() { } } -WindowHashGroup::WindowHashGroup(WindowGlobalSinkState &gsink, HashGroupPtr &sorted, const idx_t hash_bin_p) - : gsink(gsink), count(0), blocks(0), rows(std::move(sorted)), stage(WindowGroupStage::MASK), hash_bin(hash_bin_p), - masked(0), sunk(0), finalized(0), completed(0), batch_base(0) { +WindowHashGroup::WindowHashGroup(WindowGlobalSinkState &gsink, const ChunkRow &chunk_row, const idx_t hash_bin_p) + : gsink(gsink), count(chunk_row.count), blocks(chunk_row.chunks), stage(WindowGroupStage::SORT), + hash_bin(hash_bin_p), sorted(0), materialized(0), masked(0), sunk(0), finalized(0), completed(0), batch_base(0) { // There are three types of partitions: // 1. No partition (no sorting) // 2. One partition (sorting, but no hashing) // 3. Multiple partitions (sorting and hashing) - // How big is the partition? - auto &gpart = *gsink.global_partition; - layout.Initialize(gpart.payload_types, TupleDataValidityType::CAN_HAVE_NULL_VALUES); - - if (rows) { - count = rows->Count(); - blocks = rows->ChunkCount(); - } - // Set up the collection for any fully materialised data const auto &shared = WindowSharedExpressions::GetSortedExpressions(gsink.shared.coll_shared); vector types; @@ -664,6 +674,10 @@ class WindowLocalSourceState : public LocalSourceState { DataChunk output_chunk; protected: + //! Sort the partition + void Sort(ExecutionContext &context, InterruptState &interrupt); + //! Materialize the sorted run + void Materialize(ExecutionContext &context, InterruptState &interrupt); //! Compute a mask range void Mask(ExecutionContext &context, InterruptState &interrupt); //! Sink tuples into function global states @@ -689,12 +703,50 @@ class WindowLocalSourceState : public LocalSourceState { idx_t WindowHashGroup::InitTasks(idx_t per_thread_p) { per_thread = per_thread_p; - group_threads = (rows->ChunkCount() + per_thread - 1) / per_thread; + group_threads = BinValue(ChunkCount(), per_thread); thread_states.resize(GetThreadCount()); return GetTaskCount(); } +void WindowLocalSourceState::Sort(ExecutionContext &context, InterruptState &interrupt) { + D_ASSERT(task); + D_ASSERT(task->stage == WindowGroupStage::SORT); + + auto &gsink = gsource.gsink; + auto &hashed_sort = *gsink.global_partition; + OperatorSinkFinalizeInput finalize {*gsink.hashed_sink, interrupt}; + hashed_sort.SortColumnData(context, task_local.group_idx, finalize); + + // Mark this range as done + window_hash_group->sorted += (task->end_idx - task->begin_idx); + task->begin_idx = task->end_idx; +} + +void WindowLocalSourceState::Materialize(ExecutionContext &context, InterruptState &interrupt) { + D_ASSERT(task); + D_ASSERT(task->stage == WindowGroupStage::MATERIALIZE); + + auto unused = make_uniq(); + OperatorSourceInput source {*gsource.hashed_source, *unused, interrupt}; + auto &gsink = gsource.gsink; + auto &hashed_sort = *gsink.global_partition; + hashed_sort.MaterializeColumnData(context, task_local.group_idx, source); + + // Mark this range as done + window_hash_group->materialized += (task->end_idx - task->begin_idx); + task->begin_idx = task->end_idx; + + // There is no good place to read the column data, + // and if we do it twice we can split the results. + if (window_hash_group->materialized >= window_hash_group->blocks) { + lock_guard prepare_guard(window_hash_group->lock); + if (!window_hash_group->rows) { + window_hash_group->rows = hashed_sort.GetColumnData(task_local.group_idx, source); + } + } +} + void WindowLocalSourceState::Mask(ExecutionContext &context, InterruptState &interrupt) { D_ASSERT(task); D_ASSERT(task->stage == WindowGroupStage::MASK); @@ -921,6 +973,14 @@ void WindowLocalSourceState::ExecuteTask(ExecutionContext &context, DataChunk &r // Process the new state switch (task->stage) { + case WindowGroupStage::SORT: + Sort(context, interrupt); + D_ASSERT(TaskFinished()); + break; + case WindowGroupStage::MATERIALIZE: + Materialize(context, interrupt); + D_ASSERT(TaskFinished()); + break; case WindowGroupStage::MASK: Mask(context, interrupt); D_ASSERT(TaskFinished()); diff --git a/src/duckdb/src/execution/operator/csv_scanner/scanner/string_value_scanner.cpp b/src/duckdb/src/execution/operator/csv_scanner/scanner/string_value_scanner.cpp index 5ed14a992..2bddff3ca 100644 --- a/src/duckdb/src/execution/operator/csv_scanner/scanner/string_value_scanner.cpp +++ b/src/duckdb/src/execution/operator/csv_scanner/scanner/string_value_scanner.cpp @@ -639,7 +639,6 @@ void StringValueResult::AddValue(StringValueResult &result, const idx_t buffer_p } void StringValueResult::HandleUnicodeError(idx_t col_idx, LinePosition &error_position) { - bool first_nl = false; auto borked_line = current_line_position.ReconstructCurrentLine(first_nl, buffer_handles, PrintErrorLine()); LinesPerBoundary lines_per_batch(iterator.GetBoundaryIdx(), lines_read); diff --git a/src/duckdb/src/execution/operator/csv_scanner/sniffer/dialect_detection.cpp b/src/duckdb/src/execution/operator/csv_scanner/sniffer/dialect_detection.cpp index fc8dc9385..880a06149 100644 --- a/src/duckdb/src/execution/operator/csv_scanner/sniffer/dialect_detection.cpp +++ b/src/duckdb/src/execution/operator/csv_scanner/sniffer/dialect_detection.cpp @@ -146,7 +146,6 @@ void CSVSniffer::GenerateStateMachineSearchSpace(vector(CSVStateMachineCache::ObjectType()); } diff --git a/src/duckdb/src/execution/operator/csv_scanner/table_function/csv_file_scanner.cpp b/src/duckdb/src/execution/operator/csv_scanner/table_function/csv_file_scanner.cpp index 8a830c8c9..e293c7337 100644 --- a/src/duckdb/src/execution/operator/csv_scanner/table_function/csv_file_scanner.cpp +++ b/src/duckdb/src/execution/operator/csv_scanner/table_function/csv_file_scanner.cpp @@ -13,7 +13,6 @@ CSVFileScan::CSVFileScan(ClientContext &context, const OpenFileInfo &file_p, CSV : BaseFileReader(file_p), buffer_manager(std::move(buffer_manager_p)), error_handler(make_shared_ptr(options_p.ignore_errors.GetValue())), options(std::move(options_p)) { - // Initialize Buffer Manager if (!buffer_manager) { buffer_manager = make_shared_ptr(context, options, file, per_file_single_threaded); diff --git a/src/duckdb/src/execution/operator/csv_scanner/table_function/csv_multi_file_info.cpp b/src/duckdb/src/execution/operator/csv_scanner/table_function/csv_multi_file_info.cpp index 260ac35d4..843209eb7 100644 --- a/src/duckdb/src/execution/operator/csv_scanner/table_function/csv_multi_file_info.cpp +++ b/src/duckdb/src/execution/operator/csv_scanner/table_function/csv_multi_file_info.cpp @@ -365,13 +365,15 @@ bool CSVFileScan::TryInitializeScan(ClientContext &context, GlobalTableFunctionS return true; } -void CSVFileScan::Scan(ClientContext &context, GlobalTableFunctionState &global_state, - LocalTableFunctionState &local_state, DataChunk &chunk) { +AsyncResult CSVFileScan::Scan(ClientContext &context, GlobalTableFunctionState &global_state, + LocalTableFunctionState &local_state, DataChunk &chunk) { auto &lstate = local_state.Cast(); if (lstate.csv_reader->FinishedIterator()) { - return; + return AsyncResult(SourceResultType::FINISHED); } lstate.csv_reader->Flush(chunk); + return chunk.size() == 0 ? AsyncResult(SourceResultType::FINISHED) + : AsyncResult(SourceResultType::HAVE_MORE_OUTPUT); } void CSVFileScan::FinishFile(ClientContext &context, GlobalTableFunctionState &global_state) { diff --git a/src/duckdb/src/execution/operator/csv_scanner/util/csv_error.cpp b/src/duckdb/src/execution/operator/csv_scanner/util/csv_error.cpp index 7fd64d889..07558c6bc 100644 --- a/src/duckdb/src/execution/operator/csv_scanner/util/csv_error.cpp +++ b/src/duckdb/src/execution/operator/csv_scanner/util/csv_error.cpp @@ -275,7 +275,7 @@ CSVError::CSVError(string error_message_p, CSVErrorType type_p, LinesPerBoundary CSVError::CSVError(string error_message_p, CSVErrorType type_p, idx_t column_idx_p, string csv_row_p, LinesPerBoundary error_info_p, idx_t row_byte_position, optional_idx byte_position_p, - const CSVReaderOptions &reader_options, const string &fixes, const string ¤t_path) + const CSVReaderOptions &reader_options, const string &fixes, const String ¤t_path) : error_message(std::move(error_message_p)), type(type_p), column_idx(column_idx_p), csv_row(std::move(csv_row_p)), error_info(error_info_p), row_byte_position(row_byte_position), byte_position(byte_position_p) { // What were the options @@ -319,7 +319,7 @@ void CSVError::RemoveNewLine(string &error) { CSVError CSVError::CastError(const CSVReaderOptions &options, const string &column_name, string &cast_error, idx_t column_idx, string &csv_row, LinesPerBoundary error_info, idx_t row_byte_position, - optional_idx byte_position, LogicalTypeId type, const string ¤t_path) { + optional_idx byte_position, LogicalTypeId type, const String ¤t_path) { std::ostringstream error; // Which column error << "Error when converting column \"" << column_name << "\". "; @@ -350,7 +350,7 @@ CSVError CSVError::CastError(const CSVReaderOptions &options, const string &colu } CSVError CSVError::LineSizeError(const CSVReaderOptions &options, LinesPerBoundary error_info, string &csv_row, - idx_t byte_position, const string ¤t_path) { + idx_t byte_position, const String ¤t_path) { std::ostringstream error; error << "Maximum line size of " << options.maximum_line_size.GetValue() << " bytes exceeded. "; error << "Actual Size:" << csv_row.size() << " bytes." << '\n'; @@ -365,7 +365,7 @@ CSVError CSVError::LineSizeError(const CSVReaderOptions &options, LinesPerBounda CSVError CSVError::InvalidState(const CSVReaderOptions &options, idx_t current_column, LinesPerBoundary error_info, string &csv_row, idx_t row_byte_position, optional_idx byte_position, - const string ¤t_path) { + const String ¤t_path) { std::ostringstream error; error << "The CSV Parser state machine reached an invalid state.\nThis can happen when is not possible to parse " "your CSV File with the given options, or the CSV File is not RFC 4180 compliant "; @@ -521,7 +521,7 @@ CSVError CSVError::SniffingError(const CSVReaderOptions &options, const string & } CSVError CSVError::NullPaddingFail(const CSVReaderOptions &options, LinesPerBoundary error_info, - const string ¤t_path) { + const String ¤t_path) { std::ostringstream error; error << " The parallel scanner does not support null_padding in conjunction with quoted new lines. Please " "disable the parallel csv reader with parallel=false" @@ -533,7 +533,7 @@ CSVError CSVError::NullPaddingFail(const CSVReaderOptions &options, LinesPerBoun CSVError CSVError::UnterminatedQuotesError(const CSVReaderOptions &options, idx_t current_column, LinesPerBoundary error_info, string &csv_row, idx_t row_byte_position, - optional_idx byte_position, const string ¤t_path) { + optional_idx byte_position, const String ¤t_path) { std::ostringstream error; error << "Value with unterminated quote found." << '\n'; std::ostringstream how_to_fix_it; @@ -551,7 +551,7 @@ CSVError CSVError::UnterminatedQuotesError(const CSVReaderOptions &options, idx_ CSVError CSVError::IncorrectColumnAmountError(const CSVReaderOptions &options, idx_t actual_columns, LinesPerBoundary error_info, string &csv_row, idx_t row_byte_position, - optional_idx byte_position, const string ¤t_path) { + optional_idx byte_position, const String ¤t_path) { std::ostringstream error; // We don't have a fix for this std::ostringstream how_to_fix_it; @@ -581,7 +581,7 @@ CSVError CSVError::IncorrectColumnAmountError(const CSVReaderOptions &options, i CSVError CSVError::InvalidUTF8(const CSVReaderOptions &options, idx_t current_column, LinesPerBoundary error_info, string &csv_row, idx_t row_byte_position, optional_idx byte_position, - const string ¤t_path) { + const String ¤t_path) { std::ostringstream error; // How many columns were expected and how many were found error << "Invalid unicode (byte sequence mismatch) detected. This file is not " << options.encoding << " encoded." diff --git a/src/duckdb/src/execution/operator/csv_scanner/util/csv_reader_options.cpp b/src/duckdb/src/execution/operator/csv_scanner/util/csv_reader_options.cpp index 5801e99b0..d963df00c 100644 --- a/src/duckdb/src/execution/operator/csv_scanner/util/csv_reader_options.cpp +++ b/src/duckdb/src/execution/operator/csv_scanner/util/csv_reader_options.cpp @@ -465,7 +465,7 @@ bool CSVReaderOptions::WasTypeManuallySet(idx_t i) const { return was_type_manually_set[i]; } -string CSVReaderOptions::ToString(const string ¤t_file_path) const { +string CSVReaderOptions::ToString(const String ¤t_file_path) const { auto &delimiter = dialect_options.state_machine_options.delimiter; auto "e = dialect_options.state_machine_options.quote; auto &escape = dialect_options.state_machine_options.escape; @@ -475,7 +475,7 @@ string CSVReaderOptions::ToString(const string ¤t_file_path) const { auto &skip_rows = dialect_options.skip_rows; auto &header = dialect_options.header; - string error = " file = " + current_file_path + "\n "; + string error = " file = " + current_file_path.ToStdString() + "\n "; // Let's first print options that can either be set by the user or by the sniffer // delimiter error += FormatOptionLine("delimiter", delimiter); diff --git a/src/duckdb/src/execution/operator/filter/physical_filter.cpp b/src/duckdb/src/execution/operator/filter/physical_filter.cpp index 2921a0e83..889667e82 100644 --- a/src/duckdb/src/execution/operator/filter/physical_filter.cpp +++ b/src/duckdb/src/execution/operator/filter/physical_filter.cpp @@ -7,7 +7,6 @@ namespace duckdb { PhysicalFilter::PhysicalFilter(PhysicalPlan &physical_plan, vector types, vector> select_list, idx_t estimated_cardinality) : CachingPhysicalOperator(physical_plan, PhysicalOperatorType::FILTER, std::move(types), estimated_cardinality) { - D_ASSERT(!select_list.empty()); if (select_list.size() == 1) { expression = std::move(select_list[0]); diff --git a/src/duckdb/src/execution/operator/helper/physical_batch_collector.cpp b/src/duckdb/src/execution/operator/helper/physical_batch_collector.cpp index c36e891f2..e79a4d044 100644 --- a/src/duckdb/src/execution/operator/helper/physical_batch_collector.cpp +++ b/src/duckdb/src/execution/operator/helper/physical_batch_collector.cpp @@ -47,7 +47,7 @@ unique_ptr PhysicalBatchCollector::GetGlobalSinkState(ClientCon return make_uniq(context, *this); } -unique_ptr PhysicalBatchCollector::GetResult(GlobalSinkState &state) { +unique_ptr PhysicalBatchCollector::GetResult(GlobalSinkState &state) const { auto &gstate = state.Cast(); D_ASSERT(gstate.result); return std::move(gstate.result); diff --git a/src/duckdb/src/execution/operator/helper/physical_buffered_batch_collector.cpp b/src/duckdb/src/execution/operator/helper/physical_buffered_batch_collector.cpp index 404d14343..9e3caab6c 100644 --- a/src/duckdb/src/execution/operator/helper/physical_buffered_batch_collector.cpp +++ b/src/duckdb/src/execution/operator/helper/physical_buffered_batch_collector.cpp @@ -53,7 +53,6 @@ SinkResultType PhysicalBufferedBatchCollector::Sink(ExecutionContext &context, D SinkNextBatchType PhysicalBufferedBatchCollector::NextBatch(ExecutionContext &context, OperatorSinkNextBatchInput &input) const { - auto &gstate = input.global_state.Cast(); auto &lstate = input.local_state.Cast(); @@ -94,11 +93,11 @@ unique_ptr PhysicalBufferedBatchCollector::GetLocalSinkState(Exe unique_ptr PhysicalBufferedBatchCollector::GetGlobalSinkState(ClientContext &context) const { auto state = make_uniq(); state->context = context.shared_from_this(); - state->buffered_data = make_shared_ptr(state->context); + state->buffered_data = make_shared_ptr(context); return std::move(state); } -unique_ptr PhysicalBufferedBatchCollector::GetResult(GlobalSinkState &state) { +unique_ptr PhysicalBufferedBatchCollector::GetResult(GlobalSinkState &state) const { auto &gstate = state.Cast(); auto cc = gstate.context.lock(); auto result = make_uniq(statement_type, properties, types, names, cc->GetClientProperties(), diff --git a/src/duckdb/src/execution/operator/helper/physical_buffered_collector.cpp b/src/duckdb/src/execution/operator/helper/physical_buffered_collector.cpp index 7795230dc..8ee1e1617 100644 --- a/src/duckdb/src/execution/operator/helper/physical_buffered_collector.cpp +++ b/src/duckdb/src/execution/operator/helper/physical_buffered_collector.cpp @@ -48,7 +48,7 @@ SinkCombineResultType PhysicalBufferedCollector::Combine(ExecutionContext &conte unique_ptr PhysicalBufferedCollector::GetGlobalSinkState(ClientContext &context) const { auto state = make_uniq(); state->context = context.shared_from_this(); - state->buffered_data = make_shared_ptr(state->context); + state->buffered_data = make_shared_ptr(context); return std::move(state); } @@ -57,7 +57,7 @@ unique_ptr PhysicalBufferedCollector::GetLocalSinkState(Executio return std::move(state); } -unique_ptr PhysicalBufferedCollector::GetResult(GlobalSinkState &state) { +unique_ptr PhysicalBufferedCollector::GetResult(GlobalSinkState &state) const { auto &gstate = state.Cast(); lock_guard l(gstate.glock); // FIXME: maybe we want to check if the execution was successful before creating the StreamQueryResult ? diff --git a/src/duckdb/src/execution/operator/helper/physical_limit.cpp b/src/duckdb/src/execution/operator/helper/physical_limit.cpp index 5a4339c63..987e6b4eb 100644 --- a/src/duckdb/src/execution/operator/helper/physical_limit.cpp +++ b/src/duckdb/src/execution/operator/helper/physical_limit.cpp @@ -8,6 +8,8 @@ namespace duckdb { +constexpr const idx_t PhysicalLimit::MAX_LIMIT_VALUE; + PhysicalLimit::PhysicalLimit(PhysicalPlan &physical_plan, vector types, BoundLimitNode limit_val_p, BoundLimitNode offset_val_p, idx_t estimated_cardinality) : PhysicalOperator(physical_plan, PhysicalOperatorType::LIMIT, std::move(types), estimated_cardinality), @@ -19,7 +21,8 @@ PhysicalLimit::PhysicalLimit(PhysicalPlan &physical_plan, vector ty //===--------------------------------------------------------------------===// class LimitGlobalState : public GlobalSinkState { public: - explicit LimitGlobalState(ClientContext &context, const PhysicalLimit &op) : data(context, op.types, true) { + explicit LimitGlobalState(ClientContext &context, const PhysicalLimit &op) + : data(context, op.types, ColumnDataAllocatorType::BUFFER_MANAGER_ALLOCATOR) { limit = 0; offset = 0; } @@ -33,7 +36,7 @@ class LimitGlobalState : public GlobalSinkState { class LimitLocalState : public LocalSinkState { public: explicit LimitLocalState(ClientContext &context, const PhysicalLimit &op) - : current_offset(0), data(context, op.types, true) { + : current_offset(0), data(context, op.types, ColumnDataAllocatorType::BUFFER_MANAGER_ALLOCATOR) { PhysicalLimit::SetInitialLimits(op.limit_val, op.offset_val, limit, offset); } @@ -108,7 +111,6 @@ bool PhysicalLimit::ComputeOffset(ExecutionContext &context, DataChunk &input, o } SinkResultType PhysicalLimit::Sink(ExecutionContext &context, DataChunk &chunk, OperatorSinkInput &input) const { - D_ASSERT(chunk.size() > 0); auto &state = input.local_state.Cast(); auto &limit = state.limit; diff --git a/src/duckdb/src/execution/operator/helper/physical_materialized_collector.cpp b/src/duckdb/src/execution/operator/helper/physical_materialized_collector.cpp index a70b914ce..a42e7a4a8 100644 --- a/src/duckdb/src/execution/operator/helper/physical_materialized_collector.cpp +++ b/src/duckdb/src/execution/operator/helper/physical_materialized_collector.cpp @@ -2,6 +2,7 @@ #include "duckdb/main/materialized_query_result.hpp" #include "duckdb/main/client_context.hpp" +#include "duckdb/main/result_set_manager.hpp" namespace duckdb { @@ -10,6 +11,19 @@ PhysicalMaterializedCollector::PhysicalMaterializedCollector(PhysicalPlan &physi : PhysicalResultCollector(physical_plan, data), parallel(parallel) { } +class MaterializedCollectorGlobalState : public GlobalSinkState { +public: + mutex glock; + unique_ptr collection; + shared_ptr context; +}; + +class MaterializedCollectorLocalState : public LocalSinkState { +public: + unique_ptr collection; + ColumnDataAppendState append_state; +}; + SinkResultType PhysicalMaterializedCollector::Sink(ExecutionContext &context, DataChunk &chunk, OperatorSinkInput &input) const { auto &lstate = input.local_state.Cast(); @@ -43,15 +57,15 @@ unique_ptr PhysicalMaterializedCollector::GetGlobalSinkState(Cl unique_ptr PhysicalMaterializedCollector::GetLocalSinkState(ExecutionContext &context) const { auto state = make_uniq(); - state->collection = make_uniq(Allocator::DefaultAllocator(), types); + state->collection = CreateCollection(context.client); state->collection->InitializeAppend(state->append_state); return std::move(state); } -unique_ptr PhysicalMaterializedCollector::GetResult(GlobalSinkState &state) { +unique_ptr PhysicalMaterializedCollector::GetResult(GlobalSinkState &state) const { auto &gstate = state.Cast(); if (!gstate.collection) { - gstate.collection = make_uniq(Allocator::DefaultAllocator(), types); + gstate.collection = CreateCollection(*gstate.context); } auto result = make_uniq(statement_type, properties, names, std::move(gstate.collection), gstate.context->GetClientProperties()); diff --git a/src/duckdb/src/execution/operator/helper/physical_prepare.cpp b/src/duckdb/src/execution/operator/helper/physical_prepare.cpp index 784d6ada4..1d1505743 100644 --- a/src/duckdb/src/execution/operator/helper/physical_prepare.cpp +++ b/src/duckdb/src/execution/operator/helper/physical_prepare.cpp @@ -8,7 +8,7 @@ SourceResultType PhysicalPrepare::GetData(ExecutionContext &context, DataChunk & auto &client = context.client; // store the prepared statement in the context - ClientData::Get(client).prepared_statements[name] = prepared; + ClientData::Get(client).prepared_statements[name.ToStdString()] = prepared; return SourceResultType::FINISHED; } diff --git a/src/duckdb/src/execution/operator/helper/physical_reset.cpp b/src/duckdb/src/execution/operator/helper/physical_reset.cpp index 1f5baf75d..9476b4e23 100644 --- a/src/duckdb/src/execution/operator/helper/physical_reset.cpp +++ b/src/duckdb/src/execution/operator/helper/physical_reset.cpp @@ -36,8 +36,7 @@ SourceResultType PhysicalReset::GetData(ExecutionContext &context, DataChunk &ch auto extension_name = Catalog::AutoloadExtensionByConfigName(context.client, name); entry = config.extension_parameters.find(name.ToStdString()); if (entry == config.extension_parameters.end()) { - throw InvalidInputException("Extension parameter %s was not found after autoloading", - name.ToStdString()); + throw InvalidInputException("Extension parameter %s was not found after autoloading", name); } } ResetExtensionVariable(context, config, entry->second); diff --git a/src/duckdb/src/execution/operator/helper/physical_result_collector.cpp b/src/duckdb/src/execution/operator/helper/physical_result_collector.cpp index d78bf225b..df95ce707 100644 --- a/src/duckdb/src/execution/operator/helper/physical_result_collector.cpp +++ b/src/duckdb/src/execution/operator/helper/physical_result_collector.cpp @@ -10,13 +10,15 @@ #include "duckdb/parallel/meta_pipeline.hpp" #include "duckdb/main/query_result.hpp" #include "duckdb/parallel/pipeline.hpp" +#include "duckdb/storage/buffer_manager.hpp" +#include "duckdb/main/client_context.hpp" namespace duckdb { PhysicalResultCollector::PhysicalResultCollector(PhysicalPlan &physical_plan, PreparedStatementData &data) : PhysicalOperator(physical_plan, PhysicalOperatorType::RESULT_COLLECTOR, {LogicalType::BOOLEAN}, 0), - statement_type(data.statement_type), properties(data.properties), plan(data.physical_plan->Root()), - names(data.names) { + statement_type(data.statement_type), properties(data.properties), memory_type(data.memory_type), + plan(data.physical_plan->Root()), names(data.names) { types = data.types; } @@ -26,7 +28,7 @@ PhysicalOperator &PhysicalResultCollector::GetResultCollector(ClientContext &con if (!PhysicalPlanGenerator::PreserveInsertionOrder(context, root)) { // Not an order-preserving plan: use the parallel materialized collector. - if (data.is_streaming) { + if (data.output_type == QueryResultOutputType::ALLOW_STREAMING) { return physical_plan.Make(data, true); } return physical_plan.Make(data, true); @@ -34,14 +36,14 @@ PhysicalOperator &PhysicalResultCollector::GetResultCollector(ClientContext &con if (!PhysicalPlanGenerator::UseBatchIndex(context, root)) { // Order-preserving plan, and we cannot use the batch index: use single-threaded result collector. - if (data.is_streaming) { + if (data.output_type == QueryResultOutputType::ALLOW_STREAMING) { return physical_plan.Make(data, false); } return physical_plan.Make(data, false); } // Order-preserving plan, and we can use the batch index: use a batch collector. - if (data.is_streaming) { + if (data.output_type == QueryResultOutputType::ALLOW_STREAMING) { return physical_plan.Make(data); } return physical_plan.Make(data); @@ -66,4 +68,18 @@ void PhysicalResultCollector::BuildPipelines(Pipeline ¤t, MetaPipeline &me child_meta_pipeline.Build(plan); } +unique_ptr PhysicalResultCollector::CreateCollection(ClientContext &context) const { + switch (memory_type) { + case QueryResultMemoryType::IN_MEMORY: + return make_uniq(Allocator::DefaultAllocator(), types); + case QueryResultMemoryType::BUFFER_MANAGED: + // Use the DatabaseInstance BufferManager because the query result can outlive the ClientContext + return make_uniq(BufferManager::GetBufferManager(*context.db), types, + ColumnDataCollectionLifetime::THROW_ERROR_AFTER_DATABASE_CLOSES); + default: + throw NotImplementedException("PhysicalResultCollector::CreateCollection for %s", + EnumUtil::ToString(memory_type)); + } +} + } // namespace duckdb diff --git a/src/duckdb/src/execution/operator/helper/physical_set.cpp b/src/duckdb/src/execution/operator/helper/physical_set.cpp index e8362ad9c..7c28d925e 100644 --- a/src/duckdb/src/execution/operator/helper/physical_set.cpp +++ b/src/duckdb/src/execution/operator/helper/physical_set.cpp @@ -6,17 +6,17 @@ namespace duckdb { -void PhysicalSet::SetGenericVariable(ClientContext &context, const string &name, SetScope scope, Value target_value) { +void PhysicalSet::SetGenericVariable(ClientContext &context, const String &name, SetScope scope, Value target_value) { if (scope == SetScope::GLOBAL) { auto &config = DBConfig::GetConfig(context); config.SetOption(name, std::move(target_value)); } else { auto &client_config = ClientConfig::GetConfig(context); - client_config.set_variables[name] = std::move(target_value); + client_config.set_variables[name.ToStdString()] = std::move(target_value); } } -void PhysicalSet::SetExtensionVariable(ClientContext &context, ExtensionOption &extension_option, const string &name, +void PhysicalSet::SetExtensionVariable(ClientContext &context, ExtensionOption &extension_option, const String &name, SetScope scope, const Value &value) { auto &target_type = extension_option.type; Value target_value = value.CastAs(context, target_type); @@ -36,10 +36,10 @@ SourceResultType PhysicalSet::GetData(ExecutionContext &context, DataChunk &chun auto option = DBConfig::GetOptionByName(name); if (!option) { // check if this is an extra extension variable - auto entry = config.extension_parameters.find(name); + auto entry = config.extension_parameters.find(name.ToStdString()); if (entry == config.extension_parameters.end()) { auto extension_name = Catalog::AutoloadExtensionByConfigName(context.client, name); - entry = config.extension_parameters.find(name); + entry = config.extension_parameters.find(name.ToStdString()); if (entry == config.extension_parameters.end()) { throw InvalidInputException("Extension parameter %s was not found after autoloading", name); } diff --git a/src/duckdb/src/execution/operator/helper/physical_set_variable.cpp b/src/duckdb/src/execution/operator/helper/physical_set_variable.cpp index 430e7055e..67f1a1615 100644 --- a/src/duckdb/src/execution/operator/helper/physical_set_variable.cpp +++ b/src/duckdb/src/execution/operator/helper/physical_set_variable.cpp @@ -1,12 +1,13 @@ #include "duckdb/execution/operator/helper/physical_set_variable.hpp" #include "duckdb/main/client_config.hpp" +#include "duckdb/execution/physical_plan_generator.hpp" namespace duckdb { -PhysicalSetVariable::PhysicalSetVariable(PhysicalPlan &physical_plan, string name_p, idx_t estimated_cardinality) +PhysicalSetVariable::PhysicalSetVariable(PhysicalPlan &physical_plan, const string &name_p, idx_t estimated_cardinality) : PhysicalOperator(physical_plan, PhysicalOperatorType::SET_VARIABLE, {LogicalType::BOOLEAN}, estimated_cardinality), - name(std::move(name_p)) { + name(physical_plan.ArenaRef().MakeString(name_p)) { } SourceResultType PhysicalSetVariable::GetData(ExecutionContext &context, DataChunk &chunk, diff --git a/src/duckdb/src/execution/operator/join/perfect_hash_join_executor.cpp b/src/duckdb/src/execution/operator/join/perfect_hash_join_executor.cpp index a48eaee4f..d9db21dbd 100644 --- a/src/duckdb/src/execution/operator/join/perfect_hash_join_executor.cpp +++ b/src/duckdb/src/execution/operator/join/perfect_hash_join_executor.cpp @@ -125,14 +125,14 @@ bool PerfectHashJoinExecutor::CanDoPerfectHashJoin(const PhysicalHashJoin &op, c //===--------------------------------------------------------------------===// bool PerfectHashJoinExecutor::BuildPerfectHashTable(LogicalType &key_type) { // First, allocate memory for each build column - auto build_size = perfect_join_statistics.build_range + 1; + const auto build_size = perfect_join_statistics.build_range + 1; for (const auto &type : join.rhs_output_columns.col_types) { - perfect_hash_table.emplace_back(type, build_size); + perfect_hash_table.emplace_back(DictionaryVector::CreateReusableDictionary(type, build_size)); } // and for duplicate_checking - bitmap_build_idx = make_unsafe_uniq_array_uninitialized(build_size); - memset(bitmap_build_idx.get(), 0, sizeof(bool) * build_size); // set false + bitmap_build_idx.Initialize(build_size); + bitmap_build_idx.SetAllInvalid(build_size); // Now fill columns with build data return FullScanHashTable(key_type); @@ -168,22 +168,25 @@ bool PerfectHashJoinExecutor::FullScanHashTable(LogicalType &key_type) { if (!success) { return false; } - if (unique_keys == perfect_join_statistics.build_range + 1 && !ht.has_null) { + + const auto build_size = perfect_join_statistics.build_range + 1; + if (unique_keys == build_size && !ht.has_null) { perfect_join_statistics.is_build_dense = true; + bitmap_build_idx.Reset(build_size); // All valid } key_count = unique_keys; // do not consider keys out of the range // Full scan the remaining build columns and fill the perfect hash table - const auto build_size = perfect_join_statistics.build_range + 1; + for (idx_t i = 0; i < join.rhs_output_columns.col_types.size(); i++) { - auto &vector = perfect_hash_table[i]; + auto &vector = perfect_hash_table[i]->data; const auto output_col_idx = ht.output_columns[i]; D_ASSERT(vector.GetType() == ht.layout_ptr->GetTypes()[output_col_idx]); - if (build_size > STANDARD_VECTOR_SIZE) { - auto &col_mask = FlatVector::Validity(vector); - col_mask.Initialize(build_size); - } + auto &col_mask = FlatVector::Validity(vector); + col_mask.Reset(build_size); data_collection.Gather(tuples_addresses, sel_tuples, key_count, output_col_idx, vector, sel_build, nullptr); + // This ensures the empty entries are set to NULL, so that the emitted dictionary vectors make sense + col_mask.Combine(bitmap_build_idx, build_size); } return true; @@ -227,19 +230,19 @@ bool PerfectHashJoinExecutor::TemplatedFillSelectionVectorBuild(Vector &source, auto max_value = perfect_join_statistics.build_max.GetValueUnsafe(); UnifiedVectorFormat vector_data; source.ToUnifiedFormat(count, vector_data); - auto data = reinterpret_cast(vector_data.data); + const auto data = vector_data.GetData(); // generate the selection vector for (idx_t i = 0, sel_idx = 0; i < count; ++i) { auto data_idx = vector_data.sel->get_index(i); auto input_value = data[data_idx]; // add index to selection vector if value in the range if (min_value <= input_value && input_value <= max_value) { - auto idx = (idx_t)(input_value - min_value); // subtract min value to get the idx position + auto idx = UnsafeNumericCast(input_value - min_value); // subtract min value to get the idx position sel_vec.set_index(sel_idx, idx); - if (bitmap_build_idx[idx]) { + if (bitmap_build_idx.RowIsValidUnsafe(idx)) { return false; } else { - bitmap_build_idx[idx] = true; + bitmap_build_idx.SetValidUnsafe(idx); unique_keys++; } seq_sel_vec.set_index(sel_idx++, i); @@ -302,9 +305,7 @@ OperatorResultType PerfectHashJoinExecutor::ProbePerfectHashTable(ExecutionConte for (idx_t i = 0; i < join.rhs_output_columns.col_types.size(); i++) { auto &result_vector = result.data[lhs_output_columns.ColumnCount() + i]; D_ASSERT(result_vector.GetType() == ht.layout_ptr->GetTypes()[ht.output_columns[i]]); - auto &build_vec = perfect_hash_table[i]; - result_vector.Reference(build_vec); - result_vector.Slice(state.build_sel_vec, probe_sel_count); + result_vector.Dictionary(perfect_hash_table[i], state.build_sel_vec); } return OperatorResultType::NEED_MORE_INPUT; } @@ -367,9 +368,9 @@ void PerfectHashJoinExecutor::TemplatedFillSelectionVectorProbe(Vector &source, auto input_value = data[data_idx]; // add index to selection vector if value in the range if (min_value <= input_value && input_value <= max_value) { - auto idx = (idx_t)(input_value - min_value); // subtract min value to get the idx position - // check for matches in the build - if (bitmap_build_idx[idx]) { + auto idx = UnsafeNumericCast(input_value - min_value); // subtract min value to get the idx + // position check for matches in the build + if (bitmap_build_idx.RowIsValid(idx)) { build_sel_vec.set_index(sel_idx, idx); probe_sel_vec.set_index(sel_idx++, i); probe_sel_count++; @@ -386,9 +387,9 @@ void PerfectHashJoinExecutor::TemplatedFillSelectionVectorProbe(Vector &source, auto input_value = data[data_idx]; // add index to selection vector if value in the range if (min_value <= input_value && input_value <= max_value) { - auto idx = (idx_t)(input_value - min_value); // subtract min value to get the idx position - // check for matches in the build - if (bitmap_build_idx[idx]) { + auto idx = UnsafeNumericCast(input_value - min_value); // subtract min value to get the idx + // position check for matches in the build + if (bitmap_build_idx.RowIsValid(idx)) { build_sel_vec.set_index(sel_idx, idx); probe_sel_vec.set_index(sel_idx++, i); probe_sel_count++; diff --git a/src/duckdb/src/execution/operator/join/physical_asof_join.cpp b/src/duckdb/src/execution/operator/join/physical_asof_join.cpp index 719781992..a672a36e8 100644 --- a/src/duckdb/src/execution/operator/join/physical_asof_join.cpp +++ b/src/duckdb/src/execution/operator/join/physical_asof_join.cpp @@ -1,14 +1,15 @@ #include "duckdb/execution/operator/join/physical_asof_join.hpp" #include "duckdb/common/row_operations/row_operations.hpp" -#include "duckdb/common/sort/partition_state.hpp" -#include "duckdb/common/sort/sort.hpp" -#include "duckdb/common/vector_operations/vector_operations.hpp" -#include "duckdb/execution/expression_executor.hpp" +#include "duckdb/common/sorting/hashed_sort.hpp" +#include "duckdb/common/sorting/sort_key.hpp" +#include "duckdb/common/sorting/sorted_run.hpp" +#include "duckdb/common/types/row/block_iterator.hpp" #include "duckdb/execution/operator/join/outer_join_marker.hpp" +#include "duckdb/function/create_sort_key.hpp" #include "duckdb/main/client_context.hpp" #include "duckdb/parallel/event.hpp" -#include "duckdb/parallel/thread_context.hpp" +#include "duckdb/parallel/meta_pipeline.hpp" namespace duckdb { @@ -16,9 +17,9 @@ PhysicalAsOfJoin::PhysicalAsOfJoin(PhysicalPlan &physical_plan, LogicalCompariso PhysicalOperator &right) : PhysicalComparisonJoin(physical_plan, op, PhysicalOperatorType::ASOF_JOIN, std::move(op.conditions), op.join_type, op.estimated_cardinality), - comparison_type(ExpressionType::INVALID), predicate(std::move(op.predicate)) { - + comparison_type(ExpressionType::INVALID) { // Convert the conditions partitions and sorts + D_ASSERT(!op.predicate.get()); for (auto &cond : conditions) { D_ASSERT(cond.left->return_type == cond.right->return_type); join_key_types.push_back(cond.left->return_type); @@ -74,51 +75,44 @@ PhysicalAsOfJoin::PhysicalAsOfJoin(PhysicalPlan &physical_plan, LogicalCompariso //===--------------------------------------------------------------------===// class AsOfGlobalSinkState : public GlobalSinkState { public: - AsOfGlobalSinkState(ClientContext &context, const PhysicalAsOfJoin &op) - : rhs_sink(context, op.rhs_partitions, op.rhs_orders, op.children[1].get().GetTypes(), {}, - op.estimated_cardinality), - is_outer(IsRightOuterJoin(op.join_type)), has_null(false) { - } - - idx_t Count() const { - return rhs_sink.count; + using HashedSortPtr = unique_ptr; + using HashedSinkPtr = unique_ptr; + using PartitionMarkers = vector; + + AsOfGlobalSinkState(ClientContext &client, const PhysicalAsOfJoin &op) { + // Set up partitions for both sides + hashed_sorts.reserve(2); + hashed_sinks.reserve(2); + const vector> partitions_stats; + auto &lhs = op.children[0].get(); + auto sort = make_uniq(client, op.lhs_partitions, op.lhs_orders, lhs.GetTypes(), partitions_stats, + lhs.estimated_cardinality, true); + hashed_sinks.emplace_back(sort->GetGlobalSinkState(client)); + hashed_sorts.emplace_back(std::move(sort)); + + auto &rhs = op.children[1].get(); + sort = make_uniq(client, op.rhs_partitions, op.rhs_orders, rhs.GetTypes(), partitions_stats, + rhs.estimated_cardinality, true); + hashed_sinks.emplace_back(sort->GetGlobalSinkState(client)); + hashed_sorts.emplace_back(std::move(sort)); } - PartitionLocalSinkState *RegisterBuffer(ClientContext &context) { - lock_guard guard(lock); - lhs_buffers.emplace_back(make_uniq(context, *lhs_sink)); - return lhs_buffers.back().get(); - } - - PartitionGlobalSinkState rhs_sink; - - // One per partition - const bool is_outer; - vector right_outers; - bool has_null; - - // Left side buffering - unique_ptr lhs_sink; - - mutex lock; - vector> lhs_buffers; + //! The child that is being materialised (right/1 then left/0) + size_t child = 1; + //! The child's partitioning description + vector hashed_sorts; + //! The child's partitioning buffer + vector hashed_sinks; }; class AsOfLocalSinkState : public LocalSinkState { public: - explicit AsOfLocalSinkState(ClientContext &context, PartitionGlobalSinkState &gstate_p) - : local_partition(context, gstate_p) { - } - - void Sink(DataChunk &input_chunk) { - local_partition.Sink(input_chunk); + AsOfLocalSinkState(ExecutionContext &context, AsOfGlobalSinkState &gsink) { + auto &hashed_sort = *gsink.hashed_sorts[gsink.child]; + local_partition = hashed_sort.GetLocalSinkState(context); } - void Combine() { - local_partition.Combine(); - } - - PartitionLocalSinkState local_partition; + unique_ptr local_partition; }; unique_ptr PhysicalAsOfJoin::GetGlobalSinkState(ClientContext &context) const { @@ -126,411 +120,990 @@ unique_ptr PhysicalAsOfJoin::GetGlobalSinkState(ClientContext & } unique_ptr PhysicalAsOfJoin::GetLocalSinkState(ExecutionContext &context) const { - // We only sink the RHS auto &gsink = sink_state->Cast(); - return make_uniq(context.client, gsink.rhs_sink); + return make_uniq(context, gsink); } -SinkResultType PhysicalAsOfJoin::Sink(ExecutionContext &context, DataChunk &chunk, OperatorSinkInput &input) const { - auto &lstate = input.local_state.Cast(); +SinkResultType PhysicalAsOfJoin::Sink(ExecutionContext &context, DataChunk &chunk, OperatorSinkInput &sink) const { + auto &gstate = sink.global_state.Cast(); + auto &lstate = sink.local_state.Cast(); - lstate.Sink(chunk); + auto &hashed_sort = *gstate.hashed_sorts[gstate.child]; + auto &gsink = *gstate.hashed_sinks[gstate.child]; + auto &lsink = *lstate.local_partition; - return SinkResultType::NEED_MORE_INPUT; + OperatorSinkInput hsink {gsink, lsink, sink.interrupt_state}; + return hashed_sort.Sink(context, chunk, hsink); } -SinkCombineResultType PhysicalAsOfJoin::Combine(ExecutionContext &context, OperatorSinkCombineInput &input) const { - auto &lstate = input.local_state.Cast(); - lstate.Combine(); - return SinkCombineResultType::FINISHED; +SinkCombineResultType PhysicalAsOfJoin::Combine(ExecutionContext &context, OperatorSinkCombineInput &combine) const { + auto &gstate = combine.global_state.Cast(); + auto &lstate = combine.local_state.Cast(); + + auto &hashed_sort = *gstate.hashed_sorts[gstate.child]; + auto &gsink = *gstate.hashed_sinks[gstate.child]; + auto &lsink = *lstate.local_partition; + + OperatorSinkCombineInput hcombine {gsink, lsink, combine.interrupt_state}; + return hashed_sort.Combine(context, hcombine); } //===--------------------------------------------------------------------===// // Finalize //===--------------------------------------------------------------------===// -SinkFinalizeType PhysicalAsOfJoin::Finalize(Pipeline &pipeline, Event &event, ClientContext &context, - OperatorSinkFinalizeInput &input) const { - auto &gstate = input.global_state.Cast(); - - // The data is all in so we can initialise the left partitioning. - const vector> partitions_stats; - gstate.lhs_sink = make_uniq(context, lhs_partitions, lhs_orders, - children[0].get().GetTypes(), partitions_stats, 0U); - gstate.lhs_sink->SyncPartitioning(gstate.rhs_sink); - - // Find the first group to sort - if (!gstate.rhs_sink.HasMergeTasks() && EmptyResultIfRHSIsEmpty()) { - // Empty input! - return SinkFinalizeType::NO_OUTPUT_POSSIBLE; +SinkFinalizeType PhysicalAsOfJoin::Finalize(Pipeline &pipeline, Event &event, ClientContext &client, + OperatorSinkFinalizeInput &finalize) const { + auto &gstate = finalize.global_state.Cast(); + + // The data is all in so we can synchronise the left partitioning. + auto &hashed_sort = *gstate.hashed_sorts[gstate.child]; + auto &hashed_sink = *gstate.hashed_sinks[gstate.child]; + OperatorSinkFinalizeInput hfinalize {hashed_sink, finalize.interrupt_state}; + if (gstate.child == 1) { + auto &lhs_groups = *gstate.hashed_sinks[1 - gstate.child]; + auto &rhs_groups = hashed_sink; + hashed_sort.Synchronize(rhs_groups, lhs_groups); } - // Schedule all the sorts for maximum thread utilisation - auto new_event = make_shared_ptr(gstate.rhs_sink, pipeline, *this); - event.InsertEvent(std::move(new_event)); + // Switch sides + gstate.child = 1 - gstate.child; - return SinkFinalizeType::READY; + return hashed_sort.Finalize(client, hfinalize); +} + +OperatorResultType PhysicalAsOfJoin::ExecuteInternal(ExecutionContext &context, DataChunk &input, DataChunk &chunk, + GlobalOperatorState &gstate, OperatorState &lstate_p) const { + return OperatorResultType::FINISHED; } //===--------------------------------------------------------------------===// -// Operator +// Source //===--------------------------------------------------------------------===// -class AsOfGlobalState : public GlobalOperatorState { +enum class AsOfJoinSourceStage : uint8_t { INIT, SORT, MATERIALIZE, GET, LEFT, RIGHT, DONE }; + +struct AsOfSourceTask { + AsOfSourceTask() { + } + + AsOfJoinSourceStage stage = AsOfJoinSourceStage::DONE; + //! The hash group + idx_t group_idx = 0; + //! The thread index (for local state) + idx_t thread_idx = 0; + //! The total block index count + idx_t max_idx = 0; + //! The first block index count + idx_t begin_idx = 0; + //! The last block index count + idx_t end_idx = 0; +}; + +class AsOfPayloadScanner { public: - explicit AsOfGlobalState(AsOfGlobalSinkState &gsink) { - // for FULL/RIGHT OUTER JOIN, initialize right_outers to false for every tuple - auto &rhs_partition = gsink.rhs_sink; - auto &right_outers = gsink.right_outers; - right_outers.reserve(rhs_partition.hash_groups.size()); - for (const auto &hash_group : rhs_partition.hash_groups) { - right_outers.emplace_back(OuterJoinMarker(gsink.is_outer)); - right_outers.back().Initialize(hash_group->count); + using Types = vector; + using Columns = vector; + + AsOfPayloadScanner(const SortedRun &sorted_run, const HashedSort &hashed_sort); + idx_t Base() const { + return base; + } + idx_t Scanned() const { + return scanned; + } + idx_t Remaining() const { + return count - scanned; + } + idx_t NextSize() const { + return MinValue(Remaining(), STANDARD_VECTOR_SIZE); + } + void SeekBlock(idx_t block_idx) { + chunk_idx = block_idx; + base = MinValue(chunk_idx * STANDARD_VECTOR_SIZE, count); + scanned = base; + } + inline void SeekRow(idx_t row_idx) { + SeekBlock(row_idx / STANDARD_VECTOR_SIZE); + } + bool Scan(DataChunk &chunk) { + // Free the previous blocks + block_state.SetKeepPinned(true); + block_state.SetPinPayload(true); + + base = scanned; + const auto result = (this->*scan_func)(); + chunk.ReferenceColumns(scan_chunk, scan_ids); + scanned += scan_chunk.size(); + ++chunk_idx; + return result; + } + +private: + template + bool TemplatedScan() { + using SORT_KEY = SortKey; + using BLOCK_ITERATOR = block_iterator_t; + BLOCK_ITERATOR itr(block_state, chunk_idx, 0); + + const auto sort_keys = FlatVector::GetData(sort_key_pointers); + const auto result_count = NextSize(); + for (idx_t i = 0; i < result_count; ++i) { + const auto idx = block_state.GetIndex(chunk_idx, i); + sort_keys[i] = &itr[idx]; } + + // Scan + scan_chunk.Reset(); + scan_state.Scan(sorted_run, sort_key_pointers, result_count, scan_chunk); + return scan_chunk.size() > 0; } + + // Only figure out the scan function once. + using scan_t = bool (duckdb::AsOfPayloadScanner::*)(); + scan_t scan_func; + + const SortedRun &sorted_run; + ExternalBlockIteratorState block_state; + Vector sort_key_pointers = Vector(LogicalType::POINTER); + SortedRunScanState scan_state; + const Columns scan_ids; + DataChunk scan_chunk; + const idx_t count; + idx_t base = 0; + idx_t scanned = 0; + idx_t chunk_idx = 0; }; -unique_ptr PhysicalAsOfJoin::GetGlobalOperatorState(ClientContext &context) const { - auto &gsink = sink_state->Cast(); - return make_uniq(gsink); +AsOfPayloadScanner::AsOfPayloadScanner(const SortedRun &sorted_run, const HashedSort &hashed_sort) + : sorted_run(sorted_run), block_state(*sorted_run.key_data, sorted_run.payload_data.get()), + scan_state(sorted_run.context, sorted_run.sort), scan_ids(hashed_sort.scan_ids), count(sorted_run.Count()) { + scan_chunk.Initialize(sorted_run.context, hashed_sort.payload_types); + const auto sort_key_type = sorted_run.key_data->GetLayout().GetSortKeyType(); + switch (sort_key_type) { + case SortKeyType::NO_PAYLOAD_FIXED_8: + scan_func = &AsOfPayloadScanner::TemplatedScan; + break; + case SortKeyType::NO_PAYLOAD_FIXED_16: + scan_func = &AsOfPayloadScanner::TemplatedScan; + break; + case SortKeyType::NO_PAYLOAD_FIXED_24: + scan_func = &AsOfPayloadScanner::TemplatedScan; + break; + case SortKeyType::NO_PAYLOAD_FIXED_32: + scan_func = &AsOfPayloadScanner::TemplatedScan; + break; + case SortKeyType::NO_PAYLOAD_VARIABLE_32: + scan_func = &AsOfPayloadScanner::TemplatedScan; + break; + case SortKeyType::PAYLOAD_FIXED_16: + scan_func = &AsOfPayloadScanner::TemplatedScan; + break; + case SortKeyType::PAYLOAD_FIXED_24: + scan_func = &AsOfPayloadScanner::TemplatedScan; + break; + case SortKeyType::PAYLOAD_FIXED_32: + scan_func = &AsOfPayloadScanner::TemplatedScan; + break; + case SortKeyType::PAYLOAD_VARIABLE_32: + scan_func = &AsOfPayloadScanner::TemplatedScan; + break; + default: + throw NotImplementedException("AsOfPayloadScanner for %s", EnumUtil::ToString(sort_key_type)); + } } -class AsOfLocalState : public CachingOperatorState { +class AsOfHashGroup { public: - AsOfLocalState(ClientContext &context, const PhysicalAsOfJoin &op) - : context(context), allocator(Allocator::Get(context)), op(op), lhs_executor(context), - left_outer(IsLeftOuterJoin(op.join_type)), fetch_next_left(true) { - lhs_keys.Initialize(allocator, op.join_key_types); - for (const auto &cond : op.conditions) { - lhs_executor.AddExpression(*cond.left); - } + using HashGroupPtr = unique_ptr; + using ChunkRow = HashedSort::ChunkRow; + + template + static T BinValue(T n, T val) { + return ((n + (val - 1)) / val); + } - lhs_payload.Initialize(allocator, op.children[0].get().GetTypes()); - lhs_sel.Initialize(); - left_outer.Initialize(STANDARD_VECTOR_SIZE); + AsOfHashGroup(const PhysicalAsOfJoin &op, const ChunkRow &left_stats, const ChunkRow &right_stats, + const idx_t hash_group); - auto &gsink = op.sink_state->Cast(); - lhs_partition_sink = gsink.RegisterBuffer(context); + //! Is this a right join (do we have a RIGHT stage?) + inline bool IsRightOuter() const { + return right_outer.Enabled(); } - bool Sink(DataChunk &input); - OperatorResultType ExecuteInternal(ExecutionContext &context, DataChunk &input, DataChunk &chunk); + //! The processing stage for this group + AsOfJoinSourceStage GetStage() const { + return stage; + } - ClientContext &context; - Allocator &allocator; - const PhysicalAsOfJoin &op; + //! The total number of tasks we will execute + idx_t GetTaskCount() const { + return stage_begin[size_t(AsOfJoinSourceStage::DONE)]; + } - ExpressionExecutor lhs_executor; - DataChunk lhs_keys; - ValidityMask lhs_valid_mask; - SelectionVector lhs_sel; - DataChunk lhs_payload; + //! The number of left chunks + inline idx_t LeftChunks() const { + return left_stats.chunks; + } - OuterJoinMarker left_outer; - bool fetch_next_left; + //! The number of right chunks + inline idx_t RightChunks() const { + return right_stats.chunks; + } + + // Set up the task parameters + idx_t InitTasks(idx_t per_thread); + + //! The maximum number of chunks that we will scan for each state + idx_t MaximumChunks() const { + return MaxValue(LeftChunks(), RightChunks()); + } - optional_ptr lhs_partition_sink; + //! Try to move to the next stage + bool TryPrepareNextStage(); + //! Try to get another task for this group + bool TryNextTask(AsOfSourceTask &task); + //! Finish the given task. Returns true if there are no more tasks. + bool FinishTask(AsOfSourceTask &task); + + //! The parent operator + const PhysicalAsOfJoin &op; + //! The group number + const idx_t group_idx; + //! The number of left chunks/rows + const ChunkRow left_stats; + //! The number of right chunks/rows + const ChunkRow right_stats; + //! The left hash partition data + HashGroupPtr left_group; + //! The right hash partition data + HashGroupPtr right_group; + //! The right outer join markers + OuterJoinMarker right_outer; + // The processing stage for this group + AsOfJoinSourceStage stage; + //! The the number of blocks per thread. + idx_t per_thread = 0; + //! The the number of tasks per stage. + vector stage_tasks; + //! The the first task in the stage. + vector stage_begin; + //! The next task to process + idx_t next_task = 0; + //! Count of sorting tasks completed + std::atomic sorted; + //! Count of materialization tasks completed + std::atomic materialized; + //! Count of get tasks completed + std::atomic gotten; + //! Count of left side tasks completed + std::atomic left_completed; + //! Count of right side tasks completed + std::atomic right_completed; }; -bool AsOfLocalState::Sink(DataChunk &input) { - // Compute the join keys - lhs_keys.Reset(); - lhs_executor.Execute(input, lhs_keys); - lhs_keys.Flatten(); +AsOfHashGroup::AsOfHashGroup(const PhysicalAsOfJoin &op, const ChunkRow &left_stats, const ChunkRow &right_stats, + const idx_t hash_group) + : op(op), group_idx(hash_group), left_stats(left_stats), right_stats(right_stats), + right_outer(IsRightOuterJoin(op.join_type)), stage(AsOfJoinSourceStage::INIT), sorted(0), materialized(0), + gotten(0), left_completed(0), right_completed(0) { + right_outer.Initialize(right_stats.count); +}; - // Combine the NULLs - const auto count = input.size(); - lhs_valid_mask.Reset(); - for (auto col_idx : op.null_sensitive) { - auto &col = lhs_keys.data[col_idx]; - UnifiedVectorFormat unified; - col.ToUnifiedFormat(count, unified); - lhs_valid_mask.Combine(unified.validity, count); +idx_t AsOfHashGroup::InitTasks(idx_t per_thread_p) { + per_thread = per_thread_p; + + // INIT + stage_tasks.emplace_back(0); + + // SORT + auto materialize_tasks = BinValue(LeftChunks(), per_thread); + materialize_tasks += BinValue(RightChunks(), per_thread); + stage_tasks.emplace_back(materialize_tasks); + + // MATERIALIZE + stage_tasks.emplace_back(materialize_tasks); + + // GET + stage_tasks.emplace_back(materialize_tasks ? 1 : 0); + + // LEFT + const auto left_tasks = BinValue(LeftChunks(), per_thread); + stage_tasks.emplace_back(left_tasks); + + // RIGHT + const auto right_chunks = IsRightOuter() ? RightChunks() : 0; + const auto right_tasks = BinValue(right_chunks, per_thread); + stage_tasks.emplace_back(right_tasks); + + // DONE + stage_tasks.emplace_back(0); + + // Accumulate task counts so we can find boundaries reliably + idx_t begin = 0; + for (const auto &stage_task : stage_tasks) { + stage_begin.emplace_back(begin); + begin += stage_task; } - // Convert the mask to a selection vector - // and mark all the rows that cannot match for early return. - idx_t lhs_valid = 0; - const auto entry_count = lhs_valid_mask.EntryCount(count); - idx_t base_idx = 0; - left_outer.Reset(); - for (idx_t entry_idx = 0; entry_idx < entry_count;) { - const auto validity_entry = lhs_valid_mask.GetValidityEntry(entry_idx++); - const auto next = MinValue(base_idx + ValidityMask::BITS_PER_VALUE, count); - if (ValidityMask::AllValid(validity_entry)) { - for (; base_idx < next; ++base_idx) { - lhs_sel.set_index(lhs_valid++, base_idx); - left_outer.SetMatch(base_idx); - } - } else if (ValidityMask::NoneValid(validity_entry)) { - base_idx = next; - } else { - const auto start = base_idx; - for (; base_idx < next; ++base_idx) { - if (ValidityMask::RowIsValid(validity_entry, base_idx - start)) { - lhs_sel.set_index(lhs_valid++, base_idx); - left_outer.SetMatch(base_idx); - } - } + stage = AsOfJoinSourceStage(1); + + return GetTaskCount(); +} + +bool AsOfHashGroup::TryPrepareNextStage() { + switch (stage) { + case AsOfJoinSourceStage::INIT: + stage = AsOfJoinSourceStage::SORT; + return true; + case AsOfJoinSourceStage::SORT: + if (sorted >= stage_tasks[size_t(stage)]) { + stage = AsOfJoinSourceStage::MATERIALIZE; + return true; } + break; + case AsOfJoinSourceStage::MATERIALIZE: + if (materialized >= stage_tasks[size_t(stage)]) { + stage = AsOfJoinSourceStage::GET; + return true; + } + break; + case AsOfJoinSourceStage::GET: + if (gotten >= stage_tasks[size_t(stage)]) { + stage = AsOfJoinSourceStage::LEFT; + return true; + } + break; + case AsOfJoinSourceStage::LEFT: + if (left_completed >= stage_tasks[size_t(stage)]) { + stage = stage_tasks[size_t(AsOfJoinSourceStage::RIGHT)] ? AsOfJoinSourceStage::RIGHT + : AsOfJoinSourceStage::DONE; + return true; + } + break; + case AsOfJoinSourceStage::RIGHT: + if (right_completed >= stage_tasks[size_t(stage)]) { + stage = AsOfJoinSourceStage::DONE; + return true; + } + break; + case AsOfJoinSourceStage::DONE: + return true; } - // Slice the keys to the ones we can match - lhs_payload.Reset(); - if (lhs_valid == count) { - lhs_payload.Reference(input); - lhs_payload.SetCardinality(input); - } else { - lhs_payload.Slice(input, lhs_sel, lhs_valid); - lhs_payload.SetCardinality(lhs_valid); + return false; +} - // Flush the ones that can't match - fetch_next_left = false; +bool AsOfHashGroup::TryNextTask(AsOfSourceTask &task) { + if (next_task >= GetTaskCount()) { + return false; } - lhs_partition_sink->Sink(lhs_payload); + // Search for where we are in the task list + for (idx_t stage = idx_t(AsOfJoinSourceStage::INIT); stage <= idx_t(AsOfJoinSourceStage::DONE); ++stage) { + if (next_task < stage_begin[stage]) { + task.stage = AsOfJoinSourceStage(stage - 1); + task.thread_idx = next_task - stage_begin[size_t(task.stage)]; + break; + } + } - return false; + if (task.stage != GetStage()) { + return false; + } + + task.group_idx = group_idx; + task.begin_idx = 0; + task.end_idx = 0; + + switch (task.stage) { + case AsOfJoinSourceStage::SORT: + task.begin_idx = task.thread_idx * per_thread; + task.max_idx = LeftChunks() + RightChunks(); + task.end_idx = MinValue(task.begin_idx + per_thread, task.max_idx); + break; + case AsOfJoinSourceStage::MATERIALIZE: + if (!left_group || !right_group) { + task.begin_idx = task.thread_idx * per_thread; + task.max_idx = LeftChunks() + RightChunks(); + task.end_idx = MinValue(task.begin_idx + per_thread, task.max_idx); + } + break; + case AsOfJoinSourceStage::GET: + task.begin_idx = 0; + task.end_idx = 1; + task.max_idx = 1; + break; + case AsOfJoinSourceStage::LEFT: + if (left_group) { + task.begin_idx = task.thread_idx * per_thread; + task.max_idx = LeftChunks(); + task.end_idx = MinValue(task.begin_idx + per_thread, task.max_idx); + } + break; + case AsOfJoinSourceStage::RIGHT: + if (right_group) { + task.begin_idx = task.thread_idx * per_thread; + task.max_idx = RightChunks(); + task.end_idx = MinValue(task.begin_idx + per_thread, task.max_idx); + } + break; + case AsOfJoinSourceStage::INIT: + case AsOfJoinSourceStage::DONE: + break; + } + + ++next_task; + + return true; +} + +bool AsOfHashGroup::FinishTask(AsOfSourceTask &task) { + // Inside the lock + switch (task.stage) { + case AsOfJoinSourceStage::SORT: + case AsOfJoinSourceStage::MATERIALIZE: + case AsOfJoinSourceStage::GET: + break; + case AsOfJoinSourceStage::LEFT: + if (left_completed >= stage_tasks[size_t(task.stage)]) { + left_group.reset(); + if (!IsRightOuter()) { + right_group.reset(); + } + } + break; + case AsOfJoinSourceStage::RIGHT: + if (right_completed >= stage_tasks[size_t(task.stage)]) { + right_group.reset(); + } + break; + case AsOfJoinSourceStage::INIT: + case AsOfJoinSourceStage::DONE: + break; + } + + return (materialized + gotten + left_completed + right_completed) >= GetTaskCount(); } -OperatorResultType AsOfLocalState::ExecuteInternal(ExecutionContext &context, DataChunk &input, DataChunk &chunk) { - input.Verify(); - Sink(input); +class AsOfLocalSourceState; - // If there were any unmatchable rows, return them now so we can forget about them. - if (!fetch_next_left) { - fetch_next_left = true; - left_outer.ConstructLeftJoinResult(input, chunk); - left_outer.Reset(); +class AsOfGlobalSourceState : public GlobalSourceState { +public: + using AsOfHashGroupPtr = unique_ptr; + using AsOfHashGroups = vector; + using HashedSourceStatePtr = unique_ptr; + using Task = AsOfSourceTask; + using TaskPtr = optional_ptr; + using PartitionBlock = std::pair; + + AsOfGlobalSourceState(ClientContext &client, const PhysicalAsOfJoin &op); + + //! Are there any more tasks? + bool HasMoreTasks() const { + return !stopped && started < total_tasks; + } + bool HasUnfinishedTasks() const { + return !stopped && finished < total_tasks; } - // Just keep asking for data and buffering it - return OperatorResultType::NEED_MORE_INPUT; + //! Assign a new task to the local state + bool TryNextTask(TaskPtr &task, Task &task_local); + + //! The parent operator + const PhysicalAsOfJoin &op; + //! The source states for the hashed sort + vector hashed_sources; + //! The hash groups + AsOfHashGroups asof_groups; + //! The sorted list of (blocks, group_idx) pairs + vector partition_blocks; + //! The ordered set of active groups + vector active_groups; + //! The next group to start + atomic next_group; + //! The total number of tasks + idx_t total_tasks = 0; + //! The number of started tasks + atomic started; + //! The number of tasks finished. + atomic finished; + //! Stop producing tasks + atomic stopped; + +public: + idx_t MaxThreads() override { + return total_tasks; + } + +protected: + //! Build task list + void CreateTaskList(ClientContext &client); + //! Finish a task + void FinishTask(TaskPtr task); +}; + +AsOfGlobalSourceState::AsOfGlobalSourceState(ClientContext &client, const PhysicalAsOfJoin &op) + : op(op), next_group(0), started(0), finished(0), stopped(false) { + // Take ownership of the hash groups + auto &gsink = op.sink_state->Cast(); + + using ChunkRow = HashedSort::ChunkRow; + using ChunkRows = HashedSort::ChunkRows; + vector child_groups(2); + for (idx_t child = 0; child < child_groups.size(); ++child) { + auto &hashed_sort = *gsink.hashed_sorts[child]; + auto &hashed_sink = *gsink.hashed_sinks[child]; + auto hashed_source = hashed_sort.GetGlobalSourceState(client, hashed_sink); + child_groups[child] = hashed_sort.GetHashGroups(*hashed_source); + hashed_sources.emplace_back(std::move(hashed_source)); + } + + // Pivot into AsOfHashGroups + auto &lhs_groups = child_groups[0]; + auto &rhs_groups = child_groups[1]; + const auto group_count = MaxValue(lhs_groups.size(), rhs_groups.size()); + for (idx_t group_idx = 0; group_idx < group_count; ++group_idx) { + ChunkRow lhs_stats; + if (group_idx < lhs_groups.size()) { + lhs_stats = lhs_groups[group_idx]; + } + ChunkRow rhs_stats; + if (group_idx < rhs_groups.size()) { + rhs_stats = rhs_groups[group_idx]; + } + auto asof_group = make_uniq(op, lhs_stats, rhs_stats, group_idx); + asof_groups.emplace_back(std::move(asof_group)); + } + + CreateTaskList(client); } -OperatorResultType PhysicalAsOfJoin::ExecuteInternal(ExecutionContext &context, DataChunk &input, DataChunk &chunk, - GlobalOperatorState &gstate, OperatorState &lstate_p) const { - auto &gsink = sink_state->Cast(); - auto &lstate = lstate_p.Cast(); +void AsOfGlobalSourceState::CreateTaskList(ClientContext &client) { + // Sort the groups from largest to smallest + if (asof_groups.empty()) { + return; + } - if (gsink.rhs_sink.count == 0) { - // empty RHS - if (!EmptyResultIfRHSIsEmpty()) { - ConstructEmptyJoinResult(join_type, gsink.has_null, input, chunk); - return OperatorResultType::NEED_MORE_INPUT; - } else { - return OperatorResultType::FINISHED; + // Count chunks, not rows (otherwise left and right raggedness could give the wrong answer + for (idx_t group_idx = 0; group_idx < asof_groups.size(); ++group_idx) { + auto &asof_hash_group = asof_groups[group_idx]; + if (!asof_hash_group) { + continue; } + partition_blocks.emplace_back(asof_hash_group->MaximumChunks(), group_idx); + } + std::sort(partition_blocks.begin(), partition_blocks.end(), std::greater()); + const auto &max_block = partition_blocks.front(); + + // Schedule the largest group on as many threads as possible + auto &ts = TaskScheduler::GetScheduler(client); + const auto threads = NumericCast(ts.NumberOfThreads()); + + const auto per_thread = AsOfHashGroup::BinValue(max_block.first, threads); + if (!per_thread) { + throw InternalException("No blocks per AsOf thread! %ld threads, %ld groups, %ld blocks, %ld hash group", + threads, partition_blocks.size(), max_block.first, max_block.second); } - return lstate.ExecuteInternal(context, input, chunk); + for (const auto &b : partition_blocks) { + total_tasks += asof_groups[b.second]->InitTasks(per_thread); + } } -//===--------------------------------------------------------------------===// -// Source -//===--------------------------------------------------------------------===// +enum class SortKeyPrefixComparisonType : uint8_t { FIXED, VARCHAR, NESTED }; + +struct SortKeyPrefixComparisonColumn { + SortKeyPrefixComparisonType type; + idx_t size; +}; + +struct SortKeyPrefixComparisonResult { + //! The column at which the sides are no longer equal, + //! e.g., Compare([42, 84], [42, 83]) would return {1, COMPARE_GREATERTHAN} + idx_t column_index; + //! Either COMPARE_EQUAL, COMPARE_LESSTHAN, COMPARE_GREATERTHAN + ExpressionType type; +}; + +struct SortKeyPrefixComparison { + unsafe_vector columns; + //! Two row buffer for measuring lhs and rhs widths for nested types. + //! Gross, but there is currently no way to measure the width of a single key + //! except as a side-effect of decoding it... + DataChunk decoded; + + template + SortKeyPrefixComparisonResult Compare(const SORT_KEY &lhs, const SORT_KEY &rhs) { + SortKeyPrefixComparisonResult result {0, ExpressionType::COMPARE_EQUAL}; + + auto lhs_copy = lhs; + string_t lhs_key; + lhs_copy.Deconstruct(lhs_key); + auto lhs_ptr = lhs_key.GetData(); + + auto rhs_copy = rhs; + string_t rhs_key; + rhs_copy.Deconstruct(rhs_key); + auto rhs_ptr = rhs_key.GetData(); + + // Partition keys are always sorted this way. + OrderModifiers modifiers(OrderType::ASCENDING, OrderByNullType::NULLS_FIRST); + + for (column_t col_idx = 0; col_idx < columns.size(); ++col_idx) { + const auto &col = columns[col_idx]; + auto &vec = decoded.data[col_idx]; + auto lhs_width = col.size; + auto rhs_width = col.size; + int cmp = 1; + switch (col.type) { + case SortKeyPrefixComparisonType::FIXED: + cmp = memcmp(lhs_ptr, rhs_ptr, lhs_width); + break; + case SortKeyPrefixComparisonType::VARCHAR: + // Include first null byte. + lhs_width = 1 + strlen(lhs_ptr); + rhs_width = 1 + strlen(rhs_ptr); + cmp = memcmp(lhs_ptr, rhs_ptr, MinValue(lhs_width, rhs_width)); + break; + case SortKeyPrefixComparisonType::NESTED: + decoded.Reset(); + lhs_width = CreateSortKeyHelpers::DecodeSortKey(lhs_key, vec, 0, modifiers); + rhs_width = CreateSortKeyHelpers::DecodeSortKey(rhs_key, vec, 1, modifiers); + cmp = memcmp(lhs_ptr, rhs_ptr, MinValue(lhs_width, rhs_width)); + if (!cmp) { + cmp = (rhs_width < lhs_width) - (lhs_width < rhs_width); + } + break; + } + + if (cmp) { + result.type = (cmp < 0) ? ExpressionType::COMPARE_LESSTHAN : ExpressionType::COMPARE_GREATERTHAN; + return result; + } + + ++result.column_index; + lhs_ptr += lhs_width; + rhs_ptr += rhs_width; + } + + return result; + } +}; + class AsOfProbeBuffer { public: using Orders = vector; + using Task = AsOfSourceTask; + using TaskPtr = optional_ptr; - static bool IsExternal(ClientContext &context) { - return ClientConfig::GetConfig(context).force_external; - } - - AsOfProbeBuffer(ClientContext &context, const PhysicalAsOfJoin &op); + AsOfProbeBuffer(ClientContext &client, const PhysicalAsOfJoin &op, AsOfGlobalSourceState &gsource); public: - void ResolveJoin(bool *found_matches, idx_t *matches = nullptr); - bool Scanning() const { - return lhs_scanner.get(); + // Comparison utilities + static bool IsStrictComparison(ExpressionType comparison) { + switch (comparison) { + case ExpressionType::COMPARE_LESSTHAN: + case ExpressionType::COMPARE_GREATERTHAN: + return true; + case ExpressionType::COMPARE_LESSTHANOREQUALTO: + case ExpressionType::COMPARE_GREATERTHANOREQUALTO: + return false; + default: + throw NotImplementedException("Unsupported comparison type for ASOF join"); + } } - void BeginLeftScan(hash_t scan_bin); + + //! Is left cmp right? + template + static inline bool Compare(const T &lhs, const T &rhs, const bool strict) { + const bool less_than = lhs < rhs; + if (!less_than && !strict) { + return !(rhs < lhs); + } + return less_than; + } + + template + void ResolveJoin(idx_t *matches); + + using resolve_join_t = void (duckdb::AsOfProbeBuffer::*)(idx_t *); + resolve_join_t resolve_join_func; + + void BeginLeftScan(TaskPtr task); bool NextLeft(); + void ScanLeft(); void EndLeftScan(); + //! Create a new iterator for the sorted run + static unique_ptr CreateIteratorState(SortedRun &sorted) { + auto state = make_uniq(*sorted.key_data, sorted.payload_data.get()); + + // Unless we do this, we will only get values from the first chunk + Repin(*state); + + return state; + } + //! Reset the pins for an iterator so we release memory in a timely manner + static void Repin(ExternalBlockIteratorState &iter) { + // Don't pin the payload because we are not using it here. + iter.SetKeepPinned(true); + } // resolve joins that output max N elements (SEMI, ANTI, MARK) void ResolveSimpleJoin(ExecutionContext &context, DataChunk &chunk); - // resolve joins that can potentially output N*M elements (INNER, LEFT, FULL) + // resolve joins that can potentially output N*M elements (LEFT, LEFT, FULL) void ResolveComplexJoin(ExecutionContext &context, DataChunk &chunk); // Chunk may be empty void GetData(ExecutionContext &context, DataChunk &chunk); bool HasMoreData() const { - return !fetch_next_left || (lhs_scanner && lhs_scanner->Remaining()); + return !fetch_next_left || (task->begin_idx < task->end_idx); } - ClientContext &context; - Allocator &allocator; + ClientContext &client; const PhysicalAsOfJoin &op; - BufferManager &buffer_manager; - const bool force_external; - const idx_t memory_per_thread; - Orders lhs_orders; + //! The source state + AsOfGlobalSourceState &gsource; + //! Is the inequality strict? + const bool strict; + //! The current hash group + optional_ptr asof_hash_group; + //! The task we are processing + TaskPtr task; // LHS scanning - SelectionVector lhs_sel; - optional_ptr left_hash; + optional_ptr left_group; OuterJoinMarker left_outer; - unique_ptr left_itr; - unique_ptr lhs_scanner; + unique_ptr left_itr; + unique_ptr lhs_scanner; DataChunk lhs_payload; - idx_t left_group = 0; + ExpressionExecutor lhs_executor; + DataChunk lhs_keys; + ValidityMask lhs_valid_mask; + idx_t left_bin = 0; + SelectionVector lhs_match_sel; // RHS scanning - optional_ptr right_hash; + optional_ptr right_group; optional_ptr right_outer; - unique_ptr right_itr; - unique_ptr rhs_scanner; + unique_ptr right_itr; + idx_t right_pos; // ExternalBlockIteratorState doesn't know this... + unique_ptr rhs_scanner; DataChunk rhs_payload; - idx_t right_group = 0; + DataChunk rhs_input; + SelectionVector rhs_match_sel; + idx_t right_bin = 0; // Predicate evaluation - SelectionVector filter_sel; - ExpressionExecutor filterer; - idx_t lhs_match_count; bool fetch_next_left; + + SortKeyPrefixComparison prefix; }; -AsOfProbeBuffer::AsOfProbeBuffer(ClientContext &context, const PhysicalAsOfJoin &op) - : context(context), allocator(Allocator::Get(context)), op(op), - buffer_manager(BufferManager::GetBufferManager(context)), force_external(IsExternal(context)), - memory_per_thread(op.GetMaxThreadMemory(context)), left_outer(IsLeftOuterJoin(op.join_type)), filterer(context), - fetch_next_left(true) { - vector> partition_stats; - Orders partitions; // Not used. - PartitionGlobalSinkState::GenerateOrderings(partitions, lhs_orders, op.lhs_partitions, op.lhs_orders, - partition_stats); - - // We sort the row numbers of the incoming block, not the rows - lhs_payload.Initialize(allocator, op.children[0].get().GetTypes()); - rhs_payload.Initialize(allocator, op.children[1].get().GetTypes()); - - lhs_sel.Initialize(); +AsOfProbeBuffer::AsOfProbeBuffer(ClientContext &client, const PhysicalAsOfJoin &op, AsOfGlobalSourceState &gsource) + : client(client), op(op), gsource(gsource), strict(IsStrictComparison(op.comparison_type)), + left_outer(IsLeftOuterJoin(op.join_type)), lhs_executor(client), fetch_next_left(true) { + lhs_keys.Initialize(client, op.join_key_types); + for (const auto &cond : op.conditions) { + lhs_executor.AddExpression(*cond.left); + } + + lhs_payload.Initialize(client, op.children[0].get().GetTypes()); + rhs_payload.Initialize(client, op.children[1].get().GetTypes()); + rhs_input.Initialize(client, op.children[1].get().GetTypes()); + + lhs_match_sel.Initialize(); + rhs_match_sel.Initialize(); left_outer.Initialize(STANDARD_VECTOR_SIZE); - if (op.predicate) { - filter_sel.Initialize(); - filterer.AddExpression(*op.predicate); + // If we have equality predicates, set up the prefix data. + vector prefix_types; + for (idx_t i = 0; i < op.conditions.size() - 1; ++i) { + const auto &cond = op.conditions[i]; + const auto &type = cond.left->return_type; + prefix_types.emplace_back(type); + SortKeyPrefixComparisonColumn col; + col.size = DConstants::INVALID_INDEX; + switch (type.id()) { + case LogicalTypeId::VARCHAR: + case LogicalTypeId::BLOB: + col.type = SortKeyPrefixComparisonType::VARCHAR; + break; + case LogicalTypeId::STRUCT: + case LogicalTypeId::LIST: + case LogicalTypeId::ARRAY: + col.type = SortKeyPrefixComparisonType::NESTED; + break; + default: + col.type = SortKeyPrefixComparisonType::FIXED; + col.size = 1 + GetTypeIdSize(type.InternalType()); + break; + } + prefix.columns.emplace_back(col); + } + if (!prefix_types.empty()) { + // LHS, RHS + prefix.decoded.Initialize(client, prefix_types, 2); } } -void AsOfProbeBuffer::BeginLeftScan(hash_t scan_bin) { +void AsOfProbeBuffer::BeginLeftScan(TaskPtr task_p) { auto &gsink = op.sink_state->Cast(); + task = task_p; + const auto scan_bin = task->group_idx; - auto &lhs_sink = *gsink.lhs_sink; - left_group = lhs_sink.bin_groups[scan_bin]; + asof_hash_group = gsource.asof_groups[scan_bin].get(); - // Always set right_group too for memory management - auto &rhs_sink = gsink.rhs_sink; - if (scan_bin < rhs_sink.bin_groups.size()) { - right_group = rhs_sink.bin_groups[scan_bin]; - } else { - right_group = rhs_sink.bin_groups.size(); - } + // Always set right_bin too for memory management + right_group = asof_hash_group->right_group; + right_bin = right_group ? scan_bin : gsource.asof_groups.size(); - if (left_group >= lhs_sink.bin_groups.size()) { + left_group = asof_hash_group->left_group; + left_bin = left_group ? scan_bin : gsource.asof_groups.size(); + if (!left_group || !left_group->Count()) { return; } - auto iterator_comp = ExpressionType::INVALID; - switch (op.comparison_type) { - case ExpressionType::COMPARE_GREATERTHANOREQUALTO: - iterator_comp = ExpressionType::COMPARE_LESSTHANOREQUALTO; + // Set up function pointer for sort type + const auto sort_key_type = left_group->key_data->GetLayout().GetSortKeyType(); + switch (sort_key_type) { + case SortKeyType::NO_PAYLOAD_FIXED_8: + resolve_join_func = &AsOfProbeBuffer::ResolveJoin; + break; + case SortKeyType::NO_PAYLOAD_FIXED_16: + resolve_join_func = &AsOfProbeBuffer::ResolveJoin; + break; + case SortKeyType::NO_PAYLOAD_FIXED_24: + resolve_join_func = &AsOfProbeBuffer::ResolveJoin; break; - case ExpressionType::COMPARE_GREATERTHAN: - iterator_comp = ExpressionType::COMPARE_LESSTHAN; + case SortKeyType::NO_PAYLOAD_FIXED_32: + resolve_join_func = &AsOfProbeBuffer::ResolveJoin; break; - case ExpressionType::COMPARE_LESSTHANOREQUALTO: - iterator_comp = ExpressionType::COMPARE_GREATERTHANOREQUALTO; + case SortKeyType::NO_PAYLOAD_VARIABLE_32: + resolve_join_func = &AsOfProbeBuffer::ResolveJoin; break; - case ExpressionType::COMPARE_LESSTHAN: - iterator_comp = ExpressionType::COMPARE_GREATERTHAN; + case SortKeyType::PAYLOAD_FIXED_16: + resolve_join_func = &AsOfProbeBuffer::ResolveJoin; + break; + case SortKeyType::PAYLOAD_FIXED_24: + resolve_join_func = &AsOfProbeBuffer::ResolveJoin; + break; + case SortKeyType::PAYLOAD_FIXED_32: + resolve_join_func = &AsOfProbeBuffer::ResolveJoin; + break; + case SortKeyType::PAYLOAD_VARIABLE_32: + resolve_join_func = &AsOfProbeBuffer::ResolveJoin; break; default: throw NotImplementedException("Unsupported comparison type for ASOF join"); } - left_hash = lhs_sink.hash_groups[left_group].get(); - auto &left_sort = *(left_hash->global_sort); - if (left_sort.sorted_blocks.empty()) { - return; - } - lhs_scanner = make_uniq(left_sort, false); - left_itr = make_uniq(left_sort, iterator_comp); + lhs_scanner = make_uniq(*left_group, *gsink.hashed_sorts[0]); + lhs_scanner->SeekBlock(task->begin_idx); + left_itr = CreateIteratorState(*left_group); // We are only probing the corresponding right side bin, which may be empty - // If they are empty, we leave the iterator as null so we can emit left matches - if (right_group < rhs_sink.bin_groups.size()) { - right_hash = rhs_sink.hash_groups[right_group].get(); - right_outer = gsink.right_outers.data() + right_group; - auto &right_sort = *(right_hash->global_sort); - right_itr = make_uniq(right_sort, iterator_comp); - rhs_scanner = make_uniq(right_sort, false); + // If it is empty, we leave the iterator as null so we can emit left matches + right_pos = 0; + if (right_group) { + right_outer = &asof_hash_group->right_outer; + if (right_group && right_group->Count()) { + right_itr = CreateIteratorState(*right_group); + rhs_scanner = make_uniq(*right_group, *gsink.hashed_sorts[1]); + } } } bool AsOfProbeBuffer::NextLeft() { - if (!HasMoreData()) { - return false; - } + return task->begin_idx < task->end_idx; +} +void AsOfProbeBuffer::ScanLeft() { // Scan the next sorted chunk lhs_payload.Reset(); - left_itr->SetIndex(lhs_scanner->Scanned()); lhs_scanner->Scan(lhs_payload); + ++task->begin_idx; - return true; + // Compute the join keys + lhs_keys.Reset(); + lhs_executor.Execute(lhs_payload, lhs_keys); + lhs_keys.Flatten(); + + // Combine the NULLs + const auto count = lhs_payload.size(); + lhs_valid_mask.Reset(); + for (auto col_idx : op.null_sensitive) { + auto &col = lhs_keys.data[col_idx]; + UnifiedVectorFormat unified; + col.ToUnifiedFormat(count, unified); + lhs_valid_mask.Combine(unified.validity, count); + } + + // Filter out NULL matches + if (!lhs_valid_mask.AllValid()) { + const auto count = lhs_match_count; + lhs_match_count = 0; + for (idx_t i = 0; i < count; ++i) { + const auto idx = lhs_match_sel.get_index(i); + if (lhs_valid_mask.RowIsValidUnsafe(idx)) { + lhs_match_sel.set_index(lhs_match_count++, idx); + } + } + } } void AsOfProbeBuffer::EndLeftScan() { - auto &gsink = op.sink_state->Cast(); + if (task->stage != AsOfJoinSourceStage::LEFT) { + return; + } + task->stage = AsOfJoinSourceStage::DONE; + + D_ASSERT(asof_hash_group); + asof_hash_group->left_completed++; - right_hash = nullptr; + right_group = nullptr; right_itr.reset(); rhs_scanner.reset(); right_outer = nullptr; - auto &rhs_sink = gsink.rhs_sink; - if (!gsink.is_outer && right_group < rhs_sink.bin_groups.size()) { - rhs_sink.hash_groups[right_group].reset(); - } - - left_hash = nullptr; + left_group = nullptr; left_itr.reset(); lhs_scanner.reset(); - - auto &lhs_sink = *gsink.lhs_sink; - if (left_group < lhs_sink.bin_groups.size()) { - lhs_sink.hash_groups[left_group].reset(); - } } -void AsOfProbeBuffer::ResolveJoin(bool *found_match, idx_t *matches) { +template +void AsOfProbeBuffer::ResolveJoin(idx_t *matches) { + using SORT_KEY = SortKey; + using BLOCKS_ITERATOR = block_iterator_t; + // If there was no right partition, there are no matches lhs_match_count = 0; if (!right_itr) { return; } - const auto count = lhs_payload.size(); - const auto left_base = left_itr->GetIndex(); + Repin(*left_itr); + BLOCKS_ITERATOR left_key(*left_itr); + + Repin(*right_itr); + BLOCKS_ITERATOR right_key(*right_itr); + + const auto count = lhs_scanner->NextSize(); + const auto left_base = lhs_scanner->Scanned(); // Searching for right <= left for (idx_t i = 0; i < count; ++i) { - left_itr->SetIndex(left_base + i); - // If right > left, then there is no match - if (!right_itr->Compare(*left_itr)) { + const auto left_pos = left_base + i; + if (!Compare(right_key[right_pos], left_key[left_pos], strict)) { continue; } // Exponential search forward for a non-matching value using radix iterators // (We use exponential search to avoid thrashing the block manager on large probes) idx_t bound = 1; - idx_t begin = right_itr->GetIndex(); - right_itr->SetIndex(begin + bound); - while (right_itr->GetIndex() < right_hash->count) { - if (right_itr->Compare(*left_itr)) { + idx_t begin = right_pos; + while (begin + bound < right_group->Count()) { + if (Compare(right_key[begin + bound], left_key[left_pos], strict)) { // If right <= left, jump ahead bound *= 2; - right_itr->SetIndex(begin + bound); } else { break; } @@ -539,43 +1112,46 @@ void AsOfProbeBuffer::ResolveJoin(bool *found_match, idx_t *matches) { // Binary search for the first non-matching value using radix iterators // The previous value (which we know exists) is the match auto first = begin + bound / 2; - auto last = MinValue(begin + bound, right_hash->count); + auto last = MinValue(begin + bound, right_group->Count()); while (first < last) { const auto mid = first + (last - first) / 2; - right_itr->SetIndex(mid); - if (right_itr->Compare(*left_itr)) { + if (Compare(right_key[mid], left_key[left_pos], strict)) { // If right <= left, new lower bound first = mid + 1; } else { last = mid; } } - right_itr->SetIndex(--first); + right_pos = --first; // Check partitions for strict equality - if (right_hash->ComparePartitions(*left_itr, *right_itr)) { - continue; + if (!prefix.columns.empty()) { + const auto cmp = prefix.Compare(left_key[left_pos], right_key[right_pos]); + if (cmp.column_index < prefix.columns.size()) { + continue; + } } // Emit match data - if (found_match) { - found_match[i] = true; - } if (matches) { matches[i] = first; } - lhs_sel.set_index(lhs_match_count++, i); + lhs_match_sel.set_index(lhs_match_count++, i); } } -unique_ptr PhysicalAsOfJoin::GetOperatorState(ExecutionContext &context) const { - return make_uniq(context.client, *this); -} - void AsOfProbeBuffer::ResolveSimpleJoin(ExecutionContext &context, DataChunk &chunk) { // perform the actual join + (this->*resolve_join_func)(nullptr); + + // Scan the lhs values (after comparing keys) and filter out the LHS NULLs + ScanLeft(); + + // Convert the match selection to simple join mask bool found_match[STANDARD_VECTOR_SIZE] = {false}; - ResolveJoin(found_match); + for (idx_t i = 0; i < lhs_match_count; ++i) { + found_match[lhs_match_sel.get_index(i)] = true; + } // now construct the result based on the join result switch (op.join_type) { @@ -593,43 +1169,51 @@ void AsOfProbeBuffer::ResolveSimpleJoin(ExecutionContext &context, DataChunk &ch void AsOfProbeBuffer::ResolveComplexJoin(ExecutionContext &context, DataChunk &chunk) { // perform the actual join idx_t matches[STANDARD_VECTOR_SIZE]; - ResolveJoin(nullptr, matches); + (this->*resolve_join_func)(matches); + + // Scan the lhs values (after comparing keys) and filter out the LHS NULLs + ScanLeft(); + // Extract the rhs input columns from the match + rhs_input.Reset(); + idx_t rhs_match_count = 0; for (idx_t i = 0; i < lhs_match_count; ++i) { - const auto idx = lhs_sel[i]; + const auto idx = lhs_match_sel[i]; const auto match_pos = matches[idx]; // Skip to the range containing the match - while (match_pos >= rhs_scanner->Scanned()) { + if (match_pos >= rhs_scanner->Scanned()) { + if (rhs_match_count) { + rhs_input.Append(rhs_payload, false, &rhs_match_sel, rhs_match_count); + rhs_match_count = 0; + } rhs_payload.Reset(); + rhs_scanner->SeekRow(match_pos); rhs_scanner->Scan(rhs_payload); } - // Append the individual values - // TODO: Batch the copies - const auto source_offset = match_pos - (rhs_scanner->Scanned() - rhs_payload.size()); - for (column_t col_idx = 0; col_idx < op.right_projection_map.size(); ++col_idx) { - const auto rhs_idx = op.right_projection_map[col_idx]; - auto &source = rhs_payload.data[rhs_idx]; - auto &target = chunk.data[lhs_payload.ColumnCount() + col_idx]; - VectorOperations::Copy(source, target, source_offset + 1, source_offset, i); - } + // Select the individual values + const auto source_offset = match_pos - rhs_scanner->Base(); + rhs_match_sel.set_index(rhs_match_count++, source_offset); } + rhs_input.Append(rhs_payload, false, &rhs_match_sel, rhs_match_count); // Slice the left payload into the result for (column_t i = 0; i < lhs_payload.ColumnCount(); ++i) { - chunk.data[i].Slice(lhs_payload.data[i], lhs_sel, lhs_match_count); + chunk.data[i].Slice(lhs_payload.data[i], lhs_match_sel, lhs_match_count); } - chunk.SetCardinality(lhs_match_count); - auto match_sel = &lhs_sel; - if (filterer.expressions.size() == 1) { - lhs_match_count = filterer.SelectExpression(chunk, filter_sel); - chunk.Slice(filter_sel, lhs_match_count); - match_sel = &filter_sel; + + // Reference the projected right payload into the result + for (column_t col_idx = 0; col_idx < op.right_projection_map.size(); ++col_idx) { + const auto rhs_idx = op.right_projection_map[col_idx]; + auto &source = rhs_input.data[rhs_idx]; + auto &target = chunk.data[lhs_payload.ColumnCount() + col_idx]; + target.Reference(source); } + chunk.SetCardinality(lhs_match_count); // Update the match masks for the rows we ended up with left_outer.Reset(); for (idx_t i = 0; i < lhs_match_count; ++i) { - const auto idx = match_sel->get_index(i); + const auto idx = lhs_match_sel.get_index(i); left_outer.SetMatch(idx); const auto first = matches[idx]; right_outer->SetMatch(first); @@ -675,241 +1259,412 @@ void AsOfProbeBuffer::GetData(ExecutionContext &context, DataChunk &chunk) { } } -class AsOfGlobalSourceState : public GlobalSourceState { -public: - explicit AsOfGlobalSourceState(AsOfGlobalSinkState &gsink_p) - : gsink(gsink_p), next_combine(0), combined(0), merged(0), mergers(0), next_left(0), flushed(0), next_right(0) { - } - - PartitionGlobalMergeStates &GetMergeStates() { - lock_guard guard(lock); - if (!merge_states) { - merge_states = make_uniq(*gsink.lhs_sink); - } - return *merge_states; - } - - AsOfGlobalSinkState &gsink; - //! The next buffer to combine - atomic next_combine; - //! The number of combined buffers - atomic combined; - //! The number of combined buffers - atomic merged; - //! The number of combined buffers - atomic mergers; - //! The next buffer to flush - atomic next_left; - //! The number of flushed buffers - atomic flushed; - //! The right outer output read position. - atomic next_right; - //! The merge handler - mutex lock; - unique_ptr merge_states; - -public: - idx_t MaxThreads() override { - return gsink.lhs_buffers.size(); - } -}; - -unique_ptr PhysicalAsOfJoin::GetGlobalSourceState(ClientContext &context) const { - auto &gsink = sink_state->Cast(); - return make_uniq(gsink); +unique_ptr PhysicalAsOfJoin::GetGlobalSourceState(ClientContext &client) const { + return make_uniq(client, *this); } class AsOfLocalSourceState : public LocalSourceState { public: - using HashGroupPtr = unique_ptr; - - AsOfLocalSourceState(AsOfGlobalSourceState &gsource, const PhysicalAsOfJoin &op, ClientContext &client_p); + using HashGroupPtr = optional_ptr; + using Task = AsOfSourceTask; + using TaskPtr = optional_ptr; + + AsOfLocalSourceState(ExecutionContext &context, AsOfGlobalSourceState &gsource, const PhysicalAsOfJoin &op); + + //! Task management + bool TaskFinished() const; + //! Assign the next task + bool TryAssignTask(); + + void ExecuteSortTask(ExecutionContext &context, DataChunk &chunk, OperatorSourceInput &source); + void ExecuteMaterializeTask(ExecutionContext &context, DataChunk &chunk, OperatorSourceInput &source); + void ExecuteGetTask(ExecutionContext &context, DataChunk &chunk, OperatorSourceInput &source); + void ExecuteLeftTask(ExecutionContext &context, DataChunk &chunk, OperatorSourceInput &source); + void ExecuteRightTask(ExecutionContext &context, DataChunk &chunk, OperatorSourceInput &source); + + void ExecuteTask(ExecutionContext &context, DataChunk &chunk, OperatorSourceInput &source) { + switch (task->stage) { + case AsOfJoinSourceStage::SORT: + ExecuteSortTask(context, chunk, source); + break; + case AsOfJoinSourceStage::MATERIALIZE: + ExecuteMaterializeTask(context, chunk, source); + break; + case AsOfJoinSourceStage::GET: + ExecuteGetTask(context, chunk, source); + break; + case AsOfJoinSourceStage::LEFT: + ExecuteLeftTask(context, chunk, source); + break; + case AsOfJoinSourceStage::RIGHT: + ExecuteRightTask(context, chunk, source); + break; + case AsOfJoinSourceStage::INIT: + case AsOfJoinSourceStage::DONE: + throw InternalException("Invalid state for AsOf Task"); + } - // Return true if we were not interrupted (another thread died) - bool CombineLeftPartitions(); - bool MergeLeftPartitions(); + if (TaskFinished()) { + ++gsource.finished; + } + } - idx_t BeginRightScan(const idx_t hash_bin); + void BeginRightScan(); + void EndRightScan(); AsOfGlobalSourceState &gsource; - ClientContext &client; + ExecutionContext &context; //! The left side partition being probed AsOfProbeBuffer probe_buffer; - //! The read partition - idx_t hash_bin; + //! The task this thread is working on + TaskPtr task; + //! The task storage + Task task_local; + //! The rhs group HashGroupPtr hash_group; //! The read cursor - unique_ptr scanner; - //! Pointer to the matches - const bool *found_match = {}; + unique_ptr scanner; + //! The right outer buffer + DataChunk rhs_chunk; + //! The right outer slicer + SelectionVector rsel; + //! Pointer to the right marker + const bool *rhs_matches = {}; }; -AsOfLocalSourceState::AsOfLocalSourceState(AsOfGlobalSourceState &gsource, const PhysicalAsOfJoin &op, - ClientContext &client_p) - : gsource(gsource), client(client_p), probe_buffer(gsource.gsink.lhs_sink->context, op) { - gsource.mergers++; +AsOfLocalSourceState::AsOfLocalSourceState(ExecutionContext &context, AsOfGlobalSourceState &gsource, + const PhysicalAsOfJoin &op) + : gsource(gsource), context(context), probe_buffer(context.client, op, gsource), rsel(STANDARD_VECTOR_SIZE) { + rhs_chunk.Initialize(context.client, op.children[1].get().GetTypes()); } -bool AsOfLocalSourceState::CombineLeftPartitions() { - const auto buffer_count = gsource.gsink.lhs_buffers.size(); - while (gsource.combined < buffer_count && !client.interrupted) { - const auto next_combine = gsource.next_combine++; - if (next_combine < buffer_count) { - gsource.gsink.lhs_buffers[next_combine]->Combine(); - ++gsource.combined; - } else { - TaskScheduler::GetScheduler(client).YieldThread(); - } +bool AsOfLocalSourceState::TaskFinished() const { + if (!task) { + return true; } - return !client.interrupted; -} - -bool AsOfLocalSourceState::MergeLeftPartitions() { - PartitionGlobalMergeStates::Callback local_callback; - PartitionLocalMergeState local_merge(*gsource.gsink.lhs_sink); - gsource.GetMergeStates().ExecuteTask(local_merge, local_callback); - gsource.merged++; - while (gsource.merged < gsource.mergers && !client.interrupted) { - TaskScheduler::GetScheduler(client).YieldThread(); + if (task->stage == AsOfJoinSourceStage::LEFT && !probe_buffer.fetch_next_left) { + return false; } - return !client.interrupted; + + return task->begin_idx >= task->end_idx; } -idx_t AsOfLocalSourceState::BeginRightScan(const idx_t hash_bin_p) { - hash_bin = hash_bin_p; +void AsOfLocalSourceState::BeginRightScan() { + const auto hash_bin = task->group_idx; - hash_group = std::move(gsource.gsink.rhs_sink.hash_groups[hash_bin]); - if (hash_group->global_sort->sorted_blocks.empty()) { - return 0; + auto &asof_groups = gsource.asof_groups; + if (hash_bin >= asof_groups.size()) { + return; } - scanner = make_uniq(*hash_group->global_sort); - found_match = gsource.gsink.right_outers[hash_bin].GetMatches(); - return scanner->Remaining(); + hash_group = asof_groups[hash_bin]->right_group.get(); + if (!hash_group || !hash_group->Count()) { + return; + } + auto &gsink = gsource.op.sink_state->Cast(); + scanner = make_uniq(*hash_group, *gsink.hashed_sorts[1]); + scanner->SeekBlock(task->begin_idx); + + rhs_matches = asof_groups[hash_bin]->right_outer.GetMatches(); +} + +void AsOfLocalSourceState::EndRightScan() { + D_ASSERT(task->stage == AsOfJoinSourceStage::RIGHT); + + auto &asof_groups = gsource.asof_groups; + const auto hash_bin = task->group_idx; + const auto &asof_hash_group = asof_groups[hash_bin]; + asof_hash_group->right_completed++; } unique_ptr PhysicalAsOfJoin::GetLocalSourceState(ExecutionContext &context, GlobalSourceState &gstate) const { auto &gsource = gstate.Cast(); - return make_uniq(gsource, *this, context.client); + return make_uniq(context, gsource, *this); } -SourceResultType PhysicalAsOfJoin::GetData(ExecutionContext &context, DataChunk &chunk, - OperatorSourceInput &input) const { - auto &gsource = input.global_state.Cast(); - auto &lsource = input.local_state.Cast(); - auto &rhs_sink = gsource.gsink.rhs_sink; - auto &client = context.client; - - // Step 1: Combine the partitions - if (!lsource.CombineLeftPartitions()) { - return SourceResultType::FINISHED; - } - - // Step 2: Sort on all threads - if (!lsource.MergeLeftPartitions()) { - return SourceResultType::FINISHED; - } - - // Step 3: Join the partitions - auto &lhs_sink = *gsource.gsink.lhs_sink; - const auto left_bins = lhs_sink.grouping_data ? lhs_sink.grouping_data->GetPartitions().size() : 1; - while (gsource.flushed < left_bins) { - // Make sure we have something to flush - if (!lsource.probe_buffer.Scanning()) { - const auto left_bin = gsource.next_left++; - if (left_bin < left_bins) { - // More to flush - lsource.probe_buffer.BeginLeftScan(left_bin); - } else if (!IsRightOuterJoin(join_type) || client.interrupted) { - return SourceResultType::FINISHED; - } else { - // Wait for all threads to finish - // TODO: How to implement a spin wait correctly? - // Returning BLOCKED seems to hang the system. - TaskScheduler::GetScheduler(client).YieldThread(); - continue; - } +void AsOfGlobalSourceState::FinishTask(TaskPtr task) { + // Inside the lock + if (!task) { + return; + } + + const auto group_idx = task->group_idx; + auto &finished_hash_group = asof_groups[group_idx]; + D_ASSERT(finished_hash_group); + + if (finished_hash_group->FinishTask(*task)) { + // Remove it from the active groups + auto &v = active_groups; + v.erase(std::remove(v.begin(), v.end(), group_idx), v.end()); + } +} + +bool AsOfLocalSourceState::TryAssignTask() { + D_ASSERT(TaskFinished()); + // Because downstream operators may be using our internal buffers, + // we can't "finish" a task until we are about to get the next one. + if (task) { + switch (task->stage) { + case AsOfJoinSourceStage::SORT: + gsource.asof_groups[task_local.group_idx]->sorted++; + break; + case AsOfJoinSourceStage::MATERIALIZE: + gsource.asof_groups[task_local.group_idx]->materialized++; + break; + case AsOfJoinSourceStage::GET: + gsource.asof_groups[task_local.group_idx]->gotten++; + break; + case AsOfJoinSourceStage::LEFT: + probe_buffer.EndLeftScan(); + break; + case AsOfJoinSourceStage::RIGHT: + EndRightScan(); + break; + case AsOfJoinSourceStage::INIT: + case AsOfJoinSourceStage::DONE: + break; } + } - lsource.probe_buffer.GetData(context, chunk); - if (chunk.size()) { - return SourceResultType::HAVE_MORE_OUTPUT; - } else if (lsource.probe_buffer.HasMoreData()) { - // Join the next partition - continue; - } else { - lsource.probe_buffer.EndLeftScan(); - gsource.flushed++; + if (!gsource.TryNextTask(task, task_local)) { + return false; + } + + switch (task->stage) { + case AsOfJoinSourceStage::SORT: + case AsOfJoinSourceStage::MATERIALIZE: + case AsOfJoinSourceStage::GET: + break; + case AsOfJoinSourceStage::LEFT: + probe_buffer.BeginLeftScan(*task); + break; + case AsOfJoinSourceStage::RIGHT: + BeginRightScan(); + break; + case AsOfJoinSourceStage::INIT: + case AsOfJoinSourceStage::DONE: + break; + } + + return true; +} + +bool AsOfGlobalSourceState::TryNextTask(TaskPtr &task, Task &task_local) { + auto guard = Lock(); + FinishTask(task); + + if (!HasMoreTasks()) { + task = nullptr; + return false; + } + + // Run through the active groups looking for one that can assign a task + for (const auto &group_idx : active_groups) { + auto &asof_group = asof_groups[group_idx]; + if (asof_group->TryPrepareNextStage()) { + UnblockTasks(guard); + } + if (asof_group->TryNextTask(task_local)) { + task = task_local; + ++started; + return true; } } - // Step 4: Emit right join matches - if (!IsRightOuterJoin(join_type)) { - return SourceResultType::FINISHED; + // All active groups are busy or blocked, so start the next one (if any) + while (next_group < partition_blocks.size()) { + const auto group_idx = partition_blocks[next_group++].second; + active_groups.emplace_back(group_idx); + + auto &asof_group = asof_groups[group_idx]; + if (asof_group->TryPrepareNextStage()) { + UnblockTasks(guard); + } + if (!asof_group->TryNextTask(task_local)) { + // Group has no tasks (empty?) + continue; + } + + task = task_local; + ++started; + return true; } - auto &hash_groups = rhs_sink.hash_groups; - const auto right_groups = hash_groups.size(); + task = nullptr; - DataChunk rhs_chunk; - rhs_chunk.Initialize(Allocator::Get(context.client), rhs_sink.payload_types); - SelectionVector rsel(STANDARD_VECTOR_SIZE); - - while (chunk.size() == 0) { - // Move to the next bin if we are done. - while (!lsource.scanner || !lsource.scanner->Remaining()) { - lsource.scanner.reset(); - lsource.hash_group.reset(); - auto hash_bin = gsource.next_right++; - if (hash_bin >= right_groups) { - return SourceResultType::FINISHED; + return false; +} + +void AsOfLocalSourceState::ExecuteSortTask(ExecutionContext &context, DataChunk &chunk, OperatorSourceInput &source) { + auto &asof_group = *gsource.asof_groups[task_local.group_idx]; + + // Left or right? + const idx_t child = task_local.begin_idx >= asof_group.LeftChunks(); + const auto &gsink = gsource.op.sink_state->Cast(); + auto &hashed_sort = *gsink.hashed_sorts[child]; + auto &hashed_sink = *gsink.hashed_sinks[child]; + + OperatorSinkFinalizeInput finalize {hashed_sink, source.interrupt_state}; + hashed_sort.SortColumnData(context, task_local.group_idx, finalize); + + // Mark this range as done + task->begin_idx = task->end_idx; +} + +void AsOfLocalSourceState::ExecuteMaterializeTask(ExecutionContext &context, DataChunk &chunk, + OperatorSourceInput &source) { + auto &asof_group = *gsource.asof_groups[task_local.group_idx]; + + // Left or right? + const idx_t child = task_local.begin_idx >= asof_group.LeftChunks(); + const auto &gsink = gsource.op.sink_state->Cast(); + auto &hashed_sort = *gsink.hashed_sorts[child]; + auto &hashed_source = *gsource.hashed_sources[child]; + + auto unused = make_uniq(); + OperatorSourceInput hsource {hashed_source, *unused, source.interrupt_state}; + hashed_sort.MaterializeSortedRun(context, task_local.group_idx, hsource); + + // Mark this range as done + task->begin_idx = task->end_idx; +} + +void AsOfLocalSourceState::ExecuteGetTask(ExecutionContext &context, DataChunk &chunk, OperatorSourceInput &source) { + auto &asof_group = *gsource.asof_groups[task_local.group_idx]; + + const auto &gsink = gsource.op.sink_state->Cast(); + auto unused = make_uniq(); + + for (idx_t child = 0; child < gsink.hashed_sorts.size(); ++child) { + // Don't get children that don't exist + if (child) { + if (!asof_group.RightChunks()) { + continue; + } + } else { + if (!asof_group.LeftChunks()) { + continue; } + } - for (; hash_bin < hash_groups.size(); hash_bin = gsource.next_right++) { - if (hash_groups[hash_bin]) { - break; - } + auto &hashed_sort = *gsink.hashed_sorts[child]; + auto &hashed_source = *gsource.hashed_sources[child]; + OperatorSourceInput hsource {hashed_source, *unused, source.interrupt_state}; + + auto group = hashed_sort.GetSortedRun(context.client, task_local.group_idx, hsource); + if (group) { + if (child) { + asof_group.right_group = std::move(group); + } else { + asof_group.left_group = std::move(group); } - lsource.BeginRightScan(hash_bin); } - const auto rhs_position = lsource.scanner->Scanned(); - lsource.scanner->Scan(rhs_chunk); + } - const auto count = rhs_chunk.size(); - if (count == 0) { - return SourceResultType::FINISHED; + // Mark this range as done + task->begin_idx = task->end_idx; +} + +void AsOfLocalSourceState::ExecuteLeftTask(ExecutionContext &context, DataChunk &chunk, OperatorSourceInput &source) { + while (probe_buffer.HasMoreData()) { + probe_buffer.GetData(context, chunk); + if (chunk.size()) { + return; + } + } +} + +SourceResultType PhysicalAsOfJoin::GetData(ExecutionContext &context, DataChunk &chunk, + OperatorSourceInput &input) const { + auto &gsource = input.global_state.Cast(); + auto &lsource = input.local_state.Cast(); + + // Any call to GetData must produce tuples, otherwise the pipeline executor thinks that we're done + // Therefore, we loop until we've produced tuples, or until the operator is actually done + while (gsource.HasUnfinishedTasks() && chunk.size() == 0) { + if (!lsource.TaskFinished() || lsource.TryAssignTask()) { + lsource.ExecuteTask(context, chunk, input); + } else { + auto guard = gsource.Lock(); + if (!gsource.HasMoreTasks()) { + gsource.UnblockTasks(guard); + } else { + // there are more tasks available, but we can't execute them yet + // block the source + return gsource.BlockSource(guard, input.interrupt_state); + } } + } + + return chunk.size() > 0 ? SourceResultType::HAVE_MORE_OUTPUT : SourceResultType::FINISHED; +} + +void AsOfLocalSourceState::ExecuteRightTask(ExecutionContext &context, DataChunk &chunk, OperatorSourceInput &input) { + while (task->begin_idx < task->end_idx) { + const auto rhs_position = scanner->Scanned(); + scanner->Scan(rhs_chunk); + ++task->begin_idx; // figure out which tuples didn't find a match in the RHS - auto found_match = lsource.found_match; + const auto count = rhs_chunk.size(); idx_t result_count = 0; for (idx_t i = 0; i < count; i++) { - if (!found_match[rhs_position + i]) { + if (!rhs_matches[rhs_position + i]) { rsel.set_index(result_count++, i); } } + if (!result_count) { + continue; + } - if (result_count > 0) { - // if there were any tuples that didn't find a match, output them - const idx_t left_column_count = children[0].get().GetTypes().size(); - for (idx_t col_idx = 0; col_idx < left_column_count; ++col_idx) { - chunk.data[col_idx].SetVectorType(VectorType::CONSTANT_VECTOR); - ConstantVector::SetNull(chunk.data[col_idx], true); - } - for (idx_t col_idx = 0; col_idx < right_projection_map.size(); ++col_idx) { - const auto rhs_idx = right_projection_map[col_idx]; - chunk.data[left_column_count + col_idx].Slice(rhs_chunk.data[rhs_idx], rsel, result_count); - } - chunk.SetCardinality(result_count); - break; + // if there were any tuples that didn't find a match, output them + const auto &op = gsource.op; + const idx_t left_column_count = op.children[0].get().GetTypes().size(); + for (idx_t col_idx = 0; col_idx < left_column_count; ++col_idx) { + chunk.data[col_idx].SetVectorType(VectorType::CONSTANT_VECTOR); + ConstantVector::SetNull(chunk.data[col_idx], true); } + for (idx_t col_idx = 0; col_idx < op.right_projection_map.size(); ++col_idx) { + const auto rhs_idx = op.right_projection_map[col_idx]; + chunk.data[left_column_count + col_idx].Slice(rhs_chunk.data[rhs_idx], rsel, result_count); + } + chunk.SetCardinality(result_count); + return; } - return chunk.size() > 0 ? SourceResultType::HAVE_MORE_OUTPUT : SourceResultType::FINISHED; + // Exhausted the task data + scanner.reset(); +} + +//===--------------------------------------------------------------------===// +// Pipeline Construction +//===--------------------------------------------------------------------===// +void PhysicalAsOfJoin::BuildPipelines(Pipeline ¤t, MetaPipeline &meta_pipeline) { + D_ASSERT(children.size() == 2); + if (meta_pipeline.HasRecursiveCTE()) { + throw NotImplementedException("AsOf joins are not supported in recursive CTEs yet"); + } + + // becomes a source after both children fully sink their data + meta_pipeline.GetState().SetPipelineSource(current, *this); + + // Create one child meta pipeline that will hold the LHS and RHS pipelines + auto &child_meta_pipeline = meta_pipeline.CreateChildMetaPipeline(current, *this); + + // Build out RHS first because that is the order the join planner expects. + auto rhs_pipeline = child_meta_pipeline.GetBasePipeline(); + children[1].get().BuildPipelines(*rhs_pipeline, child_meta_pipeline); + + // Build out LHS + auto &lhs_pipeline = child_meta_pipeline.CreatePipeline(); + children[0].get().BuildPipelines(lhs_pipeline, child_meta_pipeline); + + // Despite having the same sink, LHS and everything created after it need their own (same) PipelineFinishEvent + child_meta_pipeline.AddFinishEvent(lhs_pipeline); } } // namespace duckdb diff --git a/src/duckdb/src/execution/operator/join/physical_hash_join.cpp b/src/duckdb/src/execution/operator/join/physical_hash_join.cpp index 9513bded8..ad9a7841c 100644 --- a/src/duckdb/src/execution/operator/join/physical_hash_join.cpp +++ b/src/duckdb/src/execution/operator/join/physical_hash_join.cpp @@ -36,7 +36,6 @@ PhysicalHashJoin::PhysicalHashJoin(PhysicalPlan &physical_plan, LogicalOperator : PhysicalComparisonJoin(physical_plan, op, PhysicalOperatorType::HASH_JOIN, std::move(cond), join_type, estimated_cardinality), delim_types(std::move(delim_types)) { - filter_pushdown = std::move(pushdown_info_p); children.push_back(left); @@ -283,7 +282,7 @@ unique_ptr PhysicalHashJoin::InitializeHashTable(ClientContext &c auto count_fun = CountFunctionBase::GetFunction(); vector> children; // this is a dummy but we need it to make the hash table understand whats going on - children.push_back(make_uniq_base(count_fun.return_type, 0U)); + children.push_back(make_uniq_base(count_fun.GetReturnType(), 0U)); aggr = function_binder.BindAggregateFunction(count_fun, std::move(children), nullptr, AggregateType::NON_DISTINCT); correlated_aggregates.push_back(&*aggr); @@ -391,7 +390,6 @@ static bool KeysAreSkewed(const HashJoinGlobalSinkState &sink) { //! If we have only one thread, always finalize single-threaded. Otherwise, we finalize in parallel if we //! have more than 1M rows or if we want to verify parallelism. static bool FinalizeSingleThreaded(const HashJoinGlobalSinkState &sink, const bool consider_skew) { - // if only one thread, finalize single-threaded const auto num_threads = NumericCast(sink.num_threads); if (num_threads == 1) { @@ -1159,7 +1157,8 @@ unique_ptr PhysicalHashJoin::GetLocalSourceState(ExecutionCont HashJoinGlobalSourceState::HashJoinGlobalSourceState(const PhysicalHashJoin &op, const ClientContext &context) : op(op), global_stage(HashJoinSourceStage::INIT), build_chunk_count(0), build_chunk_done(0), probe_chunk_count(0), probe_chunk_done(0), probe_count(op.children[0].get().estimated_cardinality), - parallel_scan_chunk_count(context.config.verify_parallelism ? 1 : 120) { + parallel_scan_chunk_count(context.config.verify_parallelism ? 1 : 120), full_outer_chunk_count(0), + full_outer_chunk_done(0) { } void HashJoinGlobalSourceState::Initialize(HashJoinGlobalSinkState &sink) { diff --git a/src/duckdb/src/execution/operator/join/physical_iejoin.cpp b/src/duckdb/src/execution/operator/join/physical_iejoin.cpp index 90aba4722..cffb4eb43 100644 --- a/src/duckdb/src/execution/operator/join/physical_iejoin.cpp +++ b/src/duckdb/src/execution/operator/join/physical_iejoin.cpp @@ -1,15 +1,8 @@ -#include - #include "duckdb/execution/operator/join/physical_iejoin.hpp" #include "duckdb/common/atomic.hpp" -#include "duckdb/common/operator/comparison_operators.hpp" #include "duckdb/common/row_operations/row_operations.hpp" -#include "duckdb/common/sort/sort.hpp" -#include "duckdb/common/sort/sorted_block.hpp" -#include "duckdb/common/thread.hpp" -#include "duckdb/common/vector_operations/vector_operations.hpp" -#include "duckdb/execution/expression_executor.hpp" +#include "duckdb/common/sorting/sort_key.hpp" #include "duckdb/main/client_context.hpp" #include "duckdb/parallel/event.hpp" #include "duckdb/parallel/meta_pipeline.hpp" @@ -17,6 +10,8 @@ #include "duckdb/planner/expression/bound_constant_expression.hpp" #include "duckdb/planner/expression/bound_reference_expression.hpp" +#include + namespace duckdb { PhysicalIEJoin::PhysicalIEJoin(PhysicalPlan &physical_plan, LogicalComparisonJoin &op, PhysicalOperator &left, @@ -24,7 +19,6 @@ PhysicalIEJoin::PhysicalIEJoin(PhysicalPlan &physical_plan, LogicalComparisonJoi idx_t estimated_cardinality, unique_ptr pushdown_info) : PhysicalRangeJoin(physical_plan, op, PhysicalOperatorType::IE_JOIN, left, right, std::move(cond), join_type, estimated_cardinality, std::move(pushdown_info)) { - // 1. let L1 (resp. L2) be the array of column X (resp. Y) D_ASSERT(conditions.size() >= 2); for (idx_t i = 0; i < 2; ++i) { @@ -82,17 +76,15 @@ class IEJoinGlobalState : public GlobalSinkState { public: IEJoinGlobalState(ClientContext &context, const PhysicalIEJoin &op) : child(1) { tables.resize(2); - RowLayout lhs_layout; - lhs_layout.Initialize(op.children[0].get().GetTypes()); + const auto &lhs_types = op.children[0].get().GetTypes(); vector lhs_order; lhs_order.emplace_back(op.lhs_orders[0].Copy()); - tables[0] = make_uniq(context, lhs_order, lhs_layout, op); + tables[0] = make_uniq(context, lhs_order, lhs_types, op); - RowLayout rhs_layout; - rhs_layout.Initialize(op.children[1].get().GetTypes()); + const auto &rhs_types = op.children[1].get().GetTypes(); vector rhs_order; rhs_order.emplace_back(op.rhs_orders[0].Copy()); - tables[1] = make_uniq(context, rhs_order, rhs_layout, op); + tables[1] = make_uniq(context, rhs_order, rhs_types, op); if (op.filter_pushdown) { skip_filter_pushdown = op.filter_pushdown->probe_info.empty(); @@ -100,11 +92,18 @@ class IEJoinGlobalState : public GlobalSinkState { } } - void Sink(DataChunk &input, IEJoinLocalState &lstate); - void Finalize(Pipeline &pipeline, Event &event) { + void Sink(ExecutionContext &context, DataChunk &input, IEJoinLocalState &lstate); + + void Finalize(ClientContext &client, InterruptState &interrupt) { // Sort the current input child D_ASSERT(child < tables.size()); - tables[child]->Finalize(pipeline, event); + tables[child]->Finalize(client, interrupt); + }; + + void Materialize(Pipeline &pipeline, Event &event) { + // Sort the current input child + D_ASSERT(child < tables.size()); + tables[child]->Materialize(pipeline, event); child = child ? 0 : 2; skip_filter_pushdown = true; }; @@ -123,9 +122,8 @@ class IEJoinLocalState : public LocalSinkState { public: using LocalSortedTable = PhysicalRangeJoin::LocalSortedTable; - IEJoinLocalState(ClientContext &context, const PhysicalRangeJoin &op, IEJoinGlobalState &gstate) - : table(context, op, gstate.child) { - + IEJoinLocalState(ExecutionContext &context, const PhysicalRangeJoin &op, IEJoinGlobalState &gstate) + : table(context, *gstate.tables[gstate.child], gstate.child) { if (op.filter_pushdown) { local_filter_state = op.filter_pushdown->GetLocalState(*gstate.global_filter_state); } @@ -144,32 +142,23 @@ unique_ptr PhysicalIEJoin::GetGlobalSinkState(ClientContext &co unique_ptr PhysicalIEJoin::GetLocalSinkState(ExecutionContext &context) const { auto &ie_sink = sink_state->Cast(); - return make_uniq(context.client, *this, ie_sink); + return make_uniq(context, *this, ie_sink); } -void IEJoinGlobalState::Sink(DataChunk &input, IEJoinLocalState &lstate) { - auto &table = *tables[child]; - auto &global_sort_state = table.global_sort_state; - auto &local_sort_state = lstate.table.local_sort_state; - +void IEJoinGlobalState::Sink(ExecutionContext &context, DataChunk &input, IEJoinLocalState &lstate) { // Sink the data into the local sort state - lstate.table.Sink(input, global_sort_state); - - // When sorting data reaches a certain size, we sort it - if (local_sort_state.SizeInBytes() >= table.memory_per_thread) { - local_sort_state.Sort(global_sort_state, true); - } + lstate.table.Sink(context, input); } SinkResultType PhysicalIEJoin::Sink(ExecutionContext &context, DataChunk &chunk, OperatorSinkInput &input) const { auto &gstate = input.global_state.Cast(); auto &lstate = input.local_state.Cast(); - if (gstate.child == 0 && gstate.tables[1]->global_sort_state.sorted_blocks.empty() && EmptyResultIfRHSIsEmpty()) { + if (gstate.child == 0 && gstate.tables[1]->Count() == 0 && EmptyResultIfRHSIsEmpty()) { return SinkResultType::FINISHED; } - gstate.Sink(chunk, lstate); + gstate.Sink(context, chunk, lstate); if (filter_pushdown && !gstate.skip_filter_pushdown) { filter_pushdown->Sink(lstate.table.keys, *lstate.local_filter_state); @@ -181,7 +170,7 @@ SinkResultType PhysicalIEJoin::Sink(ExecutionContext &context, DataChunk &chunk, SinkCombineResultType PhysicalIEJoin::Combine(ExecutionContext &context, OperatorSinkCombineInput &input) const { auto &gstate = input.global_state.Cast(); auto &lstate = input.local_state.Cast(); - gstate.tables[gstate.child]->Combine(lstate.table); + gstate.tables[gstate.child]->Combine(context, lstate.table); auto &client_profiler = QueryProfiler::Get(context.client); context.thread.profiler.Flush(*this); @@ -197,14 +186,13 @@ SinkCombineResultType PhysicalIEJoin::Combine(ExecutionContext &context, Operato //===--------------------------------------------------------------------===// // Finalize //===--------------------------------------------------------------------===// -SinkFinalizeType PhysicalIEJoin::Finalize(Pipeline &pipeline, Event &event, ClientContext &context, +SinkFinalizeType PhysicalIEJoin::Finalize(Pipeline &pipeline, Event &event, ClientContext &client, OperatorSinkFinalizeInput &input) const { auto &gstate = input.global_state.Cast(); if (filter_pushdown && !gstate.skip_filter_pushdown) { - (void)filter_pushdown->Finalize(context, nullptr, *gstate.global_filter_state, *this); + (void)filter_pushdown->Finalize(client, nullptr, *gstate.global_filter_state, *this); } auto &table = *gstate.tables[gstate.child]; - auto &global_sort_state = table.global_sort_state; if ((gstate.child == 1 && PropagatesBuildSide(join_type)) || (gstate.child == 0 && IsLeftOuterJoin(join_type))) { // for FULL/LEFT/RIGHT OUTER JOIN, initialize found_match to false for every tuple @@ -212,15 +200,18 @@ SinkFinalizeType PhysicalIEJoin::Finalize(Pipeline &pipeline, Event &event, Clie } SinkFinalizeType res; - if (gstate.child == 1 && global_sort_state.sorted_blocks.empty() && EmptyResultIfRHSIsEmpty()) { + if (gstate.child == 1 && table.Count() == 0 && EmptyResultIfRHSIsEmpty()) { // Empty input! res = SinkFinalizeType::NO_OUTPUT_POSSIBLE; } else { res = SinkFinalizeType::READY; } + // Clean up the current table + gstate.Finalize(client, input.interrupt_state); + // Move to the next input child - gstate.Finalize(pipeline, event); + gstate.Materialize(pipeline, event); return res; } @@ -236,43 +227,70 @@ OperatorResultType PhysicalIEJoin::ExecuteInternal(ExecutionContext &context, Da //===--------------------------------------------------------------------===// // Source //===--------------------------------------------------------------------===// +enum class IEJoinSourceStage : uint8_t { INIT, INNER, OUTER, DONE }; + struct IEJoinUnion { using SortedTable = PhysicalRangeJoin::GlobalSortedTable; + using ChunkRange = std::pair; - static idx_t AppendKey(SortedTable &table, ExpressionExecutor &executor, SortedTable &marked, int64_t increment, - int64_t base, const idx_t block_idx); - - static void Sort(SortedTable &table) { - auto &global_sort_state = table.global_sort_state; - global_sort_state.PrepareMergePhase(); - while (global_sort_state.sorted_blocks.size() > 1) { - global_sort_state.InitializeMergeRound(); - MergeSorter merge_sorter(global_sort_state, global_sort_state.buffer_manager); - merge_sorter.PerformInMergeRound(); - global_sort_state.CompleteMergeRound(true); + // Comparison utilities + static bool IsStrictComparison(ExpressionType comparison) { + switch (comparison) { + case ExpressionType::COMPARE_LESSTHAN: + case ExpressionType::COMPARE_GREATERTHAN: + return true; + case ExpressionType::COMPARE_LESSTHANOREQUALTO: + case ExpressionType::COMPARE_GREATERTHANOREQUALTO: + return false; + default: + throw InternalException("Unimplemented comparison type for IEJoin!"); } } + template + static inline bool Compare(const T &lhs, const T &rhs, const bool strict) { + const bool less_than = lhs < rhs; + if (!less_than && !strict) { + return !(rhs < lhs); + } + return less_than; + } + + template + static bool TemplatedCompareKeys(ExternalBlockIteratorState &state1, const idx_t pos1, + ExternalBlockIteratorState &state2, const idx_t pos2, bool strict); + + static bool CompareKeys(ExternalBlockIteratorState &state1, const idx_t pos1, ExternalBlockIteratorState &state2, + const idx_t pos2, bool strict, const SortKeyType &sort_key_type); + + static bool CompareBounds(SortedTable &t1, const ChunkRange &b1, SortedTable &t2, const ChunkRange &b2, + bool strict); + + static idx_t AppendKey(ExecutionContext &context, InterruptState &interrupt, SortedTable &table, + ExpressionExecutor &executor, SortedTable &marked, int64_t increment, int64_t rid, + const ChunkRange &range); + + static void Sort(ExecutionContext &context, InterruptState &interrupt, SortedTable &table) { + table.Finalize(context.client, interrupt); + table.Materialize(context, interrupt); + } + template static vector ExtractColumn(SortedTable &table, idx_t col_idx) { vector result; result.reserve(table.count); - auto &gstate = table.global_sort_state; - auto &blocks = *gstate.sorted_blocks[0]->payload_data; - PayloadScanner scanner(blocks, gstate, false); + auto &collection = *table.sorted->payload_data; + vector scan_ids(1, col_idx); + TupleDataScanState state; + collection.InitializeScan(state, scan_ids); DataChunk payload; - payload.Initialize(Allocator::DefaultAllocator(), gstate.payload_layout.GetTypes()); - for (;;) { - payload.Reset(); - scanner.Scan(payload); - const auto count = payload.size(); - if (!count) { - break; - } + collection.InitializeScanChunk(state, payload); - const auto data_ptr = FlatVector::GetData(payload.data[col_idx]); + while (collection.Scan(state, payload)) { + const auto count = payload.size(); + const auto data_ptr = FlatVector::GetData(payload.data[0]); for (idx_t i = 0; i < count; i++) { result.push_back(UnsafeNumericCast(data_ptr[i])); } @@ -281,12 +299,40 @@ struct IEJoinUnion { return result; } - IEJoinUnion(ClientContext &context, const PhysicalIEJoin &op, SortedTable &t1, const idx_t b1, SortedTable &t2, - const idx_t b2); + class UnionIterator { + public: + UnionIterator(SortedTable &table, bool strict) : state(table.CreateIteratorState()), strict(strict) { + } + + inline idx_t GetIndex() const { + return index; + } + + inline void SetIndex(idx_t i) { + index = i; + } + + UnionIterator &operator++() { + ++index; + return *this; + } + + unique_ptr state; + idx_t index = 0; + const bool strict; + }; + + IEJoinUnion(ExecutionContext &context, const PhysicalIEJoin &op, SortedTable &t1, const ChunkRange &b1, + SortedTable &t2, const ChunkRange &b2); idx_t SearchL1(idx_t pos); + + template bool NextRow(); + using next_row_t = bool (duckdb::IEJoinUnion::*)(); + next_row_t next_row_func; + //! Inverted loop idx_t JoinComplexBlocks(SelectionVector &lsel, SelectionVector &rsel); @@ -314,49 +360,64 @@ struct IEJoinUnion { idx_t n; idx_t i; idx_t j; - unique_ptr op1; - unique_ptr off1; - unique_ptr op2; - unique_ptr off2; + unique_ptr op1; + unique_ptr off1; + unique_ptr op2; + unique_ptr off2; int64_t lrid; }; -idx_t IEJoinUnion::AppendKey(SortedTable &table, ExpressionExecutor &executor, SortedTable &marked, int64_t increment, - int64_t base, const idx_t block_idx) { - LocalSortState local_sort_state; - local_sort_state.Initialize(marked.global_sort_state, marked.global_sort_state.buffer_manager); +idx_t IEJoinUnion::AppendKey(ExecutionContext &context, InterruptState &interrupt, SortedTable &table, + ExpressionExecutor &executor, SortedTable &marked, int64_t increment, int64_t rid, + const ChunkRange &chunk_range) { + const auto chunk_begin = chunk_range.first; + const auto chunk_end = chunk_range.second; // Reading const auto valid = table.count - table.has_null; - auto &gstate = table.global_sort_state; - PayloadScanner scanner(gstate, block_idx); - auto table_idx = block_idx * gstate.block_capacity; + auto &source = *table.sorted->payload_data; + TupleDataScanState scanner; + source.InitializeScan(scanner); DataChunk scanned; - scanned.Initialize(Allocator::DefaultAllocator(), scanner.GetPayloadTypes()); + source.InitializeScanChunk(scanner, scanned); - // Writing - auto types = local_sort_state.sort_layout->logical_types; - const idx_t payload_idx = types.size(); + // TODO: Random access into TupleDataCollection (NextScanIndex is private...) + idx_t table_idx = 0; + for (idx_t i = 0; i < chunk_begin; ++i) { + source.Scan(scanner, scanned); + table_idx += scanned.size(); + } - const auto &payload_types = local_sort_state.payload_layout->GetTypes(); - types.insert(types.end(), payload_types.begin(), payload_types.end()); - const idx_t rid_idx = types.size() - 1; + // Writing + auto &sort = *marked.sort; + auto local_sort_state = sort.GetLocalSinkState(context); + vector types; + for (const auto &expr : executor.expressions) { + types.emplace_back(expr->return_type); + } + const idx_t rid_idx = types.size(); + types.emplace_back(LogicalType::BIGINT); DataChunk keys; DataChunk payload; keys.Initialize(Allocator::DefaultAllocator(), types); + OperatorSinkInput sink {*marked.global_sink, *local_sort_state, interrupt}; idx_t inserted = 0; - for (auto rid = base; table_idx < valid;) { - scanned.Reset(); - scanner.Scan(scanned); + for (auto chunk_idx = chunk_begin; chunk_idx < chunk_end; ++chunk_idx) { + source.Scan(scanner, scanned); // NULLs are at the end, so stop when we reach them auto scan_count = scanned.size(); if (table_idx + scan_count > valid) { - scan_count = valid - table_idx; - scanned.SetCardinality(scan_count); + if (table_idx >= valid) { + scan_count = 0; + ; + } else { + scan_count = valid - table_idx; + scanned.SetCardinality(scan_count); + } } if (scan_count == 0) { break; @@ -375,43 +436,88 @@ idx_t IEJoinUnion::AppendKey(SortedTable &table, ExpressionExecutor &executor, S rid += increment * UnsafeNumericCast(scan_count); // Sort on the sort columns (which will no longer be needed) - keys.Split(payload, payload_idx); - local_sort_state.SinkChunk(keys, payload); + sort.Sink(context, keys, sink); inserted += scan_count; - keys.Fuse(payload); - - // Flush when we have enough data - if (local_sort_state.SizeInBytes() >= marked.memory_per_thread) { - local_sort_state.Sort(marked.global_sort_state, true); - } } - marked.global_sort_state.AddLocalState(local_sort_state); + OperatorSinkCombineInput combine {*marked.global_sink, *local_sort_state, interrupt}; + sort.Combine(context, combine); marked.count += inserted; return inserted; } -IEJoinUnion::IEJoinUnion(ClientContext &context, const PhysicalIEJoin &op, SortedTable &t1, const idx_t b1, - SortedTable &t2, const idx_t b2) +// TODO: Function pointers? +template +bool IEJoinUnion::TemplatedCompareKeys(ExternalBlockIteratorState &state1, const idx_t pos1, + ExternalBlockIteratorState &state2, const idx_t pos2, bool strict) { + using SORT_KEY = SortKey; + using BLOCKS_ITERATOR = block_iterator_t; + + BLOCKS_ITERATOR bounds1(state1, pos1); + BLOCKS_ITERATOR bounds2(state2, pos2); + + return Compare(*bounds1, *bounds2, strict); +} + +bool IEJoinUnion::CompareKeys(ExternalBlockIteratorState &state1, const idx_t pos1, ExternalBlockIteratorState &state2, + const idx_t pos2, bool strict, const SortKeyType &sort_key_type) { + switch (sort_key_type) { + case SortKeyType::NO_PAYLOAD_FIXED_8: + return TemplatedCompareKeys(state1, pos1, state2, pos2, strict); + case SortKeyType::NO_PAYLOAD_FIXED_16: + return TemplatedCompareKeys(state1, pos1, state2, pos2, strict); + case SortKeyType::NO_PAYLOAD_FIXED_24: + return TemplatedCompareKeys(state1, pos1, state2, pos2, strict); + case SortKeyType::NO_PAYLOAD_FIXED_32: + return TemplatedCompareKeys(state1, pos1, state2, pos2, strict); + case SortKeyType::NO_PAYLOAD_VARIABLE_32: + return TemplatedCompareKeys(state1, pos1, state2, pos2, strict); + case SortKeyType::PAYLOAD_FIXED_16: + return TemplatedCompareKeys(state1, pos1, state2, pos2, strict); + case SortKeyType::PAYLOAD_FIXED_24: + return TemplatedCompareKeys(state1, pos1, state2, pos2, strict); + case SortKeyType::PAYLOAD_FIXED_32: + return TemplatedCompareKeys(state1, pos1, state2, pos2, strict); + case SortKeyType::PAYLOAD_VARIABLE_32: + return TemplatedCompareKeys(state1, pos1, state2, pos2, strict); + default: + throw NotImplementedException("IEJoinUnion::CompareKeys for %s", EnumUtil::ToString(sort_key_type)); + } +} + +bool IEJoinUnion::CompareBounds(SortedTable &t1, const ChunkRange &b1, SortedTable &t2, const ChunkRange &b2, + bool strict) { + auto &keys1 = *t1.sorted->key_data; + ExternalBlockIteratorState state1(keys1, nullptr); + const idx_t pos1 = t1.BlockStart(b1.first); + + auto &keys2 = *t2.sorted->key_data; + ExternalBlockIteratorState state2(keys2, nullptr); + const idx_t pos2 = t2.BlockEnd(b2.second - 1); + + const auto sort_key_type = t1.GetSortKeyType(); + D_ASSERT(sort_key_type == t2.GetSortKeyType()); + return CompareKeys(state1, pos1, state2, pos2, strict, sort_key_type); +} + +IEJoinUnion::IEJoinUnion(ExecutionContext &context, const PhysicalIEJoin &op, SortedTable &t1, const ChunkRange &b1, + SortedTable &t2, const ChunkRange &b2) : n(0), i(0) { // input : query Q with 2 join predicates t1.X op1 t2.X' and t1.Y op2 t2.Y', tables T, T' of sizes m and n resp. // output: a list of tuple pairs (ti , tj) // Note that T/T' are already sorted on X/X' and contain the payload data // We only join the two block numbers and use the sizes of the blocks as the counts + InterruptState interrupt; + // 0. Filter out tables with no overlap - if (!t1.BlockSize(b1) || !t2.BlockSize(b2)) { + if (t1.sorted->key_data->ChunkCount() <= b1.first || t2.sorted->key_data->ChunkCount() <= b2.first) { return; } - const auto &cmp1 = op.conditions[0].comparison; - SBIterator bounds1(t1.global_sort_state, cmp1); - SBIterator bounds2(t2.global_sort_state, cmp1); - // t1.X[0] op1 t2.X'[-1] - bounds1.SetIndex(bounds1.block_capacity * b1); - bounds2.SetIndex(bounds2.block_capacity * b2 + t2.BlockSize(b2) - 1); - if (!bounds1.Compare(bounds2)) { + const auto strict1 = IsStrictComparison(op.conditions[0].comparison); + if (!CompareBounds(t1, b1, t2, b2, strict1)) { return; } @@ -428,8 +534,6 @@ IEJoinUnion::IEJoinUnion(ClientContext &context, const PhysicalIEJoin &op, Sorte vector types; types.emplace_back(order2.expression->return_type); types.emplace_back(LogicalType::BIGINT); - RowLayout payload_layout; - payload_layout.Initialize(types); // Sort on the first expression auto ref = make_uniq(order1.expression->return_type, 0U); @@ -451,37 +555,37 @@ IEJoinUnion::IEJoinUnion(ClientContext &context, const PhysicalIEJoin &op, Sorte // Using this OrderType, if i < j then value[i] (from left table) and value[j] (from right table) match // the condition (t1.time <= t2.time or t1.time < t2.time), then from_left will force them into the correct order. auto from_left = make_uniq(Value::BOOLEAN(true)); - orders.emplace_back(SBIterator::ComparisonValue(cmp1) == 0 ? OrderType::DESCENDING : OrderType::ASCENDING, - OrderByNullType::ORDER_DEFAULT, std::move(from_left)); + orders.emplace_back(!strict1 ? OrderType::DESCENDING : OrderType::ASCENDING, OrderByNullType::ORDER_DEFAULT, + std::move(from_left)); - l1 = make_uniq(context, orders, payload_layout, op); + l1 = make_uniq(context.client, orders, types, op); // LHS has positive rids - ExpressionExecutor l_executor(context); + ExpressionExecutor l_executor(context.client); l_executor.AddExpression(*order1.expression); // add const column true auto left_const = make_uniq(Value::BOOLEAN(true)); l_executor.AddExpression(*left_const); l_executor.AddExpression(*order2.expression); - AppendKey(t1, l_executor, *l1, 1, 1, b1); + AppendKey(context, interrupt, t1, l_executor, *l1, 1, 1, b1); // RHS has negative rids - ExpressionExecutor r_executor(context); + ExpressionExecutor r_executor(context.client); r_executor.AddExpression(*op.rhs_orders[0].expression); // add const column flase auto right_const = make_uniq(Value::BOOLEAN(false)); r_executor.AddExpression(*right_const); r_executor.AddExpression(*op.rhs_orders[1].expression); - AppendKey(t2, r_executor, *l1, -1, -1, b2); + AppendKey(context, interrupt, t2, r_executor, *l1, -1, -1, b2); - if (l1->global_sort_state.sorted_blocks.empty()) { + if (!l1->Count()) { return; } - Sort(*l1); + Sort(context, interrupt, *l1); - op1 = make_uniq(l1->global_sort_state, cmp1); - off1 = make_uniq(l1->global_sort_state, cmp1); + op1 = make_uniq(*l1, strict1); + off1 = make_uniq(*l1, strict1); // We don't actually need the L1 column, just its sort key, which is in the sort blocks li = ExtractColumn(*l1, types.size() - 1); @@ -493,22 +597,19 @@ IEJoinUnion::IEJoinUnion(ClientContext &context, const PhysicalIEJoin &op, Sorte // For this we just need a two-column table of Y, P types.clear(); types.emplace_back(LogicalType::BIGINT); - payload_layout.Initialize(types); // Sort on the first expression orders.clear(); ref = make_uniq(order2.expression->return_type, 0U); orders.emplace_back(order2.type, order2.null_order, std::move(ref)); - ExpressionExecutor executor(context); + ExpressionExecutor executor(context.client); executor.AddExpression(*orders[0].expression); - l2 = make_uniq(context, orders, payload_layout, op); - for (idx_t base = 0, block_idx = 0; block_idx < l1->BlockCount(); ++block_idx) { - base += AppendKey(*l1, executor, *l2, 1, NumericCast(base), block_idx); - } + l2 = make_uniq(context.client, orders, types, op); + AppendKey(context, interrupt, *l1, executor, *l2, 1, 0, {0, l1->BlockCount()}); - Sort(*l2); + Sort(context, interrupt, *l2); // We don't actually need the L2 column, just its sort key, which is in the sort blocks @@ -526,15 +627,57 @@ IEJoinUnion::IEJoinUnion(ClientContext &context, const PhysicalIEJoin &op, Sorte bloom_filter.Initialize(bloom_array.data(), bloom_count); // 11. for(i←1 to n) do - const auto &cmp2 = op.conditions[1].comparison; - op2 = make_uniq(l2->global_sort_state, cmp2); - off2 = make_uniq(l2->global_sort_state, cmp2); + const auto strict2 = IsStrictComparison(op.conditions[1].comparison); + op2 = make_uniq(*l2, strict2); + off2 = make_uniq(*l2, strict2); i = 0; j = 0; - (void)NextRow(); + + const auto sort_key_type = l2->GetSortKeyType(); + switch (sort_key_type) { + case SortKeyType::NO_PAYLOAD_FIXED_8: + next_row_func = &IEJoinUnion::NextRow; + break; + case SortKeyType::NO_PAYLOAD_FIXED_16: + next_row_func = &IEJoinUnion::NextRow; + break; + case SortKeyType::NO_PAYLOAD_FIXED_24: + next_row_func = &IEJoinUnion::NextRow; + break; + case SortKeyType::NO_PAYLOAD_FIXED_32: + next_row_func = &IEJoinUnion::NextRow; + break; + case SortKeyType::NO_PAYLOAD_VARIABLE_32: + next_row_func = &IEJoinUnion::NextRow; + break; + case SortKeyType::PAYLOAD_FIXED_16: + next_row_func = &IEJoinUnion::NextRow; + break; + case SortKeyType::PAYLOAD_FIXED_24: + next_row_func = &IEJoinUnion::NextRow; + break; + case SortKeyType::PAYLOAD_FIXED_32: + next_row_func = &IEJoinUnion::NextRow; + break; + case SortKeyType::PAYLOAD_VARIABLE_32: + next_row_func = &IEJoinUnion::NextRow; + break; + default: + throw NotImplementedException("IEJoinUnion for %s", EnumUtil::ToString(sort_key_type)); + } + + (this->*next_row_func)(); } +template bool IEJoinUnion::NextRow() { + using SORT_KEY = SortKey; + using BLOCKS_ITERATOR = block_iterator_t; + + BLOCKS_ITERATOR off2_itr(*off2->state); + BLOCKS_ITERATOR op2_itr(*op2->state); + const auto strict = off2->strict; + for (; i < n; ++i) { // 12. pos ← P[i] auto pos = p[i]; @@ -546,7 +689,7 @@ bool IEJoinUnion::NextRow() { // 16. B[pos] ← 1 op2->SetIndex(i); for (; off2->GetIndex() < n; ++(*off2)) { - if (!off2->Compare(*op2)) { + if (!Compare(off2_itr[off2->GetIndex()], op2_itr[op2->GetIndex()], strict)) { break; } const auto p2 = p[off2->GetIndex()]; @@ -652,7 +795,7 @@ idx_t IEJoinUnion::JoinComplexBlocks(SelectionVector &lsel, SelectionVector &rse } ++i; - if (!NextRow()) { + if (!(this->*next_row_func)()) { break; } } @@ -660,13 +803,83 @@ idx_t IEJoinUnion::JoinComplexBlocks(SelectionVector &lsel, SelectionVector &rse return result_count; } +class IEJoinLocalSourceState; + +class IEJoinGlobalSourceState : public GlobalSourceState { +public: + IEJoinGlobalSourceState(const PhysicalIEJoin &op, IEJoinGlobalState &gsink) + : op(op), gsink(gsink), stage(IEJoinSourceStage::INIT), next_pair(0), completed(0), left_outers(0), + next_left(0), right_outers(0), next_right(0) { + auto &left_table = *gsink.tables[0]; + auto &right_table = *gsink.tables[1]; + + left_blocks = left_table.BlockCount(); + left_ranges = (left_blocks + left_per_thread - 1) / left_per_thread; + + right_blocks = right_table.BlockCount(); + right_ranges = (right_blocks + right_per_thread - 1) / right_per_thread; + + pair_count = left_ranges * right_ranges; + } + + void Initialize(); + bool TryPrepareNextStage(); + bool AssignTask(ExecutionContext &context, IEJoinLocalSourceState &lstate); + +public: + idx_t MaxThreads() override; + + ProgressData GetProgress() const; + + const PhysicalIEJoin &op; + IEJoinGlobalState &gsink; + + atomic stage; + + // Join queue state + idx_t left_blocks = 0; + idx_t left_ranges = 0; + const idx_t left_per_thread = 1024; + idx_t right_blocks = 0; + idx_t right_ranges = 0; + const idx_t right_per_thread = 1024; + idx_t pair_count; + atomic next_pair; + atomic completed; + + // Outer joins + atomic left_outers; + atomic next_left; + + atomic right_outers; + atomic next_right; +}; + class IEJoinLocalSourceState : public LocalSourceState { public: - explicit IEJoinLocalSourceState(ClientContext &context, const PhysicalIEJoin &op) - : op(op), true_sel(STANDARD_VECTOR_SIZE), left_executor(context), right_executor(context), - left_matches(nullptr), right_matches(nullptr) { - auto &allocator = Allocator::Get(context); - unprojected.Initialize(allocator, op.unprojected_types); + IEJoinLocalSourceState(ClientContext &client, IEJoinGlobalSourceState &gsource) + : gsource(gsource), lsel(STANDARD_VECTOR_SIZE), rsel(STANDARD_VECTOR_SIZE), true_sel(STANDARD_VECTOR_SIZE), + left_executor(client), right_executor(client), left_matches(nullptr), right_matches(nullptr) + + { + auto &op = gsource.op; + auto &allocator = Allocator::Get(client); + unprojected.InitializeEmpty(op.unprojected_types); + lpayload.Initialize(allocator, op.children[0].get().GetTypes()); + rpayload.Initialize(allocator, op.children[1].get().GetTypes()); + + auto &ie_sink = op.sink_state->Cast(); + auto &left_table = *ie_sink.tables[0]; + auto &right_table = *ie_sink.tables[1]; + + left_iterator = left_table.CreateIteratorState(); + right_iterator = right_table.CreateIteratorState(); + + left_table.InitializePayloadState(left_chunk_state); + right_table.InitializePayloadState(right_chunk_state); + + left_scan_state = left_table.CreateScanState(client); + right_scan_state = right_table.CreateScanState(client); if (op.conditions.size() < 3) { return; @@ -703,16 +916,40 @@ class IEJoinLocalSourceState : public LocalSourceState { return count; } - const PhysicalIEJoin &op; + // Are we executing a task? + bool TaskFinished() const { + return !joiner && !left_matches && !right_matches; + } + + // resolve joins that can potentially output N*M elements (INNER, LEFT, FULL) + void ResolveComplexJoin(ExecutionContext &context, DataChunk &result); + // Resolve left join results + void ExecuteLeftTask(ExecutionContext &context, DataChunk &result); + // Resolve right join results + void ExecuteRightTask(ExecutionContext &context, DataChunk &result); + // Execute the current task + void ExecuteTask(ExecutionContext &context, DataChunk &result); + + IEJoinGlobalSourceState &gsource; // Joining unique_ptr joiner; idx_t left_base; idx_t left_block_index; + unique_ptr left_iterator; + TupleDataChunkState left_chunk_state; + SelectionVector lsel; + DataChunk lpayload; + unique_ptr left_scan_state; idx_t right_base; idx_t right_block_index; + unique_ptr right_iterator; + TupleDataChunkState right_chunk_state; + SelectionVector rsel; + DataChunk rpayload; + unique_ptr right_scan_state; // Trailing predicates SelectionVector true_sel; @@ -732,254 +969,246 @@ class IEJoinLocalSourceState : public LocalSourceState { bool *right_matches; }; -void PhysicalIEJoin::ResolveComplexJoin(ExecutionContext &context, DataChunk &result, LocalSourceState &state_p) const { - auto &state = state_p.Cast(); - auto &ie_sink = sink_state->Cast(); +void IEJoinLocalSourceState::ExecuteTask(ExecutionContext &context, DataChunk &result) { + if (joiner) { + ResolveComplexJoin(context, result); + } else if (left_matches != nullptr) { + ExecuteLeftTask(context, result); + } else if (right_matches != nullptr) { + ExecuteRightTask(context, result); + } +} + +void IEJoinLocalSourceState::ResolveComplexJoin(ExecutionContext &context, DataChunk &result) { + auto &op = gsource.op; + auto &ie_sink = op.sink_state->Cast(); + const auto &conditions = op.conditions; + + auto &chunk = unprojected; + auto &left_table = *ie_sink.tables[0]; + const auto left_cols = op.children[0].get().GetTypes().size(); + auto &right_table = *ie_sink.tables[1]; - const auto left_cols = children[0].get().GetTypes().size(); - auto &chunk = state.unprojected; do { - SelectionVector lsel(STANDARD_VECTOR_SIZE); - SelectionVector rsel(STANDARD_VECTOR_SIZE); - auto result_count = state.joiner->JoinComplexBlocks(lsel, rsel); + auto result_count = joiner->JoinComplexBlocks(lsel, rsel); if (result_count == 0) { // exhausted this pair + joiner.reset(); + ++gsource.completed; return; } // found matches: extract them - chunk.Reset(); - SliceSortedPayload(chunk, left_table.global_sort_state, state.left_block_index, lsel, result_count, 0); - SliceSortedPayload(chunk, right_table.global_sort_state, state.right_block_index, rsel, result_count, - left_cols); - chunk.SetCardinality(result_count); + left_table.Repin(*left_iterator); + right_table.Repin(*right_iterator); + + op.SliceSortedPayload(lpayload, left_table, *left_iterator, left_chunk_state, left_block_index, lsel, + result_count, *left_scan_state); + op.SliceSortedPayload(rpayload, right_table, *right_iterator, right_chunk_state, right_block_index, rsel, + result_count, *right_scan_state); auto sel = FlatVector::IncrementalSelectionVector(); if (conditions.size() > 2) { // If there are more expressions to compute, - // split the result chunk into the left and right halves - // so we can compute the values for comparison. + // use the left and right payloads + // to we can compute the values for comparison. const auto tail_cols = conditions.size() - 2; - DataChunk right_chunk; - chunk.Split(right_chunk, left_cols); - state.left_executor.SetChunk(chunk); - state.right_executor.SetChunk(right_chunk); + left_executor.SetChunk(lpayload); + right_executor.SetChunk(rpayload); auto tail_count = result_count; - auto true_sel = &state.true_sel; + auto match_sel = &true_sel; for (size_t cmp_idx = 0; cmp_idx < tail_cols; ++cmp_idx) { - auto &left = state.left_keys.data[cmp_idx]; - state.left_executor.ExecuteExpression(cmp_idx, left); + auto &left = left_keys.data[cmp_idx]; + left_executor.ExecuteExpression(cmp_idx, left); - auto &right = state.right_keys.data[cmp_idx]; - state.right_executor.ExecuteExpression(cmp_idx, right); + auto &right = right_keys.data[cmp_idx]; + right_executor.ExecuteExpression(cmp_idx, right); if (tail_count < result_count) { left.Slice(*sel, tail_count); right.Slice(*sel, tail_count); } - tail_count = SelectJoinTail(conditions[cmp_idx + 2].comparison, left, right, sel, tail_count, true_sel); - sel = true_sel; + tail_count = + op.SelectJoinTail(conditions[cmp_idx + 2].comparison, left, right, sel, tail_count, match_sel); + sel = match_sel; } - chunk.Fuse(right_chunk); if (tail_count < result_count) { result_count = tail_count; - chunk.Slice(*sel, result_count); + lpayload.Slice(*sel, result_count); + rpayload.Slice(*sel, result_count); + } + } + + // Merge the payloads + chunk.Reset(); + for (column_t col_idx = 0; col_idx < chunk.ColumnCount(); ++col_idx) { + if (col_idx < left_cols) { + chunk.data[col_idx].Reference(lpayload.data[col_idx]); + } else { + chunk.data[col_idx].Reference(rpayload.data[col_idx - left_cols]); } } + chunk.SetCardinality(result_count); // We need all of the data to compute other predicates, // but we only return what is in the projection map - ProjectResult(chunk, result); + op.ProjectResult(chunk, result); // found matches: mark the found matches if required if (left_table.found_match) { for (idx_t i = 0; i < result_count; i++) { - left_table.found_match[state.left_base + lsel[sel->get_index(i)]] = true; + left_table.found_match[left_base + lsel[sel->get_index(i)]] = true; } } if (right_table.found_match) { for (idx_t i = 0; i < result_count; i++) { - right_table.found_match[state.right_base + rsel[sel->get_index(i)]] = true; + right_table.found_match[right_base + rsel[sel->get_index(i)]] = true; } } result.Verify(); } while (result.size() == 0); } -class IEJoinGlobalSourceState : public GlobalSourceState { -public: - explicit IEJoinGlobalSourceState(const PhysicalIEJoin &op, IEJoinGlobalState &gsink) - : op(op), gsink(gsink), initialized(false), next_pair(0), completed(0), left_outers(0), next_left(0), - right_outers(0), next_right(0) { +void IEJoinGlobalSourceState::Initialize() { + auto guard = Lock(); + if (stage != IEJoinSourceStage::INIT) { + return; } - void Initialize() { - auto guard = Lock(); - if (initialized) { - return; - } + // Compute the starting row for each block + auto &left_table = *gsink.tables[0]; + const auto left_blocks = left_table.BlockCount(); - // Compute the starting row for reach block - // (In theory these are all the same size, but you never know...) - auto &left_table = *gsink.tables[0]; - const auto left_blocks = left_table.BlockCount(); - idx_t left_base = 0; + auto &right_table = *gsink.tables[1]; + const auto right_blocks = right_table.BlockCount(); - for (size_t lhs = 0; lhs < left_blocks; ++lhs) { - left_bases.emplace_back(left_base); - left_base += left_table.BlockSize(lhs); - } + // Outer join block counts + if (left_table.found_match) { + left_outers = left_blocks; + } - auto &right_table = *gsink.tables[1]; - const auto right_blocks = right_table.BlockCount(); - idx_t right_base = 0; - for (size_t rhs = 0; rhs < right_blocks; ++rhs) { - right_bases.emplace_back(right_base); - right_base += right_table.BlockSize(rhs); - } + if (right_table.found_match) { + right_outers = right_blocks; + } - // Outer join block counts - if (left_table.found_match) { - left_outers = left_blocks; + // Ready for action + stage = IEJoinSourceStage::INNER; +} +bool IEJoinGlobalSourceState::TryPrepareNextStage() { + // Inside lock + switch (stage.load()) { + case IEJoinSourceStage::INNER: + if (completed >= pair_count) { + stage = IEJoinSourceStage::OUTER; + return true; } - - if (right_table.found_match) { - right_outers = right_blocks; + break; + case IEJoinSourceStage::OUTER: + if (next_left >= left_outers && next_right >= right_outers) { + stage = IEJoinSourceStage::DONE; + return true; } - - // Ready for action - initialized = true; - } - -public: - idx_t MaxThreads() override { - // We can't leverage any more threads than block pairs. - const auto &sink_state = (op.sink_state->Cast()); - return sink_state.tables[0]->BlockCount() * sink_state.tables[1]->BlockCount(); + break; + default: + break; } - void GetNextPair(ClientContext &client, IEJoinLocalSourceState &lstate) { - auto &left_table = *gsink.tables[0]; - auto &right_table = *gsink.tables[1]; + return false; +} - const auto left_blocks = left_table.BlockCount(); - const auto right_blocks = right_table.BlockCount(); - const auto pair_count = left_blocks * right_blocks; +idx_t IEJoinGlobalSourceState::MaxThreads() { + // We can't leverage any more threads than block pairs. + const auto &sink_state = (op.sink_state->Cast()); + return sink_state.tables[0]->BlockCount() * sink_state.tables[1]->BlockCount(); +} - // Regular block - const auto i = next_pair++; - if (i < pair_count) { - const auto b1 = i / right_blocks; - const auto b2 = i % right_blocks; +bool IEJoinGlobalSourceState::AssignTask(ExecutionContext &context, IEJoinLocalSourceState &lstate) { + auto guard = Lock(); - lstate.left_block_index = b1; - lstate.left_base = left_bases[b1]; + using ChunkRange = IEJoinUnion::ChunkRange; + auto &left_table = *gsink.tables[0]; + auto &right_table = *gsink.tables[1]; - lstate.right_block_index = b2; - lstate.right_base = right_bases[b2]; + // Regular block + switch (stage.load()) { + case IEJoinSourceStage::INNER: + if (next_pair < pair_count) { + const auto i = next_pair++; + const auto b1 = (i / right_ranges) * left_per_thread; + const auto b2 = (i % right_ranges) * right_per_thread; - lstate.joiner = make_uniq(client, op, left_table, b1, right_table, b2); - return; - } + ChunkRange l_range {b1, MinValue(left_blocks, b1 + left_per_thread)}; + lstate.left_block_index = l_range.first; + lstate.left_base = left_table.BlockStart(l_range.first); - // Outer joins - if (!left_outers && !right_outers) { - return; - } + ChunkRange r_range {b2, MinValue(right_blocks, b2 + right_per_thread)}; + lstate.right_block_index = r_range.first; + lstate.right_base = right_table.BlockStart(r_range.first); - // Spin wait for regular blocks to finish(!) - while (completed < pair_count) { - std::this_thread::yield(); + lstate.joiner = make_uniq(context, op, left_table, l_range, right_table, r_range); + return true; } - + break; + case IEJoinSourceStage::OUTER: // Left outer blocks - const auto l = next_left++; - if (l < left_outers) { + if (next_left < left_outers) { + const auto l = next_left++; lstate.joiner = nullptr; lstate.left_block_index = l; - lstate.left_base = left_bases[l]; + lstate.left_base = left_table.BlockStart(l); lstate.left_matches = left_table.found_match.get() + lstate.left_base; lstate.outer_idx = 0; lstate.outer_count = left_table.BlockSize(l); - return; + return true; } else { lstate.left_matches = nullptr; } - // Right outer block - const auto r = next_right++; - if (r < right_outers) { + // Right outer blocks + if (next_right < right_outers) { + const auto r = next_right++; lstate.joiner = nullptr; lstate.right_block_index = r; - lstate.right_base = right_bases[r]; + lstate.right_base = right_table.BlockStart(r); lstate.right_matches = right_table.found_match.get() + lstate.right_base; lstate.outer_idx = 0; lstate.outer_count = right_table.BlockSize(r); - return; + return true; } else { lstate.right_matches = nullptr; } + break; + default: + break; } - void PairCompleted(ClientContext &client, IEJoinLocalSourceState &lstate) { - lstate.joiner.reset(); - ++completed; - GetNextPair(client, lstate); - } - - ProgressData GetProgress() const { - auto &left_table = *gsink.tables[0]; - auto &right_table = *gsink.tables[1]; - - const auto left_blocks = left_table.BlockCount(); - const auto right_blocks = right_table.BlockCount(); - const auto pair_count = left_blocks * right_blocks; + return false; +} - const auto count = pair_count + left_outers + right_outers; +ProgressData IEJoinGlobalSourceState::GetProgress() const { + const auto count = pair_count + left_outers + right_outers; - const auto l = MinValue(next_left.load(), left_outers.load()); - const auto r = MinValue(next_right.load(), right_outers.load()); - const auto returned = completed.load() + l + r; + const auto l = MinValue(next_left.load(), left_outers.load()); + const auto r = MinValue(next_right.load(), right_outers.load()); + const auto returned = completed.load() + l + r; - ProgressData res; - if (count) { - res.done = double(returned); - res.total = double(count); - } else { - res.SetInvalid(); - } - return res; + ProgressData res; + if (count) { + res.done = double(returned); + res.total = double(count); + } else { + res.SetInvalid(); } - - const PhysicalIEJoin &op; - IEJoinGlobalState &gsink; - - bool initialized; - - // Join queue state - atomic next_pair; - atomic completed; - - // Block base row number - vector left_bases; - vector right_bases; - - // Outer joins - atomic left_outers; - atomic next_left; - - atomic right_outers; - atomic next_right; -}; - + return res; +} unique_ptr PhysicalIEJoin::GetGlobalSourceState(ClientContext &context) const { auto &gsink = sink_state->Cast(); return make_uniq(*this, gsink); @@ -987,7 +1216,8 @@ unique_ptr PhysicalIEJoin::GetGlobalSourceState(ClientContext unique_ptr PhysicalIEJoin::GetLocalSourceState(ExecutionContext &context, GlobalSourceState &gstate) const { - return make_uniq(context.client, *this); + auto &gsource = gstate.Cast(); + return make_uniq(context.client, gsource); } ProgressData PhysicalIEJoin::GetProgress(ClientContext &context, GlobalSourceState &gsource_p) const { @@ -997,80 +1227,97 @@ ProgressData PhysicalIEJoin::GetProgress(ClientContext &context, GlobalSourceSta SourceResultType PhysicalIEJoin::GetData(ExecutionContext &context, DataChunk &result, OperatorSourceInput &input) const { - auto &ie_sink = sink_state->Cast(); - auto &ie_gstate = input.global_state.Cast(); - auto &ie_lstate = input.local_state.Cast(); + auto &gsource = input.global_state.Cast(); + auto &lsource = input.local_state.Cast(); - ie_gstate.Initialize(); + gsource.Initialize(); - if (!ie_lstate.joiner && !ie_lstate.left_matches && !ie_lstate.right_matches) { - ie_gstate.GetNextPair(context.client, ie_lstate); + // Any call to GetData must produce tuples, otherwise the pipeline executor thinks that we're done + // Therefore, we loop until we've produced tuples, or until the operator is actually done + while (gsource.stage != IEJoinSourceStage::DONE && result.size() == 0) { + if (!lsource.TaskFinished() || gsource.AssignTask(context, lsource)) { + lsource.ExecuteTask(context, result); + } else { + auto guard = gsource.Lock(); + if (gsource.TryPrepareNextStage() || gsource.stage == IEJoinSourceStage::DONE) { + gsource.UnblockTasks(guard); + } else { + return gsource.BlockSource(guard, input.interrupt_state); + } + } } + return result.size() == 0 ? SourceResultType::FINISHED : SourceResultType::HAVE_MORE_OUTPUT; +} - // Process INNER results - while (ie_lstate.joiner) { - ResolveComplexJoin(context, result, ie_lstate); +void IEJoinLocalSourceState::ExecuteLeftTask(ExecutionContext &context, DataChunk &result) { + auto &op = gsource.op; + auto &ie_sink = op.sink_state->Cast(); - if (result.size()) { - return SourceResultType::HAVE_MORE_OUTPUT; - } + const auto left_cols = op.children[0].get().GetTypes().size(); + auto &chunk = unprojected; - ie_gstate.PairCompleted(context.client, ie_lstate); + const idx_t count = SelectOuterRows(left_matches); + if (!count) { + left_matches = nullptr; + return; } - // Process LEFT OUTER results - const auto left_cols = children[0].get().GetTypes().size(); - while (ie_lstate.left_matches) { - const idx_t count = ie_lstate.SelectOuterRows(ie_lstate.left_matches); - if (!count) { - ie_gstate.GetNextPair(context.client, ie_lstate); - continue; - } - auto &chunk = ie_lstate.unprojected; - chunk.Reset(); - SliceSortedPayload(chunk, ie_sink.tables[0]->global_sort_state, ie_lstate.left_block_index, ie_lstate.true_sel, - count); + auto &left_table = *ie_sink.tables[0]; + + left_table.Repin(*left_iterator); + op.SliceSortedPayload(lpayload, left_table, *left_iterator, left_chunk_state, left_block_index, true_sel, count, + *left_scan_state); - // Fill in NULLs to the right - for (auto col_idx = left_cols; col_idx < chunk.ColumnCount(); ++col_idx) { + // Fill in NULLs to the right + chunk.Reset(); + for (column_t col_idx = 0; col_idx < chunk.ColumnCount(); ++col_idx) { + if (col_idx < left_cols) { + chunk.data[col_idx].Reference(lpayload.data[col_idx]); + } else { chunk.data[col_idx].SetVectorType(VectorType::CONSTANT_VECTOR); ConstantVector::SetNull(chunk.data[col_idx], true); } + } - ProjectResult(chunk, result); - result.SetCardinality(count); - result.Verify(); + op.ProjectResult(chunk, result); + result.SetCardinality(count); + result.Verify(); +} + +void IEJoinLocalSourceState::ExecuteRightTask(ExecutionContext &context, DataChunk &result) { + auto &op = gsource.op; + auto &ie_sink = op.sink_state->Cast(); + const auto left_cols = op.children[0].get().GetTypes().size(); - return result.size() == 0 ? SourceResultType::FINISHED : SourceResultType::HAVE_MORE_OUTPUT; + auto &chunk = unprojected; + + const idx_t count = SelectOuterRows(right_matches); + if (!count) { + right_matches = nullptr; + return; } - // Process RIGHT OUTER results - while (ie_lstate.right_matches) { - const idx_t count = ie_lstate.SelectOuterRows(ie_lstate.right_matches); - if (!count) { - ie_gstate.GetNextPair(context.client, ie_lstate); - continue; - } + auto &right_table = *ie_sink.tables[1]; + auto &rsel = true_sel; - auto &chunk = ie_lstate.unprojected; - chunk.Reset(); - SliceSortedPayload(chunk, ie_sink.tables[1]->global_sort_state, ie_lstate.right_block_index, ie_lstate.true_sel, - count, left_cols); + right_table.Repin(*right_iterator); + op.SliceSortedPayload(rpayload, right_table, *right_iterator, right_chunk_state, right_block_index, rsel, count, + *right_scan_state); - // Fill in NULLs to the left - for (idx_t col_idx = 0; col_idx < left_cols; ++col_idx) { + // Fill in NULLs to the left + chunk.Reset(); + for (column_t col_idx = 0; col_idx < chunk.ColumnCount(); ++col_idx) { + if (col_idx < left_cols) { chunk.data[col_idx].SetVectorType(VectorType::CONSTANT_VECTOR); ConstantVector::SetNull(chunk.data[col_idx], true); + } else { + chunk.data[col_idx].Reference(rpayload.data[col_idx - left_cols]); } - - ProjectResult(chunk, result); - result.SetCardinality(count); - result.Verify(); - - break; } - return result.size() == 0 ? SourceResultType::FINISHED : SourceResultType::HAVE_MORE_OUTPUT; + op.ProjectResult(chunk, result); + result.SetCardinality(count); + result.Verify(); } //===--------------------------------------------------------------------===// diff --git a/src/duckdb/src/execution/operator/join/physical_nested_loop_join.cpp b/src/duckdb/src/execution/operator/join/physical_nested_loop_join.cpp index d96cda05d..7793a0ba2 100644 --- a/src/duckdb/src/execution/operator/join/physical_nested_loop_join.cpp +++ b/src/duckdb/src/execution/operator/join/physical_nested_loop_join.cpp @@ -1,7 +1,5 @@ #include "duckdb/execution/operator/join/physical_nested_loop_join.hpp" #include "duckdb/parallel/thread_context.hpp" -#include "duckdb/common/operator/comparison_operators.hpp" -#include "duckdb/common/vector_operations/vector_operations.hpp" #include "duckdb/execution/expression_executor.hpp" #include "duckdb/execution/nested_loop_join.hpp" #include "duckdb/main/client_context.hpp" @@ -9,20 +7,22 @@ namespace duckdb { -PhysicalNestedLoopJoin::PhysicalNestedLoopJoin(PhysicalPlan &physical_plan, LogicalOperator &op, PhysicalOperator &left, - PhysicalOperator &right, vector cond, JoinType join_type, +PhysicalNestedLoopJoin::PhysicalNestedLoopJoin(PhysicalPlan &physical_plan, LogicalComparisonJoin &op, + PhysicalOperator &left, PhysicalOperator &right, + vector cond, JoinType join_type, idx_t estimated_cardinality, unique_ptr pushdown_info_p) : PhysicalComparisonJoin(physical_plan, op, PhysicalOperatorType::NESTED_LOOP_JOIN, std::move(cond), join_type, - estimated_cardinality) { - + estimated_cardinality), + predicate(std::move(op.predicate)) { filter_pushdown = std::move(pushdown_info_p); children.push_back(left); children.push_back(right); } -PhysicalNestedLoopJoin::PhysicalNestedLoopJoin(PhysicalPlan &physical_plan, LogicalOperator &op, PhysicalOperator &left, - PhysicalOperator &right, vector cond, JoinType join_type, +PhysicalNestedLoopJoin::PhysicalNestedLoopJoin(PhysicalPlan &physical_plan, LogicalComparisonJoin &op, + PhysicalOperator &left, PhysicalOperator &right, + vector cond, JoinType join_type, idx_t estimated_cardinality) : PhysicalNestedLoopJoin(physical_plan, op, left, right, std::move(cond), join_type, estimated_cardinality, nullptr) { @@ -273,7 +273,7 @@ class PhysicalNestedLoopJoinState : public CachingOperatorState { PhysicalNestedLoopJoinState(ClientContext &context, const PhysicalNestedLoopJoin &op, const vector &conditions) : fetch_next_left(true), fetch_next_right(false), lhs_executor(context), left_tuple(0), right_tuple(0), - left_outer(IsLeftOuterJoin(op.join_type)) { + left_outer(IsLeftOuterJoin(op.join_type)), pred_executor(context) { vector condition_types; for (auto &cond : conditions) { lhs_executor.AddExpression(*cond.left); @@ -284,6 +284,11 @@ class PhysicalNestedLoopJoinState : public CachingOperatorState { right_condition.Initialize(allocator, condition_types); right_payload.Initialize(allocator, op.children[1].get().GetTypes()); left_outer.Initialize(STANDARD_VECTOR_SIZE); + + if (op.predicate) { + pred_executor.AddExpression(*op.predicate); + pred_matches.Initialize(); + } } bool fetch_next_left; @@ -302,6 +307,10 @@ class PhysicalNestedLoopJoinState : public CachingOperatorState { OuterJoinMarker left_outer; + //! Predicate + ExpressionExecutor pred_executor; + SelectionVector pred_matches; + public: void Finalize(const PhysicalOperator &op, ExecutionContext &context) override { context.thread.profiler.Flush(op); @@ -438,11 +447,20 @@ OperatorResultType PhysicalNestedLoopJoin::ResolveComplexJoin(ExecutionContext & if (match_count > 0) { // we have matching tuples! // construct the result - state.left_outer.SetMatches(lvector, match_count); - gstate.right_outer.SetMatches(rvector, match_count, state.condition_scan_state.current_row_index); - chunk.Slice(input, lvector, match_count); chunk.Slice(right_payload, rvector, match_count, input.ColumnCount()); + + // If we have a predicate, apply it to the result + if (predicate) { + auto &sel = state.pred_matches; + match_count = state.pred_executor.SelectExpression(chunk, sel); + chunk.Slice(sel, match_count); + lvector.SliceInPlace(sel, match_count); + rvector.SliceInPlace(sel, match_count); + } + + state.left_outer.SetMatches(lvector, match_count); + gstate.right_outer.SetMatches(rvector, match_count, state.condition_scan_state.current_row_index); } // check if we exhausted the RHS, if we did we need to move to the next right chunk in the next iteration diff --git a/src/duckdb/src/execution/operator/join/physical_piecewise_merge_join.cpp b/src/duckdb/src/execution/operator/join/physical_piecewise_merge_join.cpp index 1bd48ab62..e4faffac1 100644 --- a/src/duckdb/src/execution/operator/join/physical_piecewise_merge_join.cpp +++ b/src/duckdb/src/execution/operator/join/physical_piecewise_merge_join.cpp @@ -1,11 +1,8 @@ #include "duckdb/execution/operator/join/physical_piecewise_merge_join.hpp" -#include "duckdb/common/fast_mem.hpp" -#include "duckdb/common/operator/comparison_operators.hpp" #include "duckdb/common/row_operations/row_operations.hpp" -#include "duckdb/common/sort/comparators.hpp" -#include "duckdb/common/sort/sort.hpp" -#include "duckdb/common/vector_operations/vector_operations.hpp" +#include "duckdb/common/sorting/sort_key.hpp" +#include "duckdb/common/types/row/block_iterator.hpp" #include "duckdb/execution/expression_executor.hpp" #include "duckdb/execution/operator/join/outer_join_marker.hpp" #include "duckdb/main/client_context.hpp" @@ -21,7 +18,6 @@ PhysicalPiecewiseMergeJoin::PhysicalPiecewiseMergeJoin(PhysicalPlan &physical_pl unique_ptr pushdown_info_p) : PhysicalRangeJoin(physical_plan, op, PhysicalOperatorType::PIECEWISE_MERGE_JOIN, left, right, std::move(cond), join_type, estimated_cardinality, std::move(pushdown_info_p)) { - for (auto &join_cond : conditions) { D_ASSERT(join_cond.left->return_type == join_cond.right->return_type); join_key_types.push_back(join_cond.left->return_type); @@ -65,15 +61,14 @@ class MergeJoinGlobalState : public GlobalSinkState { using GlobalSortedTable = PhysicalRangeJoin::GlobalSortedTable; public: - MergeJoinGlobalState(ClientContext &context, const PhysicalPiecewiseMergeJoin &op) { - RowLayout rhs_layout; - rhs_layout.Initialize(op.children[1].get().GetTypes()); + MergeJoinGlobalState(ClientContext &client, const PhysicalPiecewiseMergeJoin &op) { + const auto &rhs_types = op.children[1].get().GetTypes(); vector rhs_order; rhs_order.emplace_back(op.rhs_orders[0].Copy()); - table = make_uniq(context, rhs_order, rhs_layout, op); + table = make_uniq(client, rhs_order, rhs_types, op); if (op.filter_pushdown) { skip_filter_pushdown = op.filter_pushdown->probe_info.empty(); - global_filter_state = op.filter_pushdown->GetGlobalState(context, op); + global_filter_state = op.filter_pushdown->GetGlobalState(client, op); } } @@ -81,8 +76,9 @@ class MergeJoinGlobalState : public GlobalSinkState { return table->count; } - void Sink(DataChunk &input, MergeJoinLocalState &lstate); + void Sink(ExecutionContext &context, DataChunk &input, MergeJoinLocalState &lstate); + //! The sorted table unique_ptr table; //! Should we not bother pushing down filters? bool skip_filter_pushdown = false; @@ -92,16 +88,19 @@ class MergeJoinGlobalState : public GlobalSinkState { class MergeJoinLocalState : public LocalSinkState { public: - explicit MergeJoinLocalState(ClientContext &context, const PhysicalRangeJoin &op, MergeJoinGlobalState &gstate, - const idx_t child) - : table(context, op, child) { + using GlobalSortedTable = PhysicalRangeJoin::GlobalSortedTable; + using LocalSortedTable = PhysicalRangeJoin::LocalSortedTable; + + MergeJoinLocalState(ExecutionContext &context, MergeJoinGlobalState &gstate, const idx_t child) + : table(context, *gstate.table, child) { + auto &op = gstate.table->op; if (op.filter_pushdown) { local_filter_state = op.filter_pushdown->GetLocalState(*gstate.global_filter_state); } } //! The local sort state - PhysicalRangeJoin::LocalSortedTable table; + LocalSortedTable table; //! Local state for accumulating filter statistics unique_ptr local_filter_state; }; @@ -113,20 +112,12 @@ unique_ptr PhysicalPiecewiseMergeJoin::GetGlobalSinkState(Clien unique_ptr PhysicalPiecewiseMergeJoin::GetLocalSinkState(ExecutionContext &context) const { // We only sink the RHS auto &gstate = sink_state->Cast(); - return make_uniq(context.client, *this, gstate, 1U); + return make_uniq(context, gstate, 1U); } -void MergeJoinGlobalState::Sink(DataChunk &input, MergeJoinLocalState &lstate) { - auto &global_sort_state = table->global_sort_state; - auto &local_sort_state = lstate.table.local_sort_state; - +void MergeJoinGlobalState::Sink(ExecutionContext &context, DataChunk &input, MergeJoinLocalState &lstate) { // Sink the data into the local sort state - lstate.table.Sink(input, global_sort_state); - - // When sorting data reaches a certain size, we sort it - if (local_sort_state.SizeInBytes() >= table->memory_per_thread) { - local_sort_state.Sort(global_sort_state, true); - } + lstate.table.Sink(context, input); } SinkResultType PhysicalPiecewiseMergeJoin::Sink(ExecutionContext &context, DataChunk &chunk, @@ -134,7 +125,7 @@ SinkResultType PhysicalPiecewiseMergeJoin::Sink(ExecutionContext &context, DataC auto &gstate = input.global_state.Cast(); auto &lstate = input.local_state.Cast(); - gstate.Sink(chunk, lstate); + gstate.Sink(context, chunk, lstate); if (filter_pushdown && !gstate.skip_filter_pushdown) { filter_pushdown->Sink(lstate.table.keys, *lstate.local_filter_state); @@ -147,7 +138,7 @@ SinkCombineResultType PhysicalPiecewiseMergeJoin::Combine(ExecutionContext &cont OperatorSinkCombineInput &input) const { auto &gstate = input.global_state.Cast(); auto &lstate = input.local_state.Cast(); - gstate.table->Combine(lstate.table); + gstate.table->Combine(context, lstate.table); auto &client_profiler = QueryProfiler::Get(context.client); context.thread.profiler.Flush(*this); @@ -162,25 +153,28 @@ SinkCombineResultType PhysicalPiecewiseMergeJoin::Combine(ExecutionContext &cont //===--------------------------------------------------------------------===// // Finalize //===--------------------------------------------------------------------===// -SinkFinalizeType PhysicalPiecewiseMergeJoin::Finalize(Pipeline &pipeline, Event &event, ClientContext &context, +SinkFinalizeType PhysicalPiecewiseMergeJoin::Finalize(Pipeline &pipeline, Event &event, ClientContext &client, OperatorSinkFinalizeInput &input) const { auto &gstate = input.global_state.Cast(); if (filter_pushdown && !gstate.skip_filter_pushdown) { - (void)filter_pushdown->Finalize(context, nullptr, *gstate.global_filter_state, *this); + (void)filter_pushdown->Finalize(client, nullptr, *gstate.global_filter_state, *this); } - auto &global_sort_state = gstate.table->global_sort_state; + + gstate.table->Finalize(client, input.interrupt_state); if (PropagatesBuildSide(join_type)) { // for FULL/RIGHT OUTER JOIN, initialize found_match to false for every tuple gstate.table->IntializeMatches(); } - if (global_sort_state.sorted_blocks.empty() && EmptyResultIfRHSIsEmpty()) { + + if (gstate.table->Count() == 0 && EmptyResultIfRHSIsEmpty()) { // Empty input! + gstate.table->MaterializeEmpty(client); return SinkFinalizeType::NO_OUTPUT_POSSIBLE; } // Sort the current input child - gstate.table->Finalize(pipeline, event); + gstate.table->Materialize(pipeline, event); return SinkFinalizeType::READY; } @@ -191,46 +185,50 @@ SinkFinalizeType PhysicalPiecewiseMergeJoin::Finalize(Pipeline &pipeline, Event class PiecewiseMergeJoinState : public CachingOperatorState { public: using LocalSortedTable = PhysicalRangeJoin::LocalSortedTable; + using GlobalSortedTable = PhysicalRangeJoin::GlobalSortedTable; - PiecewiseMergeJoinState(ClientContext &context, const PhysicalPiecewiseMergeJoin &op, bool force_external) - : context(context), allocator(Allocator::Get(context)), op(op), - buffer_manager(BufferManager::GetBufferManager(context)), force_external(force_external), - left_outer(IsLeftOuterJoin(op.join_type)), left_position(0), first_fetch(true), finished(true), - right_position(0), right_chunk_index(0), rhs_executor(context) { - vector condition_types; - for (auto &order : op.lhs_orders) { - condition_types.push_back(order.expression->return_type); - } + PiecewiseMergeJoinState(ClientContext &client, const PhysicalPiecewiseMergeJoin &op) + : client(client), allocator(Allocator::Get(client)), op(op), left_outer(IsLeftOuterJoin(op.join_type)), + left_position(0), first_fetch(true), finished(true), right_position(0), right_chunk_index(0), + rhs_executor(client) { left_outer.Initialize(STANDARD_VECTOR_SIZE); - lhs_layout.Initialize(op.children[0].get().GetTypes()); - lhs_payload.Initialize(allocator, op.children[0].get().GetTypes()); + lhs_payload.Initialize(client, op.children[0].get().GetTypes()); + // Sort on the first column lhs_order.emplace_back(op.lhs_orders[0].Copy()); // Set up shared data for multiple predicates sel.Initialize(STANDARD_VECTOR_SIZE); - condition_types.clear(); + vector condition_types; for (auto &order : op.rhs_orders) { rhs_executor.AddExpression(*order.expression); condition_types.push_back(order.expression->return_type); } - rhs_keys.Initialize(allocator, condition_types); + rhs_keys.Initialize(client, condition_types); + rhs_input.Initialize(client, op.children[1].get().GetTypes()); + + auto &gsink = op.sink_state->Cast(); + auto &rhs_table = *gsink.table; + rhs_iterator = rhs_table.CreateIteratorState(); + rhs_table.InitializePayloadState(rhs_chunk_state); + rhs_scan_state = rhs_table.CreateScanState(client); + + // Since we have now materialized the payload, the keys will not have payloads? + sort_key_type = rhs_table.GetSortKeyType(); } - ClientContext &context; + ClientContext &client; Allocator &allocator; const PhysicalPiecewiseMergeJoin &op; - BufferManager &buffer_manager; - bool force_external; // Block sorting DataChunk lhs_payload; OuterJoinMarker left_outer; vector lhs_order; - RowLayout lhs_layout; + unique_ptr lhs_global_table; unique_ptr lhs_local_table; - unique_ptr lhs_global_state; - unique_ptr scanner; + SortKeyType sort_key_type; + TupleDataScanState lhs_scan; // Simple scans idx_t left_position; @@ -238,178 +236,127 @@ class PiecewiseMergeJoinState : public CachingOperatorState { // Complex scans bool first_fetch; bool finished; + unique_ptr lhs_iterator; + unique_ptr rhs_iterator; idx_t right_position; idx_t right_chunk_index; idx_t right_base; idx_t prev_left_index; + TupleDataChunkState rhs_chunk_state; + unique_ptr rhs_scan_state; // Secondary predicate shared data SelectionVector sel; DataChunk rhs_keys; DataChunk rhs_input; ExpressionExecutor rhs_executor; - vector payload_heap_handles; public: - void ResolveJoinKeys(DataChunk &input) { + void ResolveJoinKeys(ExecutionContext &context, DataChunk &input) { // sort by join key - lhs_global_state = make_uniq(context, lhs_order, lhs_layout); - lhs_local_table = make_uniq(context, op, 0U); - lhs_local_table->Sink(input, *lhs_global_state); - - // Set external (can be forced with the PRAGMA) - lhs_global_state->external = force_external; - lhs_global_state->AddLocalState(lhs_local_table->local_sort_state); - lhs_global_state->PrepareMergePhase(); - while (lhs_global_state->sorted_blocks.size() > 1) { - MergeSorter merge_sorter(*lhs_global_state, buffer_manager); - merge_sorter.PerformInMergeRound(); - lhs_global_state->CompleteMergeRound(); - } - - // Scan the sorted payload - D_ASSERT(lhs_global_state->sorted_blocks.size() == 1); - - scanner = make_uniq(*lhs_global_state->sorted_blocks[0]->payload_data, *lhs_global_state); - lhs_payload.Reset(); - scanner->Scan(lhs_payload); + const auto &lhs_types = lhs_payload.GetTypes(); + lhs_global_table = make_uniq(context.client, lhs_order, lhs_types, op); + lhs_local_table = make_uniq(context, *lhs_global_table, 0U); + lhs_local_table->Sink(context, input); + lhs_global_table->Combine(context, *lhs_local_table); + + InterruptState interrupt; + lhs_global_table->Finalize(context.client, interrupt); + lhs_global_table->Materialize(context, interrupt); + + // Scan the sorted payload (minus the primary sort column) + auto &lhs_table = *lhs_global_table; + auto &lhs_payload_data = *lhs_table.sorted->payload_data; + lhs_payload_data.InitializeScan(lhs_scan); + lhs_payload_data.Scan(lhs_scan, lhs_payload); // Recompute the sorted keys from the sorted input - lhs_local_table->keys.Reset(); - lhs_local_table->executor.Execute(lhs_payload, lhs_local_table->keys); - } + auto &lhs_keys = lhs_local_table->keys; + lhs_keys.Reset(); + lhs_local_table->executor.Execute(lhs_payload, lhs_keys); - void Finalize(const PhysicalOperator &op, ExecutionContext &context) override { - if (lhs_local_table) { - context.thread.profiler.Flush(op); - } + lhs_iterator = lhs_table.CreateIteratorState(); } }; unique_ptr PhysicalPiecewiseMergeJoin::GetOperatorState(ExecutionContext &context) const { - bool force_external = ClientConfig::GetConfig(context.client).force_external; - return make_uniq(context.client, *this, force_external); + return make_uniq(context.client, *this); } -static inline idx_t SortedBlockNotNull(const idx_t base, const idx_t count, const idx_t not_null) { - return MinValue(base + count, MaxValue(base, not_null)) - base; +static inline idx_t SortedChunkNotNull(const idx_t chunk_idx, const idx_t count, const idx_t has_null) { + const auto chunk_begin = chunk_idx * STANDARD_VECTOR_SIZE; + const auto chunk_end = MinValue(chunk_begin + STANDARD_VECTOR_SIZE, count); + const auto not_null = count - has_null; + return MinValue(chunk_end, MaxValue(chunk_begin, not_null)) - chunk_begin; } -static int MergeJoinComparisonValue(ExpressionType comparison) { +static bool MergeJoinStrictComparison(ExpressionType comparison) { switch (comparison) { case ExpressionType::COMPARE_LESSTHAN: case ExpressionType::COMPARE_GREATERTHAN: - return -1; + return true; case ExpressionType::COMPARE_LESSTHANOREQUALTO: case ExpressionType::COMPARE_GREATERTHANOREQUALTO: - return 0; + return false; default: throw InternalException("Unimplemented comparison type for merge join!"); } } -struct BlockMergeInfo { - GlobalSortState &state; - //! The block being scanned - const idx_t block_idx; - //! The number of not-NULL values in the block (they are at the end) - const idx_t not_null; - //! The current offset in the block - idx_t &entry_idx; - SelectionVector result; - - BlockMergeInfo(GlobalSortState &state, idx_t block_idx, idx_t &entry_idx, idx_t not_null) - : state(state), block_idx(block_idx), not_null(not_null), entry_idx(entry_idx), result(STANDARD_VECTOR_SIZE) { - } -}; - -static void MergeJoinPinSortingBlock(SBScanState &scan, const idx_t block_idx) { - scan.SetIndices(block_idx, 0); - scan.PinRadix(block_idx); - - auto &sd = *scan.sb->blob_sorting_data; - if (block_idx < sd.data_blocks.size()) { - scan.PinData(sd); +// Compare using +bool MergeJoinBefore(const T &lhs, const T &rhs, const bool strict) { + const bool less_than = lhs < rhs; + if (!less_than && !strict) { + return !(rhs < lhs); } + return less_than; } -static data_ptr_t MergeJoinRadixPtr(SBScanState &scan, const idx_t entry_idx) { - scan.entry_idx = entry_idx; - return scan.RadixPtr(); -} +template +static idx_t TemplatedMergeJoinSimpleBlocks(PiecewiseMergeJoinState &lstate, MergeJoinGlobalState &gstate, + bool *found_match, const bool strict) { + using SORT_KEY = SortKey; + using BLOCK_ITERATOR = block_iterator_t; -static idx_t MergeJoinSimpleBlocks(PiecewiseMergeJoinState &lstate, MergeJoinGlobalState &rstate, bool *found_match, - const ExpressionType comparison) { - const auto cmp = MergeJoinComparisonValue(comparison); - - // The sort parameters should all be the same - auto &lsort = *lstate.lhs_global_state; - auto &rsort = rstate.table->global_sort_state; - D_ASSERT(lsort.sort_layout.all_constant == rsort.sort_layout.all_constant); - const auto all_constant = lsort.sort_layout.all_constant; - D_ASSERT(lsort.external == rsort.external); - const auto external = lsort.external; - - // There should only be one sorted block if they have been sorted - D_ASSERT(lsort.sorted_blocks.size() == 1); - SBScanState lread(lsort.buffer_manager, lsort); - lread.sb = lsort.sorted_blocks[0].get(); - - const idx_t l_block_idx = 0; - idx_t l_entry_idx = 0; - const auto lhs_not_null = lstate.lhs_local_table->count - lstate.lhs_local_table->has_null; - MergeJoinPinSortingBlock(lread, l_block_idx); - auto l_ptr = MergeJoinRadixPtr(lread, l_entry_idx); - - D_ASSERT(rsort.sorted_blocks.size() == 1); - SBScanState rread(rsort.buffer_manager, rsort); - rread.sb = rsort.sorted_blocks[0].get(); + // We only need the keys because we are extracting the row numbers + auto &lhs_table = *lstate.lhs_global_table; + D_ASSERT(SORT_KEY_TYPE == lhs_table.GetSortKeyType()); + auto &lhs_iterator = *lstate.lhs_iterator; + const auto lhs_not_null = lhs_table.count - lhs_table.has_null; - const auto cmp_size = lsort.sort_layout.comparison_size; - const auto entry_size = lsort.sort_layout.entry_size; + auto &rhs_table = *gstate.table; + auto &rhs_iterator = *lstate.rhs_iterator; + const auto rhs_not_null = rhs_table.count - rhs_table.has_null; - idx_t right_base = 0; - for (idx_t r_block_idx = 0; r_block_idx < rread.sb->radix_sorting_data.size(); r_block_idx++) { - // we only care about the BIGGEST value in each of the RHS data blocks + idx_t l_entry_idx = 0; + BLOCK_ITERATOR lhs_itr(lhs_iterator); + BLOCK_ITERATOR rhs_itr(rhs_iterator); + for (idx_t r_idx = 0; r_idx < rhs_not_null; r_idx += STANDARD_VECTOR_SIZE) { + // Repin the RHS to release memory + // This is safe because we only return the LHS values + // Note we only do this for the RHS because the LHS is only one chunk. + rhs_table.Repin(rhs_iterator); + + // we only care about the BIGGEST value in the RHS // because we want to figure out if the LHS values are less than [or equal] to ANY value - // get the biggest value from the RHS chunk - MergeJoinPinSortingBlock(rread, r_block_idx); - - auto &rblock = *rread.sb->radix_sorting_data[r_block_idx]; - const auto r_not_null = - SortedBlockNotNull(right_base, rblock.count, rstate.table->count - rstate.table->has_null); - if (r_not_null == 0) { - break; - } - const auto r_entry_idx = r_not_null - 1; - right_base += rblock.count; - - auto r_ptr = MergeJoinRadixPtr(rread, r_entry_idx); + const auto r_entry_idx = MinValue(r_idx + STANDARD_VECTOR_SIZE, rhs_not_null) - 1; // now we start from the current lpos value and check if we found a new value that is [<= OR <] the max RHS // value while (true) { - int comp_res; - if (all_constant) { - comp_res = FastMemcmp(l_ptr, r_ptr, cmp_size); - } else { - lread.entry_idx = l_entry_idx; - rread.entry_idx = r_entry_idx; - comp_res = Comparators::CompareTuple(lread, rread, l_ptr, r_ptr, lsort.sort_layout, external); - } - - if (comp_res <= cmp) { + // Note that both subscripts here are table indices, not chunk indices. + if (MergeJoinBefore(lhs_itr[l_entry_idx], rhs_itr[r_entry_idx], strict)) { // found a match for lpos, set it in the found_match vector found_match[l_entry_idx] = true; l_entry_idx++; - l_ptr += entry_size; if (l_entry_idx >= lhs_not_null) { // early out: we exhausted the entire LHS and they all match return 0; } } else { // we found no match: any subsequent value from the LHS we scan now will be bigger and thus also not - // match move to the next RHS chunk + // match. Move to the next RHS chunk break; } } @@ -417,13 +364,42 @@ static idx_t MergeJoinSimpleBlocks(PiecewiseMergeJoinState &lstate, MergeJoinGlo return 0; } +static idx_t MergeJoinSimpleBlocks(PiecewiseMergeJoinState &lstate, MergeJoinGlobalState &gstate, bool *match, + const ExpressionType comparison) { + const auto strict = MergeJoinStrictComparison(comparison); + + switch (lstate.sort_key_type) { + case SortKeyType::NO_PAYLOAD_FIXED_8: + return TemplatedMergeJoinSimpleBlocks(lstate, gstate, match, strict); + case SortKeyType::NO_PAYLOAD_FIXED_16: + return TemplatedMergeJoinSimpleBlocks(lstate, gstate, match, strict); + case SortKeyType::NO_PAYLOAD_FIXED_24: + return TemplatedMergeJoinSimpleBlocks(lstate, gstate, match, strict); + case SortKeyType::NO_PAYLOAD_FIXED_32: + return TemplatedMergeJoinSimpleBlocks(lstate, gstate, match, strict); + case SortKeyType::NO_PAYLOAD_VARIABLE_32: + return TemplatedMergeJoinSimpleBlocks(lstate, gstate, match, strict); + case SortKeyType::PAYLOAD_FIXED_16: + return TemplatedMergeJoinSimpleBlocks(lstate, gstate, match, strict); + case SortKeyType::PAYLOAD_FIXED_24: + return TemplatedMergeJoinSimpleBlocks(lstate, gstate, match, strict); + case SortKeyType::PAYLOAD_FIXED_32: + return TemplatedMergeJoinSimpleBlocks(lstate, gstate, match, strict); + case SortKeyType::PAYLOAD_VARIABLE_32: + return TemplatedMergeJoinSimpleBlocks(lstate, gstate, match, strict); + default: + throw NotImplementedException("MergeJoinSimpleBlocks for %s", EnumUtil::ToString(lstate.sort_key_type)); + } +} + void PhysicalPiecewiseMergeJoin::ResolveSimpleJoin(ExecutionContext &context, DataChunk &input, DataChunk &chunk, OperatorState &state_p) const { auto &state = state_p.Cast(); auto &gstate = sink_state->Cast(); - state.ResolveJoinKeys(input); - auto &lhs_table = *state.lhs_local_table; + state.ResolveJoinKeys(context, input); + auto &lhs_table = *state.lhs_global_table; + auto &lhs_keys = state.lhs_local_table->keys; // perform the actual join bool found_match[STANDARD_VECTOR_SIZE]; @@ -439,8 +415,8 @@ void PhysicalPiecewiseMergeJoin::ResolveSimpleJoin(ExecutionContext &context, Da case JoinType::MARK: { // The only part of the join keys that is actually used is the validity mask. // Since the payload is sorted, we can just set the tail end of the validity masks to invalid. - for (auto &key : lhs_table.keys.data) { - key.Flatten(lhs_table.keys.size()); + for (auto &key : lhs_keys.data) { + key.Flatten(lhs_keys.size()); auto &mask = FlatVector::Validity(key); if (mask.AllValid()) { continue; @@ -451,7 +427,7 @@ void PhysicalPiecewiseMergeJoin::ResolveSimpleJoin(ExecutionContext &context, Da } } // So we make a set of keys that have the validity mask set for the - PhysicalJoin::ConstructMarkJoinResult(lhs_table.keys, payload, chunk, found_match, gstate.table->has_null); + PhysicalJoin::ConstructMarkJoinResult(lhs_keys, payload, chunk, found_match, gstate.table->has_null); break; } case JoinType::SEMI: @@ -465,40 +441,40 @@ void PhysicalPiecewiseMergeJoin::ResolveSimpleJoin(ExecutionContext &context, Da } } -static idx_t MergeJoinComplexBlocks(BlockMergeInfo &l, BlockMergeInfo &r, const ExpressionType comparison, - idx_t &prev_left_index) { - const auto cmp = MergeJoinComparisonValue(comparison); - - // The sort parameters should all be the same - D_ASSERT(l.state.sort_layout.all_constant == r.state.sort_layout.all_constant); - const auto all_constant = r.state.sort_layout.all_constant; - D_ASSERT(l.state.external == r.state.external); - const auto external = l.state.external; - - // There should only be one sorted block if they have been sorted - D_ASSERT(l.state.sorted_blocks.size() == 1); - SBScanState lread(l.state.buffer_manager, l.state); - lread.sb = l.state.sorted_blocks[0].get(); - D_ASSERT(lread.sb->radix_sorting_data.size() == 1); - MergeJoinPinSortingBlock(lread, l.block_idx); - auto l_start = MergeJoinRadixPtr(lread, 0); - auto l_ptr = MergeJoinRadixPtr(lread, l.entry_idx); - - D_ASSERT(r.state.sorted_blocks.size() == 1); - SBScanState rread(r.state.buffer_manager, r.state); - rread.sb = r.state.sorted_blocks[0].get(); +struct ChunkMergeInfo { + //! The iteration state + ExternalBlockIteratorState &state; + //! The block being scanned + const idx_t block_idx; + //! The number of not-NULL values in the chunk (they are at the end) + const idx_t not_null; + //! The current offset in the chunk + idx_t &entry_idx; + //! The offsets that match + SelectionVector result; - if (r.entry_idx >= r.not_null) { - return 0; + ChunkMergeInfo(ExternalBlockIteratorState &state, idx_t block_idx, idx_t &entry_idx, idx_t not_null) + : state(state), block_idx(block_idx), not_null(not_null), entry_idx(entry_idx), result(STANDARD_VECTOR_SIZE) { } - MergeJoinPinSortingBlock(rread, r.block_idx); - auto r_ptr = MergeJoinRadixPtr(rread, r.entry_idx); + idx_t GetIndex() const { + return state.GetIndex(block_idx, entry_idx); + } +}; - const auto cmp_size = l.state.sort_layout.comparison_size; - const auto entry_size = l.state.sort_layout.entry_size; +template +static idx_t TemplatedMergeJoinComplexBlocks(ChunkMergeInfo &l, ChunkMergeInfo &r, const bool strict, + idx_t &prev_left_index) { + using SORT_KEY = SortKey; + using BLOCK_ITERATOR = block_iterator_t; + + if (r.entry_idx >= r.not_null) { + return 0; + } idx_t result_count = 0; + BLOCK_ITERATOR l_ptr(l.state); + BLOCK_ITERATOR r_ptr(r.state); while (true) { if (l.entry_idx < prev_left_index) { // left side smaller: found match @@ -507,7 +483,7 @@ static idx_t MergeJoinComplexBlocks(BlockMergeInfo &l, BlockMergeInfo &r, const result_count++; // move left side forward l.entry_idx++; - l_ptr += entry_size; + ++l_ptr; if (result_count == STANDARD_VECTOR_SIZE) { // out of space! break; @@ -515,22 +491,14 @@ static idx_t MergeJoinComplexBlocks(BlockMergeInfo &l, BlockMergeInfo &r, const continue; } if (l.entry_idx < l.not_null) { - int comp_res; - if (all_constant) { - comp_res = FastMemcmp(l_ptr, r_ptr, cmp_size); - } else { - lread.entry_idx = l.entry_idx; - rread.entry_idx = r.entry_idx; - comp_res = Comparators::CompareTuple(lread, rread, l_ptr, r_ptr, l.state.sort_layout, external); - } - if (comp_res <= cmp) { + if (MergeJoinBefore(l_ptr[l.GetIndex()], r_ptr[r.GetIndex()], strict)) { // left side smaller: found match l.result.set_index(result_count, sel_t(l.entry_idx)); r.result.set_index(result_count, sel_t(r.entry_idx)); result_count++; // move left side forward l.entry_idx++; - l_ptr += entry_size; + ++l_ptr; if (result_count == STANDARD_VECTOR_SIZE) { // out of space! break; @@ -546,27 +514,53 @@ static idx_t MergeJoinComplexBlocks(BlockMergeInfo &l, BlockMergeInfo &r, const if (r.entry_idx >= r.not_null) { break; } - r_ptr += entry_size; + ++r_ptr; - l_ptr = l_start; l.entry_idx = 0; } return result_count; } +static idx_t MergeJoinComplexBlocks(const SortKeyType &sort_key_type, ChunkMergeInfo &l, ChunkMergeInfo &r, + const ExpressionType comparison, idx_t &prev_left_index) { + const auto strict = MergeJoinStrictComparison(comparison); + + switch (sort_key_type) { + case SortKeyType::NO_PAYLOAD_FIXED_8: + return TemplatedMergeJoinComplexBlocks(l, r, strict, prev_left_index); + case SortKeyType::NO_PAYLOAD_FIXED_16: + return TemplatedMergeJoinComplexBlocks(l, r, strict, prev_left_index); + case SortKeyType::NO_PAYLOAD_FIXED_24: + return TemplatedMergeJoinComplexBlocks(l, r, strict, prev_left_index); + case SortKeyType::NO_PAYLOAD_FIXED_32: + return TemplatedMergeJoinComplexBlocks(l, r, strict, prev_left_index); + case SortKeyType::NO_PAYLOAD_VARIABLE_32: + return TemplatedMergeJoinComplexBlocks(l, r, strict, prev_left_index); + case SortKeyType::PAYLOAD_FIXED_16: + return TemplatedMergeJoinComplexBlocks(l, r, strict, prev_left_index); + case SortKeyType::PAYLOAD_FIXED_24: + return TemplatedMergeJoinComplexBlocks(l, r, strict, prev_left_index); + case SortKeyType::PAYLOAD_FIXED_32: + return TemplatedMergeJoinComplexBlocks(l, r, strict, prev_left_index); + case SortKeyType::PAYLOAD_VARIABLE_32: + return TemplatedMergeJoinComplexBlocks(l, r, strict, prev_left_index); + default: + throw NotImplementedException("MergeJoinSimpleBlocks for %s", EnumUtil::ToString(sort_key_type)); + } +} + OperatorResultType PhysicalPiecewiseMergeJoin::ResolveComplexJoin(ExecutionContext &context, DataChunk &input, DataChunk &chunk, OperatorState &state_p) const { auto &state = state_p.Cast(); auto &gstate = sink_state->Cast(); - auto &rsorted = *gstate.table->global_sort_state.sorted_blocks[0]; const auto left_cols = input.ColumnCount(); const auto tail_cols = conditions.size() - 1; - state.payload_heap_handles.clear(); do { if (state.first_fetch) { - state.ResolveJoinKeys(input); + state.ResolveJoinKeys(context, input); + state.lhs_payload.Verify(); state.right_chunk_index = 0; state.right_base = 0; @@ -588,36 +582,44 @@ OperatorResultType PhysicalPiecewiseMergeJoin::ResolveComplexJoin(ExecutionConte return OperatorResultType::NEED_MORE_INPUT; } - auto &lhs_table = *state.lhs_local_table; + auto &lhs_table = *state.lhs_global_table; const auto lhs_not_null = lhs_table.count - lhs_table.has_null; - BlockMergeInfo left_info(*state.lhs_global_state, 0, state.left_position, lhs_not_null); + ChunkMergeInfo left_info(*state.lhs_iterator, 0, state.left_position, lhs_not_null); + + auto &rhs_table = *gstate.table; + auto &rhs_iterator = *state.rhs_iterator; + const auto rhs_not_null = SortedChunkNotNull(state.right_chunk_index, rhs_table.count, rhs_table.has_null); + ChunkMergeInfo right_info(rhs_iterator, state.right_chunk_index, state.right_position, rhs_not_null); - const auto &rblock = *rsorted.radix_sorting_data[state.right_chunk_index]; - const auto rhs_not_null = - SortedBlockNotNull(state.right_base, rblock.count, gstate.table->count - gstate.table->has_null); - BlockMergeInfo right_info(gstate.table->global_sort_state, state.right_chunk_index, state.right_position, - rhs_not_null); + // Repin so we don't hang on to data after we have scanned it + // Note we only do this for the RHS because the LHS is only one chunk. + rhs_table.Repin(rhs_iterator); - idx_t result_count = - MergeJoinComplexBlocks(left_info, right_info, conditions[0].comparison, state.prev_left_index); + idx_t result_count = MergeJoinComplexBlocks(state.sort_key_type, left_info, right_info, + conditions[0].comparison, state.prev_left_index); if (result_count == 0) { // exhausted this chunk on the right side // move to the next right chunk state.left_position = 0; state.right_position = 0; - state.right_base += rsorted.radix_sorting_data[state.right_chunk_index]->count; + state.right_base += STANDARD_VECTOR_SIZE; state.right_chunk_index++; - if (state.right_chunk_index >= rsorted.radix_sorting_data.size()) { + if (state.right_chunk_index >= rhs_table.BlockCount()) { state.finished = true; } } else { // found matches: extract them + SliceSortedPayload(state.rhs_input, rhs_table, rhs_iterator, state.rhs_chunk_state, right_info.block_idx, + right_info.result, result_count, *state.rhs_scan_state); + chunk.Reset(); - for (idx_t c = 0; c < state.lhs_payload.ColumnCount(); ++c) { - chunk.data[c].Slice(state.lhs_payload.data[c], left_info.result, result_count); + for (idx_t col_idx = 0; col_idx < chunk.ColumnCount(); ++col_idx) { + if (col_idx < left_cols) { + chunk.data[col_idx].Slice(state.lhs_payload.data[col_idx], left_info.result, result_count); + } else { + chunk.data[col_idx].Reference(state.rhs_input.data[col_idx - left_cols]); + } } - state.payload_heap_handles.push_back(SliceSortedPayload(chunk, right_info.state, right_info.block_idx, - right_info.result, result_count, left_cols)); chunk.SetCardinality(result_count); auto sel = FlatVector::IncrementalSelectionVector(); @@ -625,13 +627,12 @@ OperatorResultType PhysicalPiecewiseMergeJoin::ResolveComplexJoin(ExecutionConte // If there are more expressions to compute, // split the result chunk into the left and right halves // so we can compute the values for comparison. - chunk.Split(state.rhs_input, left_cols); state.rhs_executor.SetChunk(state.rhs_input); state.rhs_keys.Reset(); auto tail_count = result_count; for (size_t cmp_idx = 1; cmp_idx < conditions.size(); ++cmp_idx) { - Vector left(lhs_table.keys.data[cmp_idx]); + Vector left(state.lhs_local_table->keys.data[cmp_idx]); left.Slice(left_info.result, result_count); auto &right = state.rhs_keys.data[cmp_idx]; @@ -645,7 +646,6 @@ OperatorResultType PhysicalPiecewiseMergeJoin::ResolveComplexJoin(ExecutionConte SelectJoinTail(conditions[cmp_idx].comparison, left, right, sel, tail_count, &state.sel); sel = &state.sel; } - chunk.Fuse(state.rhs_input); if (tail_count < result_count) { result_count = tail_count; @@ -713,54 +713,78 @@ OperatorResultType PhysicalPiecewiseMergeJoin::ExecuteInternal(ExecutionContext //===--------------------------------------------------------------------===// // Source //===--------------------------------------------------------------------===// -class PiecewiseJoinScanState : public GlobalSourceState { +class PiecewiseJoinGlobalScanState : public GlobalSourceState { +public: + explicit PiecewiseJoinGlobalScanState(TupleDataCollection &payload) : payload(payload), right_outer_position(0) { + payload.InitializeScan(parallel_scan); + } + + idx_t Scan(TupleDataLocalScanState &local_scan, DataChunk &chunk) { + lock_guard guard(lock); + const auto result = right_outer_position; + payload.Scan(parallel_scan, local_scan, chunk); + right_outer_position += chunk.size(); + return result; + } + + TupleDataCollection &payload; + public: - explicit PiecewiseJoinScanState(const PhysicalPiecewiseMergeJoin &op) : op(op), right_outer_position(0) { + idx_t MaxThreads() override { + return payload.ChunkCount(); } +private: mutex lock; - const PhysicalPiecewiseMergeJoin &op; - unique_ptr scanner; + TupleDataParallelScanState parallel_scan; idx_t right_outer_position; +}; +class PiecewiseJoinLocalScanState : public LocalSourceState { public: - idx_t MaxThreads() override { - auto &sink = op.sink_state->Cast(); - return sink.Count() / (STANDARD_VECTOR_SIZE * idx_t(10)); + explicit PiecewiseJoinLocalScanState(PiecewiseJoinGlobalScanState &gstate) : rsel(STANDARD_VECTOR_SIZE) { + gstate.payload.InitializeScan(scanner); + gstate.payload.InitializeChunk(rhs_chunk); } + + TupleDataLocalScanState scanner; + DataChunk rhs_chunk; + SelectionVector rsel; }; unique_ptr PhysicalPiecewiseMergeJoin::GetGlobalSourceState(ClientContext &context) const { - return make_uniq(*this); + auto &gsink = sink_state->Cast(); + return make_uniq(*gsink.table->sorted->payload_data); +} + +unique_ptr PhysicalPiecewiseMergeJoin::GetLocalSourceState(ExecutionContext &context, + GlobalSourceState &gstate) const { + return make_uniq(gstate.Cast()); } SourceResultType PhysicalPiecewiseMergeJoin::GetData(ExecutionContext &context, DataChunk &result, - OperatorSourceInput &input) const { + OperatorSourceInput &source) const { D_ASSERT(PropagatesBuildSide(join_type)); // check if we need to scan any unmatched tuples from the RHS for the full/right outer join - auto &sink = sink_state->Cast(); - auto &state = input.global_state.Cast(); - - lock_guard l(state.lock); - if (!state.scanner) { - // Initialize scanner (if not yet initialized) - auto &sort_state = sink.table->global_sort_state; - if (sort_state.sorted_blocks.empty()) { - return SourceResultType::FINISHED; - } - state.scanner = make_uniq(*sort_state.sorted_blocks[0]->payload_data, sort_state); + auto &gsink = sink_state->Cast(); + auto &gsource = source.global_state.Cast(); + + // RHS was empty, so nothing to do? + if (!gsink.table->count) { + return SourceResultType::FINISHED; } // if the LHS is exhausted in a FULL/RIGHT OUTER JOIN, we scan the found_match for any chunks we // still need to output - const auto found_match = sink.table->found_match.get(); + const auto found_match = gsink.table->found_match.get(); - DataChunk rhs_chunk; - rhs_chunk.Initialize(Allocator::Get(context.client), sink.table->global_sort_state.payload_layout.GetTypes()); - SelectionVector rsel(STANDARD_VECTOR_SIZE); + auto &lsource = source.local_state.Cast(); + auto &rhs_chunk = lsource.rhs_chunk; + auto &rsel = lsource.rsel; for (;;) { // Read the next sorted chunk - state.scanner->Scan(rhs_chunk); + rhs_chunk.Reset(); + const auto rhs_pos = gsource.Scan(lsource.scanner, rhs_chunk); const auto count = rhs_chunk.size(); if (count == 0) { @@ -770,11 +794,10 @@ SourceResultType PhysicalPiecewiseMergeJoin::GetData(ExecutionContext &context, idx_t result_count = 0; // figure out which tuples didn't find a match in the RHS for (idx_t i = 0; i < count; i++) { - if (!found_match[state.right_outer_position + i]) { + if (!found_match[rhs_pos + i]) { rsel.set_index(result_count++, i); } } - state.right_outer_position += count; if (result_count > 0) { // if there were any tuples that didn't find a match, output them diff --git a/src/duckdb/src/execution/operator/join/physical_range_join.cpp b/src/duckdb/src/execution/operator/join/physical_range_join.cpp index 4fefafbd4..41abaeca9 100644 --- a/src/duckdb/src/execution/operator/join/physical_range_join.cpp +++ b/src/duckdb/src/execution/operator/join/physical_range_join.cpp @@ -1,10 +1,7 @@ #include "duckdb/execution/operator/join/physical_range_join.hpp" -#include "duckdb/common/fast_mem.hpp" -#include "duckdb/common/operator/comparison_operators.hpp" #include "duckdb/common/row_operations/row_operations.hpp" -#include "duckdb/common/sort/comparators.hpp" -#include "duckdb/common/sort/sort.hpp" +#include "duckdb/common/sorting/sort_key.hpp" #include "duckdb/common/types/validity_mask.hpp" #include "duckdb/common/types/vector.hpp" #include "duckdb/common/unordered_map.hpp" @@ -14,15 +11,15 @@ #include "duckdb/parallel/base_pipeline_event.hpp" #include "duckdb/parallel/thread_context.hpp" #include "duckdb/parallel/executor_task.hpp" - -#include +#include "duckdb/planner/expression/bound_reference_expression.hpp" namespace duckdb { -PhysicalRangeJoin::LocalSortedTable::LocalSortedTable(ClientContext &context, const PhysicalRangeJoin &op, +PhysicalRangeJoin::LocalSortedTable::LocalSortedTable(ExecutionContext &context, GlobalSortedTable &global_table, const idx_t child) - : op(op), executor(context), has_null(0), count(0) { + : global_table(global_table), executor(context.client), has_null(0), count(0) { // Initialize order clause expression executor and key DataChunk + const auto &op = global_table.op; vector types; for (const auto &cond : op.conditions) { const auto &expr = child ? cond.right : cond.left; @@ -30,16 +27,19 @@ PhysicalRangeJoin::LocalSortedTable::LocalSortedTable(ClientContext &context, co types.push_back(expr->return_type); } - auto &allocator = Allocator::Get(context); + auto &allocator = Allocator::Get(context.client); keys.Initialize(allocator, types); -} -void PhysicalRangeJoin::LocalSortedTable::Sink(DataChunk &input, GlobalSortState &global_sort_state) { - // Initialize local state (if necessary) - if (!local_sort_state.initialized) { - local_sort_state.Initialize(global_sort_state, global_sort_state.buffer_manager); - } + local_sink = global_table.sort->GetLocalSinkState(context); + + // Only sort the primary key + types.resize(1); + const auto &payload_types = op.children[child].get().types; + types.insert(types.end(), payload_types.begin(), payload_types.end()); + sort_chunk.InitializeEmpty(types); +} +void PhysicalRangeJoin::LocalSortedTable::Sink(ExecutionContext &context, DataChunk &input) { // Obtain sorting columns keys.Reset(); executor.Execute(input, keys); @@ -47,121 +47,179 @@ void PhysicalRangeJoin::LocalSortedTable::Sink(DataChunk &input, GlobalSortState // Do not operate on primary key directly to avoid modifying the input chunk Vector primary = keys.data[0]; // Count the NULLs so we can exclude them later - has_null += MergeNulls(primary, op.conditions); + has_null += MergeNulls(primary, global_table.op.conditions); count += keys.size(); // Only sort the primary key - DataChunk join_head; - join_head.data.emplace_back(primary); - join_head.SetCardinality(keys.size()); + sort_chunk.data[0].Reference(primary); + for (column_t col_idx = 0; col_idx < input.ColumnCount(); ++col_idx) { + sort_chunk.data[col_idx + 1].Reference(input.data[col_idx]); + } + sort_chunk.SetCardinality(input); // Sink the data into the local sort state - local_sort_state.SinkChunk(join_head, input); + InterruptState interrupt; + OperatorSinkInput sink {*global_table.global_sink, *local_sink, interrupt}; + global_table.sort->Sink(context, sort_chunk, sink); } -PhysicalRangeJoin::GlobalSortedTable::GlobalSortedTable(ClientContext &context, const vector &orders, - RowLayout &payload_layout, const PhysicalOperator &op_p) - : op(op_p), global_sort_state(context, orders, payload_layout), has_null(0), count(0), memory_per_thread(0) { +PhysicalRangeJoin::GlobalSortedTable::GlobalSortedTable(ClientContext &client, + const vector &order_bys, + const vector &payload_types, + const PhysicalRangeJoin &op) + : op(op), has_null(0), count(0), tasks_completed(0) { + // Set up the sort. We will materialize keys ourselves, so just set up references. + vector orders; + vector input_types; + for (const auto &order_by : order_bys) { + auto order = order_by.Copy(); + const auto type = order.expression->return_type; + input_types.emplace_back(type); + order.expression = make_uniq(type, orders.size()); + orders.emplace_back(std::move(order)); + } + + vector projection_map; + for (const auto &type : payload_types) { + projection_map.emplace_back(input_types.size()); + input_types.emplace_back(type); + } + + sort = make_uniq(client, orders, input_types, projection_map); - // Set external (can be forced with the PRAGMA) - global_sort_state.external = ClientConfig::GetConfig(context).force_external; - memory_per_thread = PhysicalRangeJoin::GetMaxThreadMemory(context); + global_sink = sort->GetGlobalSinkState(client); } -void PhysicalRangeJoin::GlobalSortedTable::Combine(LocalSortedTable <able) { - global_sort_state.AddLocalState(ltable.local_sort_state); +void PhysicalRangeJoin::GlobalSortedTable::Combine(ExecutionContext &context, LocalSortedTable <able) { + InterruptState interrupt; + OperatorSinkCombineInput combine {*global_sink, *ltable.local_sink, interrupt}; + sort->Combine(context, combine); has_null += ltable.has_null; count += ltable.count; } +void PhysicalRangeJoin::GlobalSortedTable::Finalize(ClientContext &client, InterruptState &interrupt) { + OperatorSinkFinalizeInput finalize {*global_sink, interrupt}; + sort->Finalize(client, finalize); +} + void PhysicalRangeJoin::GlobalSortedTable::IntializeMatches() { found_match = make_unsafe_uniq_array_uninitialized(Count()); memset(found_match.get(), 0, sizeof(bool) * Count()); } +void PhysicalRangeJoin::GlobalSortedTable::MaterializeEmpty(ClientContext &client) { + D_ASSERT(!sorted); + sorted = make_uniq(client, *sort, false); +} + void PhysicalRangeJoin::GlobalSortedTable::Print() { - global_sort_state.Print(); + D_ASSERT(sorted); + auto &collection = *sorted->payload_data; + TupleDataScanState scanner; + collection.InitializeScan(scanner); + + DataChunk payload; + collection.InitializeScanChunk(scanner, payload); + + while (collection.Scan(scanner, payload)) { + payload.Print(); + } } -class RangeJoinMergeTask : public ExecutorTask { +//===--------------------------------------------------------------------===// +// RangeJoinMaterializeTask +//===--------------------------------------------------------------------===// +class RangeJoinMaterializeTask : public ExecutorTask { public: using GlobalSortedTable = PhysicalRangeJoin::GlobalSortedTable; public: - RangeJoinMergeTask(shared_ptr event_p, ClientContext &context, GlobalSortedTable &table) - : ExecutorTask(context, std::move(event_p), table.op), context(context), table(table) { + RangeJoinMaterializeTask(Pipeline &pipeline, shared_ptr event, ClientContext &client, + GlobalSortedTable &table, idx_t tasks_scheduled) + : ExecutorTask(client, std::move(event), table.op), pipeline(pipeline), table(table), + tasks_scheduled(tasks_scheduled) { } TaskExecutionResult ExecuteTask(TaskExecutionMode mode) override { - // Initialize iejoin sorted and iterate until done - auto &global_sort_state = table.global_sort_state; - MergeSorter merge_sorter(global_sort_state, BufferManager::GetBufferManager(context)); - merge_sorter.PerformInMergeRound(); - event->FinishTask(); + ExecutionContext execution(pipeline.GetClientContext(), *thread_context, &pipeline); + auto &sort = *table.sort; + auto &sort_global = *table.global_source; + auto sort_local = sort.GetLocalSourceState(execution, sort_global); + InterruptState interrupt((weak_ptr(shared_from_this()))); + OperatorSourceInput input {sort_global, *sort_local, interrupt}; + sort.MaterializeSortedRun(execution, input); + if (++table.tasks_completed == tasks_scheduled) { + table.sorted = sort.GetSortedRun(sort_global); + if (!table.sorted) { + table.MaterializeEmpty(execution.client); + } + } + event->FinishTask(); return TaskExecutionResult::TASK_FINISHED; } string TaskType() const override { - return "RangeJoinMergeTask"; + return "RangeJoinMaterializeTask"; } private: - ClientContext &context; + Pipeline &pipeline; GlobalSortedTable &table; + const idx_t tasks_scheduled; }; -class RangeJoinMergeEvent : public BasePipelineEvent { +//===--------------------------------------------------------------------===// +// RangeJoinMaterializeEvent +//===--------------------------------------------------------------------===// +class RangeJoinMaterializeEvent : public BasePipelineEvent { public: using GlobalSortedTable = PhysicalRangeJoin::GlobalSortedTable; public: - RangeJoinMergeEvent(GlobalSortedTable &table_p, Pipeline &pipeline_p) - : BasePipelineEvent(pipeline_p), table(table_p) { + RangeJoinMaterializeEvent(GlobalSortedTable &table, Pipeline &pipeline) + : BasePipelineEvent(pipeline), table(table) { } GlobalSortedTable &table; public: void Schedule() override { - auto &context = pipeline->GetClientContext(); + auto &client = pipeline->GetClientContext(); - // Schedule tasks equal to the number of threads, which will each merge multiple partitions - auto &ts = TaskScheduler::GetScheduler(context); + // Schedule as many tasks as the sort will allow + auto &ts = TaskScheduler::GetScheduler(client); auto num_threads = NumericCast(ts.NumberOfThreads()); - - vector> iejoin_tasks; - for (idx_t tnum = 0; tnum < num_threads; tnum++) { - iejoin_tasks.push_back(make_uniq(shared_from_this(), context, table)); + vector> tasks; + + auto &sort = *table.sort; + auto &global_sink = *table.global_sink; + table.global_source = sort.GetGlobalSourceState(client, global_sink); + const auto tasks_scheduled = MinValue(num_threads, table.global_source->MaxThreads()); + for (idx_t tnum = 0; tnum < tasks_scheduled; ++tnum) { + tasks.push_back( + make_uniq(*pipeline, shared_from_this(), client, table, tasks_scheduled)); } - SetTasks(std::move(iejoin_tasks)); - } - void FinishEvent() override { - auto &global_sort_state = table.global_sort_state; - - global_sort_state.CompleteMergeRound(true); - if (global_sort_state.sorted_blocks.size() > 1) { - // Multiple blocks remaining: Schedule the next round - table.ScheduleMergeTasks(*pipeline, *this); - } + SetTasks(std::move(tasks)); } }; -void PhysicalRangeJoin::GlobalSortedTable::ScheduleMergeTasks(Pipeline &pipeline, Event &event) { - // Initialize global sort state for a round of merging - global_sort_state.InitializeMergeRound(); - auto new_event = make_shared_ptr(*this, pipeline); - event.InsertEvent(std::move(new_event)); +void PhysicalRangeJoin::GlobalSortedTable::Materialize(Pipeline &pipeline, Event &event) { + // Schedule all the sorts for maximum thread utilisation + auto sort_event = make_shared_ptr(*this, pipeline); + event.InsertEvent(std::move(sort_event)); } -void PhysicalRangeJoin::GlobalSortedTable::Finalize(Pipeline &pipeline, Event &event) { - // Prepare for merge sort phase - global_sort_state.PrepareMergePhase(); - - // Start the merge phase or finish if a merge is not necessary - if (global_sort_state.sorted_blocks.size() > 1) { - ScheduleMergeTasks(pipeline, event); +void PhysicalRangeJoin::GlobalSortedTable::Materialize(ExecutionContext &context, InterruptState &interrupt) { + global_source = sort->GetGlobalSourceState(context.client, *global_sink); + auto local_source = sort->GetLocalSourceState(context, *global_source); + OperatorSourceInput source {*global_source, *local_source, interrupt}; + sort->MaterializeSortedRun(context, source); + sorted = sort->GetSortedRun(*global_source); + if (!sorted) { + MaterializeEmpty(context.client); } } @@ -336,56 +394,74 @@ void PhysicalRangeJoin::ProjectResult(DataChunk &chunk, DataChunk &result) const result.SetCardinality(chunk); } -BufferHandle PhysicalRangeJoin::SliceSortedPayload(DataChunk &payload, GlobalSortState &state, const idx_t block_idx, - const SelectionVector &result, const idx_t result_count, - const idx_t left_cols) { - // There should only be one sorted block if they have been sorted - D_ASSERT(state.sorted_blocks.size() == 1); - SBScanState read_state(state.buffer_manager, state); - read_state.sb = state.sorted_blocks[0].get(); - auto &sorted_data = *read_state.sb->payload_data; - - read_state.SetIndices(block_idx, 0); - read_state.PinData(sorted_data); - const auto data_ptr = read_state.DataPtr(sorted_data); - data_ptr_t heap_ptr = nullptr; - - // Set up a batch of pointers to scan data from - Vector addresses(LogicalType::POINTER, result_count); - auto data_pointers = FlatVector::GetData(addresses); - - // Set up the data pointers for the values that are actually referenced - const idx_t &row_width = sorted_data.layout.GetRowWidth(); - - auto prev_idx = result.get_index(0); - SelectionVector gsel(result_count); - idx_t addr_count = 0; - gsel.set_index(0, addr_count); - data_pointers[addr_count] = data_ptr + prev_idx * row_width; - for (idx_t i = 1; i < result_count; ++i) { - const auto row_idx = result.get_index(i); - if (row_idx != prev_idx) { - data_pointers[++addr_count] = data_ptr + row_idx * row_width; - prev_idx = row_idx; - } - gsel.set_index(i, addr_count); +template +static void TemplatedSliceSortedPayload(DataChunk &chunk, const SortedRun &sorted_run, + ExternalBlockIteratorState &state, Vector &sort_key_pointers, + SortedRunScanState &scan_state, const idx_t chunk_idx, SelectionVector &result, + const idx_t result_count) { + using SORT_KEY = SortKey; + using BLOCK_ITERATOR = block_iterator_t; + BLOCK_ITERATOR itr(state, chunk_idx, 0); + + const auto sort_keys = FlatVector::GetData(sort_key_pointers); + for (idx_t i = 0; i < result_count; ++i) { + const auto idx = state.GetIndex(chunk_idx, result.get_index(i)); + sort_keys[i] = &itr[idx]; } - ++addr_count; - // Unswizzle the offsets back to pointers (if needed) - if (!sorted_data.layout.AllConstant() && state.external) { - heap_ptr = read_state.payload_heap_handle.Ptr(); - } + // Scan + chunk.Reset(); + scan_state.Scan(sorted_run, sort_key_pointers, result_count, chunk); +} - // Deserialize the payload data - auto sel = FlatVector::IncrementalSelectionVector(); - for (idx_t col_no = 0; col_no < sorted_data.layout.ColumnCount(); col_no++) { - auto &col = payload.data[left_cols + col_no]; - RowOperations::Gather(addresses, *sel, col, *sel, addr_count, sorted_data.layout, col_no, 0, heap_ptr); - col.Slice(gsel, result_count); +void PhysicalRangeJoin::SliceSortedPayload(DataChunk &chunk, GlobalSortedTable &table, + ExternalBlockIteratorState &state, TupleDataChunkState &chunk_state, + const idx_t chunk_idx, SelectionVector &result, const idx_t result_count, + SortedRunScanState &scan_state) { + auto &sorted = *table.sorted; + auto &sort_keys = chunk_state.row_locations; + const auto sort_key_type = table.GetSortKeyType(); + + switch (sort_key_type) { + case SortKeyType::NO_PAYLOAD_FIXED_8: + TemplatedSliceSortedPayload(chunk, sorted, state, sort_keys, scan_state, + chunk_idx, result, result_count); + break; + case SortKeyType::NO_PAYLOAD_FIXED_16: + TemplatedSliceSortedPayload(chunk, sorted, state, sort_keys, scan_state, + chunk_idx, result, result_count); + break; + case SortKeyType::NO_PAYLOAD_FIXED_24: + TemplatedSliceSortedPayload(chunk, sorted, state, sort_keys, scan_state, + chunk_idx, result, result_count); + break; + case SortKeyType::NO_PAYLOAD_FIXED_32: + TemplatedSliceSortedPayload(chunk, sorted, state, sort_keys, scan_state, + chunk_idx, result, result_count); + break; + case SortKeyType::NO_PAYLOAD_VARIABLE_32: + TemplatedSliceSortedPayload(chunk, sorted, state, sort_keys, scan_state, + chunk_idx, result, result_count); + break; + case SortKeyType::PAYLOAD_FIXED_16: + TemplatedSliceSortedPayload(chunk, sorted, state, sort_keys, scan_state, + chunk_idx, result, result_count); + break; + case SortKeyType::PAYLOAD_FIXED_24: + TemplatedSliceSortedPayload(chunk, sorted, state, sort_keys, scan_state, + chunk_idx, result, result_count); + break; + case SortKeyType::PAYLOAD_FIXED_32: + TemplatedSliceSortedPayload(chunk, sorted, state, sort_keys, scan_state, + chunk_idx, result, result_count); + break; + case SortKeyType::PAYLOAD_VARIABLE_32: + TemplatedSliceSortedPayload(chunk, sorted, state, sort_keys, scan_state, + chunk_idx, result, result_count); + break; + default: + throw NotImplementedException("MergeJoinSimpleBlocks for %s", EnumUtil::ToString(sort_key_type)); } - - return std::move(read_state.payload_heap_handle); } idx_t PhysicalRangeJoin::SelectJoinTail(const ExpressionType &condition, Vector &left, Vector &right, diff --git a/src/duckdb/src/execution/operator/order/physical_top_n.cpp b/src/duckdb/src/execution/operator/order/physical_top_n.cpp index ec082601c..579bd189a 100644 --- a/src/duckdb/src/execution/operator/order/physical_top_n.cpp +++ b/src/duckdb/src/execution/operator/order/physical_top_n.cpp @@ -1,6 +1,7 @@ #include "duckdb/execution/operator/order/physical_top_n.hpp" #include "duckdb/common/assert.hpp" +#include "duckdb/common/arena_containers/arena_vector.hpp" #include "duckdb/execution/expression_executor.hpp" #include "duckdb/function/create_sort_key.hpp" #include "duckdb/storage/data_table.hpp" @@ -85,7 +86,8 @@ class TopNHeap { Allocator &allocator; BufferManager &buffer_manager; - unsafe_vector heap; + ArenaAllocator arena_allocator; + unsafe_arena_vector heap; const vector &payload_types; const vector &orders; vector modifiers; @@ -162,10 +164,11 @@ class TopNHeap { //===--------------------------------------------------------------------===// TopNHeap::TopNHeap(ClientContext &context, Allocator &allocator, const vector &payload_types_p, const vector &orders_p, idx_t limit, idx_t offset) - : allocator(allocator), buffer_manager(BufferManager::GetBufferManager(context)), payload_types(payload_types_p), - orders(orders_p), limit(limit), offset(offset), heap_size(limit + offset), executor(context), - sort_key_heap(allocator), matching_sel(STANDARD_VECTOR_SIZE), final_sel(STANDARD_VECTOR_SIZE), - true_sel(STANDARD_VECTOR_SIZE), false_sel(STANDARD_VECTOR_SIZE), new_remaining_sel(STANDARD_VECTOR_SIZE) { + : allocator(allocator), buffer_manager(BufferManager::GetBufferManager(context)), arena_allocator(allocator), + heap(arena_allocator), payload_types(payload_types_p), orders(orders_p), limit(limit), offset(offset), + heap_size(limit + offset), executor(context), sort_key_heap(allocator), matching_sel(STANDARD_VECTOR_SIZE), + final_sel(STANDARD_VECTOR_SIZE), true_sel(STANDARD_VECTOR_SIZE), false_sel(STANDARD_VECTOR_SIZE), + new_remaining_sel(STANDARD_VECTOR_SIZE) { // initialize the executor and the sort_chunk vector sort_types; for (auto &order : orders) { diff --git a/src/duckdb/src/execution/operator/persistent/physical_batch_insert.cpp b/src/duckdb/src/execution/operator/persistent/physical_batch_insert.cpp index 95b519d4d..025af5ba1 100644 --- a/src/duckdb/src/execution/operator/persistent/physical_batch_insert.cpp +++ b/src/duckdb/src/execution/operator/persistent/physical_batch_insert.cpp @@ -90,7 +90,7 @@ class CollectionMerger { auto &collection = data_table.GetOptimisticCollection(context, collection_indexes[i]); TableScanState scan_state; scan_state.Initialize(column_ids); - collection.collection->InitializeScan(scan_state.local_state, column_ids, nullptr); + collection.collection->InitializeScan(context, scan_state.local_state, column_ids, nullptr); while (true) { scan_chunk.Reset(); @@ -194,7 +194,10 @@ class BatchInsertLocalState : public LocalSinkState { void CreateNewCollection(ClientContext &context, DuckTableEntry &table_entry, const vector &insert_types) { - auto collection = OptimisticDataWriter::CreateCollection(table_entry.GetStorage(), insert_types); + if (!optimistic_writer) { + optimistic_writer = make_uniq(context, table_entry.GetStorage()); + } + auto collection = optimistic_writer->CreateCollection(table_entry.GetStorage(), insert_types); auto &row_collection = *collection->collection; row_collection.InitializeEmpty(); row_collection.InitializeAppend(current_append_state); @@ -526,9 +529,6 @@ SinkResultType PhysicalBatchInsert::Sink(ExecutionContext &context, DataChunk &i lock_guard l(gstate.lock); // no collection yet: create a new one lstate.CreateNewCollection(context.client, table, insert_types); - if (!lstate.optimistic_writer) { - lstate.optimistic_writer = make_uniq(context.client, table.GetStorage()); - } } if (lstate.current_index != batch_index) { diff --git a/src/duckdb/src/execution/operator/persistent/physical_delete.cpp b/src/duckdb/src/execution/operator/persistent/physical_delete.cpp index 13458c923..4fb5df814 100644 --- a/src/duckdb/src/execution/operator/persistent/physical_delete.cpp +++ b/src/duckdb/src/execution/operator/persistent/physical_delete.cpp @@ -27,7 +27,6 @@ class DeleteGlobalState : public GlobalSinkState { explicit DeleteGlobalState(ClientContext &context, const vector &return_types, TableCatalogEntry &table, const vector> &bound_constraints) : deleted_count(0), return_collection(context, return_types), has_unique_indexes(false) { - // We need to append deletes to the local delete-ART. auto &storage = table.GetStorage(); if (storage.HasUniqueIndexes()) { diff --git a/src/duckdb/src/execution/operator/persistent/physical_insert.cpp b/src/duckdb/src/execution/operator/persistent/physical_insert.cpp index 97c31c4ba..25ac6576e 100644 --- a/src/duckdb/src/execution/operator/persistent/physical_insert.cpp +++ b/src/duckdb/src/execution/operator/persistent/physical_insert.cpp @@ -36,7 +36,6 @@ PhysicalInsert::PhysicalInsert(PhysicalPlan &physical_plan, vector set_expressions(std::move(set_expressions)), set_columns(std::move(set_columns)), set_types(std::move(set_types)), on_conflict_condition(std::move(on_conflict_condition_p)), do_update_condition(std::move(do_update_condition_p)), conflict_target(std::move(conflict_target_p)), update_is_del_and_insert(update_is_del_and_insert) { - if (action_type == OnConflictAction::THROW) { return; } @@ -82,7 +81,6 @@ InsertGlobalState::InsertGlobalState(ClientContext &context, const vector &types, const vector> &bound_constraints) : collection_index(DConstants::INVALID_INDEX), bound_constraints(bound_constraints) { - auto &allocator = Allocator::Get(context); update_chunk.Initialize(allocator, types); append_chunk.Initialize(allocator, types); @@ -189,7 +187,6 @@ static void CombineExistingAndInsertTuples(DataChunk &result, DataChunk &scan_ch static void CreateUpdateChunk(ExecutionContext &context, DataChunk &chunk, TableCatalogEntry &table, Vector &row_ids, DataChunk &update_chunk, const PhysicalInsert &op) { - auto &do_update_condition = op.do_update_condition; auto &set_types = op.set_types; auto &set_expressions = op.set_expressions; @@ -651,14 +648,14 @@ SinkResultType PhysicalInsert::Sink(ExecutionContext &context, DataChunk &insert D_ASSERT(!return_chunk); auto &data_table = gstate.table.GetStorage(); if (!lstate.collection_index.IsValid()) { + lock_guard l(gstate.lock); + lstate.optimistic_writer = make_uniq(context.client, data_table); // Create the local row group collection. - auto optimistic_collection = OptimisticDataWriter::CreateCollection(storage, insert_types); + auto optimistic_collection = lstate.optimistic_writer->CreateCollection(storage, insert_types); auto &collection = *optimistic_collection->collection; collection.InitializeEmpty(); collection.InitializeAppend(lstate.local_append_state); - lock_guard l(gstate.lock); - lstate.optimistic_writer = make_uniq(context.client, data_table); lstate.collection_index = data_table.CreateOptimisticCollection(context.client, std::move(optimistic_collection)); } diff --git a/src/duckdb/src/execution/operator/persistent/physical_merge_into.cpp b/src/duckdb/src/execution/operator/persistent/physical_merge_into.cpp index 04a5f3dca..64f02cbd4 100644 --- a/src/duckdb/src/execution/operator/persistent/physical_merge_into.cpp +++ b/src/duckdb/src/execution/operator/persistent/physical_merge_into.cpp @@ -10,7 +10,6 @@ PhysicalMergeInto::PhysicalMergeInto(PhysicalPlan &physical_plan, vector ranges; for (auto &entry : actions_p) { MergeActionRange range; diff --git a/src/duckdb/src/execution/operator/persistent/physical_update.cpp b/src/duckdb/src/execution/operator/persistent/physical_update.cpp index f96dba699..8f9d7ecab 100644 --- a/src/duckdb/src/execution/operator/persistent/physical_update.cpp +++ b/src/duckdb/src/execution/operator/persistent/physical_update.cpp @@ -25,7 +25,6 @@ PhysicalUpdate::PhysicalUpdate(PhysicalPlan &physical_plan, vector tableref(tableref), table(table), columns(std::move(columns)), expressions(std::move(expressions)), bound_defaults(std::move(bound_defaults)), bound_constraints(std::move(bound_constraints)), return_chunk(return_chunk), index_update(false) { - auto &indexes = table.GetDataTableInfo().get()->GetIndexes(); auto index_columns = indexes.GetRequiredColumns(); @@ -67,7 +66,6 @@ class UpdateLocalState : public LocalSinkState { const vector &table_types, const vector> &bound_defaults, const vector> &bound_constraints) : default_executor(context, bound_defaults), bound_constraints(bound_constraints) { - // Initialize the update chunk. auto &allocator = Allocator::Get(context); vector update_types; diff --git a/src/duckdb/src/execution/operator/projection/physical_pivot.cpp b/src/duckdb/src/execution/operator/projection/physical_pivot.cpp index ffd9ec565..68aa95b18 100644 --- a/src/duckdb/src/execution/operator/projection/physical_pivot.cpp +++ b/src/duckdb/src/execution/operator/projection/physical_pivot.cpp @@ -9,7 +9,6 @@ PhysicalPivot::PhysicalPivot(PhysicalPlan &physical_plan, vector ty BoundPivotInfo bound_pivot_p) : PhysicalOperator(physical_plan, PhysicalOperatorType::PIVOT, std::move(types_p), child.estimated_cardinality), bound_pivot(std::move(bound_pivot_p)) { - children.push_back(child); for (idx_t p = 0; p < bound_pivot.pivot_values.size(); p++) { auto entry = pivot_map.find(bound_pivot.pivot_values[p]); diff --git a/src/duckdb/src/execution/operator/projection/physical_unnest.cpp b/src/duckdb/src/execution/operator/projection/physical_unnest.cpp index 62164d95b..e1ce4bb05 100644 --- a/src/duckdb/src/execution/operator/projection/physical_unnest.cpp +++ b/src/duckdb/src/execution/operator/projection/physical_unnest.cpp @@ -13,7 +13,6 @@ class UnnestOperatorState : public OperatorState { public: UnnestOperatorState(ClientContext &context, const vector> &select_list) : current_row(0), list_position(0), first_fetch(true), input_sel(STANDARD_VECTOR_SIZE), executor(context) { - // for each UNNEST in the select_list, we add the child expression to the expression executor // and set the return type in the list_data chunk, which will contain the evaluated expression results vector list_data_types; @@ -139,7 +138,6 @@ OperatorResultType PhysicalUnnest::ExecuteInternal(ExecutionContext &context, Da OperatorState &state_p, const vector> &select_list, bool include_input) { - auto &state = state_p.Cast(); do { diff --git a/src/duckdb/src/execution/operator/scan/physical_positional_scan.cpp b/src/duckdb/src/execution/operator/scan/physical_positional_scan.cpp index bff24d785..3ea2328da 100644 --- a/src/duckdb/src/execution/operator/scan/physical_positional_scan.cpp +++ b/src/duckdb/src/execution/operator/scan/physical_positional_scan.cpp @@ -14,7 +14,6 @@ PhysicalPositionalScan::PhysicalPositionalScan(PhysicalPlan &physical_plan, vect PhysicalOperator &left, PhysicalOperator &right) : PhysicalOperator(physical_plan, PhysicalOperatorType::POSITIONAL_SCAN, std::move(types), MaxValue(left.estimated_cardinality, right.estimated_cardinality)) { - // Manage the children ourselves if (left.type == PhysicalOperatorType::TABLE_SCAN) { child_tables.emplace_back(left); diff --git a/src/duckdb/src/execution/operator/scan/physical_table_scan.cpp b/src/duckdb/src/execution/operator/scan/physical_table_scan.cpp index e9f66bea4..4d86eace7 100644 --- a/src/duckdb/src/execution/operator/scan/physical_table_scan.cpp +++ b/src/duckdb/src/execution/operator/scan/physical_table_scan.cpp @@ -4,6 +4,9 @@ #include "duckdb/common/string_util.hpp" #include "duckdb/planner/expression/bound_conjunction_expression.hpp" #include "duckdb/transaction/transaction.hpp" +#include "duckdb/main/database.hpp" +#include "duckdb/execution/physical_table_scan_enum.hpp" +#include "duckdb/main/settings.hpp" #include @@ -16,6 +19,7 @@ PhysicalTableScan::PhysicalTableScan(PhysicalPlan &physical_plan, vector parameters_p, virtual_column_map_t virtual_columns_p) : PhysicalOperator(physical_plan, PhysicalOperatorType::TABLE_SCAN, std::move(types), estimated_cardinality), + function(std::move(function_p)), bind_data(std::move(bind_data_p)), returned_types(std::move(returned_types_p)), column_ids(std::move(column_ids_p)), projection_ids(std::move(projection_ids_p)), names(std::move(names_p)), table_filters(std::move(table_filters_p)), extra_info(std::move(extra_info)), parameters(std::move(parameters_p)), @@ -25,6 +29,9 @@ PhysicalTableScan::PhysicalTableScan(PhysicalPlan &physical_plan, vector(context); + if (op.dynamic_filters && op.dynamic_filters->HasFilters()) { table_filters = op.dynamic_filters->GetFinalTableFilters(op, op.table_filters.get()); } @@ -56,6 +63,7 @@ class TableScanGlobalSourceState : public GlobalSourceState { } idx_t max_threads = 0; + PhysicalTableScanExecutionStrategy physical_table_scan_execution_strategy; unique_ptr global_state; bool in_out_final = false; DataChunk input_chunk; @@ -93,6 +101,61 @@ unique_ptr PhysicalTableScan::GetGlobalSourceState(ClientCont return make_uniq(context, *this); } +static void ValidateAsyncStrategyResult(const PhysicalTableScanExecutionStrategy &strategy, + const AsyncResultsExecutionMode &execution_mode_pre, + const AsyncResultsExecutionMode &execution_mode_post, + const AsyncResultType &result_pre, const AsyncResultType &result_post, + const idx_t output_chunk_size) { + auto execution_mode_pre_computed = AsyncResult::ConvertToAsyncResultExecutionMode(strategy); + if (execution_mode_pre_computed != execution_mode_pre) { + throw InternalException("ValidateAsyncStrategyResult: invalid conversion PhysicalTableScanExecutionStrategy to " + "AsyncResultsExecutionMode, from '%s', to '%s'", + EnumUtil::ToChars(strategy), EnumUtil::ToChars(execution_mode_pre)); + } + + if (execution_mode_pre != execution_mode_post) { + throw InternalException("ValidateAsyncStrategyResult: results_execution_mode changed within table API's " + "`function` call, before '%s', after '%s'", + EnumUtil::ToChars(execution_mode_pre), EnumUtil::ToChars(execution_mode_post)); + } + if (result_pre != AsyncResultType::IMPLICIT) { + throw InternalException("ValidateAsyncStrategyResult: async_result is supposed to be IMPLICIT, was '%s', " + "before table API's `function` call", + EnumUtil::ToChars(result_pre)); + } + switch (strategy) { + case PhysicalTableScanExecutionStrategy::TASK_EXECUTOR_BUT_FORCE_SYNC_CHECKS: + // This is a funny one, expected to throw on non-trivial workflows in this function + case PhysicalTableScanExecutionStrategy::SYNCHRONOUS: + switch (result_post) { + case AsyncResultType::INVALID: + throw InternalException("ValidateAsyncStrategyResult: found INVALID"); + case AsyncResultType::BLOCKED: + throw InternalException("ValidateAsyncStrategyResult: found BLOCKED"); + case AsyncResultType::FINISHED: + if (output_chunk_size > 0) { + throw InternalException("ValidateAsyncStrategyResult: found FINISHED with non-empty chunk"); + } + break; + case AsyncResultType::HAVE_MORE_OUTPUT: + if (output_chunk_size == 0) { + throw InternalException("ValidateAsyncStrategyResult: found HAVE_MORE_OUTPUT with empty chunk"); + } + break; + case AsyncResultType::IMPLICIT: + break; + } + break; + default: + if (result_post == AsyncResultType::BLOCKED) { + if (output_chunk_size > 0) { + throw InternalException("ValidateAsyncStrategyResult: found BLOCKED with non-empty chunk"); + } + } + break; + } +} + SourceResultType PhysicalTableScan::GetData(ExecutionContext &context, DataChunk &chunk, OperatorSourceInput &input) const { D_ASSERT(!column_ids.empty()); @@ -102,15 +165,55 @@ SourceResultType PhysicalTableScan::GetData(ExecutionContext &context, DataChunk TableFunctionInput data(bind_data.get(), l_state.local_state.get(), g_state.global_state.get()); if (function.function) { + data.async_result = AsyncResultType::IMPLICIT; + + const auto initial_async_result = data.async_result.GetResultType(); + const auto execution_strategy = g_state.physical_table_scan_execution_strategy; + const auto input_execution_mode = AsyncResult::ConvertToAsyncResultExecutionMode(execution_strategy); + data.results_execution_mode = input_execution_mode; + + // Actually call the function function.function(context.client, data, chunk); - return chunk.size() == 0 ? SourceResultType::FINISHED : SourceResultType::HAVE_MORE_OUTPUT; + + const auto output_async_result = data.async_result.GetResultType(); + + // Compare and check whether state before and after function.function call is compatible, will throw in case of + // inconsistencies + ValidateAsyncStrategyResult(execution_strategy, input_execution_mode, data.results_execution_mode, + initial_async_result, output_async_result, chunk.size()); + + // Handle results + switch (output_async_result) { + case AsyncResultType::BLOCKED: { + D_ASSERT(data.async_result.HasTasks()); + auto guard = g_state.Lock(); + if (g_state.CanBlock(guard)) { + data.async_result.ScheduleTasks(input.interrupt_state, context.pipeline->executor); + return SourceResultType::BLOCKED; + } + return SourceResultType::FINISHED; + } + case AsyncResultType::IMPLICIT: + if (chunk.size() > 0) { + return SourceResultType::HAVE_MORE_OUTPUT; + } + return SourceResultType::FINISHED; + case AsyncResultType::FINISHED: + return SourceResultType::FINISHED; + case AsyncResultType::HAVE_MORE_OUTPUT: + return SourceResultType::HAVE_MORE_OUTPUT; + default: + throw InternalException( + "PhysicalTableScan::GetData call of function.function returned unexpected return '%'", + EnumUtil::ToChars(data.async_result.GetResultType())); + } + throw InternalException("PhysicalTableScan::GetData hasn't handled a function.function return"); } if (g_state.in_out_final) { function.in_out_function_final(context, data, chunk); } switch (function.in_out_function(context, data, g_state.input_chunk, chunk)) { - case OperatorResultType::BLOCKED: { auto guard = g_state.Lock(); return g_state.BlockSource(guard, input.interrupt_state); diff --git a/src/duckdb/src/execution/operator/schema/physical_attach.cpp b/src/duckdb/src/execution/operator/schema/physical_attach.cpp index 48e687703..df066c5a0 100644 --- a/src/duckdb/src/execution/operator/schema/physical_attach.cpp +++ b/src/duckdb/src/execution/operator/schema/physical_attach.cpp @@ -40,7 +40,6 @@ SourceResultType PhysicalAttach::GetData(ExecutionContext &context, DataChunk &c if (existing_db) { if ((existing_db->IsReadOnly() && options.access_mode == AccessMode::READ_WRITE) || (!existing_db->IsReadOnly() && options.access_mode == AccessMode::READ_ONLY)) { - auto existing_mode = existing_db->IsReadOnly() ? AccessMode::READ_ONLY : AccessMode::READ_WRITE; auto existing_mode_str = EnumUtil::ToString(existing_mode); auto attached_mode = EnumUtil::ToString(options.access_mode); diff --git a/src/duckdb/src/execution/operator/schema/physical_create_art_index.cpp b/src/duckdb/src/execution/operator/schema/physical_create_art_index.cpp index d21b7bcf1..9e92a787a 100644 --- a/src/duckdb/src/execution/operator/schema/physical_create_art_index.cpp +++ b/src/duckdb/src/execution/operator/schema/physical_create_art_index.cpp @@ -23,7 +23,6 @@ PhysicalCreateARTIndex::PhysicalCreateARTIndex(PhysicalPlan &physical_plan, Logi : PhysicalOperator(physical_plan, PhysicalOperatorType::CREATE_INDEX, op.types, estimated_cardinality), table(table_p.Cast()), info(std::move(info)), unbound_expressions(std::move(unbound_expressions)), sorted(sorted), alter_table_info(std::move(alter_table_info)) { - // Convert the logical column ids to physical column ids. for (auto &column_id : column_ids) { storage_ids.push_back(table.GetColumns().LogicalToPhysical(LogicalIndex(column_id)).index); @@ -85,7 +84,6 @@ unique_ptr PhysicalCreateARTIndex::GetLocalSinkState(ExecutionCo } SinkResultType PhysicalCreateARTIndex::SinkUnsorted(OperatorSinkInput &input) const { - auto &l_state = input.local_state.Cast(); auto row_count = l_state.key_chunk.size(); auto &art = l_state.local_index->Cast(); @@ -105,7 +103,6 @@ SinkResultType PhysicalCreateARTIndex::SinkUnsorted(OperatorSinkInput &input) co } SinkResultType PhysicalCreateARTIndex::SinkSorted(OperatorSinkInput &input) const { - auto &l_state = input.local_state.Cast(); auto &storage = table.GetStorage(); auto &l_index = l_state.local_index; @@ -172,7 +169,7 @@ SinkFinalizeType PhysicalCreateARTIndex::Finalize(Pipeline &pipeline, Event &eve // Vacuum excess memory and verify. state.global_index->Vacuum(); - D_ASSERT(!state.global_index->VerifyAndToString(true).empty()); + state.global_index->Verify(); state.global_index->VerifyAllocations(); auto &storage = table.GetStorage(); diff --git a/src/duckdb/src/execution/operator/set/physical_recursive_cte.cpp b/src/duckdb/src/execution/operator/set/physical_recursive_cte.cpp index 79420e902..9fcbd3876 100644 --- a/src/duckdb/src/execution/operator/set/physical_recursive_cte.cpp +++ b/src/duckdb/src/execution/operator/set/physical_recursive_cte.cpp @@ -34,7 +34,6 @@ class RecursiveCTEState : public GlobalSinkState { public: explicit RecursiveCTEState(ClientContext &context, const PhysicalRecursiveCTE &op) : intermediate_table(context, op.GetTypes()), new_groups(STANDARD_VECTOR_SIZE) { - vector payload_aggregates_ptr; for (idx_t i = 0; i < op.payload_aggregates.size(); i++) { auto &dat = op.payload_aggregates[i]; diff --git a/src/duckdb/src/execution/physical_operator.cpp b/src/duckdb/src/execution/physical_operator.cpp index ad51afa31..48d7118e7 100644 --- a/src/duckdb/src/execution/physical_operator.cpp +++ b/src/duckdb/src/execution/physical_operator.cpp @@ -301,7 +301,6 @@ bool CachingPhysicalOperator::CanCacheType(const LogicalType &type) { CachingPhysicalOperator::CachingPhysicalOperator(PhysicalPlan &physical_plan, PhysicalOperatorType type, vector types_p, idx_t estimated_cardinality) : PhysicalOperator(physical_plan, type, std::move(types_p), estimated_cardinality) { - caching_supported = true; for (auto &col_type : types) { if (!CanCacheType(col_type)) { diff --git a/src/duckdb/src/execution/physical_plan/plan_asof_join.cpp b/src/duckdb/src/execution/physical_plan/plan_asof_join.cpp index 5759583c5..3d84506fa 100644 --- a/src/duckdb/src/execution/physical_plan/plan_asof_join.cpp +++ b/src/duckdb/src/execution/physical_plan/plan_asof_join.cpp @@ -13,13 +13,13 @@ #include "duckdb/planner/expression/bound_constant_expression.hpp" #include "duckdb/planner/expression/bound_reference_expression.hpp" #include "duckdb/planner/expression/bound_window_expression.hpp" +#include "duckdb/planner/expression_iterator.hpp" #include "duckdb/main/settings.hpp" namespace duckdb { optional_ptr PhysicalPlanGenerator::PlanAsOfLoopJoin(LogicalComparisonJoin &op, PhysicalOperator &probe, PhysicalOperator &build) { - // Plan a inverse nested loop join, then aggregate the values to choose the optimal match for each probe row. // Use a row number primary key to handle duplicate probe values. // aggregate the fields to produce at most one match per probe row, @@ -27,7 +27,7 @@ PhysicalPlanGenerator::PlanAsOfLoopJoin(LogicalComparisonJoin &op, PhysicalOpera // // ∏ * \ pk // | - // Γ pk;first(P),arg_xxx(B,inequality) + // Γ pk;first(P),arg_xxx_null(B,inequality) // | // ∏ *,inequality // | @@ -43,10 +43,9 @@ PhysicalPlanGenerator::PlanAsOfLoopJoin(LogicalComparisonJoin &op, PhysicalOpera const auto &probe_types = op.children[0]->types; join_op.types.insert(join_op.types.end(), probe_types.begin(), probe_types.end()); - // TODO: We can't handle predicates right now because we would have to remap column references. - if (op.predicate) { - return nullptr; - } + // Project pk + LogicalType pk_type = LogicalType::BIGINT; + join_op.types.emplace_back(pk_type); // Fill in the projection maps to simplify the code below // Since NLJ doesn't support projection, but ASOF does, @@ -65,9 +64,25 @@ PhysicalPlanGenerator::PlanAsOfLoopJoin(LogicalComparisonJoin &op, PhysicalOpera } } - // Project pk - LogicalType pk_type = LogicalType::BIGINT; - join_op.types.emplace_back(pk_type); + // Remap predicate column references. + if (op.predicate) { + vector swap_projection_map; + const auto rhs_width = op.children[1]->types.size(); + for (const auto &l : join_op.right_projection_map) { + swap_projection_map.emplace_back(l + rhs_width); + } + for (const auto &r : join_op.left_projection_map) { + swap_projection_map.emplace_back(r); + } + join_op.predicate = op.predicate->Copy(); + ExpressionIterator::EnumerateExpression(join_op.predicate, [&](Expression &child) { + if (child.GetExpressionClass() == ExpressionClass::BOUND_REF) { + auto &col_idx = child.Cast().index; + const auto new_idx = swap_projection_map[col_idx]; + col_idx = new_idx; + } + }); + } auto binder = Binder::CreateBinder(context); FunctionBinder function_binder(*binder); @@ -88,13 +103,13 @@ PhysicalPlanGenerator::PlanAsOfLoopJoin(LogicalComparisonJoin &op, PhysicalOpera case ExpressionType::COMPARE_GREATERTHAN: D_ASSERT(asof_idx == op.conditions.size()); asof_idx = i; - arg_min_max = "arg_max"; + arg_min_max = "arg_max_null"; break; case ExpressionType::COMPARE_LESSTHANOREQUALTO: case ExpressionType::COMPARE_LESSTHAN: D_ASSERT(asof_idx == op.conditions.size()); asof_idx = i; - arg_min_max = "arg_min"; + arg_min_max = "arg_min_null"; break; case ExpressionType::COMPARE_EQUAL: case ExpressionType::COMPARE_NOTEQUAL: @@ -208,7 +223,7 @@ PhysicalPlanGenerator::PlanAsOfLoopJoin(LogicalComparisonJoin &op, PhysicalOpera auto window_types = probe.GetTypes(); window_types.emplace_back(pk_type); - idx_t probe_cardinality = op.children[0]->EstimateCardinality(context); + const auto probe_cardinality = op.EstimateCardinality(context); auto &window = Make(window_types, std::move(window_select), probe_cardinality); window.children.emplace_back(probe); @@ -275,10 +290,12 @@ PhysicalOperator &PhysicalPlanGenerator::PlanAsOfJoin(LogicalComparisonJoin &op) } D_ASSERT(asof_idx < op.conditions.size()); - bool force_asof_join = DBConfig::GetSetting(context); - if (!force_asof_join) { - idx_t asof_join_threshold = DBConfig::GetSetting(context); - if (op.children[0]->has_estimated_cardinality && lhs_cardinality < asof_join_threshold) { + // If there is a non-comparison predicate, we have to use NLJ. + const bool has_predicate = op.predicate.get(); + const bool force_asof_join = DBConfig::GetSetting(context); + if (!force_asof_join || has_predicate) { + const idx_t asof_join_threshold = DBConfig::GetSetting(context); + if (has_predicate || (op.children[0]->has_estimated_cardinality && lhs_cardinality < asof_join_threshold)) { auto result = PlanAsOfLoopJoin(op, left, right); if (result) { return *result; diff --git a/src/duckdb/src/execution/radix_partitioned_hashtable.cpp b/src/duckdb/src/execution/radix_partitioned_hashtable.cpp index 7265a23e4..eb40a50cb 100644 --- a/src/duckdb/src/execution/radix_partitioned_hashtable.cpp +++ b/src/duckdb/src/execution/radix_partitioned_hashtable.cpp @@ -208,7 +208,6 @@ RadixHTGlobalSinkState::RadixHTGlobalSinkState(ClientContext &context_p, const R any_combined(false), radix_ht(radix_ht_p), config(*this), stored_allocators_size(0), finalize_done(0), scan_pin_properties(TupleDataPinProperties::DESTROY_AFTER_DONE), count_before_combining(0), max_partition_size(0) { - // Compute minimum reservation auto block_alloc_size = BufferManager::GetBufferManager(context).GetBlockAllocSize(); auto tuples_per_block = block_alloc_size / radix_ht.GetLayout().GetRowWidth(); diff --git a/src/duckdb/src/execution/sample/base_reservoir_sample.cpp b/src/duckdb/src/execution/sample/base_reservoir_sample.cpp index 35de6d54f..3be480eaf 100644 --- a/src/duckdb/src/execution/sample/base_reservoir_sample.cpp +++ b/src/duckdb/src/execution/sample/base_reservoir_sample.cpp @@ -60,7 +60,6 @@ void BaseReservoirSampling::SetNextEntry() { } void BaseReservoirSampling::ReplaceElementWithIndex(idx_t entry_index, double with_weight, bool pop) { - if (pop) { reservoir_weights.pop(); } diff --git a/src/duckdb/src/execution/sample/reservoir_sample.cpp b/src/duckdb/src/execution/sample/reservoir_sample.cpp index cb52f3f2b..a603bc19f 100644 --- a/src/duckdb/src/execution/sample/reservoir_sample.cpp +++ b/src/duckdb/src/execution/sample/reservoir_sample.cpp @@ -190,7 +190,6 @@ void ReservoirSample::Vacuum() { } unique_ptr ReservoirSample::Copy() const { - auto ret = make_uniq(sample_count); ret->stats_sample = stats_sample; @@ -271,7 +270,7 @@ void ReservoirSample::SimpleMerge(ReservoirSample &other) { auto weight_tuples_this = static_cast(GetTuplesSeen()) / static_cast(total_seen); auto weight_tuples_other = static_cast(other.GetTuplesSeen()) / static_cast(total_seen); - // If weights don't add up to 1, most likely a simple merge occured and no new samples were added. + // If weights don't add up to 1, most likely a simple merge occurred and no new samples were added. // if that is the case, add the missing weight to the lower weighted sample to adjust. // this is to avoid cases where if you have a 20k row table and add another 20k rows row by row // then eventually the missing weights will add up, and get you a more even distribution @@ -564,7 +563,6 @@ T ReservoirSample::GetReservoirChunkCapacity() const { } idx_t ReservoirSample::FillReservoir(DataChunk &chunk) { - idx_t ingested_count = 0; if (!reservoir_chunk) { if (chunk.size() > FIXED_SAMPLE_SIZE) { @@ -609,7 +607,6 @@ SelectionVectorHelper ReservoirSample::GetReplacementIndexes(idx_t sample_chunk_ } SelectionVectorHelper ReservoirSample::GetReplacementIndexesFast(idx_t sample_chunk_offset, idx_t chunk_length) { - // how much weight to the other tuples have compared to the ones in this chunk? auto weight_tuples_other = static_cast(chunk_length) / static_cast(GetTuplesSeen() + chunk_length); auto num_to_pop = static_cast(round(weight_tuples_other * static_cast(sample_count))); diff --git a/src/duckdb/src/function/aggregate/distributive/count.cpp b/src/duckdb/src/function/aggregate/distributive/count.cpp index 41af395aa..665fc70ce 100644 --- a/src/duckdb/src/function/aggregate/distributive/count.cpp +++ b/src/duckdb/src/function/aggregate/distributive/count.cpp @@ -239,7 +239,7 @@ AggregateFunction CountFunctionBase::GetFunction() { AggregateFunction CountStarFun::GetFunction() { auto fun = AggregateFunction::NullaryAggregate(LogicalType::BIGINT); fun.name = "count_star"; - fun.null_handling = FunctionNullHandling::SPECIAL_HANDLING; + fun.SetNullHandling(FunctionNullHandling::SPECIAL_HANDLING); fun.order_dependent = AggregateOrderDependent::NOT_ORDER_DEPENDENT; fun.window = CountStarFunction::Window; return fun; diff --git a/src/duckdb/src/function/aggregate/distributive/first_last_any.cpp b/src/duckdb/src/function/aggregate/distributive/first_last_any.cpp index 442eec461..2873195be 100644 --- a/src/duckdb/src/function/aggregate/distributive/first_last_any.cpp +++ b/src/duckdb/src/function/aggregate/distributive/first_last_any.cpp @@ -213,7 +213,7 @@ struct FirstVectorFunction : FirstFunctionStringBase { static unique_ptr Bind(ClientContext &context, AggregateFunction &function, vector> &arguments) { function.arguments[0] = arguments[0]->return_type; - function.return_type = arguments[0]->return_type; + function.SetReturnType(arguments[0]->return_type); return nullptr; } }; @@ -260,7 +260,7 @@ AggregateFunction GetFirstFunction(const LogicalType &type) { type.Verify(); AggregateFunction function = GetDecimalFirstFunction(type); function.arguments[0] = type; - function.return_type = type; + function.SetReturnType(type); return function; } switch (type.InternalType()) { @@ -318,7 +318,7 @@ unique_ptr BindDecimalFirst(ClientContext &context, AggregateFunct function = GetFirstFunction(decimal_type); function.name = std::move(name); function.distinct_dependent = AggregateDistinctDependent::NOT_DISTINCT_DEPENDENT; - function.return_type = decimal_type; + function.SetReturnType(decimal_type); return nullptr; } diff --git a/src/duckdb/src/function/aggregate/distributive/minmax.cpp b/src/duckdb/src/function/aggregate/distributive/minmax.cpp index ce5ef12af..d56101cad 100644 --- a/src/duckdb/src/function/aggregate/distributive/minmax.cpp +++ b/src/duckdb/src/function/aggregate/distributive/minmax.cpp @@ -296,7 +296,7 @@ struct VectorMinMaxBase { static unique_ptr Bind(ClientContext &context, AggregateFunction &function, vector> &arguments) { function.arguments[0] = arguments[0]->return_type; - function.return_type = arguments[0]->return_type; + function.SetReturnType(arguments[0]->return_type); return nullptr; } }; @@ -367,8 +367,8 @@ unique_ptr BindMinMax(ClientContext &context, AggregateFunction &f // Bind function like arg_min/arg_max. function.arguments[0] = arguments[0]->return_type; - function.return_type = arguments[0]->return_type; - return nullptr; + function.SetReturnType(arguments[0]->return_type); + return make_uniq(); } } @@ -431,7 +431,6 @@ class MinMaxNState { template void MinMaxNUpdate(Vector inputs[], AggregateInputData &aggr_input, idx_t input_count, Vector &state_vector, idx_t count) { - auto &val_vector = inputs[0]; auto &n_vector = inputs[1]; @@ -441,7 +440,7 @@ void MinMaxNUpdate(Vector inputs[], AggregateInputData &aggr_input, idx_t input_ auto val_extra_state = STATE::VAL_TYPE::CreateExtraState(val_vector, count); - STATE::VAL_TYPE::PrepareData(val_vector, count, val_extra_state, val_format); + STATE::VAL_TYPE::PrepareData(val_vector, count, val_extra_state, val_format, true); n_vector.ToUnifiedFormat(count, n_format); state_vector.ToUnifiedFormat(count, state_format); @@ -520,7 +519,6 @@ void SpecializeMinMaxNFunction(PhysicalType arg_type, AggregateFunction &functio template unique_ptr MinMaxNBind(ClientContext &context, AggregateFunction &function, vector> &arguments) { - for (auto &arg : arguments) { if (arg->return_type.id() == LogicalTypeId::UNKNOWN) { throw ParameterNotResolvedException(); @@ -532,7 +530,7 @@ unique_ptr MinMaxNBind(ClientContext &context, AggregateFunction & // Specialize the function based on the input types SpecializeMinMaxNFunction(val_type, function); - function.return_type = LogicalType::LIST(arguments[0]->return_type); + function.SetReturnType(LogicalType::LIST(arguments[0]->return_type)); return nullptr; } diff --git a/src/duckdb/src/function/aggregate/sorted_aggregate_function.cpp b/src/duckdb/src/function/aggregate/sorted_aggregate_function.cpp index bde4c1479..193898272 100644 --- a/src/duckdb/src/function/aggregate/sorted_aggregate_function.cpp +++ b/src/duckdb/src/function/aggregate/sorted_aggregate_function.cpp @@ -10,6 +10,7 @@ #include "duckdb/planner/expression/bound_reference_expression.hpp" #include "duckdb/parser/expression_map.hpp" #include "duckdb/parallel/thread_context.hpp" +#include "duckdb/main/client_context.hpp" #include "duckdb/main/settings.hpp" namespace duckdb { @@ -24,7 +25,6 @@ struct SortedAggregateBindData : public FunctionData { BindInfoPtr &bind_info, OrderBys &order_bys) : context(context), function(aggregate), bind_info(std::move(bind_info)), threshold(DBConfig::GetSetting(context)) { - // Describe the arguments. for (const auto &child : children) { buffered_cols.emplace_back(buffered_cols.size()); @@ -433,7 +433,6 @@ struct SortedAggregateFunction { static void ProjectInputs(Vector inputs[], const SortedAggregateBindData &order_bind, idx_t input_count, idx_t count, DataChunk &buffered) { - // Only reference the buffered columns buffered.InitializeEmpty(order_bind.buffered_types); const auto &buffered_cols = order_bind.buffered_cols; @@ -709,13 +708,14 @@ void FunctionBinder::BindSortedAggregate(ClientContext &context, BoundAggregateE // Replace the aggregate with the wrapper AggregateFunction ordered_aggregate( - bound_function.name, arguments, bound_function.return_type, AggregateFunction::StateSize, + bound_function.name, arguments, bound_function.GetReturnType(), + AggregateFunction::StateSize, AggregateFunction::StateInitialize, SortedAggregateFunction::ScatterUpdate, AggregateFunction::StateCombine, - SortedAggregateFunction::Finalize, bound_function.null_handling, SortedAggregateFunction::SimpleUpdate, nullptr, - AggregateFunction::StateDestroy, nullptr, + SortedAggregateFunction::Finalize, bound_function.GetNullHandling(), SortedAggregateFunction::SimpleUpdate, + nullptr, AggregateFunction::StateDestroy, nullptr, SortedAggregateFunction::Window); expr.function = std::move(ordered_aggregate); @@ -765,12 +765,12 @@ void FunctionBinder::BindSortedAggregate(ClientContext &context, BoundWindowExpr // Replace the aggregate with the wrapper AggregateFunction ordered_aggregate( - aggregate.name, arguments, aggregate.return_type, AggregateFunction::StateSize, + aggregate.name, arguments, aggregate.GetReturnType(), AggregateFunction::StateSize, AggregateFunction::StateInitialize, SortedAggregateFunction::ScatterUpdate, AggregateFunction::StateCombine, - SortedAggregateFunction::Finalize, aggregate.null_handling, SortedAggregateFunction::SimpleUpdate, nullptr, + SortedAggregateFunction::Finalize, aggregate.GetNullHandling(), SortedAggregateFunction::SimpleUpdate, nullptr, AggregateFunction::StateDestroy, nullptr, SortedAggregateFunction::Window); diff --git a/src/duckdb/src/function/cast/array_casts.cpp b/src/duckdb/src/function/cast/array_casts.cpp index 2357a2c2c..ad243ef2c 100644 --- a/src/duckdb/src/function/cast/array_casts.cpp +++ b/src/duckdb/src/function/cast/array_casts.cpp @@ -39,7 +39,6 @@ unique_ptr ArrayBoundCastData::InitArrayLocalState(CastLocal // ARRAY -> ARRAY //------------------------------------------------------------------------------ static bool ArrayToArrayCast(Vector &source, Vector &result, idx_t count, CastParameters ¶meters) { - auto source_array_size = ArrayType::GetSize(source.GetType()); auto target_array_size = ArrayType::GetSize(result.GetType()); if (source_array_size != target_array_size) { diff --git a/src/duckdb/src/function/cast/cast_function_set.cpp b/src/duckdb/src/function/cast/cast_function_set.cpp index 4e6ed8b99..606fa9010 100644 --- a/src/duckdb/src/function/cast/cast_function_set.cpp +++ b/src/duckdb/src/function/cast/cast_function_set.cpp @@ -184,7 +184,9 @@ int64_t CastFunctionSet::ImplicitCastCost(optional_ptr context, c old_implicit_casting = DBConfig::GetSetting(*config); } if (old_implicit_casting) { - score = 149; + // very high cost to avoid choosing this cast if any other option is available + // (it should be more costly than casting to TEMPLATE if that is available) + score = 10000000000; } } return score; diff --git a/src/duckdb/src/function/cast/default_casts.cpp b/src/duckdb/src/function/cast/default_casts.cpp index 0c0c1c058..558329f70 100644 --- a/src/duckdb/src/function/cast/default_casts.cpp +++ b/src/duckdb/src/function/cast/default_casts.cpp @@ -162,6 +162,8 @@ BoundCastInfo DefaultCasts::GetDefaultCastFunction(BindCastInput &input, const L return EnumCastSwitch(input, source, target); case LogicalTypeId::ARRAY: return ArrayCastSwitch(input, source, target); + case LogicalTypeId::GEOMETRY: + return GeoCastSwitch(input, source, target); case LogicalTypeId::BIGNUM: return BignumCastSwitch(input, source, target); case LogicalTypeId::AGGREGATE_STATE: diff --git a/src/duckdb/src/function/cast/geo_casts.cpp b/src/duckdb/src/function/cast/geo_casts.cpp new file mode 100644 index 000000000..59595359f --- /dev/null +++ b/src/duckdb/src/function/cast/geo_casts.cpp @@ -0,0 +1,23 @@ +#include "duckdb/common/types/geometry.hpp" +#include "duckdb/function/cast/default_casts.hpp" +#include "duckdb/function/cast/vector_cast_helpers.hpp" + +namespace duckdb { + +static bool GeometryToVarcharCast(Vector &source, Vector &result, idx_t count, CastParameters ¶meters) { + UnaryExecutor::Execute( + source, result, count, [&](const string_t &input) -> string_t { return Geometry::ToString(result, input); }); + return true; +} + +BoundCastInfo DefaultCasts::GeoCastSwitch(BindCastInput &input, const LogicalType &source, const LogicalType &target) { + // now switch on the result type + switch (target.id()) { + case LogicalTypeId::VARCHAR: + return GeometryToVarcharCast; + default: + return TryVectorNullCast; + } +} + +} // namespace duckdb diff --git a/src/duckdb/src/function/cast/string_cast.cpp b/src/duckdb/src/function/cast/string_cast.cpp index 511d09a86..930231808 100644 --- a/src/duckdb/src/function/cast/string_cast.cpp +++ b/src/duckdb/src/function/cast/string_cast.cpp @@ -490,6 +490,8 @@ BoundCastInfo DefaultCasts::StringCastSwitch(BindCastInput &input, const Logical return BoundCastInfo(&VectorCastHelpers::TryCastStringLoop); case LogicalTypeId::UUID: return BoundCastInfo(&VectorCastHelpers::TryCastStringLoop); + case LogicalTypeId::GEOMETRY: + return BoundCastInfo(&VectorCastHelpers::TryCastStringLoop); case LogicalTypeId::SQLNULL: return &DefaultCasts::TryVectorNullCast; case LogicalTypeId::VARCHAR: diff --git a/src/duckdb/src/function/cast/struct_cast.cpp b/src/duckdb/src/function/cast/struct_cast.cpp index 97a9354d1..12c60bd75 100644 --- a/src/duckdb/src/function/cast/struct_cast.cpp +++ b/src/duckdb/src/function/cast/struct_cast.cpp @@ -12,8 +12,8 @@ unique_ptr StructBoundCastData::BindStructToStructCast(BindCastIn auto &source_children = StructType::GetChildTypes(source); auto &target_children = StructType::GetChildTypes(target); - auto target_is_unnamed = StructType::IsUnnamed(target); - auto source_is_unnamed = StructType::IsUnnamed(source); + auto target_is_unnamed = target_children.empty() || StructType::IsUnnamed(target); + auto source_is_unnamed = source_children.empty() || StructType::IsUnnamed(source); auto is_unnamed = target_is_unnamed || source_is_unnamed; if (is_unnamed && source_children.size() != target_children.size()) { @@ -268,7 +268,6 @@ StructToMapBoundCastData::InitStructToMapCastLocalState(CastLocalStateParameters } static bool StructToMapCast(Vector &source, Vector &result, idx_t count, CastParameters ¶meters) { - if (source.GetVectorType() == VectorType::CONSTANT_VECTOR) { // Optimization: if the source vector is constant, we only have a single physical element, so we can set the // result vectortype to ConstantVector as well and set the (logical) count to 1 diff --git a/src/duckdb/src/function/cast/union_casts.cpp b/src/duckdb/src/function/cast/union_casts.cpp index 65f018d2b..5a7b5d466 100644 --- a/src/duckdb/src/function/cast/union_casts.cpp +++ b/src/duckdb/src/function/cast/union_casts.cpp @@ -56,7 +56,6 @@ static unique_ptr BindToUnionCast(BindCastInput &input, const Log // check if the cast is ambiguous (2 or more casts have the same cost) if (candidates.size() > 1 && candidates[1].cost == selected_cost) { - // collect all the ambiguous types auto message = StringUtil::Format( "Type %s can't be cast as %s. The cast is ambiguous, multiple possible members in target: ", source, @@ -107,7 +106,6 @@ static bool ToUnionCast(Vector &source, Vector &result, idx_t count, CastParamet BoundCastInfo DefaultCasts::ImplicitToUnionCast(BindCastInput &input, const LogicalType &source, const LogicalType &target) { - D_ASSERT(target.id() == LogicalTypeId::UNION); if (StructToUnionCast::AllowImplicitCastFromStruct(source, target)) { return StructToUnionCast::Bind(input, source, target); @@ -130,7 +128,6 @@ BoundCastInfo DefaultCasts::ImplicitToUnionCast(BindCastInput &input, const Logi // INVALID: UNION(A, B, D) -> UNION(A, B, C) struct UnionUnionBoundCastData : public BoundCastData { - // mapping from source member index to target member index // these are always the same size as the source member count // (since all source members must be present in the target) @@ -284,7 +281,6 @@ static bool UnionToUnionCast(Vector &source, Vector &result, idx_t count, CastPa FlatVector::GetData(result_tag_vector)[row_idx] = UnsafeNumericCast(target_tag); } else { - // Issue: The members of the result is not always flatvectors // In the case of TryNullCast, the result member is constant. FlatVector::SetNull(result, row_idx, true); diff --git a/src/duckdb/src/function/cast/variant/from_variant.cpp b/src/duckdb/src/function/cast/variant/from_variant.cpp index ca377b326..f29db2b85 100644 --- a/src/duckdb/src/function/cast/variant/from_variant.cpp +++ b/src/duckdb/src/function/cast/variant/from_variant.cpp @@ -1,3 +1,4 @@ +#include "yyjson_utils.hpp" #include "duckdb/function/cast/default_casts.hpp" #include "duckdb/common/types/variant.hpp" #include "duckdb/function/scalar/variant_utils.hpp" @@ -49,22 +50,6 @@ struct DecimalConversionPayloadFromVariant { idx_t scale; }; -struct ConvertedJSONHolder { -public: - ~ConvertedJSONHolder() { - if (doc) { - yyjson_mut_doc_free(doc); - } - if (stringified_json) { - free(stringified_json); - } - } - -public: - yyjson_mut_doc *doc = nullptr; - char *stringified_json = nullptr; -}; - } // namespace //===--------------------------------------------------------------------===// @@ -364,6 +349,14 @@ static bool ConvertVariantToStruct(FromVariantConversionData &conversion_data, V SelectionVector child_values_sel; child_values_sel.Initialize(count); + SelectionVector row_sel(0, count); + if (row.IsValid()) { + auto row_index = row.GetIndex(); + for (idx_t i = 0; i < count; i++) { + row_sel[i] = static_cast(row_index); + } + } + for (idx_t child_idx = 0; child_idx < child_types.size(); child_idx++) { auto &child_name = child_types[child_idx].first; @@ -372,14 +365,21 @@ static bool ConvertVariantToStruct(FromVariantConversionData &conversion_data, V VariantPathComponent component; component.key = child_name; component.lookup_mode = VariantChildLookupMode::BY_KEY; - auto collection_result = - VariantUtils::FindChildValues(conversion_data.variant, component, row, child_values_sel, child_data, count); - if (!collection_result.Success()) { - D_ASSERT(collection_result.type == VariantChildDataCollectionResult::Type::COMPONENT_NOT_FOUND); - auto nested_index = collection_result.nested_data_index; - auto row_index = row.IsValid() ? row.GetIndex() : nested_index; + ValidityMask lookup_validity(count); + VariantUtils::FindChildValues(conversion_data.variant, component, row_sel, child_values_sel, lookup_validity, + child_data, count); + if (!lookup_validity.AllValid()) { + optional_idx nested_index; + for (idx_t i = 0; i < count; i++) { + if (!lookup_validity.RowIsValid(i)) { + nested_index = i; + break; + } + } + D_ASSERT(nested_index.IsValid()); + auto row_index = row.IsValid() ? row.GetIndex() : nested_index.GetIndex(); auto object_keys = - VariantUtils::GetObjectKeys(conversion_data.variant, row_index, child_data[nested_index]); + VariantUtils::GetObjectKeys(conversion_data.variant, row_index, child_data[nested_index.GetIndex()]); conversion_data.error = StringUtil::Format("VARIANT(OBJECT(%s)) is missing key '%s'", StringUtil::Join(object_keys, ","), component.key); return false; @@ -550,6 +550,11 @@ static bool CastVariant(FromVariantConversionData &conversion_data, Vector &resu return CastVariantToPrimitive>( conversion_data, result, sel, offset, count, row, string_payload); } + case LogicalTypeId::GEOMETRY: { + StringConversionPayload string_payload(result); + return CastVariantToPrimitive>( + conversion_data, result, sel, offset, count, row, string_payload); + } case LogicalTypeId::VARCHAR: { if (target_type.IsJSONType()) { return CastVariantToJSON(conversion_data, result, sel, offset, count, row); @@ -686,6 +691,8 @@ BoundCastInfo DefaultCasts::VariantCastSwitch(BindCastInput &input, const Logica case LogicalTypeId::UUID: case LogicalTypeId::ARRAY: return BoundCastInfo(CastFromVARIANT); + case LogicalTypeId::GEOMETRY: + return BoundCastInfo(CastFromVARIANT); case LogicalTypeId::VARCHAR: { return BoundCastInfo(CastFromVARIANT); } diff --git a/src/duckdb/src/function/cast/variant/to_json.cpp b/src/duckdb/src/function/cast/variant/to_json.cpp index 9d35c142c..482fa90c2 100644 --- a/src/duckdb/src/function/cast/variant/to_json.cpp +++ b/src/duckdb/src/function/cast/variant/to_json.cpp @@ -10,6 +10,7 @@ #include "duckdb/common/types/decimal.hpp" #include "duckdb/common/types/time.hpp" #include "duckdb/common/types/timestamp.hpp" +#include "duckdb/common/types/variant_visitor.hpp" using namespace duckdb_yyjson; // NOLINT @@ -17,256 +18,211 @@ namespace duckdb { //! ------------ Variant -> JSON ------------ -yyjson_mut_val *VariantCasts::ConvertVariantToJSON(yyjson_mut_doc *doc, const RecursiveUnifiedVectorFormat &source, - idx_t row, uint32_t values_idx) { - auto index = source.unified.sel->get_index(row); - if (!source.unified.validity.RowIsValid(index)) { - return yyjson_mut_null(doc); - } +namespace { + +struct JSONConverter { + using result_type = yyjson_mut_val *; - //! values - auto &values = UnifiedVariantVector::GetValues(source); - auto values_data = values.GetData(values); - - //! type_ids - auto &type_ids = UnifiedVariantVector::GetValuesTypeId(source); - auto type_ids_data = type_ids.GetData(type_ids); - - //! byte_offsets - auto &byte_offsets = UnifiedVariantVector::GetValuesByteOffset(source); - auto byte_offsets_data = byte_offsets.GetData(byte_offsets); - - //! children - auto &children = UnifiedVariantVector::GetChildren(source); - auto children_data = children.GetData(children); - - //! values_index - auto &values_index = UnifiedVariantVector::GetChildrenValuesIndex(source); - auto values_index_data = values_index.GetData(values_index); - - //! keys_index - auto &keys_index = UnifiedVariantVector::GetChildrenKeysIndex(source); - auto keys_index_data = keys_index.GetData(keys_index); - - //! keys - auto &keys = UnifiedVariantVector::GetKeys(source); - auto keys_data = keys.GetData(keys); - auto &keys_entry = UnifiedVariantVector::GetKeysEntry(source); - auto keys_entry_data = keys_entry.GetData(keys_entry); - - //! list entries - auto keys_list_entry = keys_data[keys.sel->get_index(row)]; - auto children_list_entry = children_data[children.sel->get_index(row)]; - auto values_list_entry = values_data[values.sel->get_index(row)]; - - //! The 'values' data of the value we're currently converting - values_idx += values_list_entry.offset; - auto type_id = static_cast(type_ids_data[type_ids.sel->get_index(values_idx)]); - auto byte_offset = byte_offsets_data[byte_offsets.sel->get_index(values_idx)]; - - //! The blob data of the Variant, accessed by byte offset retrieved above ^ - auto &value = UnifiedVariantVector::GetData(source); - auto value_data = value.GetData(value); - auto &blob = value_data[value.sel->get_index(row)]; - auto blob_data = const_data_ptr_cast(blob.GetData()); - - auto ptr = const_data_ptr_cast(blob_data + byte_offset); - switch (type_id) { - case VariantLogicalType::VARIANT_NULL: + static yyjson_mut_val *VisitNull(yyjson_mut_doc *doc) { return yyjson_mut_null(doc); - case VariantLogicalType::BOOL_TRUE: - return yyjson_mut_true(doc); - case VariantLogicalType::BOOL_FALSE: - return yyjson_mut_false(doc); - case VariantLogicalType::INT8: { - auto val = Load(ptr); - return yyjson_mut_sint(doc, val); - } - case VariantLogicalType::INT16: { - auto val = Load(ptr); - return yyjson_mut_sint(doc, val); } - case VariantLogicalType::INT32: { - auto val = Load(ptr); - return yyjson_mut_sint(doc, val); + + static yyjson_mut_val *VisitBoolean(bool val, yyjson_mut_doc *doc) { + return val ? yyjson_mut_true(doc) : yyjson_mut_false(doc); } - case VariantLogicalType::INT64: { - auto val = Load(ptr); - return yyjson_mut_sint(doc, val); + + template + static yyjson_mut_val *VisitInteger(T val, yyjson_mut_doc *doc) { + throw InternalException("JSONConverter::VisitInteger not implemented!"); } - case VariantLogicalType::INT128: { - auto val = Load(ptr); - auto val_str = val.ToString(); - return yyjson_mut_rawncpy(doc, val_str.c_str(), val_str.size()); + + static yyjson_mut_val *VisitTime(dtime_t val, yyjson_mut_doc *doc) { + auto val_str = Time::ToString(val); + return yyjson_mut_strncpy(doc, val_str.c_str(), val_str.size()); } - case VariantLogicalType::UINT8: { - auto val = Load(ptr); - return yyjson_mut_sint(doc, val); + + static yyjson_mut_val *VisitTimeNanos(dtime_ns_t val, yyjson_mut_doc *doc) { + auto val_str = Value::TIME_NS(val).ToString(); + return yyjson_mut_strncpy(doc, val_str.c_str(), val_str.size()); } - case VariantLogicalType::UINT16: { - auto val = Load(ptr); - return yyjson_mut_sint(doc, val); + + static yyjson_mut_val *VisitTimeTZ(dtime_tz_t val, yyjson_mut_doc *doc) { + auto val_str = Value::TIMETZ(val).ToString(); + return yyjson_mut_strncpy(doc, val_str.c_str(), val_str.size()); } - case VariantLogicalType::UINT32: { - auto val = Load(ptr); - return yyjson_mut_sint(doc, val); + + static yyjson_mut_val *VisitTimestampSec(timestamp_sec_t val, yyjson_mut_doc *doc) { + auto val_str = Value::TIMESTAMPSEC(val).ToString(); + return yyjson_mut_strncpy(doc, val_str.c_str(), val_str.size()); } - case VariantLogicalType::UINT64: { - auto val = Load(ptr); - return yyjson_mut_uint(doc, val); + + static yyjson_mut_val *VisitTimestampMs(timestamp_ms_t val, yyjson_mut_doc *doc) { + auto val_str = Value::TIMESTAMPMS(val).ToString(); + return yyjson_mut_strncpy(doc, val_str.c_str(), val_str.size()); } - case VariantLogicalType::UINT128: { - auto val = Load(ptr); - auto val_str = val.ToString(); - return yyjson_mut_rawncpy(doc, val_str.c_str(), val_str.size()); + + static yyjson_mut_val *VisitTimestamp(timestamp_t val, yyjson_mut_doc *doc) { + auto val_str = Timestamp::ToString(val); + return yyjson_mut_strncpy(doc, val_str.c_str(), val_str.size()); } - case VariantLogicalType::UUID: { - auto val = Value::UUID(Load(ptr)); - auto val_str = val.ToString(); + + static yyjson_mut_val *VisitTimestampNanos(timestamp_ns_t val, yyjson_mut_doc *doc) { + auto val_str = Value::TIMESTAMPNS(val).ToString(); return yyjson_mut_strncpy(doc, val_str.c_str(), val_str.size()); } - case VariantLogicalType::INTERVAL: { - auto val = Value::INTERVAL(Load(ptr)); - auto val_str = val.ToString(); + + static yyjson_mut_val *VisitTimestampTZ(timestamp_tz_t val, yyjson_mut_doc *doc) { + auto val_str = Value::TIMESTAMPTZ(val).ToString(); return yyjson_mut_strncpy(doc, val_str.c_str(), val_str.size()); } - case VariantLogicalType::FLOAT: { - auto val = Load(ptr); + + static yyjson_mut_val *VisitFloat(float val, yyjson_mut_doc *doc) { return yyjson_mut_real(doc, val); } - case VariantLogicalType::DOUBLE: { - auto val = Load(ptr); + + static yyjson_mut_val *VisitDouble(double val, yyjson_mut_doc *doc) { return yyjson_mut_real(doc, val); } - case VariantLogicalType::DATE: { - auto val = Load(ptr); - auto val_str = Date::ToString(date_t(val)); + + static yyjson_mut_val *VisitUUID(hugeint_t val, yyjson_mut_doc *doc) { + auto val_str = Value::UUID(val).ToString(); return yyjson_mut_strncpy(doc, val_str.c_str(), val_str.size()); } - case VariantLogicalType::BLOB: { - auto string_length = VarintDecode(ptr); - auto string_data = reinterpret_cast(ptr); - auto val_str = Value::BLOB(const_data_ptr_cast(string_data), string_length).ToString(); + + static yyjson_mut_val *VisitDate(date_t val, yyjson_mut_doc *doc) { + auto val_str = Date::ToString(val); return yyjson_mut_strncpy(doc, val_str.c_str(), val_str.size()); } - case VariantLogicalType::VARCHAR: { - auto string_length = VarintDecode(ptr); - auto string_data = reinterpret_cast(ptr); - return yyjson_mut_strncpy(doc, string_data, static_cast(string_length)); - } - case VariantLogicalType::DECIMAL: { - auto width = NumericCast(VarintDecode(ptr)); - auto scale = NumericCast(VarintDecode(ptr)); - string val_str; - if (width > DecimalWidth::max) { - val_str = Decimal::ToString(Load(ptr), width, scale); - } else if (width > DecimalWidth::max) { - val_str = Decimal::ToString(Load(ptr), width, scale); - } else if (width > DecimalWidth::max) { - val_str = Decimal::ToString(Load(ptr), width, scale); - } else { - val_str = Decimal::ToString(Load(ptr), width, scale); - } - return yyjson_mut_rawncpy(doc, val_str.c_str(), val_str.size()); - } - case VariantLogicalType::TIME_MICROS: { - auto val = Load(ptr); - auto val_str = Time::ToString(val); + static yyjson_mut_val *VisitInterval(interval_t val, yyjson_mut_doc *doc) { + auto val_str = Value::INTERVAL(val).ToString(); return yyjson_mut_strncpy(doc, val_str.c_str(), val_str.size()); } - case VariantLogicalType::TIME_MICROS_TZ: { - auto val = Value::TIMETZ(Load(ptr)); - auto val_str = val.ToString(); - return yyjson_mut_strncpy(doc, val_str.c_str(), val_str.size()); + + static yyjson_mut_val *VisitString(const string_t &str, yyjson_mut_doc *doc) { + return yyjson_mut_strncpy(doc, str.GetData(), str.GetSize()); } - case VariantLogicalType::TIMESTAMP_MICROS: { - auto val = Load(ptr); - auto val_str = Timestamp::ToString(val); + + static yyjson_mut_val *VisitBlob(const string_t &str, yyjson_mut_doc *doc) { + auto val_str = Value::BLOB(const_data_ptr_cast(str.GetData()), str.GetSize()).ToString(); return yyjson_mut_strncpy(doc, val_str.c_str(), val_str.size()); } - case VariantLogicalType::TIMESTAMP_SEC: { - auto val = Value::TIMESTAMPSEC(Load(ptr)); - auto val_str = val.ToString(); - return yyjson_mut_strncpy(doc, val_str.c_str(), val_str.size()); + + static yyjson_mut_val *VisitBignum(const string_t &str, yyjson_mut_doc *doc) { + auto val_str = Value::BIGNUM(const_data_ptr_cast(str.GetData()), str.GetSize()).ToString(); + return yyjson_mut_rawncpy(doc, val_str.c_str(), val_str.size()); } - case VariantLogicalType::TIMESTAMP_NANOS: { - auto val = Value::TIMESTAMPNS(Load(ptr)); - auto val_str = val.ToString(); + + static yyjson_mut_val *VisitGeometry(const string_t &str, yyjson_mut_doc *doc) { + auto val_str = Value::GEOMETRY(const_data_ptr_cast(str.GetData()), str.GetSize()).ToString(); return yyjson_mut_strncpy(doc, val_str.c_str(), val_str.size()); } - case VariantLogicalType::TIMESTAMP_MILIS: { - auto val = Value::TIMESTAMPMS(Load(ptr)); - auto val_str = val.ToString(); + + static yyjson_mut_val *VisitBitstring(const string_t &str, yyjson_mut_doc *doc) { + auto val_str = Value::BIT(const_data_ptr_cast(str.GetData()), str.GetSize()).ToString(); return yyjson_mut_strncpy(doc, val_str.c_str(), val_str.size()); } - case VariantLogicalType::TIMESTAMP_MICROS_TZ: { - auto val = Value::TIMESTAMPTZ(Load(ptr)); - auto val_str = val.ToString(); - return yyjson_mut_strncpy(doc, val_str.c_str(), val_str.size()); + + template + static yyjson_mut_val *VisitDecimal(T val, uint32_t width, uint32_t scale, yyjson_mut_doc *doc) { + string val_str; + if (std::is_same::value) { + val_str = Decimal::ToString(val, static_cast(width), static_cast(scale)); + } else if (std::is_same::value) { + val_str = Decimal::ToString(val, static_cast(width), static_cast(scale)); + } else if (std::is_same::value) { + val_str = Decimal::ToString(val, static_cast(width), static_cast(scale)); + } else if (std::is_same::value) { + val_str = Decimal::ToString(val, static_cast(width), static_cast(scale)); + } else { + throw InternalException("Unhandled decimal type"); + } + return yyjson_mut_rawncpy(doc, val_str.c_str(), val_str.size()); } - case VariantLogicalType::ARRAY: { - auto count = VarintDecode(ptr); + + static yyjson_mut_val *VisitArray(const UnifiedVariantVectorData &variant, idx_t row, + const VariantNestedData &nested_data, yyjson_mut_doc *doc) { auto arr = yyjson_mut_arr(doc); - if (!count) { - return arr; - } - auto child_index_start = VarintDecode(ptr); - for (idx_t i = 0; i < count; i++) { - auto index = values_index.sel->get_index(children_list_entry.offset + child_index_start + i); - auto child_index = values_index_data[index]; -#ifdef DEBUG - auto key_id_index = keys_index.sel->get_index(children_list_entry.offset + child_index_start + i); - D_ASSERT(!keys_index.validity.RowIsValid(key_id_index)); -#endif - auto val = ConvertVariantToJSON(doc, source, row, child_index); - if (!val) { - return nullptr; - } - yyjson_mut_arr_add_val(arr, val); + auto array_items = VariantVisitor::VisitArrayItems(variant, row, nested_data, doc); + for (auto &entry : array_items) { + yyjson_mut_arr_add_val(arr, entry); } return arr; } - case VariantLogicalType::OBJECT: { - auto count = VarintDecode(ptr); + + static yyjson_mut_val *VisitObject(const UnifiedVariantVectorData &variant, idx_t row, + const VariantNestedData &nested_data, yyjson_mut_doc *doc) { auto obj = yyjson_mut_obj(doc); - if (!count) { - return obj; - } - auto child_index_start = VarintDecode(ptr); - - for (idx_t i = 0; i < count; i++) { - auto children_index = values_index.sel->get_index(children_list_entry.offset + child_index_start + i); - auto child_value_idx = values_index_data[children_index]; - auto val = ConvertVariantToJSON(doc, source, row, child_value_idx); - if (!val) { - return nullptr; - } - auto keys_index_index = keys_index.sel->get_index(children_list_entry.offset + child_index_start + i); - D_ASSERT(keys_index.validity.RowIsValid(keys_index_index)); - auto child_key_id = keys_index_data[keys_index_index]; - auto &key = keys_entry_data[keys_entry.sel->get_index(keys_list_entry.offset + child_key_id)]; - yyjson_mut_obj_put(obj, yyjson_mut_strncpy(doc, key.GetData(), key.GetSize()), val); + auto object_items = VariantVisitor::VisitObjectItems(variant, row, nested_data, doc); + for (auto &entry : object_items) { + yyjson_mut_obj_put(obj, yyjson_mut_strncpy(doc, entry.first.c_str(), entry.first.size()), entry.second); } return obj; } - case VariantLogicalType::BITSTRING: { - auto string_length = VarintDecode(ptr); - auto string_data = reinterpret_cast(ptr); - auto val_str = Value::BIT(const_data_ptr_cast(string_data), string_length).ToString(); - return yyjson_mut_strncpy(doc, val_str.c_str(), val_str.size()); - } - case VariantLogicalType::BIGNUM: { - auto string_length = VarintDecode(ptr); - auto string_data = reinterpret_cast(ptr); - auto val_str = Value::BIGNUM(const_data_ptr_cast(string_data), string_length).ToString(); - return yyjson_mut_rawncpy(doc, val_str.c_str(), val_str.size()); - } - default: - throw InternalException("VariantLogicalType(%d) not handled", static_cast(type_id)); + + static yyjson_mut_val *VisitDefault(VariantLogicalType type_id, const_data_ptr_t, yyjson_mut_doc *) { + throw InternalException("VariantLogicalType(%s) not handled", EnumUtil::ToString(type_id)); } +}; + +template <> +yyjson_mut_val *JSONConverter::VisitInteger(int8_t val, yyjson_mut_doc *doc) { + return yyjson_mut_sint(doc, val); +} - return nullptr; +template <> +yyjson_mut_val *JSONConverter::VisitInteger(int16_t val, yyjson_mut_doc *doc) { + return yyjson_mut_sint(doc, val); +} + +template <> +yyjson_mut_val *JSONConverter::VisitInteger(int32_t val, yyjson_mut_doc *doc) { + return yyjson_mut_sint(doc, val); +} + +template <> +yyjson_mut_val *JSONConverter::VisitInteger(int64_t val, yyjson_mut_doc *doc) { + return yyjson_mut_sint(doc, val); +} + +template <> +yyjson_mut_val *JSONConverter::VisitInteger(hugeint_t val, yyjson_mut_doc *doc) { + auto val_str = val.ToString(); + return yyjson_mut_rawncpy(doc, val_str.c_str(), val_str.size()); +} + +template <> +yyjson_mut_val *JSONConverter::VisitInteger(uint8_t val, yyjson_mut_doc *doc) { + return yyjson_mut_sint(doc, val); +} + +template <> +yyjson_mut_val *JSONConverter::VisitInteger(uint16_t val, yyjson_mut_doc *doc) { + return yyjson_mut_sint(doc, val); +} + +template <> +yyjson_mut_val *JSONConverter::VisitInteger(uint32_t val, yyjson_mut_doc *doc) { + return yyjson_mut_sint(doc, val); +} + +template <> +yyjson_mut_val *JSONConverter::VisitInteger(uint64_t val, yyjson_mut_doc *doc) { + return yyjson_mut_uint(doc, val); +} + +template <> +yyjson_mut_val *JSONConverter::VisitInteger(uhugeint_t val, yyjson_mut_doc *doc) { + auto val_str = val.ToString(); + return yyjson_mut_rawncpy(doc, val_str.c_str(), val_str.size()); +} + +} // namespace + +yyjson_mut_val *VariantCasts::ConvertVariantToJSON(yyjson_mut_doc *doc, const RecursiveUnifiedVectorFormat &source, + idx_t row, uint32_t values_idx) { + UnifiedVariantVectorData variant(source); + return VariantVisitor::Visit(variant, row, values_idx, doc); } } // namespace duckdb diff --git a/src/duckdb/src/function/cast/variant/to_variant.cpp b/src/duckdb/src/function/cast/variant/to_variant.cpp index ad1962d37..4cfb290f4 100644 --- a/src/duckdb/src/function/cast/variant/to_variant.cpp +++ b/src/duckdb/src/function/cast/variant/to_variant.cpp @@ -8,7 +8,6 @@ #include "duckdb/function/cast/variant/to_variant.hpp" namespace duckdb { - namespace variant { static void InitializeOffsets(DataChunk &offsets, idx_t count) { @@ -84,39 +83,6 @@ static void InitializeVariants(DataChunk &offsets, Vector &result, SelectionVect selvec_size = keys_offset; } -static void FinalizeVariantKeys(Vector &variant, OrderedOwningStringMap &dictionary, SelectionVector &sel, - idx_t sel_size) { - auto &keys = VariantVector::GetKeys(variant); - auto &keys_entry = ListVector::GetEntry(keys); - auto keys_entry_data = FlatVector::GetData(keys_entry); - - bool already_sorted = true; - - vector unsorted_to_sorted(dictionary.size()); - auto it = dictionary.begin(); - for (uint32_t sorted_idx = 0; sorted_idx < dictionary.size(); sorted_idx++) { - auto unsorted_idx = it->second; - if (unsorted_idx != sorted_idx) { - already_sorted = false; - } - unsorted_to_sorted[unsorted_idx] = sorted_idx; - D_ASSERT(sorted_idx < ListVector::GetListSize(keys)); - keys_entry_data[sorted_idx] = it->first; - auto size = static_cast(keys_entry_data[sorted_idx].GetSize()); - keys_entry_data[sorted_idx].SetSizeAndFinalize(size, size); - it++; - } - - if (!already_sorted) { - //! Adjust the selection vector to point to the right dictionary index - for (idx_t i = 0; i < sel_size; i++) { - auto &entry = sel[i]; - auto sorted_idx = unsorted_to_sorted[entry]; - entry = sorted_idx; - } - } -} - static bool GatherOffsetsAndSizes(ToVariantSourceData &source, ToVariantGlobalResultData &result, idx_t count) { InitializeOffsets(result.offsets, count); //! First pass - collect sizes/offsets @@ -130,6 +96,9 @@ static bool WriteVariantResultData(ToVariantSourceData &source, ToVariantGlobalR } static bool CastToVARIANT(Vector &source, Vector &result, idx_t count, CastParameters ¶meters) { + if (!count) { + return true; + } DataChunk offsets; offsets.Initialize(Allocator::DefaultAllocator(), {LogicalType::UINTEGER, LogicalType::UINTEGER, LogicalType::UINTEGER, LogicalType::UINTEGER}, @@ -163,7 +132,7 @@ static bool CastToVARIANT(Vector &source, Vector &result, idx_t count, CastParam } } - FinalizeVariantKeys(result, dictionary, keys_selvec, keys_selvec_size); + VariantUtils::FinalizeVariantKeys(result, dictionary, keys_selvec, keys_selvec_size); //! Finalize the 'data' auto &blob = VariantVector::GetData(result); auto blob_data = FlatVector::GetData(blob); diff --git a/src/duckdb/src/function/cast_rules.cpp b/src/duckdb/src/function/cast_rules.cpp index d73bc38e5..c077cd87c 100644 --- a/src/duckdb/src/function/cast_rules.cpp +++ b/src/duckdb/src/function/cast_rules.cpp @@ -146,7 +146,6 @@ static int64_t ImplicitCastUSmallint(const LogicalType &to) { static int64_t ImplicitCastUInteger(const LogicalType &to) { switch (to.id()) { - case LogicalTypeId::UBIGINT: case LogicalTypeId::BIGINT: case LogicalTypeId::UHUGEINT: @@ -187,7 +186,6 @@ static int64_t ImplicitCastFloat(const LogicalType &to) { static int64_t ImplicitCastDouble(const LogicalType &to) { switch (to.id()) { - case LogicalTypeId::BIGNUM: return TargetTypeCost(to); default: @@ -500,7 +498,6 @@ int64_t CastRules::ImplicitCast(const LogicalType &from, const LogicalType &to) int64_t cost = -1; if (named_struct_cast) { - // Collect the target members in a map for easy lookup case_insensitive_map_t target_members; for (idx_t target_idx = 0; target_idx < target_children.size(); target_idx++) { diff --git a/src/duckdb/src/function/copy_blob.cpp b/src/duckdb/src/function/copy_blob.cpp index 2af12a8c3..398eb9534 100644 --- a/src/duckdb/src/function/copy_blob.cpp +++ b/src/duckdb/src/function/copy_blob.cpp @@ -61,7 +61,6 @@ struct WriteBlobGlobalState final : public GlobalFunctionData { unique_ptr WriteBlobInitializeGlobal(ClientContext &context, FunctionData &bind_data, const string &file_path) { - auto &bdata = bind_data.Cast(); auto &fs = FileSystem::GetFileSystem(context); @@ -102,7 +101,6 @@ void WriteBlobSink(ExecutionContext &context, FunctionData &bind_data, GlobalFun for (idx_t row_idx = 0; row_idx < input.size(); row_idx++) { const auto out_idx = vdata.sel->get_index(row_idx); if (vdata.validity.RowIsValid(out_idx)) { - auto &blob = blobs[out_idx]; auto blob_len = blob.GetSize(); auto blob_ptr = blob.GetDataWriteable(); diff --git a/src/duckdb/src/function/function_binder.cpp b/src/duckdb/src/function/function_binder.cpp index c9ad55d92..fa48257fd 100644 --- a/src/duckdb/src/function/function_binder.cpp +++ b/src/duckdb/src/function/function_binder.cpp @@ -334,8 +334,8 @@ unique_ptr FunctionBinder::BindScalarFunction(ScalarFunctionCatalogE // Some functions may have an invalid default return type, as they must be bound to infer the return type. // In those cases, we default to SQLNULL. const auto return_type_if_null = - bound_function.return_type.IsComplete() ? bound_function.return_type : LogicalType::SQLNULL; - if (bound_function.null_handling == FunctionNullHandling::DEFAULT_NULL_HANDLING) { + bound_function.GetReturnType().IsComplete() ? bound_function.GetReturnType() : LogicalType::SQLNULL; + if (bound_function.GetNullHandling() == FunctionNullHandling::DEFAULT_NULL_HANDLING) { for (auto &child : children) { if (child->return_type == LogicalTypeId::SQLNULL) { return make_uniq(Value(return_type_if_null)); @@ -378,7 +378,7 @@ static string ExtractCollation(const vector> &children) { static void PropagateCollations(ClientContext &, ScalarFunction &bound_function, vector> &children) { - if (!RequiresCollationPropagation(bound_function.return_type)) { + if (!RequiresCollationPropagation(bound_function.GetReturnType())) { // we only need to propagate if the function returns a varchar return; } @@ -389,7 +389,7 @@ static void PropagateCollations(ClientContext &, ScalarFunction &bound_function, } // propagate the collation to the return type auto collation_type = LogicalType::VARCHAR_COLLATION(std::move(collation)); - bound_function.return_type = std::move(collation_type); + bound_function.SetReturnType(std::move(collation_type)); } static void PushCollations(ClientContext &context, ScalarFunction &bound_function, @@ -401,8 +401,8 @@ static void PushCollations(ClientContext &context, ScalarFunction &bound_functio } // push collation into the return type if required auto collation_type = LogicalType::VARCHAR_COLLATION(std::move(collation)); - if (RequiresCollationPropagation(bound_function.return_type)) { - bound_function.return_type = collation_type; + if (RequiresCollationPropagation(bound_function.GetReturnType())) { + bound_function.SetReturnType(collation_type); } // push collations to the children for (auto &arg : children) { @@ -417,7 +417,7 @@ static void PushCollations(ClientContext &context, ScalarFunction &bound_functio static void HandleCollations(ClientContext &context, ScalarFunction &bound_function, vector> &children) { - switch (bound_function.collation_handling) { + switch (bound_function.GetCollationHandling()) { case FunctionCollationHandling::IGNORE_COLLATIONS: // explicitly ignoring collation handling break; @@ -436,7 +436,6 @@ static void HandleCollations(ClientContext &context, ScalarFunction &bound_funct static void InferTemplateType(ClientContext &context, const LogicalType &source, const LogicalType &target, case_insensitive_map_t> &bindings, const Expression ¤t_expr, const BaseScalarFunction &function) { - if (target.id() == LogicalTypeId::UNKNOWN || target.id() == LogicalTypeId::SQLNULL) { // If the actual type is unknown, we cannot infer anything more. // Therefore, we map all remaining templates in the source to UNKNOWN or SQLNULL, if not already inferred to @@ -517,7 +516,6 @@ static void InferTemplateType(ClientContext &context, const LogicalType &source, case LogicalTypeId::ARRAY: { if ((source.id() == LogicalTypeId::ARRAY || source.id() == LogicalTypeId::LIST) && (target.id() == LogicalTypeId::LIST || target.id() == LogicalTypeId::ARRAY)) { - const auto &source_child = source.id() == LogicalTypeId::LIST ? ListType::GetChildType(source) : ArrayType::GetChildType(source); const auto &target_child = @@ -565,7 +563,6 @@ static void InferTemplateType(ClientContext &context, const LogicalType &source, static void SubstituteTemplateType(LogicalType &type, case_insensitive_map_t> &bindings, const string &function_name) { - // Replace all template types in with their bound concrete types. type = TypeVisitor::VisitReplace(type, [&](const LogicalType &t) -> LogicalType { if (t.id() == LogicalTypeId::TEMPLATE) { @@ -614,8 +611,8 @@ void FunctionBinder::ResolveTemplateTypes(BaseScalarFunction &bound_function, } // If the return type is templated, we need to subsitute it as well - if (bound_function.return_type.IsTemplated()) { - to_substitute.emplace_back(bound_function.return_type); + if (bound_function.GetReturnType().IsTemplated()) { + to_substitute.emplace_back(bound_function.GetReturnType()); } // Finally, substitute all template types in the bound function with their concrete types. @@ -641,13 +638,12 @@ void FunctionBinder::CheckTemplateTypesResolved(const BaseScalarFunction &bound_ VerifyTemplateType(arg, bound_function.name); } VerifyTemplateType(bound_function.varargs, bound_function.name); - VerifyTemplateType(bound_function.return_type, bound_function.name); + VerifyTemplateType(bound_function.GetReturnType(), bound_function.name); } unique_ptr FunctionBinder::BindScalarFunction(ScalarFunction bound_function, vector> children, bool is_operator, optional_ptr binder) { - // Attempt to resolve template types, before we call the "Bind" callback. ResolveTemplateTypes(bound_function, children); @@ -679,7 +675,7 @@ unique_ptr FunctionBinder::BindScalarFunction(ScalarFunction bound_f // check if we need to add casts to the children CastToFunctionArguments(bound_function, children); - auto return_type = bound_function.return_type; + auto return_type = bound_function.GetReturnType(); unique_ptr result; auto result_func = make_uniq(std::move(return_type), std::move(bound_function), std::move(children), std::move(bind_info), is_operator); @@ -698,7 +694,6 @@ unique_ptr FunctionBinder::BindAggregateFunction(Aggre vector> children, unique_ptr filter, AggregateType aggr_type) { - ResolveTemplateTypes(bound_function, children); unique_ptr bind_info; diff --git a/src/duckdb/src/function/function_list.cpp b/src/duckdb/src/function/function_list.cpp index d73467d3a..180303399 100644 --- a/src/duckdb/src/function/function_list.cpp +++ b/src/duckdb/src/function/function_list.cpp @@ -4,6 +4,7 @@ #include "duckdb/function/scalar/compressed_materialization_functions.hpp" #include "duckdb/function/scalar/date_functions.hpp" #include "duckdb/function/scalar/generic_functions.hpp" +#include "duckdb/function/scalar/geometry_functions.hpp" #include "duckdb/function/scalar/list_functions.hpp" #include "duckdb/function/scalar/map_functions.hpp" #include "duckdb/function/scalar/variant_functions.hpp" @@ -15,6 +16,7 @@ #include "duckdb/parser/parsed_data/create_aggregate_function_info.hpp" #include "duckdb/parser/parsed_data/create_scalar_function_info.hpp" + namespace duckdb { // Scalar Function @@ -45,6 +47,7 @@ static const StaticFunctionDefinition function[] = { DUCKDB_SCALAR_FUNCTION(NotLikeFun), DUCKDB_SCALAR_FUNCTION(NotILikeFun), DUCKDB_SCALAR_FUNCTION_SET(OperatorModuloFun), + DUCKDB_SCALAR_FUNCTION_ALIAS(StIntersectsExtentFunAlias), DUCKDB_SCALAR_FUNCTION_SET(OperatorMultiplyFun), DUCKDB_SCALAR_FUNCTION_SET(OperatorAddFun), DUCKDB_SCALAR_FUNCTION_SET(OperatorSubtractFun), @@ -151,6 +154,12 @@ static const StaticFunctionDefinition function[] = { DUCKDB_SCALAR_FUNCTION_SET(SHA1Fun), DUCKDB_SCALAR_FUNCTION_SET(SHA256Fun), DUCKDB_SCALAR_FUNCTION_ALIAS(SplitFun), + DUCKDB_SCALAR_FUNCTION_ALIAS(StAsbinaryFun), + DUCKDB_SCALAR_FUNCTION(StAstextFun), + DUCKDB_SCALAR_FUNCTION(StAswkbFun), + DUCKDB_SCALAR_FUNCTION_ALIAS(StAswktFun), + DUCKDB_SCALAR_FUNCTION(StGeomfromwkbFun), + DUCKDB_SCALAR_FUNCTION(StIntersectsExtentFun), DUCKDB_SCALAR_FUNCTION_ALIAS(StrSplitFun), DUCKDB_SCALAR_FUNCTION_SET_ALIAS(StrSplitRegexFun), DUCKDB_SCALAR_FUNCTION_SET(StrfTimeFun), @@ -177,6 +186,7 @@ static const StaticFunctionDefinition function[] = { DUCKDB_SCALAR_FUNCTION_ALIAS(UcaseFun), DUCKDB_SCALAR_FUNCTION(UpperFun), DUCKDB_SCALAR_FUNCTION_SET(VariantExtractFun), + DUCKDB_SCALAR_FUNCTION(VariantNormalizeFun), DUCKDB_SCALAR_FUNCTION(VariantTypeofFun), DUCKDB_SCALAR_FUNCTION_SET(WriteLogFun), DUCKDB_SCALAR_FUNCTION(ConcatOperatorFun), diff --git a/src/duckdb/src/function/macro_function.cpp b/src/duckdb/src/function/macro_function.cpp index 2f407c025..90edb191f 100644 --- a/src/duckdb/src/function/macro_function.cpp +++ b/src/duckdb/src/function/macro_function.cpp @@ -45,9 +45,8 @@ MacroBindResult MacroFunction::BindMacroFunction( Binder &binder, const vector> &functions, const string &name, FunctionExpression &function_expr, vector> &positional_arguments, InsertionOrderPreservingMap> &named_arguments, idx_t depth) { - ExpressionBinder expr_binder(binder, binder.context); - + expr_binder.lambda_bindings = binder.lambda_bindings; // Find argument types and separate positional and default arguments vector positional_arg_types; InsertionOrderPreservingMap named_arg_types; diff --git a/src/duckdb/src/function/scalar/compressed_materialization/compress_integral.cpp b/src/duckdb/src/function/scalar/compressed_materialization/compress_integral.cpp index 740924397..5bd4f7ebc 100644 --- a/src/duckdb/src/function/scalar/compressed_materialization/compress_integral.cpp +++ b/src/duckdb/src/function/scalar/compressed_materialization/compress_integral.cpp @@ -190,7 +190,7 @@ scalar_function_t GetIntegralDecompressFunctionInputSwitch(const LogicalType &in void CMIntegralSerialize(Serializer &serializer, const optional_ptr bind_data, const ScalarFunction &function) { serializer.WriteProperty(100, "arguments", function.arguments); - serializer.WriteProperty(101, "return_type", function.return_type); + serializer.WriteProperty(101, "return_type", function.GetReturnType()); } template @@ -229,9 +229,9 @@ ScalarFunction CMIntegralCompressFun::GetFunction(const LogicalType &input_type, result.serialize = CMIntegralSerialize; result.deserialize = CMIntegralDeserialize; #if defined(D_ASSERT_IS_ENABLED) - result.errors = FunctionErrors::CAN_THROW_RUNTIME_ERROR; // Can only throw runtime error when assertions are enabled + result.SetFallible(); // Can only throw runtime error when assertions are enabled #else - result.errors = FunctionErrors::CANNOT_ERROR; + result.SetErrorMode(FunctionErrors::CANNOT_ERROR); #endif return result; } diff --git a/src/duckdb/src/function/scalar/compressed_materialization/compress_string.cpp b/src/duckdb/src/function/scalar/compressed_materialization/compress_string.cpp index 39821858d..21e92c5c3 100644 --- a/src/duckdb/src/function/scalar/compressed_materialization/compress_string.cpp +++ b/src/duckdb/src/function/scalar/compressed_materialization/compress_string.cpp @@ -198,7 +198,7 @@ scalar_function_t GetStringDecompressFunctionSwitch(const LogicalType &input_typ case LogicalTypeId::UHUGEINT: return GetStringDecompressFunction(input_type); case LogicalTypeId::HUGEINT: - return GetStringCompressFunction(input_type); + return GetStringDecompressFunction(input_type); default: throw InternalException("Unexpected type in GetStringDecompressFunctionSwitch"); } @@ -207,7 +207,7 @@ scalar_function_t GetStringDecompressFunctionSwitch(const LogicalType &input_typ void CMStringCompressSerialize(Serializer &serializer, const optional_ptr bind_data, const ScalarFunction &function) { serializer.WriteProperty(100, "arguments", function.arguments); - serializer.WriteProperty(101, "return_type", function.return_type); + serializer.WriteProperty(101, "return_type", function.GetReturnType()); } unique_ptr CMStringCompressDeserialize(Deserializer &deserializer, ScalarFunction &function) { @@ -225,7 +225,7 @@ void CMStringDecompressSerialize(Serializer &serializer, const optional_ptr CMStringDecompressDeserialize(Deserializer &deserializer, ScalarFunction &function) { function.arguments = deserializer.ReadProperty>(100, "arguments"); function.function = GetStringDecompressFunctionSwitch(function.arguments[0]); - function.return_type = deserializer.Get(); + function.SetReturnType(deserializer.Get()); return nullptr; } @@ -248,9 +248,9 @@ ScalarFunction CMStringCompressFun::GetFunction(const LogicalType &result_type) result.serialize = CMStringCompressSerialize; result.deserialize = CMStringCompressDeserialize; #if defined(D_ASSERT_IS_ENABLED) - result.errors = FunctionErrors::CAN_THROW_RUNTIME_ERROR; // Can only throw runtime error when assertions are enabled + result.SetFallible(); // Can only throw runtime error when assertions are enabled #else - result.errors = FunctionErrors::CANNOT_ERROR; + result.SetErrorMode(FunctionErrors::CANNOT_ERROR); #endif return result; } diff --git a/src/duckdb/src/function/scalar/create_sort_key.cpp b/src/duckdb/src/function/scalar/create_sort_key.cpp index 2f5463e3f..d93e20d71 100644 --- a/src/duckdb/src/function/scalar/create_sort_key.cpp +++ b/src/duckdb/src/function/scalar/create_sort_key.cpp @@ -63,7 +63,7 @@ unique_ptr CreateSortKeyBind(ClientContext &context, ScalarFunctio } if (all_constant) { if (constant_size <= sizeof(int64_t)) { - bound_function.return_type = LogicalType::BIGINT; + bound_function.SetReturnType(LogicalType::BIGINT); } } return std::move(result); @@ -696,13 +696,15 @@ void PrepareSortData(Vector &result, idx_t size, SortKeyLengthInfo &key_lengths, } } -void FinalizeSortData(Vector &result, idx_t size) { +void FinalizeSortData(Vector &result, idx_t size, const SortKeyLengthInfo &key_lengths, + const unsafe_vector &offsets) { switch (result.GetType().id()) { case LogicalTypeId::BLOB: { auto result_data = FlatVector::GetData(result); // call Finalize on the result for (idx_t r = 0; r < size; r++) { - result_data[r].Finalize(); + result_data[r].SetSizeAndFinalize(NumericCast(offsets[r]), + key_lengths.variable_lengths[r] + key_lengths.constant_length); } break; } @@ -739,7 +741,7 @@ void CreateSortKeyInternal(vector> &sort_key_data, SortKeyConstructInfo info(modifiers[c], offsets, data_pointers.get()); ConstructSortKey(*sort_key_data[c], info); } - FinalizeSortData(result, row_count); + FinalizeSortData(result, row_count, key_lengths, offsets); } } // namespace @@ -861,7 +863,7 @@ unique_ptr DecodeSortKeyBind(ClientContext &context, ScalarFunctio throw BinderException("sort_key must be either BIGINT or BLOB, got %s instead", sort_key_arg.return_type.ToString()); } - bound_function.return_type = LogicalType::STRUCT(std::move(children)); + bound_function.SetReturnType(LogicalType::STRUCT(std::move(children))); return std::move(result); } @@ -1156,11 +1158,13 @@ void DecodeSortKeyRecursive(DecodeSortKeyData decode_data[], DecodeSortKeyVector } // namespace -void CreateSortKeyHelpers::DecodeSortKey(string_t sort_key, Vector &result, idx_t result_idx, - OrderModifiers modifiers) { +idx_t CreateSortKeyHelpers::DecodeSortKey(string_t sort_key, Vector &result, idx_t result_idx, + OrderModifiers modifiers) { DecodeSortKeyVectorData sort_key_data(result.GetType(), modifiers); DecodeSortKeyData decode_data(sort_key); DecodeSortKeyRecursive(&decode_data, sort_key_data, result, result_idx, 1); + + return decode_data.position; } void CreateSortKeyHelpers::DecodeSortKey(string_t sort_key, DataChunk &result, idx_t result_idx, @@ -1242,7 +1246,7 @@ ScalarFunction CreateSortKeyFun::GetFunction() { ScalarFunction sort_key_function("create_sort_key", {LogicalType::ANY}, LogicalType::BLOB, CreateSortKeyFunction, CreateSortKeyBind); sort_key_function.varargs = LogicalType::ANY; - sort_key_function.null_handling = FunctionNullHandling::SPECIAL_HANDLING; + sort_key_function.SetNullHandling(FunctionNullHandling::SPECIAL_HANDLING); return sort_key_function; } diff --git a/src/duckdb/src/function/scalar/date/strftime.cpp b/src/duckdb/src/function/scalar/date/strftime.cpp index 66a044f34..bdf2931b8 100644 --- a/src/duckdb/src/function/scalar/date/strftime.cpp +++ b/src/duckdb/src/function/scalar/date/strftime.cpp @@ -148,7 +148,6 @@ inline bool StrpTimeTryResult(StrpTimeFormat &format, string_t &input, timestamp } struct StrpTimeFunction { - template static void Parse(DataChunk &args, ExpressionState &state, Vector &result) { auto &func_expr = state.expr.Cast(); @@ -225,9 +224,9 @@ struct StrpTimeFunction { error); } if (format.HasFormatSpecifier(StrTimeSpecifier::UTC_OFFSET)) { - bound_function.return_type = LogicalType::TIMESTAMP_TZ; + bound_function.SetReturnType(LogicalType::TIMESTAMP_TZ); } else if (format.HasFormatSpecifier(StrTimeSpecifier::NANOSECOND_PADDED)) { - bound_function.return_type = LogicalType::TIMESTAMP_NS; + bound_function.SetReturnType(LogicalType::TIMESTAMP_NS); if (bound_function.name == "strptime") { bound_function.function = Parse; } else { @@ -261,11 +260,11 @@ struct StrpTimeFunction { if (has_offset) { // If any format has UTC offsets, then we have to produce TSTZ - bound_function.return_type = LogicalType::TIMESTAMP_TZ; + bound_function.SetReturnType(LogicalType::TIMESTAMP_TZ); } else if (has_nanos) { // If any format has nanoseconds, then we have to produce TSNS // unless there is an offset, in which case we produce - bound_function.return_type = LogicalType::TIMESTAMP_NS; + bound_function.SetReturnType(LogicalType::TIMESTAMP_NS); if (bound_function.name == "strptime") { bound_function.function = Parse; } else { @@ -304,14 +303,14 @@ ScalarFunctionSet StrpTimeFun::GetFunctions() { const auto list_type = LogicalType::LIST(LogicalType::VARCHAR); auto fun = ScalarFunction({LogicalType::VARCHAR, LogicalType::VARCHAR}, LogicalType::TIMESTAMP, StrpTimeFunction::Parse, StrpTimeFunction::Bind); - fun.null_handling = FunctionNullHandling::SPECIAL_HANDLING; - BaseScalarFunction::SetReturnsError(fun); + fun.SetNullHandling(FunctionNullHandling::SPECIAL_HANDLING); + fun.SetFallible(); strptime.AddFunction(fun); fun = ScalarFunction({LogicalType::VARCHAR, list_type}, LogicalType::TIMESTAMP, StrpTimeFunction::Parse, StrpTimeFunction::Bind); - BaseScalarFunction::SetReturnsError(fun); - fun.null_handling = FunctionNullHandling::SPECIAL_HANDLING; + fun.SetFallible(); + fun.SetNullHandling(FunctionNullHandling::SPECIAL_HANDLING); strptime.AddFunction(fun); return strptime; } @@ -322,12 +321,12 @@ ScalarFunctionSet TryStrpTimeFun::GetFunctions() { const auto list_type = LogicalType::LIST(LogicalType::VARCHAR); auto fun = ScalarFunction({LogicalType::VARCHAR, LogicalType::VARCHAR}, LogicalType::TIMESTAMP, StrpTimeFunction::TryParse, StrpTimeFunction::Bind); - fun.null_handling = FunctionNullHandling::SPECIAL_HANDLING; + fun.SetNullHandling(FunctionNullHandling::SPECIAL_HANDLING); try_strptime.AddFunction(fun); fun = ScalarFunction({LogicalType::VARCHAR, list_type}, LogicalType::TIMESTAMP, StrpTimeFunction::TryParse, StrpTimeFunction::Bind); - fun.null_handling = FunctionNullHandling::SPECIAL_HANDLING; + fun.SetNullHandling(FunctionNullHandling::SPECIAL_HANDLING); try_strptime.AddFunction(fun); return try_strptime; diff --git a/src/duckdb/src/function/scalar/generic/constant_or_null.cpp b/src/duckdb/src/function/scalar/generic/constant_or_null.cpp index c5c4307bd..533aef715 100644 --- a/src/duckdb/src/function/scalar/generic/constant_or_null.cpp +++ b/src/duckdb/src/function/scalar/generic/constant_or_null.cpp @@ -81,7 +81,7 @@ unique_ptr ConstantOrNullBind(ClientContext &context, ScalarFuncti } D_ASSERT(arguments.size() >= 2); auto value = ExpressionExecutor::EvaluateScalar(context, *arguments[0]); - bound_function.return_type = arguments[0]->return_type; + bound_function.SetReturnType(arguments[0]->return_type); return make_uniq(std::move(value)); } diff --git a/src/duckdb/src/function/scalar/generic/error.cpp b/src/duckdb/src/function/scalar/generic/error.cpp index 30d2a5f13..f2847786c 100644 --- a/src/duckdb/src/function/scalar/generic/error.cpp +++ b/src/duckdb/src/function/scalar/generic/error.cpp @@ -26,8 +26,8 @@ static void ErrorFunction(DataChunk &args, ExpressionState &state, Vector &resul ScalarFunction ErrorFun::GetFunction() { auto fun = ScalarFunction("error", {LogicalType::VARCHAR}, LogicalType::SQLNULL, ErrorFunction); // Set the function with side effects to avoid the optimization. - fun.stability = FunctionStability::VOLATILE; - BaseScalarFunction::SetReturnsError(fun); + fun.SetVolatile(); + fun.SetFallible(); return fun; } diff --git a/src/duckdb/src/function/scalar/generic/getvariable.cpp b/src/duckdb/src/function/scalar/generic/getvariable.cpp index 52c63488f..1f1975892 100644 --- a/src/duckdb/src/function/scalar/generic/getvariable.cpp +++ b/src/duckdb/src/function/scalar/generic/getvariable.cpp @@ -36,7 +36,7 @@ unique_ptr GetVariableBind(ClientContext &context, ScalarFunction if (!variable_name.IsNull()) { ClientConfig::GetConfig(context).GetUserVariable(variable_name.ToString(), value); } - function.return_type = value.type(); + function.SetReturnType(value.type()); return make_uniq(std::move(value)); } diff --git a/src/duckdb/src/function/scalar/geometry/geometry_functions.cpp b/src/duckdb/src/function/scalar/geometry/geometry_functions.cpp new file mode 100644 index 000000000..18a259460 --- /dev/null +++ b/src/duckdb/src/function/scalar/geometry/geometry_functions.cpp @@ -0,0 +1,65 @@ +#include "duckdb/function/scalar/geometry_functions.hpp" +#include "duckdb/common/types/geometry.hpp" +#include "duckdb/common/vector_operations/binary_executor.hpp" + +namespace duckdb { + +static void FromWKBFunction(DataChunk &input, ExpressionState &state, Vector &result) { + Geometry::FromBinary(input.data[0], result, input.size(), true); +} + +ScalarFunction StGeomfromwkbFun::GetFunction() { + ScalarFunction function({LogicalType::BLOB}, LogicalType::GEOMETRY(), FromWKBFunction); + return function; +} + +static void ToWKBFunction(DataChunk &input, ExpressionState &state, Vector &result) { + UnaryExecutor::Execute(input.data[0], result, input.size(), [&](const string_t &geom) { + // TODO: convert to internal representation + return geom; + }); + // Add a heap reference to the input WKB to prevent it from being freed + StringVector::AddHeapReference(input.data[0], result); +} + +ScalarFunction StAswkbFun::GetFunction() { + ScalarFunction function({LogicalType::GEOMETRY()}, LogicalType::BLOB, ToWKBFunction); + return function; +} + +static void ToWKTFunction(DataChunk &input, ExpressionState &state, Vector &result) { + UnaryExecutor::Execute(input.data[0], result, input.size(), + [&](const string_t &geom) { return Geometry::ToString(result, geom); }); +} + +ScalarFunction StAstextFun::GetFunction() { + ScalarFunction function({LogicalType::GEOMETRY()}, LogicalType::VARCHAR, ToWKTFunction); + return function; +} + +static void IntersectsExtentFunction(DataChunk &input, ExpressionState &state, Vector &result) { + BinaryExecutor::Execute( + input.data[0], input.data[1], result, input.size(), [](const string_t &lhs_geom, const string_t &rhs_geom) { + auto lhs_extent = GeometryExtent::Empty(); + auto rhs_extent = GeometryExtent::Empty(); + + const auto lhs_is_empty = Geometry::GetExtent(lhs_geom, lhs_extent) == 0; + const auto rhs_is_empty = Geometry::GetExtent(rhs_geom, rhs_extent) == 0; + + if (lhs_is_empty || rhs_is_empty) { + // One of the geometries is empty + return false; + } + + // Don't take Z and M into account for intersection test + return lhs_extent.IntersectsXY(rhs_extent); + }); +} + +ScalarFunction StIntersectsExtentFun::GetFunction() { + ScalarFunction function({LogicalType::GEOMETRY(), LogicalType::GEOMETRY()}, LogicalType::BOOLEAN, + IntersectsExtentFunction); + return function; +} + +} // namespace duckdb diff --git a/src/duckdb/src/function/scalar/list/contains_or_position.cpp b/src/duckdb/src/function/scalar/list/contains_or_position.cpp index 064bd4b00..bd4a2de51 100644 --- a/src/duckdb/src/function/scalar/list/contains_or_position.cpp +++ b/src/duckdb/src/function/scalar/list/contains_or_position.cpp @@ -34,7 +34,7 @@ ScalarFunction ListContainsFun::GetFunction() { ScalarFunction ListPositionFun::GetFunction() { auto fun = ScalarFunction({LogicalType::LIST(LogicalType::TEMPLATE("T")), LogicalType::TEMPLATE("T")}, LogicalType::INTEGER, ListSearchFunction); - fun.null_handling = FunctionNullHandling::SPECIAL_HANDLING; + fun.SetNullHandling(FunctionNullHandling::SPECIAL_HANDLING); return fun; } diff --git a/src/duckdb/src/function/scalar/list/list_extract.cpp b/src/duckdb/src/function/scalar/list/list_extract.cpp index fd79249d9..d4ed220dd 100644 --- a/src/duckdb/src/function/scalar/list/list_extract.cpp +++ b/src/duckdb/src/function/scalar/list/list_extract.cpp @@ -157,8 +157,8 @@ ScalarFunctionSet ListExtractFun::GetFunctions() { LogicalType::TEMPLATE("T"), ListExtractFunction, ListExtractBind, nullptr, ListExtractStats); ScalarFunction sfun({LogicalType::VARCHAR, LogicalType::BIGINT}, LogicalType::VARCHAR, ListExtractFunction); - BaseScalarFunction::SetReturnsError(lfun); - BaseScalarFunction::SetReturnsError(sfun); + lfun.SetFallible(); + sfun.SetFallible(); list_extract_set.AddFunction(lfun); list_extract_set.AddFunction(sfun); return list_extract_set; diff --git a/src/duckdb/src/function/scalar/list/list_resize.cpp b/src/duckdb/src/function/scalar/list/list_resize.cpp index d159a7204..19fd149e3 100644 --- a/src/duckdb/src/function/scalar/list/list_resize.cpp +++ b/src/duckdb/src/function/scalar/list/list_resize.cpp @@ -8,7 +8,6 @@ namespace duckdb { static void ListResizeFunction(DataChunk &args, ExpressionState &, Vector &result) { - // Early-out, if the return value is a constant NULL. if (result.GetType().id() == LogicalTypeId::SQLNULL) { result.SetVectorType(VectorType::CONSTANT_VECTOR); @@ -63,7 +62,6 @@ static void ListResizeFunction(DataChunk &args, ExpressionState &, Vector &resul idx_t offset = 0; for (idx_t row_idx = 0; row_idx < row_count; row_idx++) { - auto list_idx = lists_data.sel->get_index(row_idx); auto new_size_idx = new_sizes_data.sel->get_index(row_idx); @@ -134,14 +132,14 @@ static unique_ptr ListResizeBind(ClientContext &context, ScalarFun // Early-out, if the first argument is a constant NULL. if (arguments[0]->return_type == LogicalType::SQLNULL) { bound_function.arguments[0] = LogicalType::SQLNULL; - bound_function.return_type = LogicalType::SQLNULL; - return make_uniq(bound_function.return_type); + bound_function.SetReturnType(LogicalType::SQLNULL); + return make_uniq(bound_function.GetReturnType()); } // Early-out, if the first argument is a prepared statement. if (arguments[0]->return_type == LogicalType::UNKNOWN) { - bound_function.return_type = arguments[0]->return_type; - return make_uniq(bound_function.return_type); + bound_function.SetReturnType(arguments[0]->return_type); + return make_uniq(bound_function.GetReturnType()); } // Attempt implicit casting, if the default type does not match list the list child type. @@ -151,19 +149,19 @@ static unique_ptr ListResizeBind(ClientContext &context, ScalarFun bound_function.arguments[2] = ListType::GetChildType(arguments[0]->return_type); } - bound_function.return_type = arguments[0]->return_type; - return make_uniq(bound_function.return_type); + bound_function.SetReturnType(arguments[0]->return_type); + return make_uniq(bound_function.GetReturnType()); } ScalarFunctionSet ListResizeFun::GetFunctions() { ScalarFunction simple_fun({LogicalType::LIST(LogicalTypeId::ANY), LogicalTypeId::ANY}, LogicalType::LIST(LogicalTypeId::ANY), ListResizeFunction, ListResizeBind); - simple_fun.null_handling = FunctionNullHandling::SPECIAL_HANDLING; - BaseScalarFunction::SetReturnsError(simple_fun); + simple_fun.SetNullHandling(FunctionNullHandling::SPECIAL_HANDLING); + simple_fun.SetFallible(); ScalarFunction default_value_fun({LogicalType::LIST(LogicalTypeId::ANY), LogicalTypeId::ANY, LogicalTypeId::ANY}, LogicalType::LIST(LogicalTypeId::ANY), ListResizeFunction, ListResizeBind); - default_value_fun.null_handling = FunctionNullHandling::SPECIAL_HANDLING; - BaseScalarFunction::SetReturnsError(default_value_fun); + default_value_fun.SetNullHandling(FunctionNullHandling::SPECIAL_HANDLING); + default_value_fun.SetFallible(); ScalarFunctionSet list_resize_set("list_resize"); list_resize_set.AddFunction(simple_fun); list_resize_set.AddFunction(default_value_fun); diff --git a/src/duckdb/src/function/scalar/list/list_zip.cpp b/src/duckdb/src/function/scalar/list/list_zip.cpp index ef39a989d..2f83b61d6 100644 --- a/src/duckdb/src/function/scalar/list/list_zip.cpp +++ b/src/duckdb/src/function/scalar/list/list_zip.cpp @@ -155,15 +155,14 @@ static unique_ptr ListZipBind(ClientContext &context, ScalarFuncti throw BinderException("Parameter type needs to be List"); } } - bound_function.return_type = LogicalType::LIST(LogicalType::STRUCT(struct_children)); - return make_uniq(bound_function.return_type); + bound_function.SetReturnType(LogicalType::LIST(LogicalType::STRUCT(struct_children))); + return make_uniq(bound_function.GetReturnType()); } ScalarFunction ListZipFun::GetFunction() { - auto fun = ScalarFunction({}, LogicalType::LIST(LogicalTypeId::STRUCT), ListZipFunction, ListZipBind); fun.varargs = LogicalType::ANY; - fun.null_handling = FunctionNullHandling::SPECIAL_HANDLING; + fun.SetNullHandling(FunctionNullHandling::SPECIAL_HANDLING); return fun; } diff --git a/src/duckdb/src/function/scalar/nested_functions.cpp b/src/duckdb/src/function/scalar/nested_functions.cpp index 2d5359c4e..b09f04275 100644 --- a/src/duckdb/src/function/scalar/nested_functions.cpp +++ b/src/duckdb/src/function/scalar/nested_functions.cpp @@ -3,21 +3,22 @@ namespace duckdb { void MapUtil::ReinterpretMap(Vector &result, Vector &input, idx_t count) { + // Copy the list size + const auto list_size = ListVector::GetListSize(input); + ListVector::SetListSize(result, list_size); + UnifiedVectorFormat input_data; input.ToUnifiedFormat(count, input_data); + // Copy the list validity FlatVector::SetValidity(result, input_data.validity); // Copy the struct validity UnifiedVectorFormat input_struct_data; - ListVector::GetEntry(input).ToUnifiedFormat(count, input_struct_data); + ListVector::GetEntry(input).ToUnifiedFormat(list_size, input_struct_data); auto &result_struct = ListVector::GetEntry(result); FlatVector::SetValidity(result_struct, input_struct_data.validity); - // Copy the list size - auto list_size = ListVector::GetListSize(input); - ListVector::SetListSize(result, list_size); - // Copy the list buffer (the list_entry_t data) result.CopyBuffer(input); diff --git a/src/duckdb/src/function/scalar/operator/arithmetic.cpp b/src/duckdb/src/function/scalar/operator/arithmetic.cpp index 82cd9b5b7..2440f7560 100644 --- a/src/duckdb/src/function/scalar/operator/arithmetic.cpp +++ b/src/duckdb/src/function/scalar/operator/arithmetic.cpp @@ -239,7 +239,7 @@ unique_ptr BindDecimalArithmetic(ClientContext &conte bound_function.arguments[i] = result_type; } } - bound_function.return_type = result_type; + bound_function.SetReturnType(result_type); return bind_data; } @@ -249,7 +249,7 @@ unique_ptr BindDecimalAddSubtract(ClientContext &context, ScalarFu auto bind_data = BindDecimalArithmetic(context, bound_function, arguments); // now select the physical function to execute - auto &result_type = bound_function.return_type; + auto &result_type = bound_function.GetReturnType(); if (bind_data->check_overflow) { bound_function.function = GetScalarBinaryFunction(result_type.InternalType()); } else { @@ -270,14 +270,13 @@ void SerializeDecimalArithmetic(Serializer &serializer, const optional_ptrCast(); serializer.WriteProperty(100, "check_overflow", bind_data.check_overflow); - serializer.WriteProperty(101, "return_type", function.return_type); + serializer.WriteProperty(101, "return_type", function.GetReturnType()); serializer.WriteProperty(102, "arguments", function.arguments); } // TODO this is partially duplicated from the bind template unique_ptr DeserializeDecimalArithmetic(Deserializer &deserializer, ScalarFunction &bound_function) { - // // re-change the function pointers auto check_overflow = deserializer.ReadProperty(100, "check_overflow"); auto return_type = deserializer.ReadProperty(101, "return_type"); @@ -288,7 +287,7 @@ unique_ptr DeserializeDecimalArithmetic(Deserializer &deserializer bound_function.function = GetScalarBinaryFunction(return_type.InternalType()); } bound_function.statistics = nullptr; // TODO we likely dont want to do stats prop again - bound_function.return_type = return_type; + bound_function.SetReturnType(return_type); bound_function.arguments = arguments; auto bind_data = make_uniq(); @@ -298,7 +297,7 @@ unique_ptr DeserializeDecimalArithmetic(Deserializer &deserializer unique_ptr NopDecimalBind(ClientContext &context, ScalarFunction &bound_function, vector> &arguments) { - bound_function.return_type = arguments[0]->return_type; + bound_function.SetReturnType(arguments[0]->return_type); bound_function.arguments[0] = arguments[0]->return_type; return nullptr; } @@ -353,7 +352,7 @@ ScalarFunction AddFunction::GetFunction(const LogicalType &left_type, const Logi if (left_type.id() == LogicalTypeId::DECIMAL) { auto function = ScalarFunction("+", {left_type, right_type}, left_type, nullptr, BindDecimalAddSubtract); - BaseScalarFunction::SetReturnsError(function); + function.SetFallible(); function.serialize = SerializeDecimalArithmetic; function.deserialize = DeserializeDecimalArithmetic; return function; @@ -362,12 +361,12 @@ ScalarFunction AddFunction::GetFunction(const LogicalType &left_type, const Logi GetScalarIntegerFunction(left_type.InternalType()), nullptr, nullptr, PropagateNumericStats); - BaseScalarFunction::SetReturnsError(function); + function.SetFallible(); return function; } else { ScalarFunction function("+", {left_type, right_type}, left_type, GetScalarBinaryFunction(left_type.InternalType())); - BaseScalarFunction::SetReturnsError(function); + function.SetFallible(); return function; } } @@ -376,7 +375,7 @@ ScalarFunction AddFunction::GetFunction(const LogicalType &left_type, const Logi case LogicalTypeId::BIGNUM: if (right_type.id() == LogicalTypeId::BIGNUM) { ScalarFunction function("+", {left_type, right_type}, LogicalType::BIGNUM, BignumAdd); - BaseScalarFunction::SetReturnsError(function); + function.SetFallible(); return function; } break; @@ -385,22 +384,22 @@ ScalarFunction AddFunction::GetFunction(const LogicalType &left_type, const Logi if (right_type.id() == LogicalTypeId::INTEGER) { ScalarFunction function("+", {left_type, right_type}, LogicalType::DATE, ScalarFunction::BinaryFunction); - BaseScalarFunction::SetReturnsError(function); + function.SetFallible(); return function; } else if (right_type.id() == LogicalTypeId::INTERVAL) { ScalarFunction function("+", {left_type, right_type}, LogicalType::TIMESTAMP, ScalarFunction::BinaryFunction); - BaseScalarFunction::SetReturnsError(function); + function.SetFallible(); return function; } else if (right_type.id() == LogicalTypeId::TIME) { ScalarFunction function("+", {left_type, right_type}, LogicalType::TIMESTAMP, ScalarFunction::BinaryFunction); - BaseScalarFunction::SetReturnsError(function); + function.SetFallible(); return function; } else if (right_type.id() == LogicalTypeId::TIME_TZ) { ScalarFunction function("+", {left_type, right_type}, LogicalType::TIMESTAMP_TZ, ScalarFunction::BinaryFunction); - BaseScalarFunction::SetReturnsError(function); + function.SetFallible(); return function; } break; @@ -408,7 +407,7 @@ ScalarFunction AddFunction::GetFunction(const LogicalType &left_type, const Logi if (right_type.id() == LogicalTypeId::DATE) { ScalarFunction function("+", {left_type, right_type}, right_type, ScalarFunction::BinaryFunction); - BaseScalarFunction::SetReturnsError(function); + function.SetFallible(); return function; } break; @@ -416,28 +415,28 @@ ScalarFunction AddFunction::GetFunction(const LogicalType &left_type, const Logi if (right_type.id() == LogicalTypeId::INTERVAL) { ScalarFunction function("+", {left_type, right_type}, LogicalType::INTERVAL, ScalarFunction::BinaryFunction); - BaseScalarFunction::SetReturnsError(function); + function.SetFallible(); return function; } else if (right_type.id() == LogicalTypeId::DATE) { ScalarFunction function("+", {left_type, right_type}, LogicalType::TIMESTAMP, ScalarFunction::BinaryFunction); - BaseScalarFunction::SetReturnsError(function); + function.SetFallible(); return function; } else if (right_type.id() == LogicalTypeId::TIME) { ScalarFunction function("+", {left_type, right_type}, LogicalType::TIME, ScalarFunction::BinaryFunction); - ScalarFunction::SetReturnsError(function); + function.SetFallible(); return function; } else if (right_type.id() == LogicalTypeId::TIME_TZ) { ScalarFunction function( "+", {left_type, right_type}, LogicalType::TIME_TZ, ScalarFunction::BinaryFunction); - ScalarFunction::SetReturnsError(function); + function.SetFallible(); return function; } else if (right_type.id() == LogicalTypeId::TIMESTAMP) { ScalarFunction function("+", {left_type, right_type}, LogicalType::TIMESTAMP, ScalarFunction::BinaryFunction); - ScalarFunction::SetReturnsError(function); + function.SetFallible(); return function; } break; @@ -445,12 +444,12 @@ ScalarFunction AddFunction::GetFunction(const LogicalType &left_type, const Logi if (right_type.id() == LogicalTypeId::INTERVAL) { ScalarFunction function("+", {left_type, right_type}, LogicalType::TIME, ScalarFunction::BinaryFunction); - ScalarFunction::SetReturnsError(function); + function.SetFallible(); return function; } else if (right_type.id() == LogicalTypeId::DATE) { ScalarFunction function("+", {left_type, right_type}, LogicalType::TIMESTAMP, ScalarFunction::BinaryFunction); - ScalarFunction::SetReturnsError(function); + function.SetFallible(); return function; } break; @@ -458,13 +457,13 @@ ScalarFunction AddFunction::GetFunction(const LogicalType &left_type, const Logi if (right_type.id() == LogicalTypeId::DATE) { ScalarFunction function("+", {left_type, right_type}, LogicalType::TIMESTAMP_TZ, ScalarFunction::BinaryFunction); - ScalarFunction::SetReturnsError(function); + function.SetFallible(); return function; } else if (right_type.id() == LogicalTypeId::INTERVAL) { ScalarFunction function( "+", {left_type, right_type}, LogicalType::TIME_TZ, ScalarFunction::BinaryFunction); - ScalarFunction::SetReturnsError(function); + function.SetFallible(); return function; } break; @@ -472,7 +471,7 @@ ScalarFunction AddFunction::GetFunction(const LogicalType &left_type, const Logi if (right_type.id() == LogicalTypeId::INTERVAL) { ScalarFunction function("+", {left_type, right_type}, LogicalType::TIMESTAMP, ScalarFunction::BinaryFunction); - ScalarFunction::SetReturnsError(function); + function.SetFallible(); return function; } break; @@ -589,7 +588,6 @@ struct DecimalNegateBindData : public FunctionData { unique_ptr DecimalNegateBind(ClientContext &context, ScalarFunction &bound_function, vector> &arguments) { - auto bind_data = make_uniq(); auto &decimal_type = arguments[0]->return_type; @@ -606,7 +604,7 @@ unique_ptr DecimalNegateBind(ClientContext &context, ScalarFunctio } decimal_type.Verify(); bound_function.arguments[0] = decimal_type; - bound_function.return_type = decimal_type; + bound_function.SetReturnType(decimal_type); return nullptr; } @@ -672,7 +670,7 @@ unique_ptr NegateBindStatistics(ClientContext &context, Function ScalarFunction SubtractFunction::GetFunction(const LogicalType &type) { if (type.id() == LogicalTypeId::INTERVAL) { ScalarFunction func("-", {type}, type, ScalarFunction::UnaryFunction); - ScalarFunction::SetReturnsError(func); + func.SetFallible(); return func; } else if (type.id() == LogicalTypeId::DECIMAL) { ScalarFunction func("-", {type}, type, nullptr, DecimalNegateBind, nullptr, NegateBindStatistics); @@ -684,7 +682,7 @@ ScalarFunction SubtractFunction::GetFunction(const LogicalType &type) { D_ASSERT(type.IsNumeric()); ScalarFunction func("-", {type}, type, ScalarFunction::GetScalarUnaryFunction(type), nullptr, nullptr, NegateBindStatistics); - ScalarFunction::SetReturnsError(func); + func.SetFallible(); return func; } } @@ -694,7 +692,7 @@ ScalarFunction SubtractFunction::GetFunction(const LogicalType &left_type, const if (left_type.id() == LogicalTypeId::DECIMAL) { ScalarFunction function("-", {left_type, right_type}, left_type, nullptr, BindDecimalAddSubtract); - ScalarFunction::SetReturnsError(function); + function.SetFallible(); function.serialize = SerializeDecimalArithmetic; function.deserialize = DeserializeDecimalArithmetic; return function; @@ -703,13 +701,13 @@ ScalarFunction SubtractFunction::GetFunction(const LogicalType &left_type, const "-", {left_type, right_type}, left_type, GetScalarIntegerFunction(left_type.InternalType()), nullptr, nullptr, PropagateNumericStats); - ScalarFunction::SetReturnsError(function); + function.SetFallible(); return function; } else { ScalarFunction function("-", {left_type, right_type}, left_type, GetScalarBinaryFunction(left_type.InternalType())); - ScalarFunction::SetReturnsError(function); + function.SetFallible(); return function; } } @@ -723,18 +721,18 @@ ScalarFunction SubtractFunction::GetFunction(const LogicalType &left_type, const if (right_type.id() == LogicalTypeId::DATE) { ScalarFunction function("-", {left_type, right_type}, LogicalType::BIGINT, ScalarFunction::BinaryFunction); - ScalarFunction::SetReturnsError(function); + function.SetFallible(); return function; } else if (right_type.id() == LogicalTypeId::INTEGER) { ScalarFunction function("-", {left_type, right_type}, LogicalType::DATE, ScalarFunction::BinaryFunction); - ScalarFunction::SetReturnsError(function); + function.SetFallible(); return function; } else if (right_type.id() == LogicalTypeId::INTERVAL) { ScalarFunction function("-", {left_type, right_type}, LogicalType::TIMESTAMP, ScalarFunction::BinaryFunction); - ScalarFunction::SetReturnsError(function); + function.SetFallible(); return function; } break; @@ -743,13 +741,13 @@ ScalarFunction SubtractFunction::GetFunction(const LogicalType &left_type, const ScalarFunction function( "-", {left_type, right_type}, LogicalType::INTERVAL, ScalarFunction::BinaryFunction); - ScalarFunction::SetReturnsError(function); + function.SetFallible(); return function; } else if (right_type.id() == LogicalTypeId::INTERVAL) { ScalarFunction function( "-", {left_type, right_type}, LogicalType::TIMESTAMP, ScalarFunction::BinaryFunction); - ScalarFunction::SetReturnsError(function); + function.SetFallible(); return function; } break; @@ -758,7 +756,7 @@ ScalarFunction SubtractFunction::GetFunction(const LogicalType &left_type, const ScalarFunction function( "-", {left_type, right_type}, LogicalType::INTERVAL, ScalarFunction::BinaryFunction); - ScalarFunction::SetReturnsError(function); + function.SetFallible(); return function; } break; @@ -766,7 +764,7 @@ ScalarFunction SubtractFunction::GetFunction(const LogicalType &left_type, const if (right_type.id() == LogicalTypeId::INTERVAL) { ScalarFunction function("-", {left_type, right_type}, LogicalType::TIME, ScalarFunction::BinaryFunction); - ScalarFunction::SetReturnsError(function); + function.SetFallible(); return function; } break; @@ -775,7 +773,7 @@ ScalarFunction SubtractFunction::GetFunction(const LogicalType &left_type, const ScalarFunction function( "-", {left_type, right_type}, LogicalType::TIME_TZ, ScalarFunction::BinaryFunction); - ScalarFunction::SetReturnsError(function); + function.SetFallible(); return function; } break; @@ -861,7 +859,6 @@ struct MultiplyPropagateStatistics { unique_ptr BindDecimalMultiply(ClientContext &context, ScalarFunction &bound_function, vector> &arguments) { - auto bind_data = make_uniq(); uint8_t result_width = 0, result_scale = 0; @@ -915,7 +912,7 @@ unique_ptr BindDecimalMultiply(ClientContext &context, ScalarFunct } } result_type.Verify(); - bound_function.return_type = result_type; + bound_function.SetReturnType(result_type); // now select the physical function to execute if (bind_data->check_overflow) { bound_function.function = GetScalarBinaryFunction(result_type.InternalType()); @@ -962,7 +959,7 @@ ScalarFunctionSet OperatorMultiplyFun::GetFunctions() { ScalarFunction({LogicalType::INTERVAL, LogicalType::BIGINT}, LogicalType::INTERVAL, ScalarFunction::BinaryFunction)); for (auto &func : multiply.functions) { - ScalarFunction::SetReturnsError(func); + func.SetFallible(); } return multiply; @@ -1096,9 +1093,9 @@ template unique_ptr BindBinaryFloatingPoint(ClientContext &context, ScalarFunction &bound_function, vector> &arguments) { if (DBConfig::GetSetting(context)) { - bound_function.function = GetScalarBinaryFunction(bound_function.return_type.InternalType()); + bound_function.function = GetScalarBinaryFunction(bound_function.GetReturnType().InternalType()); } else { - bound_function.function = GetBinaryFunctionIgnoreZero(bound_function.return_type.InternalType()); + bound_function.function = GetBinaryFunctionIgnoreZero(bound_function.GetReturnType().InternalType()); } return nullptr; } @@ -1114,7 +1111,7 @@ ScalarFunctionSet OperatorFloatDivideFun::GetFunctions() { ScalarFunction({LogicalType::INTERVAL, LogicalType::DOUBLE}, LogicalType::INTERVAL, BinaryScalarFunctionIgnoreZero)); for (auto &func : fp_divide.functions) { - ScalarFunction::SetReturnsError(func); + func.SetFallible(); } return fp_divide; } @@ -1130,7 +1127,7 @@ ScalarFunctionSet OperatorIntegerDivideFun::GetFunctions() { } } for (auto &func : full_divide.functions) { - ScalarFunction::SetReturnsError(func); + func.SetFallible(); } return full_divide; } @@ -1148,9 +1145,9 @@ static unique_ptr BindDecimalModulo(ClientContext &context, Scalar for (auto &arg : bound_function.arguments) { arg = LogicalType::DOUBLE; } - bound_function.return_type = LogicalType::DOUBLE; + bound_function.SetReturnType(LogicalType::DOUBLE); } - auto &result_type = bound_function.return_type; + auto &result_type = bound_function.GetReturnType(); bound_function.function = GetBinaryFunctionIgnoreZero(result_type.InternalType()); return std::move(bind_data); } @@ -1188,7 +1185,7 @@ ScalarFunctionSet OperatorModuloFun::GetFunctions() { } } for (auto &func : modulo.functions) { - ScalarFunction::SetReturnsError(func); + func.SetFallible(); } return modulo; @@ -1220,7 +1217,7 @@ hugeint_t InterpolateOperator::Operation(const hugeint_t &lo, const double d, co template <> uhugeint_t InterpolateOperator::Operation(const uhugeint_t &lo, const double d, const uhugeint_t &hi) { - return Hugeint::Convert(Operation(Uhugeint::Cast(lo), d, Uhugeint::Cast(hi))); + return Uhugeint::Convert(Operation(Uhugeint::Cast(lo), d, Uhugeint::Cast(hi))); } static interval_t MultiplyByDouble(const interval_t &i, const double &d) { // NOLINT diff --git a/src/duckdb/src/function/scalar/sequence/nextval.cpp b/src/duckdb/src/function/scalar/sequence/nextval.cpp index c053bb7f6..1defb3653 100644 --- a/src/duckdb/src/function/scalar/sequence/nextval.cpp +++ b/src/duckdb/src/function/scalar/sequence/nextval.cpp @@ -141,12 +141,12 @@ ScalarFunction NextvalFun::GetFunction() { ScalarFunction next_val("nextval", {LogicalType::VARCHAR}, LogicalType::BIGINT, NextValFunction, nullptr, nullptr); next_val.bind_extended = NextValBind; - next_val.stability = FunctionStability::VOLATILE; next_val.serialize = Serialize; next_val.deserialize = Deserialize; next_val.get_modified_databases = NextValModifiedDatabases; next_val.init_local_state = NextValLocalFunction; - BaseScalarFunction::SetReturnsError(next_val); + next_val.SetVolatile(); + next_val.SetFallible(); return next_val; } @@ -154,11 +154,11 @@ ScalarFunction CurrvalFun::GetFunction() { ScalarFunction curr_val("currval", {LogicalType::VARCHAR}, LogicalType::BIGINT, NextValFunction, nullptr, nullptr); curr_val.bind_extended = NextValBind; - curr_val.stability = FunctionStability::VOLATILE; curr_val.serialize = Serialize; curr_val.deserialize = Deserialize; curr_val.init_local_state = NextValLocalFunction; - BaseScalarFunction::SetReturnsError(curr_val); + curr_val.SetVolatile(); + curr_val.SetFallible(); return curr_val; } diff --git a/src/duckdb/src/function/scalar/string/concat.cpp b/src/duckdb/src/function/scalar/string/concat.cpp index cae184de9..97a74cebe 100644 --- a/src/duckdb/src/function/scalar/string/concat.cpp +++ b/src/duckdb/src/function/scalar/string/concat.cpp @@ -208,9 +208,13 @@ void ListConcatFunction(DataChunk &args, ExpressionState &state, Vector &result, void ConcatFunction(DataChunk &args, ExpressionState &state, Vector &result) { auto &func_expr = state.expr.Cast(); auto &info = func_expr.bind_info->Cast(); + if (info.return_type.id() == LogicalTypeId::SQLNULL) { + return; + } if (info.return_type.id() == LogicalTypeId::LIST) { return ListConcatFunction(args, state, result, info.is_operator); - } else if (info.is_operator) { + } + if (info.is_operator) { return ConcatOperator(args, state, result); } return StringConcatFunction(args, state, result); @@ -220,7 +224,7 @@ void SetArgumentType(ScalarFunction &bound_function, const LogicalType &type, bo if (is_operator) { bound_function.arguments[0] = type; bound_function.arguments[1] = type; - bound_function.return_type = type; + bound_function.SetReturnType(type); return; } @@ -228,7 +232,7 @@ void SetArgumentType(ScalarFunction &bound_function, const LogicalType &type, bo arg = type; } bound_function.varargs = type; - bound_function.return_type = type; + bound_function.SetReturnType(type); } unique_ptr BindListConcat(ClientContext &context, ScalarFunction &bound_function, @@ -277,17 +281,18 @@ unique_ptr BindListConcat(ClientContext &context, ScalarFunction & if (all_null) { // all arguments are NULL SetArgumentType(bound_function, LogicalTypeId::SQLNULL, is_operator); - return make_uniq(bound_function.return_type, is_operator); + return make_uniq(bound_function.GetReturnType(), is_operator); } auto list_type = LogicalType::LIST(child_type); SetArgumentType(bound_function, list_type, is_operator); - return make_uniq(bound_function.return_type, is_operator); + return make_uniq(bound_function.GetReturnType(), is_operator); } unique_ptr BindConcatFunctionInternal(ClientContext &context, ScalarFunction &bound_function, vector> &arguments, bool is_operator) { bool list_concat = false; + bool all_null = true; // blob concat is only supported for the concat operator - regular concat converts to varchar bool all_blob = is_operator ? true : false; for (auto &arg : arguments) { @@ -300,15 +305,18 @@ unique_ptr BindConcatFunctionInternal(ClientContext &context, Scal if (arg->return_type.id() != LogicalTypeId::BLOB) { all_blob = false; } + if (arg->return_type.id() != LogicalTypeId::SQLNULL) { + all_null = false; + } } - if (list_concat) { + if (list_concat || all_null) { return BindListConcat(context, bound_function, arguments, is_operator); } auto return_type = all_blob ? LogicalType::BLOB : LogicalType::VARCHAR; // we can now assume that the input is a string or castable to a string SetArgumentType(bound_function, return_type, is_operator); - return make_uniq(bound_function.return_type, is_operator); + return make_uniq(bound_function.GetReturnType(), is_operator); } unique_ptr BindConcatFunction(ClientContext &context, ScalarFunction &bound_function, @@ -337,7 +345,7 @@ ScalarFunction ListConcatFun::GetFunction() { auto fun = ScalarFunction({}, LogicalType::LIST(LogicalType::ANY), ConcatFunction, BindConcatFunction, nullptr, ListConcatStats); fun.varargs = LogicalType::LIST(LogicalType::ANY); - fun.null_handling = FunctionNullHandling::SPECIAL_HANDLING; + fun.SetNullHandling(FunctionNullHandling::SPECIAL_HANDLING); return fun; } @@ -353,7 +361,7 @@ ScalarFunction ConcatFun::GetFunction() { ScalarFunction concat = ScalarFunction("concat", {LogicalType::ANY}, LogicalType::ANY, ConcatFunction, BindConcatFunction); concat.varargs = LogicalType::ANY; - concat.null_handling = FunctionNullHandling::SPECIAL_HANDLING; + concat.SetNullHandling(FunctionNullHandling::SPECIAL_HANDLING); return concat; } diff --git a/src/duckdb/src/function/scalar/string/concat_ws.cpp b/src/duckdb/src/function/scalar/string/concat_ws.cpp index ebc1e8b3a..9b67878cd 100644 --- a/src/duckdb/src/function/scalar/string/concat_ws.cpp +++ b/src/duckdb/src/function/scalar/string/concat_ws.cpp @@ -142,7 +142,7 @@ ScalarFunction ConcatWsFun::GetFunction() { ScalarFunction concat_ws = ScalarFunction("concat_ws", {LogicalType::VARCHAR, LogicalType::ANY}, LogicalType::VARCHAR, ConcatWSFunction, BindConcatWSFunction); concat_ws.varargs = LogicalType::ANY; - concat_ws.null_handling = FunctionNullHandling::SPECIAL_HANDLING; + concat_ws.SetNullHandling(FunctionNullHandling::SPECIAL_HANDLING); return ScalarFunction(concat_ws); } diff --git a/src/duckdb/src/function/scalar/string/contains.cpp b/src/duckdb/src/function/scalar/string/contains.cpp index fb496b1fd..95c53ff01 100644 --- a/src/duckdb/src/function/scalar/string/contains.cpp +++ b/src/duckdb/src/function/scalar/string/contains.cpp @@ -121,7 +121,7 @@ idx_t FindStrInStr(const string_t &haystack_s, const string_t &needle_s) { ScalarFunction GetStringContains() { ScalarFunction string_fun("contains", {LogicalType::VARCHAR, LogicalType::VARCHAR}, LogicalType::BOOLEAN, ScalarFunction::BinaryFunction); - string_fun.collation_handling = FunctionCollationHandling::PUSH_COMBINABLE_COLLATIONS; + string_fun.SetCollationHandling(FunctionCollationHandling::PUSH_COMBINABLE_COLLATIONS); return string_fun; } diff --git a/src/duckdb/src/function/scalar/string/length.cpp b/src/duckdb/src/function/scalar/string/length.cpp index 66542af3c..2f0792a65 100644 --- a/src/duckdb/src/function/scalar/string/length.cpp +++ b/src/duckdb/src/function/scalar/string/length.cpp @@ -248,7 +248,7 @@ ScalarFunctionSet ArrayLengthFun::GetFunctions() { array_length.AddFunction(ScalarFunction({LogicalType::LIST(LogicalType::ANY), LogicalType::BIGINT}, LogicalType::BIGINT, nullptr, ArrayOrListLengthBinaryBind)); for (auto &func : array_length.functions) { - BaseScalarFunction::SetReturnsError(func); + func.SetFallible(); } return (array_length); } diff --git a/src/duckdb/src/function/scalar/string/like.cpp b/src/duckdb/src/function/scalar/string/like.cpp index ba974f9d2..9c65cae17 100644 --- a/src/duckdb/src/function/scalar/string/like.cpp +++ b/src/duckdb/src/function/scalar/string/like.cpp @@ -524,14 +524,14 @@ void RegularLikeFunction(DataChunk &input, ExpressionState &state, Vector &resul ScalarFunction NotLikeFun::GetFunction() { ScalarFunction not_like("!~~", {LogicalType::VARCHAR, LogicalType::VARCHAR}, LogicalType::BOOLEAN, RegularLikeFunction, LikeBindFunction); - not_like.collation_handling = FunctionCollationHandling::PUSH_COMBINABLE_COLLATIONS; + not_like.SetCollationHandling(FunctionCollationHandling::PUSH_COMBINABLE_COLLATIONS); return not_like; } ScalarFunction GlobPatternFun::GetFunction() { ScalarFunction glob("~~~", {LogicalType::VARCHAR, LogicalType::VARCHAR}, LogicalType::BOOLEAN, ScalarFunction::BinaryFunction); - glob.collation_handling = FunctionCollationHandling::PUSH_COMBINABLE_COLLATIONS; + glob.SetCollationHandling(FunctionCollationHandling::PUSH_COMBINABLE_COLLATIONS); return glob; } @@ -539,7 +539,7 @@ ScalarFunction ILikeFun::GetFunction() { ScalarFunction ilike("~~*", {LogicalType::VARCHAR, LogicalType::VARCHAR}, LogicalType::BOOLEAN, ScalarFunction::BinaryFunction, nullptr, nullptr, ILikePropagateStats); - ilike.collation_handling = FunctionCollationHandling::PUSH_COMBINABLE_COLLATIONS; + ilike.SetCollationHandling(FunctionCollationHandling::PUSH_COMBINABLE_COLLATIONS); return ilike; } @@ -547,14 +547,14 @@ ScalarFunction NotILikeFun::GetFunction() { ScalarFunction not_ilike("!~~*", {LogicalType::VARCHAR, LogicalType::VARCHAR}, LogicalType::BOOLEAN, ScalarFunction::BinaryFunction, nullptr, nullptr, ILikePropagateStats); - not_ilike.collation_handling = FunctionCollationHandling::PUSH_COMBINABLE_COLLATIONS; + not_ilike.SetCollationHandling(FunctionCollationHandling::PUSH_COMBINABLE_COLLATIONS); return not_ilike; } ScalarFunction LikeFun::GetFunction() { ScalarFunction like("~~", {LogicalType::VARCHAR, LogicalType::VARCHAR}, LogicalType::BOOLEAN, RegularLikeFunction, LikeBindFunction); - like.collation_handling = FunctionCollationHandling::PUSH_COMBINABLE_COLLATIONS; + like.SetCollationHandling(FunctionCollationHandling::PUSH_COMBINABLE_COLLATIONS); return like; } @@ -562,14 +562,14 @@ ScalarFunction NotLikeEscapeFun::GetFunction() { ScalarFunction not_like_escape("not_like_escape", {LogicalType::VARCHAR, LogicalType::VARCHAR, LogicalType::VARCHAR}, LogicalType::BOOLEAN, LikeEscapeFunction); - not_like_escape.collation_handling = FunctionCollationHandling::PUSH_COMBINABLE_COLLATIONS; + not_like_escape.SetCollationHandling(FunctionCollationHandling::PUSH_COMBINABLE_COLLATIONS); return not_like_escape; } ScalarFunction IlikeEscapeFun::GetFunction() { ScalarFunction ilike_escape("ilike_escape", {LogicalType::VARCHAR, LogicalType::VARCHAR, LogicalType::VARCHAR}, LogicalType::BOOLEAN, LikeEscapeFunction); - ilike_escape.collation_handling = FunctionCollationHandling::PUSH_COMBINABLE_COLLATIONS; + ilike_escape.SetCollationHandling(FunctionCollationHandling::PUSH_COMBINABLE_COLLATIONS); return ilike_escape; } @@ -577,13 +577,13 @@ ScalarFunction NotIlikeEscapeFun::GetFunction() { ScalarFunction not_ilike_escape("not_ilike_escape", {LogicalType::VARCHAR, LogicalType::VARCHAR, LogicalType::VARCHAR}, LogicalType::BOOLEAN, LikeEscapeFunction); - not_ilike_escape.collation_handling = FunctionCollationHandling::PUSH_COMBINABLE_COLLATIONS; + not_ilike_escape.SetCollationHandling(FunctionCollationHandling::PUSH_COMBINABLE_COLLATIONS); return not_ilike_escape; } ScalarFunction LikeEscapeFun::GetFunction() { ScalarFunction like_escape("like_escape", {LogicalType::VARCHAR, LogicalType::VARCHAR, LogicalType::VARCHAR}, LogicalType::BOOLEAN, LikeEscapeFunction); - like_escape.collation_handling = FunctionCollationHandling::PUSH_COMBINABLE_COLLATIONS; + like_escape.SetCollationHandling(FunctionCollationHandling::PUSH_COMBINABLE_COLLATIONS); return like_escape; } diff --git a/src/duckdb/src/function/scalar/string/regexp.cpp b/src/duckdb/src/function/scalar/string/regexp.cpp index f91121a07..347fdfaa0 100644 --- a/src/duckdb/src/function/scalar/string/regexp.cpp +++ b/src/duckdb/src/function/scalar/string/regexp.cpp @@ -245,6 +245,11 @@ static void RegexExtractFunction(DataChunk &args, ExpressionState &state, Vector // Regexp Extract Struct //===--------------------------------------------------------------------===// static void RegexExtractStructFunction(DataChunk &args, ExpressionState &state, Vector &result) { + // This function assumes a constant pre-compiled pattern stored in the local state. + // If a non-constant pattern reaches here it indicates a binder bug. Return a clean error instead of crashing. + if (!ExecuteFunctionState::GetFunctionState(state)) { + throw InternalException("REGEXP_EXTRACT struct variant executed without constant pattern state"); + } auto &lstate = ExecuteFunctionState::GetFunctionState(state)->Cast(); const auto count = args.size(); @@ -346,32 +351,13 @@ static unique_ptr RegexExtractBind(ClientContext &context, ScalarF group_string = ""; } else if (group.type().id() == LogicalTypeId::LIST) { if (!constant_pattern) { - throw BinderException("%s with LIST requires a constant pattern", bound_function.name); - } - auto &list_children = ListValue::GetChildren(group); - if (list_children.empty()) { - throw BinderException("%s requires non-empty lists of capture names", bound_function.name); + throw BinderException("%s with LIST of group names requires a constant pattern", bound_function.name); } - case_insensitive_set_t name_collision_set; + vector dummy_names; // not reused after bind child_list_t struct_children; - for (const auto &child : list_children) { - if (child.IsNull()) { - throw BinderException("NULL group name in %s", bound_function.name); - } - const auto group_name = child.ToString(); - if (name_collision_set.find(group_name) != name_collision_set.end()) { - throw BinderException("Duplicate group name \"%s\" in %s", group_name, bound_function.name); - } - name_collision_set.insert(group_name); - struct_children.emplace_back(make_pair(group_name, LogicalType::VARCHAR)); - } - bound_function.return_type = LogicalType::STRUCT(struct_children); - - duckdb_re2::StringPiece constant_piece(constant_string.c_str(), constant_string.size()); - RE2 constant_pattern(constant_piece, options); - if (size_t(constant_pattern.NumberOfCapturingGroups()) < list_children.size()) { - throw BinderException("Not enough group names in %s", bound_function.name); - } + regexp_util::ParseGroupNameList(context, bound_function.name, *arguments[2], constant_string, options, + constant_pattern, dummy_names, struct_children); + bound_function.SetReturnType(LogicalType::STRUCT(struct_children)); } else { auto group_idx = group.GetValue(); if (group_idx < 0 || group_idx > 9) { @@ -409,7 +395,7 @@ ScalarFunctionSet RegexpMatchesFun::GetFunctions() { RegexpMatchesFunction, RegexpMatchesBind, nullptr, nullptr, RegexInitLocalState, LogicalType::INVALID, FunctionStability::CONSISTENT, FunctionNullHandling::SPECIAL_HANDLING)); for (auto &func : regexp_partial_match.functions) { - BaseScalarFunction::SetReturnsError(func); + func.SetFallible(); } return (regexp_partial_match); } @@ -467,6 +453,19 @@ ScalarFunctionSet RegexpExtractAllFun::GetFunctions() { LogicalType::LIST(LogicalType::VARCHAR), RegexpExtractAll::Execute, RegexpExtractAll::Bind, nullptr, nullptr, RegexpExtractAll::InitLocalState, LogicalType::INVALID, FunctionStability::CONSISTENT, FunctionNullHandling::SPECIAL_HANDLING)); + // Struct multi-match variant(s): pattern must be constant due to bind-time struct shape inference + regexp_extract_all.AddFunction( + ScalarFunction({LogicalType::VARCHAR, LogicalType::VARCHAR, LogicalType::LIST(LogicalType::VARCHAR)}, + LogicalType::LIST(LogicalType::VARCHAR), // temporary, replaced in bind + RegexpExtractAllStruct::Execute, RegexpExtractAllStruct::Bind, nullptr, nullptr, + RegexpExtractAllStruct::InitLocalState, LogicalType::INVALID, FunctionStability::CONSISTENT, + FunctionNullHandling::SPECIAL_HANDLING)); + regexp_extract_all.AddFunction(ScalarFunction( + {LogicalType::VARCHAR, LogicalType::VARCHAR, LogicalType::LIST(LogicalType::VARCHAR), LogicalType::VARCHAR}, + LogicalType::LIST(LogicalType::VARCHAR), // temporary, replaced in bind + RegexpExtractAllStruct::Execute, RegexpExtractAllStruct::Bind, nullptr, nullptr, + RegexpExtractAllStruct::InitLocalState, LogicalType::INVALID, FunctionStability::CONSISTENT, + FunctionNullHandling::SPECIAL_HANDLING)); return (regexp_extract_all); } diff --git a/src/duckdb/src/function/scalar/string/regexp/regexp_extract_all.cpp b/src/duckdb/src/function/scalar/string/regexp/regexp_extract_all.cpp index 144dcff03..151b7c599 100644 --- a/src/duckdb/src/function/scalar/string/regexp/regexp_extract_all.cpp +++ b/src/duckdb/src/function/scalar/string/regexp/regexp_extract_all.cpp @@ -4,6 +4,7 @@ #include "duckdb/planner/expression/bound_function_expression.hpp" #include "duckdb/function/scalar/string_functions.hpp" #include "re2/re2.h" +#include "re2/stringpiece.h" namespace duckdb { @@ -21,10 +22,19 @@ RegexpExtractAll::InitLocalState(ExpressionState &state, const BoundFunctionExpr return nullptr; } +unique_ptr RegexpExtractAllStruct::InitLocalState(ExpressionState &state, + const BoundFunctionExpression &expr, + FunctionData *bind_data) { + auto &info = bind_data->Cast(); + if (info.constant_pattern) { + return make_uniq(info, true); + } + return nullptr; +} + // Forwards startpos automatically bool ExtractAll(duckdb_re2::StringPiece &input, duckdb_re2::RE2 &pattern, idx_t *startpos, duckdb_re2::StringPiece *groups, int ngroups) { - D_ASSERT(pattern.ok()); D_ASSERT(pattern.NumberOfCapturingGroups() == ngroups); @@ -33,13 +43,8 @@ bool ExtractAll(duckdb_re2::StringPiece &input, duckdb_re2::RE2 &pattern, idx_t } idx_t consumed = static_cast(groups[0].end() - (input.begin() + *startpos)); if (!consumed) { - // Empty match found, have to manually forward the input - // to avoid an infinite loop - // FIXME: support unicode characters - consumed++; - while (*startpos + consumed < input.length() && !IsCharacter(input[*startpos + consumed])) { - consumed++; - } + // Empty match: advance exactly one UTF-8 codepoint + consumed = regexp_util::AdvanceOneUTF8Basic(input, *startpos); } *startpos += consumed; return true; @@ -228,6 +233,136 @@ void RegexpExtractAll::Execute(DataChunk &args, ExpressionState &state, Vector & } } +static inline bool ExtractAllStruct(duckdb_re2::StringPiece &input, duckdb_re2::RE2 &re, idx_t &startpos, + duckdb_re2::StringPiece *groups, int provided_groups) { + D_ASSERT(re.ok()); + if (!re.Match(input, startpos, input.size(), re.UNANCHORED, groups, provided_groups + 1)) { + return false; + } + idx_t consumed = static_cast(groups[0].end() - (input.begin() + startpos)); + if (!consumed) { + consumed = regexp_util::AdvanceOneUTF8Basic(input, startpos); + } + startpos += consumed; + return true; +} + +static void ExtractStructAllSingleTuple(const string_t &string_val, duckdb_re2::RE2 &re, + vector &group_spans, + vector> &child_entries, Vector &result, idx_t row) { + const idx_t group_count = child_entries.size(); + auto list_entries = FlatVector::GetData(result); + idx_t current_list_size = ListVector::GetListSize(result); + list_entries[row].offset = current_list_size; + + auto input_piece = CreateStringPiece(string_val); + idx_t startpos = 0; + for (; ExtractAllStruct(input_piece, re, startpos, group_spans.data(), UnsafeNumericCast(group_count));) { + // Ensure capacity + if (current_list_size + 1 >= ListVector::GetListCapacity(result)) { + ListVector::Reserve(result, ListVector::GetListCapacity(result) * 2); + } + // Write each selected group + for (idx_t g = 0; g < group_count; g++) { + auto &child_vec = *child_entries[g]; + child_vec.SetVectorType(VectorType::FLAT_VECTOR); + auto cdata = FlatVector::GetData(child_vec); + auto &span = group_spans[g + 1]; + if (span.empty()) { + if (span.begin() == nullptr) { + // Unmatched optional group -> always NULL + FlatVector::Validity(child_vec).SetInvalid(current_list_size); + } + cdata[current_list_size] = string_t(string_val.GetData(), 0); + } else { + auto offset = span.begin() - string_val.GetData(); + cdata[current_list_size] = + string_t(string_val.GetData() + offset, UnsafeNumericCast(span.size())); + } + } + current_list_size++; + if (startpos > input_piece.size()) { + break; // empty match at end + } + } + list_entries[row].length = current_list_size - list_entries[row].offset; + ListVector::SetListSize(result, current_list_size); +} + +void RegexpExtractAllStruct::Execute(DataChunk &args, ExpressionState &state, Vector &result) { + auto &func_expr = state.expr.Cast(); + const auto &info = func_expr.bind_info->Cast(); + // Struct multi-match variant only supports constant pattern (enforced in Bind) + D_ASSERT(info.constant_pattern); + + // Expect arguments: string, pattern, list_of_group_names [, options] + auto &strings = args.data[0]; + + D_ASSERT(result.GetType().id() == LogicalTypeId::LIST); + auto &struct_vector = ListVector::GetEntry(result); + D_ASSERT(struct_vector.GetType().id() == LogicalTypeId::STRUCT); + auto &child_entries = StructVector::GetEntries(struct_vector); + const idx_t group_count = child_entries.size(); + + // Reference original string buffer for zero-copy substring assignment + for (auto &child : child_entries) { + child->SetAuxiliary(strings.GetAuxiliary()); + child->SetVectorType(VectorType::FLAT_VECTOR); + } + + UnifiedVectorFormat strings_data; + strings.ToUnifiedFormat(args.size(), strings_data); + ListVector::Reserve(result, STANDARD_VECTOR_SIZE); + idx_t tuple_count = args.AllConstant() ? 1 : args.size(); + + auto &lstate = ExecuteFunctionState::GetFunctionState(state)->Cast(); + + auto &list_validity = FlatVector::Validity(result); + auto list_entries = FlatVector::GetData(result); + + vector group_spans(group_count + 1); + + for (idx_t row = 0; row < tuple_count; row++) { + auto sindex = strings_data.sel->get_index(row); + if (!strings_data.validity.RowIsValid(sindex)) { + list_entries[row].offset = ListVector::GetListSize(result); + list_entries[row].length = 0; + list_validity.SetInvalid(row); + continue; + } + auto &string_val = UnifiedVectorFormat::GetData(strings_data)[sindex]; + ExtractStructAllSingleTuple(string_val, lstate.constant_pattern, group_spans, child_entries, result, row); + } + if (args.AllConstant()) { + result.SetVectorType(VectorType::CONSTANT_VECTOR); + } +} + +unique_ptr RegexpExtractAllStruct::Bind(ClientContext &context, ScalarFunction &bound_function, + vector> &arguments) { + // arguments: string, pattern, LIST group_names [, options] + if (arguments.size() < 3) { + throw BinderException("regexp_extract_all struct variant requires at least 3 arguments"); + } + duckdb_re2::RE2::Options options; + string constant_string; + bool constant_pattern = TryParseConstantPattern(context, *arguments[1], constant_string); + if (!constant_pattern) { + throw BinderException("%s with LIST requires a constant pattern", bound_function.name); + } + if (arguments.size() >= 4) { + ParseRegexOptions(context, *arguments[3], options); + } + options.set_log_errors(false); + vector group_names; + child_list_t struct_children; + regexp_util::ParseGroupNameList(context, bound_function.name, *arguments[2], constant_string, options, true, + group_names, struct_children); + bound_function.SetReturnType(LogicalType::LIST(LogicalType::STRUCT(struct_children))); + return make_uniq(options, std::move(constant_string), constant_pattern, + std::move(group_names)); +} + unique_ptr RegexpExtractAll::Bind(ClientContext &context, ScalarFunction &bound_function, vector> &arguments) { D_ASSERT(arguments.size() >= 2); diff --git a/src/duckdb/src/function/scalar/string/regexp/regexp_util.cpp b/src/duckdb/src/function/scalar/string/regexp/regexp_util.cpp index 4e485195c..2bac42104 100644 --- a/src/duckdb/src/function/scalar/string/regexp/regexp_util.cpp +++ b/src/duckdb/src/function/scalar/string/regexp/regexp_util.cpp @@ -1,5 +1,7 @@ #include "duckdb/function/scalar/regexp.hpp" #include "duckdb/execution/expression_executor.hpp" +#include "re2/re2.h" +#include "re2/stringpiece.h" namespace duckdb { @@ -78,6 +80,76 @@ void ParseRegexOptions(ClientContext &context, Expression &expr, RE2::Options &t ParseRegexOptions(StringValue::Get(options_str), target, global_replace); } +void ParseGroupNameList(ClientContext &context, const string &function_name, Expression &group_expr, + const string &pattern_string, RE2::Options &options, bool require_constant_pattern, + vector &out_names, child_list_t &out_struct_children) { + if (group_expr.HasParameter()) { + throw ParameterNotResolvedException(); + } + if (!group_expr.IsFoldable()) { + throw InvalidInputException("Group specification field must be a constant list"); + } + Value list_val = ExpressionExecutor::EvaluateScalar(context, group_expr); + if (list_val.IsNull() || list_val.type().id() != LogicalTypeId::LIST) { + throw BinderException("Group specification must be a non-NULL LIST"); + } + auto &children = ListValue::GetChildren(list_val); + if (children.empty()) { + throw BinderException("Group name list must be non-empty"); + } + case_insensitive_set_t name_set; + for (auto &child : children) { + if (child.IsNull()) { + throw BinderException("NULL group name in %s", function_name); + } + auto name = child.ToString(); + if (name_set.find(name) != name_set.end()) { + throw BinderException("Duplicate group name '%s' in %s", name, function_name); + } + name_set.insert(name); + out_names.push_back(name); + out_struct_children.emplace_back(make_pair(name, LogicalType::VARCHAR)); + } + if (require_constant_pattern) { + duckdb_re2::StringPiece const_piece(pattern_string.c_str(), pattern_string.size()); + RE2 constant_re(const_piece, options); + auto group_cnt = constant_re.NumberOfCapturingGroups(); + if (group_cnt == -1) { + throw BinderException("Pattern failed to parse: %s", constant_re.error()); + } + if ((idx_t)group_cnt < out_names.size()) { + throw BinderException("Not enough capturing groups (%d) for provided names (%llu)", group_cnt, + NumericCast(out_names.size())); + } + } +} + +// Advance exactly one UTF-8 codepoint starting at 'base'. Falls back to single byte on invalid lead. +// Does not do a full validation of UTF-8 sequence, assumes input is mostly valid UTF-8. +idx_t AdvanceOneUTF8Basic(const duckdb_re2::StringPiece &input, idx_t base) { + if (base >= input.length()) { + return 1; // Out of bounds, just advance one byte + } + unsigned char first = static_cast(input[base]); + idx_t char_len = 1; + if ((first & 0x80) == 0) { + char_len = 1; // ASCII + } else if ((first & 0xE0) == 0xC0) { + char_len = 2; + } else if ((first & 0xF0) == 0xE0) { + char_len = 3; + } else if ((first & 0xF8) == 0xF0) { + char_len = 4; + } else { + // This should be impossible since RE2 operates on codepoints + throw InternalException("Invalid UTF-8 lead byte in regexp_extract_all"); + } + if (base + char_len > input.length()) { + throw InternalException("Invalid UTF-8 sequence in regexp_extract_all"); + } + return char_len; +} + } // namespace regexp_util } // namespace duckdb diff --git a/src/duckdb/src/function/scalar/string/string_split.cpp b/src/duckdb/src/function/scalar/string/string_split.cpp index 070886d8c..2438dfe81 100644 --- a/src/duckdb/src/function/scalar/string/string_split.cpp +++ b/src/duckdb/src/function/scalar/string/string_split.cpp @@ -181,7 +181,7 @@ ScalarFunction StringSplitFun::GetFunction() { auto varchar_list_type = LogicalType::LIST(LogicalType::VARCHAR); ScalarFunction string_split({LogicalType::VARCHAR, LogicalType::VARCHAR}, varchar_list_type, StringSplitFunction); - string_split.null_handling = FunctionNullHandling::SPECIAL_HANDLING; + string_split.SetNullHandling(FunctionNullHandling::SPECIAL_HANDLING); return string_split; } diff --git a/src/duckdb/src/function/scalar/string/substring.cpp b/src/duckdb/src/function/scalar/string/substring.cpp index b82e6871b..de85777a5 100644 --- a/src/duckdb/src/function/scalar/string/substring.cpp +++ b/src/duckdb/src/function/scalar/string/substring.cpp @@ -16,7 +16,6 @@ static const int64_t SUPPORTED_UPPER_BOUND = NumericLimits::Maximum(); static const int64_t SUPPORTED_LOWER_BOUND = -SUPPORTED_UPPER_BOUND - 1; static inline void AssertInSupportedRange(idx_t input_size, int64_t offset, int64_t length) { - if (input_size > (uint64_t)SUPPORTED_UPPER_BOUND) { throw OutOfRangeException("Substring input size is too large (> %d)", SUPPORTED_UPPER_BOUND); } diff --git a/src/duckdb/src/function/scalar/struct/remap_struct.cpp b/src/duckdb/src/function/scalar/struct/remap_struct.cpp index 136a89165..e926a7bec 100644 --- a/src/duckdb/src/function/scalar/struct/remap_struct.cpp +++ b/src/duckdb/src/function/scalar/struct/remap_struct.cpp @@ -11,6 +11,10 @@ namespace duckdb { namespace { +static bool IsRemappable(const LogicalType &type) { + return type.IsNested() && type.id() != LogicalTypeId::VARIANT; +} + struct RemapColumnInfo { optional_idx index; optional_idx default_index; @@ -230,7 +234,7 @@ void RemapStruct(Vector &input, Vector &default_vector, Vector &result, idx_t re void RemapNested(Vector &input, Vector &default_vector, Vector &result, idx_t result_size, const vector &remap_info) { auto &source_type = input.GetType(); - D_ASSERT(source_type.IsNested()); + D_ASSERT(IsRemappable(source_type)); switch (source_type.id()) { case LogicalTypeId::STRUCT: return RemapStruct(input, default_vector, result, result_size, remap_info); @@ -293,7 +297,7 @@ struct RemapIndex { RemapIndex index; index.index = idx; index.type = type; - if (type.IsNested()) { + if (IsRemappable(type)) { index.child_map = make_uniq>(GetMap(type)); } return index; @@ -344,8 +348,8 @@ struct RemapEntry { auto &source_type = entry->second.type; auto &target_type = target_entry->second.type; - bool source_is_nested = source_type.IsNested(); - bool target_is_nested = target_type.IsNested(); + bool source_is_nested = IsRemappable(source_type); + bool target_is_nested = IsRemappable(target_type); RemapEntry remap; remap.index = entry->second.index; remap.target_type = target_entry->second.type; @@ -387,7 +391,7 @@ struct RemapEntry { remap.default_index = default_idx; if (default_type.id() == LogicalTypeId::STRUCT) { // nested remap - recurse - if (!target_type.IsNested()) { + if (!IsRemappable(target_type)) { throw BinderException("Default value is a struct - target value should be a nested type, is '%s'", target_type.ToString()); } @@ -436,7 +440,7 @@ struct RemapEntry { RemapColumnInfo info; info.index = entry->second.index; info.default_index = entry->second.default_index; - if (child_type.IsNested() && entry->second.child_remaps) { + if (IsRemappable(child_type) && entry->second.child_remaps) { // type is nested and a mapping for it is given - recurse info.child_remap_info = ConstructMap(child_type, *entry->second.child_remaps); } @@ -447,7 +451,7 @@ struct RemapEntry { static vector ConstructMap(const LogicalType &type, const case_insensitive_map_t &remap_map) { - D_ASSERT(type.IsNested()); + D_ASSERT(IsRemappable(type)); switch (type.id()) { case LogicalTypeId::STRUCT: { auto &target_children = StructType::GetChildTypes(type); @@ -484,7 +488,7 @@ struct RemapEntry { auto remap_entry = remap_map.find(entry->second); D_ASSERT(remap_entry != remap_map.end()); // this entry is remapped - fetch the target type - if (child_type.IsNested() && remap_entry->second.child_remaps) { + if (IsRemappable(child_type) && remap_entry->second.child_remaps) { // type is nested and a mapping for it is given - recurse new_source_children.emplace_back(child_name, RemapCast(child_type, *remap_entry->second.child_remaps)); @@ -552,7 +556,7 @@ unique_ptr RemapStructBind(ClientContext &context, ScalarFunction // remap target can be NULL continue; } - if (!arg->return_type.IsNested()) { + if (!IsRemappable(arg->return_type)) { throw BinderException("Struct remap can only remap nested types, not '%s'", arg->return_type.ToString()); } else if (arg->return_type.id() == LogicalTypeId::STRUCT && StructType::IsUnnamed(arg->return_type)) { throw BinderException("Struct remap can only remap named structs"); @@ -569,7 +573,7 @@ unique_ptr RemapStructBind(ClientContext &context, ScalarFunction throw BinderException("The defaults have to be either NULL or a named STRUCT, not an unnamed struct"); } - if ((from_type.IsNested() || to_type.IsNested()) && from_type.id() != to_type.id()) { + if ((IsRemappable(from_type) || IsRemappable(to_type)) && from_type.id() != to_type.id()) { throw BinderException("Can't change source type (%s) to target type (%s), type conversion not allowed", from_type.ToString(), to_type.ToString()); } @@ -617,7 +621,7 @@ unique_ptr RemapStructBind(ClientContext &context, ScalarFunction bound_function.arguments[1] = arguments[1]->return_type; bound_function.arguments[2] = arguments[2]->return_type; bound_function.arguments[3] = arguments[3]->return_type; - bound_function.return_type = arguments[1]->return_type; + bound_function.SetReturnType(arguments[1]->return_type); return make_uniq(std::move(remap)); } @@ -628,7 +632,7 @@ ScalarFunction RemapStructFun::GetFunction() { ScalarFunction remap("remap_struct", {LogicalTypeId::ANY, LogicalTypeId::ANY, LogicalTypeId::ANY, LogicalTypeId::ANY}, LogicalTypeId::ANY, RemapStructFunction, RemapStructBind); - remap.null_handling = FunctionNullHandling::SPECIAL_HANDLING; + remap.SetNullHandling(FunctionNullHandling::SPECIAL_HANDLING); return remap; } diff --git a/src/duckdb/src/function/scalar/struct/struct_concat.cpp b/src/duckdb/src/function/scalar/struct/struct_concat.cpp index ccfe7d363..153319891 100644 --- a/src/duckdb/src/function/scalar/struct/struct_concat.cpp +++ b/src/duckdb/src/function/scalar/struct/struct_concat.cpp @@ -33,7 +33,6 @@ static void StructConcatFunction(DataChunk &args, ExpressionState &state, Vector static unique_ptr StructConcatBind(ClientContext &context, ScalarFunction &bound_function, vector> &arguments) { - // collect names and deconflict, construct return type if (arguments.empty()) { throw InvalidInputException("struct_concat: At least one argument is required"); @@ -80,7 +79,7 @@ static unique_ptr StructConcatBind(ClientContext &context, ScalarF throw InvalidInputException("struct_concat: Cannot mix named and unnamed STRUCTs"); } - bound_function.return_type = LogicalType::STRUCT(combined_children); + bound_function.SetReturnType(LogicalType::STRUCT(combined_children)); return nullptr; } @@ -108,7 +107,7 @@ ScalarFunction StructConcatFun::GetFunction() { ScalarFunction fun("struct_concat", {}, LogicalTypeId::STRUCT, StructConcatFunction, StructConcatBind, nullptr, StructConcatStats); fun.varargs = LogicalType::ANY; - fun.null_handling = FunctionNullHandling::SPECIAL_HANDLING; + fun.SetNullHandling(FunctionNullHandling::SPECIAL_HANDLING); return fun; } diff --git a/src/duckdb/src/function/scalar/struct/struct_contains.cpp b/src/duckdb/src/function/scalar/struct/struct_contains.cpp index 3f8b39aa9..db9fd3554 100644 --- a/src/duckdb/src/function/scalar/struct/struct_contains.cpp +++ b/src/duckdb/src/function/scalar/struct/struct_contains.cpp @@ -204,7 +204,7 @@ static unique_ptr StructContainsBind(ClientContext &context, Scala if (child_type.id() == LogicalTypeId::SQLNULL) { bound_function.arguments[0] = LogicalTypeId::UNKNOWN; bound_function.arguments[1] = LogicalTypeId::UNKNOWN; - bound_function.return_type = LogicalType::SQLNULL; + bound_function.SetReturnType(LogicalType::SQLNULL); return nullptr; } @@ -248,7 +248,7 @@ ScalarFunction StructContainsFun::GetFunction() { ScalarFunction StructPositionFun::GetFunction() { ScalarFunction fun("struct_contains", {LogicalTypeId::STRUCT, LogicalType::ANY}, LogicalType::INTEGER, StructSearchFunction, StructContainsBind); - fun.null_handling = FunctionNullHandling::SPECIAL_HANDLING; + fun.SetNullHandling(FunctionNullHandling::SPECIAL_HANDLING); return fun; } diff --git a/src/duckdb/src/function/scalar/struct/struct_extract.cpp b/src/duckdb/src/function/scalar/struct/struct_extract.cpp index 23c5419cd..5da4a265e 100644 --- a/src/duckdb/src/function/scalar/struct/struct_extract.cpp +++ b/src/duckdb/src/function/scalar/struct/struct_extract.cpp @@ -83,7 +83,7 @@ static unique_ptr StructExtractBind(ClientContext &context, Scalar throw BinderException("Could not find key \"%s\" in struct\n%s", key, message); } - bound_function.return_type = std::move(return_type); + bound_function.SetReturnType(std::move(return_type)); return StructExtractAtFun::GetBindData(key_index); } @@ -120,7 +120,7 @@ static unique_ptr StructExtractBindInternal(ClientContext &context throw BinderException("Key index %lld for struct_extract out of range - expected an index between 1 and %llu", index, struct_children.size()); } - bound_function.return_type = struct_children[NumericCast(index - 1)].second; + bound_function.SetReturnType(struct_children[NumericCast(index - 1)].second); return StructExtractAtFun::GetBindData(NumericCast(index - 1)); } diff --git a/src/duckdb/src/function/scalar/struct/struct_pack.cpp b/src/duckdb/src/function/scalar/struct/struct_pack.cpp index dfbabcca0..f5993ab38 100644 --- a/src/duckdb/src/function/scalar/struct/struct_pack.cpp +++ b/src/duckdb/src/function/scalar/struct/struct_pack.cpp @@ -56,8 +56,8 @@ static unique_ptr StructPackBind(ClientContext &context, ScalarFun } // this is more for completeness reasons - bound_function.return_type = LogicalType::STRUCT(struct_children); - return make_uniq(bound_function.return_type); + bound_function.SetReturnType(LogicalType::STRUCT(struct_children)); + return make_uniq(bound_function.GetReturnType()); } static unique_ptr StructPackStats(ClientContext &context, FunctionStatisticsInput &input) { @@ -75,7 +75,7 @@ static ScalarFunction GetStructPackFunction() { ScalarFunction fun(IS_STRUCT_PACK ? "struct_pack" : "row", {}, LogicalTypeId::STRUCT, StructPackFunction, StructPackBind, nullptr, StructPackStats); fun.varargs = LogicalType::ANY; - fun.null_handling = FunctionNullHandling::SPECIAL_HANDLING; + fun.SetNullHandling(FunctionNullHandling::SPECIAL_HANDLING); fun.serialize = VariableReturnBindData::Serialize; fun.deserialize = VariableReturnBindData::Deserialize; return fun; diff --git a/src/duckdb/src/function/scalar/system/aggregate_export.cpp b/src/duckdb/src/function/scalar/system/aggregate_export.cpp index e6e8e22c0..7ad356723 100644 --- a/src/duckdb/src/function/scalar/system/aggregate_export.cpp +++ b/src/duckdb/src/function/scalar/system/aggregate_export.cpp @@ -185,7 +185,6 @@ void AggregateStateCombine(DataChunk &input, ExpressionState &state_p, Vector &r unique_ptr BindAggregateState(ClientContext &context, ScalarFunction &bound_function, vector> &arguments) { - // grab the aggregate type and bind the aggregate again // the aggregate name and types are in the logical type of the aggregate state, make sure its sane @@ -241,15 +240,16 @@ unique_ptr BindAggregateState(ClientContext &context, ScalarFuncti } } - if (bound_aggr.return_type != state_type.return_type || bound_aggr.arguments != state_type.bound_argument_types) { + if (bound_aggr.GetReturnType() != state_type.return_type || + bound_aggr.arguments != state_type.bound_argument_types) { throw InternalException("Type mismatch for exported aggregate %s", state_type.function_name); } if (bound_function.name == "finalize") { - bound_function.return_type = bound_aggr.return_type; + bound_function.SetReturnType(bound_aggr.GetReturnType()); } else { D_ASSERT(bound_function.name == "combine"); - bound_function.return_type = arg_return_type; + bound_function.SetReturnType(arg_return_type); } return make_uniq(bound_aggr, bound_aggr.state_size(bound_aggr)); @@ -304,14 +304,14 @@ ExportAggregateFunction::Bind(unique_ptr child_aggrega D_ASSERT(bound_function.state_size); D_ASSERT(bound_function.finalize); - D_ASSERT(child_aggregate->function.return_type.id() != LogicalTypeId::INVALID); + D_ASSERT(child_aggregate->function.GetReturnType().id() != LogicalTypeId::INVALID); #ifdef DEBUG for (auto &arg_type : child_aggregate->function.arguments) { D_ASSERT(arg_type.id() != LogicalTypeId::INVALID); } #endif auto export_bind_data = make_uniq(child_aggregate->Copy()); - aggregate_state_t state_type(child_aggregate->function.name, child_aggregate->function.return_type, + aggregate_state_t state_type(child_aggregate->function.name, child_aggregate->function.GetReturnType(), child_aggregate->function.arguments); auto return_type = LogicalType::AGGREGATE_STATE(std::move(state_type)); @@ -321,7 +321,7 @@ ExportAggregateFunction::Bind(unique_ptr child_aggrega bound_function.combine, ExportAggregateFinalize, bound_function.simple_update, /* can't bind this again */ nullptr, /* no dynamic state yet */ nullptr, /* can't propagate statistics */ nullptr, nullptr); - export_function.null_handling = FunctionNullHandling::SPECIAL_HANDLING; + export_function.SetNullHandling(FunctionNullHandling::SPECIAL_HANDLING); export_function.serialize = ExportStateAggregateSerialize; export_function.deserialize = ExportStateAggregateDeserialize; @@ -347,7 +347,7 @@ bool ExportAggregateFunctionBindData::Equals(const FunctionData &other_p) const ScalarFunction FinalizeFun::GetFunction() { auto result = ScalarFunction("finalize", {LogicalTypeId::AGGREGATE_STATE}, LogicalTypeId::INVALID, AggregateStateFinalize, BindAggregateState, nullptr, nullptr, InitFinalizeState); - result.null_handling = FunctionNullHandling::SPECIAL_HANDLING; + result.SetNullHandling(FunctionNullHandling::SPECIAL_HANDLING); result.serialize = ExportStateScalarSerialize; result.deserialize = ExportStateScalarDeserialize; return result; @@ -357,7 +357,7 @@ ScalarFunction CombineFun::GetFunction() { auto result = ScalarFunction("combine", {LogicalTypeId::AGGREGATE_STATE, LogicalTypeId::ANY}, LogicalTypeId::AGGREGATE_STATE, AggregateStateCombine, BindAggregateState, nullptr, nullptr, InitCombineState); - result.null_handling = FunctionNullHandling::SPECIAL_HANDLING; + result.SetNullHandling(FunctionNullHandling::SPECIAL_HANDLING); result.serialize = ExportStateScalarSerialize; result.deserialize = ExportStateScalarDeserialize; return result; diff --git a/src/duckdb/src/function/scalar/system/parse_log_message.cpp b/src/duckdb/src/function/scalar/system/parse_log_message.cpp index d5e336165..e6625a5c2 100644 --- a/src/duckdb/src/function/scalar/system/parse_log_message.cpp +++ b/src/duckdb/src/function/scalar/system/parse_log_message.cpp @@ -29,7 +29,6 @@ struct ParseLogMessageData : FunctionData { unique_ptr ParseLogMessageBind(ClientContext &context, ScalarFunction &bound_function, vector> &arguments) { - if (arguments.size() != 2) { throw BinderException("structured_log_schema: expects 1 argument", arguments[0]->alias); } @@ -53,9 +52,9 @@ unique_ptr ParseLogMessageBind(ClientContext &context, ScalarFunct if (!lookup->is_structured) { // Unstructured types we simply wrap in a struct with a single field called message child_list_t children = {{"message", LogicalType::VARCHAR}}; - bound_function.return_type = LogicalType::STRUCT(children); + bound_function.SetReturnType(LogicalType::STRUCT(children)); } else { - bound_function.return_type = lookup->type; + bound_function.SetReturnType(lookup->type); } return make_uniq(*lookup); diff --git a/src/duckdb/src/function/scalar/system/write_log.cpp b/src/duckdb/src/function/scalar/system/write_log.cpp index fa67ec089..6fd7aba8d 100644 --- a/src/duckdb/src/function/scalar/system/write_log.cpp +++ b/src/duckdb/src/function/scalar/system/write_log.cpp @@ -65,7 +65,7 @@ unique_ptr WriteLogBind(ClientContext &context, ScalarFunction &bo auto result = make_uniq(); // Default return type - bound_function.return_type = LogicalType::VARCHAR; + bound_function.SetReturnType(LogicalType::VARCHAR); for (idx_t i = 1; i < arguments.size(); i++) { auto &arg = arguments[i]; @@ -100,7 +100,7 @@ unique_ptr WriteLogBind(ClientContext &context, ScalarFunction &bo } else if (arg->alias == "return_value") { result->return_type = arg->return_type; result->output_col = i; - bound_function.return_type = result->return_type; + bound_function.SetReturnType(result->return_type); } else { throw BinderException(StringUtil::Format("write_log: Unknown argument '%s'", arg->alias)); } diff --git a/src/duckdb/src/function/scalar/variant/variant_extract.cpp b/src/duckdb/src/function/scalar/variant/variant_extract.cpp index e0c10fa73..76d5c84cd 100644 --- a/src/duckdb/src/function/scalar/variant/variant_extract.cpp +++ b/src/duckdb/src/function/scalar/variant/variant_extract.cpp @@ -12,6 +12,7 @@ struct BindData : public FunctionData { public: explicit BindData(const string &str); explicit BindData(uint32_t index); + BindData(const BindData &other) = default; public: unique_ptr Copy() const override; @@ -28,15 +29,15 @@ BindData::BindData(const string &str) : FunctionData() { component.key = str; } BindData::BindData(uint32_t index) : FunctionData() { + if (index == 0) { + throw BinderException("Extracting index 0 from VARIANT(ARRAY) is invalid, indexes are 1-based"); + } component.lookup_mode = VariantChildLookupMode::BY_INDEX; - component.index = index; + component.index = index - 1; } unique_ptr BindData::Copy() const { - if (component.lookup_mode == VariantChildLookupMode::BY_INDEX) { - return make_uniq(component.index); - } - return make_uniq(component.key); + return make_uniq(*this); } bool BindData::Equals(const FunctionData &other) const { @@ -142,22 +143,26 @@ static void VariantExtractFunction(DataChunk &input, ExpressionState &state, Vec } //! Look up the value_index of the child we're extracting - auto child_collection_result = - VariantUtils::FindChildValues(variant, component, optional_idx(), new_value_index_sel, nested_data, count); - if (!child_collection_result.Success()) { - if (child_collection_result.type == VariantChildDataCollectionResult::Type::INDEX_ZERO) { - throw InvalidInputException("Extracting index 0 from VARIANT(ARRAY) is invalid, indexes are 1-based"); + ValidityMask lookup_validity(count); + VariantUtils::FindChildValues(variant, component, nullptr, new_value_index_sel, lookup_validity, nested_data, + count); + if (!lookup_validity.AllValid()) { + optional_idx index; + for (idx_t i = 0; i < count; i++) { + if (!lookup_validity.RowIsValid(i)) { + index = i; + break; + } } + D_ASSERT(index.IsValid()); switch (component.lookup_mode) { case VariantChildLookupMode::BY_INDEX: { - D_ASSERT(child_collection_result.type == VariantChildDataCollectionResult::Type::COMPONENT_NOT_FOUND); - auto nested_index = child_collection_result.nested_data_index; + auto nested_index = index.GetIndex(); throw InvalidInputException("VARIANT(ARRAY(%d)) is missing index %d", nested_data[nested_index].child_count, component.index); } case VariantChildLookupMode::BY_KEY: { - D_ASSERT(child_collection_result.type == VariantChildDataCollectionResult::Type::COMPONENT_NOT_FOUND); - auto nested_index = child_collection_result.nested_data_index; + auto nested_index = index.GetIndex(); auto row_index = nested_index; auto object_keys = VariantUtils::GetObjectKeys(variant, row_index, nested_data[nested_index]); throw InvalidInputException("VARIANT(OBJECT(%s)) is missing key '%s'", StringUtil::Join(object_keys, ","), diff --git a/src/duckdb/src/function/scalar/variant/variant_normalize.cpp b/src/duckdb/src/function/scalar/variant/variant_normalize.cpp new file mode 100644 index 000000000..ef79e38e3 --- /dev/null +++ b/src/duckdb/src/function/scalar/variant/variant_normalize.cpp @@ -0,0 +1,311 @@ +#include "duckdb/function/scalar/variant_utils.hpp" +#include "duckdb/function/scalar/variant_functions.hpp" +#include "duckdb/function/scalar/regexp.hpp" +#include "duckdb/planner/expression/bound_function_expression.hpp" +#include "duckdb/execution/expression_executor.hpp" + +#include "duckdb/function/cast/variant/to_variant_fwd.hpp" +#include "duckdb/common/types/variant_visitor.hpp" + +namespace duckdb { + +namespace { + +struct VariantNormalizerState { +public: + VariantNormalizerState(idx_t result_row, VariantVectorData &source, OrderedOwningStringMap &dictionary, + SelectionVector &keys_selvec) + : source(source), dictionary(dictionary), keys_selvec(keys_selvec), + keys_index_validity(source.keys_index_validity) { + auto keys_list_entry = source.keys_data[result_row]; + auto values_list_entry = source.values_data[result_row]; + auto children_list_entry = source.children_data[result_row]; + + keys_offset = keys_list_entry.offset; + children_offset = children_list_entry.offset; + + blob_data = data_ptr_cast(source.blob_data[result_row].GetDataWriteable()); + type_ids = source.type_ids_data + values_list_entry.offset; + byte_offsets = source.byte_offset_data + values_list_entry.offset; + values_indexes = source.values_index_data + children_list_entry.offset; + keys_indexes = source.keys_index_data + children_list_entry.offset; + } + +public: + data_ptr_t GetDestination() { + return blob_data + blob_size; + } + uint32_t GetOrCreateIndex(const string_t &key) { + auto unsorted_idx = dictionary.size(); + //! This will later be remapped to the sorted idx (see FinalizeVariantKeys in 'to_variant.cpp') + return dictionary.emplace(std::make_pair(key, unsorted_idx)).first->second; + } + +public: + uint32_t keys_size = 0; + uint32_t children_size = 0; + uint32_t values_size = 0; + uint32_t blob_size = 0; + + VariantVectorData &source; + OrderedOwningStringMap &dictionary; + SelectionVector &keys_selvec; + + uint64_t keys_offset; + uint64_t children_offset; + ValidityMask &keys_index_validity; + + data_ptr_t blob_data; + uint8_t *type_ids; + uint32_t *byte_offsets; + uint32_t *values_indexes; + uint32_t *keys_indexes; +}; + +struct VariantNormalizer { + using result_type = void; + + static void VisitNull(VariantNormalizerState &state) { + return; + } + static void VisitBoolean(bool val, VariantNormalizerState &state) { + return; + } + + static void VisitMetadata(VariantLogicalType type_id, VariantNormalizerState &state) { + state.type_ids[state.values_size] = static_cast(type_id); + state.byte_offsets[state.values_size] = state.blob_size; + state.values_size++; + } + + template + static void VisitInteger(T val, VariantNormalizerState &state) { + Store(val, state.GetDestination()); + state.blob_size += sizeof(T); + } + static void VisitFloat(float val, VariantNormalizerState &state) { + VisitInteger(val, state); + } + static void VisitDouble(double val, VariantNormalizerState &state) { + VisitInteger(val, state); + } + static void VisitUUID(hugeint_t val, VariantNormalizerState &state) { + VisitInteger(val, state); + } + static void VisitDate(date_t val, VariantNormalizerState &state) { + VisitInteger(val, state); + } + static void VisitInterval(interval_t val, VariantNormalizerState &state) { + VisitInteger(val, state); + } + static void VisitTime(dtime_t val, VariantNormalizerState &state) { + VisitInteger(val, state); + } + static void VisitTimeNanos(dtime_ns_t val, VariantNormalizerState &state) { + VisitInteger(val, state); + } + static void VisitTimeTZ(dtime_tz_t val, VariantNormalizerState &state) { + VisitInteger(val, state); + } + static void VisitTimestampSec(timestamp_sec_t val, VariantNormalizerState &state) { + VisitInteger(val, state); + } + static void VisitTimestampMs(timestamp_ms_t val, VariantNormalizerState &state) { + VisitInteger(val, state); + } + static void VisitTimestamp(timestamp_t val, VariantNormalizerState &state) { + VisitInteger(val, state); + } + static void VisitTimestampNanos(timestamp_ns_t val, VariantNormalizerState &state) { + VisitInteger(val, state); + } + static void VisitTimestampTZ(timestamp_tz_t val, VariantNormalizerState &state) { + VisitInteger(val, state); + } + + static void WriteStringInternal(const string_t &str, VariantNormalizerState &state) { + } + + static void VisitString(const string_t &str, VariantNormalizerState &state) { + auto length = str.GetSize(); + state.blob_size += VarintEncode(length, state.GetDestination()); + memcpy(state.GetDestination(), str.GetData(), length); + state.blob_size += length; + } + static void VisitBlob(const string_t &blob, VariantNormalizerState &state) { + return VisitString(blob, state); + } + static void VisitBignum(const string_t &bignum, VariantNormalizerState &state) { + return VisitString(bignum, state); + } + static void VisitGeometry(const string_t &geom, VariantNormalizerState &state) { + return VisitString(geom, state); + } + static void VisitBitstring(const string_t &bits, VariantNormalizerState &state) { + return VisitString(bits, state); + } + + template + static void VisitDecimal(T val, uint32_t width, uint32_t scale, VariantNormalizerState &state) { + state.blob_size += VarintEncode(width, state.GetDestination()); + state.blob_size += VarintEncode(scale, state.GetDestination()); + Store(val, state.GetDestination()); + state.blob_size += sizeof(T); + } + + static void VisitArray(const UnifiedVariantVectorData &variant, idx_t row, const VariantNestedData &nested_data, + VariantNormalizerState &state) { + state.blob_size += VarintEncode(nested_data.child_count, state.GetDestination()); + if (!nested_data.child_count) { + return; + } + idx_t result_children_idx = state.children_size; + state.blob_size += VarintEncode(result_children_idx, state.GetDestination()); + state.children_size += nested_data.child_count; + + for (idx_t i = 0; i < nested_data.child_count; i++) { + auto source_children_idx = nested_data.children_idx + i; + auto values_index = variant.GetValuesIndex(row, source_children_idx); + + //! Set the 'values_index' for the child, and set the 'keys_index' to NULL + state.values_indexes[result_children_idx] = state.values_size; + state.keys_index_validity.SetInvalid(state.children_offset + result_children_idx); + result_children_idx++; + + //! Visit the child value + VariantVisitor::Visit(variant, row, values_index, state); + } + } + + static void VisitObject(const UnifiedVariantVectorData &variant, idx_t row, const VariantNestedData &nested_data, + VariantNormalizerState &state) { + state.blob_size += VarintEncode(nested_data.child_count, state.GetDestination()); + if (!nested_data.child_count) { + return; + } + uint32_t children_idx = state.children_size; + uint32_t keys_idx = state.keys_size; + state.blob_size += VarintEncode(children_idx, state.GetDestination()); + state.children_size += nested_data.child_count; + state.keys_size += nested_data.child_count; + + //! First iterate through all fields to populate the map of key -> field + map sorted_fields; + for (idx_t i = 0; i < nested_data.child_count; i++) { + auto keys_index = variant.GetKeysIndex(row, nested_data.children_idx + i); + auto &key = variant.GetKey(row, keys_index); + sorted_fields.emplace(key, i); + } + + //! Then visit the fields in sorted order + for (auto &entry : sorted_fields) { + auto source_children_idx = nested_data.children_idx + entry.second; + + //! Add the key of the field to the result + auto keys_index = variant.GetKeysIndex(row, source_children_idx); + auto &key = variant.GetKey(row, keys_index); + auto dict_index = state.GetOrCreateIndex(key); + state.keys_selvec.set_index(state.keys_offset + keys_idx, dict_index); + + //! Visit the child value + auto values_index = variant.GetValuesIndex(row, source_children_idx); + state.values_indexes[children_idx] = state.values_size; + state.keys_indexes[children_idx] = keys_idx; + children_idx++; + keys_idx++; + VariantVisitor::Visit(variant, row, values_index, state); + } + } + + static void VisitDefault(VariantLogicalType type_id, const_data_ptr_t, VariantNormalizerState &state) { + throw InternalException("VariantLogicalType(%s) not handled", EnumUtil::ToString(type_id)); + } +}; + +} // namespace + +static void VariantNormalizeFunction(DataChunk &input, ExpressionState &state, Vector &result) { + auto count = input.size(); + + D_ASSERT(input.ColumnCount() == 1); + auto &variant_vec = input.data[0]; + D_ASSERT(variant_vec.GetType() == LogicalType::VARIANT()); + + //! Set up the access helper for the source VARIANT + RecursiveUnifiedVectorFormat source_format; + Vector::RecursiveToUnifiedFormat(variant_vec, count, source_format); + UnifiedVariantVectorData variant(source_format); + + //! Take the original sizes of the lists, the result will be similar size, never bigger + auto original_keys_size = ListVector::GetListSize(VariantVector::GetKeys(variant_vec)); + auto original_children_size = ListVector::GetListSize(VariantVector::GetChildren(variant_vec)); + auto original_values_size = ListVector::GetListSize(VariantVector::GetValues(variant_vec)); + + auto &keys = VariantVector::GetKeys(result); + auto &children = VariantVector::GetChildren(result); + auto &values = VariantVector::GetValues(result); + auto &data = VariantVector::GetData(result); + + ListVector::Reserve(keys, original_keys_size); + ListVector::SetListSize(keys, 0); + ListVector::Reserve(children, original_children_size); + ListVector::SetListSize(children, 0); + ListVector::Reserve(values, original_values_size); + ListVector::SetListSize(values, 0); + + //! Initialize the dictionary + auto &keys_entry = ListVector::GetEntry(keys); + OrderedOwningStringMap dictionary(StringVector::GetStringBuffer(keys_entry).GetStringAllocator()); + + VariantVectorData variant_data(result); + SelectionVector keys_selvec; + keys_selvec.Initialize(original_keys_size); + + for (idx_t i = 0; i < count; i++) { + if (!variant.RowIsValid(i)) { + FlatVector::SetNull(result, i, true); + continue; + } + //! Allocate for the new data, use the same size as source + auto &blob_data = variant_data.blob_data[i]; + auto original_data = variant.GetData(i); + blob_data = StringVector::EmptyString(data, original_data.GetSize()); + + auto &keys_list_entry = variant_data.keys_data[i]; + keys_list_entry.offset = ListVector::GetListSize(keys); + + auto &children_list_entry = variant_data.children_data[i]; + children_list_entry.offset = ListVector::GetListSize(children); + + auto &values_list_entry = variant_data.values_data[i]; + values_list_entry.offset = ListVector::GetListSize(values); + + //! Visit the source to populate the result + VariantNormalizerState visitor_state(i, variant_data, dictionary, keys_selvec); + VariantVisitor::Visit(variant, i, 0, visitor_state); + + blob_data.SetSizeAndFinalize(visitor_state.blob_size, original_data.GetSize()); + keys_list_entry.length = visitor_state.keys_size; + children_list_entry.length = visitor_state.children_size; + values_list_entry.length = visitor_state.values_size; + + ListVector::SetListSize(keys, ListVector::GetListSize(keys) + visitor_state.keys_size); + ListVector::SetListSize(children, ListVector::GetListSize(children) + visitor_state.children_size); + ListVector::SetListSize(values, ListVector::GetListSize(values) + visitor_state.values_size); + } + + VariantUtils::FinalizeVariantKeys(result, dictionary, keys_selvec, ListVector::GetListSize(keys)); + keys_entry.Slice(keys_selvec, ListVector::GetListSize(keys)); + + if (input.AllConstant()) { + result.SetVectorType(VectorType::CONSTANT_VECTOR); + } + result.Verify(count); +} + +ScalarFunction VariantNormalizeFun::GetFunction() { + auto variant_type = LogicalType::VARIANT(); + return ScalarFunction("variant_normalize", {variant_type}, variant_type, VariantNormalizeFunction); +} + +} // namespace duckdb diff --git a/src/duckdb/src/function/scalar/variant/variant_typeof.cpp b/src/duckdb/src/function/scalar/variant/variant_typeof.cpp index d1767736e..19526a653 100644 --- a/src/duckdb/src/function/scalar/variant/variant_typeof.cpp +++ b/src/duckdb/src/function/scalar/variant/variant_typeof.cpp @@ -63,7 +63,7 @@ static void VariantTypeofFunction(DataChunk &input, ExpressionState &state, Vect ScalarFunction VariantTypeofFun::GetFunction() { auto variant_type = LogicalType::VARIANT(); auto res = ScalarFunction("variant_typeof", {variant_type}, LogicalType::VARCHAR, VariantTypeofFunction); - res.null_handling = FunctionNullHandling::SPECIAL_HANDLING; + res.SetNullHandling(FunctionNullHandling::SPECIAL_HANDLING); return res; } diff --git a/src/duckdb/src/function/scalar/variant/variant_utils.cpp b/src/duckdb/src/function/scalar/variant/variant_utils.cpp index 44a370251..c2e118204 100644 --- a/src/duckdb/src/function/scalar/variant/variant_utils.cpp +++ b/src/duckdb/src/function/scalar/variant/variant_utils.cpp @@ -4,9 +4,22 @@ #include "duckdb/common/types/string_type.hpp" #include "duckdb/common/types/decimal.hpp" #include "duckdb/common/serializer/varint.hpp" +#include "duckdb/common/types/variant_visitor.hpp" namespace duckdb { +PhysicalType VariantDecimalData::GetPhysicalType() const { + if (width > DecimalWidth::max) { + return PhysicalType::INT128; + } else if (width > DecimalWidth::max) { + return PhysicalType::INT64; + } else if (width > DecimalWidth::max) { + return PhysicalType::INT32; + } else { + return PhysicalType::INT16; + } +} + bool VariantUtils::IsNestedType(const UnifiedVariantVectorData &variant, idx_t row, uint32_t value_index) { auto type_id = variant.GetTypeId(row, value_index); return type_id == VariantLogicalType::ARRAY || type_id == VariantLogicalType::OBJECT; @@ -19,10 +32,19 @@ VariantDecimalData VariantUtils::DecodeDecimalData(const UnifiedVariantVectorDat auto data = const_data_ptr_cast(variant.GetData(row).GetData()); auto ptr = data + byte_offset; - VariantDecimalData result; - result.width = VarintDecode(ptr); - result.scale = VarintDecode(ptr); - return result; + auto width = VarintDecode(ptr); + auto scale = VarintDecode(ptr); + auto value_ptr = ptr; + return VariantDecimalData(width, scale, value_ptr); +} + +string_t VariantUtils::DecodeStringData(const UnifiedVariantVectorData &variant, idx_t row, uint32_t value_index) { + auto byte_offset = variant.GetByteOffset(row, value_index); + auto data = const_data_ptr_cast(variant.GetData(row).GetData()); + auto ptr = data + byte_offset; + + auto length = VarintDecode(ptr); + return string_t(reinterpret_cast(ptr), length); } VariantNestedData VariantUtils::DecodeNestedData(const UnifiedVariantVectorData &variant, idx_t row, @@ -53,13 +75,12 @@ vector VariantUtils::GetObjectKeys(const UnifiedVariantVectorData &varia return object_keys; } -VariantChildDataCollectionResult VariantUtils::FindChildValues(const UnifiedVariantVectorData &variant, - const VariantPathComponent &component, optional_idx row, - SelectionVector &res, VariantNestedData *nested_data, - idx_t count) { - +//! FIXME: this shouldn't return a "result", it should populate a validity mask instead. +void VariantUtils::FindChildValues(const UnifiedVariantVectorData &variant, const VariantPathComponent &component, + optional_ptr sel, SelectionVector &res, + ValidityMask &res_validity, VariantNestedData *nested_data, idx_t count) { for (idx_t i = 0; i < count; i++) { - auto row_index = row.IsValid() ? row.GetIndex() : i; + auto row_index = sel ? sel->get_index(i) : i; auto &nested_data_entry = nested_data[i]; if (nested_data_entry.is_null) { @@ -67,13 +88,10 @@ VariantChildDataCollectionResult VariantUtils::FindChildValues(const UnifiedVari } if (component.lookup_mode == VariantChildLookupMode::BY_INDEX) { auto child_idx = component.index; - if (child_idx == 0) { - return VariantChildDataCollectionResult::IndexZero(); - } - child_idx--; if (child_idx >= nested_data_entry.child_count) { //! The list is too small to contain this index - return VariantChildDataCollectionResult::NotFound(i); + res_validity.SetInvalid(i); + continue; } auto value_id = variant.GetValuesIndex(row_index, nested_data_entry.children_idx + child_idx); res[i] = static_cast(value_id); @@ -93,10 +111,9 @@ VariantChildDataCollectionResult VariantUtils::FindChildValues(const UnifiedVari } } if (!found_child) { - return VariantChildDataCollectionResult::NotFound(i); + res_validity.SetInvalid(i); } } - return VariantChildDataCollectionResult(); } vector VariantUtils::ValueIsNull(const UnifiedVariantVectorData &variant, const SelectionVector &sel, @@ -146,133 +163,204 @@ VariantUtils::CollectNestedData(const UnifiedVariantVectorData &variant, Variant return VariantNestedDataCollectionResult(); } -Value VariantUtils::ConvertVariantToValue(const UnifiedVariantVectorData &variant, idx_t row, idx_t values_idx) { - if (!variant.RowIsValid(row)) { - return Value(LogicalTypeId::SQLNULL); +namespace { + +struct ValueConverter { + using result_type = Value; + + static Value VisitNull() { + return Value(LogicalType::SQLNULL); } - //! The 'values' data of the value we're currently converting - auto type_id = variant.GetTypeId(row, values_idx); - auto byte_offset = variant.GetByteOffset(row, values_idx); + static Value VisitBoolean(bool val) { + return Value::BOOLEAN(val); + } - //! The blob data of the Variant, accessed by byte offset retrieved above ^ - auto blob_data = const_data_ptr_cast(variant.GetData(row).GetData()); + template + static Value VisitInteger(T val) { + throw InternalException("ValueConverter::VisitInteger not implemented!"); + } - auto ptr = const_data_ptr_cast(blob_data + byte_offset); - switch (type_id) { - case VariantLogicalType::VARIANT_NULL: - return Value(LogicalType::SQLNULL); - case VariantLogicalType::BOOL_TRUE: - return Value::BOOLEAN(true); - case VariantLogicalType::BOOL_FALSE: - return Value::BOOLEAN(false); - case VariantLogicalType::INT8: - return Value::TINYINT(Load(ptr)); - case VariantLogicalType::INT16: - return Value::SMALLINT(Load(ptr)); - case VariantLogicalType::INT32: - return Value::INTEGER(Load(ptr)); - case VariantLogicalType::INT64: - return Value::BIGINT(Load(ptr)); - case VariantLogicalType::INT128: - return Value::HUGEINT(Load(ptr)); - case VariantLogicalType::UINT8: - return Value::UTINYINT(Load(ptr)); - case VariantLogicalType::UINT16: - return Value::USMALLINT(Load(ptr)); - case VariantLogicalType::UINT32: - return Value::UINTEGER(Load(ptr)); - case VariantLogicalType::UINT64: - return Value::UBIGINT(Load(ptr)); - case VariantLogicalType::UINT128: - return Value::UHUGEINT(Load(ptr)); - case VariantLogicalType::UUID: - return Value::UUID(Load(ptr)); - case VariantLogicalType::INTERVAL: - return Value::INTERVAL(Load(ptr)); - case VariantLogicalType::FLOAT: - return Value::FLOAT(Load(ptr)); - case VariantLogicalType::DOUBLE: - return Value::DOUBLE(Load(ptr)); - case VariantLogicalType::DATE: - return Value::DATE(date_t(Load(ptr))); - case VariantLogicalType::BLOB: { - auto string_length = VarintDecode(ptr); - auto string_data = reinterpret_cast(ptr); - return Value::BLOB(const_data_ptr_cast(string_data), string_length); - } - case VariantLogicalType::VARCHAR: { - auto string_length = VarintDecode(ptr); - auto string_data = reinterpret_cast(ptr); - return Value(string_t(string_data, string_length)); - } - case VariantLogicalType::DECIMAL: { - auto width = NumericCast(VarintDecode(ptr)); - auto scale = NumericCast(VarintDecode(ptr)); - - if (width > DecimalWidth::max) { - return Value::DECIMAL(Load(ptr), width, scale); - } else if (width > DecimalWidth::max) { - return Value::DECIMAL(Load(ptr), width, scale); - } else if (width > DecimalWidth::max) { - return Value::DECIMAL(Load(ptr), width, scale); + static Value VisitTime(dtime_t val) { + return Value::TIME(val); + } + + static Value VisitTimeNanos(dtime_ns_t val) { + return Value::TIME_NS(val); + } + + static Value VisitTimeTZ(dtime_tz_t val) { + return Value::TIMETZ(val); + } + + static Value VisitTimestampSec(timestamp_sec_t val) { + return Value::TIMESTAMPSEC(val); + } + + static Value VisitTimestampMs(timestamp_ms_t val) { + return Value::TIMESTAMPMS(val); + } + + static Value VisitTimestamp(timestamp_t val) { + return Value::TIMESTAMP(val); + } + + static Value VisitTimestampNanos(timestamp_ns_t val) { + return Value::TIMESTAMPNS(val); + } + + static Value VisitTimestampTZ(timestamp_tz_t val) { + return Value::TIMESTAMPTZ(val); + } + + static Value VisitFloat(float val) { + return Value::FLOAT(val); + } + static Value VisitDouble(double val) { + return Value::DOUBLE(val); + } + static Value VisitUUID(hugeint_t val) { + return Value::UUID(val); + } + static Value VisitDate(date_t val) { + return Value::DATE(val); + } + static Value VisitInterval(interval_t val) { + return Value::INTERVAL(val); + } + + static Value VisitString(const string_t &str) { + return Value(str); + } + static Value VisitBlob(const string_t &str) { + return Value::BLOB(const_data_ptr_cast(str.GetData()), str.GetSize()); + } + static Value VisitBignum(const string_t &str) { + return Value::BIGNUM(const_data_ptr_cast(str.GetData()), str.GetSize()); + } + static Value VisitGeometry(const string_t &str) { + return Value::GEOMETRY(const_data_ptr_cast(str.GetData()), str.GetSize()); + } + static Value VisitBitstring(const string_t &str) { + return Value::BIT(const_data_ptr_cast(str.GetData()), str.GetSize()); + } + + template + static Value VisitDecimal(T val, uint32_t width, uint32_t scale) { + if (std::is_same::value) { + return Value::DECIMAL(val, static_cast(width), static_cast(scale)); + } else if (std::is_same::value) { + return Value::DECIMAL(val, static_cast(width), static_cast(scale)); + } else if (std::is_same::value) { + return Value::DECIMAL(val, static_cast(width), static_cast(scale)); + } else if (std::is_same::value) { + return Value::DECIMAL(val, static_cast(width), static_cast(scale)); } else { - return Value::DECIMAL(Load(ptr), width, scale); + throw InternalException("Unhandled decimal type"); } } - case VariantLogicalType::TIME_MICROS: - return Value::TIME(Load(ptr)); - case VariantLogicalType::TIME_MICROS_TZ: - return Value::TIMETZ(Load(ptr)); - case VariantLogicalType::TIMESTAMP_MICROS: - return Value::TIMESTAMP(Load(ptr)); - case VariantLogicalType::TIMESTAMP_SEC: - return Value::TIMESTAMPSEC(Load(ptr)); - case VariantLogicalType::TIMESTAMP_NANOS: - return Value::TIMESTAMPNS(Load(ptr)); - case VariantLogicalType::TIMESTAMP_MILIS: - return Value::TIMESTAMPMS(Load(ptr)); - case VariantLogicalType::TIMESTAMP_MICROS_TZ: - return Value::TIMESTAMPTZ(Load(ptr)); - case VariantLogicalType::ARRAY: { - auto count = VarintDecode(ptr); - vector array_items; - if (count) { - auto child_index_start = VarintDecode(ptr); - for (idx_t i = 0; i < count; i++) { - auto child_index = variant.GetValuesIndex(row, child_index_start + i); - array_items.emplace_back(ConvertVariantToValue(variant, row, child_index)); - } - } + + static Value VisitArray(const UnifiedVariantVectorData &variant, idx_t row, const VariantNestedData &nested_data) { + auto array_items = VariantVisitor::VisitArrayItems(variant, row, nested_data); return Value::LIST(LogicalType::VARIANT(), std::move(array_items)); } - case VariantLogicalType::OBJECT: { - auto count = VarintDecode(ptr); - child_list_t object_children; - if (count) { - auto child_index_start = VarintDecode(ptr); - for (idx_t i = 0; i < count; i++) { - auto child_value_idx = variant.GetValuesIndex(row, child_index_start + i); - auto val = ConvertVariantToValue(variant, row, child_value_idx); - auto child_key_id = variant.GetKeysIndex(row, child_index_start + i); - auto &key = variant.GetKey(row, child_key_id); - - object_children.emplace_back(key.GetString(), std::move(val)); - } - } + static Value VisitObject(const UnifiedVariantVectorData &variant, idx_t row, const VariantNestedData &nested_data) { + auto object_children = VariantVisitor::VisitObjectItems(variant, row, nested_data); return Value::STRUCT(std::move(object_children)); } - case VariantLogicalType::BITSTRING: { - auto string_length = VarintDecode(ptr); - return Value::BIT(ptr, string_length); + + static Value VisitDefault(VariantLogicalType type_id, const_data_ptr_t) { + throw InternalException("VariantLogicalType(%s) not handled", EnumUtil::ToString(type_id)); } - case VariantLogicalType::BIGNUM: { - auto string_length = VarintDecode(ptr); - return Value::BIGNUM(ptr, string_length); +}; + +template <> +Value ValueConverter::VisitInteger(int8_t val) { + return Value::TINYINT(val); +} + +template <> +Value ValueConverter::VisitInteger(int16_t val) { + return Value::SMALLINT(val); +} + +template <> +Value ValueConverter::VisitInteger(int32_t val) { + return Value::INTEGER(val); +} + +template <> +Value ValueConverter::VisitInteger(int64_t val) { + return Value::BIGINT(val); +} + +template <> +Value ValueConverter::VisitInteger(hugeint_t val) { + return Value::HUGEINT(val); +} + +template <> +Value ValueConverter::VisitInteger(uint8_t val) { + return Value::UTINYINT(val); +} + +template <> +Value ValueConverter::VisitInteger(uint16_t val) { + return Value::USMALLINT(val); +} + +template <> +Value ValueConverter::VisitInteger(uint32_t val) { + return Value::UINTEGER(val); +} + +template <> +Value ValueConverter::VisitInteger(uint64_t val) { + return Value::UBIGINT(val); +} + +template <> +Value ValueConverter::VisitInteger(uhugeint_t val) { + return Value::UHUGEINT(val); +} + +} // namespace + +Value VariantUtils::ConvertVariantToValue(const UnifiedVariantVectorData &variant, idx_t row, uint32_t values_idx) { + return VariantVisitor::Visit(variant, row, values_idx); +} + +void VariantUtils::FinalizeVariantKeys(Vector &variant, OrderedOwningStringMap &dictionary, + SelectionVector &sel, idx_t sel_size) { + auto &keys = VariantVector::GetKeys(variant); + auto &keys_entry = ListVector::GetEntry(keys); + auto keys_entry_data = FlatVector::GetData(keys_entry); + + bool already_sorted = true; + + vector unsorted_to_sorted(dictionary.size()); + auto it = dictionary.begin(); + for (uint32_t sorted_idx = 0; sorted_idx < dictionary.size(); sorted_idx++) { + auto unsorted_idx = it->second; + if (unsorted_idx != sorted_idx) { + already_sorted = false; + } + unsorted_to_sorted[unsorted_idx] = sorted_idx; + D_ASSERT(sorted_idx < ListVector::GetListSize(keys)); + keys_entry_data[sorted_idx] = it->first; + auto size = static_cast(keys_entry_data[sorted_idx].GetSize()); + keys_entry_data[sorted_idx].SetSizeAndFinalize(size, size); + it++; } - default: - throw InternalException("VariantLogicalType(%d) not handled", static_cast(type_id)); + + if (!already_sorted) { + //! Adjust the selection vector to point to the right dictionary index + for (idx_t i = 0; i < sel_size; i++) { + auto &entry = sel[i]; + auto sorted_idx = unsorted_to_sorted[entry]; + entry = sorted_idx; + } } } diff --git a/src/duckdb/src/function/table/arrow_conversion.cpp b/src/duckdb/src/function/table/arrow_conversion.cpp index e194852f0..65f617e83 100644 --- a/src/duckdb/src/function/table/arrow_conversion.cpp +++ b/src/duckdb/src/function/table/arrow_conversion.cpp @@ -257,7 +257,6 @@ static void ArrowToDuckDBList(Vector &vector, ArrowArray &array, idx_t chunk_off static void ArrowToDuckDBArray(Vector &vector, ArrowArray &array, idx_t chunk_offset, ArrowArrayScanState &array_state, idx_t size, const ArrowType &arrow_type, int64_t nested_offset, const ValidityMask *parent_mask, int64_t parent_offset) { - auto &array_info = arrow_type.GetTypeInfo(); auto array_size = array_info.FixedSize(); auto child_count = array_size * size; @@ -695,7 +694,6 @@ template void ConvertDecimal(SRC src_ptr, Vector &vector, ArrowArray &array, idx_t size, int64_t nested_offset, uint64_t parent_offset, idx_t chunk_offset, ValidityMask &val_mask, DecimalBitWidth arrow_bit_width) { - switch (vector.GetType().InternalType()) { case PhysicalType::INT16: { auto tgt_ptr = FlatVector::GetData(vector); @@ -1184,7 +1182,6 @@ static void SetSelectionVectorLoop(SelectionVector &sel, data_ptr_t indices_p, i template static void SetSelectionVectorLoopWithChecks(SelectionVector &sel, data_ptr_t indices_p, idx_t size) { - auto indices = reinterpret_cast(indices_p); for (idx_t row = 0; row < size; row++) { if (indices[row] > NumericLimits::Maximum()) { diff --git a/src/duckdb/src/function/table/copy_csv.cpp b/src/duckdb/src/function/table/copy_csv.cpp index 600e50f49..1ffd6e7ee 100644 --- a/src/duckdb/src/function/table/copy_csv.cpp +++ b/src/duckdb/src/function/table/copy_csv.cpp @@ -280,7 +280,31 @@ struct GlobalWriteCSVData : public GlobalFunctionData { return writer.FileSize(); } + unique_ptr GetLocalState(ClientContext &context, const idx_t flush_size) { + { + lock_guard guard(local_state_lock); + if (!local_states.empty()) { + auto result = std::move(local_states.back()); + local_states.pop_back(); + return result; + } + } + auto result = make_uniq(context, flush_size); + result->require_manual_flush = true; + return result; + } + + void StoreLocalState(unique_ptr lstate) { + lock_guard guard(local_state_lock); + lstate->Reset(); + local_states.push_back(std::move(lstate)); + } + CSVWriter writer; + +private: + mutex local_state_lock; + vector> local_states; }; static unique_ptr WriteCSVInitializeLocal(ExecutionContext &context, FunctionData &bind_data) { @@ -371,9 +395,7 @@ CopyFunctionExecutionMode WriteCSVExecutionMode(bool preserve_insertion_order, b // Prepare Batch //===--------------------------------------------------------------------===// struct WriteCSVBatchData : public PreparedBatchData { - explicit WriteCSVBatchData(ClientContext &context, const idx_t flush_size) - : writer_local_state(make_uniq(context, flush_size)) { - writer_local_state->require_manual_flush = true; + explicit WriteCSVBatchData(unique_ptr writer_state) : writer_local_state(std::move(writer_state)) { } //! The thread-local buffer to write data into @@ -397,7 +419,8 @@ unique_ptr WriteCSVPrepareBatch(ClientContext &context, Funct auto &global_state = gstate.Cast(); // write CSV chunks to the batch data - auto batch = make_uniq(context, NextPowerOfTwo(collection->SizeInBytes())); + auto local_writer_state = global_state.GetLocalState(context, NextPowerOfTwo(collection->SizeInBytes())); + auto batch = make_uniq(std::move(local_writer_state)); for (auto &chunk : collection->Chunks()) { WriteCSVChunkInternal(global_state.writer, *batch->writer_local_state, cast_chunk, chunk, executor); } @@ -412,6 +435,7 @@ void WriteCSVFlushBatch(ClientContext &context, FunctionData &bind_data, GlobalF auto &csv_batch = batch.Cast(); auto &global_state = gstate.Cast(); global_state.writer.Flush(*csv_batch.writer_local_state); + global_state.StoreLocalState(std::move(csv_batch.writer_local_state)); } //===--------------------------------------------------------------------===// diff --git a/src/duckdb/src/function/table/direct_file_reader.cpp b/src/duckdb/src/function/table/direct_file_reader.cpp index 8aa6aba35..e28d78218 100644 --- a/src/duckdb/src/function/table/direct_file_reader.cpp +++ b/src/duckdb/src/function/table/direct_file_reader.cpp @@ -44,14 +44,16 @@ static inline void VERIFY(const string &filename, const string_t &content) { } } -void DirectFileReader::Scan(ClientContext &context, GlobalTableFunctionState &global_state, - LocalTableFunctionState &local_state, DataChunk &output) { +AsyncResult DirectFileReader::Scan(ClientContext &context, GlobalTableFunctionState &global_state, + LocalTableFunctionState &local_state, DataChunk &output) { auto &state = global_state.Cast(); if (done || file_list_idx.GetIndex() >= state.file_list->GetTotalFileCount()) { - return; + return AsyncResult(SourceResultType::FINISHED); } auto files = state.file_list; + + auto ®ular_fs = FileSystem::GetFileSystem(context); auto fs = CachingFileSystem::Get(context); idx_t out_idx = 0; @@ -65,6 +67,14 @@ void DirectFileReader::Scan(ClientContext &context, GlobalTableFunctionState &gl flags |= FileFlags::FILE_FLAGS_DIRECT_IO; } file_handle = fs.OpenFile(QueryContext(context), file, flags); + } else { + // At least verify that the file exist + // The globbing behavior in remote filesystems can lead to files being listed that do not actually exist + if (FileSystem::IsRemoteFile(file.path) && !regular_fs.FileExists(file.path)) { + output.SetCardinality(0); + done = true; + return SourceResultType::FINISHED; + } } for (idx_t col_idx = 0; col_idx < state.column_ids.size(); col_idx++) { @@ -163,6 +173,7 @@ void DirectFileReader::Scan(ClientContext &context, GlobalTableFunctionState &gl } output.SetCardinality(1); done = true; + return AsyncResult(SourceResultType::HAVE_MORE_OUTPUT); }; void DirectFileReader::FinishFile(ClientContext &context, GlobalTableFunctionState &gstate) { diff --git a/src/duckdb/src/function/table/read_duckdb.cpp b/src/duckdb/src/function/table/read_duckdb.cpp index c68f1c32e..4f8cfac66 100644 --- a/src/duckdb/src/function/table/read_duckdb.cpp +++ b/src/duckdb/src/function/table/read_duckdb.cpp @@ -87,8 +87,8 @@ class DuckDBReader : public BaseFileReader { public: bool TryInitializeScan(ClientContext &context, GlobalTableFunctionState &gstate, LocalTableFunctionState &lstate) override; - void Scan(ClientContext &context, GlobalTableFunctionState &global_state, LocalTableFunctionState &local_state, - DataChunk &chunk) override; + AsyncResult Scan(ClientContext &context, GlobalTableFunctionState &global_state, + LocalTableFunctionState &local_state, DataChunk &chunk) override; shared_ptr GetUnionData(idx_t file_idx) override; void FinishFile(ClientContext &context, GlobalTableFunctionState &gstate) override; double GetProgressInFile(ClientContext &context) override; @@ -300,14 +300,38 @@ bool DuckDBReader::TryInitializeScan(ClientContext &context, GlobalTableFunction return true; } -void DuckDBReader::Scan(ClientContext &context, GlobalTableFunctionState &gstate_p, LocalTableFunctionState &lstate_p, - DataChunk &chunk) { +AsyncResult DuckDBReader::Scan(ClientContext &context, GlobalTableFunctionState &gstate_p, + LocalTableFunctionState &lstate_p, DataChunk &chunk) { chunk.Reset(); auto &lstate = lstate_p.Cast(); TableFunctionInput input(bind_data.get(), lstate.local_state, global_state); - scan_function.function(context, input, chunk); - if (chunk.size() == 0) { - finished = true; + + if (!scan_function.function) { + throw InternalException("DuckDBReader works only with simple table functions"); + } else { + input.async_result = AsyncResultType::IMPLICIT; + input.results_execution_mode = AsyncResultsExecutionMode::TASK_EXECUTOR; + scan_function.function(context, input, chunk); + + switch (input.async_result.GetResultType()) { + case AsyncResultType::BLOCKED: + return std::move(input.async_result); + case AsyncResultType::HAVE_MORE_OUTPUT: + return SourceResultType::HAVE_MORE_OUTPUT; + case AsyncResultType::IMPLICIT: + if (chunk.size() > 0) { + return SourceResultType::HAVE_MORE_OUTPUT; + } + finished = true; + return SourceResultType::FINISHED; + case AsyncResultType::FINISHED: + finished = true; + return SourceResultType::FINISHED; + default: + throw InternalException("DuckDBReader call of scan_function.function returned unexpected return '%'", + EnumUtil::ToChars(input.async_result.GetResultType())); + } + throw InternalException("DuckDBReader hasn't handled a scan_function.function return"); } } diff --git a/src/duckdb/src/function/table/read_file.cpp b/src/duckdb/src/function/table/read_file.cpp index d0481cc23..d929e8074 100644 --- a/src/duckdb/src/function/table/read_file.cpp +++ b/src/duckdb/src/function/table/read_file.cpp @@ -10,10 +10,43 @@ namespace duckdb { +namespace { + //------------------------------------------------------------------------------ // DirectMultiFileInfo //------------------------------------------------------------------------------ +template +struct DirectMultiFileInfo : MultiFileReaderInterface { + static unique_ptr CreateInterface(ClientContext &context); + unique_ptr InitializeOptions(ClientContext &context, + optional_ptr info) override; + bool ParseCopyOption(ClientContext &context, const string &key, const vector &values, + BaseFileReaderOptions &options, vector &expected_names, + vector &expected_types) override; + bool ParseOption(ClientContext &context, const string &key, const Value &val, MultiFileOptions &file_options, + BaseFileReaderOptions &options) override; + unique_ptr InitializeBindData(MultiFileBindData &multi_file_data, + unique_ptr options) override; + void BindReader(ClientContext &context, vector &return_types, vector &names, + MultiFileBindData &bind_data) override; + optional_idx MaxThreads(const MultiFileBindData &bind_data_p, const MultiFileGlobalState &global_state, + FileExpandResult expand_result) override; + unique_ptr InitializeGlobalState(ClientContext &context, MultiFileBindData &bind_data, + MultiFileGlobalState &global_state) override; + unique_ptr InitializeLocalState(ExecutionContext &, GlobalTableFunctionState &) override; + shared_ptr CreateReader(ClientContext &context, GlobalTableFunctionState &gstate, + BaseUnionData &union_data, const MultiFileBindData &bind_data_p) override; + shared_ptr CreateReader(ClientContext &context, GlobalTableFunctionState &gstate, + const OpenFileInfo &file, idx_t file_idx, + const MultiFileBindData &bind_data) override; + shared_ptr CreateReader(ClientContext &context, const OpenFileInfo &file, + BaseFileReaderOptions &options, + const MultiFileOptions &file_options) override; + unique_ptr GetCardinality(const MultiFileBindData &bind_data, idx_t file_count) override; + FileGlobInput GetGlobInput() override; +}; + template unique_ptr DirectMultiFileInfo::CreateInterface(ClientContext &context) { return make_uniq(); @@ -132,14 +165,45 @@ FileGlobInput DirectMultiFileInfo::GetGlobInput() { } //------------------------------------------------------------------------------ -// Register +// Operations //------------------------------------------------------------------------------ + +struct ReadBlobOperation { + static constexpr const char *NAME = "read_blob"; + static constexpr const char *FILE_TYPE = "blob"; + + static inline LogicalType TYPE() { + return LogicalType::BLOB; + } +}; + +struct ReadTextOperation { + static constexpr const char *NAME = "read_text"; + static constexpr const char *FILE_TYPE = "text"; + + static inline LogicalType TYPE() { + return LogicalType::VARCHAR; + } +}; + template static TableFunction GetFunction() { MultiFileFunction> table_function(OP::NAME); + // Erase extra multi file reader options + table_function.named_parameters.erase("filename"); + table_function.named_parameters.erase("hive_partitioning"); + table_function.named_parameters.erase("union_by_name"); + table_function.named_parameters.erase("hive_types"); + table_function.named_parameters.erase("hive_types_autocast"); return table_function; } +} // namespace + +//------------------------------------------------------------------------------ +// Register +//------------------------------------------------------------------------------ + void ReadBlobFunction::RegisterFunction(BuiltinFunctions &set) { auto scan_fun = GetFunction(); set.AddFunction(MultiFileReader::CreateFunctionSet(scan_fun)); diff --git a/src/duckdb/src/function/table/summary.cpp b/src/duckdb/src/function/table/summary.cpp index d6c4615e4..8c12148ca 100644 --- a/src/duckdb/src/function/table/summary.cpp +++ b/src/duckdb/src/function/table/summary.cpp @@ -9,7 +9,6 @@ namespace duckdb { static unique_ptr SummaryFunctionBind(ClientContext &context, TableFunctionBindInput &input, vector &return_types, vector &names) { - return_types.emplace_back(LogicalType::VARCHAR); names.emplace_back("summary"); diff --git a/src/duckdb/src/function/table/system/duckdb_columns.cpp b/src/duckdb/src/function/table/system/duckdb_columns.cpp index fe958ea3f..ff14fdd73 100644 --- a/src/duckdb/src/function/table/system/duckdb_columns.cpp +++ b/src/duckdb/src/function/table/system/duckdb_columns.cpp @@ -196,7 +196,8 @@ unique_ptr ColumnHelper::Create(CatalogEntry &entry) { case CatalogType::VIEW_ENTRY: return make_uniq(entry.Cast()); default: - throw NotImplementedException("Unsupported catalog type for duckdb_columns"); + throw NotImplementedException({{"catalog_type", CatalogTypeToString(entry.type)}}, + "Unsupported catalog type for duckdb_columns"); } } diff --git a/src/duckdb/src/function/table/system/duckdb_connection_count.cpp b/src/duckdb/src/function/table/system/duckdb_connection_count.cpp new file mode 100644 index 000000000..ce7857f3b --- /dev/null +++ b/src/duckdb/src/function/table/system/duckdb_connection_count.cpp @@ -0,0 +1,45 @@ +#include "duckdb/function/table/system_functions.hpp" + +#include "duckdb/main/client_context.hpp" +#include "duckdb/main/database.hpp" +#include "duckdb/main/connection_manager.hpp" + +namespace duckdb { + +struct DuckDBConnectionCountData : public GlobalTableFunctionState { + DuckDBConnectionCountData() : count(0), finished(false) { + } + idx_t count; + bool finished; +}; + +static unique_ptr DuckDBConnectionCountBind(ClientContext &context, TableFunctionBindInput &input, + vector &return_types, vector &names) { + names.emplace_back("count"); + return_types.emplace_back(LogicalType::UBIGINT); + return nullptr; +} + +unique_ptr DuckDBConnectionCountInit(ClientContext &context, TableFunctionInitInput &input) { + auto result = make_uniq(); + auto &conn_manager = context.db->GetConnectionManager(); + result->count = conn_manager.GetConnectionCount(); + return std::move(result); +} + +void DuckDBConnectionCountFunction(ClientContext &context, TableFunctionInput &data_p, DataChunk &output) { + auto &data = data_p.global_state->Cast(); + if (data.finished) { + return; + } + output.SetValue(0, 0, Value::UBIGINT(data.count)); + output.SetCardinality(1); + data.finished = true; +} + +void DuckDBConnectionCountFun::RegisterFunction(BuiltinFunctions &set) { + set.AddFunction(TableFunction("duckdb_connection_count", {}, DuckDBConnectionCountFunction, + DuckDBConnectionCountBind, DuckDBConnectionCountInit)); +} + +} // namespace duckdb diff --git a/src/duckdb/src/function/table/system/duckdb_functions.cpp b/src/duckdb/src/function/table/system/duckdb_functions.cpp index b0c7656fe..09ce83bcd 100644 --- a/src/duckdb/src/function/table/system/duckdb_functions.cpp +++ b/src/duckdb/src/function/table/system/duckdb_functions.cpp @@ -15,14 +15,20 @@ #include "duckdb/common/optional_idx.hpp" #include "duckdb/common/types.hpp" #include "duckdb/main/client_data.hpp" +#include "duckdb/parser/expression/window_expression.hpp" +#include "duckdb/main/database.hpp" +#include "duckdb/parser/parsed_data/create_scalar_function_info.hpp" +#include "duckdb/function/scalar_function.hpp" namespace duckdb { +constexpr const char *AggregateFunctionCatalogEntry::Name; struct DuckDBFunctionsData : public GlobalTableFunctionState { - DuckDBFunctionsData() : offset(0), offset_in_entry(0) { + DuckDBFunctionsData() : window_iterator(WindowExpression::WindowFunctions()), offset(0), offset_in_entry(0) { } vector> entries; + const WindowFunctionDefinition *window_iterator; idx_t offset; idx_t offset_in_entry; }; @@ -141,7 +147,7 @@ struct ScalarFunctionExtractor { } static Value GetReturnType(ScalarFunctionCatalogEntry &entry, idx_t offset) { - return Value(entry.functions.GetFunctionByOffset(offset).return_type.ToString()); + return Value(entry.functions.GetFunctionByOffset(offset).GetReturnType().ToString()); } static vector GetParameters(ScalarFunctionCatalogEntry &entry, idx_t offset) { @@ -176,11 +182,84 @@ struct ScalarFunctionExtractor { } static Value IsVolatile(ScalarFunctionCatalogEntry &entry, idx_t offset) { - return Value::BOOLEAN(entry.functions.GetFunctionByOffset(offset).stability == FunctionStability::VOLATILE); + return Value::BOOLEAN(entry.functions.GetFunctionByOffset(offset).GetStability() == + FunctionStability::VOLATILE); } static Value ResultType(ScalarFunctionCatalogEntry &entry, idx_t offset) { - return FunctionStabilityToValue(entry.functions.GetFunctionByOffset(offset).stability); + return FunctionStabilityToValue(entry.functions.GetFunctionByOffset(offset).GetStability()); + } +}; + +namespace { + +struct WindowFunctionCatalogEntry : CatalogEntry { +public: + WindowFunctionCatalogEntry(const SchemaCatalogEntry &schema, const string &name, vector arguments, + LogicalType return_type) + : CatalogEntry(CatalogType::AGGREGATE_FUNCTION_ENTRY, name, 0), schema(schema), arguments(std::move(arguments)), + return_type(std::move(return_type)) { + internal = true; + } + +public: + const SchemaCatalogEntry &schema; + vector arguments; + LogicalType return_type; + vector descriptions; + string alias_of; +}; + +} // namespace + +struct WindowFunctionExtractor { + static idx_t FunctionCount(WindowFunctionCatalogEntry &entry) { + return 1; + } + + static Value GetFunctionType() { + //! FIXME: should be 'window' but requires adapting generation scripts + return Value("aggregate"); + } + + static Value GetReturnType(WindowFunctionCatalogEntry &entry, idx_t offset) { + return Value(entry.return_type.ToString()); + } + + static vector GetParameters(WindowFunctionCatalogEntry &entry, idx_t offset) { + vector results; + for (idx_t i = 0; i < entry.arguments.size(); i++) { + results.emplace_back("col" + to_string(i)); + } + return results; + } + + static Value GetParameterTypes(WindowFunctionCatalogEntry &entry, idx_t offset) { + vector results; + for (idx_t i = 0; i < entry.arguments.size(); i++) { + results.emplace_back(entry.arguments[i].ToString()); + } + return Value::LIST(LogicalType::VARCHAR, std::move(results)); + } + + static vector GetParameterLogicalTypes(WindowFunctionCatalogEntry &entry, idx_t offset) { + return entry.arguments; + } + + static Value GetVarArgs(WindowFunctionCatalogEntry &entry, idx_t offset) { + return Value(); + } + + static Value GetMacroDefinition(WindowFunctionCatalogEntry &entry, idx_t offset) { + return Value(); + } + + static Value IsVolatile(WindowFunctionCatalogEntry &entry, idx_t offset) { + return Value::BOOLEAN(false); + } + + static Value ResultType(WindowFunctionCatalogEntry &entry, idx_t offset) { + return FunctionStabilityToValue(FunctionStability::CONSISTENT); } }; @@ -194,7 +273,7 @@ struct AggregateFunctionExtractor { } static Value GetReturnType(AggregateFunctionCatalogEntry &entry, idx_t offset) { - return Value(entry.functions.GetFunctionByOffset(offset).return_type.ToString()); + return Value(entry.functions.GetFunctionByOffset(offset).GetReturnType().ToString()); } static vector GetParameters(AggregateFunctionCatalogEntry &entry, idx_t offset) { @@ -229,11 +308,12 @@ struct AggregateFunctionExtractor { } static Value IsVolatile(AggregateFunctionCatalogEntry &entry, idx_t offset) { - return Value::BOOLEAN(entry.functions.GetFunctionByOffset(offset).stability == FunctionStability::VOLATILE); + return Value::BOOLEAN(entry.functions.GetFunctionByOffset(offset).GetStability() == + FunctionStability::VOLATILE); } static Value ResultType(AggregateFunctionCatalogEntry &entry, idx_t offset) { - return FunctionStabilityToValue(entry.functions.GetFunctionByOffset(offset).stability); + return FunctionStabilityToValue(entry.functions.GetFunctionByOffset(offset).GetStability()); } }; @@ -497,7 +577,7 @@ static vector ToValueVector(vector &string_vector) { } template -static Value GetParameterNames(FunctionEntry &entry, idx_t function_idx, FunctionDescription &function_description, +static Value GetParameterNames(CatalogEntry &entry, idx_t function_idx, FunctionDescription &function_description, Value ¶meter_types) { vector parameter_names; if (!function_description.parameter_names.empty()) { @@ -566,13 +646,13 @@ static optional_idx GetFunctionDescriptionIndex(vector &fun } template -bool ExtractFunctionData(FunctionEntry &entry, idx_t function_idx, DataChunk &output, idx_t output_offset) { +bool ExtractFunctionData(CatalogEntry &entry, idx_t function_idx, DataChunk &output, idx_t output_offset) { auto &function = entry.Cast(); vector parameter_types_vector = OP::GetParameterLogicalTypes(function, function_idx); Value parameter_types_value = OP::GetParameterTypes(function, function_idx); - optional_idx description_idx = GetFunctionDescriptionIndex(entry.descriptions, parameter_types_vector); + optional_idx description_idx = GetFunctionDescriptionIndex(function.descriptions, parameter_types_vector); FunctionDescription function_description = - description_idx.IsValid() ? entry.descriptions[description_idx.GetIndex()] : FunctionDescription(); + description_idx.IsValid() ? function.descriptions[description_idx.GetIndex()] : FunctionDescription(); idx_t col = 0; @@ -601,10 +681,10 @@ bool ExtractFunctionData(FunctionEntry &entry, idx_t function_idx, DataChunk &ou (function_description.description.empty()) ? Value() : Value(function_description.description)); // comment, LogicalType::VARCHAR - output.SetValue(col++, output_offset, entry.comment); + output.SetValue(col++, output_offset, function.comment); // tags, LogicalType::MAP(LogicalType::VARCHAR, LogicalType::VARCHAR) - output.SetValue(col++, output_offset, Value::MAP(entry.tags)); + output.SetValue(col++, output_offset, Value::MAP(function.tags)); // return_type, LogicalType::VARCHAR output.SetValue(col++, output_offset, OP::GetReturnType(function, function_idx)); @@ -645,9 +725,75 @@ bool ExtractFunctionData(FunctionEntry &entry, idx_t function_idx, DataChunk &ou return function_idx + 1 == OP::FunctionCount(function); } +void ExtractWindowFunctionData(ClientContext &context, const WindowFunctionDefinition *it, DataChunk &output, + idx_t output_offset) { + D_ASSERT(it && it->name != nullptr); + string name(it->name); + + auto &system_catalog = Catalog::GetSystemCatalog(DatabaseInstance::GetDatabase(context)); + string schema_name(DEFAULT_SCHEMA); + EntryLookupInfo schema_lookup(CatalogType::SCHEMA_ENTRY, schema_name); + auto &default_schema = system_catalog.GetSchema(context, schema_lookup); + + switch (it->expression_type) { + case ExpressionType::WINDOW_FILL: + case ExpressionType::WINDOW_LAST_VALUE: + case ExpressionType::WINDOW_FIRST_VALUE: { + WindowFunctionCatalogEntry function(default_schema, name, {LogicalType::TEMPLATE("T")}, + LogicalType::TEMPLATE("T")); + ExtractFunctionData(function, 0, output, output_offset); + break; + } + case ExpressionType::WINDOW_NTH_VALUE: { + WindowFunctionCatalogEntry function(default_schema, name, {LogicalType::TEMPLATE("T"), LogicalType::BIGINT}, + LogicalType::TEMPLATE("T")); + ExtractFunctionData(function, 0, output, output_offset); + break; + } + case ExpressionType::WINDOW_ROW_NUMBER: + case ExpressionType::WINDOW_RANK: + case ExpressionType::WINDOW_RANK_DENSE: { + WindowFunctionCatalogEntry function(default_schema, name, {}, LogicalType::BIGINT); + ExtractFunctionData(function, 0, output, output_offset); + break; + } + case ExpressionType::WINDOW_NTILE: { + WindowFunctionCatalogEntry function(default_schema, name, {LogicalType::BIGINT}, LogicalType::BIGINT); + ExtractFunctionData(function, 0, output, output_offset); + break; + } + case ExpressionType::WINDOW_PERCENT_RANK: + case ExpressionType::WINDOW_CUME_DIST: { + WindowFunctionCatalogEntry function(default_schema, name, {}, LogicalType::DOUBLE); + ExtractFunctionData(function, 0, output, output_offset); + break; + } + case ExpressionType::WINDOW_LAG: + case ExpressionType::WINDOW_LEAD: { + WindowFunctionCatalogEntry function( + default_schema, name, {LogicalType::TEMPLATE("T"), LogicalType::BIGINT, LogicalType::TEMPLATE("T")}, + LogicalType::TEMPLATE("T")); + ExtractFunctionData(function, 0, output, output_offset); + break; + } + default: + throw InternalException("Window function '%s' not implemented", name); + } +} + +static bool Finished(const DuckDBFunctionsData &data) { + if (data.offset < data.entries.size()) { + return false; + } + if (data.window_iterator->name == nullptr) { + return true; + } + return false; +} + void DuckDBFunctionsFunction(ClientContext &context, TableFunctionInput &data_p, DataChunk &output) { auto &data = data_p.global_state->Cast(); - if (data.offset >= data.entries.size()) { + if (Finished(data)) { // finished returning values return; } @@ -696,6 +842,11 @@ void DuckDBFunctionsFunction(ClientContext &context, TableFunctionInput &data_p, } count++; } + while (data.window_iterator->name != nullptr && count < STANDARD_VECTOR_SIZE) { + ExtractWindowFunctionData(context, data.window_iterator, output, count); + count++; + data.window_iterator++; + } output.SetCardinality(count); } diff --git a/src/duckdb/src/function/table/system/duckdb_secrets.cpp b/src/duckdb/src/function/table/system/duckdb_secrets.cpp index 6069344bf..ae7f3104a 100644 --- a/src/duckdb/src/function/table/system/duckdb_secrets.cpp +++ b/src/duckdb/src/function/table/system/duckdb_secrets.cpp @@ -37,6 +37,9 @@ static unique_ptr DuckDBSecretsBind(ClientContext &context, TableF auto entry = input.named_parameters.find("redact"); if (entry != input.named_parameters.end()) { + if (entry->second.IsNull()) { + throw InvalidInputException("Cannot use NULL as argument for redact"); + } if (BooleanValue::Get(entry->second)) { result->redact = SecretDisplayType::REDACTED; } else { diff --git a/src/duckdb/src/function/table/system/pragma_storage_info.cpp b/src/duckdb/src/function/table/system/pragma_storage_info.cpp index 5500c1c5d..7ba6cfd69 100644 --- a/src/duckdb/src/function/table/system/pragma_storage_info.cpp +++ b/src/duckdb/src/function/table/system/pragma_storage_info.cpp @@ -88,7 +88,7 @@ static unique_ptr PragmaStorageInfoBind(ClientContext &context, Ta Binder::BindSchemaOrCatalog(context, qname.catalog, qname.schema); auto &table_entry = Catalog::GetEntry(context, qname.catalog, qname.schema, qname.name); auto result = make_uniq(table_entry); - result->column_segments_info = table_entry.GetColumnSegmentInfo(); + result->column_segments_info = table_entry.GetColumnSegmentInfo(context); return std::move(result); } diff --git a/src/duckdb/src/function/table/system/pragma_table_sample.cpp b/src/duckdb/src/function/table/system/pragma_table_sample.cpp index ce083d92c..cf5a9ccfb 100644 --- a/src/duckdb/src/function/table/system/pragma_table_sample.cpp +++ b/src/duckdb/src/function/table/system/pragma_table_sample.cpp @@ -32,7 +32,6 @@ struct DuckDBTableSampleOperatorData : public GlobalTableFunctionState { static unique_ptr DuckDBTableSampleBind(ClientContext &context, TableFunctionBindInput &input, vector &return_types, vector &names) { - // look up the table name in the catalog auto qname = QualifiedName::Parse(input.inputs[0].GetValue()); Binder::BindSchemaOrCatalog(context, qname.catalog, qname.schema); diff --git a/src/duckdb/src/function/table/system/pragma_user_agent.cpp b/src/duckdb/src/function/table/system/pragma_user_agent.cpp index 3803f7195..6448422bf 100644 --- a/src/duckdb/src/function/table/system/pragma_user_agent.cpp +++ b/src/duckdb/src/function/table/system/pragma_user_agent.cpp @@ -13,7 +13,6 @@ struct PragmaUserAgentData : public GlobalTableFunctionState { static unique_ptr PragmaUserAgentBind(ClientContext &context, TableFunctionBindInput &input, vector &return_types, vector &names) { - names.emplace_back("user_agent"); return_types.emplace_back(LogicalType::VARCHAR); diff --git a/src/duckdb/src/function/table/system/test_all_types.cpp b/src/duckdb/src/function/table/system/test_all_types.cpp index cd4ba3964..94952d8d0 100644 --- a/src/duckdb/src/function/table/system/test_all_types.cpp +++ b/src/duckdb/src/function/table/system/test_all_types.cpp @@ -319,10 +319,16 @@ static unique_ptr TestAllTypesBind(ClientContext &context, TableFu bool use_large_bignum = false; auto entry = input.named_parameters.find("use_large_enum"); if (entry != input.named_parameters.end()) { + if (entry->second.IsNull()) { + throw InvalidInputException("Cannot use NULL as argument for use_large_enum"); + } use_large_enum = BooleanValue::Get(entry->second); } entry = input.named_parameters.find("use_large_bignum"); if (entry != input.named_parameters.end()) { + if (entry->second.IsNull()) { + throw InvalidInputException("Cannot use NULL as argument for use_large_bignum"); + } use_large_bignum = BooleanValue::Get(entry->second); } result->test_types = TestAllTypesFun::GetTestTypes(use_large_enum, use_large_bignum); diff --git a/src/duckdb/src/function/table/system/test_vector_types.cpp b/src/duckdb/src/function/table/system/test_vector_types.cpp index 23dab8758..5c5c073be 100644 --- a/src/duckdb/src/function/table/system/test_vector_types.cpp +++ b/src/duckdb/src/function/table/system/test_vector_types.cpp @@ -277,6 +277,9 @@ static unique_ptr TestVectorTypesBind(ClientContext &context, Tabl } for (auto &entry : input.named_parameters) { if (entry.first == "all_flat") { + if (entry.second.IsNull()) { + throw InvalidInputException("Cannot use NULL as argument for all_flat"); + } result->all_flat = BooleanValue::Get(entry.second); } else { throw InternalException("Unrecognized named parameter for test_vector_types"); diff --git a/src/duckdb/src/function/table/system_functions.cpp b/src/duckdb/src/function/table/system_functions.cpp index d10ec5d31..0a6a03507 100644 --- a/src/duckdb/src/function/table/system_functions.cpp +++ b/src/duckdb/src/function/table/system_functions.cpp @@ -18,6 +18,7 @@ void BuiltinFunctions::RegisterSQLiteFunctions() { PragmaDatabaseSize::RegisterFunction(*this); PragmaUserAgent::RegisterFunction(*this); + DuckDBConnectionCountFun::RegisterFunction(*this); DuckDBApproxDatabaseCountFun::RegisterFunction(*this); DuckDBColumnsFun::RegisterFunction(*this); DuckDBConstraintsFun::RegisterFunction(*this); diff --git a/src/duckdb/src/function/table/table_scan.cpp b/src/duckdb/src/function/table/table_scan.cpp index 99a9bcf79..fc24ec702 100644 --- a/src/duckdb/src/function/table/table_scan.cpp +++ b/src/duckdb/src/function/table/table_scan.cpp @@ -249,6 +249,11 @@ class DuckTableScanState : public TableScanGlobalState { storage_ids.push_back(GetStorageIndex(bind_data.table, col)); } + if (bind_data.order_options) { + l_state->scan_state.table_state.reorderer = make_uniq(*bind_data.order_options); + l_state->scan_state.local_state.reorderer = make_uniq(*bind_data.order_options); + } + l_state->scan_state.Initialize(std::move(storage_ids), context.client, input.filters, input.sample_options); storage.NextParallelScan(context.client, state, l_state->scan_state); @@ -329,6 +334,11 @@ static unique_ptr TableScanInitLocal(ExecutionContext & unique_ptr DuckTableScanInitGlobal(ClientContext &context, TableFunctionInitInput &input, DataTable &storage, const TableScanBindData &bind_data) { auto g_state = make_uniq(context, input.bind_data.get()); + if (bind_data.order_options) { + g_state->state.scan_state.reorderer = make_uniq(*bind_data.order_options); + g_state->state.local_state.reorderer = make_uniq(*bind_data.order_options); + } + storage.InitializeParallelScan(context, g_state->state); if (!input.CanRemoveFilterColumns()) { return std::move(g_state); @@ -740,6 +750,11 @@ vector TableScanGetRowIdColumns(ClientContext &context, optional_ptr order_options, optional_ptr bind_data_p) { + auto &bind_data = bind_data_p->Cast(); + bind_data.order_options = std::move(order_options); +} + TableFunction TableScanFunction::GetFunction() { TableFunction scan_function("seq_scan", {}, TableScanFunc); scan_function.init_local = TableScanInitLocal; @@ -763,6 +778,7 @@ TableFunction TableScanFunction::GetFunction() { scan_function.pushdown_expression = TableScanPushdownExpression; scan_function.get_virtual_columns = TableScanGetVirtualColumns; scan_function.get_row_id_columns = TableScanGetRowIdColumns; + scan_function.set_scan_order = SetScanOrder; return scan_function; } diff --git a/src/duckdb/src/function/table/version/pragma_version.cpp b/src/duckdb/src/function/table/version/pragma_version.cpp index c3f5f0a6b..fb39b9550 100644 --- a/src/duckdb/src/function/table/version/pragma_version.cpp +++ b/src/duckdb/src/function/table/version/pragma_version.cpp @@ -1,5 +1,5 @@ #ifndef DUCKDB_PATCH_VERSION -#define DUCKDB_PATCH_VERSION "0-dev383" +#define DUCKDB_PATCH_VERSION "0-dev2368" #endif #ifndef DUCKDB_MINOR_VERSION #define DUCKDB_MINOR_VERSION 5 @@ -8,10 +8,10 @@ #define DUCKDB_MAJOR_VERSION 1 #endif #ifndef DUCKDB_VERSION -#define DUCKDB_VERSION "v1.5.0-dev383" +#define DUCKDB_VERSION "v1.5.0-dev2368" #endif #ifndef DUCKDB_SOURCE_ID -#define DUCKDB_SOURCE_ID "07d170f87e" +#define DUCKDB_SOURCE_ID "44b706b2b7" #endif #include "duckdb/function/table/system_functions.hpp" #include "duckdb/main/database.hpp" @@ -91,6 +91,9 @@ const char *DuckDB::ReleaseCodename() { if (StringUtil::StartsWith(DUCKDB_VERSION, "v1.4.")) { return "Andium"; } + if (StringUtil::StartsWith(DUCKDB_VERSION, "v1.5.")) { + return "Variegata"; + } // add new version names here // we should not get here, but let's not fail because of it because tags on forks can be whatever diff --git a/src/duckdb/src/function/table_function.cpp b/src/duckdb/src/function/table_function.cpp index 310f75b58..b3835befc 100644 --- a/src/duckdb/src/function/table_function.cpp +++ b/src/duckdb/src/function/table_function.cpp @@ -14,11 +14,26 @@ PartitionStatistics::PartitionStatistics() : row_start(0), count(0), count_type( TableFunctionInfo::~TableFunctionInfo() { } -TableFunction::TableFunction(string name, vector arguments, table_function_t function, +TableFunction::TableFunction(string name, const vector &arguments, table_function_t function_, table_function_bind_t bind, table_function_init_global_t init_global, table_function_init_local_t init_local) - : SimpleNamedParameterFunction(std::move(name), std::move(arguments)), bind(bind), bind_replace(nullptr), - bind_operator(nullptr), init_global(init_global), init_local(init_local), function(function), + : SimpleNamedParameterFunction(std::move(name), arguments), bind(bind), bind_replace(nullptr), + bind_operator(nullptr), init_global(init_global), init_local(init_local), function(function_), + in_out_function(nullptr), in_out_function_final(nullptr), statistics(nullptr), dependency(nullptr), + cardinality(nullptr), pushdown_complex_filter(nullptr), pushdown_expression(nullptr), to_string(nullptr), + dynamic_to_string(nullptr), table_scan_progress(nullptr), get_partition_data(nullptr), get_bind_info(nullptr), + type_pushdown(nullptr), get_multi_file_reader(nullptr), supports_pushdown_type(nullptr), + get_partition_info(nullptr), get_partition_stats(nullptr), get_virtual_columns(nullptr), + get_row_id_columns(nullptr), set_scan_order(nullptr), serialize(nullptr), deserialize(nullptr), + projection_pushdown(false), filter_pushdown(false), filter_prune(false), sampling_pushdown(false), + late_materialization(false) { +} + +TableFunction::TableFunction(string name, const vector &arguments, std::nullptr_t function_, + table_function_bind_t bind, table_function_init_global_t init_global, + table_function_init_local_t init_local) + : SimpleNamedParameterFunction(std::move(name), arguments), bind(bind), bind_replace(nullptr), + bind_operator(nullptr), init_global(init_global), init_local(init_local), function(nullptr), in_out_function(nullptr), in_out_function_final(nullptr), statistics(nullptr), dependency(nullptr), cardinality(nullptr), pushdown_complex_filter(nullptr), pushdown_expression(nullptr), to_string(nullptr), dynamic_to_string(nullptr), table_scan_progress(nullptr), get_partition_data(nullptr), get_bind_info(nullptr), @@ -28,10 +43,15 @@ TableFunction::TableFunction(string name, vector arguments, table_f filter_pushdown(false), filter_prune(false), sampling_pushdown(false), late_materialization(false) { } -TableFunction::TableFunction(const vector &arguments, table_function_t function, +TableFunction::TableFunction(const vector &arguments, table_function_t function_, table_function_bind_t bind, table_function_init_global_t init_global, table_function_init_local_t init_local) - : TableFunction(string(), arguments, function, bind, init_global, init_local) { + : TableFunction("", arguments, function_, bind, init_global, init_local) { +} + +TableFunction::TableFunction(const vector &arguments, std::nullptr_t function_, table_function_bind_t bind, + table_function_init_global_t init_global, table_function_init_local_t init_local) + : TableFunction("", arguments, function_, bind, init_global, init_local) { } TableFunction::TableFunction() : TableFunction("", {}, nullptr, nullptr, nullptr, nullptr) { @@ -56,4 +76,22 @@ bool TableFunction::Equal(const TableFunction &rhs) const { return true; // they are equal } +bool ExtractSourceResultType(AsyncResultType in, SourceResultType &out) { + switch (in) { + case AsyncResultType::IMPLICIT: + case AsyncResultType::INVALID: + return false; + case AsyncResultType::HAVE_MORE_OUTPUT: + out = SourceResultType::HAVE_MORE_OUTPUT; + break; + case AsyncResultType::FINISHED: + out = SourceResultType::FINISHED; + break; + case AsyncResultType::BLOCKED: + out = SourceResultType::BLOCKED; + break; + } + return true; +} + } // namespace duckdb diff --git a/src/duckdb/src/function/udf_function.cpp b/src/duckdb/src/function/udf_function.cpp index 3c03dbbe3..55ba9385f 100644 --- a/src/duckdb/src/function/udf_function.cpp +++ b/src/duckdb/src/function/udf_function.cpp @@ -9,10 +9,9 @@ namespace duckdb { void UDFWrapper::RegisterFunction(string name, vector args, LogicalType ret_type, scalar_function_t udf_function, ClientContext &context, LogicalType varargs) { - ScalarFunction scalar_function(std::move(name), std::move(args), std::move(ret_type), std::move(udf_function)); scalar_function.varargs = std::move(varargs); - scalar_function.null_handling = FunctionNullHandling::SPECIAL_HANDLING; + scalar_function.SetNullHandling(FunctionNullHandling::SPECIAL_HANDLING); CreateScalarFunctionInfo info(scalar_function); info.schema = DEFAULT_SCHEMA; context.RegisterFunction(info); diff --git a/src/duckdb/src/function/window/window_aggregate_function.cpp b/src/duckdb/src/function/window/window_aggregate_function.cpp index 95c8a5059..1d0e01e2f 100644 --- a/src/duckdb/src/function/window/window_aggregate_function.cpp +++ b/src/duckdb/src/function/window/window_aggregate_function.cpp @@ -52,7 +52,6 @@ static BoundWindowExpression &SimplifyWindowedAggregate(BoundWindowExpression &w WindowAggregateExecutor::WindowAggregateExecutor(BoundWindowExpression &wexpr, ClientContext &client, WindowSharedExpressions &shared, WindowAggregationMode mode) : WindowExecutor(SimplifyWindowedAggregate(wexpr, client), shared), mode(mode) { - // Force naive for SEPARATE mode or for (currently!) unsupported functionality if (!ClientConfig::GetConfig(client).enable_optimizer || mode == WindowAggregationMode::SEPARATE) { if (!WindowNaiveAggregator::CanAggregate(wexpr)) { @@ -111,7 +110,6 @@ class WindowAggregateExecutorLocalState : public WindowExecutorBoundsLocalState const WindowAggregator &aggregator) : WindowExecutorBoundsLocalState(context, gstate.Cast()), filter_executor(context.client) { - auto &gastate = gstate.Cast(); aggregator_state = aggregator.GetLocalState(context, *gastate.gsink); diff --git a/src/duckdb/src/function/window/window_aggregator.cpp b/src/duckdb/src/function/window/window_aggregator.cpp index 3ac9c91c9..107a4d31c 100644 --- a/src/duckdb/src/function/window/window_aggregator.cpp +++ b/src/duckdb/src/function/window/window_aggregator.cpp @@ -12,7 +12,6 @@ namespace duckdb { WindowAggregator::WindowAggregator(const BoundWindowExpression &wexpr) : wexpr(wexpr), aggr(wexpr), result_type(wexpr.return_type), state_size(aggr.function.state_size(aggr.function)), exclude_mode(wexpr.exclude_clause) { - for (auto &child : wexpr.children) { arg_types.emplace_back(child->return_type); } @@ -32,7 +31,6 @@ WindowAggregatorGlobalState::WindowAggregatorGlobalState(ClientContext &client, idx_t group_count) : client(client), allocator(Allocator::DefaultAllocator()), aggregator(aggregator_p), aggr(aggregator.wexpr), locals(0), finalized(0) { - if (aggr.filter) { // Start with all invalid and set the ones that pass filter_mask.Initialize(group_count, false); diff --git a/src/duckdb/src/function/window/window_boundaries_state.cpp b/src/duckdb/src/function/window/window_boundaries_state.cpp index 84ae8abb2..84fbc7929 100644 --- a/src/duckdb/src/function/window/window_boundaries_state.cpp +++ b/src/duckdb/src/function/window/window_boundaries_state.cpp @@ -620,7 +620,6 @@ void WindowBoundariesState::PartitionEnd(DataChunk &bounds, idx_t row_idx, const void WindowBoundariesState::PeerBegin(DataChunk &bounds, idx_t row_idx, const idx_t count, bool is_jump, const ValidityMask &partition_mask, const ValidityMask &order_mask) { - auto peer_begin_data = FlatVector::GetData(bounds.data[PEER_BEGIN]); // OVER() diff --git a/src/duckdb/src/function/window/window_constant_aggregator.cpp b/src/duckdb/src/function/window/window_constant_aggregator.cpp index d35c90b4f..359f4ef8e 100644 --- a/src/duckdb/src/function/window/window_constant_aggregator.cpp +++ b/src/duckdb/src/function/window/window_constant_aggregator.cpp @@ -35,7 +35,6 @@ WindowConstantAggregatorGlobalState::WindowConstantAggregatorGlobalState(ClientC idx_t group_count, const ValidityMask &partition_mask) : WindowAggregatorGlobalState(context, aggregator, STANDARD_VECTOR_SIZE), statef(aggr) { - // Locate the partition boundaries if (partition_mask.AllValid()) { partition_offsets.emplace_back(0); @@ -201,7 +200,6 @@ BoundWindowExpression &WindowConstantAggregator::RebindAggregate(ClientContext & WindowConstantAggregator::WindowConstantAggregator(BoundWindowExpression &wexpr, WindowSharedExpressions &shared, ClientContext &context) : WindowAggregator(RebindAggregate(context, wexpr)) { - // We only need these values for Sink for (auto &child : wexpr.children) { child_idx.emplace_back(shared.RegisterSink(child)); diff --git a/src/duckdb/src/function/window/window_distinct_aggregator.cpp b/src/duckdb/src/function/window/window_distinct_aggregator.cpp index 063e25a80..04f1624f8 100644 --- a/src/duckdb/src/function/window/window_distinct_aggregator.cpp +++ b/src/duckdb/src/function/window/window_distinct_aggregator.cpp @@ -124,7 +124,6 @@ WindowDistinctAggregatorGlobalState::WindowDistinctAggregatorGlobalState(ClientC idx_t group_count) : WindowAggregatorGlobalState(client, aggregator, group_count), stage(WindowDistinctSortStage::INIT), tasks_assigned(0), tasks_completed(0), merge_sort_tree(*this, group_count), levels_flat_native(aggr) { - // 1: functionComputePrevIdcs(𝑖𝑛) // 2: sorted ← [] // We sort the aggregate arguments and use the partition index as a tie-breaker. @@ -704,7 +703,6 @@ unique_ptr WindowDistinctAggregator::GetLocalState(ExecutionCont void WindowDistinctAggregator::Evaluate(ExecutionContext &context, const DataChunk &bounds, Vector &result, idx_t count, idx_t row_idx, OperatorSinkInput &sink) const { - const auto &gdstate = sink.global_state.Cast(); auto &ldstate = sink.local_state.Cast(); ldstate.Evaluate(context, gdstate, bounds, result, count, row_idx); diff --git a/src/duckdb/src/function/window/window_merge_sort_tree.cpp b/src/duckdb/src/function/window/window_merge_sort_tree.cpp index 6af3d0e5b..5943e6228 100644 --- a/src/duckdb/src/function/window/window_merge_sort_tree.cpp +++ b/src/duckdb/src/function/window/window_merge_sort_tree.cpp @@ -1,5 +1,8 @@ #include "duckdb/function/window/window_merge_sort_tree.hpp" +#include "duckdb/main/client_config.hpp" +#include "duckdb/main/client_context.hpp" #include "duckdb/planner/expression/bound_reference_expression.hpp" +#include "duckdb/common/types/column/column_data_collection.hpp" #include #include diff --git a/src/duckdb/src/function/window/window_naive_aggregator.cpp b/src/duckdb/src/function/window/window_naive_aggregator.cpp index a01e4813c..4d0e77f21 100644 --- a/src/duckdb/src/function/window/window_naive_aggregator.cpp +++ b/src/duckdb/src/function/window/window_naive_aggregator.cpp @@ -13,7 +13,6 @@ namespace duckdb { //===--------------------------------------------------------------------===// WindowNaiveAggregator::WindowNaiveAggregator(const WindowAggregateExecutor &executor, WindowSharedExpressions &shared) : WindowAggregator(executor.wexpr, shared), executor(executor) { - for (const auto &order : wexpr.arg_orders) { arg_order_idx.emplace_back(shared.RegisterCollection(order.expression, false)); } diff --git a/src/duckdb/src/function/window/window_rank_function.cpp b/src/duckdb/src/function/window/window_rank_function.cpp index af70521a0..644190940 100644 --- a/src/duckdb/src/function/window/window_rank_function.cpp +++ b/src/duckdb/src/function/window/window_rank_function.cpp @@ -1,6 +1,7 @@ #include "duckdb/function/window/window_rank_function.hpp" #include "duckdb/function/window/window_shared_expressions.hpp" #include "duckdb/function/window/window_token_tree.hpp" +#include "duckdb/main/client_context.hpp" #include "duckdb/planner/expression/bound_window_expression.hpp" namespace duckdb { @@ -103,7 +104,6 @@ void WindowPeerLocalState::NextRank(idx_t partition_begin, idx_t peer_begin, idx //===--------------------------------------------------------------------===// WindowPeerExecutor::WindowPeerExecutor(BoundWindowExpression &wexpr, WindowSharedExpressions &shared) : WindowExecutor(wexpr, shared) { - for (const auto &order : wexpr.arg_orders) { arg_order_idx.emplace_back(shared.RegisterSink(order.expression)); } diff --git a/src/duckdb/src/function/window/window_rownumber_function.cpp b/src/duckdb/src/function/window/window_rownumber_function.cpp index f0929d642..27e7adecc 100644 --- a/src/duckdb/src/function/window/window_rownumber_function.cpp +++ b/src/duckdb/src/function/window/window_rownumber_function.cpp @@ -1,6 +1,7 @@ #include "duckdb/function/window/window_rownumber_function.hpp" #include "duckdb/function/window/window_shared_expressions.hpp" #include "duckdb/function/window/window_token_tree.hpp" +#include "duckdb/main/client_context.hpp" #include "duckdb/planner/expression/bound_window_expression.hpp" namespace duckdb { @@ -91,7 +92,6 @@ void WindowRowNumberLocalState::Finalize(ExecutionContext &context, CollectionPt //===--------------------------------------------------------------------===// WindowRowNumberExecutor::WindowRowNumberExecutor(BoundWindowExpression &wexpr, WindowSharedExpressions &shared) : WindowExecutor(wexpr, shared) { - for (const auto &order : wexpr.arg_orders) { arg_order_idx.emplace_back(shared.RegisterSink(order.expression)); } @@ -141,7 +141,6 @@ void WindowRowNumberExecutor::EvaluateInternal(ExecutionContext &context, DataCh //===--------------------------------------------------------------------===// WindowNtileExecutor::WindowNtileExecutor(BoundWindowExpression &wexpr, WindowSharedExpressions &shared) : WindowRowNumberExecutor(wexpr, shared) { - // NTILE has one argument ntile_idx = shared.RegisterEvaluate(wexpr.children[0]); } diff --git a/src/duckdb/src/function/window/window_segment_tree.cpp b/src/duckdb/src/function/window/window_segment_tree.cpp index f62a0a856..22de75480 100644 --- a/src/duckdb/src/function/window/window_segment_tree.cpp +++ b/src/duckdb/src/function/window/window_segment_tree.cpp @@ -164,7 +164,6 @@ WindowSegmentTreePart::WindowSegmentTreePart(ArenaAllocator &allocator, const Ag filter_mask(filter_mask), state_size(aggr.function.state_size(aggr.function)), state(state_size * STANDARD_VECTOR_SIZE), cursor(std::move(cursor_p)), statep(LogicalType::POINTER), statel(LogicalType::POINTER), statef(LogicalType::POINTER), flush_count(0) { - auto &inputs = cursor->chunk; if (inputs.ColumnCount() > 0) { leaves.Initialize(Allocator::DefaultAllocator(), inputs.GetTypes()); @@ -298,7 +297,6 @@ void WindowSegmentTreePart::Finalize(Vector &result, idx_t count) { WindowSegmentTreeGlobalState::WindowSegmentTreeGlobalState(ClientContext &context, const WindowSegmentTree &aggregator, idx_t group_count) : WindowAggregatorGlobalState(context, aggregator, group_count), tree(aggregator), levels_flat_native(aggr) { - D_ASSERT(!aggregator.wexpr.children.empty()); // compute space required to store internal nodes of segment tree @@ -570,7 +568,6 @@ void WindowSegmentTreePart::EvaluateUpperLevels(const WindowSegmentTreeGlobalSta void WindowSegmentTreePart::EvaluateLeaves(const WindowSegmentTreeGlobalState &tree, const idx_t *begins, const idx_t *ends, const idx_t *bounds, idx_t count, idx_t row_idx, FramePart frame_part, FramePart leaf_part) { - auto fdata = FlatVector::GetData(statef); // For order-sensitive aggregates, we have to process the ragged leaves in two pieces. diff --git a/src/duckdb/src/function/window/window_value_function.cpp b/src/duckdb/src/function/window/window_value_function.cpp index 0258b7d6b..adf60be11 100644 --- a/src/duckdb/src/function/window/window_value_function.cpp +++ b/src/duckdb/src/function/window/window_value_function.cpp @@ -23,7 +23,6 @@ class WindowValueGlobalState : public WindowExecutorGlobalState { const ValidityMask &partition_mask, const ValidityMask &order_mask) : WindowExecutorGlobalState(client, executor, payload_count, partition_mask, order_mask), ignore_nulls(&all_valid), child_idx(executor.child_idx) { - if (!executor.arg_order_idx.empty()) { value_tree = make_uniq(client, executor.wexpr.arg_orders, executor.arg_order_idx, payload_count); @@ -139,7 +138,6 @@ void WindowValueLocalState::Finalize(ExecutionContext &context, CollectionPtr co //===--------------------------------------------------------------------===// WindowValueExecutor::WindowValueExecutor(BoundWindowExpression &wexpr, WindowSharedExpressions &shared) : WindowExecutor(wexpr, shared) { - for (const auto &order : wexpr.arg_orders) { arg_order_idx.emplace_back(shared.RegisterSink(order.expression)); } @@ -200,7 +198,6 @@ class WindowLeadLagGlobalState : public WindowValueGlobalState { const idx_t payload_count, const ValidityMask &partition_mask, const ValidityMask &order_mask) : WindowValueGlobalState(client, executor, payload_count, partition_mask, order_mask) { - if (value_tree) { use_framing = true; @@ -842,7 +839,6 @@ static fill_value_t GetFillValueFunction(const LogicalType &type) { WindowFillExecutor::WindowFillExecutor(BoundWindowExpression &wexpr, WindowSharedExpressions &shared) : WindowValueExecutor(wexpr, shared) { - // We need the sort values for interpolation, so either use the range or the secondary ordering expression if (arg_order_idx.empty()) { // We use the range ordering, even if it has not been defined @@ -918,7 +914,6 @@ unique_ptr WindowFillExecutor::GetLocalState(ExecutionContext &c void WindowFillExecutor::EvaluateInternal(ExecutionContext &context, DataChunk &eval_chunk, Vector &result, idx_t count, idx_t row_idx, OperatorSinkInput &sink) const { - auto &lfstate = sink.local_state.Cast(); auto &cursor = *lfstate.cursor; diff --git a/src/duckdb/src/include/duckdb.h b/src/duckdb/src/include/duckdb.h index ccf5ad5ac..d11de040f 100644 --- a/src/duckdb/src/include/duckdb.h +++ b/src/duckdb/src/include/duckdb.h @@ -255,6 +255,21 @@ typedef enum duckdb_file_flag { DUCKDB_FILE_FLAG_APPEND = 5, } duckdb_file_flag; +//! An enum over DuckDB's configuration option scopes. +//! This enum can be used to specify the default scope when creating a custom configuration option, +//! but it is also be used to determine the scope in which a configuration option is set when it is +//! changed or retrieved. +typedef enum duckdb_config_option_scope { + DUCKDB_CONFIG_OPTION_SCOPE_INVALID = 0, + // The option is set for the duration of the current transaction only. + // !! CURRENTLY NOT IMPLEMENTED !! + DUCKDB_CONFIG_OPTION_SCOPE_LOCAL = 1, + // The option is set for the current session/connection only. + DUCKDB_CONFIG_OPTION_SCOPE_SESSION = 2, + // Set the option globally for all sessions/connections. + DUCKDB_CONFIG_OPTION_SCOPE_GLOBAL = 3, +} duckdb_config_option_scope; + //===--------------------------------------------------------------------===// // General type definitions //===--------------------------------------------------------------------===// @@ -548,6 +563,12 @@ typedef struct _duckdb_config { void *internal_ptr; } * duckdb_config; +//! A custom configuration option instance. Used to register custom options that can be set on a duckdb_config. +//! or by the user in SQL using `SET = `. +typedef struct _duckdb_config_option { + void *internal_ptr; +} * duckdb_config_option; + //! A logical type. //! Must be destroyed with `duckdb_destroy_logical_type`. typedef struct _duckdb_logical_type { @@ -699,6 +720,47 @@ typedef void (*duckdb_table_function_init_t)(duckdb_init_info info); //! The function to generate an output chunk during table function execution. typedef void (*duckdb_table_function_t)(duckdb_function_info info, duckdb_data_chunk output); +//===--------------------------------------------------------------------===// +// Copy function types +//===--------------------------------------------------------------------===// + +//! A COPY function. Must be destroyed with `duckdb_destroy_copy_function`. +typedef struct _duckdb_copy_function { + void *internal_ptr; +} * duckdb_copy_function; + +//! Info for the bind function of a COPY function. +typedef struct _duckdb_copy_function_bind_info { + void *internal_ptr; +} * duckdb_copy_function_bind_info; + +//! Info for the global initialization function of a COPY function. +typedef struct _duckdb_copy_function_global_init_info { + void *internal_ptr; +} * duckdb_copy_function_global_init_info; + +//! Info for the sink function of a COPY function. +typedef struct _duckdb_copy_function_sink_info { + void *internal_ptr; +} * duckdb_copy_function_sink_info; + +//! Info for the finalize function of a COPY function. +typedef struct _duckdb_copy_function_finalize_info { + void *internal_ptr; +} * duckdb_copy_function_finalize_info; + +//! The bind function to use when binding a COPY ... TO function. +typedef void (*duckdb_copy_function_bind_t)(duckdb_copy_function_bind_info info); + +//! The initialization function to use when initializing a COPY ... TO function. +typedef void (*duckdb_copy_function_global_init_t)(duckdb_copy_function_global_init_info info); + +//! The function to sink an input chunk into during execution of a COPY ... TO function. +typedef void (*duckdb_copy_function_sink_t)(duckdb_copy_function_sink_info info, duckdb_data_chunk input); + +//! The function to finalize the COPY ... TO function execution. +typedef void (*duckdb_copy_function_finalize_t)(duckdb_copy_function_finalize_info info); + //===--------------------------------------------------------------------===// // Cast types //===--------------------------------------------------------------------===// @@ -806,9 +868,12 @@ struct duckdb_extension_access { // Functions //===--------------------------------------------------------------------===// -//===--------------------------------------------------------------------===// +//---------------------------------------------------------------------------------------------------------------------- // Open Connect -//===--------------------------------------------------------------------===// +//---------------------------------------------------------------------------------------------------------------------- +// DESCRIPTION: +// Functions to operate on the instance cache, databases, connections, as well as some metadata functions. +//---------------------------------------------------------------------------------------------------------------------- /*! Creates a new database instance cache. @@ -896,10 +961,10 @@ Interrupt running query DUCKDB_C_API void duckdb_interrupt(duckdb_connection connection); /*! -Get progress of the running query +Get the progress of the running query. -* @param connection The working connection -* @return -1 if no progress or a percentage of the progress +* @param connection The connection running the query. +* @return The query progress type containing progress information. */ DUCKDB_C_API duckdb_query_progress_type duckdb_query_progress(duckdb_connection connection); @@ -968,9 +1033,12 @@ with duckdb_destroy_value. */ DUCKDB_C_API duckdb_value duckdb_get_table_names(duckdb_connection connection, const char *query, bool qualified); -//===--------------------------------------------------------------------===// +//---------------------------------------------------------------------------------------------------------------------- // Configuration -//===--------------------------------------------------------------------===// +//---------------------------------------------------------------------------------------------------------------------- +// DESCRIPTION: +// Functions to interact with a `duckdb_config`, which is the configuration parameter for opening a database. +//---------------------------------------------------------------------------------------------------------------------- /*! Initializes an empty configuration object that can be used to provide start-up options for the DuckDB instance @@ -1031,12 +1099,13 @@ Destroys the specified configuration object and de-allocates all memory allocate */ DUCKDB_C_API void duckdb_destroy_config(duckdb_config *config); -//===--------------------------------------------------------------------===// +//---------------------------------------------------------------------------------------------------------------------- // Error Data -//===--------------------------------------------------------------------===// - -// Functions that can throw DuckDB errors must return duckdb_error_data. -// Please use this interface for all new functions, as it deprecates all previous error handling approaches. +//---------------------------------------------------------------------------------------------------------------------- +// DESCRIPTION: +// Functions to operate on `duckdb_error_data`, which contains, for example, the error type and message. Please use this +// interface for all new C API functions, as it supersedes previous error handling approaches. +//---------------------------------------------------------------------------------------------------------------------- /*! Creates duckdb_error_data. @@ -1079,9 +1148,12 @@ Returns whether the error data contains an error or not. */ DUCKDB_C_API bool duckdb_error_data_has_error(duckdb_error_data error_data); -//===--------------------------------------------------------------------===// +//---------------------------------------------------------------------------------------------------------------------- // Query Execution -//===--------------------------------------------------------------------===// +//---------------------------------------------------------------------------------------------------------------------- +// DESCRIPTION: +// Functions to obtain a `duckdb_result` and to retrieve metadata from it. +//---------------------------------------------------------------------------------------------------------------------- /*! Executes a SQL query within a connection and stores the full (materialized) result in the out_result pointer. @@ -1254,10 +1326,6 @@ Returns the result error type contained within the result. The error is only set */ DUCKDB_C_API duckdb_error_type duckdb_result_error_type(duckdb_result *result); -//===--------------------------------------------------------------------===// -// Result Functions -//===--------------------------------------------------------------------===// - #ifndef DUCKDB_API_NO_DEPRECATED /*! **DEPRECATION NOTICE**: This method is scheduled for removal in a future release. @@ -1310,14 +1378,20 @@ Returns the return_type of the given result, or DUCKDB_RETURN_TYPE_INVALID on er */ DUCKDB_C_API duckdb_result_type duckdb_result_return_type(duckdb_result result); -//===--------------------------------------------------------------------===// +//---------------------------------------------------------------------------------------------------------------------- // Safe Fetch Functions -//===--------------------------------------------------------------------===// +//---------------------------------------------------------------------------------------------------------------------- +// DESCRIPTION: +// Deprecated functions to interact with a `duckdb_result`. +// +// DEPRECATION NOTICE: +// This function group is deprecated and scheduled for removal. +// +// USE INSTEAD: +// To access the values in a result, use `duckdb_fetch_chunk` repeatedly. For each chunk, use the `duckdb_data_chunk` +// interface to access any columns and their values. +//---------------------------------------------------------------------------------------------------------------------- -// These functions will perform conversions if necessary. -// On failure (e.g. if conversion cannot be performed or if the value is NULL) a default value is returned. -// Note that these functions are slow since they perform bounds checking and conversion -// For fast access of values prefer using `duckdb_result_get_chunk` #ifndef DUCKDB_API_NO_DEPRECATED /*! **DEPRECATION NOTICE**: This method is scheduled for removal in a future release. @@ -1446,8 +1520,7 @@ DUCKDB_C_API duckdb_timestamp duckdb_value_timestamp(duckdb_result *result, idx_ DUCKDB_C_API duckdb_interval duckdb_value_interval(duckdb_result *result, idx_t col, idx_t row); /*! -**DEPRECATED**: Use duckdb_value_string instead. This function does not work correctly if the string contains null -bytes. +**DEPRECATION NOTICE**: This method is scheduled for removal in a future release. * @return The text value at the specified location as a null-terminated string, or nullptr if the value cannot be converted. The result must be freed with `duckdb_free`. @@ -1457,16 +1530,12 @@ DUCKDB_C_API char *duckdb_value_varchar(duckdb_result *result, idx_t col, idx_t /*! **DEPRECATION NOTICE**: This method is scheduled for removal in a future release. -No support for nested types, and for other complex types. -The resulting field "string.data" must be freed with `duckdb_free.` - * @return The string value at the specified location. Attempts to cast the result value to string. */ DUCKDB_C_API duckdb_string duckdb_value_string(duckdb_result *result, idx_t col, idx_t row); /*! -**DEPRECATED**: Use duckdb_value_string_internal instead. This function does not work correctly if the string contains -null bytes. +**DEPRECATION NOTICE**: This method is scheduled for removal in a future release. * @return The char* value at the specified location. ONLY works on VARCHAR columns and does not auto-cast. If the column is NOT a VARCHAR column this function will return NULL. @@ -1476,8 +1545,8 @@ The result must NOT be freed. DUCKDB_C_API char *duckdb_value_varchar_internal(duckdb_result *result, idx_t col, idx_t row); /*! -**DEPRECATED**: Use duckdb_value_string_internal instead. This function does not work correctly if the string contains -null bytes. +**DEPRECATION NOTICE**: This method is scheduled for removal in a future release. + * @return The char* value at the specified location. ONLY works on VARCHAR columns and does not auto-cast. If the column is NOT a VARCHAR column this function will return NULL. @@ -1502,9 +1571,12 @@ DUCKDB_C_API bool duckdb_value_is_null(duckdb_result *result, idx_t col, idx_t r #endif -//===--------------------------------------------------------------------===// +//---------------------------------------------------------------------------------------------------------------------- // Helpers -//===--------------------------------------------------------------------===// +//---------------------------------------------------------------------------------------------------------------------- +// DESCRIPTION: +// Generic and `duckdb_string_t` helper functions. +//---------------------------------------------------------------------------------------------------------------------- /*! Allocate `size` bytes of memory using the duckdb internal malloc function. Any memory allocated in this manner @@ -1554,9 +1626,13 @@ Get a pointer to the string data of a string_t */ DUCKDB_C_API const char *duckdb_string_t_data(duckdb_string_t *string); -//===--------------------------------------------------------------------===// +//---------------------------------------------------------------------------------------------------------------------- // Date Time Timestamp Helpers -//===--------------------------------------------------------------------===// +//---------------------------------------------------------------------------------------------------------------------- +// DESCRIPTION: +// Functions to convert from and to `duckdb_[date, time, time_tz, timestamp]`. +// `duckdb_is_finite_timestamp[_s, _ms, _ns]` helper functions. +//---------------------------------------------------------------------------------------------------------------------- /*! Decompose a `duckdb_date` object into year, month and date (stored as `duckdb_date_struct`). @@ -1664,9 +1740,12 @@ Test a `duckdb_timestamp_ns` to see if it is a finite value. */ DUCKDB_C_API bool duckdb_is_finite_timestamp_ns(duckdb_timestamp_ns ts); -//===--------------------------------------------------------------------===// -// Hugeint Helpers -//===--------------------------------------------------------------------===// +//---------------------------------------------------------------------------------------------------------------------- +// Hugeint and Uhugeint Helpers +//---------------------------------------------------------------------------------------------------------------------- +// DESCRIPTION: +// Functions to convert from and to `duckdb_[hugeint, uhugeint]`. +//---------------------------------------------------------------------------------------------------------------------- /*! Converts a duckdb_hugeint object (as obtained from a `DUCKDB_TYPE_HUGEINT` column) into a double. @@ -1686,10 +1765,6 @@ If the conversion fails because the double value is too big the result will be 0 */ DUCKDB_C_API duckdb_hugeint duckdb_double_to_hugeint(double val); -//===--------------------------------------------------------------------===// -// Unsigned Hugeint Helpers -//===--------------------------------------------------------------------===// - /*! Converts a duckdb_uhugeint object (as obtained from a `DUCKDB_TYPE_UHUGEINT` column) into a double. @@ -1708,9 +1783,12 @@ If the conversion fails because the double value is too big the result will be 0 */ DUCKDB_C_API duckdb_uhugeint duckdb_double_to_uhugeint(double val); -//===--------------------------------------------------------------------===// +//---------------------------------------------------------------------------------------------------------------------- // Decimal Helpers -//===--------------------------------------------------------------------===// +//---------------------------------------------------------------------------------------------------------------------- +// DESCRIPTION: +// Functions to convert from and to `duckdb_decimal`. +//---------------------------------------------------------------------------------------------------------------------- /*! Converts a double value to a duckdb_decimal object. @@ -1730,19 +1808,21 @@ Converts a duckdb_decimal object (as obtained from a `DUCKDB_TYPE_DECIMAL` colum */ DUCKDB_C_API double duckdb_decimal_to_double(duckdb_decimal val); -//===--------------------------------------------------------------------===// +//---------------------------------------------------------------------------------------------------------------------- // Prepared Statements -//===--------------------------------------------------------------------===// - -// A prepared statement is a parameterized query that allows you to bind parameters to it. -// * This is useful to easily supply parameters to functions and avoid SQL injection attacks. -// * This is useful to speed up queries that you will execute several times with different parameters. -// Because the query will only be parsed, bound, optimized and planned once during the prepare stage, -// rather than once per execution. +//---------------------------------------------------------------------------------------------------------------------- +// DESCRIPTION: +// A prepared statement is a parameterized query, and you can bind parameters to it. Prepared statements are commonly +// used to easily supply parameters to functions and avoid SQL injection attacks. They also speed up queries that are +// executed repeatedly with different parameters. That is because the query is only parsed, bound, optimized and planned +// once during the prepare stage, rather than once per execution, if it is possible to resolve all parameter types. +// // For example: -// SELECT * FROM tbl WHERE id=? +// SELECT * FROM tbl WHERE id = ? // Or a query with multiple parameters: -// SELECT * FROM tbl WHERE id=$1 OR name=$2 +// SELECT * FROM tbl WHERE id = $1 OR name = $2 +//---------------------------------------------------------------------------------------------------------------------- + /*! Create a prepared statement object from a query. @@ -1881,9 +1961,13 @@ Returns `DUCKDB_TYPE_INVALID` if the column is out of range. DUCKDB_C_API duckdb_type duckdb_prepared_statement_column_type(duckdb_prepared_statement prepared_statement, idx_t col_idx); -//===--------------------------------------------------------------------===// +//---------------------------------------------------------------------------------------------------------------------- // Bind Values to Prepared Statements -//===--------------------------------------------------------------------===// +//---------------------------------------------------------------------------------------------------------------------- +// DESCRIPTION: +// Functions to bind values to prepared statements. Try to use `duckdb_bind_value` and the `duckdb_create_...` interface +// for all types. +//---------------------------------------------------------------------------------------------------------------------- /*! Binds a value to the prepared statement at the specified index. @@ -2026,9 +2110,12 @@ Binds a NULL value to the prepared statement at the specified index. */ DUCKDB_C_API duckdb_state duckdb_bind_null(duckdb_prepared_statement prepared_statement, idx_t param_idx); -//===--------------------------------------------------------------------===// +//---------------------------------------------------------------------------------------------------------------------- // Execute Prepared Statements -//===--------------------------------------------------------------------===// +//---------------------------------------------------------------------------------------------------------------------- +// DESCRIPTION: +// Functions to execute a prepared statement. +//---------------------------------------------------------------------------------------------------------------------- /*! Executes the prepared statement with the given bound parameters, and returns a materialized query result. @@ -2066,11 +2153,14 @@ DUCKDB_C_API duckdb_state duckdb_execute_prepared_streaming(duckdb_prepared_stat #endif -//===--------------------------------------------------------------------===// +//---------------------------------------------------------------------------------------------------------------------- // Extract Statements -//===--------------------------------------------------------------------===// +//---------------------------------------------------------------------------------------------------------------------- +// DESCRIPTION: +// A query string can be extracted into multiple SQL statements. Each statement should be prepared and executed +// separately. +//---------------------------------------------------------------------------------------------------------------------- -// A query string can be extracted into multiple SQL statements. Each statement can be prepared and executed separately. /*! Extract all statements from a query. Note that after calling `duckdb_extract_statements`, the extracted statements should always be destroyed using @@ -2119,9 +2209,12 @@ De-allocates all memory allocated for the extracted statements. */ DUCKDB_C_API void duckdb_destroy_extracted(duckdb_extracted_statements *extracted_statements); -//===--------------------------------------------------------------------===// +//---------------------------------------------------------------------------------------------------------------------- // Pending Result Interface -//===--------------------------------------------------------------------===// +//---------------------------------------------------------------------------------------------------------------------- +// DESCRIPTION: +// Functions to interact with a pending result. First, prepare a pending result, and then execute it. +//---------------------------------------------------------------------------------------------------------------------- /*! Executes the prepared statement with the given bound parameters, and returns a pending result. @@ -2224,9 +2317,14 @@ DUCKDB_PENDING_RESULT_READY, this function will return true. */ DUCKDB_C_API bool duckdb_pending_execution_is_finished(duckdb_pending_state pending_state); -//===--------------------------------------------------------------------===// +//---------------------------------------------------------------------------------------------------------------------- // Value Interface -//===--------------------------------------------------------------------===// +//---------------------------------------------------------------------------------------------------------------------- +// DESCRIPTION: +// Functions to create a `duckdb_value` for each of DuckDB's supported data types, and to access the contents of a +// `duckdb_value`. The `duckdb_value` wrapper allows handling of primitive and arbitrarily (nested) types through the +// same interface. +//---------------------------------------------------------------------------------------------------------------------- /*! Destroys the value and de-allocates all memory allocated for that type. @@ -2869,9 +2967,12 @@ Returns the SQL string representation of the given value. */ DUCKDB_C_API char *duckdb_value_to_string(duckdb_value value); -//===--------------------------------------------------------------------===// +//---------------------------------------------------------------------------------------------------------------------- // Logical Type Interface -//===--------------------------------------------------------------------===// +//---------------------------------------------------------------------------------------------------------------------- +// DESCRIPTION: +// Functions to create and interact with `duckdb_logical_type`. +//---------------------------------------------------------------------------------------------------------------------- /*! Creates a `duckdb_logical_type` from a primitive type. @@ -3159,9 +3260,17 @@ The type must have an alias DUCKDB_C_API duckdb_state duckdb_register_logical_type(duckdb_connection con, duckdb_logical_type type, duckdb_create_type_info info); -//===--------------------------------------------------------------------===// +//---------------------------------------------------------------------------------------------------------------------- // Data Chunk Interface -//===--------------------------------------------------------------------===// +//---------------------------------------------------------------------------------------------------------------------- +// DESCRIPTION: +// Functions to interact with `duckdb_data_chunk`. Data chunks pass through the different operators of DuckDB's +// execution engine, when, e.g., executing a scalar function. Additionally, a query result is composed of a sequence of +// data chunks. +// +// A data chunk contains a number of vectors, which, in turn, contain data in a columnar format. For the query result, +// the vectors are the result columns, and they contain the query result for each row. +//---------------------------------------------------------------------------------------------------------------------- /*! Creates an empty data chunk with the specified column types. @@ -3224,9 +3333,13 @@ Sets the current number of tuples in a data chunk. */ DUCKDB_C_API void duckdb_data_chunk_set_size(duckdb_data_chunk chunk, idx_t size); -//===--------------------------------------------------------------------===// +//---------------------------------------------------------------------------------------------------------------------- // Vector Interface -//===--------------------------------------------------------------------===// +//---------------------------------------------------------------------------------------------------------------------- +// DESCRIPTION: +// Functions to interact with `duckdb_vector`. A vector typically (but not always) lives in a data chunk and contains a +// subset of the rows of a column. +//---------------------------------------------------------------------------------------------------------------------- /*! Creates a flat vector. Must be destroyed with `duckdb_destroy_vector`. @@ -3336,23 +3449,26 @@ Returns the size of the child vector of the list. DUCKDB_C_API idx_t duckdb_list_vector_get_size(duckdb_vector vector); /*! -Sets the total size of the underlying child-vector of a list vector. +Sets the size of the underlying child-vector of a list vector. +Note that this does NOT reserve the memory in the child buffer, +and that it is possible to set a size exceeding the capacity. +To set the capacity, use `duckdb_list_vector_reserve`. * @param vector The list vector. * @param size The size of the child list. -* @return The duckdb state. Returns DuckDBError if the vector is nullptr. +* @return The duckdb state. Returns DuckDBError, if the vector is nullptr. */ DUCKDB_C_API duckdb_state duckdb_list_vector_set_size(duckdb_vector vector, idx_t size); /*! -Sets the total capacity of the underlying child-vector of a list. - -After calling this method, you must call `duckdb_vector_get_validity` and `duckdb_vector_get_data` to obtain current -data and validity pointers +Sets the capacity of the underlying child-vector of a list vector. +We increment to the next power of two, based on the required capacity. +Thus, the capacity might not match the size of the list (capacity >= size), +which is set via `duckdb_list_vector_set_size`. * @param vector The list vector. -* @param required_capacity the total capacity to reserve. -* @return The duckdb state. Returns DuckDBError if the vector is nullptr. +* @param required_capacity The child buffer capacity to reserve. +* @return The duckdb state. Returns DuckDBError, if the vector is nullptr. */ DUCKDB_C_API duckdb_state duckdb_list_vector_reserve(duckdb_vector vector, idx_t required_capacity); @@ -3419,9 +3535,13 @@ Changes `to_vector` to reference `from_vector. After, the vectors share ownershi */ DUCKDB_C_API void duckdb_vector_reference_vector(duckdb_vector to_vector, duckdb_vector from_vector); -//===--------------------------------------------------------------------===// +//---------------------------------------------------------------------------------------------------------------------- // Validity Mask Functions -//===--------------------------------------------------------------------===// +//---------------------------------------------------------------------------------------------------------------------- +// DESCRIPTION: +// Functions to interact with the validity mask of a vector. The validity mask is a bitmask determining whether a row in +// a vector is `NULL`, or not. +//---------------------------------------------------------------------------------------------------------------------- /*! Returns whether or not a row is valid (i.e. not NULL) in the given validity mask. @@ -3464,9 +3584,14 @@ Equivalent to `duckdb_validity_set_row_validity` with valid set to true. */ DUCKDB_C_API void duckdb_validity_set_row_valid(uint64_t *validity, idx_t row); -//===--------------------------------------------------------------------===// +//---------------------------------------------------------------------------------------------------------------------- // Scalar Functions -//===--------------------------------------------------------------------===// +//---------------------------------------------------------------------------------------------------------------------- +// DESCRIPTION: +// Functions to create, execute, and register custom scalar functions. Scalar functions take one or more input +// parameters, and return a single output parameter. Consider using a table function, if your scalar function does not +// take any input parameters. +//---------------------------------------------------------------------------------------------------------------------- /*! Creates a new empty scalar function. @@ -3702,9 +3827,14 @@ Returns the input argument at index of the scalar function. */ DUCKDB_C_API duckdb_expression duckdb_scalar_function_bind_get_argument(duckdb_bind_info info, idx_t index); -//===--------------------------------------------------------------------===// +//---------------------------------------------------------------------------------------------------------------------- // Selection Vector Interface -//===--------------------------------------------------------------------===// +//---------------------------------------------------------------------------------------------------------------------- +// DESCRIPTION: +// Functions to interact with `duckdb_selection_vector`. Selection vectors define a selection on top of a vector. Lets +// say that a filter filters out all `VARCHAR`-rows containing `hello`. Then, instead of creating a full new copy of the +// filtered-out data, it is possible to use a selection vector only selecting the rows satisfying the filter. +//---------------------------------------------------------------------------------------------------------------------- /*! Creates a new selection vector of size `size`. @@ -3730,9 +3860,13 @@ Access the data pointer of a selection vector. */ DUCKDB_C_API sel_t *duckdb_selection_vector_get_data_ptr(duckdb_selection_vector sel); -//===--------------------------------------------------------------------===// +//---------------------------------------------------------------------------------------------------------------------- // Aggregate Functions -//===--------------------------------------------------------------------===// +//---------------------------------------------------------------------------------------------------------------------- +// DESCRIPTION: +// Functions to create, execute, and register custom aggregate functions. Aggregate functions aggregate the values of a +// column into an output value. +//---------------------------------------------------------------------------------------------------------------------- /*! Creates a new empty aggregate function. @@ -3887,9 +4021,13 @@ If the set is incomplete or a function with this name already exists DuckDBError DUCKDB_C_API duckdb_state duckdb_register_aggregate_function_set(duckdb_connection con, duckdb_aggregate_function_set set); -//===--------------------------------------------------------------------===// +//---------------------------------------------------------------------------------------------------------------------- // Table Functions -//===--------------------------------------------------------------------===// +//---------------------------------------------------------------------------------------------------------------------- +// DESCRIPTION: +// Functions to create, execute, and register custom table functions. Table functions take one or more input parameters, +// and return one or more output parameters. +//---------------------------------------------------------------------------------------------------------------------- /*! Creates a new empty table function. @@ -4005,9 +4143,13 @@ If the function is incomplete or a function with this name already exists DuckDB */ DUCKDB_C_API duckdb_state duckdb_register_table_function(duckdb_connection con, duckdb_table_function function); -//===--------------------------------------------------------------------===// +//---------------------------------------------------------------------------------------------------------------------- // Table Function Bind -//===--------------------------------------------------------------------===// +//---------------------------------------------------------------------------------------------------------------------- +// DESCRIPTION: +// Functions to implement the bind-phase of a table function. The bind-phase happens once before the execution of the +// table function. It is useful to, e.g., set up any read-only information for the different threads during execution. +//---------------------------------------------------------------------------------------------------------------------- /*! Retrieves the extra info of the function as set in `duckdb_table_function_set_extra_info`. @@ -4090,9 +4232,13 @@ Report that an error has occurred while calling bind on a table function. */ DUCKDB_C_API void duckdb_bind_set_error(duckdb_bind_info info, const char *error); -//===--------------------------------------------------------------------===// +//---------------------------------------------------------------------------------------------------------------------- // Table Function Init -//===--------------------------------------------------------------------===// +//---------------------------------------------------------------------------------------------------------------------- +// DESCRIPTION: +// Functions to implement the init-phase of a table function. The init-phase happens once for each thread and +// initializes thread-local information prior to execution. +//---------------------------------------------------------------------------------------------------------------------- /*! Retrieves the extra info of the function as set in `duckdb_table_function_set_extra_info`. @@ -4159,9 +4305,13 @@ Report that an error has occurred while calling init. */ DUCKDB_C_API void duckdb_init_set_error(duckdb_init_info info, const char *error); -//===--------------------------------------------------------------------===// +//---------------------------------------------------------------------------------------------------------------------- // Table Function -//===--------------------------------------------------------------------===// +//---------------------------------------------------------------------------------------------------------------------- +// DESCRIPTION: +// Functions to implement the execution callback of a table function. The execution callback (i.e., the main function) +// produces a data chunk output based on a data chunk input, and has access to both the bind and init data. +//---------------------------------------------------------------------------------------------------------------------- /*! Retrieves the extra info of the function as set in `duckdb_table_function_set_extra_info`. @@ -4206,9 +4356,13 @@ Report that an error has occurred while executing the function. */ DUCKDB_C_API void duckdb_function_set_error(duckdb_function_info info, const char *error); -//===--------------------------------------------------------------------===// +//---------------------------------------------------------------------------------------------------------------------- // Replacement Scans -//===--------------------------------------------------------------------===// +//---------------------------------------------------------------------------------------------------------------------- +// DESCRIPTION: +// Functions to create, execute, and register a custom replacement scan. A replacement scan is a callback replacing a +// scan of a table that does not exist in the catalog. +//---------------------------------------------------------------------------------------------------------------------- /*! Add a replacement scan definition to the specified database. @@ -4247,9 +4401,12 @@ Report that an error has occurred while executing the replacement scan. */ DUCKDB_C_API void duckdb_replacement_scan_set_error(duckdb_replacement_scan_info info, const char *error); -//===--------------------------------------------------------------------===// +//---------------------------------------------------------------------------------------------------------------------- // Profiling Info -//===--------------------------------------------------------------------===// +//---------------------------------------------------------------------------------------------------------------------- +// DESCRIPTION: +// Functions to access the post-execution profiling information of a query. Only available, if profiling is enabled. +//---------------------------------------------------------------------------------------------------------------------- /*! Returns the root node of the profiling information. Returns nullptr, if profiling is not enabled. @@ -4296,23 +4453,17 @@ Returns the child node at the specified index. */ DUCKDB_C_API duckdb_profiling_info duckdb_profiling_info_get_child(duckdb_profiling_info info, idx_t index); -//===--------------------------------------------------------------------===// +//---------------------------------------------------------------------------------------------------------------------- // Appender -//===--------------------------------------------------------------------===// - -// Appenders are the most efficient way of loading data into DuckDB from within the C API. -// They are recommended for fast data loading as they perform better than prepared statements or individual `INSERT -// INTO` statements. - -// Appends are possible in row-wise format, and by appending entire data chunks. - -// Row-wise: for every column, a `duckdb_append_[type]` call should be made. After finishing all appends to a row, call -// `duckdb_appender_end_row`. - -// Chunk-wise: Consecutively call `duckdb_append_data_chunk` until all chunks have been appended. - -// After all data has been appended, call `duckdb_appender_close` to finalize the appender followed by -// `duckdb_appender_destroy` to clean up the memory. +//---------------------------------------------------------------------------------------------------------------------- +// DESCRIPTION: +// Appenders are the most efficient way of bulk-loading data into DuckDB. They are recommended for fast data loading as +// they perform better than prepared statements or individual `INSERT INTO` statements. Appends are possible in row-wise +// format, and by appending entire data chunks. Try to use chunk-wise appends via `duckdb_append_data_chunk` to ensure +// support for all of DuckDBs data types. Chunk-wise appends consecutively call `duckdb_append_data_chunk` until all +// chunks have been appended. Afterward, call `duckdb_appender_destroy` flush any outstanding data and to destroy the +// appender instance. +//---------------------------------------------------------------------------------------------------------------------- /*! Creates an appender object. @@ -4421,6 +4572,15 @@ duckdb_appender_destroy to destroy the invalidated appender. */ DUCKDB_C_API duckdb_state duckdb_appender_flush(duckdb_appender appender); +/*! +Clears all buffered data from the appender without flushing it to the table. This discards any data that has been +appended but not yet written. The appender can continue to be used after clearing. + +* @param appender The appender to clear. +* @return `DuckDBSuccess` on success or `DuckDBError` on failure. +*/ +DUCKDB_C_API duckdb_state duckdb_appender_clear(duckdb_appender appender); + /*! Closes the appender by flushing all intermediate states and closing it for further appends. If flushing the data triggers a constraint violation or any other error, then all data is invalidated, and this function returns DuckDBError. @@ -4616,9 +4776,12 @@ Appends a pre-filled data chunk to the specified appender. */ DUCKDB_C_API duckdb_state duckdb_append_data_chunk(duckdb_appender appender, duckdb_data_chunk chunk); -//===--------------------------------------------------------------------===// +//---------------------------------------------------------------------------------------------------------------------- // Table Description -//===--------------------------------------------------------------------===// +//---------------------------------------------------------------------------------------------------------------------- +// DESCRIPTION: +// Functions to create and access a `duckdb_table_description` instance. +//---------------------------------------------------------------------------------------------------------------------- /*! Creates a table description object. Note that `duckdb_table_description_destroy` should always be called on the @@ -4675,6 +4838,14 @@ Check if the column at 'index' index of the table has a DEFAULT expression. */ DUCKDB_C_API duckdb_state duckdb_column_has_default(duckdb_table_description table_description, idx_t index, bool *out); +/*! +Return the number of columns of the described table. + +* @param table_description The table_description to query. +* @return The column count. +*/ +DUCKDB_C_API idx_t duckdb_table_description_get_column_count(duckdb_table_description table_description); + /*! Obtain the column name at 'index'. The out result must be destroyed with `duckdb_free`. @@ -4685,9 +4856,23 @@ The out result must be destroyed with `duckdb_free`. */ DUCKDB_C_API char *duckdb_table_description_get_column_name(duckdb_table_description table_description, idx_t index); -//===--------------------------------------------------------------------===// +/*! +Obtain the column type at 'index'. +The return value must be destroyed with `duckdb_destroy_logical_type`. + +* @param table_description The table_description to query. +* @param index The index of the column to query. +* @return The column type. +*/ +DUCKDB_C_API duckdb_logical_type duckdb_table_description_get_column_type(duckdb_table_description table_description, + idx_t index); + +//---------------------------------------------------------------------------------------------------------------------- // Arrow Interface -//===--------------------------------------------------------------------===// +//---------------------------------------------------------------------------------------------------------------------- +// DESCRIPTION: +// Functions to convert from and to Arrow. +//---------------------------------------------------------------------------------------------------------------------- /*! Transforms a DuckDB Schema into an Arrow Schema @@ -4927,9 +5112,12 @@ DUCKDB_C_API duckdb_state duckdb_arrow_array_scan(duckdb_connection connection, #endif -//===--------------------------------------------------------------------===// +//---------------------------------------------------------------------------------------------------------------------- // Threading Information -//===--------------------------------------------------------------------===// +//---------------------------------------------------------------------------------------------------------------------- +// DESCRIPTION: +// Functions to create and execute tasks. +//---------------------------------------------------------------------------------------------------------------------- /*! Execute DuckDB tasks on this thread. @@ -5008,9 +5196,12 @@ Returns true if the execution of the current query is finished. */ DUCKDB_C_API bool duckdb_execution_is_finished(duckdb_connection con); -//===--------------------------------------------------------------------===// +//---------------------------------------------------------------------------------------------------------------------- // Streaming Result Interface -//===--------------------------------------------------------------------===// +//---------------------------------------------------------------------------------------------------------------------- +// DESCRIPTION: +// Functions to stream a `duckdb_result`. Call `duckdb_fetch_chunk` until the result is exhausted. +//---------------------------------------------------------------------------------------------------------------------- #ifndef DUCKDB_API_NO_DEPRECATED /*! @@ -5047,9 +5238,12 @@ It is not known beforehand how many chunks will be returned by this result. */ DUCKDB_C_API duckdb_data_chunk duckdb_fetch_chunk(duckdb_result result); -//===--------------------------------------------------------------------===// +//---------------------------------------------------------------------------------------------------------------------- // Cast Functions -//===--------------------------------------------------------------------===// +//---------------------------------------------------------------------------------------------------------------------- +// DESCRIPTION: +// Functions to create, execute, and register custom cast functions. +//---------------------------------------------------------------------------------------------------------------------- /*! Creates a new cast function object. @@ -5153,9 +5347,13 @@ Destroys the cast function object. */ DUCKDB_C_API void duckdb_destroy_cast_function(duckdb_cast_function *cast_function); -//===--------------------------------------------------------------------===// +//---------------------------------------------------------------------------------------------------------------------- // Expression Interface -//===--------------------------------------------------------------------===// +//---------------------------------------------------------------------------------------------------------------------- +// DESCRIPTION: +// Functions to create and access expressions. Expressions are widespread in DuckDB, especially during query planning. +// E.g., scalar function parameters are expressions, and can be inspected during the bind-phase. +//---------------------------------------------------------------------------------------------------------------------- /*! Destroys the expression and de-allocates its memory. @@ -5191,9 +5389,13 @@ Folds an expression creating a folded value. DUCKDB_C_API duckdb_error_data duckdb_expression_fold(duckdb_client_context context, duckdb_expression expr, duckdb_value *out_value); -//===--------------------------------------------------------------------===// +//---------------------------------------------------------------------------------------------------------------------- // File System Interface -//===--------------------------------------------------------------------===// +//---------------------------------------------------------------------------------------------------------------------- +// DESCRIPTION: +// Functions to access the file system of a connection and to interact with file handles. File handle instances to files +// allow operations such as reading, writing, and seeking in a file. +//---------------------------------------------------------------------------------------------------------------------- /*! Get a file system instance associated with the given client context. @@ -5335,6 +5537,438 @@ Closes the given file handle. */ DUCKDB_C_API duckdb_state duckdb_file_handle_close(duckdb_file_handle file_handle); +//---------------------------------------------------------------------------------------------------------------------- +// Config Options Interface +//---------------------------------------------------------------------------------------------------------------------- +// DESCRIPTION: +// Functions to create, configure, and register custom configuration options. +//---------------------------------------------------------------------------------------------------------------------- + +/*! +Creates a configuration option instance. + +* @return The resulting configuration option instance. Must be destroyed with `duckdb_destroy_config_option`. +*/ +DUCKDB_C_API duckdb_config_option duckdb_create_config_option(); + +/*! +Destroys the given configuration option instance. +* @param option The configuration option instance to destroy. +*/ +DUCKDB_C_API void duckdb_destroy_config_option(duckdb_config_option *option); + +/*! +Sets the name of the configuration option. + +* @param option The configuration option instance. +* @param name The name to set. +*/ +DUCKDB_C_API void duckdb_config_option_set_name(duckdb_config_option option, const char *name); + +/*! +Sets the type of the configuration option. + +* @param option The configuration option instance. +* @param type The type to set. +*/ +DUCKDB_C_API void duckdb_config_option_set_type(duckdb_config_option option, duckdb_logical_type type); + +/*! +Sets the default value of the configuration option. +If the type of this option has already been set with `duckdb_config_option_set_type`, the value is cast to the type. +Otherwise, the type is inferred from the value. + +* @param option The configuration option instance. +* @param default_value The default value to set. +*/ +DUCKDB_C_API void duckdb_config_option_set_default_value(duckdb_config_option option, duckdb_value default_value); + +/*! +Sets the default scope of the configuration option. +If not set, this defaults to `DUCKDB_CONFIG_OPTION_SCOPE_SESSION`. + +* @param option The configuration option instance. +* @param default_scope The default scope to set. +*/ +DUCKDB_C_API void duckdb_config_option_set_default_scope(duckdb_config_option option, + duckdb_config_option_scope default_scope); + +/*! +Sets the description of the configuration option. + +* @param option The configuration option instance. +* @param description The description to set. +*/ +DUCKDB_C_API void duckdb_config_option_set_description(duckdb_config_option option, const char *description); + +/*! +Registers the given configuration option on the specified connection. + +* @param connection The connection to register the option on. +* @param option The configuration option instance to register. +* @return A duckdb_state indicating success or failure. +*/ +DUCKDB_C_API duckdb_state duckdb_register_config_option(duckdb_connection connection, duckdb_config_option option); + +/*! +Retrieves the value of a configuration option by name from the given client context. + +* @param context The client context. +* @param name The name of the configuration option to retrieve. +* @param out_scope Output parameter to optionally store the scope that the configuration option was retrieved from. +If this is `nullptr`, the scope is not returned. +If the requested option does not exist the scope is set to `DUCKDB_CONFIG_OPTION_SCOPE_INVALID`. +* @return The value of the configuration option. Returns `nullptr` if the option does not exist. +*/ +DUCKDB_C_API duckdb_value duckdb_client_context_get_config_option(duckdb_client_context context, const char *name, + duckdb_config_option_scope *out_scope); + +//---------------------------------------------------------------------------------------------------------------------- +// Copy Functions +//---------------------------------------------------------------------------------------------------------------------- +// DESCRIPTION: +// Functions to copy data from and to external file formats. +//---------------------------------------------------------------------------------------------------------------------- + +/*! +Creates a new empty copy function. + +The return value must be destroyed with `duckdb_destroy_copy_function`. + +* @return The copy function object. +*/ +DUCKDB_C_API duckdb_copy_function duckdb_create_copy_function(); + +/*! +Sets the name of the copy function. + +* @param copy_function The copy function +* @param name The name to set +*/ +DUCKDB_C_API void duckdb_copy_function_set_name(duckdb_copy_function copy_function, const char *name); + +/*! +Sets the extra info pointer of the copy function, which can be used to store arbitrary data. + +* @param copy_function The copy function +* @param extra_info The extra info pointer +* @param destructor A destructor function to call to destroy the extra info +*/ +DUCKDB_C_API void duckdb_copy_function_set_extra_info(duckdb_copy_function copy_function, void *extra_info, + duckdb_delete_callback_t destructor); + +/*! +Registers the given copy function on the database connection under the specified name. + +* @param connection The database connection +* @param copy_function The copy function to register +*/ +DUCKDB_C_API duckdb_state duckdb_register_copy_function(duckdb_connection connection, + duckdb_copy_function copy_function); + +/*! +Destroys the given copy function object. +* @param copy_function The copy function to destroy. +*/ +DUCKDB_C_API void duckdb_destroy_copy_function(duckdb_copy_function *copy_function); + +/*! +Sets the bind function of the copy function, to use when binding `COPY ... TO`. + +* @param bind The bind function +*/ +DUCKDB_C_API void duckdb_copy_function_set_bind(duckdb_copy_function copy_function, duckdb_copy_function_bind_t bind); + +/*! +Report that an error occurred during the binding-phase of a `COPY ... TO` function. + +* @param info The bind info provided to the bind function +* @param error The error message +*/ +DUCKDB_C_API void duckdb_copy_function_bind_set_error(duckdb_copy_function_bind_info info, const char *error); + +/*! +Retrieves the extra info pointer of the copy function. + +* @param info The bind info provided to the bind function +* @return The extra info pointer. +*/ +DUCKDB_C_API void *duckdb_copy_function_bind_get_extra_info(duckdb_copy_function_bind_info info); + +/*! +Retrieves the client context of the current connection binding the `COPY ... TO` function. + +Must be destroyed with `duckdb_destroy_client_context` + +* @param info The bind info provided to the bind function +* @return The client context. +*/ +DUCKDB_C_API duckdb_client_context duckdb_copy_function_bind_get_client_context(duckdb_copy_function_bind_info info); + +/*! +Retrieves the number of columns that will be provided to the `COPY ... TO` function. + +* @param info The bind info provided to the bind function +* @return The number of columns. +*/ +DUCKDB_C_API idx_t duckdb_copy_function_bind_get_column_count(duckdb_copy_function_bind_info info); + +/*! +Retrieves the type of a column that will be provided to the `COPY ... TO` function. + +* @param info The bind info provided to the bind function +* @param col_idx The index of the column to retrieve the type for +* @return The type of the column. Must be destroyed with `duckdb_destroy_logical_type`. +*/ +DUCKDB_C_API duckdb_logical_type duckdb_copy_function_bind_get_column_type(duckdb_copy_function_bind_info info, + idx_t col_idx); + +/*! +Retrieves all values for the given options provided to the `COPY ... TO` function. + +* @param info The bind info provided to the bind function +* @return A STRUCT value containing all options as fields. Must be destroyed with `duckdb_destroy_value`. +*/ +DUCKDB_C_API duckdb_value duckdb_copy_function_bind_get_options(duckdb_copy_function_bind_info info); + +/*! +Sets the bind data of the copy function, to be provided to the init, sink and finalize functions. + +* @param info The bind info provided to the bind function +* @param bind_data The bind data pointer +* @param destructor A destructor function to call to destroy the bind data +*/ +DUCKDB_C_API void duckdb_copy_function_bind_set_bind_data(duckdb_copy_function_bind_info info, void *bind_data, + duckdb_delete_callback_t destructor); + +/*! +Sets the initialization function of the copy function, called right before executing `COPY ... TO`. + +* @param init The init function +*/ +DUCKDB_C_API void duckdb_copy_function_set_global_init(duckdb_copy_function copy_function, + duckdb_copy_function_global_init_t init); + +/*! +Report that an error occurred during the initialization-phase of a `COPY ... TO` function. + +* @param info The init info provided to the init function +* @param error The error message +*/ +DUCKDB_C_API void duckdb_copy_function_global_init_set_error(duckdb_copy_function_global_init_info info, + const char *error); + +/*! +Retrieves the extra info pointer of the copy function. + +* @param info The init info provided to the init function +* @return The extra info pointer. +*/ +DUCKDB_C_API void *duckdb_copy_function_global_init_get_extra_info(duckdb_copy_function_global_init_info info); + +/*! +Retrieves the client context of the current connection initializing the `COPY ... TO` function. + +Must be destroyed with `duckdb_destroy_client_context` + +* @param info The init info provided to the init function +* @return The client context. +*/ +DUCKDB_C_API duckdb_client_context +duckdb_copy_function_global_init_get_client_context(duckdb_copy_function_global_init_info info); + +/*! +Retrieves the bind data provided during the binding-phase of a `COPY ... TO` function. + +* @param info The init info provided to the init function +* @return The bind data pointer. +*/ +DUCKDB_C_API void *duckdb_copy_function_global_init_get_bind_data(duckdb_copy_function_global_init_info info); + +/*! +Retrieves the file path provided to the `COPY ... TO` function. + +Lives for the duration of the initialization callback, must not be destroyed. + +* @param info The init info provided to the init function +* @return The file path. +*/ +DUCKDB_C_API const char *duckdb_copy_function_global_init_get_file_path(duckdb_copy_function_global_init_info info); + +/*! +Sets the global state of the copy function, to be provided to all subsequent local init, sink and finalize functions. + +* @param info The init info provided to the init function +* @param global_state The global state pointer +* @param destructor A destructor function to call to destroy the global state +*/ +DUCKDB_C_API void duckdb_copy_function_global_init_set_global_state(duckdb_copy_function_global_init_info info, + void *global_state, + duckdb_delete_callback_t destructor); + +/*! +Sets the sink function of the copy function, called during `COPY ... TO`. + +* @param function The sink function +*/ +DUCKDB_C_API void duckdb_copy_function_set_sink(duckdb_copy_function copy_function, + duckdb_copy_function_sink_t function); + +/*! +Report that an error occurred during the sink-phase of a `COPY ... TO` function. + +* @param info The sink info provided to the sink function +* @param error The error message +*/ +DUCKDB_C_API void duckdb_copy_function_sink_set_error(duckdb_copy_function_sink_info info, const char *error); + +/*! +Retrieves the extra info pointer of the copy function. + +* @param info The sink info provided to the sink function +* @return The extra info pointer. +*/ +DUCKDB_C_API void *duckdb_copy_function_sink_get_extra_info(duckdb_copy_function_sink_info info); + +/*! +Retrieves the client context of the current connection during the sink-phase of the `COPY ... TO` function. + +Must be destroyed with `duckdb_destroy_client_context` + +* @param info The sink info provided to the sink function +* @return The client context. +*/ +DUCKDB_C_API duckdb_client_context duckdb_copy_function_sink_get_client_context(duckdb_copy_function_sink_info info); + +/*! +Retrieves the bind data provided during the binding-phase of a `COPY ... TO` function. + +* @param info The sink info provided to the sink function +* @return The bind data pointer. +*/ +DUCKDB_C_API void *duckdb_copy_function_sink_get_bind_data(duckdb_copy_function_sink_info info); + +/*! +Retrieves the global state provided during the init-phase of a `COPY ... TO` function. + +* @param info The sink info provided to the sink function +* @return The global state pointer. +*/ +DUCKDB_C_API void *duckdb_copy_function_sink_get_global_state(duckdb_copy_function_sink_info info); + +/*! +Sets the finalize function of the copy function, called at the end of `COPY ... TO`. + +* @param finalize The finalize function +*/ +DUCKDB_C_API void duckdb_copy_function_set_finalize(duckdb_copy_function copy_function, + duckdb_copy_function_finalize_t finalize); + +/*! +Report that an error occurred during the finalize-phase of a `COPY ... TO` function + +* @param info The finalize info provided to the finalize function +* @param error The error message +*/ +DUCKDB_C_API void duckdb_copy_function_finalize_set_error(duckdb_copy_function_finalize_info info, const char *error); + +/*! +Retrieves the extra info pointer of the copy function. + +* @param info The finalize info provided to the finalize function +* @return The extra info pointer. +*/ +DUCKDB_C_API void *duckdb_copy_function_finalize_get_extra_info(duckdb_copy_function_finalize_info info); + +/*! +Retrieves the client context of the current connection during the finalize-phase of the `COPY ... TO` function. + +Must be destroyed with `duckdb_destroy_client_context` + +* @param info The finalize info provided to the finalize function +* @return The client context. +*/ +DUCKDB_C_API duckdb_client_context +duckdb_copy_function_finalize_get_client_context(duckdb_copy_function_finalize_info info); + +/*! +Retrieves the bind data provided during the binding-phase of a `COPY ... TO` function. + +* @param info The finalize info provided to the finalize function +* @return The bind data pointer. +*/ +DUCKDB_C_API void *duckdb_copy_function_finalize_get_bind_data(duckdb_copy_function_finalize_info info); + +/*! +Retrieves the global state provided during the init-phase of a `COPY ... TO` function. + +* @param info The finalize info provided to the finalize function +* @return The global state pointer. +*/ +DUCKDB_C_API void *duckdb_copy_function_finalize_get_global_state(duckdb_copy_function_finalize_info info); + +/*! +Sets the table function to use when executing a `COPY ... FROM (...)` statement with this copy function. + +The table function must have a `duckdb_table_function_bind_t`, `duckdb_table_function_init_t` and +`duckdb_table_function_t` set. + +The table function must take a single VARCHAR parameter (the file path). + +Options passed to the `COPY ... FROM (...)` statement are forwarded as named parameters to the table function. + +Since `COPY ... FROM` copies into an already existing table, the table function should not define its own result columns +using `duckdb_bind_add_result_column` when binding . Instead use `duckdb_table_function_bind_get_result_column_count` +and related functions in the bind callback of the table function to retrieve the schema of the target table of the `COPY +... FROM` statement. + +* @param copy_function The copy function +* @param table_function The table function to use for `COPY ... FROM` +*/ +DUCKDB_C_API void duckdb_copy_function_set_copy_from_function(duckdb_copy_function copy_function, + duckdb_table_function table_function); + +/*! +Retrieves the number of result columns of a table function. + +If the table function is used in a `COPY ... FROM` statement, this can be used to retrieve the number of columns in the +target table at the start of the bind callback. + +* @param info The bind info provided to the bind function +* @return The number of result columns. +*/ +DUCKDB_C_API idx_t duckdb_table_function_bind_get_result_column_count(duckdb_bind_info info); + +/*! +Retrieves the name of a result column of a table function. + +If the table function is used in a `COPY ... FROM` statement, this can be used to retrieve the names of the columns in +the target table at the start of the bind callback. + +The result is valid for the duration of the bind callback or until the next call to `duckdb_bind_add_result_column`, so +it must not be destroyed. + +* @param info The bind info provided to the bind function +* @param col_idx The index of the result column to retrieve the name for +* @return The name of the result column. +*/ +DUCKDB_C_API const char *duckdb_table_function_bind_get_result_column_name(duckdb_bind_info info, idx_t col_idx); + +/*! +Retrieves the type of a result column of a table function. + +If the table function is used in a `COPY ... FROM` statement, this can be used to retrieve the types of the columns in +the target table at the start of the bind callback. + +The result must be destroyed with `duckdb_destroy_logical_type`. + +* @param info The bind info provided to the bind function +* @param col_idx The index of the result column to retrieve the type for +* @return The type of the result column. +*/ +DUCKDB_C_API duckdb_logical_type duckdb_table_function_bind_get_result_column_type(duckdb_bind_info info, + idx_t col_idx); + #endif #ifdef __cplusplus diff --git a/src/duckdb/src/include/duckdb/catalog/catalog_entry/duck_table_entry.hpp b/src/duckdb/src/include/duckdb/catalog/catalog_entry/duck_table_entry.hpp index 0cf71fa73..7f206a43d 100644 --- a/src/duckdb/src/include/duckdb/catalog/catalog_entry/duck_table_entry.hpp +++ b/src/duckdb/src/include/duckdb/catalog/catalog_entry/duck_table_entry.hpp @@ -47,7 +47,7 @@ class DuckTableEntry : public TableCatalogEntry { TableFunction GetScanFunction(ClientContext &context, unique_ptr &bind_data) override; - vector GetColumnSegmentInfo() override; + vector GetColumnSegmentInfo(const QueryContext &context) override; TableStorageInfo GetStorageInfo(ClientContext &context) override; diff --git a/src/duckdb/src/include/duckdb/catalog/catalog_entry/table_catalog_entry.hpp b/src/duckdb/src/include/duckdb/catalog/catalog_entry/table_catalog_entry.hpp index 5cab72c59..1e40319cf 100644 --- a/src/duckdb/src/include/duckdb/catalog/catalog_entry/table_catalog_entry.hpp +++ b/src/duckdb/src/include/duckdb/catalog/catalog_entry/table_catalog_entry.hpp @@ -111,7 +111,7 @@ class TableCatalogEntry : public StandardEntry { static string ColumnNamesToSQL(const ColumnList &columns); //! Returns a list of segment information for this table, if exists - virtual vector GetColumnSegmentInfo(); + virtual vector GetColumnSegmentInfo(const QueryContext &context); //! Returns the storage info of this table virtual TableStorageInfo GetStorageInfo(ClientContext &context) = 0; diff --git a/src/duckdb/src/include/duckdb/catalog/default/builtin_types/types.hpp b/src/duckdb/src/include/duckdb/catalog/default/builtin_types/types.hpp index e2265c8c7..f3a71b594 100644 --- a/src/duckdb/src/include/duckdb/catalog/default/builtin_types/types.hpp +++ b/src/duckdb/src/include/duckdb/catalog/default/builtin_types/types.hpp @@ -19,7 +19,7 @@ struct DefaultType { LogicalTypeId type; }; -using builtin_type_array = std::array; +using builtin_type_array = std::array; static constexpr const builtin_type_array BUILTIN_TYPES{{ {"decimal", LogicalTypeId::DECIMAL}, @@ -97,7 +97,8 @@ static constexpr const builtin_type_array BUILTIN_TYPES{{ {"real", LogicalTypeId::FLOAT}, {"float4", LogicalTypeId::FLOAT}, {"double", LogicalTypeId::DOUBLE}, - {"float8", LogicalTypeId::DOUBLE} + {"float8", LogicalTypeId::DOUBLE}, + {"geometry", LogicalTypeId::GEOMETRY} }}; } // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/common/arena_containers/arena_ptr.hpp b/src/duckdb/src/include/duckdb/common/arena_containers/arena_ptr.hpp new file mode 100644 index 000000000..501d0fe37 --- /dev/null +++ b/src/duckdb/src/include/duckdb/common/arena_containers/arena_ptr.hpp @@ -0,0 +1,29 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/common/arena_containers/arena_ptr.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/common/unique_ptr.hpp" + +namespace duckdb { + +//! Call destructor without attempting to free the memory +template +struct arena_deleter { // NOLINT: match stl case + void operator()(T *pointer) { + pointer->~T(); + } +}; + +template +using arena_ptr = unique_ptr>; + +template +using unsafe_arena_ptr = unique_ptr, false>; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/common/arena_containers/arena_vector.hpp b/src/duckdb/src/include/duckdb/common/arena_containers/arena_vector.hpp new file mode 100644 index 000000000..4c05cc37b --- /dev/null +++ b/src/duckdb/src/include/duckdb/common/arena_containers/arena_vector.hpp @@ -0,0 +1,21 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/common/arena_containers/arena_vector.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/common/arena_stl_allocator.hpp" + +namespace duckdb { + +template +using arena_vector = vector>; + +template +using unsafe_arena_vector = vector>; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/common/arena_stl_allocator.hpp b/src/duckdb/src/include/duckdb/common/arena_stl_allocator.hpp new file mode 100644 index 000000000..5f7582df6 --- /dev/null +++ b/src/duckdb/src/include/duckdb/common/arena_stl_allocator.hpp @@ -0,0 +1,92 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/common/arena_stl_allocator.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/storage/arena_allocator.hpp" + +namespace duckdb { + +template +class arena_stl_allocator { // NOLINT: match stl case +public: + //! Typedefs + typedef T value_type; + typedef std::size_t size_type; + typedef std::ptrdiff_t difference_type; + typedef value_type &reference; + typedef value_type const &const_reference; + typedef value_type *pointer; + typedef value_type const *const_pointer; + + //! Propagation traits + using propagate_on_container_copy_assignment = std::true_type; + using propagate_on_container_move_assignment = std::true_type; + using propagate_on_container_swap = std::true_type; + using is_always_equal = std::false_type; + + //! Rebind + template + struct rebind { + using other = arena_stl_allocator; + }; + +public: + arena_stl_allocator(ArenaAllocator &arena_allocator_p) noexcept // NOLINT: allow implicit conversion + : arena_allocator(arena_allocator_p) { + } + template + arena_stl_allocator(const arena_stl_allocator &other) noexcept // NOLINT: allow implicit conversion + : arena_allocator(other.GetAllocator()) { + } + +public: + pointer allocate(size_type n) { // NOLINT: match stl case + arena_allocator.get().AlignNext(); + return reinterpret_cast(arena_allocator.get().Allocate(n * sizeof(T))); + } + + void deallocate(pointer p, size_type n) noexcept { // NOLINT: match stl case + } + + template + void construct(U *p, Args &&...args) { // NOLINT: match stl case + ::new (p) U(std::forward(args)...); + } + + template + void destroy(U *p) noexcept { // NOLINT: match stl case + p->~U(); + } + + pointer address(reference x) const { // NOLINT: match stl case + return &x; + } + + const_pointer address(const_reference x) const { // NOLINT: match stl case + return &x; + } + + ArenaAllocator &GetAllocator() const { + return arena_allocator.get(); + } + +public: + bool operator==(const arena_stl_allocator &other) const noexcept { + return RefersToSameObject(arena_allocator, other.arena_allocator); + } + bool operator!=(const arena_stl_allocator &other) const noexcept { + return !(*this == other); + } + +private: + //! Need to use std::reference_wrapper because "reference" is already a typedef + std::reference_wrapper arena_allocator; +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/common/arrow/arrow_query_result.hpp b/src/duckdb/src/include/duckdb/common/arrow/arrow_query_result.hpp index 811a410a5..3f736731b 100644 --- a/src/duckdb/src/include/duckdb/common/arrow/arrow_query_result.hpp +++ b/src/duckdb/src/include/duckdb/common/arrow/arrow_query_result.hpp @@ -31,10 +31,6 @@ class ArrowQueryResult : public QueryResult { DUCKDB_API explicit ArrowQueryResult(ErrorData error); public: - //! Fetches a DataChunk from the query result. - //! This will consume the result (i.e. the result can only be scanned once with this function) - DUCKDB_API unique_ptr Fetch() override; - DUCKDB_API unique_ptr FetchRaw() override; //! Converts the QueryResult to a string DUCKDB_API string ToString() override; @@ -44,6 +40,9 @@ class ArrowQueryResult : public QueryResult { void SetArrowData(vector> arrays); idx_t BatchSize() const; +protected: + DUCKDB_API unique_ptr FetchInternal() override; + private: vector> arrays; idx_t batch_size; diff --git a/src/duckdb/src/include/duckdb/common/arrow/physical_arrow_collector.hpp b/src/duckdb/src/include/duckdb/common/arrow/physical_arrow_collector.hpp index 3bd89e67f..d659235d2 100644 --- a/src/duckdb/src/include/duckdb/common/arrow/physical_arrow_collector.hpp +++ b/src/duckdb/src/include/duckdb/common/arrow/physical_arrow_collector.hpp @@ -47,7 +47,7 @@ class PhysicalArrowCollector : public PhysicalResultCollector { static PhysicalOperator &Create(ClientContext &context, PreparedStatementData &data, idx_t batch_size); SinkResultType Sink(ExecutionContext &context, DataChunk &chunk, OperatorSinkInput &input) const override; SinkCombineResultType Combine(ExecutionContext &context, OperatorSinkCombineInput &input) const override; - unique_ptr GetResult(GlobalSinkState &state) override; + unique_ptr GetResult(GlobalSinkState &state) const override; unique_ptr GetGlobalSinkState(ClientContext &context) const override; unique_ptr GetLocalSinkState(ExecutionContext &context) const override; SinkFinalizeType Finalize(Pipeline &pipeline, Event &event, ClientContext &context, diff --git a/src/duckdb/src/include/duckdb/common/assert.hpp b/src/duckdb/src/include/duckdb/common/assert.hpp index dbf0744e7..4bf4b90e5 100644 --- a/src/duckdb/src/include/duckdb/common/assert.hpp +++ b/src/duckdb/src/include/duckdb/common/assert.hpp @@ -38,3 +38,6 @@ DUCKDB_API void DuckDBAssertInternal(bool condition, const char *condition_name, #define D_ASSERT_IS_ENABLED #endif + +//! Force assertion implementation, which always asserts whatever build type is used. +#define ALWAYS_ASSERT(condition) duckdb::DuckDBAssertInternal(bool(condition), #condition, __FILE__, __LINE__) diff --git a/src/duckdb/src/include/duckdb/common/bitpacking.hpp b/src/duckdb/src/include/duckdb/common/bitpacking.hpp index 06d12882a..618caf0e8 100644 --- a/src/duckdb/src/include/duckdb/common/bitpacking.hpp +++ b/src/duckdb/src/include/duckdb/common/bitpacking.hpp @@ -25,7 +25,6 @@ struct HugeIntPacker { }; class BitpackingPrimitives { - public: static constexpr const idx_t BITPACKING_ALGORITHM_GROUP_SIZE = 32; static constexpr const idx_t BITPACKING_HEADER_SIZE = sizeof(uint64_t); @@ -61,7 +60,6 @@ class BitpackingPrimitives { template inline static void UnPackBuffer(data_ptr_t dst, data_ptr_t src, idx_t count, bitpacking_width_t width, bool skip_sign_extension = false) { - for (idx_t i = 0; i < count; i += BITPACKING_ALGORITHM_GROUP_SIZE) { UnPackGroup(dst + i * sizeof(T), src + (i * width) / 8, width, skip_sign_extension); } diff --git a/src/duckdb/src/include/duckdb/common/csv_writer.hpp b/src/duckdb/src/include/duckdb/common/csv_writer.hpp index b2d0e066e..188d1de7f 100644 --- a/src/duckdb/src/include/duckdb/common/csv_writer.hpp +++ b/src/duckdb/src/include/duckdb/common/csv_writer.hpp @@ -90,9 +90,6 @@ class CSVWriter { //! Closes the writer, optionally writes a postfix void Close(); - unique_ptr InitializeLocalWriteState(ClientContext &context, idx_t flush_size); - unique_ptr InitializeLocalWriteState(DatabaseInstance &db, idx_t flush_size); - vector> string_casts; idx_t BytesWritten(); diff --git a/src/duckdb/src/include/duckdb/common/deque.hpp b/src/duckdb/src/include/duckdb/common/deque.hpp index f5c8ba990..6b5d38826 100644 --- a/src/duckdb/src/include/duckdb/common/deque.hpp +++ b/src/duckdb/src/include/duckdb/common/deque.hpp @@ -8,8 +8,115 @@ #pragma once +#include "duckdb/common/assert.hpp" +#include "duckdb/common/typedefs.hpp" +#include "duckdb/common/likely.hpp" +#include "duckdb/common/exception.hpp" +#include "duckdb/common/memory_safety.hpp" #include namespace duckdb { -using std::deque; -} + +template +class deque : public std::deque> { // NOLINT: matching name of std +public: + using original = std::deque>; + using original::original; + using value_type = typename original::value_type; + using allocator_type = typename original::allocator_type; + using size_type = typename original::size_type; + using difference_type = typename original::difference_type; + using reference = typename original::reference; + using const_reference = typename original::const_reference; + using pointer = typename original::pointer; + using const_pointer = typename original::const_pointer; + using iterator = typename original::iterator; + using const_iterator = typename original::const_iterator; + using reverse_iterator = typename original::reverse_iterator; + using const_reverse_iterator = typename original::const_reverse_iterator; + +private: + static inline void AssertIndexInBounds(idx_t index, idx_t size) { +#if defined(DUCKDB_DEBUG_NO_SAFETY) || defined(DUCKDB_CLANG_TIDY) + return; +#else + if (DUCKDB_UNLIKELY(index >= size)) { + throw InternalException("Attempted to access index %ld within deque of size %ld", index, size); + } +#endif + } + +public: +#ifdef DUCKDB_CLANG_TIDY + [[clang::reinitializes]] +#endif + inline void + clear() noexcept { // NOLINT: hiding on purpose + original::clear(); + } + + // Because we create the other constructor, the implicitly created constructor + // gets deleted, so we have to be explicit + deque() = default; + deque(original &&other) : original(std::move(other)) { // NOLINT: allow implicit conversion + } + template + deque(deque &&other) : original(std::move(other)) { // NOLINT: allow implicit conversion + } + + template + inline typename original::reference get(typename original::size_type __n) { // NOLINT: hiding on purpose + if (MemorySafety::ENABLED) { + AssertIndexInBounds(__n, original::size()); + } + return original::operator[](__n); + } + + template + inline typename original::const_reference get(typename original::size_type __n) const { // NOLINT: hiding on purpose + if (MemorySafety::ENABLED) { + AssertIndexInBounds(__n, original::size()); + } + return original::operator[](__n); + } + + typename original::reference operator[](typename original::size_type __n) { // NOLINT: hiding on purpose + return get(__n); + } + typename original::const_reference operator[](typename original::size_type __n) const { // NOLINT: hiding on purpose + return get(__n); + } + + typename original::reference front() { // NOLINT: hiding on purpose + if (MemorySafety::ENABLED && original::empty()) { + throw InternalException("'front' called on an empty deque!"); + } + return get(0); + } + + typename original::const_reference front() const { // NOLINT: hiding on purpose + if (MemorySafety::ENABLED && original::empty()) { + throw InternalException("'front' called on an empty deque!"); + } + return get(0); + } + + typename original::reference back() { // NOLINT: hiding on purpose + if (MemorySafety::ENABLED && original::empty()) { + throw InternalException("'back' called on an empty deque!"); + } + return get(original::size() - 1); + } + + typename original::const_reference back() const { // NOLINT: hiding on purpose + if (MemorySafety::ENABLED && original::empty()) { + throw InternalException("'back' called on an empty deque!"); + } + return get(original::size() - 1); + } +}; + +template +using unsafe_deque = deque; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/common/encryption_functions.hpp b/src/duckdb/src/include/duckdb/common/encryption_functions.hpp index 07f50b98c..b43544b92 100644 --- a/src/duckdb/src/include/duckdb/common/encryption_functions.hpp +++ b/src/duckdb/src/include/duckdb/common/encryption_functions.hpp @@ -27,7 +27,6 @@ struct EncryptionNonce { }; class EncryptionEngine { - public: EncryptionEngine(); ~EncryptionEngine(); diff --git a/src/duckdb/src/include/duckdb/common/encryption_key_manager.hpp b/src/duckdb/src/include/duckdb/common/encryption_key_manager.hpp index 55c3aed75..9e41ee8c1 100644 --- a/src/duckdb/src/include/duckdb/common/encryption_key_manager.hpp +++ b/src/duckdb/src/include/duckdb/common/encryption_key_manager.hpp @@ -17,7 +17,6 @@ namespace duckdb { class EncryptionKey { - public: explicit EncryptionKey(data_ptr_t encryption_key); ~EncryptionKey(); @@ -42,7 +41,6 @@ class EncryptionKey { }; class EncryptionKeyManager : public ObjectCacheEntry { - public: static EncryptionKeyManager &GetInternal(ObjectCache &cache); static EncryptionKeyManager &Get(ClientContext &context); diff --git a/src/duckdb/src/include/duckdb/common/encryption_state.hpp b/src/duckdb/src/include/duckdb/common/encryption_state.hpp index 32c0597a9..2563c0bde 100644 --- a/src/duckdb/src/include/duckdb/common/encryption_state.hpp +++ b/src/duckdb/src/include/duckdb/common/encryption_state.hpp @@ -14,7 +14,6 @@ namespace duckdb { class EncryptionTypes { - public: enum CipherType : uint8_t { INVALID = 0, GCM = 1, CTR = 2, CBC = 3 }; enum KeyDerivationFunction : uint8_t { DEFAULT = 0, SHA256 = 1, PBKDF2 = 2 }; @@ -27,7 +26,6 @@ class EncryptionTypes { }; class EncryptionState { - public: DUCKDB_API explicit EncryptionState(EncryptionTypes::CipherType cipher_p, idx_t key_len); DUCKDB_API virtual ~EncryptionState(); @@ -47,7 +45,6 @@ class EncryptionState { }; class EncryptionUtil { - public: DUCKDB_API explicit EncryptionUtil() {}; diff --git a/src/duckdb/src/include/duckdb/common/enum_util.hpp b/src/duckdb/src/include/duckdb/common/enum_util.hpp index d07e93d02..de103462f 100644 --- a/src/duckdb/src/include/duckdb/common/enum_util.hpp +++ b/src/duckdb/src/include/duckdb/common/enum_util.hpp @@ -78,6 +78,10 @@ enum class ArrowTypeInfoType : uint8_t; enum class ArrowVariableSizeType : uint8_t; +enum class AsyncResultType : uint8_t; + +enum class AsyncResultsExecutionMode : uint8_t; + enum class BinderType : uint8_t; enum class BindingMode : uint8_t; @@ -88,6 +92,8 @@ enum class BlockIteratorStateType : int8_t; enum class BlockState : uint8_t; +enum class BufferedIndexReplay : uint8_t; + enum class CAPIResultSetType : uint8_t; enum class CSVState : uint8_t; @@ -202,6 +208,8 @@ enum class FunctionStability : uint8_t; enum class GateStatus : uint8_t; +enum class GeometryType : uint8_t; + enum class HLLStorageType : uint8_t; enum class HTTPStatusCode : uint16_t; @@ -294,8 +302,6 @@ enum class ParseInfoType : uint8_t; enum class ParserExtensionResultType : uint8_t; -enum class PartitionSortStage : uint8_t; - enum class PartitionedColumnDataType : uint8_t; enum class PartitionedTupleDataType : uint8_t; @@ -304,6 +310,8 @@ enum class PendingExecutionResult : uint8_t; enum class PhysicalOperatorType : uint8_t; +enum class PhysicalTableScanExecutionStrategy : uint8_t; + enum class PhysicalType : uint8_t; enum class PragmaType : uint8_t; @@ -322,8 +330,14 @@ enum class QuantileSerializationType : uint8_t; enum class QueryNodeType : uint8_t; +enum class QueryResultMemoryType : uint8_t; + +enum class QueryResultOutputType : uint8_t; + enum class QueryResultType : uint8_t; +enum class RecoveryMode : uint8_t; + enum class RelationType : uint8_t; enum class RenderMode : uint8_t; @@ -440,6 +454,8 @@ enum class VerificationType : uint8_t; enum class VerifyExistenceType : uint8_t; +enum class VertexType : uint8_t; + enum class WALType : uint8_t; enum class WindowAggregationMode : uint32_t; @@ -520,6 +536,12 @@ const char* EnumUtil::ToChars(ArrowTypeInfoType value); template<> const char* EnumUtil::ToChars(ArrowVariableSizeType value); +template<> +const char* EnumUtil::ToChars(AsyncResultType value); + +template<> +const char* EnumUtil::ToChars(AsyncResultsExecutionMode value); + template<> const char* EnumUtil::ToChars(BinderType value); @@ -535,6 +557,9 @@ const char* EnumUtil::ToChars(BlockIteratorStateType val template<> const char* EnumUtil::ToChars(BlockState value); +template<> +const char* EnumUtil::ToChars(BufferedIndexReplay value); + template<> const char* EnumUtil::ToChars(CAPIResultSetType value); @@ -706,6 +731,9 @@ const char* EnumUtil::ToChars(FunctionStability value); template<> const char* EnumUtil::ToChars(GateStatus value); +template<> +const char* EnumUtil::ToChars(GeometryType value); + template<> const char* EnumUtil::ToChars(HLLStorageType value); @@ -844,9 +872,6 @@ const char* EnumUtil::ToChars(ParseInfoType value); template<> const char* EnumUtil::ToChars(ParserExtensionResultType value); -template<> -const char* EnumUtil::ToChars(PartitionSortStage value); - template<> const char* EnumUtil::ToChars(PartitionedColumnDataType value); @@ -859,6 +884,9 @@ const char* EnumUtil::ToChars(PendingExecutionResult val template<> const char* EnumUtil::ToChars(PhysicalOperatorType value); +template<> +const char* EnumUtil::ToChars(PhysicalTableScanExecutionStrategy value); + template<> const char* EnumUtil::ToChars(PhysicalType value); @@ -886,9 +914,18 @@ const char* EnumUtil::ToChars(QuantileSerializationTy template<> const char* EnumUtil::ToChars(QueryNodeType value); +template<> +const char* EnumUtil::ToChars(QueryResultMemoryType value); + +template<> +const char* EnumUtil::ToChars(QueryResultOutputType value); + template<> const char* EnumUtil::ToChars(QueryResultType value); +template<> +const char* EnumUtil::ToChars(RecoveryMode value); + template<> const char* EnumUtil::ToChars(RelationType value); @@ -1063,6 +1100,9 @@ const char* EnumUtil::ToChars(VerificationType value); template<> const char* EnumUtil::ToChars(VerifyExistenceType value); +template<> +const char* EnumUtil::ToChars(VertexType value); + template<> const char* EnumUtil::ToChars(WALType value); @@ -1148,6 +1188,12 @@ ArrowTypeInfoType EnumUtil::FromString(const char *value); template<> ArrowVariableSizeType EnumUtil::FromString(const char *value); +template<> +AsyncResultType EnumUtil::FromString(const char *value); + +template<> +AsyncResultsExecutionMode EnumUtil::FromString(const char *value); + template<> BinderType EnumUtil::FromString(const char *value); @@ -1163,6 +1209,9 @@ BlockIteratorStateType EnumUtil::FromString(const char * template<> BlockState EnumUtil::FromString(const char *value); +template<> +BufferedIndexReplay EnumUtil::FromString(const char *value); + template<> CAPIResultSetType EnumUtil::FromString(const char *value); @@ -1334,6 +1383,9 @@ FunctionStability EnumUtil::FromString(const char *value); template<> GateStatus EnumUtil::FromString(const char *value); +template<> +GeometryType EnumUtil::FromString(const char *value); + template<> HLLStorageType EnumUtil::FromString(const char *value); @@ -1472,9 +1524,6 @@ ParseInfoType EnumUtil::FromString(const char *value); template<> ParserExtensionResultType EnumUtil::FromString(const char *value); -template<> -PartitionSortStage EnumUtil::FromString(const char *value); - template<> PartitionedColumnDataType EnumUtil::FromString(const char *value); @@ -1487,6 +1536,9 @@ PendingExecutionResult EnumUtil::FromString(const char * template<> PhysicalOperatorType EnumUtil::FromString(const char *value); +template<> +PhysicalTableScanExecutionStrategy EnumUtil::FromString(const char *value); + template<> PhysicalType EnumUtil::FromString(const char *value); @@ -1514,9 +1566,18 @@ QuantileSerializationType EnumUtil::FromString(const template<> QueryNodeType EnumUtil::FromString(const char *value); +template<> +QueryResultMemoryType EnumUtil::FromString(const char *value); + +template<> +QueryResultOutputType EnumUtil::FromString(const char *value); + template<> QueryResultType EnumUtil::FromString(const char *value); +template<> +RecoveryMode EnumUtil::FromString(const char *value); + template<> RelationType EnumUtil::FromString(const char *value); @@ -1691,6 +1752,9 @@ VerificationType EnumUtil::FromString(const char *value); template<> VerifyExistenceType EnumUtil::FromString(const char *value); +template<> +VertexType EnumUtil::FromString(const char *value); + template<> WALType EnumUtil::FromString(const char *value); diff --git a/src/duckdb/src/include/duckdb/common/enums/compression_type.hpp b/src/duckdb/src/include/duckdb/common/enums/compression_type.hpp index 1dda5ee64..5198f7627 100644 --- a/src/duckdb/src/include/duckdb/common/enums/compression_type.hpp +++ b/src/duckdb/src/include/duckdb/common/enums/compression_type.hpp @@ -36,8 +36,48 @@ enum class CompressionType : uint8_t { COMPRESSION_COUNT // This has to stay the last entry of the type! }; -bool CompressionTypeIsDeprecated(CompressionType compression_type, - optional_ptr storage_manager = nullptr); +struct CompressionAvailabilityResult { +private: + enum class UnavailableReason : uint8_t { + AVAILABLE, + //! Introduced later, not available to this version + NOT_AVAILABLE_YET, + //! Used to be available, but isnt anymore + DEPRECATED + }; + +public: + CompressionAvailabilityResult() = default; + static CompressionAvailabilityResult Deprecated() { + return CompressionAvailabilityResult(UnavailableReason::DEPRECATED); + } + static CompressionAvailabilityResult NotAvailableYet() { + return CompressionAvailabilityResult(UnavailableReason::NOT_AVAILABLE_YET); + } + +public: + bool IsAvailable() const { + return reason == UnavailableReason::AVAILABLE; + } + bool IsDeprecated() { + D_ASSERT(!IsAvailable()); + return reason == UnavailableReason::DEPRECATED; + } + bool IsNotAvailableYet() { + D_ASSERT(!IsAvailable()); + return reason == UnavailableReason::NOT_AVAILABLE_YET; + } + +private: + explicit CompressionAvailabilityResult(UnavailableReason reason) : reason(reason) { + } + +public: + UnavailableReason reason = UnavailableReason::AVAILABLE; +}; + +CompressionAvailabilityResult CompressionTypeIsAvailable(CompressionType compression_type, + optional_ptr storage_manager = nullptr); vector ListCompressionTypes(void); CompressionType CompressionTypeFromString(const string &str); string CompressionTypeToString(CompressionType type); diff --git a/src/duckdb/src/include/duckdb/common/enums/explain_format.hpp b/src/duckdb/src/include/duckdb/common/enums/explain_format.hpp index 6635ca454..149e8c33f 100644 --- a/src/duckdb/src/include/duckdb/common/enums/explain_format.hpp +++ b/src/duckdb/src/include/duckdb/common/enums/explain_format.hpp @@ -12,6 +12,6 @@ namespace duckdb { -enum class ExplainFormat : uint8_t { DEFAULT, TEXT, JSON, HTML, GRAPHVIZ, YAML }; +enum class ExplainFormat : uint8_t { DEFAULT, TEXT, JSON, HTML, GRAPHVIZ, YAML, MERMAID }; } // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/common/enums/metric_type.hpp b/src/duckdb/src/include/duckdb/common/enums/metric_type.hpp index 8fd2790ab..bc4900d9c 100644 --- a/src/duckdb/src/include/duckdb/common/enums/metric_type.hpp +++ b/src/duckdb/src/include/duckdb/common/enums/metric_type.hpp @@ -20,32 +20,38 @@ namespace duckdb { enum class MetricsType : uint8_t { - QUERY_NAME, + ATTACH_LOAD_STORAGE_LATENCY, + ATTACH_REPLAY_WAL_LATENCY, BLOCKED_THREAD_TIME, + CHECKPOINT_LATENCY, + COMMIT_WRITE_WAL_LATENCY, CPU_TIME, - EXTRA_INFO, CUMULATIVE_CARDINALITY, - OPERATOR_TYPE, - OPERATOR_CARDINALITY, CUMULATIVE_ROWS_SCANNED, + EXTRA_INFO, + LATENCY, + OPERATOR_CARDINALITY, + OPERATOR_NAME, OPERATOR_ROWS_SCANNED, OPERATOR_TIMING, + OPERATOR_TYPE, + QUERY_NAME, RESULT_SET_SIZE, - LATENCY, ROWS_RETURNED, - OPERATOR_NAME, SYSTEM_PEAK_BUFFER_MEMORY, SYSTEM_PEAK_TEMP_DIR_SIZE, TOTAL_BYTES_READ, TOTAL_BYTES_WRITTEN, + WAITING_TO_ATTACH_LATENCY, + WAL_REPLAY_ENTRY_COUNT, ALL_OPTIMIZERS, CUMULATIVE_OPTIMIZER_TIMING, - PLANNER, - PLANNER_BINDING, PHYSICAL_PLANNER, PHYSICAL_PLANNER_COLUMN_BINDING, - PHYSICAL_PLANNER_RESOLVE_TYPES, PHYSICAL_PLANNER_CREATE_PLAN, + PHYSICAL_PLANNER_RESOLVE_TYPES, + PLANNER, + PLANNER_BINDING, OPTIMIZER_EXPRESSION_REWRITER, OPTIMIZER_FILTER_PULLUP, OPTIMIZER_FILTER_PUSHDOWN, @@ -64,6 +70,7 @@ enum class MetricsType : uint8_t { OPTIMIZER_BUILD_SIDE_PROBE_SIDE, OPTIMIZER_LIMIT_PUSHDOWN, OPTIMIZER_TOP_N, + OPTIMIZER_TOP_N_WINDOW_ELIMINATION, OPTIMIZER_COMPRESSED_MATERIALIZATION, OPTIMIZER_DUPLICATE_GROUPS, OPTIMIZER_REORDER_FILTER, @@ -74,6 +81,7 @@ enum class MetricsType : uint8_t { OPTIMIZER_SUM_REWRITER, OPTIMIZER_LATE_MATERIALIZATION, OPTIMIZER_CTE_INLINING, + OPTIMIZER_COMMON_SUBPLAN, }; struct MetricsTypeHashFunction { diff --git a/src/duckdb/src/include/duckdb/common/enums/operator_result_type.hpp b/src/duckdb/src/include/duckdb/common/enums/operator_result_type.hpp index c539a5636..5a161dcbe 100644 --- a/src/duckdb/src/include/duckdb/common/enums/operator_result_type.hpp +++ b/src/duckdb/src/include/duckdb/common/enums/operator_result_type.hpp @@ -42,6 +42,20 @@ enum class OperatorFinalResultType : uint8_t { FINISHED, BLOCKED }; //! BLOCKED means the source is currently blocked, e.g. by some async I/O enum class SourceResultType : uint8_t { HAVE_MORE_OUTPUT, FINISHED, BLOCKED }; +//! AsyncResultType is used to indicate the result of a AsyncResult, in the context of a wider operation being executed +enum class AsyncResultType : uint8_t { + INVALID, // current result is in an invalid state (eg: it's in the process of being initialized) + IMPLICIT, // current result depends on external context (eg: in the context of TableFunctions, either FINISHED or + // HAVE_MORE_OUTPUT depending on output_chunk.size()) + HAVE_MORE_OUTPUT, // current result is not completed, finished (eg: in the context of TableFunctions, function + // accept more iterations and might produce further results) + FINISHED, // current result is completed, no subsequent calls on the same state should be attempted + BLOCKED // current result is blocked, no subsequent calls on the same state should be attempted (eg: in the context + // of AsyncResult, BLOCKED will be associated with a vector of AsyncTasks to be scheduled) +}; + +bool ExtractSourceResultType(AsyncResultType in, SourceResultType &out); + //! The SinkResultType is used to indicate the result of data flowing into a sink //! There are three possible results: //! NEED_MORE_INPUT means the sink needs more input diff --git a/src/duckdb/src/include/duckdb/common/enums/optimizer_type.hpp b/src/duckdb/src/include/duckdb/common/enums/optimizer_type.hpp index b57823028..82675c7d5 100644 --- a/src/duckdb/src/include/duckdb/common/enums/optimizer_type.hpp +++ b/src/duckdb/src/include/duckdb/common/enums/optimizer_type.hpp @@ -33,6 +33,7 @@ enum class OptimizerType : uint32_t { BUILD_SIDE_PROBE_SIDE, LIMIT_PUSHDOWN, TOP_N, + TOP_N_WINDOW_ELIMINATION, COMPRESSED_MATERIALIZATION, DUPLICATE_GROUPS, REORDER_FILTER, @@ -42,7 +43,8 @@ enum class OptimizerType : uint32_t { MATERIALIZED_CTE, SUM_REWRITER, LATE_MATERIALIZATION, - CTE_INLINING + CTE_INLINING, + COMMON_SUBPLAN, }; string OptimizerTypeToString(OptimizerType type); diff --git a/src/duckdb/src/include/duckdb/common/enums/profiler_format.hpp b/src/duckdb/src/include/duckdb/common/enums/profiler_format.hpp index 1a416d546..9cd1206b9 100644 --- a/src/duckdb/src/include/duckdb/common/enums/profiler_format.hpp +++ b/src/duckdb/src/include/duckdb/common/enums/profiler_format.hpp @@ -12,6 +12,6 @@ namespace duckdb { -enum class ProfilerPrintFormat : uint8_t { QUERY_TREE, JSON, QUERY_TREE_OPTIMIZER, NO_OUTPUT, HTML, GRAPHVIZ }; +enum class ProfilerPrintFormat : uint8_t { QUERY_TREE, JSON, QUERY_TREE_OPTIMIZER, NO_OUTPUT, HTML, GRAPHVIZ, MERMAID }; } // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/common/enums/statement_type.hpp b/src/duckdb/src/include/duckdb/common/enums/statement_type.hpp index b6c9d08ae..1acf98b24 100644 --- a/src/duckdb/src/include/duckdb/common/enums/statement_type.hpp +++ b/src/duckdb/src/include/duckdb/common/enums/statement_type.hpp @@ -11,6 +11,7 @@ #include "duckdb/common/constants.hpp" #include "duckdb/common/optional_idx.hpp" #include "duckdb/common/unordered_set.hpp" +#include "duckdb/main/query_parameters.hpp" namespace duckdb { @@ -67,8 +68,9 @@ class ClientContext; //! A struct containing various properties of a SQL statement struct StatementProperties { StatementProperties() - : requires_valid_transaction(true), allow_stream_result(false), bound_all_parameters(true), - return_type(StatementReturnType::QUERY_RESULT), parameter_count(0), always_require_rebind(false) { + : requires_valid_transaction(true), output_type(QueryResultOutputType::FORCE_MATERIALIZED), + bound_all_parameters(true), return_type(StatementReturnType::QUERY_RESULT), parameter_count(0), + always_require_rebind(false) { } struct CatalogIdentity { @@ -92,7 +94,7 @@ struct StatementProperties { //! exception of ROLLBACK bool requires_valid_transaction; //! Whether or not the result can be streamed to the client - bool allow_stream_result; + QueryResultOutputType output_type; //! Whether or not all parameters have successfully had their types determined bool bound_all_parameters; //! What type of data the statement returns diff --git a/src/duckdb/src/include/duckdb/common/exception.hpp b/src/duckdb/src/include/duckdb/common/exception.hpp index 480dd2385..17d430201 100644 --- a/src/duckdb/src/include/duckdb/common/exception.hpp +++ b/src/duckdb/src/include/duckdb/common/exception.hpp @@ -94,15 +94,16 @@ enum class ExceptionType : uint8_t { class Exception : public std::runtime_error { public: DUCKDB_API Exception(ExceptionType exception_type, const string &message); - DUCKDB_API Exception(ExceptionType exception_type, const string &message, - const unordered_map &extra_info); + + DUCKDB_API Exception(const unordered_map &extra_info, ExceptionType exception_type, + const string &message); public: DUCKDB_API static string ExceptionTypeToString(ExceptionType type); DUCKDB_API static ExceptionType StringToExceptionType(const string &type); template - static string ConstructMessage(const string &msg, ARGS... params) { + static string ConstructMessage(const string &msg, ARGS const &...params) { const std::size_t num_args = sizeof...(ARGS); if (num_args == 0) { return msg; @@ -122,8 +123,9 @@ class Exception : public std::runtime_error { //! Whether this exception type can occur during execution of a query DUCKDB_API static bool IsExecutionError(ExceptionType type); DUCKDB_API static string ToJSON(ExceptionType type, const string &message); - DUCKDB_API static string ToJSON(ExceptionType type, const string &message, - const unordered_map &extra_info); + + DUCKDB_API static string ToJSON(const unordered_map &extra_info, ExceptionType type, + const string &message); DUCKDB_API static bool InvalidatesTransaction(ExceptionType exception_type); DUCKDB_API static bool InvalidatesDatabase(ExceptionType exception_type); @@ -131,8 +133,8 @@ class Exception : public std::runtime_error { DUCKDB_API static string ConstructMessageRecursive(const string &msg, std::vector &values); template - static string ConstructMessageRecursive(const string &msg, std::vector &values, T param, - ARGS... params) { + static string ConstructMessageRecursive(const string &msg, std::vector &values, + const T ¶m, ARGS &&...params) { values.push_back(ExceptionFormatValue::CreateFormatValue(param)); return ConstructMessageRecursive(msg, values, params...); } @@ -155,8 +157,8 @@ class ConnectionException : public Exception { DUCKDB_API explicit ConnectionException(const string &msg); template - explicit ConnectionException(const string &msg, ARGS... params) - : ConnectionException(ConstructMessage(msg, params...)) { + explicit ConnectionException(const string &msg, ARGS &&...params) + : ConnectionException(ConstructMessage(msg, std::forward(params)...)) { } }; @@ -165,8 +167,8 @@ class PermissionException : public Exception { DUCKDB_API explicit PermissionException(const string &msg); template - explicit PermissionException(const string &msg, ARGS... params) - : PermissionException(ConstructMessage(msg, params...)) { + explicit PermissionException(const string &msg, ARGS &&...params) + : PermissionException(ConstructMessage(msg, std::forward(params)...)) { } }; @@ -175,8 +177,8 @@ class OutOfRangeException : public Exception { DUCKDB_API explicit OutOfRangeException(const string &msg); template - explicit OutOfRangeException(const string &msg, ARGS... params) - : OutOfRangeException(ConstructMessage(msg, params...)) { + explicit OutOfRangeException(const string &msg, ARGS &&...params) + : OutOfRangeException(ConstructMessage(msg, std::forward(params)...)) { } DUCKDB_API OutOfRangeException(const int64_t value, const PhysicalType orig_type, const PhysicalType new_type); DUCKDB_API OutOfRangeException(const hugeint_t value, const PhysicalType orig_type, const PhysicalType new_type); @@ -189,8 +191,8 @@ class OutOfMemoryException : public Exception { DUCKDB_API explicit OutOfMemoryException(const string &msg); template - explicit OutOfMemoryException(const string &msg, ARGS... params) - : OutOfMemoryException(ConstructMessage(msg, params...)) { + explicit OutOfMemoryException(const string &msg, ARGS &&...params) + : OutOfMemoryException(ConstructMessage(msg, std::forward(params)...)) { } private: @@ -202,7 +204,8 @@ class SyntaxException : public Exception { DUCKDB_API explicit SyntaxException(const string &msg); template - explicit SyntaxException(const string &msg, ARGS... params) : SyntaxException(ConstructMessage(msg, params...)) { + explicit SyntaxException(const string &msg, ARGS &&...params) + : SyntaxException(ConstructMessage(msg, std::forward(params)...)) { } }; @@ -211,8 +214,8 @@ class ConstraintException : public Exception { DUCKDB_API explicit ConstraintException(const string &msg); template - explicit ConstraintException(const string &msg, ARGS... params) - : ConstraintException(ConstructMessage(msg, params...)) { + explicit ConstraintException(const string &msg, ARGS &&...params) + : ConstraintException(ConstructMessage(msg, std::forward(params)...)) { } }; @@ -221,25 +224,27 @@ class DependencyException : public Exception { DUCKDB_API explicit DependencyException(const string &msg); template - explicit DependencyException(const string &msg, ARGS... params) - : DependencyException(ConstructMessage(msg, params...)) { + explicit DependencyException(const string &msg, ARGS &&...params) + : DependencyException(ConstructMessage(msg, std::forward(params)...)) { } }; class IOException : public Exception { public: DUCKDB_API explicit IOException(const string &msg); - DUCKDB_API explicit IOException(const string &msg, const unordered_map &extra_info); + + DUCKDB_API explicit IOException(const unordered_map &extra_info, const string &msg); explicit IOException(ExceptionType exception_type, const string &msg) : Exception(exception_type, msg) { } template - explicit IOException(const string &msg, ARGS... params) : IOException(ConstructMessage(msg, params...)) { + explicit IOException(const string &msg, ARGS &&...params) + : IOException(ConstructMessage(msg, std::forward(params)...)) { } template - explicit IOException(const string &msg, const unordered_map &extra_info, ARGS... params) - : IOException(ConstructMessage(msg, params...), extra_info) { + explicit IOException(const unordered_map &extra_info, const string &msg, ARGS &&...params) + : IOException(extra_info, ConstructMessage(msg, std::forward(params)...)) { } }; @@ -248,18 +253,24 @@ class MissingExtensionException : public Exception { DUCKDB_API explicit MissingExtensionException(const string &msg); template - explicit MissingExtensionException(const string &msg, ARGS... params) - : MissingExtensionException(ConstructMessage(msg, params...)) { + explicit MissingExtensionException(const string &msg, ARGS &&...params) + : MissingExtensionException(ConstructMessage(msg, std::forward(params)...)) { } }; class NotImplementedException : public Exception { public: DUCKDB_API explicit NotImplementedException(const string &msg); + explicit NotImplementedException(const unordered_map &extra_info, const string &msg); template - explicit NotImplementedException(const string &msg, ARGS... params) - : NotImplementedException(ConstructMessage(msg, params...)) { + explicit NotImplementedException(const string &msg, ARGS &&...params) + : NotImplementedException(ConstructMessage(msg, std::forward(params)...)) { + } + template + explicit NotImplementedException(const unordered_map &extra_info, const string &msg, + ARGS &&...params) + : NotImplementedException(extra_info, ConstructMessage(msg, std::forward(params)...)) { } }; @@ -273,8 +284,8 @@ class SerializationException : public Exception { DUCKDB_API explicit SerializationException(const string &msg); template - explicit SerializationException(const string &msg, ARGS... params) - : SerializationException(ConstructMessage(msg, params...)) { + explicit SerializationException(const string &msg, ARGS &&...params) + : SerializationException(ConstructMessage(msg, std::forward(params)...)) { } }; @@ -283,8 +294,8 @@ class SequenceException : public Exception { DUCKDB_API explicit SequenceException(const string &msg); template - explicit SequenceException(const string &msg, ARGS... params) - : SequenceException(ConstructMessage(msg, params...)) { + explicit SequenceException(const string &msg, ARGS &&...params) + : SequenceException(ConstructMessage(msg, std::forward(params)...)) { } }; @@ -298,39 +309,48 @@ class FatalException : public Exception { explicit FatalException(const string &msg) : FatalException(ExceptionType::FATAL, msg) { } template - explicit FatalException(const string &msg, ARGS... params) : FatalException(ConstructMessage(msg, params...)) { + explicit FatalException(const string &msg, ARGS &&...params) + : FatalException(ConstructMessage(msg, std::forward(params)...)) { } protected: DUCKDB_API explicit FatalException(ExceptionType type, const string &msg); template - explicit FatalException(ExceptionType type, const string &msg, ARGS... params) - : FatalException(type, ConstructMessage(msg, params...)) { + explicit FatalException(ExceptionType type, const string &msg, ARGS &&...params) + : FatalException(type, ConstructMessage(msg, std::forward(params)...)) { } }; class InternalException : public Exception { public: DUCKDB_API explicit InternalException(const string &msg); + InternalException(const unordered_map &extra_info, const string &msg); template - explicit InternalException(const string &msg, ARGS... params) - : InternalException(ConstructMessage(msg, params...)) { + explicit InternalException(const string &msg, ARGS &&...params) + : InternalException(ConstructMessage(msg, std::forward(params)...)) { + } + + template + explicit InternalException(const unordered_map &extra_info, const string &msg, ARGS &&...params) + : InternalException(extra_info, ConstructMessage(msg, std::forward(params)...)) { } }; class InvalidInputException : public Exception { public: DUCKDB_API explicit InvalidInputException(const string &msg); - DUCKDB_API explicit InvalidInputException(const string &msg, const unordered_map &extra_info); + DUCKDB_API explicit InvalidInputException(const unordered_map &extra_info, const string &msg); template - explicit InvalidInputException(const string &msg, ARGS... params) - : InvalidInputException(ConstructMessage(msg, params...)) { + explicit InvalidInputException(const string &msg, ARGS &&...params) + : InvalidInputException(ConstructMessage(msg, std::forward(params)...)) { } + template - explicit InvalidInputException(const Expression &expr, const string &msg, ARGS... params) - : InvalidInputException(ConstructMessage(msg, params...), Exception::InitializeExtraInfo(expr)) { + explicit InvalidInputException(const Expression &expr, const string &msg, ARGS &&...params) + : InvalidInputException(Exception::InitializeExtraInfo(expr), + ConstructMessage(msg, std::forward(params)...)) { } }; @@ -339,24 +359,26 @@ class ExecutorException : public Exception { DUCKDB_API explicit ExecutorException(const string &msg); template - explicit ExecutorException(const string &msg, ARGS... params) - : ExecutorException(ConstructMessage(msg, params...)) { + explicit ExecutorException(const string &msg, ARGS &&...params) + : ExecutorException(ConstructMessage(msg, std::forward(params)...)) { } }; class InvalidConfigurationException : public Exception { public: DUCKDB_API explicit InvalidConfigurationException(const string &msg); - DUCKDB_API explicit InvalidConfigurationException(const string &msg, - const unordered_map &extra_info); + + DUCKDB_API explicit InvalidConfigurationException(const unordered_map &extra_info, + const string &msg); template - explicit InvalidConfigurationException(const string &msg, ARGS... params) - : InvalidConfigurationException(ConstructMessage(msg, params...)) { + explicit InvalidConfigurationException(const string &msg, ARGS &&...params) + : InvalidConfigurationException(ConstructMessage(msg, std::forward(params)...)) { } template - explicit InvalidConfigurationException(const Expression &expr, const string &msg, ARGS... params) - : InvalidConfigurationException(ConstructMessage(msg, params...), Exception::InitializeExtraInfo(expr)) { + explicit InvalidConfigurationException(const Expression &expr, const string &msg, ARGS &&...params) + : InvalidConfigurationException(ConstructMessage(msg, std::forward(params)...), + Exception::InitializeExtraInfo(expr)) { } }; @@ -381,8 +403,8 @@ class ParameterNotAllowedException : public Exception { DUCKDB_API explicit ParameterNotAllowedException(const string &msg); template - explicit ParameterNotAllowedException(const string &msg, ARGS... params) - : ParameterNotAllowedException(ConstructMessage(msg, params...)) { + explicit ParameterNotAllowedException(const string &msg, ARGS &&...params) + : ParameterNotAllowedException(ConstructMessage(msg, std::forward(params)...)) { } }; diff --git a/src/duckdb/src/include/duckdb/common/exception/binder_exception.hpp b/src/duckdb/src/include/duckdb/common/exception/binder_exception.hpp index 2590cb094..fd7158f87 100644 --- a/src/duckdb/src/include/duckdb/common/exception/binder_exception.hpp +++ b/src/duckdb/src/include/duckdb/common/exception/binder_exception.hpp @@ -15,31 +15,39 @@ namespace duckdb { class BinderException : public Exception { public: - DUCKDB_API explicit BinderException(const string &msg, const unordered_map &extra_info); DUCKDB_API explicit BinderException(const string &msg); + DUCKDB_API explicit BinderException(const unordered_map &extra_info, const string &msg); + template - explicit BinderException(const string &msg, ARGS... params) : BinderException(ConstructMessage(msg, params...)) { + explicit BinderException(const string &msg, ARGS &&...params) + : BinderException(ConstructMessage(msg, std::forward(params)...)) { } + template - explicit BinderException(const TableRef &ref, const string &msg, ARGS... params) - : BinderException(ConstructMessage(msg, params...), Exception::InitializeExtraInfo(ref)) { + explicit BinderException(const TableRef &ref, const string &msg, ARGS &&...params) + : BinderException(Exception::InitializeExtraInfo(ref), ConstructMessage(msg, std::forward(params)...)) { } template - explicit BinderException(const ParsedExpression &expr, const string &msg, ARGS... params) - : BinderException(ConstructMessage(msg, params...), Exception::InitializeExtraInfo(expr)) { + explicit BinderException(const ParsedExpression &expr, const string &msg, ARGS &&...params) + : BinderException(Exception::InitializeExtraInfo(expr), ConstructMessage(msg, std::forward(params)...)) { } + template - explicit BinderException(const Expression &expr, const string &msg, ARGS... params) - : BinderException(ConstructMessage(msg, params...), Exception::InitializeExtraInfo(expr)) { + explicit BinderException(const Expression &expr, const string &msg, ARGS &&...params) + : BinderException(Exception::InitializeExtraInfo(expr), ConstructMessage(msg, std::forward(params)...)) { } + template - explicit BinderException(QueryErrorContext error_context, const string &msg, ARGS... params) - : BinderException(ConstructMessage(msg, params...), Exception::InitializeExtraInfo(error_context)) { + explicit BinderException(QueryErrorContext error_context, const string &msg, ARGS &&...params) + : BinderException(Exception::InitializeExtraInfo(error_context), + ConstructMessage(msg, std::forward(params)...)) { } + template - explicit BinderException(optional_idx error_location, const string &msg, ARGS... params) - : BinderException(ConstructMessage(msg, params...), Exception::InitializeExtraInfo(error_location)) { + explicit BinderException(optional_idx error_location, const string &msg, ARGS &&...params) + : BinderException(Exception::InitializeExtraInfo(error_location), + ConstructMessage(msg, std::forward(params)...)) { } static BinderException ColumnNotFound(const string &name, const vector &similar_bindings, diff --git a/src/duckdb/src/include/duckdb/common/exception/catalog_exception.hpp b/src/duckdb/src/include/duckdb/common/exception/catalog_exception.hpp index 498fafd19..1095531d0 100644 --- a/src/duckdb/src/include/duckdb/common/exception/catalog_exception.hpp +++ b/src/duckdb/src/include/duckdb/common/exception/catalog_exception.hpp @@ -19,14 +19,18 @@ struct EntryLookupInfo; class CatalogException : public Exception { public: DUCKDB_API explicit CatalogException(const string &msg); - DUCKDB_API explicit CatalogException(const string &msg, const unordered_map &extra_info); + + DUCKDB_API explicit CatalogException(const unordered_map &extra_info, const string &msg); template - explicit CatalogException(const string &msg, ARGS... params) : CatalogException(ConstructMessage(msg, params...)) { + explicit CatalogException(const string &msg, ARGS &&...params) + : CatalogException(ConstructMessage(msg, std::forward(params)...)) { } + template - explicit CatalogException(QueryErrorContext error_context, const string &msg, ARGS... params) - : CatalogException(ConstructMessage(msg, params...), Exception::InitializeExtraInfo(error_context)) { + explicit CatalogException(QueryErrorContext error_context, const string &msg, ARGS &&...params) + : CatalogException(Exception::InitializeExtraInfo(error_context), + ConstructMessage(msg, std::forward(params)...)) { } static CatalogException MissingEntry(const EntryLookupInfo &lookup_info, const string &suggestion); diff --git a/src/duckdb/src/include/duckdb/common/exception/conversion_exception.hpp b/src/duckdb/src/include/duckdb/common/exception/conversion_exception.hpp index 5330f46e6..9252d0790 100644 --- a/src/duckdb/src/include/duckdb/common/exception/conversion_exception.hpp +++ b/src/duckdb/src/include/duckdb/common/exception/conversion_exception.hpp @@ -12,22 +12,24 @@ #include "duckdb/common/optional_idx.hpp" namespace duckdb { - class ConversionException : public Exception { public: DUCKDB_API explicit ConversionException(const string &msg); + DUCKDB_API explicit ConversionException(optional_idx error_location, const string &msg); + DUCKDB_API ConversionException(const PhysicalType orig_type, const PhysicalType new_type); + DUCKDB_API ConversionException(const LogicalType &orig_type, const LogicalType &new_type); template - explicit ConversionException(const string &msg, ARGS... params) - : ConversionException(ConstructMessage(msg, params...)) { + explicit ConversionException(const string &msg, ARGS &&...params) + : ConversionException(ConstructMessage(msg, std::forward(params)...)) { } + template - explicit ConversionException(optional_idx error_location, const string &msg, ARGS... params) - : ConversionException(error_location, ConstructMessage(msg, params...)) { + explicit ConversionException(optional_idx error_location, const string &msg, ARGS &&...params) + : ConversionException(error_location, ConstructMessage(msg, std::forward(params)...)) { } }; - } // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/common/exception/http_exception.hpp b/src/duckdb/src/include/duckdb/common/exception/http_exception.hpp index aff00d23d..b0d0e9c2d 100644 --- a/src/duckdb/src/include/duckdb/common/exception/http_exception.hpp +++ b/src/duckdb/src/include/duckdb/common/exception/http_exception.hpp @@ -24,9 +24,9 @@ class HTTPException : public Exception { } template ::status = 0, typename... ARGS> - explicit HTTPException(RESPONSE &response, const string &msg, ARGS... params) + explicit HTTPException(RESPONSE &response, const string &msg, ARGS &&...params) : HTTPException(static_cast(response.status), response.body, response.headers, response.reason, msg, - params...) { + std::forward(params)...) { } template @@ -35,16 +35,16 @@ class HTTPException : public Exception { }; template ::code = 0, typename... ARGS> - explicit HTTPException(RESPONSE &response, const string &msg, ARGS... params) + explicit HTTPException(RESPONSE &response, const string &msg, ARGS &&...params) : HTTPException(static_cast(response.code), response.body, response.headers, response.error, msg, - params...) { + std::forward(params)...) { } template explicit HTTPException(int status_code, const string &response_body, const HEADERS &headers, const string &reason, - const string &msg, ARGS... params) - : Exception(ExceptionType::HTTP, ConstructMessage(msg, params...), - HTTPExtraInfo(status_code, response_body, headers, reason)) { + const string &msg, ARGS &&...params) + : Exception(HTTPExtraInfo(status_code, response_body, headers, reason), ExceptionType::HTTP, + ConstructMessage(msg, std::forward(params)...)) { } template diff --git a/src/duckdb/src/include/duckdb/common/exception/parser_exception.hpp b/src/duckdb/src/include/duckdb/common/exception/parser_exception.hpp index 363a34457..26ce6c585 100644 --- a/src/duckdb/src/include/duckdb/common/exception/parser_exception.hpp +++ b/src/duckdb/src/include/duckdb/common/exception/parser_exception.hpp @@ -17,18 +17,21 @@ namespace duckdb { class ParserException : public Exception { public: DUCKDB_API explicit ParserException(const string &msg); - DUCKDB_API explicit ParserException(const string &msg, const unordered_map &extra_info); + + DUCKDB_API explicit ParserException(const unordered_map &extra_info, const string &msg); template - explicit ParserException(const string &msg, ARGS... params) : ParserException(ConstructMessage(msg, params...)) { + explicit ParserException(const string &msg, ARGS &&...params) + : ParserException(ConstructMessage(msg, std::forward(params)...)) { } template - explicit ParserException(optional_idx error_location, const string &msg, ARGS... params) - : ParserException(ConstructMessage(msg, params...), Exception::InitializeExtraInfo(error_location)) { + explicit ParserException(optional_idx error_location, const string &msg, ARGS &&...params) + : ParserException(Exception::InitializeExtraInfo(error_location), + ConstructMessage(msg, std::forward(params)...)) { } template - explicit ParserException(const ParsedExpression &expr, const string &msg, ARGS... params) - : ParserException(ConstructMessage(msg, params...), Exception::InitializeExtraInfo(expr)) { + explicit ParserException(const ParsedExpression &expr, const string &msg, ARGS &&...params) + : ParserException(Exception::InitializeExtraInfo(expr), ConstructMessage(msg, std::forward(params)...)) { } static ParserException SyntaxError(const string &query, const string &error_message, optional_idx error_location); diff --git a/src/duckdb/src/include/duckdb/common/exception/transaction_exception.hpp b/src/duckdb/src/include/duckdb/common/exception/transaction_exception.hpp index f0164df69..5ca0be62b 100644 --- a/src/duckdb/src/include/duckdb/common/exception/transaction_exception.hpp +++ b/src/duckdb/src/include/duckdb/common/exception/transaction_exception.hpp @@ -11,15 +11,13 @@ #include "duckdb/common/exception.hpp" namespace duckdb { - class TransactionException : public Exception { public: DUCKDB_API explicit TransactionException(const string &msg); template - explicit TransactionException(const string &msg, ARGS... params) - : TransactionException(ConstructMessage(msg, params...)) { + explicit TransactionException(const string &msg, ARGS &&...params) + : TransactionException(ConstructMessage(msg, std::forward(params)...)) { } }; - } // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/common/exception_format_value.hpp b/src/duckdb/src/include/duckdb/common/exception_format_value.hpp index 3693db54c..7beeead6e 100644 --- a/src/duckdb/src/include/duckdb/common/exception_format_value.hpp +++ b/src/duckdb/src/include/duckdb/common/exception_format_value.hpp @@ -49,13 +49,13 @@ enum class ExceptionFormatValueType : uint8_t { }; struct ExceptionFormatValue { - DUCKDB_API ExceptionFormatValue(double dbl_val); // NOLINT - DUCKDB_API ExceptionFormatValue(int64_t int_val); // NOLINT - DUCKDB_API ExceptionFormatValue(idx_t uint_val); // NOLINT - DUCKDB_API ExceptionFormatValue(string str_val); // NOLINT - DUCKDB_API ExceptionFormatValue(String str_val); // NOLINT - DUCKDB_API ExceptionFormatValue(hugeint_t hg_val); // NOLINT - DUCKDB_API ExceptionFormatValue(uhugeint_t uhg_val); // NOLINT + DUCKDB_API ExceptionFormatValue(double dbl_val); // NOLINT + DUCKDB_API ExceptionFormatValue(int64_t int_val); // NOLINT + DUCKDB_API ExceptionFormatValue(idx_t uint_val); // NOLINT + DUCKDB_API ExceptionFormatValue(string str_val); // NOLINT + DUCKDB_API ExceptionFormatValue(const String &str_val); // NOLINT + DUCKDB_API ExceptionFormatValue(hugeint_t hg_val); // NOLINT + DUCKDB_API ExceptionFormatValue(uhugeint_t uhg_val); // NOLINT ExceptionFormatValueType type; @@ -65,37 +65,37 @@ struct ExceptionFormatValue { public: template - static ExceptionFormatValue CreateFormatValue(T value) { + static ExceptionFormatValue CreateFormatValue(const T &value) { return int64_t(value); } static string Format(const string &msg, std::vector &values); }; template <> -DUCKDB_API ExceptionFormatValue ExceptionFormatValue::CreateFormatValue(PhysicalType value); +DUCKDB_API ExceptionFormatValue ExceptionFormatValue::CreateFormatValue(const PhysicalType &value); template <> -DUCKDB_API ExceptionFormatValue ExceptionFormatValue::CreateFormatValue(SQLString value); +DUCKDB_API ExceptionFormatValue ExceptionFormatValue::CreateFormatValue(const SQLString &value); template <> -DUCKDB_API ExceptionFormatValue ExceptionFormatValue::CreateFormatValue(SQLIdentifier value); +DUCKDB_API ExceptionFormatValue ExceptionFormatValue::CreateFormatValue(const SQLIdentifier &value); template <> -DUCKDB_API ExceptionFormatValue ExceptionFormatValue::CreateFormatValue(LogicalType value); +DUCKDB_API ExceptionFormatValue ExceptionFormatValue::CreateFormatValue(const LogicalType &value); template <> -DUCKDB_API ExceptionFormatValue ExceptionFormatValue::CreateFormatValue(float value); +DUCKDB_API ExceptionFormatValue ExceptionFormatValue::CreateFormatValue(const float &value); template <> -DUCKDB_API ExceptionFormatValue ExceptionFormatValue::CreateFormatValue(double value); +DUCKDB_API ExceptionFormatValue ExceptionFormatValue::CreateFormatValue(const double &value); template <> -DUCKDB_API ExceptionFormatValue ExceptionFormatValue::CreateFormatValue(string value); +DUCKDB_API ExceptionFormatValue ExceptionFormatValue::CreateFormatValue(const string &value); template <> -DUCKDB_API ExceptionFormatValue ExceptionFormatValue::CreateFormatValue(String value); +DUCKDB_API ExceptionFormatValue ExceptionFormatValue::CreateFormatValue(const String &value); template <> -DUCKDB_API ExceptionFormatValue ExceptionFormatValue::CreateFormatValue(const char *value); +DUCKDB_API ExceptionFormatValue ExceptionFormatValue::CreateFormatValue(const char *const &value); template <> -DUCKDB_API ExceptionFormatValue ExceptionFormatValue::CreateFormatValue(char *value); +DUCKDB_API ExceptionFormatValue ExceptionFormatValue::CreateFormatValue(char *const &value); template <> -DUCKDB_API ExceptionFormatValue ExceptionFormatValue::CreateFormatValue(idx_t value); +DUCKDB_API ExceptionFormatValue ExceptionFormatValue::CreateFormatValue(const idx_t &value); template <> -DUCKDB_API ExceptionFormatValue ExceptionFormatValue::CreateFormatValue(hugeint_t value); +DUCKDB_API ExceptionFormatValue ExceptionFormatValue::CreateFormatValue(const hugeint_t &value); template <> -DUCKDB_API ExceptionFormatValue ExceptionFormatValue::CreateFormatValue(uhugeint_t value); +DUCKDB_API ExceptionFormatValue ExceptionFormatValue::CreateFormatValue(const uhugeint_t &value); } // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/common/extra_type_info.hpp b/src/duckdb/src/include/duckdb/common/extra_type_info.hpp index d5e35ee96..02348d69b 100644 --- a/src/duckdb/src/include/duckdb/common/extra_type_info.hpp +++ b/src/duckdb/src/include/duckdb/common/extra_type_info.hpp @@ -28,7 +28,8 @@ enum class ExtraTypeInfoType : uint8_t { ARRAY_TYPE_INFO = 9, ANY_TYPE_INFO = 10, INTEGER_LITERAL_TYPE_INFO = 11, - TEMPLATE_TYPE_INFO = 12 + TEMPLATE_TYPE_INFO = 12, + GEO_TYPE_INFO = 13 }; struct ExtraTypeInfo { @@ -261,7 +262,6 @@ struct IntegerLiteralTypeInfo : public ExtraTypeInfo { }; struct TemplateTypeInfo : public ExtraTypeInfo { - explicit TemplateTypeInfo(string name_p); // The name of the template, e.g. `T`, or `KEY_TYPE`. Used to distinguish between different template types within @@ -278,4 +278,16 @@ struct TemplateTypeInfo : public ExtraTypeInfo { TemplateTypeInfo(); }; +struct GeoTypeInfo : public ExtraTypeInfo { +public: + GeoTypeInfo(); + + void Serialize(Serializer &serializer) const override; + static shared_ptr Deserialize(Deserializer &source); + shared_ptr Copy() const override; + +protected: + bool EqualsInternal(ExtraTypeInfo *other_p) const override; +}; + } // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/common/hugeint.hpp b/src/duckdb/src/include/duckdb/common/hugeint.hpp index c9b54bd95..acdc4fb4b 100644 --- a/src/duckdb/src/include/duckdb/common/hugeint.hpp +++ b/src/duckdb/src/include/duckdb/common/hugeint.hpp @@ -76,7 +76,7 @@ struct hugeint_t { // NOLINT: use numeric casing DUCKDB_API explicit operator int16_t() const; DUCKDB_API explicit operator int32_t() const; DUCKDB_API explicit operator int64_t() const; - DUCKDB_API operator uhugeint_t() const; // NOLINT: Allow implicit conversion from `hugeint_t` + DUCKDB_API explicit operator uhugeint_t() const; }; } // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/common/insertion_order_preserving_map.hpp b/src/duckdb/src/include/duckdb/common/insertion_order_preserving_map.hpp index 7240d62d8..5b6f11730 100644 --- a/src/duckdb/src/include/duckdb/common/insertion_order_preserving_map.hpp +++ b/src/duckdb/src/include/duckdb/common/insertion_order_preserving_map.hpp @@ -95,6 +95,10 @@ class InsertionOrderPreservingMap { map.resize(nz); } + void clear() { // NOLINT: match stl API + map.clear(); + } + void insert(const string &key, V &&value) { // NOLINT: match stl API if (contains(key)) { return; diff --git a/src/duckdb/src/include/duckdb/common/multi_file/base_file_reader.hpp b/src/duckdb/src/include/duckdb/common/multi_file/base_file_reader.hpp index f0a29c7af..c9ed4da21 100644 --- a/src/duckdb/src/include/duckdb/common/multi_file/base_file_reader.hpp +++ b/src/duckdb/src/include/duckdb/common/multi_file/base_file_reader.hpp @@ -79,8 +79,8 @@ class BaseFileReader : public enable_shared_from_this { virtual bool TryInitializeScan(ClientContext &context, GlobalTableFunctionState &gstate, LocalTableFunctionState &lstate) = 0; //! Scan a chunk from the read state - virtual void Scan(ClientContext &context, GlobalTableFunctionState &global_state, - LocalTableFunctionState &local_state, DataChunk &chunk) = 0; + virtual AsyncResult Scan(ClientContext &context, GlobalTableFunctionState &global_state, + LocalTableFunctionState &local_state, DataChunk &chunk) = 0; //! Finish scanning a given file DUCKDB_API virtual void FinishFile(ClientContext &context, GlobalTableFunctionState &gstate); //! Get progress within a given file diff --git a/src/duckdb/src/include/duckdb/common/multi_file/multi_file_data.hpp b/src/duckdb/src/include/duckdb/common/multi_file/multi_file_data.hpp index fd6380a7e..523084d6e 100644 --- a/src/duckdb/src/include/duckdb/common/multi_file/multi_file_data.hpp +++ b/src/duckdb/src/include/duckdb/common/multi_file/multi_file_data.hpp @@ -139,7 +139,7 @@ struct MultiFileLocalColumnId { } public: - operator idx_t() { // NOLINT: allow implicit conversion + operator idx_t() const { // NOLINT: allow implicit conversion return column_id; } idx_t GetId() const { @@ -170,7 +170,7 @@ struct MultiFileLocalIndex { } public: - operator idx_t() { // NOLINT: allow implicit conversion + operator idx_t() const { // NOLINT: allow implicit conversion return index; } idx_t GetIndex() const { diff --git a/src/duckdb/src/include/duckdb/common/multi_file/multi_file_function.hpp b/src/duckdb/src/include/duckdb/common/multi_file/multi_file_function.hpp index 1ed169568..9cebe4fc8 100644 --- a/src/duckdb/src/include/duckdb/common/multi_file/multi_file_function.hpp +++ b/src/duckdb/src/include/duckdb/common/multi_file/multi_file_function.hpp @@ -590,28 +590,77 @@ class MultiFileFunction : public TableFunction { static void MultiFileScan(ClientContext &context, TableFunctionInput &data_p, DataChunk &output) { if (!data_p.local_state) { + data_p.async_result = SourceResultType::FINISHED; return; } auto &data = data_p.local_state->Cast(); auto &gstate = data_p.global_state->Cast(); auto &bind_data = data_p.bind_data->CastNoConst(); + if (gstate.finished) { + data_p.async_result = SourceResultType::FINISHED; + return; + } + do { auto &scan_chunk = data.scan_chunk; scan_chunk.Reset(); - data.reader->Scan(context, *gstate.global_state, *data.local_state, scan_chunk); + auto res = data.reader->Scan(context, *gstate.global_state, *data.local_state, scan_chunk); + + if (res.GetResultType() == AsyncResultType::BLOCKED) { + if (scan_chunk.size() != 0) { + throw InternalException("Unexpected behaviour from Scan, no rows should be returned"); + } + switch (data_p.results_execution_mode) { + case AsyncResultsExecutionMode::TASK_EXECUTOR: + data_p.async_result = std::move(res); + return; + case AsyncResultsExecutionMode::SYNCHRONOUS: + res.ExecuteTasksSynchronously(); + if (res.GetResultType() != AsyncResultType::HAVE_MORE_OUTPUT) { + throw InternalException("Unexpected behaviour from ExecuteTasksSynchronously"); + } + // scan_chunk.size() is 0, see check above, and result is HAVE_MORE_OUTPUT, we need to loop again + continue; + } + } + output.SetCardinality(scan_chunk.size()); + if (scan_chunk.size() > 0) { bind_data.multi_file_reader->FinalizeChunk(context, bind_data, *data.reader, *data.reader_data, scan_chunk, output, data.executor, gstate.multi_file_reader_state); + } + if (res.GetResultType() == AsyncResultType::HAVE_MORE_OUTPUT) { + // Loop back to the same block + if (scan_chunk.size() == 0 && data_p.results_execution_mode == AsyncResultsExecutionMode::SYNCHRONOUS) { + continue; + } + data_p.async_result = SourceResultType::HAVE_MORE_OUTPUT; return; } - scan_chunk.Reset(); + + if (res.GetResultType() != AsyncResultType::FINISHED) { + throw InternalException("Unexpected result in MultiFileScan, must be FINISHED, is %s", + EnumUtil::ToChars(res.GetResultType())); + } + if (!TryInitializeNextBatch(context, bind_data, data, gstate)) { - return; + if (scan_chunk.size() > 0 && data_p.results_execution_mode == AsyncResultsExecutionMode::SYNCHRONOUS) { + gstate.finished = true; + data_p.async_result = SourceResultType::HAVE_MORE_OUTPUT; + } else { + data_p.async_result = SourceResultType::FINISHED; + } + } else { + if (scan_chunk.size() == 0 && data_p.results_execution_mode == AsyncResultsExecutionMode::SYNCHRONOUS) { + continue; + } + data_p.async_result = SourceResultType::HAVE_MORE_OUTPUT; } + return; } while (true); } @@ -672,7 +721,8 @@ class MultiFileFunction : public TableFunction { continue; } auto &reader_data = *reader_data_ptr; - double progress_in_file; + // Initialize progress_in_file with a default value to avoid uninitialized variable usage + double progress_in_file = 0.0; if (reader_data.file_state == MultiFileFileState::OPEN) { // file is currently open - get the progress within the file progress_in_file = reader_data.reader->GetProgressInFile(context); @@ -686,9 +736,6 @@ class MultiFileFunction : public TableFunction { // file is still being read progress_in_file = reader->GetProgressInFile(context); } - } else { - // file has not been opened yet - progress in this file is zero - progress_in_file = 0; } progress_in_file = MaxValue(0.0, MinValue(100.0, progress_in_file)); total_progress += progress_in_file; diff --git a/src/duckdb/src/include/duckdb/common/multi_file/multi_file_states.hpp b/src/duckdb/src/include/duckdb/common/multi_file/multi_file_states.hpp index d9801e5d0..556a33d3b 100644 --- a/src/duckdb/src/include/duckdb/common/multi_file/multi_file_states.hpp +++ b/src/duckdb/src/include/duckdb/common/multi_file/multi_file_states.hpp @@ -166,6 +166,7 @@ struct MultiFileGlobalState : public GlobalTableFunctionState { vector scanned_types; vector column_indexes; optional_ptr filters; + atomic finished {false}; unique_ptr global_state; diff --git a/src/duckdb/src/include/duckdb/common/operator/cast_operators.hpp b/src/duckdb/src/include/duckdb/common/operator/cast_operators.hpp index ac55e5a69..e495e9760 100644 --- a/src/duckdb/src/include/duckdb/common/operator/cast_operators.hpp +++ b/src/duckdb/src/include/duckdb/common/operator/cast_operators.hpp @@ -1070,6 +1070,19 @@ bool TryCastBlobToUUID::Operation(string_t input, hugeint_t &result, Vector &res template <> bool TryCastBlobToUUID::Operation(string_t input, hugeint_t &result, bool strict); +//===--------------------------------------------------------------------===// +// GEOMETRY +//===--------------------------------------------------------------------===// +struct TryCastToGeometry { + template + static inline bool Operation(SRC input, DST &result, Vector &result_vector, CastParameters ¶meters) { + throw InternalException("Unsupported type for try cast to geometry"); + } +}; + +template <> +bool TryCastToGeometry::Operation(string_t input, string_t &result, Vector &result_vector, CastParameters ¶meters); + //===--------------------------------------------------------------------===// // Pointers //===--------------------------------------------------------------------===// diff --git a/src/duckdb/src/include/duckdb/common/operator/comparison_operators.hpp b/src/duckdb/src/include/duckdb/common/operator/comparison_operators.hpp index f1a6f6eb3..a847e217b 100644 --- a/src/duckdb/src/include/duckdb/common/operator/comparison_operators.hpp +++ b/src/duckdb/src/include/duckdb/common/operator/comparison_operators.hpp @@ -210,15 +210,4 @@ inline bool GreaterThan::Operation(const interval_t &left, const interval_t &rig return Interval::GreaterThan(left, right); } -//===--------------------------------------------------------------------===// -// Specialized Hugeint Comparison Operators -//===--------------------------------------------------------------------===// -template <> -inline bool Equals::Operation(const hugeint_t &left, const hugeint_t &right) { - return Hugeint::Equals(left, right); -} -template <> -inline bool GreaterThan::Operation(const hugeint_t &left, const hugeint_t &right) { - return Hugeint::GreaterThan(left, right); -} } // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/common/primitive_dictionary.hpp b/src/duckdb/src/include/duckdb/common/primitive_dictionary.hpp index 1376cc1b5..6d4c79194 100644 --- a/src/duckdb/src/include/duckdb/common/primitive_dictionary.hpp +++ b/src/duckdb/src/include/duckdb/common/primitive_dictionary.hpp @@ -19,6 +19,14 @@ struct PrimitiveCastOperator { static TGT Operation(SRC input) { return TGT(input); } + template + static constexpr idx_t WriteSize(const TGT &input) { + return sizeof(TGT); + } + template + static void WriteToStream(const TGT &input, WriteStream &ser) { + ser.WriteData(const_data_ptr_cast(&input), sizeof(TGT)); + } }; template @@ -51,21 +59,19 @@ class PrimitiveDictionary { : capacity * sizeof(TGT))), target_stream(allocated_target.get(), allocated_target.GetSize()), dictionary(reinterpret_cast(allocated_dictionary.get())), full(false) { - // Initialize empty - for (idx_t i = 0; i < capacity; i++) { - dictionary[i].index = INVALID_INDEX; - } + Clear(); } public: //! Insert value into dictionary (if not full) + template void Insert(SRC value) { if (full) { return; } auto &entry = Lookup(value); if (entry.IsEmpty()) { - if (size + 1 > maximum_size || !AddToTarget(value)) { + if (size + 1 > maximum_size || (ADD_TO_TARGET && !AddToTarget(value))) { full = true; return; } @@ -128,7 +134,13 @@ class PrimitiveDictionary { allocated_target.Reset(); } -private: + void Clear() { + for (idx_t i = 0; i < capacity; i++) { + dictionary[i].index = INVALID_INDEX; + } + size = 0; + full = false; + } //! Look up a value in the dictionary using linear probing primitive_dictionary_entry_t &Lookup(const SRC &value) const { auto offset = Hash(value) & capacity_mask; @@ -138,6 +150,7 @@ class PrimitiveDictionary { return dictionary[offset]; } +private: //! Write a value to the target data (if source is not string) template ::value, int>::type = 0> bool AddToTarget(const SRC &src_value) { @@ -205,7 +218,7 @@ class PrimitiveDictionary { //! Maximum size and current size const idx_t maximum_size; - idx_t size; + uint32_t size; //! Dictionary capacity (power of two) and corresponding mask const idx_t capacity; diff --git a/src/duckdb/src/include/duckdb/common/profiler.hpp b/src/duckdb/src/include/duckdb/common/profiler.hpp index 5fb65337a..3a5cb402e 100644 --- a/src/duckdb/src/include/duckdb/common/profiler.hpp +++ b/src/duckdb/src/include/duckdb/common/profiler.hpp @@ -13,36 +13,52 @@ namespace duckdb { -//! The profiler can be used to measure elapsed time +//! Profiler class to measure the elapsed time. template class BaseProfiler { public: - //! Starts the timer + //! Start the timer. void Start() { finished = false; + ran = true; start = Tick(); } - //! Finishes timing + //! End the timer. void End() { end = Tick(); finished = true; } + //! Reset the timer. + void Reset() { + finished = false; + ran = false; + } - //! Returns the elapsed time in seconds. If End() has been called, returns - //! the total elapsed time. Otherwise returns how far along the timer is - //! right now. + //! Returns the elapsed time in seconds. + //! If ran is false, it returns 0. + //! If End() has been called, it returns the total elapsed time, + //! otherwise, returns how far along the timer is right now. double Elapsed() const { + if (!ran) { + return 0; + } auto measured_end = finished ? end : Tick(); return std::chrono::duration_cast>(measured_end - start).count(); } private: + //! Current time point. time_point Tick() const { return T::now(); } + //! Start time point. time_point start; + //! End time point. time_point end; + //! True, if end End() been called. bool finished = false; + //! True, if the timer was ran. + bool ran = false; }; using Profiler = BaseProfiler; diff --git a/src/duckdb/src/include/duckdb/common/progress_bar/display/terminal_progress_bar_display.hpp b/src/duckdb/src/include/duckdb/common/progress_bar/display/terminal_progress_bar_display.hpp index ae6be8549..a4ad9444c 100644 --- a/src/duckdb/src/include/duckdb/common/progress_bar/display/terminal_progress_bar_display.hpp +++ b/src/duckdb/src/include/duckdb/common/progress_bar/display/terminal_progress_bar_display.hpp @@ -13,7 +13,6 @@ #include "duckdb/common/unicode_bar.hpp" #include "duckdb/common/progress_bar/unscented_kalman_filter.hpp" #include -#include namespace duckdb { @@ -30,21 +29,26 @@ struct TerminalProgressBarDisplayedProgressInfo { } }; -class TerminalProgressBarDisplay : public ProgressBarDisplay { -private: - UnscentedKalmanFilter ukf; - std::chrono::steady_clock::time_point start_time; - - // track the progress info that has been previously - // displayed to prevent redundant updates - struct TerminalProgressBarDisplayedProgressInfo displayed_progress_info; - - double GetElapsedDuration() { - auto now = std::chrono::steady_clock::now(); - return std::chrono::duration(now - start_time).count(); - } - void StopPeriodicUpdates(); +struct ProgressBarDisplayInfo { + idx_t width = 38; +#ifndef DUCKDB_ASCII_TREE_RENDERER + const char *progress_empty = " "; + const char *const *progress_partial = UnicodeBar::PartialBlocks(); + idx_t partial_block_count = UnicodeBar::PartialBlocksCount(); + const char *progress_block = UnicodeBar::FullBlock(); + const char *progress_start = "\xE2\x96\x95"; + const char *progress_end = "\xE2\x96\x8F"; +#else + const char *progress_empty = " "; + const char *const progress_partial[PARTIAL_BLOCK_COUNT] = {" ", " ", " ", " ", " ", " ", " ", " "}; + idx_t partial_block_count = 8; + const char *progress_block = "="; + const char *progress_start = "["; + const char *progress_end = "]"; +#endif +}; +class TerminalProgressBarDisplay : public ProgressBarDisplay { public: TerminalProgressBarDisplay() { start_time = std::chrono::steady_clock::now(); @@ -57,32 +61,33 @@ class TerminalProgressBarDisplay : public ProgressBarDisplay { public: void Update(double percentage) override; void Finish() override; + static string FormatETA(double seconds, bool elapsed = false); + static string FormatProgressBar(const ProgressBarDisplayInfo &display_info, int32_t percentage); private: - std::mutex mtx; - std::thread periodic_update_thread; - std::condition_variable cv; void PeriodicUpdate(); - static constexpr const idx_t PARTIAL_BLOCK_COUNT = UnicodeBar::PartialBlocksCount(); -#ifndef DUCKDB_ASCII_TREE_RENDERER - const char *PROGRESS_EMPTY = " "; // NOLINT - const char *const *PROGRESS_PARTIAL = UnicodeBar::PartialBlocks(); // NOLINT - const char *PROGRESS_BLOCK = UnicodeBar::FullBlock(); // NOLINT - const char *PROGRESS_START = "\xE2\x96\x95"; // NOLINT - const char *PROGRESS_END = "\xE2\x96\x8F"; // NOLINT -#else - const char *PROGRESS_EMPTY = " "; - const char *const PROGRESS_PARTIAL[PARTIAL_BLOCK_COUNT] = {" ", " ", " ", " ", " ", " ", " ", " "}; - const char *PROGRESS_BLOCK = "="; - const char *PROGRESS_START = "["; - const char *PROGRESS_END = "]"; -#endif - static constexpr const idx_t PROGRESS_BAR_WIDTH = 38; +public: + ProgressBarDisplayInfo display_info; + +protected: + virtual void PrintProgressInternal(int32_t percentage, double estimated_remaining_seconds, + bool is_finished = false); -private: static int32_t NormalizePercentage(double percentage); - void PrintProgressInternal(int32_t percentage, double estimated_remaining_seconds, bool is_finished = false); + double GetElapsedDuration() { + auto now = std::chrono::steady_clock::now(); + return std::chrono::duration(now - start_time).count(); + } + void StopPeriodicUpdates(); + +private: + UnscentedKalmanFilter ukf; + std::chrono::steady_clock::time_point start_time; + + // track the progress info that has been previously + // displayed to prevent redundant updates + struct TerminalProgressBarDisplayedProgressInfo displayed_progress_info; }; } // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/common/queue.hpp b/src/duckdb/src/include/duckdb/common/queue.hpp index d3e28d982..e768490cc 100644 --- a/src/duckdb/src/include/duckdb/common/queue.hpp +++ b/src/duckdb/src/include/duckdb/common/queue.hpp @@ -8,8 +8,77 @@ #pragma once +#include "duckdb/common/assert.hpp" +#include "duckdb/common/typedefs.hpp" +#include "duckdb/common/likely.hpp" +#include "duckdb/common/exception.hpp" +#include "duckdb/common/memory_safety.hpp" #include namespace duckdb { -using std::queue; + +template , bool SAFE = true> +class queue : public std::queue { // NOLINT: matching name of std +public: + using original = std::queue; + using original::original; + using container_type = typename original::container_type; + using value_type = typename original::value_type; + using size_type = typename container_type::size_type; + using reference = typename container_type::reference; + using const_reference = typename container_type::const_reference; + +public: + // Because we create the other constructor, the implicitly created constructor + // gets deleted, so we have to be explicit + queue() = default; + queue(original &&other) : original(std::move(other)) { // NOLINT: allow implicit conversion + } + template + queue(queue &&other) : original(std::move(other)) { // NOLINT + } + + inline void clear() noexcept { + original::c.clear(); + } + + reference front() { + if (MemorySafety::ENABLED && original::empty()) { + throw InternalException("'front' called on an empty queue!"); + } + return original::front(); + } + + const_reference front() const { + if (MemorySafety::ENABLED && original::empty()) { + throw InternalException("'front' called on an empty queue!"); + } + return original::front(); + } + + reference back() { + if (MemorySafety::ENABLED && original::empty()) { + throw InternalException("'back' called on an empty queue!"); + } + return original::back(); + } + + const_reference back() const { + if (MemorySafety::ENABLED && original::empty()) { + throw InternalException("'back' called on an empty queue!"); + } + return original::back(); + } + + void pop() { + if (MemorySafety::ENABLED && original::empty()) { + throw InternalException("'pop' called on an empty queue!"); + } + original::pop(); + } +}; + +template > +using unsafe_queue = queue; + } // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/common/radix_partitioning.hpp b/src/duckdb/src/include/duckdb/common/radix_partitioning.hpp index fef1847bd..cd1e23ecd 100644 --- a/src/duckdb/src/include/duckdb/common/radix_partitioning.hpp +++ b/src/duckdb/src/include/duckdb/common/radix_partitioning.hpp @@ -109,7 +109,7 @@ class RadixPartitionedTupleData : public PartitionedTupleData { public: RadixPartitionedTupleData(BufferManager &buffer_manager, shared_ptr layout_ptr, idx_t radix_bits_p, idx_t hash_col_idx_p); - RadixPartitionedTupleData(const RadixPartitionedTupleData &other); + RadixPartitionedTupleData(RadixPartitionedTupleData &other); ~RadixPartitionedTupleData() override; idx_t GetRadixBits() const { diff --git a/src/duckdb/src/include/duckdb/common/row_operations/row_operations.hpp b/src/duckdb/src/include/duckdb/common/row_operations/row_operations.hpp index 557e9cd5b..ee1d11afb 100644 --- a/src/duckdb/src/include/duckdb/common/row_operations/row_operations.hpp +++ b/src/duckdb/src/include/duckdb/common/row_operations/row_operations.hpp @@ -25,22 +25,6 @@ struct SelectionVector; class StringHeap; struct UnifiedVectorFormat; -// The NestedValidity class help to set/get the validity from inside nested vectors -class NestedValidity { - data_ptr_t list_validity_location; - data_ptr_t *struct_validity_locations; - idx_t entry_idx; - idx_t idx_in_entry; - idx_t list_validity_offset; - -public: - explicit NestedValidity(data_ptr_t validitymask_location); - NestedValidity(data_ptr_t *validitymask_locations, idx_t child_vector_index); - void SetInvalid(idx_t idx); - bool IsValid(idx_t idx); - void OffsetListBy(idx_t offset); -}; - struct RowOperationsState { explicit RowOperationsState(ArenaAllocator &allocator) : allocator(allocator) { } @@ -49,7 +33,7 @@ struct RowOperationsState { unique_ptr addresses; // Re-usable vector for row_aggregate.cpp }; -// RowOperations contains a set of operations that operate on data using a RowLayout +// RowOperations contains a set of operations that operate on data using a TupleDataLayout struct RowOperations { //===--------------------------------------------------------------------===// // Aggregation Operators @@ -70,66 +54,6 @@ struct RowOperations { //! finalize - unaligned addresses, updated static void FinalizeStates(RowOperationsState &state, TupleDataLayout &layout, Vector &addresses, DataChunk &result, idx_t aggr_idx); - - //===--------------------------------------------------------------------===// - // Read/Write Operators - //===--------------------------------------------------------------------===// - //! Scatter group data to the rows. Initialises the ValidityMask. - static void Scatter(DataChunk &columns, UnifiedVectorFormat col_data[], const RowLayout &layout, Vector &rows, - RowDataCollection &string_heap, const SelectionVector &sel, idx_t count); - //! Gather a single column. - //! If heap_ptr is not null, then the data is assumed to contain swizzled pointers, - //! which will be unswizzled in memory. - static void Gather(Vector &rows, const SelectionVector &row_sel, Vector &col, const SelectionVector &col_sel, - const idx_t count, const RowLayout &layout, const idx_t col_no, const idx_t build_size = 0, - data_ptr_t heap_ptr = nullptr); - - //===--------------------------------------------------------------------===// - // Heap Operators - //===--------------------------------------------------------------------===// - //! Compute the entry sizes of a vector with variable size type (used before building heap buffer space). - static void ComputeEntrySizes(Vector &v, idx_t entry_sizes[], idx_t vcount, idx_t ser_count, - const SelectionVector &sel, idx_t offset = 0); - //! Compute the entry sizes of vector data with variable size type (used before building heap buffer space). - static void ComputeEntrySizes(Vector &v, UnifiedVectorFormat &vdata, idx_t entry_sizes[], idx_t vcount, - idx_t ser_count, const SelectionVector &sel, idx_t offset = 0); - //! Scatter vector with variable size type to the heap. - static void HeapScatter(Vector &v, idx_t vcount, const SelectionVector &sel, idx_t ser_count, - data_ptr_t *key_locations, optional_ptr parent_validity, idx_t offset = 0); - //! Scatter vector data with variable size type to the heap. - static void HeapScatterVData(UnifiedVectorFormat &vdata, PhysicalType type, const SelectionVector &sel, - idx_t ser_count, data_ptr_t *key_locations, - optional_ptr parent_validity, idx_t offset = 0); - //! Gather a single column with variable size type from the heap. - static void HeapGather(Vector &v, const idx_t &vcount, const SelectionVector &sel, data_ptr_t key_locations[], - optional_ptr parent_validity); - - //===--------------------------------------------------------------------===// - // Sorting Operators - //===--------------------------------------------------------------------===// - //! Scatter vector data to the rows in radix-sortable format. - static void RadixScatter(Vector &v, idx_t vcount, const SelectionVector &sel, idx_t ser_count, - data_ptr_t key_locations[], bool desc, bool has_null, bool nulls_first, idx_t prefix_len, - idx_t width, idx_t offset = 0); - - //===--------------------------------------------------------------------===// - // Out-of-Core Operators - //===--------------------------------------------------------------------===// - //! Swizzles blob pointers to offset within heap row - static void SwizzleColumns(const RowLayout &layout, const data_ptr_t base_row_ptr, const idx_t count); - //! Swizzles the base pointer of each row to offset within heap block - static void SwizzleHeapPointer(const RowLayout &layout, data_ptr_t row_ptr, const data_ptr_t heap_base_ptr, - const idx_t count, const idx_t base_offset = 0); - //! Copies 'count' heap rows that are pointed to by the rows at 'row_ptr' to 'heap_ptr' and swizzles the pointers - static void CopyHeapAndSwizzle(const RowLayout &layout, data_ptr_t row_ptr, const data_ptr_t heap_base_ptr, - data_ptr_t heap_ptr, const idx_t count); - - //! Unswizzles the base offset within heap block the rows to pointers - static void UnswizzleHeapPointer(const RowLayout &layout, const data_ptr_t base_row_ptr, - const data_ptr_t base_heap_ptr, const idx_t count); - //! Unswizzles all offsets back to pointers - static void UnswizzlePointers(const RowLayout &layout, const data_ptr_t base_row_ptr, - const data_ptr_t base_heap_ptr, const idx_t count); }; } // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/common/serializer/encoding_util.hpp b/src/duckdb/src/include/duckdb/common/serializer/encoding_util.hpp index f30cf5790..5248b04a6 100644 --- a/src/duckdb/src/include/duckdb/common/serializer/encoding_util.hpp +++ b/src/duckdb/src/include/duckdb/common/serializer/encoding_util.hpp @@ -14,7 +14,6 @@ namespace duckdb { struct EncodingUtil { - // Encode unsigned integer, returns the number of bytes written template static idx_t EncodeUnsignedLEB128(data_ptr_t target, T value) { diff --git a/src/duckdb/src/include/duckdb/common/serializer/serialization_traits.hpp b/src/duckdb/src/include/duckdb/common/serializer/serialization_traits.hpp index 5bde0f9a1..bdb82b0c9 100644 --- a/src/duckdb/src/include/duckdb/common/serializer/serialization_traits.hpp +++ b/src/duckdb/src/include/duckdb/common/serializer/serialization_traits.hpp @@ -186,7 +186,6 @@ struct is_atomic> : std::true_type { // NOLINTEND struct SerializationDefaultValue { - template static inline typename std::enable_if::value, T>::type GetDefault() { using INNER = typename is_atomic::TYPE; diff --git a/src/duckdb/src/include/duckdb/common/serializer/varint.hpp b/src/duckdb/src/include/duckdb/common/serializer/varint.hpp index 8d0316a32..8cccd6f56 100644 --- a/src/duckdb/src/include/duckdb/common/serializer/varint.hpp +++ b/src/duckdb/src/include/duckdb/common/serializer/varint.hpp @@ -35,7 +35,8 @@ uint8_t GetVarintSize(T val) { } template -void VarintEncode(T val, data_ptr_t ptr) { +idx_t VarintEncode(T val, data_ptr_t ptr) { + idx_t size = 0; do { uint8_t byte = val & 127; val >>= 7; @@ -44,11 +45,14 @@ void VarintEncode(T val, data_ptr_t ptr) { } *ptr = byte; ptr++; + size++; } while (val != 0); + return size; } template -void VarintEncode(T val, MemoryStream &ser) { +idx_t VarintEncode(T val, MemoryStream &ser) { + idx_t size = 0; do { uint8_t byte = val & 127; val >>= 7; @@ -56,7 +60,9 @@ void VarintEncode(T val, MemoryStream &ser) { byte |= 128; } ser.WriteData(&byte, sizeof(uint8_t)); + size++; } while (val != 0); + return size; } } // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/common/sort/comparators.hpp b/src/duckdb/src/include/duckdb/common/sort/comparators.hpp deleted file mode 100644 index 5f3cd3807..000000000 --- a/src/duckdb/src/include/duckdb/common/sort/comparators.hpp +++ /dev/null @@ -1,65 +0,0 @@ -//===----------------------------------------------------------------------===// -// DuckDB -// -// duckdb/common/sort/comparators.hpp -// -// -//===----------------------------------------------------------------------===// - -#pragma once - -#include "duckdb/common/types.hpp" -#include "duckdb/common/types/row/row_layout.hpp" - -namespace duckdb { - -struct SortLayout; -struct SBScanState; - -using ValidityBytes = RowLayout::ValidityBytes; - -struct Comparators { -public: - //! Whether a tie between two blobs can be broken - static bool TieIsBreakable(const idx_t &col_idx, const data_ptr_t &row_ptr, const SortLayout &sort_layout); - //! Compares the tuples that a being read from in the 'left' and 'right blocks during merge sort - //! (only in case we cannot simply 'memcmp' - if there are blob columns) - static int CompareTuple(const SBScanState &left, const SBScanState &right, const data_ptr_t &l_ptr, - const data_ptr_t &r_ptr, const SortLayout &sort_layout, const bool &external_sort); - //! Compare two blob values - static int CompareVal(const data_ptr_t l_ptr, const data_ptr_t r_ptr, const LogicalType &type); - -private: - //! Compares two blob values that were initially tied by their prefix - static int BreakBlobTie(const idx_t &tie_col, const SBScanState &left, const SBScanState &right, - const SortLayout &sort_layout, const bool &external); - //! Compare two fixed-size values - template - static int TemplatedCompareVal(const data_ptr_t &left_ptr, const data_ptr_t &right_ptr); - - //! Compare two values at the pointers (can be recursive if nested type) - static int CompareValAndAdvance(data_ptr_t &l_ptr, data_ptr_t &r_ptr, const LogicalType &type, bool valid); - //! Compares two fixed-size values at the given pointers - template - static int TemplatedCompareAndAdvance(data_ptr_t &left_ptr, data_ptr_t &right_ptr); - //! Compares two string values at the given pointers - static int CompareStringAndAdvance(data_ptr_t &left_ptr, data_ptr_t &right_ptr, bool valid); - //! Compares two struct values at the given pointers (recursive) - static int CompareStructAndAdvance(data_ptr_t &left_ptr, data_ptr_t &right_ptr, - const child_list_t &types, bool valid); - static int CompareArrayAndAdvance(data_ptr_t &left_ptr, data_ptr_t &right_ptr, const LogicalType &type, bool valid, - idx_t array_size); - //! Compare two list values at the pointers (can be recursive if nested type) - static int CompareListAndAdvance(data_ptr_t &left_ptr, data_ptr_t &right_ptr, const LogicalType &type, bool valid); - //! Compares a list of fixed-size values - template - static int TemplatedCompareListLoop(data_ptr_t &left_ptr, data_ptr_t &right_ptr, const ValidityBytes &left_validity, - const ValidityBytes &right_validity, const idx_t &count); - - //! Unwizzles an offset into a pointer - static void UnswizzleSingleValue(data_ptr_t data_ptr, const data_ptr_t &heap_ptr, const LogicalType &type); - //! Swizzles a pointer into an offset - static void SwizzleSingleValue(data_ptr_t data_ptr, const data_ptr_t &heap_ptr, const LogicalType &type); -}; - -} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/common/sort/duckdb_pdqsort.hpp b/src/duckdb/src/include/duckdb/common/sort/duckdb_pdqsort.hpp deleted file mode 100644 index c935a713a..000000000 --- a/src/duckdb/src/include/duckdb/common/sort/duckdb_pdqsort.hpp +++ /dev/null @@ -1,710 +0,0 @@ -/* -pdqsort.h - Pattern-defeating quicksort. - -Copyright (c) 2021 Orson Peters - -This software is provided 'as-is', without any express or implied warranty. In no event will the -authors be held liable for any damages arising from the use of this software. - -Permission is granted to anyone to use this software for any purpose, including commercial -applications, and to alter it and redistribute it freely, subject to the following restrictions: - -1. The origin of this software must not be misrepresented; you must not claim that you wrote the - original software. If you use this software in a product, an acknowledgment in the product - documentation would be appreciated but is not required. - -2. Altered source versions must be plainly marked as such, and must not be misrepresented as - being the original software. - -3. This notice may not be removed or altered from any source distribution. -*/ - -#pragma once - -#include "duckdb/common/constants.hpp" -#include "duckdb/common/fast_mem.hpp" -#include "duckdb/common/helper.hpp" -#include "duckdb/common/types.hpp" -#include "duckdb/common/unique_ptr.hpp" - -#include -#include -#include -#include -#include - -namespace duckdb_pdqsort { - -using duckdb::data_ptr_t; -using duckdb::data_t; -using duckdb::FastMemcmp; -using duckdb::FastMemcpy; -using duckdb::idx_t; -using duckdb::make_unsafe_uniq_array_uninitialized; -using duckdb::unique_ptr; -using duckdb::unsafe_unique_array; - -// NOLINTBEGIN - -enum { - // Partitions below this size are sorted using insertion sort. - insertion_sort_threshold = 24, - - // Partitions above this size use Tukey's ninther to select the pivot. - ninther_threshold = 128, - - // When we detect an already sorted partition, attempt an insertion sort that allows this - // amount of element moves before giving up. - partial_insertion_sort_limit = 8, - - // Must be multiple of 8 due to loop unrolling, and < 256 to fit in unsigned char. - block_size = 64, - - // Cacheline size, assumes power of two. - cacheline_size = 64 - -}; - -// Returns floor(log2(n)), assumes n > 0. -template -inline int log2(T n) { - int log = 0; - while (n >>= 1) { - ++log; - } - return log; -} - -struct PDQConstants { - PDQConstants(idx_t entry_size, idx_t comp_offset, idx_t comp_size, data_ptr_t end) - : entry_size(entry_size), comp_offset(comp_offset), comp_size(comp_size), - tmp_buf_ptr(make_unsafe_uniq_array_uninitialized(entry_size)), tmp_buf(tmp_buf_ptr.get()), - iter_swap_buf_ptr(make_unsafe_uniq_array_uninitialized(entry_size)), - iter_swap_buf(iter_swap_buf_ptr.get()), - swap_offsets_buf_ptr(make_unsafe_uniq_array_uninitialized(entry_size)), - swap_offsets_buf(swap_offsets_buf_ptr.get()), end(end) { - } - - const duckdb::idx_t entry_size; - const idx_t comp_offset; - const idx_t comp_size; - - unsafe_unique_array tmp_buf_ptr; - const data_ptr_t tmp_buf; - - unsafe_unique_array iter_swap_buf_ptr; - const data_ptr_t iter_swap_buf; - - unsafe_unique_array swap_offsets_buf_ptr; - const data_ptr_t swap_offsets_buf; - - const data_ptr_t end; -}; - -struct PDQIterator { - PDQIterator(data_ptr_t ptr, const idx_t &entry_size) : ptr(ptr), entry_size(entry_size) { - } - - inline PDQIterator(const PDQIterator &other) : ptr(other.ptr), entry_size(other.entry_size) { - } - - inline const data_ptr_t &operator*() const { - return ptr; - } - - inline PDQIterator &operator++() { - ptr += entry_size; - return *this; - } - - inline PDQIterator &operator--() { - ptr -= entry_size; - return *this; - } - - inline PDQIterator operator++(int) { - auto tmp = *this; - ptr += entry_size; - return tmp; - } - - inline PDQIterator operator--(int) { - auto tmp = *this; - ptr -= entry_size; - return tmp; - } - - inline PDQIterator operator+(const idx_t &i) const { - auto result = *this; - result.ptr += i * entry_size; - return result; - } - - inline PDQIterator operator-(const idx_t &i) const { - PDQIterator result = *this; - result.ptr -= i * entry_size; - return result; - } - - inline PDQIterator &operator=(const PDQIterator &other) { - D_ASSERT(entry_size == other.entry_size); - ptr = other.ptr; - return *this; - } - - inline friend idx_t operator-(const PDQIterator &lhs, const PDQIterator &rhs) { - D_ASSERT(duckdb::NumericCast(*lhs - *rhs) % lhs.entry_size == 0); - D_ASSERT(*lhs - *rhs >= 0); - return duckdb::NumericCast(*lhs - *rhs) / lhs.entry_size; - } - - inline friend bool operator<(const PDQIterator &lhs, const PDQIterator &rhs) { - return *lhs < *rhs; - } - - inline friend bool operator>(const PDQIterator &lhs, const PDQIterator &rhs) { - return *lhs > *rhs; - } - - inline friend bool operator>=(const PDQIterator &lhs, const PDQIterator &rhs) { - return *lhs >= *rhs; - } - - inline friend bool operator<=(const PDQIterator &lhs, const PDQIterator &rhs) { - return *lhs <= *rhs; - } - - inline friend bool operator==(const PDQIterator &lhs, const PDQIterator &rhs) { - return *lhs == *rhs; - } - - inline friend bool operator!=(const PDQIterator &lhs, const PDQIterator &rhs) { - return *lhs != *rhs; - } - -private: - data_ptr_t ptr; - const idx_t &entry_size; -}; - -static inline bool comp(const data_ptr_t &l, const data_ptr_t &r, const PDQConstants &constants) { - D_ASSERT(l == constants.tmp_buf || l == constants.swap_offsets_buf || l < constants.end); - D_ASSERT(r == constants.tmp_buf || r == constants.swap_offsets_buf || r < constants.end); - return FastMemcmp(l + constants.comp_offset, r + constants.comp_offset, constants.comp_size) < 0; -} - -static inline const data_ptr_t &GET_TMP(const data_ptr_t &src, const PDQConstants &constants) { - D_ASSERT(src != constants.tmp_buf && src != constants.swap_offsets_buf && src < constants.end); - FastMemcpy(constants.tmp_buf, src, constants.entry_size); - return constants.tmp_buf; -} - -static inline const data_ptr_t &SWAP_OFFSETS_GET_TMP(const data_ptr_t &src, const PDQConstants &constants) { - D_ASSERT(src != constants.tmp_buf && src != constants.swap_offsets_buf && src < constants.end); - FastMemcpy(constants.swap_offsets_buf, src, constants.entry_size); - return constants.swap_offsets_buf; -} - -static inline void MOVE(const data_ptr_t &dest, const data_ptr_t &src, const PDQConstants &constants) { - D_ASSERT(dest == constants.tmp_buf || dest == constants.swap_offsets_buf || dest < constants.end); - D_ASSERT(src == constants.tmp_buf || src == constants.swap_offsets_buf || src < constants.end); - FastMemcpy(dest, src, constants.entry_size); -} - -static inline void iter_swap(const PDQIterator &lhs, const PDQIterator &rhs, const PDQConstants &constants) { - D_ASSERT(*lhs < constants.end); - D_ASSERT(*rhs < constants.end); - FastMemcpy(constants.iter_swap_buf, *lhs, constants.entry_size); - FastMemcpy(*lhs, *rhs, constants.entry_size); - FastMemcpy(*rhs, constants.iter_swap_buf, constants.entry_size); -} - -// Sorts [begin, end) using insertion sort with the given comparison function. -inline void insertion_sort(const PDQIterator &begin, const PDQIterator &end, const PDQConstants &constants) { - if (begin == end) { - return; - } - - for (PDQIterator cur = begin + 1; cur != end; ++cur) { - PDQIterator sift = cur; - PDQIterator sift_1 = cur - 1; - - // Compare first so we can avoid 2 moves for an element already positioned correctly. - if (comp(*sift, *sift_1, constants)) { - const auto &tmp = GET_TMP(*sift, constants); - - do { - MOVE(*sift--, *sift_1, constants); - } while (sift != begin && comp(tmp, *--sift_1, constants)); - - MOVE(*sift, tmp, constants); - } - } -} - -// Sorts [begin, end) using insertion sort with the given comparison function. Assumes -// *(begin - 1) is an element smaller than or equal to any element in [begin, end). -inline void unguarded_insertion_sort(const PDQIterator &begin, const PDQIterator &end, const PDQConstants &constants) { - if (begin == end) { - return; - } - - for (PDQIterator cur = begin + 1; cur != end; ++cur) { - PDQIterator sift = cur; - PDQIterator sift_1 = cur - 1; - - // Compare first so we can avoid 2 moves for an element already positioned correctly. - if (comp(*sift, *sift_1, constants)) { - const auto &tmp = GET_TMP(*sift, constants); - - do { - MOVE(*sift--, *sift_1, constants); - } while (comp(tmp, *--sift_1, constants)); - - MOVE(*sift, tmp, constants); - } - } -} - -// Attempts to use insertion sort on [begin, end). Will return false if more than -// partial_insertion_sort_limit elements were moved, and abort sorting. Otherwise it will -// successfully sort and return true. -inline bool partial_insertion_sort(const PDQIterator &begin, const PDQIterator &end, const PDQConstants &constants) { - if (begin == end) { - return true; - } - - std::size_t limit = 0; - for (PDQIterator cur = begin + 1; cur != end; ++cur) { - PDQIterator sift = cur; - PDQIterator sift_1 = cur - 1; - - // Compare first so we can avoid 2 moves for an element already positioned correctly. - if (comp(*sift, *sift_1, constants)) { - const auto &tmp = GET_TMP(*sift, constants); - - do { - MOVE(*sift--, *sift_1, constants); - } while (sift != begin && comp(tmp, *--sift_1, constants)); - - MOVE(*sift, tmp, constants); - limit += cur - sift; - } - - if (limit > partial_insertion_sort_limit) { - return false; - } - } - - return true; -} - -inline void sort2(const PDQIterator &a, const PDQIterator &b, const PDQConstants &constants) { - if (comp(*b, *a, constants)) { - iter_swap(a, b, constants); - } -} - -// Sorts the elements *a, *b and *c using comparison function comp. -inline void sort3(const PDQIterator &a, const PDQIterator &b, const PDQIterator &c, const PDQConstants &constants) { - sort2(a, b, constants); - sort2(b, c, constants); - sort2(a, b, constants); -} - -template -inline T *align_cacheline(T *p) { -#if defined(UINTPTR_MAX) && __cplusplus >= 201103L - std::uintptr_t ip = reinterpret_cast(p); -#else - std::size_t ip = reinterpret_cast(p); -#endif - ip = (ip + cacheline_size - 1) & -duckdb::UnsafeNumericCast(cacheline_size); - return reinterpret_cast(ip); -} - -inline void swap_offsets(const PDQIterator &first, const PDQIterator &last, unsigned char *offsets_l, - unsigned char *offsets_r, size_t num, bool use_swaps, const PDQConstants &constants) { - if (use_swaps) { - // This case is needed for the descending distribution, where we need - // to have proper swapping for pdqsort to remain O(n). - for (size_t i = 0; i < num; ++i) { - iter_swap(first + offsets_l[i], last - offsets_r[i], constants); - } - } else if (num > 0) { - PDQIterator l = first + offsets_l[0]; - PDQIterator r = last - offsets_r[0]; - const auto &tmp = SWAP_OFFSETS_GET_TMP(*l, constants); - MOVE(*l, *r, constants); - for (size_t i = 1; i < num; ++i) { - l = first + offsets_l[i]; - MOVE(*r, *l, constants); - r = last - offsets_r[i]; - MOVE(*l, *r, constants); - } - MOVE(*r, tmp, constants); - } -} - -// Partitions [begin, end) around pivot *begin using comparison function comp. Elements equal -// to the pivot are put in the right-hand partition. Returns the position of the pivot after -// partitioning and whether the passed sequence already was correctly partitioned. Assumes the -// pivot is a median of at least 3 elements and that [begin, end) is at least -// insertion_sort_threshold long. Uses branchless partitioning. -inline std::pair partition_right_branchless(const PDQIterator &begin, const PDQIterator &end, - const PDQConstants &constants) { - // Move pivot into local for speed. - const auto &pivot = GET_TMP(*begin, constants); - PDQIterator first = begin; - PDQIterator last = end; - - // Find the first element greater than or equal than the pivot (the median of 3 guarantees - // this exists). - while (comp(*++first, pivot, constants)) { - } - - // Find the first element strictly smaller than the pivot. We have to guard this search if - // there was no element before *first. - if (first - 1 == begin) { - while (first < last && !comp(*--last, pivot, constants)) { - } - } else { - while (!comp(*--last, pivot, constants)) { - } - } - - // If the first pair of elements that should be swapped to partition are the same element, - // the passed in sequence already was correctly partitioned. - bool already_partitioned = first >= last; - if (!already_partitioned) { - iter_swap(first, last, constants); - ++first; - - // The following branchless partitioning is derived from "BlockQuicksort: How Branch - // Mispredictions don’t affect Quicksort" by Stefan Edelkamp and Armin Weiss, but - // heavily micro-optimized. - unsigned char offsets_l_storage[block_size + cacheline_size]; - unsigned char offsets_r_storage[block_size + cacheline_size]; - unsigned char *offsets_l = align_cacheline(offsets_l_storage); - unsigned char *offsets_r = align_cacheline(offsets_r_storage); - - PDQIterator offsets_l_base = first; - PDQIterator offsets_r_base = last; - size_t num_l, num_r, start_l, start_r; - num_l = num_r = start_l = start_r = 0; - - while (first < last) { - // Fill up offset blocks with elements that are on the wrong side. - // First we determine how much elements are considered for each offset block. - size_t num_unknown = last - first; - size_t left_split = num_l == 0 ? (num_r == 0 ? num_unknown / 2 : num_unknown) : 0; - size_t right_split = num_r == 0 ? (num_unknown - left_split) : 0; - - // Fill the offset blocks. - if (left_split >= block_size) { - for (unsigned char i = 0; i < block_size;) { - offsets_l[num_l] = i++; - num_l += !comp(*first, pivot, constants); - ++first; - offsets_l[num_l] = i++; - num_l += !comp(*first, pivot, constants); - ++first; - offsets_l[num_l] = i++; - num_l += !comp(*first, pivot, constants); - ++first; - offsets_l[num_l] = i++; - num_l += !comp(*first, pivot, constants); - ++first; - offsets_l[num_l] = i++; - num_l += !comp(*first, pivot, constants); - ++first; - offsets_l[num_l] = i++; - num_l += !comp(*first, pivot, constants); - ++first; - offsets_l[num_l] = i++; - num_l += !comp(*first, pivot, constants); - ++first; - offsets_l[num_l] = i++; - num_l += !comp(*first, pivot, constants); - ++first; - } - } else { - for (unsigned char i = 0; i < left_split;) { - offsets_l[num_l] = i++; - num_l += !comp(*first, pivot, constants); - ++first; - } - } - - if (right_split >= block_size) { - for (unsigned char i = 0; i < block_size;) { - offsets_r[num_r] = ++i; - num_r += comp(*--last, pivot, constants); - offsets_r[num_r] = ++i; - num_r += comp(*--last, pivot, constants); - offsets_r[num_r] = ++i; - num_r += comp(*--last, pivot, constants); - offsets_r[num_r] = ++i; - num_r += comp(*--last, pivot, constants); - offsets_r[num_r] = ++i; - num_r += comp(*--last, pivot, constants); - offsets_r[num_r] = ++i; - num_r += comp(*--last, pivot, constants); - offsets_r[num_r] = ++i; - num_r += comp(*--last, pivot, constants); - offsets_r[num_r] = ++i; - num_r += comp(*--last, pivot, constants); - } - } else { - for (unsigned char i = 0; i < right_split;) { - offsets_r[num_r] = ++i; - num_r += comp(*--last, pivot, constants); - } - } - - // Swap elements and update block sizes and first/last boundaries. - size_t num = std::min(num_l, num_r); - swap_offsets(offsets_l_base, offsets_r_base, offsets_l + start_l, offsets_r + start_r, num, num_l == num_r, - constants); - num_l -= num; - num_r -= num; - start_l += num; - start_r += num; - - if (num_l == 0) { - start_l = 0; - offsets_l_base = first; - } - - if (num_r == 0) { - start_r = 0; - offsets_r_base = last; - } - } - - // We have now fully identified [first, last)'s proper position. Swap the last elements. - if (num_l) { - offsets_l += start_l; - while (num_l--) { - iter_swap(offsets_l_base + offsets_l[num_l], --last, constants); - } - first = last; - } - if (num_r) { - offsets_r += start_r; - while (num_r--) { - iter_swap(offsets_r_base - offsets_r[num_r], first, constants), ++first; - } - last = first; - } - } - - // Put the pivot in the right place. - PDQIterator pivot_pos = first - 1; - MOVE(*begin, *pivot_pos, constants); - MOVE(*pivot_pos, pivot, constants); - - return std::make_pair(pivot_pos, already_partitioned); -} - -// Partitions [begin, end) around pivot *begin using comparison function comp. Elements equal -// to the pivot are put in the right-hand partition. Returns the position of the pivot after -// partitioning and whether the passed sequence already was correctly partitioned. Assumes the -// pivot is a median of at least 3 elements and that [begin, end) is at least -// insertion_sort_threshold long. -inline std::pair partition_right(const PDQIterator &begin, const PDQIterator &end, - const PDQConstants &constants) { - // Move pivot into local for speed. - const auto &pivot = GET_TMP(*begin, constants); - - PDQIterator first = begin; - PDQIterator last = end; - - // Find the first element greater than or equal than the pivot (the median of 3 guarantees - // this exists). - while (comp(*++first, pivot, constants)) { - } - - // Find the first element strictly smaller than the pivot. We have to guard this search if - // there was no element before *first. - if (first - 1 == begin) { - while (first < last && !comp(*--last, pivot, constants)) { - } - } else { - while (!comp(*--last, pivot, constants)) { - } - } - - // If the first pair of elements that should be swapped to partition are the same element, - // the passed in sequence already was correctly partitioned. - bool already_partitioned = first >= last; - - // Keep swapping pairs of elements that are on the wrong side of the pivot. Previously - // swapped pairs guard the searches, which is why the first iteration is special-cased - // above. - while (first < last) { - iter_swap(first, last, constants); - while (comp(*++first, pivot, constants)) { - } - while (!comp(*--last, pivot, constants)) { - } - } - - // Put the pivot in the right place. - PDQIterator pivot_pos = first - 1; - MOVE(*begin, *pivot_pos, constants); - MOVE(*pivot_pos, pivot, constants); - - return std::make_pair(pivot_pos, already_partitioned); -} - -// Similar function to the one above, except elements equal to the pivot are put to the left of -// the pivot and it doesn't check or return if the passed sequence already was partitioned. -// Since this is rarely used (the many equal case), and in that case pdqsort already has O(n) -// performance, no block quicksort is applied here for simplicity. -inline PDQIterator partition_left(const PDQIterator &begin, const PDQIterator &end, const PDQConstants &constants) { - const auto &pivot = GET_TMP(*begin, constants); - PDQIterator first = begin; - PDQIterator last = end; - - while (comp(pivot, *--last, constants)) { - } - - if (last + 1 == end) { - while (first < last && !comp(pivot, *++first, constants)) { - } - } else { - while (!comp(pivot, *++first, constants)) { - } - } - - while (first < last) { - iter_swap(first, last, constants); - while (comp(pivot, *--last, constants)) { - } - while (!comp(pivot, *++first, constants)) { - } - } - - PDQIterator pivot_pos = last; - MOVE(*begin, *pivot_pos, constants); - MOVE(*pivot_pos, pivot, constants); - - return pivot_pos; -} - -template -inline void pdqsort_loop(PDQIterator begin, const PDQIterator &end, const PDQConstants &constants, int bad_allowed, - bool leftmost = true) { - // Use a while loop for tail recursion elimination. - while (true) { - idx_t size = end - begin; - - // Insertion sort is faster for small arrays. - if (size < insertion_sort_threshold) { - if (leftmost) { - insertion_sort(begin, end, constants); - } else { - unguarded_insertion_sort(begin, end, constants); - } - return; - } - - // Choose pivot as median of 3 or pseudomedian of 9. - idx_t s2 = size / 2; - if (size > ninther_threshold) { - sort3(begin, begin + s2, end - 1, constants); - sort3(begin + 1, begin + (s2 - 1), end - 2, constants); - sort3(begin + 2, begin + (s2 + 1), end - 3, constants); - sort3(begin + (s2 - 1), begin + s2, begin + (s2 + 1), constants); - iter_swap(begin, begin + s2, constants); - } else { - sort3(begin + s2, begin, end - 1, constants); - } - - // If *(begin - 1) is the end of the right partition of a previous partition operation - // there is no element in [begin, end) that is smaller than *(begin - 1). Then if our - // pivot compares equal to *(begin - 1) we change strategy, putting equal elements in - // the left partition, greater elements in the right partition. We do not have to - // recurse on the left partition, since it's sorted (all equal). - if (!leftmost && !comp(*(begin - 1), *begin, constants)) { - begin = partition_left(begin, end, constants) + 1; - continue; - } - - // Partition and get results. - std::pair part_result = - Branchless ? partition_right_branchless(begin, end, constants) : partition_right(begin, end, constants); - PDQIterator pivot_pos = part_result.first; - bool already_partitioned = part_result.second; - - // Check for a highly unbalanced partition. - idx_t l_size = pivot_pos - begin; - idx_t r_size = end - (pivot_pos + 1); - bool highly_unbalanced = l_size < size / 8 || r_size < size / 8; - - // If we got a highly unbalanced partition we shuffle elements to break many patterns. - if (highly_unbalanced) { - // If we had too many bad partitions, switch to heapsort to guarantee O(n log n). - // if (--bad_allowed == 0) { - // std::make_heap(begin, end, comp); - // std::sort_heap(begin, end, comp); - // return; - // } - - if (l_size >= insertion_sort_threshold) { - iter_swap(begin, begin + l_size / 4, constants); - iter_swap(pivot_pos - 1, pivot_pos - l_size / 4, constants); - - if (l_size > ninther_threshold) { - iter_swap(begin + 1, begin + (l_size / 4 + 1), constants); - iter_swap(begin + 2, begin + (l_size / 4 + 2), constants); - iter_swap(pivot_pos - 2, pivot_pos - (l_size / 4 + 1), constants); - iter_swap(pivot_pos - 3, pivot_pos - (l_size / 4 + 2), constants); - } - } - - if (r_size >= insertion_sort_threshold) { - iter_swap(pivot_pos + 1, pivot_pos + (1 + r_size / 4), constants); - iter_swap(end - 1, end - r_size / 4, constants); - - if (r_size > ninther_threshold) { - iter_swap(pivot_pos + 2, pivot_pos + (2 + r_size / 4), constants); - iter_swap(pivot_pos + 3, pivot_pos + (3 + r_size / 4), constants); - iter_swap(end - 2, end - (1 + r_size / 4), constants); - iter_swap(end - 3, end - (2 + r_size / 4), constants); - } - } - } else { - // If we were decently balanced and we tried to sort an already partitioned - // sequence try to use insertion sort. - if (already_partitioned && partial_insertion_sort(begin, pivot_pos, constants) && - partial_insertion_sort(pivot_pos + 1, end, constants)) { - return; - } - } - - // Sort the left partition first using recursion and do tail recursion elimination for - // the right-hand partition. - pdqsort_loop(begin, pivot_pos, constants, bad_allowed, leftmost); - begin = pivot_pos + 1; - leftmost = false; - } -} - -inline void pdqsort(const PDQIterator &begin, const PDQIterator &end, const PDQConstants &constants) { - if (begin == end) { - return; - } - pdqsort_loop(begin, end, constants, log2(end - begin)); -} - -inline void pdqsort_branchless(const PDQIterator &begin, const PDQIterator &end, const PDQConstants &constants) { - if (begin == end) { - return; - } - pdqsort_loop(begin, end, constants, log2(end - begin)); -} -// NOLINTEND - -} // namespace duckdb_pdqsort diff --git a/src/duckdb/src/include/duckdb/common/sort/partition_state.hpp b/src/duckdb/src/include/duckdb/common/sort/partition_state.hpp deleted file mode 100644 index 8170875e8..000000000 --- a/src/duckdb/src/include/duckdb/common/sort/partition_state.hpp +++ /dev/null @@ -1,245 +0,0 @@ -//===----------------------------------------------------------------------===// -// DuckDB -// -// duckdb/common/sort/partition_state.hpp -// -// -//===----------------------------------------------------------------------===// - -#pragma once - -#include "duckdb/common/sort/sort.hpp" -#include "duckdb/common/radix_partitioning.hpp" -#include "duckdb/parallel/base_pipeline_event.hpp" - -namespace duckdb { - -class PartitionGlobalHashGroup { -public: - using GlobalSortStatePtr = unique_ptr; - using Orders = vector; - using Types = vector; - using OrderMasks = unordered_map; - - PartitionGlobalHashGroup(ClientContext &context, const Orders &partitions, const Orders &orders, - const Types &payload_types, bool external); - - inline int ComparePartitions(const SBIterator &left, const SBIterator &right) { - int part_cmp = 0; - if (partition_layout.all_constant) { - part_cmp = FastMemcmp(left.entry_ptr, right.entry_ptr, partition_layout.comparison_size); - } else { - part_cmp = Comparators::CompareTuple(left.scan, right.scan, left.entry_ptr, right.entry_ptr, - partition_layout, left.external); - } - return part_cmp; - } - - void ComputeMasks(ValidityMask &partition_mask, OrderMasks &order_masks); - - GlobalSortStatePtr global_sort; - atomic count; - - // Mask computation - SortLayout partition_layout; -}; - -class PartitionGlobalSinkState { -public: - using HashGroupPtr = unique_ptr; - using Orders = vector; - using Types = vector; - - using GroupingPartition = unique_ptr; - using GroupingAppend = unique_ptr; - - static void GenerateOrderings(Orders &partitions, Orders &orders, - const vector> &partition_bys, const Orders &order_bys, - const vector> &partitions_stats); - - PartitionGlobalSinkState(ClientContext &context, const vector> &partition_bys, - const vector &order_bys, const Types &payload_types, - const vector> &partitions_stats, idx_t estimated_cardinality); - virtual ~PartitionGlobalSinkState() = default; - - bool HasMergeTasks() const; - - unique_ptr CreatePartition(idx_t new_bits) const; - void SyncPartitioning(const PartitionGlobalSinkState &other); - - void UpdateLocalPartition(GroupingPartition &local_partition, GroupingAppend &local_append); - void CombineLocalPartition(GroupingPartition &local_partition, GroupingAppend &local_append); - - virtual void OnBeginMerge() {}; - virtual void OnSortedPartition(const idx_t hash_bin_p) {}; - - ClientContext &context; - BufferManager &buffer_manager; - Allocator &allocator; - mutex lock; - - // OVER(PARTITION BY...) (hash grouping) - unique_ptr grouping_data; - //! Payload plus hash column - shared_ptr grouping_types_ptr; - //! The number of radix bits if this partition is being synced with another - idx_t fixed_bits; - - // OVER(...) (sorting) - Orders partitions; - Orders orders; - const Types payload_types; - vector hash_groups; - bool external; - // Reverse lookup from hash bins to non-empty hash groups - vector bin_groups; - - // OVER() (no sorting) - unique_ptr rows; - unique_ptr strings; - - // Threading - idx_t memory_per_thread; - idx_t max_bits; - atomic count; - -private: - void ResizeGroupingData(idx_t cardinality); - void SyncLocalPartition(GroupingPartition &local_partition, GroupingAppend &local_append); -}; - -class PartitionLocalSinkState { -public: - using LocalSortStatePtr = unique_ptr; - - PartitionLocalSinkState(ClientContext &context, PartitionGlobalSinkState &gstate_p); - - // Global state - PartitionGlobalSinkState &gstate; - Allocator &allocator; - - // Shared expression evaluation - ExpressionExecutor executor; - DataChunk group_chunk; - DataChunk payload_chunk; - size_t sort_cols; - - // OVER(PARTITION BY...) (hash grouping) - unique_ptr local_partition; - unique_ptr local_append; - - // OVER(ORDER BY...) (only sorting) - LocalSortStatePtr local_sort; - - // OVER() (no sorting) - RowLayout payload_layout; - unique_ptr rows; - unique_ptr strings; - - //! Compute the hash values - void Hash(DataChunk &input_chunk, Vector &hash_vector); - //! Sink an input chunk - void Sink(DataChunk &input_chunk); - //! Merge the state into the global state. - void Combine(); -}; - -enum class PartitionSortStage : uint8_t { INIT, SCAN, PREPARE, MERGE, SORTED, FINISHED }; - -class PartitionLocalMergeState; - -class PartitionGlobalMergeState { -public: - using GroupDataPtr = unique_ptr; - - // OVER(PARTITION BY...) - PartitionGlobalMergeState(PartitionGlobalSinkState &sink, GroupDataPtr group_data, hash_t hash_bin); - - // OVER(ORDER BY...) - explicit PartitionGlobalMergeState(PartitionGlobalSinkState &sink); - - bool IsFinished() const { - return stage == PartitionSortStage::FINISHED; - } - - bool AssignTask(PartitionLocalMergeState &local_state); - bool TryPrepareNextStage(); - void CompleteTask(); - - PartitionGlobalSinkState &sink; - GroupDataPtr group_data; - PartitionGlobalHashGroup *hash_group; - const idx_t group_idx; - vector column_ids; - TupleDataParallelScanState chunk_state; - GlobalSortState *global_sort; - const idx_t memory_per_thread; - const idx_t num_threads; - -private: - mutable mutex lock; - atomic stage; - idx_t total_tasks; - idx_t tasks_assigned; - idx_t tasks_completed; -}; - -class PartitionLocalMergeState { -public: - explicit PartitionLocalMergeState(PartitionGlobalSinkState &gstate); - - bool TaskFinished() { - return finished; - } - - void Prepare(); - void Scan(); - void Merge(); - void Sorted(); - - void ExecuteTask(); - - PartitionGlobalMergeState *merge_state; - PartitionSortStage stage; - atomic finished; - - // Sorting buffers - ExpressionExecutor executor; - DataChunk sort_chunk; - DataChunk payload_chunk; -}; - -class PartitionGlobalMergeStates { -public: - struct Callback { - virtual ~Callback() = default; - - virtual bool HasError() const { - return false; - } - }; - - using PartitionGlobalMergeStatePtr = unique_ptr; - - explicit PartitionGlobalMergeStates(PartitionGlobalSinkState &sink); - - bool ExecuteTask(PartitionLocalMergeState &local_state, Callback &callback); - - vector states; -}; - -class PartitionMergeEvent : public BasePipelineEvent { -public: - PartitionMergeEvent(PartitionGlobalSinkState &gstate_p, Pipeline &pipeline_p, const PhysicalOperator &op_p) - : BasePipelineEvent(pipeline_p), gstate(gstate_p), merge_states(gstate_p), op(op_p) { - } - - PartitionGlobalSinkState &gstate; - PartitionGlobalMergeStates merge_states; - const PhysicalOperator &op; - -public: - void Schedule() override; -}; - -} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/common/sort/sort.hpp b/src/duckdb/src/include/duckdb/common/sort/sort.hpp deleted file mode 100644 index 188ea2127..000000000 --- a/src/duckdb/src/include/duckdb/common/sort/sort.hpp +++ /dev/null @@ -1,290 +0,0 @@ -//===----------------------------------------------------------------------===// -// DuckDB -// -// duckdb/common/sort/sort.hpp -// -// -//===----------------------------------------------------------------------===// - -#pragma once - -#include "duckdb/common/sort/sorted_block.hpp" -#include "duckdb/common/types/row/row_data_collection.hpp" -#include "duckdb/planner/bound_query_node.hpp" - -namespace duckdb { - -class RowLayout; -struct LocalSortState; - -struct SortConstants { - static constexpr idx_t VALUES_PER_RADIX = 256; - static constexpr idx_t MSD_RADIX_LOCATIONS = VALUES_PER_RADIX + 1; - static constexpr idx_t INSERTION_SORT_THRESHOLD = 24; - static constexpr idx_t MSD_RADIX_SORT_SIZE_THRESHOLD = 4; -}; - -struct SortLayout { -public: - SortLayout() { - } - explicit SortLayout(const vector &orders); - SortLayout GetPrefixComparisonLayout(idx_t num_prefix_cols) const; - -public: - idx_t column_count; - vector order_types; - vector order_by_null_types; - vector logical_types; - - bool all_constant; - vector constant_size; - vector column_sizes; - vector prefix_lengths; - vector stats; - vector has_null; - - idx_t comparison_size; - idx_t entry_size; - - RowLayout blob_layout; - unordered_map sorting_to_blob_col; -}; - -struct GlobalSortState { -public: - GlobalSortState(ClientContext &context, const vector &orders, RowLayout &payload_layout); - - //! Add local state sorted data to this global state - void AddLocalState(LocalSortState &local_sort_state); - //! Prepares the GlobalSortState for the merge sort phase (after completing radix sort phase) - void PrepareMergePhase(); - //! Initializes the global sort state for another round of merging - void InitializeMergeRound(); - //! Completes the cascaded merge sort round. - //! Pass true if you wish to use the radix data for further comparisons. - void CompleteMergeRound(bool keep_radix_data = false); - //! Print the sorted data to the console. - void Print(); - -public: - //! The client context - ClientContext &context; - //! The lock for updating the order global state - mutex lock; - //! The buffer manager - BufferManager &buffer_manager; - - //! Sorting and payload layouts - const SortLayout sort_layout; - const RowLayout payload_layout; - - //! Sorted data - vector> sorted_blocks; - vector>> sorted_blocks_temp; - unique_ptr odd_one_out; - - //! Pinned heap data (if sorting in memory) - vector> heap_blocks; - vector pinned_blocks; - - //! Capacity (number of rows) used to initialize blocks - idx_t block_capacity; - //! Whether we are doing an external sort - bool external; - - //! Progress in merge path stage - idx_t pair_idx; - idx_t num_pairs; - idx_t l_start; - idx_t r_start; -}; - -struct LocalSortState { -public: - LocalSortState(); - - //! Initialize the layouts and RowDataCollections - void Initialize(GlobalSortState &global_sort_state, BufferManager &buffer_manager_p); - //! Sink one DataChunk into the local sort state - void SinkChunk(DataChunk &sort, DataChunk &payload); - //! Size of accumulated data in bytes - idx_t SizeInBytes() const; - //! Sort the data accumulated so far - void Sort(GlobalSortState &global_sort_state, bool reorder_heap); - //! Concatenate the blocks held by a RowDataCollection into a single block - static unique_ptr ConcatenateBlocks(RowDataCollection &row_data); - -private: - //! Sorts the data in the newly created SortedBlock - void SortInMemory(); - //! Re-order the local state after sorting - void ReOrder(GlobalSortState &gstate, bool reorder_heap); - //! Re-order a SortedData object after sorting - void ReOrder(SortedData &sd, data_ptr_t sorting_ptr, RowDataCollection &heap, GlobalSortState &gstate, - bool reorder_heap); - -public: - //! Whether this local state has been initialized - bool initialized; - //! The buffer manager - BufferManager *buffer_manager; - //! The sorting and payload layouts - const SortLayout *sort_layout; - const RowLayout *payload_layout; - //! Radix/memcmp sortable data - unique_ptr radix_sorting_data; - //! Variable sized sorting data and accompanying heap - unique_ptr blob_sorting_data; - unique_ptr blob_sorting_heap; - //! Payload data and accompanying heap - unique_ptr payload_data; - unique_ptr payload_heap; - //! Sorted data - vector> sorted_blocks; - -private: - //! Selection vector and addresses for scattering the data to rows - const SelectionVector &sel_ptr = *FlatVector::IncrementalSelectionVector(); - Vector addresses = Vector(LogicalType::POINTER); -}; - -struct MergeSorter { -public: - MergeSorter(GlobalSortState &state, BufferManager &buffer_manager); - - //! Finds and merges partitions until the current cascaded merge round is finished - void PerformInMergeRound(); - -private: - //! The global sorting state - GlobalSortState &state; - //! The sorting and payload layouts - BufferManager &buffer_manager; - const SortLayout &sort_layout; - - //! The left and right reader - unique_ptr left; - unique_ptr right; - - //! Input and output blocks - unique_ptr left_input; - unique_ptr right_input; - SortedBlock *result; - -private: - //! Computes the left and right block that will be merged next (Merge Path partition) - void GetNextPartition(); - //! Finds the boundary of the next partition using binary search - void GetIntersection(const idx_t diagonal, idx_t &l_idx, idx_t &r_idx); - //! Compare values within SortedBlocks using a global index - int CompareUsingGlobalIndex(SBScanState &l, SBScanState &r, const idx_t l_idx, const idx_t r_idx); - - //! Finds the next partition and merges it - void MergePartition(); - - //! Computes how the next 'count' tuples should be merged by setting the 'left_smaller' array - void ComputeMerge(const idx_t &count, bool left_smaller[]); - - //! Merges the radix sorting blocks according to the 'left_smaller' array - void MergeRadix(const idx_t &count, const bool left_smaller[]); - //! Merges SortedData according to the 'left_smaller' array - void MergeData(SortedData &result_data, SortedData &l_data, SortedData &r_data, const idx_t &count, - const bool left_smaller[], idx_t next_entry_sizes[], bool reset_indices); - //! Merges constant size rows according to the 'left_smaller' array - void MergeRows(data_ptr_t &l_ptr, idx_t &l_entry_idx, const idx_t &l_count, data_ptr_t &r_ptr, idx_t &r_entry_idx, - const idx_t &r_count, RowDataBlock &target_block, data_ptr_t &target_ptr, const idx_t &entry_size, - const bool left_smaller[], idx_t &copied, const idx_t &count); - //! Flushes constant size rows into the result - void FlushRows(data_ptr_t &source_ptr, idx_t &source_entry_idx, const idx_t &source_count, - RowDataBlock &target_block, data_ptr_t &target_ptr, const idx_t &entry_size, idx_t &copied, - const idx_t &count); - //! Flushes blob rows and accompanying heap - void FlushBlobs(const RowLayout &layout, const idx_t &source_count, data_ptr_t &source_data_ptr, - idx_t &source_entry_idx, data_ptr_t &source_heap_ptr, RowDataBlock &target_data_block, - data_ptr_t &target_data_ptr, RowDataBlock &target_heap_block, BufferHandle &target_heap_handle, - data_ptr_t &target_heap_ptr, idx_t &copied, const idx_t &count); -}; - -struct SBIterator { - static int ComparisonValue(ExpressionType comparison); - - SBIterator(GlobalSortState &gss, ExpressionType comparison, idx_t entry_idx_p = 0); - - inline idx_t GetIndex() const { - return entry_idx; - } - - inline void SetIndex(idx_t entry_idx_p) { - const auto new_block_idx = entry_idx_p / block_capacity; - if (new_block_idx != scan.block_idx) { - scan.SetIndices(new_block_idx, 0); - if (new_block_idx < block_count) { - scan.PinRadix(scan.block_idx); - block_ptr = scan.RadixPtr(); - if (!all_constant) { - scan.PinData(*scan.sb->blob_sorting_data); - } - } - } - - scan.entry_idx = entry_idx_p % block_capacity; - entry_ptr = block_ptr + scan.entry_idx * entry_size; - entry_idx = entry_idx_p; - } - - inline SBIterator &operator++() { - if (++scan.entry_idx < block_capacity) { - entry_ptr += entry_size; - ++entry_idx; - } else { - SetIndex(entry_idx + 1); - } - - return *this; - } - - inline SBIterator &operator--() { - if (scan.entry_idx) { - --scan.entry_idx; - --entry_idx; - entry_ptr -= entry_size; - } else { - SetIndex(entry_idx - 1); - } - - return *this; - } - - inline bool Compare(const SBIterator &other, const SortLayout &prefix) const { - int comp_res; - if (all_constant) { - comp_res = FastMemcmp(entry_ptr, other.entry_ptr, prefix.comparison_size); - } else { - comp_res = Comparators::CompareTuple(scan, other.scan, entry_ptr, other.entry_ptr, prefix, external); - } - - return comp_res <= cmp; - } - - inline bool Compare(const SBIterator &other) const { - return Compare(other, sort_layout); - } - - // Fixed comparison parameters - const SortLayout &sort_layout; - const idx_t block_count; - const idx_t block_capacity; - const size_t entry_size; - const bool all_constant; - const bool external; - const int cmp; - - // Iteration state - SBScanState scan; - idx_t entry_idx; - data_ptr_t block_ptr; - data_ptr_t entry_ptr; -}; - -} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/common/sort/sorted_block.hpp b/src/duckdb/src/include/duckdb/common/sort/sorted_block.hpp deleted file mode 100644 index b6941bda2..000000000 --- a/src/duckdb/src/include/duckdb/common/sort/sorted_block.hpp +++ /dev/null @@ -1,165 +0,0 @@ -//===----------------------------------------------------------------------===// -// DuckDB -// -// duckdb/common/sort/sorted_block.hpp -// -// -//===----------------------------------------------------------------------===// -#pragma once - -#include "duckdb/common/fast_mem.hpp" -#include "duckdb/common/sort/comparators.hpp" -#include "duckdb/common/types/row/row_data_collection_scanner.hpp" -#include "duckdb/common/types/row/row_layout.hpp" -#include "duckdb/storage/buffer/buffer_handle.hpp" - -namespace duckdb { - -class BufferManager; -struct RowDataBlock; -struct SortLayout; -struct GlobalSortState; - -enum class SortedDataType { BLOB, PAYLOAD }; - -//! Object that holds sorted rows, and an accompanying heap if there are blobs -struct SortedData { -public: - SortedData(SortedDataType type, const RowLayout &layout, BufferManager &buffer_manager, GlobalSortState &state); - //! Number of rows that this object holds - idx_t Count(); - //! Initialize new block to write to - void CreateBlock(); - //! Create a slice that holds the rows between the start and end indices - unique_ptr CreateSlice(idx_t start_block_index, idx_t end_block_index, idx_t end_entry_index); - //! Unswizzles all - void Unswizzle(); - -public: - const SortedDataType type; - //! Layout of this data - const RowLayout layout; - //! Data and heap blocks - vector> data_blocks; - vector> heap_blocks; - //! Whether the pointers in this sorted data are swizzled - bool swizzled; - -private: - //! The buffer manager - BufferManager &buffer_manager; - //! The global state - GlobalSortState &state; -}; - -//! Block that holds sorted rows: radix, blob and payload data -struct SortedBlock { -public: - SortedBlock(BufferManager &buffer_manager, GlobalSortState &gstate); - //! Number of rows that this object holds - idx_t Count() const; - //! Initialize this block to write data to - void InitializeWrite(); - //! Init new block to write to - void CreateBlock(); - //! Fill this sorted block by appending the blocks held by a vector of sorted blocks - void AppendSortedBlocks(vector> &sorted_blocks); - //! Locate the block and entry index of a row in this block, - //! given an index between 0 and the total number of rows in this block - void GlobalToLocalIndex(const idx_t &global_idx, idx_t &local_block_index, idx_t &local_entry_index); - //! Create a slice that holds the rows between the start and end indices - unique_ptr CreateSlice(const idx_t start, const idx_t end, idx_t &entry_idx); - - //! Size (in bytes) of the heap of this block - idx_t HeapSize() const; - //! Total size (in bytes) of this block - idx_t SizeInBytes() const; - -public: - //! Radix/memcmp sortable data - vector> radix_sorting_data; - //! Variable sized sorting data - unique_ptr blob_sorting_data; - //! Payload data - unique_ptr payload_data; - -private: - //! Buffer manager, global state, and sorting layout constants - BufferManager &buffer_manager; - GlobalSortState &state; - const SortLayout &sort_layout; - const RowLayout &payload_layout; -}; - -//! State used to scan a SortedBlock e.g. during merge sort -struct SBScanState { -public: - SBScanState(BufferManager &buffer_manager, GlobalSortState &state); - - void PinRadix(idx_t block_idx_to); - void PinData(SortedData &sd); - - data_ptr_t RadixPtr() const; - data_ptr_t DataPtr(SortedData &sd) const; - data_ptr_t HeapPtr(SortedData &sd) const; - data_ptr_t BaseHeapPtr(SortedData &sd) const; - - idx_t Remaining() const; - - void SetIndices(idx_t block_idx_to, idx_t entry_idx_to); - -public: - BufferManager &buffer_manager; - const SortLayout &sort_layout; - GlobalSortState &state; - - SortedBlock *sb; - - idx_t block_idx; - idx_t entry_idx; - - BufferHandle radix_handle; - - BufferHandle blob_sorting_data_handle; - BufferHandle blob_sorting_heap_handle; - - BufferHandle payload_data_handle; - BufferHandle payload_heap_handle; -}; - -//! Used to scan the data into DataChunks after sorting -struct PayloadScanner { -public: - PayloadScanner(SortedData &sorted_data, GlobalSortState &global_sort_state, bool flush = true); - explicit PayloadScanner(GlobalSortState &global_sort_state, bool flush = true); - - //! Scan a single block - PayloadScanner(GlobalSortState &global_sort_state, idx_t block_idx, bool flush = false); - - //! The type layout of the payload - inline const vector &GetPayloadTypes() const { - return scanner->GetTypes(); - } - - //! The number of rows scanned so far - inline idx_t Scanned() const { - return scanner->Scanned(); - } - - //! The number of remaining rows - inline idx_t Remaining() const { - return scanner->Remaining(); - } - - //! Scans the next data chunk from the sorted data - void Scan(DataChunk &chunk); - -private: - //! The sorted data being scanned - unique_ptr rows; - unique_ptr heap; - //! The actual scanner - unique_ptr scanner; -}; - -} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/common/sorting/hashed_sort.hpp b/src/duckdb/src/include/duckdb/common/sorting/hashed_sort.hpp index 374133692..50aeeb55d 100644 --- a/src/duckdb/src/include/duckdb/common/sorting/hashed_sort.hpp +++ b/src/duckdb/src/include/duckdb/common/sorting/hashed_sort.hpp @@ -17,6 +17,7 @@ class HashedSort { using Orders = vector; using Types = vector; using HashGroupPtr = unique_ptr; + using SortedRunPtr = unique_ptr; static void GenerateOrderings(Orders &partitions, Orders &orders, const vector> &partition_bys, const Orders &order_bys, @@ -24,7 +25,8 @@ class HashedSort { HashedSort(ClientContext &context, const vector> &partition_bys, const vector &order_bys, const Types &payload_types, - const vector> &partitions_stats, idx_t estimated_cardinality); + const vector> &partitions_stats, idx_t estimated_cardinality, + bool require_payload = false); public: //===--------------------------------------------------------------------===// @@ -37,6 +39,7 @@ class HashedSort { SinkFinalizeType Finalize(ClientContext &client, OperatorSinkFinalizeInput &finalize) const; ProgressData GetSinkProgress(ClientContext &context, GlobalSinkState &gstate, const ProgressData source_progress) const; + void Synchronize(const GlobalSinkState &source, GlobalSinkState &target) const; public: //===--------------------------------------------------------------------===// @@ -49,9 +52,22 @@ class HashedSort { //===--------------------------------------------------------------------===// // Non-Standard Interface //===--------------------------------------------------------------------===// - SinkFinalizeType MaterializeHashGroups(Pipeline &pipeline, Event &event, const PhysicalOperator &op, - OperatorSinkFinalizeInput &finalize) const; - vector &GetHashGroups(GlobalSourceState &global_state) const; + void SortColumnData(ExecutionContext &context, hash_t hash_bin, OperatorSinkFinalizeInput &finalize); + + SourceResultType MaterializeColumnData(ExecutionContext &context, idx_t hash_bin, + OperatorSourceInput &source) const; + HashGroupPtr GetColumnData(idx_t hash_bin, OperatorSourceInput &source) const; + + SourceResultType MaterializeSortedRun(ExecutionContext &context, idx_t hash_bin, OperatorSourceInput &source) const; + SortedRunPtr GetSortedRun(ClientContext &client, idx_t hash_bin, OperatorSourceInput &source) const; + + // The chunk and row counts of the hash groups. + struct ChunkRow { + idx_t chunks = 0; + idx_t count = 0; + }; + using ChunkRows = vector; + const ChunkRows &GetHashGroups(GlobalSourceState &global_state) const; public: ClientContext &client; @@ -63,6 +79,8 @@ class HashedSort { Orders orders; idx_t sort_col_count; Types payload_types; + //! Are we creating a dummy payload column? + bool force_payload = false; // Input columns in the sorted output vector scan_ids; // Key columns in the sorted output diff --git a/src/duckdb/src/include/duckdb/common/sorting/sort.hpp b/src/duckdb/src/include/duckdb/common/sorting/sort.hpp index 597b8261b..de1e33f3b 100644 --- a/src/duckdb/src/include/duckdb/common/sorting/sort.hpp +++ b/src/duckdb/src/include/duckdb/common/sorting/sort.hpp @@ -8,25 +8,44 @@ #pragma once -#include "duckdb/common/sorting/sorted_run.hpp" -#include "duckdb/common/types/row/tuple_data_layout.hpp" #include "duckdb/execution/physical_operator_states.hpp" +#include "duckdb/execution/progress_data.hpp" #include "duckdb/common/sorting/sort_projection_column.hpp" +#include "duckdb/planner/bound_result_modifier.hpp" namespace duckdb { class SortLocalSinkState; class SortGlobalSinkState; + class SortLocalSourceState; class SortGlobalSourceState; +class SortedRun; +class SortedRunScanState; + +class SortedRunMerger; +class SortedRunMergerLocalState; +class SortedRunMergerGlobalState; + +class TupleDataLayout; +class ColumnDataCollection; + //! Class that sorts the data, follows the PhysicalOperator interface class Sort { friend class SortLocalSinkState; friend class SortGlobalSinkState; + friend class SortLocalSourceState; friend class SortGlobalSourceState; + friend class SortedRun; + friend class SortedRunScanState; + + friend class SortedRunMerger; + friend class SortedRunMergerLocalState; + friend class SortedRunMergerGlobalState; + public: Sort(ClientContext &context, const vector &orders, const vector &input_types, vector projection_map, bool is_index_sort = false); @@ -45,7 +64,7 @@ class Sort { vector input_projection_map; vector output_projection_columns; - //! Whether to force an external sort + //! Whether to force an approximate sort bool is_index_sort; public: diff --git a/src/duckdb/src/include/duckdb/common/sorting/sort_key.hpp b/src/duckdb/src/include/duckdb/common/sorting/sort_key.hpp index 8d8d86aca..ccabcf21b 100644 --- a/src/duckdb/src/include/duckdb/common/sorting/sort_key.hpp +++ b/src/duckdb/src/include/duckdb/common/sorting/sort_key.hpp @@ -45,7 +45,7 @@ struct SortKey; template struct SortKeyNoPayload { protected: - SortKeyNoPayload() = default; + SortKeyNoPayload() = default; // NOLINT friend SORT_KEY; public: @@ -63,7 +63,7 @@ struct SortKeyNoPayload { template struct SortKeyPayload { protected: - SortKeyPayload() = default; + SortKeyPayload() = default; // NOLINT friend SORT_KEY; public: @@ -93,7 +93,7 @@ inline bool SortKeyLessThan<1>(const uint64_t *const &lhs, const uint64_t *const template struct FixedSortKey : std::conditional, SortKeyNoPayload>::type { protected: - FixedSortKey() = default; + FixedSortKey() = default; // NOLINT friend SORT_KEY; public: @@ -163,7 +163,7 @@ struct FixedSortKey : std::conditional, So template struct VariableSortKey : std::conditional, SortKeyNoPayload>::type { protected: - VariableSortKey() = default; + VariableSortKey() = default; // NOLINT friend SORT_KEY; public: diff --git a/src/duckdb/src/include/duckdb/common/sorting/sorted_run.hpp b/src/duckdb/src/include/duckdb/common/sorting/sorted_run.hpp index fe0d67e32..a5714cf8f 100644 --- a/src/duckdb/src/include/duckdb/common/sorting/sorted_run.hpp +++ b/src/duckdb/src/include/duckdb/common/sorting/sorted_run.hpp @@ -9,18 +9,41 @@ #pragma once #include "duckdb/common/types/row/tuple_data_states.hpp" +#include "duckdb/execution/expression_executor.hpp" namespace duckdb { +class Sort; +class SortedRun; class BufferManager; class DataChunk; class TupleDataCollection; class TupleDataLayout; +class SortedRunScanState { +public: + SortedRunScanState(ClientContext &context, const Sort &sort); + +public: + void Scan(const SortedRun &sorted_run, const Vector &sort_key_pointers, const idx_t &count, DataChunk &chunk); + +private: + template + void TemplatedScan(const SortedRun &sorted_run, const Vector &sort_key_pointers, const idx_t &count, + DataChunk &chunk); + +private: + const Sort &sort; + ExpressionExecutor key_executor; + DataChunk key; + DataChunk decoded_key; + TupleDataScanState payload_state; + vector key_buffer; +}; + class SortedRun { public: - SortedRun(ClientContext &context, shared_ptr key_layout, - shared_ptr payload_layout, bool is_index_sort); + SortedRun(ClientContext &context, const Sort &sort, bool is_index_sort); unique_ptr CreateRunForMaterialization() const; ~SortedRun(); @@ -36,8 +59,13 @@ class SortedRun { //! Size of this sorted run idx_t SizeInBytes() const; +private: + mutex merger_global_state_lock; + unique_ptr merge_global_state; + public: ClientContext &context; + const Sort &sort; //! Key and payload collections (and associated append states) unique_ptr key_data; diff --git a/src/duckdb/src/include/duckdb/common/sorting/sorted_run_merger.hpp b/src/duckdb/src/include/duckdb/common/sorting/sorted_run_merger.hpp index 21a56df83..fd894d698 100644 --- a/src/duckdb/src/include/duckdb/common/sorting/sorted_run_merger.hpp +++ b/src/duckdb/src/include/duckdb/common/sorting/sorted_run_merger.hpp @@ -9,10 +9,10 @@ #pragma once #include "duckdb/execution/physical_operator_states.hpp" -#include "duckdb/common/sorting/sort_projection_column.hpp" namespace duckdb { +class Sort; class TupleDataLayout; struct BoundOrderByNode; struct ProgressData; @@ -24,9 +24,7 @@ class SortedRunMerger { friend class SortedRunMergerGlobalState; public: - SortedRunMerger(const Expression &decode_sort_key, shared_ptr key_layout, - vector> &&sorted_runs, - const vector &output_projection_columns, idx_t partition_size, bool external, + SortedRunMerger(const Sort &sort, vector> &&sorted_runs, idx_t partition_size, bool external, bool is_index_sort); public: @@ -44,14 +42,12 @@ class SortedRunMerger { //===--------------------------------------------------------------------===// // Non-Standard Interface //===--------------------------------------------------------------------===// - SourceResultType MaterializeMerge(ExecutionContext &context, OperatorSourceInput &input) const; - unique_ptr GetMaterialized(GlobalSourceState &global_state); + SourceResultType MaterializeSortedRun(ExecutionContext &context, OperatorSourceInput &input) const; + unique_ptr GetSortedRun(GlobalSourceState &global_state); public: - const Expression &decode_sort_key; - shared_ptr key_layout; + const Sort &sort; vector> sorted_runs; - const vector &output_projection_columns; const idx_t total_count; const idx_t partition_size; diff --git a/src/duckdb/src/include/duckdb/common/string_map_set.hpp b/src/duckdb/src/include/duckdb/common/string_map_set.hpp index 00600c421..40bd51171 100644 --- a/src/duckdb/src/include/duckdb/common/string_map_set.hpp +++ b/src/duckdb/src/include/duckdb/common/string_map_set.hpp @@ -28,9 +28,26 @@ struct StringEquality { } }; +struct StringCIHash { + std::size_t operator()(const string_t &k) const { + return StringUtil::CIHash(k.GetData(), k.GetSize()); + } +}; + +struct StringCIEquality { + bool operator()(const string_t &a, const string_t &b) const { + return StringUtil::CIEquals(a.GetData(), a.GetSize(), b.GetData(), b.GetSize()); + } +}; + template using string_map_t = unordered_map; using string_set_t = unordered_set; +template +using case_insensitive_string_map_t = unordered_map; + +using case_insensitive_string_set_t = unordered_set; + } // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/common/string_util.hpp b/src/duckdb/src/include/duckdb/common/string_util.hpp index 8c0c19bef..87cd19f3f 100644 --- a/src/duckdb/src/include/duckdb/common/string_util.hpp +++ b/src/duckdb/src/include/duckdb/common/string_util.hpp @@ -217,6 +217,7 @@ class StringUtil { //! Case insensitive hash DUCKDB_API static uint64_t CIHash(const string &str); + DUCKDB_API static uint64_t CIHash(const char *str, idx_t size); //! Case insensitive equals DUCKDB_API static bool CIEquals(const string &l1, const string &l2); @@ -299,6 +300,15 @@ class StringUtil { } return strcmp(s1, s2) == 0; } + static bool Equals(const string &s1, const char *s2) { + return Equals(s1.c_str(), s2); + } + static bool Equals(const char *s1, const string &s2) { + return Equals(s1, s2.c_str()); + } + static bool Equals(const string &s1, const string &s2) { + return s1 == s2; + } //! JSON method that parses a { string: value } JSON blob //! NOTE: this method is not efficient @@ -318,6 +328,8 @@ class StringUtil { //! Transforms an complex JSON to a JSON string DUCKDB_API static string ToComplexJSONMap(const ComplexJSON &complex_json); + DUCKDB_API static string ValidateJSON(const char *data, const idx_t &len); + DUCKDB_API static string GetFileName(const string &file_path); DUCKDB_API static string GetFileExtension(const string &file_name); DUCKDB_API static string GetFileStem(const string &file_name); diff --git a/src/duckdb/src/include/duckdb/common/tree_renderer/mermaid_tree_renderer.hpp b/src/duckdb/src/include/duckdb/common/tree_renderer/mermaid_tree_renderer.hpp new file mode 100644 index 000000000..63d87e77e --- /dev/null +++ b/src/duckdb/src/include/duckdb/common/tree_renderer/mermaid_tree_renderer.hpp @@ -0,0 +1,44 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/common/mermaid_tree_renderer.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/common/constants.hpp" +#include "duckdb/common/vector.hpp" +#include "duckdb/main/profiling_node.hpp" +#include "duckdb/common/tree_renderer.hpp" +#include "duckdb/common/render_tree.hpp" + +namespace duckdb { +class LogicalOperator; +class PhysicalOperator; +class Pipeline; +struct PipelineRenderNode; + +class MermaidTreeRenderer : public TreeRenderer { +public: + explicit MermaidTreeRenderer() { + } + ~MermaidTreeRenderer() override { + } + +public: + string ToString(const LogicalOperator &op); + string ToString(const PhysicalOperator &op); + string ToString(const ProfilingNode &op); + string ToString(const Pipeline &op); + + void Render(const LogicalOperator &op, std::ostream &ss); + void Render(const PhysicalOperator &op, std::ostream &ss); + void Render(const ProfilingNode &op, std::ostream &ss) override; + void Render(const Pipeline &op, std::ostream &ss); + + void ToStreamInternal(RenderTree &root, std::ostream &ss) override; +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/common/type_util.hpp b/src/duckdb/src/include/duckdb/common/type_util.hpp index 40a3eb872..8c0e7ddc9 100644 --- a/src/duckdb/src/include/duckdb/common/type_util.hpp +++ b/src/duckdb/src/include/duckdb/common/type_util.hpp @@ -22,60 +22,62 @@ struct bignum_t; //! Returns the PhysicalType for the given type template PhysicalType GetTypeId() { - if (std::is_same()) { + using TYPE = typename std::remove_cv::type; + + if (std::is_same()) { return PhysicalType::BOOL; - } else if (std::is_same()) { + } else if (std::is_same()) { return PhysicalType::INT8; - } else if (std::is_same()) { + } else if (std::is_same()) { return PhysicalType::INT16; - } else if (std::is_same()) { + } else if (std::is_same()) { return PhysicalType::INT32; - } else if (std::is_same()) { + } else if (std::is_same()) { return PhysicalType::INT64; - } else if (std::is_same()) { + } else if (std::is_same()) { return PhysicalType::UINT8; - } else if (std::is_same()) { + } else if (std::is_same()) { return PhysicalType::UINT16; - } else if (std::is_same()) { + } else if (std::is_same()) { return PhysicalType::UINT32; - } else if (std::is_same()) { + } else if (std::is_same()) { return PhysicalType::UINT64; - } else if (std::is_same() || std::is_same()) { + } else if (std::is_same() || std::is_same()) { return PhysicalType::UINT64; - } else if (std::is_same()) { + } else if (std::is_same()) { return PhysicalType::INT128; - } else if (std::is_same()) { + } else if (std::is_same()) { return PhysicalType::UINT128; - } else if (std::is_same()) { + } else if (std::is_same()) { return PhysicalType::INT32; - } else if (std::is_same()) { + } else if (std::is_same()) { return PhysicalType::INT64; - } else if (std::is_same()) { + } else if (std::is_same()) { return PhysicalType::INT64; - } else if (std::is_same()) { + } else if (std::is_same()) { return PhysicalType::INT64; - } else if (std::is_same()) { + } else if (std::is_same()) { return PhysicalType::INT64; - } else if (std::is_same()) { + } else if (std::is_same()) { return PhysicalType::INT64; - } else if (std::is_same()) { + } else if (std::is_same()) { return PhysicalType::INT64; - } else if (std::is_same()) { + } else if (std::is_same()) { return PhysicalType::INT64; - } else if (std::is_same()) { + } else if (std::is_same()) { return PhysicalType::INT64; - } else if (std::is_same() || std::is_same()) { + } else if (std::is_same() || std::is_same()) { return PhysicalType::FLOAT; - } else if (std::is_same() || std::is_same()) { + } else if (std::is_same() || std::is_same()) { return PhysicalType::DOUBLE; - } else if (std::is_same() || std::is_same() || std::is_same() || - std::is_same()) { + } else if (std::is_same() || std::is_same() || std::is_same() || + std::is_same()) { return PhysicalType::VARCHAR; - } else if (std::is_same()) { + } else if (std::is_same()) { return PhysicalType::INTERVAL; - } else if (std::is_same()) { + } else if (std::is_same()) { return PhysicalType::LIST; - } else if (std::is_pointer() || std::is_same()) { + } else if (std::is_pointer() || std::is_same()) { if (sizeof(uintptr_t) == sizeof(uint32_t)) { return PhysicalType::UINT32; } else if (sizeof(uintptr_t) == sizeof(uint64_t)) { @@ -90,10 +92,12 @@ PhysicalType GetTypeId() { template bool StorageTypeCompatible(PhysicalType type) { - if (std::is_same()) { + using TYPE = typename std::remove_cv::type; + + if (std::is_same()) { return type == PhysicalType::INT8 || type == PhysicalType::BOOL; } - if (std::is_same()) { + if (std::is_same()) { return type == PhysicalType::UINT8 || type == PhysicalType::BOOL; } return type == GetTypeId(); @@ -101,8 +105,10 @@ bool StorageTypeCompatible(PhysicalType type) { template bool TypeIsNumber() { - return std::is_integral() || std::is_floating_point() || std::is_same() || - std::is_same(); + using TYPE = typename std::remove_cv::type; + + return std::is_integral() || std::is_floating_point() || std::is_same() || + std::is_same(); } template diff --git a/src/duckdb/src/include/duckdb/common/types.hpp b/src/duckdb/src/include/duckdb/common/types.hpp index 0f7ddbb2d..6d85ce2de 100644 --- a/src/duckdb/src/include/duckdb/common/types.hpp +++ b/src/duckdb/src/include/duckdb/common/types.hpp @@ -230,6 +230,8 @@ enum class LogicalTypeId : uint8_t { VALIDITY = 53, UUID = 54, + GEOMETRY = 60, + STRUCT = 100, LIST = 101, MAP = 102, @@ -430,6 +432,7 @@ struct LogicalType { DUCKDB_API static LogicalType UNION(child_list_t members); // NOLINT DUCKDB_API static LogicalType ARRAY(const LogicalType &child, optional_idx index); // NOLINT DUCKDB_API static LogicalType ENUM(Vector &ordered_data, idx_t size); // NOLINT + DUCKDB_API static LogicalType GEOMETRY(); // NOLINT // ANY but with special rules (default is LogicalType::ANY, 5) DUCKDB_API static LogicalType ANY_PARAMS(LogicalType target, idx_t cast_score = 5); // NOLINT DUCKDB_API static LogicalType TEMPLATE(const string &name); // NOLINT diff --git a/src/duckdb/src/include/duckdb/common/types/batched_data_collection.hpp b/src/duckdb/src/include/duckdb/common/types/batched_data_collection.hpp index 8a5cc9e19..4f1d8f2b0 100644 --- a/src/duckdb/src/include/duckdb/common/types/batched_data_collection.hpp +++ b/src/duckdb/src/include/duckdb/common/types/batched_data_collection.hpp @@ -1,7 +1,7 @@ //===----------------------------------------------------------------------===// // DuckDB // -// duckdb/common/types/batched_chunk_collection.hpp +// duckdb/common/types/batched_data_collection.hpp // // //===----------------------------------------------------------------------===// @@ -10,8 +10,10 @@ #include "duckdb/common/map.hpp" #include "duckdb/common/types/column/column_data_collection.hpp" +#include "duckdb/main/query_parameters.hpp" namespace duckdb { + class BufferManager; class ClientContext; @@ -32,9 +34,16 @@ struct BatchedChunkScanState { //! Scans over a BatchedDataCollection are ordered by batch index class BatchedDataCollection { public: - DUCKDB_API BatchedDataCollection(ClientContext &context, vector types, bool buffer_managed = false); - DUCKDB_API BatchedDataCollection(ClientContext &context, vector types, batch_map_t batches, - bool buffer_managed = false); + DUCKDB_API + BatchedDataCollection(ClientContext &context, vector types, + ColumnDataAllocatorType allocator_type = ColumnDataAllocatorType::IN_MEMORY_ALLOCATOR, + ColumnDataCollectionLifetime lifetime = ColumnDataCollectionLifetime::REGULAR); + DUCKDB_API + BatchedDataCollection(ClientContext &context, vector types, QueryResultMemoryType memory_type); + DUCKDB_API + BatchedDataCollection(ClientContext &context, vector types, batch_map_t batches, + ColumnDataAllocatorType allocator_type = ColumnDataAllocatorType::IN_MEMORY_ALLOCATOR, + ColumnDataCollectionLifetime lifetime = ColumnDataCollectionLifetime::REGULAR); //! Appends a datachunk with the given batch index to the batched collection DUCKDB_API void Append(DataChunk &input, idx_t batch_index); @@ -79,6 +88,8 @@ class BatchedDataCollection { DUCKDB_API void Print() const; private: + unique_ptr CreateCollection() const; + struct CachedCollection { idx_t batch_index = DConstants::INVALID_INDEX; ColumnDataCollection *collection = nullptr; @@ -87,7 +98,8 @@ class BatchedDataCollection { ClientContext &context; vector types; - bool buffer_managed; + ColumnDataAllocatorType allocator_type; + ColumnDataCollectionLifetime lifetime; //! The data of the batched chunk collection - a set of batch_index -> ColumnDataCollection pointers map> data; //! The last batch collection that was inserted into diff --git a/src/duckdb/src/include/duckdb/common/types/column/column_data_allocator.hpp b/src/duckdb/src/include/duckdb/common/types/column/column_data_allocator.hpp index 564ca5c09..6f2c8a1c1 100644 --- a/src/duckdb/src/include/duckdb/common/types/column/column_data_allocator.hpp +++ b/src/duckdb/src/include/duckdb/common/types/column/column_data_allocator.hpp @@ -9,6 +9,7 @@ #pragma once #include "duckdb/common/types/column/column_data_collection.hpp" +#include "duckdb/main/result_set_manager.hpp" namespace duckdb { @@ -17,21 +18,31 @@ struct VectorMetaData; struct SwizzleMetaData; struct BlockMetaData { - //! The underlying block handle - shared_ptr handle; +public: //! How much space is currently used within the block uint32_t size; //! How much space is available in the block uint32_t capacity; +private: + //! The underlying block handle + shared_ptr handle; + //! Weak pointer to underlying block handle (if ColumnDataCollectionLifetime::DATABASE_INSTANCE) + weak_ptr weak_handle; + +public: + shared_ptr GetHandle() const; + void SetHandle(ManagedResultSet &managed_result_set, shared_ptr handle); uint32_t Capacity(); }; class ColumnDataAllocator { public: explicit ColumnDataAllocator(Allocator &allocator); - explicit ColumnDataAllocator(BufferManager &buffer_manager); - ColumnDataAllocator(ClientContext &context, ColumnDataAllocatorType allocator_type); + explicit ColumnDataAllocator(BufferManager &buffer_manager, + ColumnDataCollectionLifetime lifetime = ColumnDataCollectionLifetime::REGULAR); + ColumnDataAllocator(ClientContext &context, ColumnDataAllocatorType allocator_type, + ColumnDataCollectionLifetime lifetime = ColumnDataCollectionLifetime::REGULAR); ColumnDataAllocator(ColumnDataAllocator &allocator); ~ColumnDataAllocator(); @@ -81,6 +92,8 @@ class ColumnDataAllocator { //! Prevents the block with the given id from being added to the eviction queue void SetDestroyBufferUponUnpin(uint32_t block_id); + //! Gets a shared pointer to the database instance if ColumnDataCollectionLifetime::DATABASE_INSTANCE + shared_ptr GetDatabase() const; private: void AllocateEmptyBlock(idx_t size); @@ -116,6 +129,8 @@ class ColumnDataAllocator { idx_t allocated_size = 0; //! Partition index (optional, if partitioned) optional_idx partition_index; + //! Lifetime management for this allocator + ManagedResultSet managed_result_set; }; } // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/common/types/column/column_data_collection.hpp b/src/duckdb/src/include/duckdb/common/types/column/column_data_collection.hpp index f02d49001..6cb1d7bdd 100644 --- a/src/duckdb/src/include/duckdb/common/types/column/column_data_collection.hpp +++ b/src/duckdb/src/include/duckdb/common/types/column/column_data_collection.hpp @@ -8,10 +8,10 @@ #pragma once -#include "duckdb/common/pair.hpp" #include "duckdb/common/types/column/column_data_collection_iterators.hpp" namespace duckdb { + class BufferManager; class BlockHandle; class ClientContext; @@ -30,10 +30,14 @@ class ColumnDataCollection { //! Constructs an empty (but valid) in-memory column data collection from an allocator DUCKDB_API explicit ColumnDataCollection(Allocator &allocator); //! Constructs a buffer-managed column data collection - DUCKDB_API ColumnDataCollection(BufferManager &buffer_manager, vector types); + DUCKDB_API + ColumnDataCollection(BufferManager &buffer_manager, vector types, + ColumnDataCollectionLifetime lifetime = ColumnDataCollectionLifetime::REGULAR); //! Constructs either an in-memory or a buffer-managed column data collection - DUCKDB_API ColumnDataCollection(ClientContext &context, vector types, - ColumnDataAllocatorType type = ColumnDataAllocatorType::BUFFER_MANAGER_ALLOCATOR); + DUCKDB_API + ColumnDataCollection(ClientContext &context, vector types, + ColumnDataAllocatorType type = ColumnDataAllocatorType::BUFFER_MANAGER_ALLOCATOR, + ColumnDataCollectionLifetime lifetime = ColumnDataCollectionLifetime::REGULAR); //! Creates a column data collection that inherits the blocks to write to. This allows blocks to be shared //! between multiple column data collections and prevents wasting space. //! Note that after one CDC inherits blocks from another, the other @@ -78,6 +82,7 @@ class ColumnDataCollection { //! Initializes a chunk with the correct types that can be used to call Scan DUCKDB_API void InitializeScanChunk(DataChunk &chunk) const; + DUCKDB_API void InitializeScanChunk(Allocator &allocator, DataChunk &chunk) const; //! Initializes a chunk with the correct types for a given scan state DUCKDB_API void InitializeScanChunk(ColumnDataScanState &state, DataChunk &chunk) const; //! Initializes a Scan state for scanning all columns @@ -161,6 +166,8 @@ class ColumnDataCollection { vector> GetHeapReferences(); //! Get the allocator type of this ColumnDataCollection ColumnDataAllocatorType GetAllocatorType() const; + //! Get the buffer manager of the allocator + BufferManager &GetBufferManager() const; //! Get a vector of the segments in this ColumnDataCollection const vector> &GetSegments() const; @@ -194,7 +201,9 @@ class ColumnDataCollection { //! The ColumnDataRowCollection represents a set of materialized rows, as obtained from the ColumnDataCollection class ColumnDataRowCollection { public: - DUCKDB_API explicit ColumnDataRowCollection(const ColumnDataCollection &collection); + DUCKDB_API explicit ColumnDataRowCollection( + const ColumnDataCollection &collection, + ColumnDataScanProperties properties = ColumnDataScanProperties::DISALLOW_ZERO_COPY); public: DUCKDB_API Value GetValue(idx_t column, idx_t index) const; diff --git a/src/duckdb/src/include/duckdb/common/types/column/column_data_collection_iterators.hpp b/src/duckdb/src/include/duckdb/common/types/column/column_data_collection_iterators.hpp index b84b81d47..ff42eadf2 100644 --- a/src/duckdb/src/include/duckdb/common/types/column/column_data_collection_iterators.hpp +++ b/src/duckdb/src/include/duckdb/common/types/column/column_data_collection_iterators.hpp @@ -63,7 +63,9 @@ class ColumnDataRowIterationHelper { class ColumnDataRowIterator { public: - DUCKDB_API explicit ColumnDataRowIterator(const ColumnDataCollection *collection_p); + DUCKDB_API explicit ColumnDataRowIterator( + const ColumnDataCollection *collection_p, + ColumnDataScanProperties properties = ColumnDataScanProperties::DISALLOW_ZERO_COPY); const ColumnDataCollection *collection; ColumnDataScanState scan_state; diff --git a/src/duckdb/src/include/duckdb/common/types/column/column_data_scan_states.hpp b/src/duckdb/src/include/duckdb/common/types/column/column_data_scan_states.hpp index c809520c6..d544db851 100644 --- a/src/duckdb/src/include/duckdb/common/types/column/column_data_scan_states.hpp +++ b/src/duckdb/src/include/duckdb/common/types/column/column_data_scan_states.hpp @@ -35,6 +35,14 @@ enum class ColumnDataScanProperties : uint8_t { DISALLOW_ZERO_COPY }; +enum class ColumnDataCollectionLifetime { + //! Regular lifetime management + REGULAR, + //! Accessing will throw an error after the DB closes + //! Optional for ColumnDataAllocatorType::BUFFER_MANAGER_ALLOCATOR only + THROW_ERROR_AFTER_DATABASE_CLOSES, +}; + struct ChunkManagementState { unordered_map handles; ColumnDataScanProperties properties = ColumnDataScanProperties::INVALID; @@ -46,6 +54,9 @@ struct ColumnDataAppendState { }; struct ColumnDataScanState { + //! Database instance if scanning ColumnDataCollectionLifetime::DATABASE_INSTANCE + shared_ptr db; + ChunkManagementState current_chunk_state; idx_t segment_index; idx_t chunk_index; diff --git a/src/duckdb/src/include/duckdb/common/types/geometry.hpp b/src/duckdb/src/include/duckdb/common/types/geometry.hpp new file mode 100644 index 000000000..ea0e2492d --- /dev/null +++ b/src/duckdb/src/include/duckdb/common/types/geometry.hpp @@ -0,0 +1,224 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/common/types/geometry.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/common/common.hpp" +#include "duckdb/common/types.hpp" +#include "duckdb/common/pair.hpp" +#include +#include + +namespace duckdb { + +struct GeometryStatsData; + +enum class GeometryType : uint8_t { + INVALID = 0, + POINT = 1, + LINESTRING = 2, + POLYGON = 3, + MULTIPOINT = 4, + MULTILINESTRING = 5, + MULTIPOLYGON = 6, + GEOMETRYCOLLECTION = 7, +}; + +enum class VertexType : uint8_t { XY = 0, XYZ = 1, XYM = 2, XYZM = 3 }; + +struct VertexXY { + static constexpr auto TYPE = VertexType::XY; + static constexpr auto HAS_Z = false; + static constexpr auto HAS_M = false; + + double x; + double y; + + bool AllNan() const { + return std::isnan(x) && std::isnan(y); + } +}; + +struct VertexXYZ { + static constexpr auto TYPE = VertexType::XYZ; + static constexpr auto HAS_Z = true; + static constexpr auto HAS_M = false; + + double x; + double y; + double z; + + bool AllNan() const { + return std::isnan(x) && std::isnan(y) && std::isnan(z); + } +}; +struct VertexXYM { + static constexpr auto TYPE = VertexType::XYM; + static constexpr auto HAS_M = true; + static constexpr auto HAS_Z = false; + + double x; + double y; + double m; + + bool AllNan() const { + return std::isnan(x) && std::isnan(y) && std::isnan(m); + } +}; + +struct VertexXYZM { + static constexpr auto TYPE = VertexType::XYZM; + static constexpr auto HAS_Z = true; + static constexpr auto HAS_M = true; + + double x; + double y; + double z; + double m; + + bool AllNan() const { + return std::isnan(x) && std::isnan(y) && std::isnan(z) && std::isnan(m); + } +}; + +class GeometryExtent { +public: + static constexpr auto UNKNOWN_MIN = -std::numeric_limits::infinity(); + static constexpr auto UNKNOWN_MAX = +std::numeric_limits::infinity(); + + static constexpr auto EMPTY_MIN = +std::numeric_limits::infinity(); + static constexpr auto EMPTY_MAX = -std::numeric_limits::infinity(); + + // "Unknown" extent means we don't know the bounding box. + // Merging with an unknown extent results in an unknown extent. + // Everything intersects with an unknown extent. + static GeometryExtent Unknown() { + return GeometryExtent {UNKNOWN_MIN, UNKNOWN_MIN, UNKNOWN_MIN, UNKNOWN_MIN, + UNKNOWN_MAX, UNKNOWN_MAX, UNKNOWN_MAX, UNKNOWN_MAX}; + } + + // "Empty" extent means the smallest possible bounding box. + // Merging with an empty extent has no effect. + // Nothing intersects with an empty extent. + static GeometryExtent Empty() { + return GeometryExtent {EMPTY_MIN, EMPTY_MIN, EMPTY_MIN, EMPTY_MIN, EMPTY_MAX, EMPTY_MAX, EMPTY_MAX, EMPTY_MAX}; + } + + // Does this extent have any X/Y values set? + // In other words, is the range of the x/y axes not empty and not unknown? + bool HasXY() const { + return std::isfinite(x_min) && std::isfinite(y_min) && std::isfinite(x_max) && std::isfinite(y_max); + } + // Does this extent have any Z values set? + // In other words, is the range of the Z-axis not empty and not unknown? + bool HasZ() const { + return std::isfinite(z_min) && std::isfinite(z_max); + } + // Does this extent have any M values set? + // In other words, is the range of the M-axis not empty and not unknown? + bool HasM() const { + return std::isfinite(m_min) && std::isfinite(m_max); + } + + void Extend(const VertexXY &vertex) { + x_min = MinValue(x_min, vertex.x); + x_max = MaxValue(x_max, vertex.x); + y_min = MinValue(y_min, vertex.y); + y_max = MaxValue(y_max, vertex.y); + } + + void Extend(const VertexXYZ &vertex) { + x_min = MinValue(x_min, vertex.x); + x_max = MaxValue(x_max, vertex.x); + y_min = MinValue(y_min, vertex.y); + y_max = MaxValue(y_max, vertex.y); + z_min = MinValue(z_min, vertex.z); + z_max = MaxValue(z_max, vertex.z); + } + + void Extend(const VertexXYM &vertex) { + x_min = MinValue(x_min, vertex.x); + x_max = MaxValue(x_max, vertex.x); + y_min = MinValue(y_min, vertex.y); + y_max = MaxValue(y_max, vertex.y); + m_min = MinValue(m_min, vertex.m); + m_max = MaxValue(m_max, vertex.m); + } + + void Extend(const VertexXYZM &vertex) { + x_min = MinValue(x_min, vertex.x); + x_max = MaxValue(x_max, vertex.x); + y_min = MinValue(y_min, vertex.y); + y_max = MaxValue(y_max, vertex.y); + z_min = MinValue(z_min, vertex.z); + z_max = MaxValue(z_max, vertex.z); + m_min = MinValue(m_min, vertex.m); + m_max = MaxValue(m_max, vertex.m); + } + + void Merge(const GeometryExtent &other) { + x_min = MinValue(x_min, other.x_min); + y_min = MinValue(y_min, other.y_min); + z_min = MinValue(z_min, other.z_min); + m_min = MinValue(m_min, other.m_min); + + x_max = MaxValue(x_max, other.x_max); + y_max = MaxValue(y_max, other.y_max); + z_max = MaxValue(z_max, other.z_max); + m_max = MaxValue(m_max, other.m_max); + } + + bool IntersectsXY(const GeometryExtent &other) const { + return !(x_min > other.x_max || x_max < other.x_min || y_min > other.y_max || y_max < other.y_min); + } + + bool IntersectsXYZM(const GeometryExtent &other) const { + return !(x_min > other.x_max || x_max < other.x_min || y_min > other.y_max || y_max < other.y_min || + z_min > other.z_max || z_max < other.z_min || m_min > other.m_max || m_max < other.m_min); + } + + bool ContainsXY(const GeometryExtent &other) const { + return x_min <= other.x_min && x_max >= other.x_max && y_min <= other.y_min && y_max >= other.y_max; + } + + double x_min; + double y_min; + double z_min; + double m_min; + + double x_max; + double y_max; + double z_max; + double m_max; +}; + +class Geometry { +public: + static constexpr idx_t MAX_RECURSION_DEPTH = 16; + + //! Convert from WKT + DUCKDB_API static bool FromString(const string_t &wkt_text, string_t &result, Vector &result_vector, bool strict); + + //! Convert to WKT + DUCKDB_API static string_t ToString(Vector &result, const string_t &geom); + + //! Convert from WKB + DUCKDB_API static bool FromBinary(const string_t &wkb, string_t &result, Vector &result_vector, bool strict); + DUCKDB_API static void FromBinary(Vector &source, Vector &result, idx_t count, bool strict); + + //! Convert to WKB + DUCKDB_API static void ToBinary(Vector &source, Vector &result, idx_t count); + + //! Get the geometry type and vertex type from the WKB + DUCKDB_API static pair GetType(const string_t &wkb); + + //! Update the bounding box, return number of vertices processed + DUCKDB_API static uint32_t GetExtent(const string_t &wkb, GeometryExtent &extent); +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/common/types/hugeint.hpp b/src/duckdb/src/include/duckdb/common/types/hugeint.hpp index 3720bf844..9fa5d447b 100644 --- a/src/duckdb/src/include/duckdb/common/types/hugeint.hpp +++ b/src/duckdb/src/include/duckdb/common/types/hugeint.hpp @@ -129,38 +129,38 @@ class Hugeint { static int Sign(hugeint_t n); static hugeint_t Abs(hugeint_t n); // comparison operators - static bool Equals(hugeint_t lhs, hugeint_t rhs) { + static bool Equals(const hugeint_t &lhs, const hugeint_t &rhs) { bool lower_equals = lhs.lower == rhs.lower; bool upper_equals = lhs.upper == rhs.upper; return lower_equals && upper_equals; } - static bool NotEquals(hugeint_t lhs, hugeint_t rhs) { + static bool NotEquals(const hugeint_t &lhs, const hugeint_t &rhs) { return !Equals(lhs, rhs); } - static bool GreaterThan(hugeint_t lhs, hugeint_t rhs) { + static bool GreaterThan(const hugeint_t &lhs, const hugeint_t &rhs) { bool upper_bigger = lhs.upper > rhs.upper; bool upper_equal = lhs.upper == rhs.upper; bool lower_bigger = lhs.lower > rhs.lower; return upper_bigger || (upper_equal && lower_bigger); } - static bool GreaterThanEquals(hugeint_t lhs, hugeint_t rhs) { + static bool GreaterThanEquals(const hugeint_t &lhs, const hugeint_t &rhs) { bool upper_bigger = lhs.upper > rhs.upper; bool upper_equal = lhs.upper == rhs.upper; bool lower_bigger_equals = lhs.lower >= rhs.lower; return upper_bigger || (upper_equal && lower_bigger_equals); } - static bool LessThan(hugeint_t lhs, hugeint_t rhs) { + static bool LessThan(const hugeint_t &lhs, const hugeint_t &rhs) { bool upper_smaller = lhs.upper < rhs.upper; bool upper_equal = lhs.upper == rhs.upper; bool lower_smaller = lhs.lower < rhs.lower; return upper_smaller || (upper_equal && lower_smaller); } - static bool LessThanEquals(hugeint_t lhs, hugeint_t rhs) { + static bool LessThanEquals(const hugeint_t &lhs, const hugeint_t &rhs) { bool upper_smaller = lhs.upper < rhs.upper; bool upper_equal = lhs.upper == rhs.upper; bool lower_smaller_equals = lhs.lower <= rhs.lower; diff --git a/src/duckdb/src/include/duckdb/common/types/row/block_iterator.hpp b/src/duckdb/src/include/duckdb/common/types/row/block_iterator.hpp index c29b094a8..ed291df39 100644 --- a/src/duckdb/src/include/duckdb/common/types/row/block_iterator.hpp +++ b/src/duckdb/src/include/duckdb/common/types/row/block_iterator.hpp @@ -23,64 +23,94 @@ enum class BlockIteratorStateType : int8_t { EXTERNAL, }; -BlockIteratorStateType GetBlockIteratorStateType(const bool &external); - -//! State for iterating over blocks of an in-memory TupleDataCollection -//! Multiple iterators can share the same state, everything is const -class InMemoryBlockIteratorState { -public: - explicit InMemoryBlockIteratorState(const TupleDataCollection &key_data); - -public: - template - T &GetValueAtIndex(const idx_t &block_idx, const idx_t &tuple_idx) const { - D_ASSERT(GetIndex(block_idx, tuple_idx) < tuple_count); - return reinterpret_cast(block_ptrs[block_idx])[tuple_idx]; +template +class BlockIteratorStateBase { +protected: + friend BLOCK_ITERATOR_STATE; + explicit BlockIteratorStateBase(const idx_t tuple_count_p) : tuple_count(tuple_count_p) { } - template - T &GetValueAtIndex(const idx_t &n) const { - const auto quotient = fast_mod.Div(n); - return GetValueAtIndex(quotient, fast_mod.Mod(n, quotient)); +public: + idx_t GetDivisor() const { + const auto &state = static_cast(*this); + return state.GetDivisor(); } - void RandomAccess(idx_t &block_idx, idx_t &tuple_idx, const idx_t &index) const { - block_idx = fast_mod.Div(index); - tuple_idx = fast_mod.Mod(index, block_idx); + void RandomAccess(idx_t &block_or_chunk_idx, idx_t &tuple_idx, const idx_t &index) const { + const auto &state = static_cast(*this); + state.RandomAccessInternal(block_or_chunk_idx, tuple_idx, index); } - void Add(idx_t &block_idx, idx_t &tuple_idx, const idx_t &value) const { + void Add(idx_t &block_or_chunk_idx, idx_t &tuple_idx, const idx_t &value) const { tuple_idx += value; - if (tuple_idx >= fast_mod.GetDivisor()) { - const auto div = fast_mod.Div(tuple_idx); - tuple_idx -= div * fast_mod.GetDivisor(); - block_idx += div; + if (tuple_idx >= GetDivisor()) { + RandomAccess(block_or_chunk_idx, tuple_idx, GetIndex(block_or_chunk_idx, tuple_idx)); } } - void Subtract(idx_t &block_idx, idx_t &tuple_idx, const idx_t &value) const { + void Subtract(idx_t &block_or_chunk_idx, idx_t &tuple_idx, const idx_t &value) const { tuple_idx -= value; - if (tuple_idx >= fast_mod.GetDivisor()) { - const auto div = fast_mod.Div(-tuple_idx); - tuple_idx += (div + 1) * fast_mod.GetDivisor(); - block_idx -= div + 1; + if (tuple_idx >= GetDivisor()) { + RandomAccess(block_or_chunk_idx, tuple_idx, GetIndex(block_or_chunk_idx, tuple_idx)); } } - void Increment(idx_t &block_idx, idx_t &tuple_idx) const { - const auto passed_boundary = ++tuple_idx == fast_mod.GetDivisor(); - block_idx += passed_boundary; - tuple_idx *= !passed_boundary; + void Increment(idx_t &block_or_chunk_idx, idx_t &tuple_idx) const { + const auto crossed_boundary = ++tuple_idx == GetDivisor(); + block_or_chunk_idx += crossed_boundary; + tuple_idx *= !crossed_boundary; } - void Decrement(idx_t &block_idx, idx_t &tuple_idx) const { + void Decrement(idx_t &block_or_chunk_idx, idx_t &tuple_idx) const { const auto crossed_boundary = tuple_idx-- == 0; - block_idx -= crossed_boundary; - tuple_idx += crossed_boundary * fast_mod.GetDivisor(); + block_or_chunk_idx -= crossed_boundary; + tuple_idx += crossed_boundary * GetDivisor(); } - idx_t GetIndex(const idx_t &block_idx, const idx_t &tuple_idx) const { - return block_idx * fast_mod.GetDivisor() + tuple_idx; + idx_t GetIndex(const idx_t &block_or_chunk_idx, const idx_t &tuple_idx) const { + return block_or_chunk_idx * GetDivisor() + tuple_idx; + } + +protected: + const idx_t tuple_count; +}; + +template +class BlockIteratorState; + +//! State for iterating over blocks of an in-memory TupleDataCollection +//! Multiple iterators can share the same state, everything is const +template <> +class BlockIteratorState + : public BlockIteratorStateBase> { +public: + explicit BlockIteratorState(const TupleDataCollection &key_data) + : BlockIteratorStateBase(key_data.Count()), block_ptrs(ConvertBlockPointers(key_data.GetRowBlockPointers())), + fast_mod(key_data.TuplesPerBlock()) { + } + +public: + idx_t GetDivisor() const { + return fast_mod.GetDivisor(); + } + + void RandomAccessInternal(idx_t &block_idx, idx_t &tuple_idx, const idx_t &index) const { + block_idx = fast_mod.Div(index); + tuple_idx = fast_mod.Mod(index, block_idx); + } + + template + T &GetValueAtIndex(const idx_t &block_idx, const idx_t &tuple_idx) const { + D_ASSERT(GetIndex(block_idx, tuple_idx) < tuple_count); + return reinterpret_cast(block_ptrs[block_idx])[tuple_idx]; + } + + template + T &GetValueAtIndex(const idx_t &index) const { + idx_t block_idx; + idx_t tuple_idx; + RandomAccess(block_idx, tuple_idx, index); + return GetValueAtIndex(block_idx, tuple_idx); } void SetKeepPinned(const bool &) { @@ -92,72 +122,63 @@ class InMemoryBlockIteratorState { } private: - static unsafe_vector ConvertBlockPointers(const vector &block_ptrs); + static unsafe_vector ConvertBlockPointers(const vector &block_ptrs) { + unsafe_vector converted_block_ptrs; + converted_block_ptrs.reserve(block_ptrs.size()); + for (const auto &block_ptr : block_ptrs) { + converted_block_ptrs.emplace_back(block_ptr); + } + return converted_block_ptrs; + } private: const unsafe_vector block_ptrs; const FastMod fast_mod; - const idx_t tuple_count; }; +using InMemoryBlockIteratorState = BlockIteratorState; + //! State for iterating over blocks of an external (larger-than-memory) TupleDataCollection //! This state cannot be shared by multiple iterators, it is stateful -class ExternalBlockIteratorState { +template <> +class BlockIteratorState + : public BlockIteratorStateBase> { public: - explicit ExternalBlockIteratorState(TupleDataCollection &key_data, optional_ptr payload_data); - -public: - template - T &GetValueAtIndex(const idx_t &chunk_idx, const idx_t &tuple_idx) { - if (chunk_idx != current_chunk_idx) { - InitializeChunk(chunk_idx); + explicit BlockIteratorState(TupleDataCollection &key_data_p, optional_ptr payload_data_p) + : BlockIteratorStateBase(key_data_p.Count()), current_chunk_idx(DConstants::INVALID_INDEX), + key_data(key_data_p), key_ptrs(FlatVector::GetData(key_scan_state.chunk_state.row_locations)), + payload_data(payload_data_p), keep_pinned(false), pin_payload(false) { + key_data.InitializeScan(key_scan_state); + if (payload_data) { + payload_data->InitializeScan(payload_scan_state); } - return *reinterpret_cast(key_ptrs)[tuple_idx]; } - template - T &GetValueAtIndex(const idx_t &n) { - D_ASSERT(n < tuple_count); - return GetValueAtIndex(n / STANDARD_VECTOR_SIZE, n % STANDARD_VECTOR_SIZE); +public: + static constexpr idx_t GetDivisor() { + return STANDARD_VECTOR_SIZE; } - static void RandomAccess(idx_t &chunk_idx, idx_t &tuple_idx, const idx_t &index) { + static void RandomAccessInternal(idx_t &chunk_idx, idx_t &tuple_idx, const idx_t &index) { chunk_idx = index / STANDARD_VECTOR_SIZE; tuple_idx = index % STANDARD_VECTOR_SIZE; } - static void Add(idx_t &chunk_idx, idx_t &tuple_idx, const idx_t &value) { - tuple_idx += value; - if (tuple_idx >= STANDARD_VECTOR_SIZE) { - const auto div = tuple_idx / STANDARD_VECTOR_SIZE; - tuple_idx -= div * STANDARD_VECTOR_SIZE; - chunk_idx += div; - } - } - - static void Subtract(idx_t &chunk_idx, idx_t &tuple_idx, const idx_t &value) { - tuple_idx -= value; - if (tuple_idx >= STANDARD_VECTOR_SIZE) { - const auto div = -tuple_idx / STANDARD_VECTOR_SIZE; - tuple_idx += (div + 1) * STANDARD_VECTOR_SIZE; - chunk_idx -= div + 1; + template + T &GetValueAtIndex(const idx_t &chunk_idx, const idx_t &tuple_idx) { + D_ASSERT(GetIndex(chunk_idx, tuple_idx) < tuple_count); + if (chunk_idx != current_chunk_idx) { + InitializeChunk(chunk_idx); } + return *reinterpret_cast(key_ptrs)[tuple_idx]; } - static void Increment(idx_t &chunk_idx, idx_t &tuple_idx) { - const auto passed_boundary = ++tuple_idx == STANDARD_VECTOR_SIZE; - chunk_idx += passed_boundary; - tuple_idx *= !passed_boundary; - } - - static void Decrement(idx_t &chunk_idx, idx_t &tuple_idx) { - const auto crossed_boundary = tuple_idx-- == 0; - chunk_idx -= crossed_boundary; - tuple_idx += crossed_boundary * static_cast(STANDARD_VECTOR_SIZE); - } - - static idx_t GetIndex(const idx_t &chunk_idx, const idx_t &tuple_idx) { - return chunk_idx * STANDARD_VECTOR_SIZE + tuple_idx; + template + T &GetValueAtIndex(const idx_t &index) { + idx_t chunk_idx; + idx_t tuple_idx; + RandomAccess(chunk_idx, tuple_idx, index); + return GetValueAtIndex(chunk_idx, tuple_idx); } void SetKeepPinned(const bool &enable) { @@ -183,15 +204,14 @@ class ExternalBlockIteratorState { key_scan_state.pin_state.row_handles.acquire_handles(pins); key_scan_state.pin_state.heap_handles.acquire_handles(pins); } - key_data.FetchChunk(key_scan_state, 0, chunk_idx, false); + key_data.FetchChunk(key_scan_state, chunk_idx, false); if (pin_payload && payload_data) { if (keep_pinned) { payload_scan_state.pin_state.row_handles.acquire_handles(pins); payload_scan_state.pin_state.heap_handles.acquire_handles(pins); } - const auto chunk_count = payload_data->FetchChunk(payload_scan_state, 0, chunk_idx, false); + const auto chunk_count = payload_data->FetchChunk(payload_scan_state, chunk_idx, false); const auto sort_keys = reinterpret_cast(key_ptrs); - payload_data->FetchChunk(payload_scan_state, 0, chunk_idx, false); const auto payload_ptrs = FlatVector::GetData(payload_scan_state.chunk_state.row_locations); for (idx_t i = 0; i < chunk_count; i++) { sort_keys[i]->SetPayload(payload_ptrs[i]); @@ -201,7 +221,6 @@ class ExternalBlockIteratorState { } private: - const idx_t tuple_count; idx_t current_chunk_idx; TupleDataCollection &key_data; @@ -216,13 +235,7 @@ class ExternalBlockIteratorState { vector pins; }; -//! Utility so we can get the state using the type -template -using BlockIteratorState = typename std::conditional< - T == BlockIteratorStateType::IN_MEMORY, InMemoryBlockIteratorState, - typename std::conditional::type>::type; +using ExternalBlockIteratorState = BlockIteratorState; //! Iterator for data spread out over multiple blocks template @@ -305,16 +318,16 @@ class block_iterator_t { // NOLINT: match stl case return *this; } block_iterator_t operator+(const difference_type &n) const { - idx_t new_block_idx = block_or_chunk_idx; + idx_t new_block_or_chunk_idx = block_or_chunk_idx; idx_t new_tuple_idx = tuple_idx; - state->Add(new_block_idx, new_tuple_idx, n); - return block_iterator_t(*state, new_block_idx, new_tuple_idx); + state->Add(new_block_or_chunk_idx, new_tuple_idx, n); + return block_iterator_t(*state, new_block_or_chunk_idx, new_tuple_idx); } block_iterator_t operator-(const difference_type &n) const { - idx_t new_block_idx = block_or_chunk_idx; + idx_t new_block_or_chunk_idx = block_or_chunk_idx; idx_t new_tuple_idx = tuple_idx; - state->Subtract(new_block_idx, new_tuple_idx, n); - return block_iterator_t(*state, new_block_idx, new_tuple_idx); + state->Subtract(new_block_or_chunk_idx, new_tuple_idx, n); + return block_iterator_t(*state, new_block_or_chunk_idx, new_tuple_idx); } reference operator[](const difference_type &n) const { diff --git a/src/duckdb/src/include/duckdb/common/types/row/partitioned_tuple_data.hpp b/src/duckdb/src/include/duckdb/common/types/row/partitioned_tuple_data.hpp index 42e68e9ef..6eba10145 100644 --- a/src/duckdb/src/include/duckdb/common/types/row/partitioned_tuple_data.hpp +++ b/src/duckdb/src/include/duckdb/common/types/row/partitioned_tuple_data.hpp @@ -162,7 +162,7 @@ class PartitionedTupleData { //! PartitionedTupleData can only be instantiated by derived classes PartitionedTupleData(PartitionedTupleDataType type, BufferManager &buffer_manager, shared_ptr &layout_ptr); - PartitionedTupleData(const PartitionedTupleData &other); + PartitionedTupleData(PartitionedTupleData &other); //! Whether to use fixed size map or regular map bool UseFixedSizeMap() const; @@ -178,17 +178,21 @@ class PartitionedTupleData { template void BuildBufferSpace(PartitionedTupleDataAppendState &state); //! Create a collection for a specific a partition - unique_ptr CreatePartitionCollection(idx_t partition_index) { - return make_uniq(buffer_manager, layout_ptr); + unique_ptr CreatePartitionCollection() { + return make_uniq(buffer_manager, layout_ptr, stl_allocator); } //! Verify count/data size of this PartitionedTupleData void Verify() const; protected: PartitionedTupleDataType type; + BufferManager &buffer_manager; + shared_ptr stl_allocator; + shared_ptr layout_ptr; const TupleDataLayout &layout; + idx_t count; idx_t data_size; diff --git a/src/duckdb/src/include/duckdb/common/types/row/tuple_data_allocator.hpp b/src/duckdb/src/include/duckdb/common/types/row/tuple_data_allocator.hpp index c603baac0..b6d388215 100644 --- a/src/duckdb/src/include/duckdb/common/types/row/tuple_data_allocator.hpp +++ b/src/duckdb/src/include/duckdb/common/types/row/tuple_data_allocator.hpp @@ -10,6 +10,7 @@ #include "duckdb/common/types/row/tuple_data_layout.hpp" #include "duckdb/common/types/row/tuple_data_states.hpp" +#include "duckdb/common/arena_containers/arena_vector.hpp" namespace duckdb { @@ -53,7 +54,8 @@ struct TupleDataBlock { class TupleDataAllocator { public: - TupleDataAllocator(BufferManager &buffer_manager, shared_ptr &layout_ptr); + TupleDataAllocator(BufferManager &buffer_manager, shared_ptr layout_ptr, + shared_ptr stl_allocator); TupleDataAllocator(TupleDataAllocator &allocator); ~TupleDataAllocator(); @@ -62,6 +64,8 @@ class TupleDataAllocator { BufferManager &GetBufferManager(); //! Get the buffer allocator Allocator &GetAllocator(); + //! Get the STL allocator + ArenaAllocator &GetStlAllocator(); //! Get the layout shared_ptr GetLayoutPtr() const; const TupleDataLayout &GetLayout() const; @@ -99,17 +103,22 @@ class TupleDataAllocator { private: //! Builds out a single part (grabs the lock) - TupleDataChunkPart BuildChunkPart(TupleDataPinState &pin_state, TupleDataChunkState &chunk_state, - const idx_t append_offset, const idx_t append_count, TupleDataChunk &chunk); + unsafe_arena_ptr BuildChunkPart(TupleDataSegment &segment, TupleDataPinState &pin_state, + TupleDataChunkState &chunk_state, const idx_t append_offset, + const idx_t append_count, TupleDataChunk &chunk); //! Internal function for InitializeChunkState void InitializeChunkStateInternal(TupleDataPinState &pin_state, TupleDataChunkState &chunk_state, idx_t offset, bool recompute, bool init_heap_pointers, bool init_heap_sizes, unsafe_vector> &parts); //! Internal function for ReleaseOrStoreHandles static void ReleaseOrStoreHandlesInternal(TupleDataSegment &segment, - unsafe_vector &pinned_row_handles, + unsafe_arena_vector &pinned_row_handles, buffer_handle_map_t &handles, const ContinuousIdSet &block_ids, - unsafe_vector &blocks, TupleDataPinProperties properties); + unsafe_arena_vector &blocks, + TupleDataPinProperties properties); + //! Create a row/heap block, extend the pinned handles in the segment accordingly + void CreateRowBlock(TupleDataSegment &segment); + void CreateHeapBlock(TupleDataSegment &segment, idx_t size); //! Pins the given row block BufferHandle &PinRowBlock(TupleDataPinState &state, const TupleDataChunkPart &part); //! Pins the given heap block @@ -120,6 +129,8 @@ class TupleDataAllocator { data_ptr_t GetBaseHeapPointer(TupleDataPinState &state, const TupleDataChunkPart &part); private: + //! Shared allocator for STL allocations + shared_ptr stl_allocator; //! The buffer manager BufferManager &buffer_manager; //! The layout of the data @@ -128,13 +139,9 @@ class TupleDataAllocator { //! Partition index (optional, if partitioned) optional_idx partition_index; //! Blocks storing the fixed-size rows - unsafe_vector row_blocks; + unsafe_arena_vector row_blocks; //! Blocks storing the variable-size data of the fixed-size rows (e.g., string, list) - unsafe_vector heap_blocks; - - //! Re-usable arrays used while building buffer space - unsafe_vector> chunk_parts; - unsafe_vector> chunk_part_indices; + unsafe_arena_vector heap_blocks; }; } // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/common/types/row/tuple_data_collection.hpp b/src/duckdb/src/include/duckdb/common/types/row/tuple_data_collection.hpp index d759341ee..1f98a3fe7 100644 --- a/src/duckdb/src/include/duckdb/common/types/row/tuple_data_collection.hpp +++ b/src/duckdb/src/include/duckdb/common/types/row/tuple_data_collection.hpp @@ -49,7 +49,10 @@ class TupleDataCollection { public: //! Constructs a TupleDataCollection with the specified layout - TupleDataCollection(BufferManager &buffer_manager, shared_ptr layout_ptr); + TupleDataCollection(BufferManager &buffer_manager, shared_ptr layout_ptr, + shared_ptr stl_allocator = nullptr); + TupleDataCollection(ClientContext &context, shared_ptr layout_ptr, + shared_ptr stl_allocator = nullptr); ~TupleDataCollection(); @@ -172,7 +175,7 @@ class TupleDataCollection { //! Initializes a chunk with the correct types that can be used to call Append/Scan for the given columns void InitializeChunk(DataChunk &chunk, const vector &columns) const; //! Initializes a chunk with the correct types for a given scan state - void InitializeScanChunk(TupleDataScanState &state, DataChunk &chunk) const; + void InitializeScanChunk(const TupleDataScanState &state, DataChunk &chunk) const; //! Initializes a Scan state for scanning all columns void InitializeScan(TupleDataScanState &state, TupleDataPinProperties properties = TupleDataPinProperties::UNPIN_AFTER_DONE) const; @@ -185,8 +188,8 @@ class TupleDataCollection { //! Initialize a parallel scan over the tuple data collection over a subset of the columns void InitializeScan(TupleDataParallelScanState &gstate, vector column_ids, TupleDataPinProperties properties = TupleDataPinProperties::UNPIN_AFTER_DONE) const; - //! Grab the chunk state for the given segment and chunk index, returns the count of the chunk - idx_t FetchChunk(TupleDataScanState &state, idx_t segment_idx, idx_t chunk_idx, bool init_heap); + //! Grab the chunk state for the given chunk index, returns the count of the chunk + idx_t FetchChunk(TupleDataScanState &state, idx_t chunk_idx, bool init_heap); //! Scans a DataChunk from the TupleDataCollection bool Scan(TupleDataScanState &state, DataChunk &result); //! Scans a DataChunk from the TupleDataCollection @@ -221,7 +224,7 @@ class TupleDataCollection { //! Gets all column ids void GetAllColumnIDs(vector &column_ids); //! Adds a segment to this TupleDataCollection - void AddSegment(unsafe_unique_ptr segment); + void AddSegment(unsafe_arena_ptr segment); //! Computes the heap sizes for the specific Vector that will be appended static void ComputeHeapSizes(Vector &heap_sizes_v, const Vector &source_v, TupleDataVectorFormat &source, @@ -262,6 +265,8 @@ class TupleDataCollection { void Verify() const; private: + //! Shared allocator for STL allocations + shared_ptr stl_allocator; //! The layout of the TupleDataCollection shared_ptr layout_ptr; const TupleDataLayout &layout; @@ -272,11 +277,11 @@ class TupleDataCollection { //! The size (in bytes) of this TupleDataCollection idx_t data_size; //! The data segments of the TupleDataCollection - unsafe_vector> segments; + unsafe_arena_vector> segments; //! The set of scatter functions - vector scatter_functions; + unsafe_arena_vector scatter_functions; //! The set of gather functions - vector gather_functions; + unsafe_arena_vector gather_functions; //! Partition index (optional, if partitioned) optional_idx partition_index; }; diff --git a/src/duckdb/src/include/duckdb/common/types/row/tuple_data_segment.hpp b/src/duckdb/src/include/duckdb/common/types/row/tuple_data_segment.hpp index 22afdb156..93558050a 100644 --- a/src/duckdb/src/include/duckdb/common/types/row/tuple_data_segment.hpp +++ b/src/duckdb/src/include/duckdb/common/types/row/tuple_data_segment.hpp @@ -14,6 +14,7 @@ #include "duckdb/common/unordered_set.hpp" #include "duckdb/common/vector.hpp" #include "duckdb/storage/buffer_manager.hpp" +#include "duckdb/common/arena_containers/arena_vector.hpp" namespace duckdb { @@ -49,7 +50,7 @@ struct TupleDataChunkPart { idx_t total_heap_size; //! Tuple count for this chunk part uint32_t count; - //! Lock for recomputing heap pointers (owned by TupleDataChunk) + //! Lock for recomputing heap pointers reference lock; private: @@ -113,7 +114,7 @@ class ContinuousIdSet { struct TupleDataChunk { public: - TupleDataChunk(); + explicit TupleDataChunk(mutex &lock_p); //! Disable copy constructors TupleDataChunk(const TupleDataChunk &other) = delete; @@ -124,7 +125,7 @@ struct TupleDataChunk { TupleDataChunk &operator=(TupleDataChunk &&) noexcept; //! Add a part to this chunk - TupleDataChunkPart &AddPart(TupleDataSegment &segment, TupleDataChunkPart &&part); + TupleDataChunkPart &AddPart(TupleDataSegment &segment, unsafe_arena_ptr part_ptr); //! Tries to merge the last chunk part into the second-to-last one void MergeLastChunkPart(TupleDataSegment &segment); //! Verify counts of the parts in this chunk @@ -141,7 +142,7 @@ struct TupleDataChunk { //! Tuple count for this chunk idx_t count; //! Lock for recomputing heap pointers - unsafe_unique_ptr lock; + reference lock; }; struct TupleDataSegment { @@ -171,9 +172,9 @@ struct TupleDataSegment { shared_ptr allocator; const TupleDataLayout &layout; //! The chunks of this segment - unsafe_vector chunks; + unsafe_vector> chunks; //! The chunk parts of this segment - unsafe_vector chunk_parts; + unsafe_vector> chunk_parts; //! The tuple count of this segment idx_t count; //! The data size of this segment @@ -182,9 +183,9 @@ struct TupleDataSegment { //! Lock for modifying pinned_handles mutex pinned_handles_lock; //! Where handles to row blocks will be stored with TupleDataPinProperties::KEEP_EVERYTHING_PINNED - unsafe_vector pinned_row_handles; + unsafe_arena_vector pinned_row_handles; //! Where handles to heap blocks will be stored with TupleDataPinProperties::KEEP_EVERYTHING_PINNED - unsafe_vector pinned_heap_handles; + unsafe_arena_vector pinned_heap_handles; }; } // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/common/types/row/tuple_data_states.hpp b/src/duckdb/src/include/duckdb/common/types/row/tuple_data_states.hpp index bf22cac33..188438527 100644 --- a/src/duckdb/src/include/duckdb/common/types/row/tuple_data_states.hpp +++ b/src/duckdb/src/include/duckdb/common/types/row/tuple_data_states.hpp @@ -124,8 +124,9 @@ struct TupleDataChunkState { vector> cached_cast_vectors; vector> cached_cast_vector_cache; - //! Cached vector (for InitializeChunkState) - unsafe_vector> parts; + //! Re-usable arrays used while building buffer space + unsafe_vector> chunk_parts; + unsafe_vector> chunk_part_indices; }; struct TupleDataAppendState { diff --git a/src/duckdb/src/include/duckdb/common/types/selection_vector.hpp b/src/duckdb/src/include/duckdb/common/types/selection_vector.hpp index ceb5637ac..5575e5a08 100644 --- a/src/duckdb/src/include/duckdb/common/types/selection_vector.hpp +++ b/src/duckdb/src/include/duckdb/common/types/selection_vector.hpp @@ -108,6 +108,7 @@ struct SelectionVector { return selection_data; } buffer_ptr Slice(const SelectionVector &sel, idx_t count) const; + idx_t SliceInPlace(const SelectionVector &sel, idx_t count); string ToString(idx_t count = 0) const; void Print(idx_t count = 0) const; diff --git a/src/duckdb/src/include/duckdb/common/types/value.hpp b/src/duckdb/src/include/duckdb/common/types/value.hpp index 1993d0295..bba9a7297 100644 --- a/src/duckdb/src/include/duckdb/common/types/value.hpp +++ b/src/duckdb/src/include/duckdb/common/types/value.hpp @@ -201,6 +201,8 @@ class Value { DUCKDB_API static Value BIGNUM(const_data_ptr_t data, idx_t len); DUCKDB_API static Value BIGNUM(const string &data); + DUCKDB_API static Value GEOMETRY(const_data_ptr_t data, idx_t len); + //! Creates an aggregate state DUCKDB_API static Value AGGREGATE_STATE(const LogicalType &type, const_data_ptr_t data, idx_t len); // NOLINT diff --git a/src/duckdb/src/include/duckdb/common/types/variant.hpp b/src/duckdb/src/include/duckdb/common/types/variant.hpp index cc8a9ffa6..bef2f2353 100644 --- a/src/duckdb/src/include/duckdb/common/types/variant.hpp +++ b/src/duckdb/src/include/duckdb/common/types/variant.hpp @@ -29,8 +29,18 @@ struct VariantNestedData { }; struct VariantDecimalData { +public: + VariantDecimalData(uint32_t width, uint32_t scale, const_data_ptr_t value_ptr) + : width(width), scale(scale), value_ptr(value_ptr) { + } + +public: + PhysicalType GetPhysicalType() const; + +public: uint32_t width; uint32_t scale; + const_data_ptr_t value_ptr = nullptr; }; struct VariantVectorData { @@ -105,6 +115,7 @@ enum class VariantLogicalType : uint8_t { ARRAY = 30, BIGNUM = 31, BITSTRING = 32, + GEOMETRY = 33, ENUM_SIZE /* always kept as last item of the enum */ }; diff --git a/src/duckdb/src/include/duckdb/common/types/variant_visitor.hpp b/src/duckdb/src/include/duckdb/common/types/variant_visitor.hpp new file mode 100644 index 000000000..950980aef --- /dev/null +++ b/src/duckdb/src/include/duckdb/common/types/variant_visitor.hpp @@ -0,0 +1,232 @@ +#pragma once + +#include "duckdb/common/types/variant.hpp" +#include "duckdb/function/scalar/variant_utils.hpp" +#include "duckdb/common/types.hpp" +#include "duckdb/common/vector.hpp" +#include "duckdb/common/types/decimal.hpp" +#include "duckdb/common/enum_util.hpp" + +#include + +namespace duckdb { + +template +class VariantVisitor { + // Detects if T has a static VisitMetadata with signature + // void VisitMetadata(VariantLogicalType, Args...) + template + class has_visit_metadata { + private: + template + static auto test(int) -> decltype(U::VisitMetadata(std::declval(), std::declval()...), + std::true_type {}); + + template + static std::false_type test(...); + + public: + static constexpr bool value = decltype(test(0))::value; + }; + +public: + template + static ReturnType Visit(const UnifiedVariantVectorData &variant, idx_t row, uint32_t values_idx, Args &&...args) { + if (!variant.RowIsValid(row)) { + return Visitor::VisitNull(std::forward(args)...); + } + + auto type_id = variant.GetTypeId(row, values_idx); + auto byte_offset = variant.GetByteOffset(row, values_idx); + auto blob_data = const_data_ptr_cast(variant.GetData(row).GetData()); + auto ptr = const_data_ptr_cast(blob_data + byte_offset); + + VisitMetadata(type_id, std::forward(args)...); + + switch (type_id) { + case VariantLogicalType::VARIANT_NULL: + return Visitor::VisitNull(std::forward(args)...); + case VariantLogicalType::BOOL_TRUE: + return Visitor::VisitBoolean(true, std::forward(args)...); + case VariantLogicalType::BOOL_FALSE: + return Visitor::VisitBoolean(false, std::forward(args)...); + case VariantLogicalType::INT8: + return Visitor::template VisitInteger(Load(ptr), std::forward(args)...); + case VariantLogicalType::INT16: + return Visitor::template VisitInteger(Load(ptr), std::forward(args)...); + case VariantLogicalType::INT32: + return Visitor::template VisitInteger(Load(ptr), std::forward(args)...); + case VariantLogicalType::INT64: + return Visitor::template VisitInteger(Load(ptr), std::forward(args)...); + case VariantLogicalType::INT128: + return Visitor::template VisitInteger(Load(ptr), std::forward(args)...); + case VariantLogicalType::UINT8: + return Visitor::template VisitInteger(Load(ptr), std::forward(args)...); + case VariantLogicalType::UINT16: + return Visitor::template VisitInteger(Load(ptr), std::forward(args)...); + case VariantLogicalType::UINT32: + return Visitor::template VisitInteger(Load(ptr), std::forward(args)...); + case VariantLogicalType::UINT64: + return Visitor::template VisitInteger(Load(ptr), std::forward(args)...); + case VariantLogicalType::UINT128: + return Visitor::template VisitInteger(Load(ptr), std::forward(args)...); + case VariantLogicalType::FLOAT: + return Visitor::VisitFloat(Load(ptr), std::forward(args)...); + case VariantLogicalType::DOUBLE: + return Visitor::VisitDouble(Load(ptr), std::forward(args)...); + case VariantLogicalType::UUID: + return Visitor::VisitUUID(Load(ptr), std::forward(args)...); + case VariantLogicalType::DATE: + return Visitor::VisitDate(date_t(Load(ptr)), std::forward(args)...); + case VariantLogicalType::INTERVAL: + return Visitor::VisitInterval(Load(ptr), std::forward(args)...); + case VariantLogicalType::VARCHAR: + case VariantLogicalType::BLOB: + case VariantLogicalType::BITSTRING: + case VariantLogicalType::BIGNUM: + case VariantLogicalType::GEOMETRY: + return VisitString(type_id, variant, row, values_idx, std::forward(args)...); + case VariantLogicalType::DECIMAL: + return VisitDecimal(variant, row, values_idx, std::forward(args)...); + case VariantLogicalType::ARRAY: + return VisitArray(variant, row, values_idx, std::forward(args)...); + case VariantLogicalType::OBJECT: + return VisitObject(variant, row, values_idx, std::forward(args)...); + case VariantLogicalType::TIME_MICROS: + return Visitor::VisitTime(Load(ptr), std::forward(args)...); + case VariantLogicalType::TIME_NANOS: + return Visitor::VisitTimeNanos(Load(ptr), std::forward(args)...); + case VariantLogicalType::TIME_MICROS_TZ: + return Visitor::VisitTimeTZ(Load(ptr), std::forward(args)...); + case VariantLogicalType::TIMESTAMP_SEC: + return Visitor::VisitTimestampSec(Load(ptr), std::forward(args)...); + case VariantLogicalType::TIMESTAMP_MILIS: + return Visitor::VisitTimestampMs(Load(ptr), std::forward(args)...); + case VariantLogicalType::TIMESTAMP_MICROS: + return Visitor::VisitTimestamp(Load(ptr), std::forward(args)...); + case VariantLogicalType::TIMESTAMP_NANOS: + return Visitor::VisitTimestampNanos(Load(ptr), std::forward(args)...); + case VariantLogicalType::TIMESTAMP_MICROS_TZ: + return Visitor::VisitTimestampTZ(Load(ptr), std::forward(args)...); + default: + return Visitor::VisitDefault(type_id, ptr, std::forward(args)...); + } + } + + // Non-void version + template + static typename std::enable_if::value, vector>::type + VisitArrayItems(const UnifiedVariantVectorData &variant, idx_t row, const VariantNestedData &array_data, + Args &&...args) { + vector array_items; + array_items.reserve(array_data.child_count); + for (idx_t i = 0; i < array_data.child_count; i++) { + auto values_index = variant.GetValuesIndex(row, array_data.children_idx + i); + array_items.emplace_back(Visit(variant, row, values_index, std::forward(args)...)); + } + return array_items; + } + + // Void version + template + static typename std::enable_if::value, void>::type + VisitArrayItems(const UnifiedVariantVectorData &variant, idx_t row, const VariantNestedData &array_data, + Args &&...args) { + for (idx_t i = 0; i < array_data.child_count; i++) { + auto values_index = variant.GetValuesIndex(row, array_data.children_idx + i); + Visit(variant, row, values_index, std::forward(args)...); + } + } + + template + static child_list_t VisitObjectItems(const UnifiedVariantVectorData &variant, idx_t row, + const VariantNestedData &object_data, Args &&...args) { + child_list_t object_items; + for (idx_t i = 0; i < object_data.child_count; i++) { + auto values_index = variant.GetValuesIndex(row, object_data.children_idx + i); + auto val = Visit(variant, row, values_index, std::forward(args)...); + + auto keys_index = variant.GetKeysIndex(row, object_data.children_idx + i); + auto &key = variant.GetKey(row, keys_index); + + object_items.emplace_back(key.GetString(), std::move(val)); + } + return object_items; + } + +private: + template + static typename std::enable_if::value, void>::type + VisitMetadata(VariantLogicalType type_id, Args &&...args) { + Visitor::VisitMetadata(type_id, std::forward(args)...); + } + + // Fallback if the method does not exist + template + static typename std::enable_if::value, void>::type VisitMetadata(VariantLogicalType, + Args &&...) { + // do nothing + } + + template + static ReturnType VisitArray(const UnifiedVariantVectorData &variant, idx_t row, uint32_t values_idx, + Args &&...args) { + auto decoded_nested_data = VariantUtils::DecodeNestedData(variant, row, values_idx); + return Visitor::VisitArray(variant, row, decoded_nested_data, std::forward(args)...); + } + + template + static ReturnType VisitObject(const UnifiedVariantVectorData &variant, idx_t row, uint32_t values_idx, + Args &&...args) { + auto decoded_nested_data = VariantUtils::DecodeNestedData(variant, row, values_idx); + return Visitor::VisitObject(variant, row, decoded_nested_data, std::forward(args)...); + } + + template + static ReturnType VisitString(VariantLogicalType type_id, const UnifiedVariantVectorData &variant, idx_t row, + uint32_t values_idx, Args &&...args) { + auto decoded_string = VariantUtils::DecodeStringData(variant, row, values_idx); + if (type_id == VariantLogicalType::VARCHAR) { + return Visitor::VisitString(decoded_string, std::forward(args)...); + } + if (type_id == VariantLogicalType::BLOB) { + return Visitor::VisitBlob(decoded_string, std::forward(args)...); + } + if (type_id == VariantLogicalType::BIGNUM) { + return Visitor::VisitBignum(decoded_string, std::forward(args)...); + } + if (type_id == VariantLogicalType::GEOMETRY) { + return Visitor::VisitGeometry(decoded_string, std::forward(args)...); + } + if (type_id == VariantLogicalType::BITSTRING) { + return Visitor::VisitBitstring(decoded_string, std::forward(args)...); + } + throw InternalException("String-backed variant type (%s) not handled", EnumUtil::ToString(type_id)); + } + + template + static ReturnType VisitDecimal(const UnifiedVariantVectorData &variant, idx_t row, uint32_t values_idx, + Args &&...args) { + auto decoded_decimal = VariantUtils::DecodeDecimalData(variant, row, values_idx); + auto &width = decoded_decimal.width; + auto &scale = decoded_decimal.scale; + auto &ptr = decoded_decimal.value_ptr; + if (width > DecimalWidth::max) { + throw InternalException("Can't handle decimal of width: %d", width); + } else if (width > DecimalWidth::max) { + return Visitor::template VisitDecimal(Load(ptr), width, scale, + std::forward(args)...); + } else if (width > DecimalWidth::max) { + return Visitor::template VisitDecimal(Load(ptr), width, scale, + std::forward(args)...); + } else if (width > DecimalWidth::max) { + return Visitor::template VisitDecimal(Load(ptr), width, scale, + std::forward(args)...); + } else { + return Visitor::template VisitDecimal(Load(ptr), width, scale, + std::forward(args)...); + } + } +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/common/types/vector.hpp b/src/duckdb/src/include/duckdb/common/types/vector.hpp index 1ab48c056..890118013 100644 --- a/src/duckdb/src/include/duckdb/common/types/vector.hpp +++ b/src/duckdb/src/include/duckdb/common/types/vector.hpp @@ -11,6 +11,7 @@ #include "duckdb/common/bitset.hpp" #include "duckdb/common/common.hpp" #include "duckdb/common/enums/vector_type.hpp" +#include "duckdb/common/mutex.hpp" #include "duckdb/common/types/selection_vector.hpp" #include "duckdb/common/types/validity_mask.hpp" #include "duckdb/common/types/value.hpp" @@ -21,6 +22,7 @@ namespace duckdb { class VectorCache; +class VectorChildBuffer; class VectorStringBuffer; class VectorStructBuffer; class VectorListBuffer; @@ -195,6 +197,8 @@ class Vector { DUCKDB_API void Dictionary(idx_t dictionary_size, const SelectionVector &sel, idx_t count); //! Creates a reference to a dictionary of the other vector DUCKDB_API void Dictionary(Vector &dict, idx_t dictionary_size, const SelectionVector &sel, idx_t count); + //! Creates a dictionary on the reusable dict + DUCKDB_API void Dictionary(buffer_ptr reusable_dict, const SelectionVector &sel); //! Creates the data of this vector with the specified type. Any data that //! is currently in the vector is destroyed. @@ -306,20 +310,24 @@ class Vector { //! The buffer holding auxiliary data of the vector //! e.g. a string vector uses this to store strings buffer_ptr auxiliary; - //! The buffer holding precomputed hashes of the data in the vector - //! used for caching hashes of string dictionaries - buffer_ptr cached_hashes; }; -//! The DictionaryBuffer holds a selection vector +//! The VectorChildBuffer holds a child Vector class VectorChildBuffer : public VectorBuffer { public: explicit VectorChildBuffer(Vector vector) - : VectorBuffer(VectorBufferType::VECTOR_CHILD_BUFFER), data(std::move(vector)) { + : VectorBuffer(VectorBufferType::VECTOR_CHILD_BUFFER), data(std::move(vector)), + cached_hashes(LogicalType::HASH, nullptr) { } public: Vector data; + //! Optional size/id to uniquely identify re-occurring dictionaries + optional_idx size; + string id; + //! For caching the hashes of a child buffer + mutex cached_hashes_lock; + Vector cached_hashes; }; struct ConstantVector { @@ -409,22 +417,27 @@ struct DictionaryVector { } static inline optional_idx DictionarySize(const Vector &vector) { VerifyDictionary(vector); + const auto &child_buffer = vector.auxiliary->Cast(); + if (child_buffer.size.IsValid()) { + return child_buffer.size; + } return vector.buffer->Cast().GetDictionarySize(); } static inline const string &DictionaryId(const Vector &vector) { VerifyDictionary(vector); + const auto &child_buffer = vector.auxiliary->Cast(); + if (!child_buffer.id.empty()) { + return child_buffer.id; + } return vector.buffer->Cast().GetDictionaryId(); } - static inline void SetDictionaryId(Vector &vector, string new_id) { - VerifyDictionary(vector); - vector.buffer->Cast().SetDictionaryId(std::move(new_id)); - } static inline bool CanCacheHashes(const LogicalType &type) { return type.InternalType() == PhysicalType::VARCHAR; } static inline bool CanCacheHashes(const Vector &vector) { return DictionarySize(vector).IsValid() && CanCacheHashes(vector.GetType()); } + static buffer_ptr CreateReusableDictionary(const LogicalType &type, const idx_t &size); static const Vector &GetCachedHashes(Vector &input); }; @@ -488,6 +501,13 @@ struct FlatVector { }; struct ListVector { + static inline const list_entry_t *GetData(const Vector &v) { + if (v.GetVectorType() == VectorType::DICTIONARY_VECTOR) { + auto &child = DictionaryVector::Child(v); + return GetData(child); + } + return FlatVector::GetData(v); + } static inline list_entry_t *GetData(Vector &v) { if (v.GetVectorType() == VectorType::DICTIONARY_VECTOR) { auto &child = DictionaryVector::Child(v); diff --git a/src/duckdb/src/include/duckdb/common/vector.hpp b/src/duckdb/src/include/duckdb/common/vector.hpp index 676adac20..3035edcd5 100644 --- a/src/duckdb/src/include/duckdb/common/vector.hpp +++ b/src/duckdb/src/include/duckdb/common/vector.hpp @@ -17,14 +17,23 @@ namespace duckdb { -template -class vector : public std::vector> { // NOLINT: matching name of std +template > +class vector : public std::vector { // NOLINT: matching name of std public: - using original = std::vector>; + using original = std::vector; using original::original; + using value_type = typename original::value_type; + using allocator_type = typename original::allocator_type; using size_type = typename original::size_type; - using const_reference = typename original::const_reference; + using difference_type = typename original::difference_type; using reference = typename original::reference; + using const_reference = typename original::const_reference; + using pointer = typename original::pointer; + using const_pointer = typename original::const_pointer; + using iterator = typename original::iterator; + using const_iterator = typename original::const_iterator; + using reverse_iterator = typename original::reverse_iterator; + using const_reverse_iterator = typename original::const_reverse_iterator; private: static inline void AssertIndexInBounds(idx_t index, idx_t size) { diff --git a/src/duckdb/src/include/duckdb/common/virtual_file_system.hpp b/src/duckdb/src/include/duckdb/common/virtual_file_system.hpp index 6a0f0346a..07c49b541 100644 --- a/src/duckdb/src/include/duckdb/common/virtual_file_system.hpp +++ b/src/duckdb/src/include/duckdb/common/virtual_file_system.hpp @@ -11,6 +11,7 @@ #include "duckdb/common/file_system.hpp" #include "duckdb/common/map.hpp" #include "duckdb/common/unordered_set.hpp" +#include "duckdb/main/extension_helper.hpp" namespace duckdb { @@ -82,8 +83,10 @@ class VirtualFileSystem : public FileSystem { } private: + FileSystem &FindFileSystem(const string &path, optional_ptr file_opener); + FileSystem &FindFileSystem(const string &path, optional_ptr database_instance); FileSystem &FindFileSystem(const string &path); - FileSystem &FindFileSystemInternal(const string &path); + optional_ptr FindFileSystemInternal(const string &path); private: vector> sub_systems; diff --git a/src/duckdb/src/include/duckdb/execution/expression_executor_state.hpp b/src/duckdb/src/include/duckdb/execution/expression_executor_state.hpp index 8f0b77ccf..112e76109 100644 --- a/src/duckdb/src/include/duckdb/execution/expression_executor_state.hpp +++ b/src/duckdb/src/include/duckdb/execution/expression_executor_state.hpp @@ -75,12 +75,10 @@ struct ExecuteFunctionState : public ExpressionState { //! Only valid when the expression is eligible for the dictionary expression optimization //! This is the case when the input is "practically unary", i.e., only one non-const input column optional_idx input_col_idx; - //! Storage ID of the input dictionary vector - string current_input_dictionary_id; //! Vector holding the expression executed on the entire dictionary - unique_ptr output_dictionary; - //! ID of the output dictionary_vector - string output_dictionary_id; + buffer_ptr output_dictionary; + //! ID of the input dictionary Vector + string current_input_dictionary_id; }; struct ExpressionExecutorState { diff --git a/src/duckdb/src/include/duckdb/execution/index/art/art.hpp b/src/duckdb/src/include/duckdb/execution/index/art/art.hpp index 71e64cfe7..2c4706174 100644 --- a/src/duckdb/src/include/duckdb/execution/index/art/art.hpp +++ b/src/duckdb/src/include/duckdb/execution/index/art/art.hpp @@ -116,13 +116,16 @@ class ART : public BoundIndex { void GenerateKeyVectors(ArenaAllocator &allocator, DataChunk &input, Vector &row_ids, unsafe_vector &keys, unsafe_vector &row_id_keys); - //! Verifies the nodes and optionally returns a string of the ART. - string VerifyAndToString(IndexLock &l, const bool only_verify) override; + //! Verifies the nodes. + void Verify(IndexLock &l) override; //! Verifies that the node allocations match the node counts. void VerifyAllocations(IndexLock &l) override; //! Verifies the index buffers. void VerifyBuffers(IndexLock &l) override; + //! Returns string representation of the ART. + string ToString(IndexLock &l, bool display_ascii = false) override; + private: bool SearchEqual(ARTKey &key, idx_t max_count, set &row_ids); bool SearchGreater(ARTKey &key, bool equal, idx_t max_count, set &row_ids); @@ -151,7 +154,8 @@ class ART : public BoundIndex { void WritePartialBlocks(QueryContext context, const bool v1_0_0_storage); void SetPrefixCount(const IndexStorageInfo &info); - string VerifyAndToStringInternal(const bool only_verify); + string ToStringInternal(bool display_ascii); + void VerifyInternal(); void VerifyAllocationsInternal(); }; diff --git a/src/duckdb/src/include/duckdb/execution/index/art/art_operator.hpp b/src/duckdb/src/include/duckdb/execution/index/art/art_operator.hpp index 62903b198..e1309546b 100644 --- a/src/duckdb/src/include/duckdb/execution/index/art/art_operator.hpp +++ b/src/duckdb/src/include/duckdb/execution/index/art/art_operator.hpp @@ -62,6 +62,60 @@ class ARTOperator { return nullptr; } + //! LookupInLeaf returns true if the rowid is in the leaf: + //! 1) If the leaf is an inlined leaf, check if the rowid matches. + //! 2) If the leaf is a gate node, perform a search in the nested ART for the rowid. + static bool LookupInLeaf(ART &art, const Node &node, const ARTKey &rowid) { + reference ref(node); + idx_t depth = 0; + + while (ref.get().HasMetadata()) { + const auto type = ref.get().GetType(); + switch (type) { + case NType::LEAF_INLINED: { + return ref.get().GetRowId() == rowid.GetRowId(); + } + case NType::LEAF: { + throw InternalException("Invalid node type (LEAF) for ARTOperator::NestedLookup."); + } + case NType::NODE_7_LEAF: + case NType::NODE_15_LEAF: + case NType::NODE_256_LEAF: { + D_ASSERT(depth + 1 == Prefix::ROW_ID_SIZE); + const auto byte = rowid[Prefix::ROW_ID_COUNT]; + return ref.get().HasByte(art, byte); + } + case NType::NODE_4: + case NType::NODE_16: + case NType::NODE_48: + case NType::NODE_256: { + D_ASSERT(depth < Prefix::ROW_ID_SIZE); + auto child = ref.get().GetChild(art, rowid[depth]); + if (child) { + // Continue in the child. + ref = *child; + depth++; + D_ASSERT(ref.get().HasMetadata()); + continue; + } + return false; + } + case NType::PREFIX: { + Prefix prefix(art, ref.get()); + for (idx_t i = 0; i < prefix.data[Prefix::Count(art)]; i++) { + if (prefix.data[i] != rowid[depth]) { + // The key and the prefix don't match. + return false; + } + depth++; + } + ref = *prefix.ptr; + } + } + } + return false; + } + //! Insert a key and its row ID into the node. //! Starts at depth (in the key). //! status indicates if the insert happens inside a gate or not. @@ -336,7 +390,6 @@ class ARTOperator { static void InsertIntoPrefix(ART &art, reference &node_ref, const ARTKey &key, const ARTKey &row_id, const idx_t pos, const idx_t depth, const GateStatus status) { - const auto cast_pos = UnsafeNumericCast(pos); const auto byte = Prefix::GetByte(art, node_ref, cast_pos); diff --git a/src/duckdb/src/include/duckdb/execution/index/art/base_leaf.hpp b/src/duckdb/src/include/duckdb/execution/index/art/base_leaf.hpp index 209d022dc..797c18469 100644 --- a/src/duckdb/src/include/duckdb/execution/index/art/base_leaf.hpp +++ b/src/duckdb/src/include/duckdb/execution/index/art/base_leaf.hpp @@ -31,13 +31,15 @@ class BaseLeaf { public: //! Get a new BaseLeaf and initialize it. - static BaseLeaf &New(ART &art, Node &node) { + static NodeHandle New(ART &art, Node &node) { node = Node::GetAllocator(art, TYPE).New(); node.SetMetadata(static_cast(TYPE)); - auto &n = Node::Ref(art, node, TYPE); + NodeHandle handle(art, node); + auto &n = handle.Get(); + n.count = 0; - return n; + return handle; } //! Returns true, if the byte exists, else false. @@ -70,7 +72,7 @@ class BaseLeaf { private: static void InsertByteInternal(BaseLeaf &n, const uint8_t byte); - static BaseLeaf &DeleteByteInternal(ART &art, Node &node, const uint8_t byte); + static NodeHandle DeleteByteInternal(ART &art, Node &node, const uint8_t byte); }; //! Node7Leaf holds up to seven sorted bytes. diff --git a/src/duckdb/src/include/duckdb/execution/index/art/iterator.hpp b/src/duckdb/src/include/duckdb/execution/index/art/iterator.hpp index c5907f820..793dcf40b 100644 --- a/src/duckdb/src/include/duckdb/execution/index/art/iterator.hpp +++ b/src/duckdb/src/include/duckdb/execution/index/art/iterator.hpp @@ -70,7 +70,7 @@ class Iterator { void FindMinimum(const Node &node); //! Finds the lower bound of the ART and adds the nodes to the stack. Returns false, if the lower //! bound exceeds the maximum value of the ART. - bool LowerBound(const Node &node, const ARTKey &key, const bool equal, idx_t depth); + bool LowerBound(const Node &node, const ARTKey &key, const bool equal); //! Returns the nested depth. uint8_t GetNestedDepth() const { diff --git a/src/duckdb/src/include/duckdb/execution/index/art/leaf.hpp b/src/duckdb/src/include/duckdb/execution/index/art/leaf.hpp index 30efdba0a..42826d241 100644 --- a/src/duckdb/src/include/duckdb/execution/index/art/leaf.hpp +++ b/src/duckdb/src/include/duckdb/execution/index/art/leaf.hpp @@ -57,11 +57,14 @@ class Leaf { static bool DeprecatedGetRowIds(ART &art, const Node &node, set &row_ids, const idx_t max_count); //! Vacuums the linked list of leaves. static void DeprecatedVacuum(ART &art, Node &node); - //! Returns the string representation of the linked list of leaves, if only_verify is true. - //! Else, it traverses and verifies the linked list of leaves. - static string DeprecatedVerifyAndToString(ART &art, const Node &node, const bool only_verify); + + //! Traverses and verifies the linked list of leaves. + static void DeprecatedVerify(ART &art, const Node &node); //! Count the number of leaves. void DeprecatedVerifyAllocations(ART &art, unordered_map &node_counts) const; + + //! Return string representation of the linked list of leaves. + static string DeprecatedToString(ART &art, const Node &node); }; } // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/execution/index/art/node.hpp b/src/duckdb/src/include/duckdb/execution/index/art/node.hpp index ae00d3e6d..4964e1119 100644 --- a/src/duckdb/src/include/duckdb/execution/index/art/node.hpp +++ b/src/duckdb/src/include/duckdb/execution/index/art/node.hpp @@ -94,9 +94,8 @@ class Node : public IndexPointer { //! Get the first byte greater than or equal to the byte. bool GetNextByte(ART &art, uint8_t &byte) const; - //! Returns the string representation of the node, if only_verify is false. - //! Else, it traverses and verifies the node. - string VerifyAndToString(ART &art, const bool only_verify) const; + //! Traverses and verifies the node. + void Verify(ART &art) const; //! Counts each node type. void VerifyAllocations(ART &art, unordered_map &node_counts) const; @@ -107,6 +106,9 @@ class Node : public IndexPointer { static void TransformToDeprecated(ART &art, Node &node, unsafe_unique_ptr &deprecated_prefix_allocator); + //! Returns the string representation of the node at indentation level. + string ToString(ART &art, idx_t indent_level, bool inside_gate = false, bool display_ascii = false) const; + //! Returns the node type. inline NType GetType() const { return NType(GetMetadata() & ~AND_GATE); diff --git a/src/duckdb/src/include/duckdb/execution/index/art/prefix.hpp b/src/duckdb/src/include/duckdb/execution/index/art/prefix.hpp index 835e32c0f..4709ebfc2 100644 --- a/src/duckdb/src/include/duckdb/execution/index/art/prefix.hpp +++ b/src/duckdb/src/include/duckdb/execution/index/art/prefix.hpp @@ -61,11 +61,15 @@ class Prefix { //! after its creation. static GateStatus Split(ART &art, reference &node, Node &child, const uint8_t pos); - //! Returns the string representation of the node, or only traverses and verifies the node and its subtree - static string VerifyAndToString(ART &art, const Node &node, const bool only_verify); + //! Traverses and verifies the node and its subtree + static void Verify(ART &art, const Node &node); //! Transform the child of the node. static void TransformToDeprecated(ART &art, Node &node, unsafe_unique_ptr &allocator); + //! Returns the string representation of the node at indentation level. + static string ToString(ART &art, const Node &node, idx_t indent_level, bool inside_gate = false, + bool display_ascii = false); + private: static Prefix NewInternal(ART &art, Node &node, const data_ptr_t data, const uint8_t count, const idx_t offset); diff --git a/src/duckdb/src/include/duckdb/execution/index/bound_index.hpp b/src/duckdb/src/include/duckdb/execution/index/bound_index.hpp index 914288bfa..e6c21062f 100644 --- a/src/duckdb/src/include/duckdb/execution/index/bound_index.hpp +++ b/src/duckdb/src/include/duckdb/execution/index/bound_index.hpp @@ -8,6 +8,7 @@ #pragma once +#include "duckdb/execution/index/unbound_index.hpp" #include "duckdb/common/enums/index_constraint_type.hpp" #include "duckdb/common/types/constraint_conflict_info.hpp" #include "duckdb/common/types/data_chunk.hpp" @@ -60,6 +61,16 @@ class BoundIndex : public Index { //! The index constraint type IndexConstraintType index_constraint_type; + //! The vector of unbound expressions, which are later turned into bound expressions. + //! We need to store the unbound expressions, as we might not always have the context + //! available to bind directly. + //! The leaves of these unbound expressions are BoundColumnRefExpressions. + //! These BoundColumnRefExpressions contain a binding (ColumnBinding), + //! and that contains a table_index and a column_index. + //! The table_index is a dummy placeholder. + //! The column_index indexes the column_ids vector in the Index base class. + //! Those column_ids store the physical table indexes of the Index, + //! and we use them when binding the unbound expressions. vector> unbound_expressions; public: @@ -125,9 +136,14 @@ class BoundIndex : public Index { idx_t GetInMemorySize(); //! Returns the string representation of an index, or only traverses and verifies the index. - virtual string VerifyAndToString(IndexLock &l, const bool only_verify) = 0; + virtual void Verify(IndexLock &l) = 0; //! Obtains a lock and calls VerifyAndToString. - string VerifyAndToString(const bool only_verify); + void Verify(); + + //! Returns the string representation of an index. + virtual string ToString(IndexLock &l, bool display_ascii = false) = 0; + //! Obtains a lock and calls ToString. + string ToString(bool display_ascii = false); //! Ensures that the node allocation counts match the node counts. virtual void VerifyAllocations(IndexLock &l) = 0; @@ -155,14 +171,22 @@ class BoundIndex : public Index { virtual string GetConstraintViolationMessage(VerifyExistenceType verify_type, idx_t failed_index, DataChunk &input) = 0; - void ApplyBufferedAppends(const vector &table_types, ColumnDataCollection &buffered_appends, + //! Replay index insert and delete operations buffered during WAL replay. + //! table_types has the physical types of the table in the order they appear, not logical (no generated columns). + //! mapped_column_ids contains the sorted order of Indexed physical column ID's (see unbound_index.hpp comments). + void ApplyBufferedReplays(const vector &table_types, vector &buffered_replays, const vector &mapped_column_ids); protected: //! Lock used for any changes to the index mutex lock; - //! Bound expressions used during expression execution + //! The vector of bound expressions to generate the Index keys based on a data chunk. + //! The leaves of the bound expressions are BoundReferenceExpressions. + //! These BoundReferenceExpressions contain offsets into the DataChunk to retrieve the columns + //! for the expression. + //! With these offsets into the DataChunk, the expression executor can now evaluate the expression + //! on incoming data chunks to generate the keys. vector> bound_expressions; private: diff --git a/src/duckdb/src/include/duckdb/execution/index/fixed_size_allocator.hpp b/src/duckdb/src/include/duckdb/execution/index/fixed_size_allocator.hpp index 691a4aac6..65ffd167f 100644 --- a/src/duckdb/src/include/duckdb/execution/index/fixed_size_allocator.hpp +++ b/src/duckdb/src/include/duckdb/execution/index/fixed_size_allocator.hpp @@ -30,7 +30,8 @@ class FixedSizeAllocator { public: //! Construct a new fixed-size allocator - FixedSizeAllocator(const idx_t segment_size, BlockManager &block_manager); + FixedSizeAllocator(const idx_t segment_size, BlockManager &block_manager, + MemoryTag memory_tag = MemoryTag::ART_INDEX); //! Block manager of the database instance BlockManager &block_manager; @@ -152,6 +153,8 @@ class FixedSizeAllocator { void VerifyBuffers(); private: + //! Memory tag of memory that is allocated through the allocator + MemoryTag memory_tag; //! Allocation size of one segment in a buffer //! We only need this value to calculate bitmask_count, bitmask_offset, and //! available_segments_per_buffer diff --git a/src/duckdb/src/include/duckdb/execution/index/fixed_size_buffer.hpp b/src/duckdb/src/include/duckdb/execution/index/fixed_size_buffer.hpp index e7c5b6aa9..6ca7dc1aa 100644 --- a/src/duckdb/src/include/duckdb/execution/index/fixed_size_buffer.hpp +++ b/src/duckdb/src/include/duckdb/execution/index/fixed_size_buffer.hpp @@ -43,7 +43,7 @@ class FixedSizeBuffer { public: //! Constructor for a new in-memory buffer - explicit FixedSizeBuffer(BlockManager &block_manager); + explicit FixedSizeBuffer(BlockManager &block_manager, MemoryTag memory_tag); //! Constructor for deserializing buffer metadata from disk FixedSizeBuffer(BlockManager &block_manager, const idx_t segment_count, const idx_t allocation_size, const BlockPointer &block_pointer); diff --git a/src/duckdb/src/include/duckdb/execution/index/unbound_index.hpp b/src/duckdb/src/include/duckdb/execution/index/unbound_index.hpp index ec2fc3cfd..30c5917f4 100644 --- a/src/duckdb/src/include/duckdb/execution/index/unbound_index.hpp +++ b/src/duckdb/src/include/duckdb/execution/index/unbound_index.hpp @@ -16,15 +16,28 @@ namespace duckdb { class ColumnDataCollection; +enum class BufferedIndexReplay : uint8_t { INSERT_ENTRY = 0, DEL_ENTRY = 1 }; + +struct BufferedIndexData { + BufferedIndexReplay type; + unique_ptr data; + + BufferedIndexData(BufferedIndexReplay replay_type, unique_ptr data_p); +}; + class UnboundIndex final : public Index { private: //! The CreateInfo of the index. unique_ptr create_info; //! The serialized storage information of the index. IndexStorageInfo storage_info; - //! Buffer for WAL replay appends. - unique_ptr buffered_appends; - //! Maps the column IDs in the buffered appends to the table columns. + //! Buffer for WAL replays. + vector buffered_replays; + + //! Maps the column IDs in the buffered replays to a physical table offset. + //! For example, column [i] in a buffered ColumnDataCollection is the data for an Indexed column with + //! physical table index mapped_column_ids[i]. + //! This is in sorted order of physical column IDs. vector mapped_column_ids; public: @@ -59,12 +72,17 @@ class UnboundIndex final : public Index { void CommitDrop() override; - void BufferChunk(DataChunk &chunk, Vector &row_ids, const vector &mapped_column_ids_p); - bool HasBufferedAppends() const { - return buffered_appends != nullptr; + //! Buffer Index delete or insert (replay_type) data chunk. + //! See note above on mapped_column_ids, this function assumes that index_column_chunk maps into + //! mapped_column_ids_p to get the physical column index for each Indexed column in the chunk. + void BufferChunk(DataChunk &index_column_chunk, Vector &row_ids, const vector &mapped_column_ids_p, + BufferedIndexReplay replay_type); + bool HasBufferedReplays() const { + return !buffered_replays.empty(); } - ColumnDataCollection &GetBufferedAppends() const { - return *buffered_appends; + + vector &GetBufferedReplays() { + return buffered_replays; } const vector &GetMappedColumnIds() const { return mapped_column_ids; diff --git a/src/duckdb/src/include/duckdb/execution/merge_sort_tree.hpp b/src/duckdb/src/include/duckdb/execution/merge_sort_tree.hpp index 8c04ecde0..d17e6944f 100644 --- a/src/duckdb/src/include/duckdb/execution/merge_sort_tree.hpp +++ b/src/duckdb/src/include/duckdb/execution/merge_sort_tree.hpp @@ -574,7 +574,6 @@ template template void MergeSortTree::AggregateLowerBound(const idx_t lower, const idx_t upper, const E needle, L aggregate) const { - if (lower >= upper) { return; } diff --git a/src/duckdb/src/include/duckdb/execution/operator/csv_scanner/csv_error.hpp b/src/duckdb/src/include/duckdb/execution/operator/csv_scanner/csv_error.hpp index 30ba0abc5..aad90df94 100644 --- a/src/duckdb/src/include/duckdb/execution/operator/csv_scanner/csv_error.hpp +++ b/src/duckdb/src/include/duckdb/execution/operator/csv_scanner/csv_error.hpp @@ -52,21 +52,21 @@ class CSVError { CSVError() {}; CSVError(string error_message, CSVErrorType type, idx_t column_idx, string csv_row, LinesPerBoundary error_info, idx_t row_byte_position, optional_idx byte_position, const CSVReaderOptions &reader_options, - const string &fixes, const string ¤t_path); + const string &fixes, const String ¤t_path); CSVError(string error_message, CSVErrorType type, LinesPerBoundary error_info); //! Produces error messages for column name -> type mismatch. static CSVError ColumnTypesError(case_insensitive_map_t sql_types_per_column, const vector &names); //! Produces error messages for casting errors static CSVError CastError(const CSVReaderOptions &options, const string &column_name, string &cast_error, idx_t column_idx, string &csv_row, LinesPerBoundary error_info, idx_t row_byte_position, - optional_idx byte_position, LogicalTypeId type, const string ¤t_path); + optional_idx byte_position, LogicalTypeId type, const String ¤t_path); //! Produces error for when the line size exceeds the maximum line size option static CSVError LineSizeError(const CSVReaderOptions &options, LinesPerBoundary error_info, string &csv_row, - idx_t byte_position, const string ¤t_path); + idx_t byte_position, const String ¤t_path); //! Produces error for when the state machine reaches an invalid state static CSVError InvalidState(const CSVReaderOptions &options, idx_t current_column, LinesPerBoundary error_info, string &csv_row, idx_t row_byte_position, optional_idx byte_position, - const string ¤t_path); + const String ¤t_path); //! Produces an error message for a dialect sniffing error. static CSVError SniffingError(const CSVReaderOptions &options, const string &search_space, idx_t max_columns_found, SetColumns &set_columns, bool type_detection); @@ -76,17 +76,17 @@ class CSVError { //! Produces error messages for unterminated quoted values static CSVError UnterminatedQuotesError(const CSVReaderOptions &options, idx_t current_column, LinesPerBoundary error_info, string &csv_row, idx_t row_byte_position, - optional_idx byte_position, const string ¤t_path); + optional_idx byte_position, const String ¤t_path); //! Produces error messages for null_padding option is set, and we have quoted new values in parallel static CSVError NullPaddingFail(const CSVReaderOptions &options, LinesPerBoundary error_info, - const string ¤t_path); + const String ¤t_path); //! Produces error for incorrect (e.g., smaller and lower than the predefined) number of columns in a CSV Line static CSVError IncorrectColumnAmountError(const CSVReaderOptions &state_machine, idx_t actual_columns, LinesPerBoundary error_info, string &csv_row, idx_t row_byte_position, - optional_idx byte_position, const string ¤t_path); + optional_idx byte_position, const String ¤t_path); static CSVError InvalidUTF8(const CSVReaderOptions &options, idx_t current_column, LinesPerBoundary error_info, string &csv_row, idx_t row_byte_position, optional_idx byte_position, - const string ¤t_path); + const String ¤t_path); idx_t GetBoundaryIndex() const { return error_info.boundary_idx; diff --git a/src/duckdb/src/include/duckdb/execution/operator/csv_scanner/csv_file_scanner.hpp b/src/duckdb/src/include/duckdb/execution/operator/csv_scanner/csv_file_scanner.hpp index 324501a3d..5446739ad 100644 --- a/src/duckdb/src/include/duckdb/execution/operator/csv_scanner/csv_file_scanner.hpp +++ b/src/duckdb/src/include/duckdb/execution/operator/csv_scanner/csv_file_scanner.hpp @@ -59,8 +59,8 @@ class CSVFileScan : public BaseFileReader { void PrepareReader(ClientContext &context, GlobalTableFunctionState &) override; bool TryInitializeScan(ClientContext &context, GlobalTableFunctionState &gstate, LocalTableFunctionState &lstate) override; - void Scan(ClientContext &context, GlobalTableFunctionState &global_state, LocalTableFunctionState &local_state, - DataChunk &chunk) override; + AsyncResult Scan(ClientContext &context, GlobalTableFunctionState &global_state, + LocalTableFunctionState &local_state, DataChunk &chunk) override; void FinishFile(ClientContext &context, GlobalTableFunctionState &gstate_p) override; double GetProgressInFile(ClientContext &context) override; diff --git a/src/duckdb/src/include/duckdb/execution/operator/csv_scanner/csv_reader_options.hpp b/src/duckdb/src/include/duckdb/execution/operator/csv_scanner/csv_reader_options.hpp index b3d4f9dd5..744710bef 100644 --- a/src/duckdb/src/include/duckdb/execution/operator/csv_scanner/csv_reader_options.hpp +++ b/src/duckdb/src/include/duckdb/execution/operator/csv_scanner/csv_reader_options.hpp @@ -196,7 +196,7 @@ struct CSVReaderOptions { //! Verify options are not conflicting void Verify(MultiFileOptions &file_options); - string ToString(const string ¤t_file_path) const; + string ToString(const String ¤t_file_path) const; //! If the type for column with idx i was manually set bool WasTypeManuallySet(idx_t i) const; diff --git a/src/duckdb/src/include/duckdb/execution/operator/csv_scanner/csv_state_machine.hpp b/src/duckdb/src/include/duckdb/execution/operator/csv_scanner/csv_state_machine.hpp index 45aeaad9b..f04bbd814 100644 --- a/src/duckdb/src/include/duckdb/execution/operator/csv_scanner/csv_state_machine.hpp +++ b/src/duckdb/src/include/duckdb/execution/operator/csv_scanner/csv_state_machine.hpp @@ -11,6 +11,7 @@ #include "duckdb/execution/operator/csv_scanner/csv_reader_options.hpp" #include "duckdb/execution/operator/csv_scanner/csv_buffer_manager.hpp" #include "duckdb/execution/operator/csv_scanner/csv_state_machine_cache.hpp" +#include "duckdb/common/printer.hpp" namespace duckdb { @@ -129,12 +130,12 @@ class CSVStateMachine { } void Print() const { - std::cout << "State Machine Options" << '\n'; - std::cout << "Delim: " << state_machine_options.delimiter.GetValue() << '\n'; - std::cout << "Quote: " << state_machine_options.quote.GetValue() << '\n'; - std::cout << "Escape: " << state_machine_options.escape.GetValue() << '\n'; - std::cout << "Comment: " << state_machine_options.comment.GetValue() << '\n'; - std::cout << "---------------------" << '\n'; + Printer::Print(OutputStream::STREAM_STDOUT, string("State Machine Options")); + Printer::Print(OutputStream::STREAM_STDOUT, string("Delim: ") + state_machine_options.delimiter.FormatValue()); + Printer::Print(OutputStream::STREAM_STDOUT, string("Quote: ") + state_machine_options.quote.FormatValue()); + Printer::Print(OutputStream::STREAM_STDOUT, string("Escape: ") + state_machine_options.escape.FormatValue()); + Printer::Print(OutputStream::STREAM_STDOUT, string("Comment: ") + state_machine_options.comment.FormatValue()); + Printer::Print(OutputStream::STREAM_STDOUT, string("---------------------")); } //! The Transition Array is a Finite State Machine //! It holds the transitions of all states, on all 256 possible different characters diff --git a/src/duckdb/src/include/duckdb/execution/operator/csv_scanner/string_value_scanner.hpp b/src/duckdb/src/include/duckdb/execution/operator/csv_scanner/string_value_scanner.hpp index bacabfc4f..331f9d669 100644 --- a/src/duckdb/src/include/duckdb/execution/operator/csv_scanner/string_value_scanner.hpp +++ b/src/duckdb/src/include/duckdb/execution/operator/csv_scanner/string_value_scanner.hpp @@ -41,13 +41,15 @@ class FullLinePosition { return {}; } string result; - if (end.buffer_idx == begin.buffer_idx) { - if (buffer_handles.find(end.buffer_idx) == buffer_handles.end()) { + if (end.buffer_idx == begin.buffer_idx || begin.buffer_pos == begin.buffer_size) { + idx_t buffer_idx = end.buffer_idx; + if (buffer_handles.find(buffer_idx) == buffer_handles.end()) { return {}; } - auto buffer = buffer_handles[begin.buffer_idx]->Ptr(); - first_char_nl = buffer[begin.buffer_pos] == '\n' || buffer[begin.buffer_pos] == '\r'; - for (idx_t i = begin.buffer_pos + first_char_nl; i < end.buffer_pos; i++) { + idx_t start_pos = begin.buffer_pos == begin.buffer_size ? 0 : begin.buffer_pos; + auto buffer = buffer_handles[buffer_idx]->Ptr(); + first_char_nl = buffer[start_pos] == '\n' || buffer[start_pos] == '\r'; + for (idx_t i = start_pos + first_char_nl; i < end.buffer_pos; i++) { result += buffer[i]; } } else { @@ -55,6 +57,9 @@ class FullLinePosition { buffer_handles.find(end.buffer_idx) == buffer_handles.end()) { return {}; } + if (begin.buffer_pos >= begin.buffer_size) { + throw InternalException("CSV reader: buffer pos out of range for buffer"); + } auto first_buffer = buffer_handles[begin.buffer_idx]->Ptr(); auto first_buffer_size = buffer_handles[begin.buffer_idx]->actual_size; auto second_buffer = buffer_handles[end.buffer_idx]->Ptr(); @@ -248,7 +253,7 @@ class StringValueResult : public ScannerResult { //! We store borked rows so we can generate multiple errors during flushing unordered_set borked_rows; - const string path; + String path; //! Variable used when trying to figure out where a new segment starts, we must always start from a Valid //! (i.e., non-comment) line. diff --git a/src/duckdb/src/include/duckdb/execution/operator/helper/physical_batch_collector.hpp b/src/duckdb/src/include/duckdb/execution/operator/helper/physical_batch_collector.hpp index 2a7425279..ff6365f6b 100644 --- a/src/duckdb/src/include/duckdb/execution/operator/helper/physical_batch_collector.hpp +++ b/src/duckdb/src/include/duckdb/execution/operator/helper/physical_batch_collector.hpp @@ -18,7 +18,7 @@ class PhysicalBatchCollector : public PhysicalResultCollector { PhysicalBatchCollector(PhysicalPlan &physical_plan, PreparedStatementData &data); public: - unique_ptr GetResult(GlobalSinkState &state) override; + unique_ptr GetResult(GlobalSinkState &state) const override; public: // Sink interface @@ -44,7 +44,8 @@ class PhysicalBatchCollector : public PhysicalResultCollector { //===--------------------------------------------------------------------===// class BatchCollectorGlobalState : public GlobalSinkState { public: - BatchCollectorGlobalState(ClientContext &context, const PhysicalBatchCollector &op) : data(context, op.types) { + BatchCollectorGlobalState(ClientContext &context, const PhysicalBatchCollector &op) + : data(context, op.types, op.memory_type) { } mutex glock; @@ -54,7 +55,8 @@ class BatchCollectorGlobalState : public GlobalSinkState { class BatchCollectorLocalState : public LocalSinkState { public: - BatchCollectorLocalState(ClientContext &context, const PhysicalBatchCollector &op) : data(context, op.types) { + BatchCollectorLocalState(ClientContext &context, const PhysicalBatchCollector &op) + : data(context, op.types, op.memory_type) { } BatchedDataCollection data; diff --git a/src/duckdb/src/include/duckdb/execution/operator/helper/physical_buffered_batch_collector.hpp b/src/duckdb/src/include/duckdb/execution/operator/helper/physical_buffered_batch_collector.hpp index 74865c5d5..cb5e8892f 100644 --- a/src/duckdb/src/include/duckdb/execution/operator/helper/physical_buffered_batch_collector.hpp +++ b/src/duckdb/src/include/duckdb/execution/operator/helper/physical_buffered_batch_collector.hpp @@ -26,7 +26,7 @@ class PhysicalBufferedBatchCollector : public PhysicalResultCollector { PhysicalBufferedBatchCollector(PhysicalPlan &physical_plan, PreparedStatementData &data); public: - unique_ptr GetResult(GlobalSinkState &state) override; + unique_ptr GetResult(GlobalSinkState &state) const override; public: // Sink interface diff --git a/src/duckdb/src/include/duckdb/execution/operator/helper/physical_buffered_collector.hpp b/src/duckdb/src/include/duckdb/execution/operator/helper/physical_buffered_collector.hpp index 16dfdafb7..fadd37bf0 100644 --- a/src/duckdb/src/include/duckdb/execution/operator/helper/physical_buffered_collector.hpp +++ b/src/duckdb/src/include/duckdb/execution/operator/helper/physical_buffered_collector.hpp @@ -20,7 +20,7 @@ class PhysicalBufferedCollector : public PhysicalResultCollector { bool parallel; public: - unique_ptr GetResult(GlobalSinkState &state) override; + unique_ptr GetResult(GlobalSinkState &state) const override; public: // Sink interface diff --git a/src/duckdb/src/include/duckdb/execution/operator/helper/physical_materialized_collector.hpp b/src/duckdb/src/include/duckdb/execution/operator/helper/physical_materialized_collector.hpp index 2c73eee8d..5ed71542f 100644 --- a/src/duckdb/src/include/duckdb/execution/operator/helper/physical_materialized_collector.hpp +++ b/src/duckdb/src/include/duckdb/execution/operator/helper/physical_materialized_collector.hpp @@ -21,7 +21,7 @@ class PhysicalMaterializedCollector : public PhysicalResultCollector { bool parallel; public: - unique_ptr GetResult(GlobalSinkState &state) override; + unique_ptr GetResult(GlobalSinkState &state) const override; public: // Sink interface @@ -35,20 +35,4 @@ class PhysicalMaterializedCollector : public PhysicalResultCollector { bool SinkOrderDependent() const override; }; -//===--------------------------------------------------------------------===// -// Sink -//===--------------------------------------------------------------------===// -class MaterializedCollectorGlobalState : public GlobalSinkState { -public: - mutex glock; - unique_ptr collection; - shared_ptr context; -}; - -class MaterializedCollectorLocalState : public LocalSinkState { -public: - unique_ptr collection; - ColumnDataAppendState append_state; -}; - } // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/execution/operator/helper/physical_prepare.hpp b/src/duckdb/src/include/duckdb/execution/operator/helper/physical_prepare.hpp index 373b82e5d..e8d17dc37 100644 --- a/src/duckdb/src/include/duckdb/execution/operator/helper/physical_prepare.hpp +++ b/src/duckdb/src/include/duckdb/execution/operator/helper/physical_prepare.hpp @@ -11,6 +11,7 @@ #include "duckdb/execution/physical_operator.hpp" #include "duckdb/common/enums/physical_operator_type.hpp" #include "duckdb/main/prepared_statement_data.hpp" +#include "duckdb/execution/physical_plan_generator.hpp" namespace duckdb { @@ -19,13 +20,13 @@ class PhysicalPrepare : public PhysicalOperator { static constexpr const PhysicalOperatorType TYPE = PhysicalOperatorType::PREPARE; public: - PhysicalPrepare(PhysicalPlan &physical_plan, string name_p, shared_ptr prepared, + PhysicalPrepare(PhysicalPlan &physical_plan, const std::string &name_p, shared_ptr prepared, idx_t estimated_cardinality) : PhysicalOperator(physical_plan, PhysicalOperatorType::PREPARE, {LogicalType::BOOLEAN}, estimated_cardinality), - name(std::move(name_p)), prepared(std::move(prepared)) { + name(physical_plan.ArenaRef().MakeString(name_p)), prepared(std::move(prepared)) { } - string name; + String name; shared_ptr prepared; public: diff --git a/src/duckdb/src/include/duckdb/execution/operator/helper/physical_result_collector.hpp b/src/duckdb/src/include/duckdb/execution/operator/helper/physical_result_collector.hpp index c141dad66..e654a8e9d 100644 --- a/src/duckdb/src/include/duckdb/execution/operator/helper/physical_result_collector.hpp +++ b/src/duckdb/src/include/duckdb/execution/operator/helper/physical_result_collector.hpp @@ -13,7 +13,9 @@ #include "duckdb/common/enums/statement_type.hpp" namespace duckdb { + class PreparedStatementData; +class ColumnDataCollection; //! PhysicalResultCollector is an abstract class that is used to generate the final result of a query class PhysicalResultCollector : public PhysicalOperator { @@ -25,6 +27,7 @@ class PhysicalResultCollector : public PhysicalOperator { StatementType statement_type; StatementProperties properties; + QueryResultMemoryType memory_type; PhysicalOperator &plan; vector names; @@ -33,7 +36,7 @@ class PhysicalResultCollector : public PhysicalOperator { public: //! The final method used to fetch the query result from this operator - virtual unique_ptr GetResult(GlobalSinkState &state) = 0; + virtual unique_ptr GetResult(GlobalSinkState &state) const = 0; bool IsSink() const override { return true; @@ -52,6 +55,9 @@ class PhysicalResultCollector : public PhysicalOperator { virtual bool IsStreaming() const { return false; } + +protected: + unique_ptr CreateCollection(ClientContext &context) const; }; } // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/execution/operator/helper/physical_set.hpp b/src/duckdb/src/include/duckdb/execution/operator/helper/physical_set.hpp index c7f2fb038..0f4cc2a80 100644 --- a/src/duckdb/src/include/duckdb/execution/operator/helper/physical_set.hpp +++ b/src/duckdb/src/include/duckdb/execution/operator/helper/physical_set.hpp @@ -11,6 +11,7 @@ #include "duckdb/common/enums/set_scope.hpp" #include "duckdb/execution/physical_operator.hpp" #include "duckdb/parser/parsed_data/vacuum_info.hpp" +#include "duckdb/execution/physical_plan_generator.hpp" namespace duckdb { @@ -26,7 +27,7 @@ class PhysicalSet : public PhysicalOperator { PhysicalSet(PhysicalPlan &physical_plan, const string &name_p, Value value_p, SetScope scope_p, idx_t estimated_cardinality) : PhysicalOperator(physical_plan, PhysicalOperatorType::SET, {LogicalType::BOOLEAN}, estimated_cardinality), - name(name_p), value(std::move(value_p)), scope(scope_p) { + name(physical_plan.ArenaRef().MakeString(name_p)), value(std::move(value_p)), scope(scope_p) { } public: @@ -37,13 +38,13 @@ class PhysicalSet : public PhysicalOperator { return true; } - static void SetExtensionVariable(ClientContext &context, ExtensionOption &extension_option, const string &name, + static void SetExtensionVariable(ClientContext &context, ExtensionOption &extension_option, const String &name, SetScope scope, const Value &value); - static void SetGenericVariable(ClientContext &context, const string &name, SetScope scope, Value target_value); + static void SetGenericVariable(ClientContext &context, const String &name, SetScope scope, Value target_value); public: - const string name; + String name; const Value value; const SetScope scope; }; diff --git a/src/duckdb/src/include/duckdb/execution/operator/helper/physical_set_variable.hpp b/src/duckdb/src/include/duckdb/execution/operator/helper/physical_set_variable.hpp index 4574cd868..7c5378855 100644 --- a/src/duckdb/src/include/duckdb/execution/operator/helper/physical_set_variable.hpp +++ b/src/duckdb/src/include/duckdb/execution/operator/helper/physical_set_variable.hpp @@ -18,7 +18,7 @@ class PhysicalSetVariable : public PhysicalOperator { static constexpr const PhysicalOperatorType TYPE = PhysicalOperatorType::SET_VARIABLE; public: - PhysicalSetVariable(PhysicalPlan &physical_plan, string name, idx_t estimated_cardinality); + PhysicalSetVariable(PhysicalPlan &physical_plan, const string &name_p, idx_t estimated_cardinality); public: // Source interface @@ -37,7 +37,7 @@ class PhysicalSetVariable : public PhysicalOperator { } public: - const string name; + String name; }; } // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/execution/operator/join/outer_join_marker.hpp b/src/duckdb/src/include/duckdb/execution/operator/join/outer_join_marker.hpp index b90a26767..6d853e6f8 100644 --- a/src/duckdb/src/include/duckdb/execution/operator/join/outer_join_marker.hpp +++ b/src/duckdb/src/include/duckdb/execution/operator/join/outer_join_marker.hpp @@ -31,7 +31,7 @@ class OuterJoinMarker { public: explicit OuterJoinMarker(bool enabled); - bool Enabled() { + bool Enabled() const { return enabled; } //! Initializes the outer join counter diff --git a/src/duckdb/src/include/duckdb/execution/operator/join/perfect_hash_join_executor.hpp b/src/duckdb/src/include/duckdb/execution/operator/join/perfect_hash_join_executor.hpp index 0affaf4cc..9e5134481 100644 --- a/src/duckdb/src/include/duckdb/execution/operator/join/perfect_hash_join_executor.hpp +++ b/src/duckdb/src/include/duckdb/execution/operator/join/perfect_hash_join_executor.hpp @@ -29,7 +29,7 @@ struct PerfectHashJoinStats { //! PhysicalHashJoin represents a hash loop join between two tables class PerfectHashJoinExecutor { - using PerfectHashTable = vector; + using PerfectHashTable = vector>; public: PerfectHashJoinExecutor(const PhysicalHashJoin &join, JoinHashTable &ht); @@ -64,7 +64,7 @@ class PerfectHashJoinExecutor { //! Build statistics PerfectHashJoinStats perfect_join_statistics; //! Stores the occurrences of each value in the build side - unsafe_unique_array bitmap_build_idx; + ValidityMask bitmap_build_idx; //! Stores the number of unique keys in the build side idx_t unique_keys = 0; }; diff --git a/src/duckdb/src/include/duckdb/execution/operator/join/physical_asof_join.hpp b/src/duckdb/src/include/duckdb/execution/operator/join/physical_asof_join.hpp index 6089a728b..24382e8d9 100644 --- a/src/duckdb/src/include/duckdb/execution/operator/join/physical_asof_join.hpp +++ b/src/duckdb/src/include/duckdb/execution/operator/join/physical_asof_join.hpp @@ -37,18 +37,6 @@ class PhysicalAsOfJoin : public PhysicalComparisonJoin { // Projection mappings vector right_projection_map; - // Predicate (join conditions that don't reference both sides) - unique_ptr predicate; - -public: - // Operator Interface - unique_ptr GetGlobalOperatorState(ClientContext &context) const override; - unique_ptr GetOperatorState(ExecutionContext &context) const override; - - bool ParallelOperator() const override { - return true; - } - protected: // CachingOperator Interface OperatorResultType ExecuteInternal(ExecutionContext &context, DataChunk &input, DataChunk &chunk, @@ -83,6 +71,9 @@ class PhysicalAsOfJoin : public PhysicalComparisonJoin { bool ParallelSink() const override { return true; } + +public: + void BuildPipelines(Pipeline ¤t, MetaPipeline &meta_pipeline) override; }; } // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/execution/operator/join/physical_iejoin.hpp b/src/duckdb/src/include/duckdb/execution/operator/join/physical_iejoin.hpp index b57fe772d..a93109a0e 100644 --- a/src/duckdb/src/include/duckdb/execution/operator/join/physical_iejoin.hpp +++ b/src/duckdb/src/include/duckdb/execution/operator/join/physical_iejoin.hpp @@ -70,10 +70,6 @@ class PhysicalIEJoin : public PhysicalRangeJoin { public: void BuildPipelines(Pipeline ¤t, MetaPipeline &meta_pipeline) override; - -private: - // resolve joins that can potentially output N*M elements (INNER, LEFT, FULL) - void ResolveComplexJoin(ExecutionContext &context, DataChunk &result, LocalSourceState &state) const; }; } // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/execution/operator/join/physical_nested_loop_join.hpp b/src/duckdb/src/include/duckdb/execution/operator/join/physical_nested_loop_join.hpp index 25ed9ed06..2cdff1374 100644 --- a/src/duckdb/src/include/duckdb/execution/operator/join/physical_nested_loop_join.hpp +++ b/src/duckdb/src/include/duckdb/execution/operator/join/physical_nested_loop_join.hpp @@ -18,13 +18,16 @@ class PhysicalNestedLoopJoin : public PhysicalComparisonJoin { static constexpr const PhysicalOperatorType TYPE = PhysicalOperatorType::NESTED_LOOP_JOIN; public: - PhysicalNestedLoopJoin(PhysicalPlan &physical_plan, LogicalOperator &op, PhysicalOperator &left, + PhysicalNestedLoopJoin(PhysicalPlan &physical_plan, LogicalComparisonJoin &op, PhysicalOperator &left, PhysicalOperator &right, vector cond, JoinType join_type, idx_t estimated_cardinality, unique_ptr pushdown_info); - PhysicalNestedLoopJoin(PhysicalPlan &physical_plan, LogicalOperator &op, PhysicalOperator &left, + PhysicalNestedLoopJoin(PhysicalPlan &physical_plan, LogicalComparisonJoin &op, PhysicalOperator &left, PhysicalOperator &right, vector cond, JoinType join_type, idx_t estimated_cardinality); + // Predicate (join conditions that don't reference both sides) + unique_ptr predicate; + public: // Operator Interface unique_ptr GetOperatorState(ExecutionContext &context) const override; diff --git a/src/duckdb/src/include/duckdb/execution/operator/join/physical_piecewise_merge_join.hpp b/src/duckdb/src/include/duckdb/execution/operator/join/physical_piecewise_merge_join.hpp index 4da01aff3..12d974ddd 100644 --- a/src/duckdb/src/include/duckdb/execution/operator/join/physical_piecewise_merge_join.hpp +++ b/src/duckdb/src/include/duckdb/execution/operator/join/physical_piecewise_merge_join.hpp @@ -46,6 +46,8 @@ class PhysicalPiecewiseMergeJoin : public PhysicalRangeJoin { public: // Source interface unique_ptr GetGlobalSourceState(ClientContext &context) const override; + unique_ptr GetLocalSourceState(ExecutionContext &context, + GlobalSourceState &gstate) const override; SourceResultType GetData(ExecutionContext &context, DataChunk &chunk, OperatorSourceInput &input) const override; bool IsSource() const override { diff --git a/src/duckdb/src/include/duckdb/execution/operator/join/physical_range_join.hpp b/src/duckdb/src/include/duckdb/execution/operator/join/physical_range_join.hpp index 4ee6ef557..1edb36ed4 100644 --- a/src/duckdb/src/include/duckdb/execution/operator/join/physical_range_join.hpp +++ b/src/duckdb/src/include/duckdb/execution/operator/join/physical_range_join.hpp @@ -1,43 +1,42 @@ //===----------------------------------------------------------------------===// // DuckDB // -// duckdb/execution/operator/join/physical_piecewise_merge_join.hpp +// duckdb/execution/operator/join/physical_range_join.hpp // // //===----------------------------------------------------------------------===// #pragma once +#include "duckdb/common/types/row/block_iterator.hpp" #include "duckdb/execution/operator/join/physical_comparison_join.hpp" -#include "duckdb/planner/bound_result_modifier.hpp" -#include "duckdb/common/sort/sort.hpp" +#include "duckdb/common/sorting/sort.hpp" +#include "duckdb/common/sorting/sorted_run.hpp" namespace duckdb { -struct GlobalSortState; - //! PhysicalRangeJoin represents one or more inequality range join predicates between //! two tables class PhysicalRangeJoin : public PhysicalComparisonJoin { public: + class GlobalSortedTable; + class LocalSortedTable { public: - LocalSortedTable(ClientContext &context, const PhysicalRangeJoin &op, const idx_t child); + LocalSortedTable(ExecutionContext &context, GlobalSortedTable &global_table, const idx_t child); - void Sink(DataChunk &input, GlobalSortState &global_sort_state); + void Sink(ExecutionContext &context, DataChunk &input); - inline void Sort(GlobalSortState &global_sort_state) { - local_sort_state.Sort(global_sort_state, true); - } - - //! The hosting operator - const PhysicalRangeJoin &op; + //! The global table we are connected to + GlobalSortedTable &global_table; //! The local sort state - LocalSortState local_sort_state; + unique_ptr local_sink; //! Local copy of the sorting expression executor ExpressionExecutor executor; //! Holds a vector of incoming sorting columns DataChunk keys; + //! The sort data + DataChunk sort_chunk; //! The number of NULL values idx_t has_null; //! The total number of rows @@ -50,45 +49,89 @@ class PhysicalRangeJoin : public PhysicalComparisonJoin { class GlobalSortedTable { public: - GlobalSortedTable(ClientContext &context, const vector &orders, RowLayout &payload_layout, - const PhysicalOperator &op); + GlobalSortedTable(ClientContext &client, const vector &orders, + const vector &payload_layout, const PhysicalRangeJoin &op); inline idx_t Count() const { return count; } inline idx_t BlockCount() const { - if (global_sort_state.sorted_blocks.empty()) { - return 0; - } - D_ASSERT(global_sort_state.sorted_blocks.size() == 1); - return global_sort_state.sorted_blocks[0]->radix_sorting_data.size(); + return sorted->key_data->ChunkCount(); + } + + inline idx_t BlockStart(idx_t i) const { + return MinValue(i * STANDARD_VECTOR_SIZE, count); + } + + inline idx_t BlockEnd(idx_t i) const { + return BlockStart(i + 1) - 1; } inline idx_t BlockSize(idx_t i) const { - return global_sort_state.sorted_blocks[0]->radix_sorting_data[i]->count; + return i < BlockCount() ? MinValue(STANDARD_VECTOR_SIZE, count - BlockStart(i)) : 0; + } + + inline SortKeyType GetSortKeyType() const { + return sorted->key_data->GetLayout().GetSortKeyType(); } - void Combine(LocalSortedTable <able); void IntializeMatches(); + + //! Combine local states + void Combine(ExecutionContext &context, LocalSortedTable <able); + //! Prepare for sorting. + void Finalize(ClientContext &client, InterruptState &interrupt); + //! Schedules the materialisation process. + void Materialize(Pipeline &pipeline, Event &event); + //! Single-threaded materialisation. + void Materialize(ExecutionContext &context, InterruptState &interrupt); + //! Materialize an empty sorted run. + void MaterializeEmpty(ClientContext &client); + //! Print the table to the console void Print(); - //! Starts the sorting process. - void Finalize(Pipeline &pipeline, Event &event); - //! Schedules tasks to merge sort the current child's data during a Finalize phase - void ScheduleMergeTasks(Pipeline &pipeline, Event &event); + //! Create an iteration state + unique_ptr CreateIteratorState() { + auto state = make_uniq(*sorted->key_data, sorted->payload_data.get()); + + // Unless we do this, we will only get values from the first chunk + Repin(*state); + + return state; + } + //! Reset the pins for an iterator so we release memory in a timely manner + static void Repin(ExternalBlockIteratorState &iter) { + iter.SetKeepPinned(true); + iter.SetPinPayload(true); + } + //! Create an iteration state + unique_ptr CreateScanState(ClientContext &client) { + return make_uniq(client, *sort); + } + //! Initialize a payload scanning state + void InitializePayloadState(TupleDataChunkState &state) { + sorted->payload_data->InitializeChunkState(state); + } //! The hosting operator - const PhysicalOperator &op; - GlobalSortState global_sort_state; + const PhysicalRangeJoin &op; + //! The sort description + unique_ptr sort; + //! The shared sort state + unique_ptr global_sink; //! Whether or not the RHS has NULL values atomic has_null; //! The total number of rows in the RHS atomic count; + //! The number of materialisation tasks completed in parallel + atomic tasks_completed; + //! The shared materialisation state + unique_ptr global_source; + //! The materialized data + unique_ptr sorted; //! A bool indicating for each tuple in the RHS if they found a match (only used in FULL OUTER JOIN) unsafe_unique_array found_match; - //! Memory usage per thread - idx_t memory_per_thread; }; public: @@ -106,10 +149,9 @@ class PhysicalRangeJoin : public PhysicalComparisonJoin { public: // Gather the result values and slice the payload columns to those values. - // Returns a buffer handle to the pinned heap block (if any) - static BufferHandle SliceSortedPayload(DataChunk &payload, GlobalSortState &state, const idx_t block_idx, - const SelectionVector &result, const idx_t result_count, - const idx_t left_cols = 0); + static void SliceSortedPayload(DataChunk &chunk, GlobalSortedTable &table, ExternalBlockIteratorState &state, + TupleDataChunkState &chunk_state, const idx_t chunk_idx, SelectionVector &result, + const idx_t result_count, SortedRunScanState &scan_state); // Apply a tail condition to the current selection static idx_t SelectJoinTail(const ExpressionType &condition, Vector &left, Vector &right, const SelectionVector *sel, idx_t count, SelectionVector *true_sel); diff --git a/src/duckdb/src/include/duckdb/execution/operator/scan/physical_table_scan.hpp b/src/duckdb/src/include/duckdb/execution/operator/scan/physical_table_scan.hpp index 811a1abda..41b84c564 100644 --- a/src/duckdb/src/include/duckdb/execution/operator/scan/physical_table_scan.hpp +++ b/src/duckdb/src/include/duckdb/execution/operator/scan/physical_table_scan.hpp @@ -14,6 +14,7 @@ #include "duckdb/storage/data_table.hpp" #include "duckdb/common/extra_operator_info.hpp" #include "duckdb/common/column_index.hpp" +#include "duckdb/execution/physical_table_scan_enum.hpp" namespace duckdb { diff --git a/src/duckdb/src/include/duckdb/execution/physical_table_scan_enum.hpp b/src/duckdb/src/include/duckdb/execution/physical_table_scan_enum.hpp new file mode 100644 index 000000000..eb7886651 --- /dev/null +++ b/src/duckdb/src/include/duckdb/execution/physical_table_scan_enum.hpp @@ -0,0 +1,22 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/common/enums/physical_table_scan_enum.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include + +namespace duckdb { + +enum class PhysicalTableScanExecutionStrategy : uint8_t { + DEFAULT, + TASK_EXECUTOR, + SYNCHRONOUS, + TASK_EXECUTOR_BUT_FORCE_SYNC_CHECKS +}; + +}; // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/function/aggregate/minmax_n_helpers.hpp b/src/duckdb/src/include/duckdb/function/aggregate/minmax_n_helpers.hpp index a26772819..666369cc7 100644 --- a/src/duckdb/src/include/duckdb/function/aggregate/minmax_n_helpers.hpp +++ b/src/duckdb/src/include/duckdb/function/aggregate/minmax_n_helpers.hpp @@ -242,6 +242,30 @@ class BinaryAggregateHeap { idx_t size; }; +enum class ArgMinMaxNullHandling { IGNORE_ANY_NULL, HANDLE_ARG_NULL, HANDLE_ANY_NULL }; + +struct ArgMinMaxFunctionData : FunctionData { + explicit ArgMinMaxFunctionData(ArgMinMaxNullHandling null_handling_p = ArgMinMaxNullHandling::IGNORE_ANY_NULL, + bool nulls_last_p = true) + : null_handling(null_handling_p), nulls_last(nulls_last_p) { + } + + unique_ptr Copy() const override { + auto copy = make_uniq(); + copy->null_handling = null_handling; + copy->nulls_last = nulls_last; + return std::move(copy); + } + + bool Equals(const FunctionData &other_p) const override { + auto &other = other_p.Cast(); + return other.null_handling == null_handling && other.nulls_last == nulls_last; + } + + ArgMinMaxNullHandling null_handling; + bool nulls_last; +}; + //------------------------------------------------------------------------------ // Specializations for fixed size types, strings, and anything else (using sortkey) //------------------------------------------------------------------------------ @@ -254,7 +278,7 @@ struct MinMaxFixedValue { return UnifiedVectorFormat::GetData(format)[idx]; } - static void Assign(Vector &vector, const idx_t idx, const TYPE &value) { + static void Assign(Vector &vector, const idx_t idx, const TYPE &value, const bool nulls_last) { FlatVector::GetData(vector)[idx] = value; } @@ -263,7 +287,8 @@ struct MinMaxFixedValue { return false; } - static void PrepareData(Vector &input, const idx_t count, EXTRA_STATE &, UnifiedVectorFormat &format) { + static void PrepareData(Vector &input, const idx_t count, EXTRA_STATE &, UnifiedVectorFormat &format, + const bool nulls_last) { input.ToUnifiedFormat(count, format); } }; @@ -276,7 +301,7 @@ struct MinMaxStringValue { return UnifiedVectorFormat::GetData(format)[idx]; } - static void Assign(Vector &vector, const idx_t idx, const TYPE &value) { + static void Assign(Vector &vector, const idx_t idx, const TYPE &value, const bool nulls_last) { FlatVector::GetData(vector)[idx] = StringVector::AddStringOrBlob(vector, value); } @@ -285,7 +310,8 @@ struct MinMaxStringValue { return false; } - static void PrepareData(Vector &input, const idx_t count, EXTRA_STATE &, UnifiedVectorFormat &format) { + static void PrepareData(Vector &input, const idx_t count, EXTRA_STATE &, UnifiedVectorFormat &format, + const bool nulls_last) { input.ToUnifiedFormat(count, format); } }; @@ -299,8 +325,9 @@ struct MinMaxFallbackValue { return UnifiedVectorFormat::GetData(format)[idx]; } - static void Assign(Vector &vector, const idx_t idx, const TYPE &value) { - OrderModifiers modifiers(OrderType::ASCENDING, OrderByNullType::NULLS_LAST); + static void Assign(Vector &vector, const idx_t idx, const TYPE &value, const bool nulls_last) { + auto order_by_null_type = nulls_last ? OrderByNullType::NULLS_LAST : OrderByNullType::NULLS_FIRST; + OrderModifiers modifiers(OrderType::ASCENDING, order_by_null_type); CreateSortKeyHelpers::DecodeSortKey(value, vector, idx, modifiers); } @@ -308,14 +335,61 @@ struct MinMaxFallbackValue { return Vector(LogicalTypeId::BLOB); } - static void PrepareData(Vector &input, const idx_t count, EXTRA_STATE &extra_state, UnifiedVectorFormat &format) { - const OrderModifiers modifiers(OrderType::ASCENDING, OrderByNullType::NULLS_LAST); + static void PrepareData(Vector &input, const idx_t count, EXTRA_STATE &extra_state, UnifiedVectorFormat &format, + const bool nulls_last) { + auto order_by_null_type = nulls_last ? OrderByNullType::NULLS_LAST : OrderByNullType::NULLS_FIRST; + const OrderModifiers modifiers(OrderType::ASCENDING, order_by_null_type); CreateSortKeyHelpers::CreateSortKeyWithValidity(input, extra_state, modifiers, count); input.Flatten(count); extra_state.ToUnifiedFormat(count, format); } }; +template +struct ValueOrNull { + T value; + bool is_valid; + + bool operator==(const ValueOrNull &other) const { + return is_valid == other.is_valid && value == other.value; + } + + bool operator>(const ValueOrNull &other) const { + if (is_valid && other.is_valid) { + return value > other.value; + } + if (!is_valid && !other.is_valid) { + return false; + } + + return is_valid ^ NULLS_LAST; + } +}; + +template +struct MinMaxFixedValueOrNull { + using TYPE = ValueOrNull; + using EXTRA_STATE = bool; + + static TYPE Create(const UnifiedVectorFormat &format, const idx_t idx) { + return TYPE {UnifiedVectorFormat::GetData(format)[idx], format.validity.RowIsValid(idx)}; + } + + static void Assign(Vector &vector, const idx_t idx, const TYPE &value, const bool nulls_last) { + FlatVector::Validity(vector).Set(idx, value.is_valid); + FlatVector::GetData(vector)[idx] = value.value; + } + + static EXTRA_STATE CreateExtraState(Vector &input, idx_t count) { + return false; + } + + static void PrepareData(Vector &input, const idx_t count, EXTRA_STATE &extra_state, UnifiedVectorFormat &format, + const bool nulls_last) { + input.ToUnifiedFormat(count, format); + } +}; + //------------------------------------------------------------------------------ // MinMaxN Operation (common for both ArgMinMaxN and MinMaxN) //------------------------------------------------------------------------------ @@ -343,7 +417,11 @@ struct MinMaxNOperation { } template - static void Finalize(Vector &state_vector, AggregateInputData &, Vector &result, idx_t count, idx_t offset) { + static void Finalize(Vector &state_vector, AggregateInputData &input_data, Vector &result, idx_t count, + idx_t offset) { + // We only expect bind data from arg_max, otherwise nulls last is the default + const bool nulls_last = + input_data.bind_data ? input_data.bind_data->Cast().nulls_last : true; UnifiedVectorFormat state_format; state_vector.ToUnifiedFormat(count, state_format); @@ -387,7 +465,7 @@ struct MinMaxNOperation { auto heap = state.heap.SortAndGetHeap(); for (idx_t slot = 0; slot < state.heap.Size(); slot++) { - STATE::VAL_TYPE::Assign(child_data, current_offset++, state.heap.GetValue(heap[slot])); + STATE::VAL_TYPE::Assign(child_data, current_offset++, state.heap.GetValue(heap[slot]), nulls_last); } } diff --git a/src/duckdb/src/include/duckdb/function/cast/default_casts.hpp b/src/duckdb/src/include/duckdb/function/cast/default_casts.hpp index d87a3a976..107b482b7 100644 --- a/src/duckdb/src/include/duckdb/function/cast/default_casts.hpp +++ b/src/duckdb/src/include/duckdb/function/cast/default_casts.hpp @@ -170,6 +170,7 @@ struct DefaultCasts { static BoundCastInfo UnionCastSwitch(BindCastInput &input, const LogicalType &source, const LogicalType &target); static BoundCastInfo VariantCastSwitch(BindCastInput &input, const LogicalType &source, const LogicalType &target); static BoundCastInfo UUIDCastSwitch(BindCastInput &input, const LogicalType &source, const LogicalType &target); + static BoundCastInfo GeoCastSwitch(BindCastInput &input, const LogicalType &source, const LogicalType &target); static BoundCastInfo BignumCastSwitch(BindCastInput &input, const LogicalType &source, const LogicalType &target); static BoundCastInfo ImplicitToUnionCast(BindCastInput &input, const LogicalType &source, const LogicalType &target); diff --git a/src/duckdb/src/include/duckdb/function/cast/variant/primitive_to_variant.hpp b/src/duckdb/src/include/duckdb/function/cast/variant/primitive_to_variant.hpp index 2e7fbf68e..9aa105bd3 100644 --- a/src/duckdb/src/include/duckdb/function/cast/variant/primitive_to_variant.hpp +++ b/src/duckdb/src/include/duckdb/function/cast/variant/primitive_to_variant.hpp @@ -357,6 +357,9 @@ bool ConvertPrimitiveToVariant(ToVariantSourceData &source, ToVariantGlobalResul case LogicalTypeId::CHAR: return ConvertPrimitiveTemplated( source, result, count, selvec, values_index_selvec, empty_payload, is_root); + case LogicalTypeId::GEOMETRY: + return ConvertPrimitiveTemplated( + source, result, count, selvec, values_index_selvec, empty_payload, is_root); case LogicalTypeId::BLOB: return ConvertPrimitiveTemplated( source, result, count, selvec, values_index_selvec, empty_payload, is_root); diff --git a/src/duckdb/src/include/duckdb/function/cast/variant/struct_to_variant.hpp b/src/duckdb/src/include/duckdb/function/cast/variant/struct_to_variant.hpp index 5a8b088ae..209598a74 100644 --- a/src/duckdb/src/include/duckdb/function/cast/variant/struct_to_variant.hpp +++ b/src/duckdb/src/include/duckdb/function/cast/variant/struct_to_variant.hpp @@ -98,7 +98,7 @@ bool ConvertStructToVariant(ToVariantSourceData &source, ToVariantGlobalResultDa } } if (WRITE_DATA) { - //! Now forward the selection to point to the next index in the children.values_index + //! Now move the selection forward to write the value_id for the next struct child, for each row for (idx_t i = 0; i < sel.count; i++) { sel.children_selection[i]++; } diff --git a/src/duckdb/src/include/duckdb/function/cast/variant/to_variant_fwd.hpp b/src/duckdb/src/include/duckdb/function/cast/variant/to_variant_fwd.hpp index 00edc6459..78dad70fc 100644 --- a/src/duckdb/src/include/duckdb/function/cast/variant/to_variant_fwd.hpp +++ b/src/duckdb/src/include/duckdb/function/cast/variant/to_variant_fwd.hpp @@ -110,7 +110,6 @@ template void WriteVariantMetadata(ToVariantGlobalResultData &result, idx_t result_index, uint32_t *values_offsets, uint32_t blob_offset, optional_ptr value_index_selvec, idx_t i, VariantLogicalType type_id) { - auto &values_offset_data = values_offsets[result_index]; if (WRITE_DATA) { auto &variant = result.variant; diff --git a/src/duckdb/src/include/duckdb/function/cast/variant/variant_to_variant.hpp b/src/duckdb/src/include/duckdb/function/cast/variant/variant_to_variant.hpp index 28d9db96b..482c3dcd5 100644 --- a/src/duckdb/src/include/duckdb/function/cast/variant/variant_to_variant.hpp +++ b/src/duckdb/src/include/duckdb/function/cast/variant/variant_to_variant.hpp @@ -1,99 +1,251 @@ #pragma once #include "duckdb/function/cast/variant/to_variant_fwd.hpp" +#include "duckdb/common/types/variant_visitor.hpp" namespace duckdb { namespace variant { -static bool VariantIsTrivialPrimitive(VariantLogicalType type) { - switch (type) { - case VariantLogicalType::INT8: - case VariantLogicalType::INT16: - case VariantLogicalType::INT32: - case VariantLogicalType::INT64: - case VariantLogicalType::INT128: - case VariantLogicalType::UINT8: - case VariantLogicalType::UINT16: - case VariantLogicalType::UINT32: - case VariantLogicalType::UINT64: - case VariantLogicalType::UINT128: - case VariantLogicalType::FLOAT: - case VariantLogicalType::DOUBLE: - case VariantLogicalType::UUID: - case VariantLogicalType::DATE: - case VariantLogicalType::TIME_MICROS: - case VariantLogicalType::TIME_NANOS: - case VariantLogicalType::TIMESTAMP_SEC: - case VariantLogicalType::TIMESTAMP_MILIS: - case VariantLogicalType::TIMESTAMP_MICROS: - case VariantLogicalType::TIMESTAMP_NANOS: - case VariantLogicalType::TIME_MICROS_TZ: - case VariantLogicalType::TIMESTAMP_MICROS_TZ: - case VariantLogicalType::INTERVAL: - return true; - default: - return false; +namespace { + +struct AnalyzeState { +public: + explicit AnalyzeState(uint32_t &children_offset) : children_offset(children_offset) { } -} -static uint32_t VariantTrivialPrimitiveSize(VariantLogicalType type) { - switch (type) { - case VariantLogicalType::INT8: - return sizeof(int8_t); - case VariantLogicalType::INT16: - return sizeof(int16_t); - case VariantLogicalType::INT32: - return sizeof(int32_t); - case VariantLogicalType::INT64: - return sizeof(int64_t); - case VariantLogicalType::INT128: - return sizeof(hugeint_t); - case VariantLogicalType::UINT8: - return sizeof(uint8_t); - case VariantLogicalType::UINT16: - return sizeof(uint16_t); - case VariantLogicalType::UINT32: - return sizeof(uint32_t); - case VariantLogicalType::UINT64: - return sizeof(uint64_t); - case VariantLogicalType::UINT128: - return sizeof(uhugeint_t); - case VariantLogicalType::FLOAT: +public: + uint32_t &children_offset; +}; + +struct WriteState { +public: + WriteState(uint32_t &keys_offset, uint32_t &children_offset, uint32_t &blob_offset, data_ptr_t blob_data, + uint32_t &blob_size) + : keys_offset(keys_offset), children_offset(children_offset), blob_offset(blob_offset), blob_data(blob_data), + blob_size(blob_size) { + } + +public: + inline data_ptr_t GetDestination() { + return blob_data + blob_offset + blob_size; + } + +public: + uint32_t &keys_offset; + uint32_t &children_offset; + uint32_t &blob_offset; + data_ptr_t blob_data; + uint32_t &blob_size; +}; + +struct VariantToVariantSizeAnalyzer { + using result_type = uint32_t; + + static uint32_t VisitNull(AnalyzeState &state) { + return 0; + } + static uint32_t VisitBoolean(bool, AnalyzeState &state) { + return 0; + } + + template + static uint32_t VisitInteger(T, AnalyzeState &state) { + return sizeof(T); + } + + static uint32_t VisitFloat(float, AnalyzeState &state) { return sizeof(float); - case VariantLogicalType::DOUBLE: + } + static uint32_t VisitDouble(double, AnalyzeState &state) { return sizeof(double); - case VariantLogicalType::UUID: + } + static uint32_t VisitUUID(hugeint_t, AnalyzeState &state) { return sizeof(hugeint_t); - case VariantLogicalType::DATE: + } + static uint32_t VisitDate(date_t, AnalyzeState &state) { return sizeof(int32_t); - case VariantLogicalType::TIME_MICROS: + } + static uint32_t VisitInterval(interval_t, AnalyzeState &state) { + return sizeof(interval_t); + } + + static uint32_t VisitTime(dtime_t, AnalyzeState &state) { return sizeof(dtime_t); - case VariantLogicalType::TIME_NANOS: + } + static uint32_t VisitTimeNanos(dtime_ns_t, AnalyzeState &state) { return sizeof(dtime_ns_t); - case VariantLogicalType::TIMESTAMP_SEC: + } + static uint32_t VisitTimeTZ(dtime_tz_t, AnalyzeState &state) { + return sizeof(dtime_tz_t); + } + static uint32_t VisitTimestampSec(timestamp_sec_t, AnalyzeState &state) { return sizeof(timestamp_sec_t); - case VariantLogicalType::TIMESTAMP_MILIS: + } + static uint32_t VisitTimestampMs(timestamp_ms_t, AnalyzeState &state) { return sizeof(timestamp_ms_t); - case VariantLogicalType::TIMESTAMP_MICROS: + } + static uint32_t VisitTimestamp(timestamp_t, AnalyzeState &state) { return sizeof(timestamp_t); - case VariantLogicalType::TIMESTAMP_NANOS: + } + static uint32_t VisitTimestampNanos(timestamp_ns_t, AnalyzeState &state) { return sizeof(timestamp_ns_t); - case VariantLogicalType::TIME_MICROS_TZ: - return sizeof(dtime_tz_t); - case VariantLogicalType::TIMESTAMP_MICROS_TZ: + } + static uint32_t VisitTimestampTZ(timestamp_tz_t, AnalyzeState &state) { return sizeof(timestamp_tz_t); - case VariantLogicalType::INTERVAL: - return sizeof(interval_t); - default: - throw InternalException("VariantLogicalType '%s' is not a trivial primitive", EnumUtil::ToString(type)); } -} + + static uint32_t VisitString(const string_t &str, AnalyzeState &state) { + auto length = static_cast(str.GetSize()); + return GetVarintSize(length) + length; + } + + static uint32_t VisitBlob(const string_t &blob, AnalyzeState &state) { + return VisitString(blob, state); + } + static uint32_t VisitBignum(const string_t &bignum, AnalyzeState &state) { + return VisitString(bignum, state); + } + static uint32_t VisitGeometry(const string_t &geom, AnalyzeState &state) { + return VisitString(geom, state); + } + static uint32_t VisitBitstring(const string_t &bits, AnalyzeState &state) { + return VisitString(bits, state); + } + + template + static uint32_t VisitDecimal(T, uint32_t width, uint32_t scale, AnalyzeState &state) { + uint32_t size = GetVarintSize(width) + GetVarintSize(scale); + size += sizeof(T); + return size; + } + + static uint32_t VisitArray(const UnifiedVariantVectorData &variant, idx_t row, const VariantNestedData &nested_data, + AnalyzeState &state) { + uint32_t size = GetVarintSize(nested_data.child_count); + if (nested_data.child_count) { + size += GetVarintSize(nested_data.children_idx + state.children_offset); + } + return size; + } + + static uint32_t VisitObject(const UnifiedVariantVectorData &variant, idx_t row, + const VariantNestedData &nested_data, AnalyzeState &state) { + return VisitArray(variant, row, nested_data, state); + } + + static uint32_t VisitDefault(VariantLogicalType type_id, const_data_ptr_t, AnalyzeState &) { + throw InternalException("Unrecognized VariantLogicalType: %s", EnumUtil::ToString(type_id)); + } +}; + +struct VariantToVariantDataWriter { + using result_type = void; + + static void VisitNull(WriteState &state) { + return; + } + static void VisitBoolean(bool, WriteState &state) { + return; + } + + template + static void VisitInteger(T val, WriteState &state) { + Store(val, state.GetDestination()); + state.blob_size += sizeof(T); + } + static void VisitFloat(float val, WriteState &state) { + VisitInteger(val, state); + } + static void VisitDouble(double val, WriteState &state) { + VisitInteger(val, state); + } + static void VisitUUID(hugeint_t val, WriteState &state) { + VisitInteger(val, state); + } + static void VisitDate(date_t val, WriteState &state) { + VisitInteger(val, state); + } + static void VisitInterval(interval_t val, WriteState &state) { + VisitInteger(val, state); + } + static void VisitTime(dtime_t val, WriteState &state) { + VisitInteger(val, state); + } + static void VisitTimeNanos(dtime_ns_t val, WriteState &state) { + VisitInteger(val, state); + } + static void VisitTimeTZ(dtime_tz_t val, WriteState &state) { + VisitInteger(val, state); + } + static void VisitTimestampSec(timestamp_sec_t val, WriteState &state) { + VisitInteger(val, state); + } + static void VisitTimestampMs(timestamp_ms_t val, WriteState &state) { + VisitInteger(val, state); + } + static void VisitTimestamp(timestamp_t val, WriteState &state) { + VisitInteger(val, state); + } + static void VisitTimestampNanos(timestamp_ns_t val, WriteState &state) { + VisitInteger(val, state); + } + static void VisitTimestampTZ(timestamp_tz_t val, WriteState &state) { + VisitInteger(val, state); + } + + static void VisitString(const string_t &str, WriteState &state) { + auto length = str.GetSize(); + state.blob_size += VarintEncode(length, state.GetDestination()); + memcpy(state.GetDestination(), str.GetData(), length); + state.blob_size += length; + } + static void VisitBlob(const string_t &blob, WriteState &state) { + return VisitString(blob, state); + } + static void VisitBignum(const string_t &bignum, WriteState &state) { + return VisitString(bignum, state); + } + static void VisitGeometry(const string_t &geom, WriteState &state) { + return VisitString(geom, state); + } + static void VisitBitstring(const string_t &bits, WriteState &state) { + return VisitString(bits, state); + } + + template + static void VisitDecimal(T val, uint32_t width, uint32_t scale, WriteState &state) { + state.blob_size += VarintEncode(width, state.GetDestination()); + state.blob_size += VarintEncode(scale, state.GetDestination()); + Store(val, state.GetDestination()); + state.blob_size += sizeof(T); + } + + static void VisitArray(const UnifiedVariantVectorData &variant, idx_t row, const VariantNestedData &nested_data, + WriteState &state) { + state.blob_size += VarintEncode(nested_data.child_count, state.GetDestination()); + if (nested_data.child_count) { + //! NOTE: The 'child_index' stored in the OBJECT/ARRAY data could require more bits + //! That's the reason we have to rewrite the data in VARIANT->VARIANT cast + state.blob_size += VarintEncode(nested_data.children_idx + state.children_offset, state.GetDestination()); + } + } + + static void VisitObject(const UnifiedVariantVectorData &variant, idx_t row, const VariantNestedData &nested_data, + WriteState &state) { + return VisitArray(variant, row, nested_data, state); + } + + static void VisitDefault(VariantLogicalType type_id, const_data_ptr_t, WriteState &) { + throw InternalException("Unrecognized VariantLogicalType: %s", EnumUtil::ToString(type_id)); + } +}; + +} // namespace template bool ConvertVariantToVariant(ToVariantSourceData &source_data, ToVariantGlobalResultData &result_data, idx_t count, optional_ptr selvec, optional_ptr values_index_selvec, const bool is_root) { - auto keys_offset_data = OffsetData::GetKeys(result_data.offsets); auto children_offset_data = OffsetData::GetChildren(result_data.offsets); auto values_offset_data = OffsetData::GetValues(result_data.offsets); @@ -168,99 +320,26 @@ bool ConvertVariantToVariant(ToVariantSourceData &source_data, ToVariantGlobalRe } } - auto source_blob_data = const_data_ptr_cast(source.GetData(source_index).GetData()); - - //! Then write all values auto source_values_list_entry = source.GetValuesListEntry(source_index); - for (uint32_t source_value_index = 0; source_value_index < source_values_list_entry.length; - source_value_index++) { - auto source_type_id = source.GetTypeId(source_index, source_value_index); - auto source_byte_offset = source.GetByteOffset(source_index, source_value_index); - - //! NOTE: we have to deserialize these in both passes - //! because to figure out the size of the 'data' that is added by the VARIANT, we have to traverse the - //! VARIANT solely because the 'child_index' stored in the OBJECT/ARRAY data could require more bits - WriteVariantMetadata(result_data, result_index, values_offset_data, blob_offset + blob_size, - nullptr, 0, source_type_id); - - if (source_type_id == VariantLogicalType::ARRAY || source_type_id == VariantLogicalType::OBJECT) { - auto source_nested_data = VariantUtils::DecodeNestedData(source, source_index, source_value_index); - if (WRITE_DATA) { - VarintEncode(source_nested_data.child_count, blob_data + blob_offset + blob_size); - } - blob_size += GetVarintSize(source_nested_data.child_count); - if (source_nested_data.child_count) { - auto new_child_index = source_nested_data.children_idx + children_offset; - if (WRITE_DATA) { - VarintEncode(new_child_index, blob_data + blob_offset + blob_size); - } - blob_size += GetVarintSize(new_child_index); - } - } else if (source_type_id == VariantLogicalType::VARIANT_NULL || - source_type_id == VariantLogicalType::BOOL_FALSE || - source_type_id == VariantLogicalType::BOOL_TRUE) { - // no-op - } else if (source_type_id == VariantLogicalType::DECIMAL) { - auto decimal_blob_data = source_blob_data + source_byte_offset; - auto width = static_cast(VarintDecode(decimal_blob_data)); - auto width_varint_size = GetVarintSize(width); - if (WRITE_DATA) { - memcpy(blob_data + blob_offset + blob_size, decimal_blob_data - width_varint_size, - width_varint_size); - } - blob_size += width_varint_size; - auto scale = static_cast(VarintDecode(decimal_blob_data)); - auto scale_varint_size = GetVarintSize(scale); - if (WRITE_DATA) { - memcpy(blob_data + blob_offset + blob_size, decimal_blob_data - scale_varint_size, - scale_varint_size); - } - blob_size += scale_varint_size; - - if (width > DecimalWidth::max) { - if (WRITE_DATA) { - memcpy(blob_data + blob_offset + blob_size, decimal_blob_data, sizeof(hugeint_t)); - } - blob_size += sizeof(hugeint_t); - } else if (width > DecimalWidth::max) { - if (WRITE_DATA) { - memcpy(blob_data + blob_offset + blob_size, decimal_blob_data, sizeof(int64_t)); - } - blob_size += sizeof(int64_t); - } else if (width > DecimalWidth::max) { - if (WRITE_DATA) { - memcpy(blob_data + blob_offset + blob_size, decimal_blob_data, sizeof(int32_t)); - } - blob_size += sizeof(int32_t); - } else { - if (WRITE_DATA) { - memcpy(blob_data + blob_offset + blob_size, decimal_blob_data, sizeof(int16_t)); - } - blob_size += sizeof(int16_t); - } - } else if (source_type_id == VariantLogicalType::BITSTRING || - source_type_id == VariantLogicalType::BIGNUM || source_type_id == VariantLogicalType::VARCHAR || - source_type_id == VariantLogicalType::BLOB) { - auto str_blob_data = source_blob_data + source_byte_offset; - auto str_length = VarintDecode(str_blob_data); - auto str_length_varint_size = GetVarintSize(str_length); - if (WRITE_DATA) { - memcpy(blob_data + blob_offset + blob_size, str_blob_data - str_length_varint_size, - str_length_varint_size); - } - blob_size += str_length_varint_size; - if (WRITE_DATA) { - memcpy(blob_data + blob_offset + blob_size, str_blob_data, str_length); - } - blob_size += str_length; - } else if (VariantIsTrivialPrimitive(source_type_id)) { - auto size = VariantTrivialPrimitiveSize(source_type_id); - if (WRITE_DATA) { - memcpy(blob_data + blob_offset + blob_size, source_blob_data + source_byte_offset, size); - } - blob_size += size; - } else { - throw InternalException("Unrecognized VariantLogicalType: %s", EnumUtil::ToString(source_type_id)); + + if (WRITE_DATA) { + WriteState write_state(keys_offset, children_offset, blob_offset, blob_data, blob_size); + for (uint32_t source_value_index = 0; source_value_index < source_values_list_entry.length; + source_value_index++) { + auto source_type_id = source.GetTypeId(source_index, source_value_index); + WriteVariantMetadata(result_data, result_index, values_offset_data, blob_offset + blob_size, + nullptr, 0, source_type_id); + + VariantVisitor::Visit(source, source_index, source_value_index, + write_state); + } + } else { + AnalyzeState analyze_state(children_offset); + for (uint32_t source_value_index = 0; source_value_index < source_values_list_entry.length; + source_value_index++) { + values_offset_data[result_index]++; + blob_size += VariantVisitor::Visit(source, source_index, + source_value_index, analyze_state); } } diff --git a/src/duckdb/src/include/duckdb/function/compression_function.hpp b/src/duckdb/src/include/duckdb/function/compression_function.hpp index 64b1c2a58..97fff72b1 100644 --- a/src/duckdb/src/include/duckdb/function/compression_function.hpp +++ b/src/duckdb/src/include/duckdb/function/compression_function.hpp @@ -17,6 +17,7 @@ #include "duckdb/storage/data_pointer.hpp" #include "duckdb/storage/storage_info.hpp" #include "duckdb/storage/block_manager.hpp" +#include "duckdb/main/client_context.hpp" #include "duckdb/storage/storage_lock.hpp" namespace duckdb { @@ -28,7 +29,6 @@ class SegmentStatistics; class TableFilter; struct TableFilterState; struct ColumnSegmentState; - struct ColumnFetchState; struct ColumnScanState; struct PrefetchState; @@ -174,7 +174,8 @@ typedef void (*compression_compress_finalize_t)(CompressionState &state); // Uncompress / Scan //===--------------------------------------------------------------------===// typedef void (*compression_init_prefetch_t)(ColumnSegment &segment, PrefetchState &prefetch_state); -typedef unique_ptr (*compression_init_segment_scan_t)(ColumnSegment &segment); +typedef unique_ptr (*compression_init_segment_scan_t)(const QueryContext &context, + ColumnSegment &segment); //! Function prototype used for reading an entire vector (STANDARD_VECTOR_SIZE) typedef void (*compression_scan_vector_t)(ColumnSegment &segment, ColumnScanState &state, idx_t scan_count, @@ -221,7 +222,8 @@ typedef void (*compression_cleanup_state_t)(ColumnSegment &segment); // GetSegmentInfo (optional) //===--------------------------------------------------------------------===// //! Function prototype for retrieving segment information straight from the column segment -typedef InsertionOrderPreservingMap (*compression_get_segment_info_t)(ColumnSegment &segment); +typedef InsertionOrderPreservingMap (*compression_get_segment_info_t)(QueryContext context, + ColumnSegment &segment); enum class CompressionValidity : uint8_t { REQUIRES_VALIDITY, NO_VALIDITY_REQUIRED }; diff --git a/src/duckdb/src/include/duckdb/function/copy_function.hpp b/src/duckdb/src/include/duckdb/function/copy_function.hpp index cfd379c0a..1b035ebd1 100644 --- a/src/duckdb/src/include/duckdb/function/copy_function.hpp +++ b/src/duckdb/src/include/duckdb/function/copy_function.hpp @@ -23,6 +23,21 @@ class ColumnDataCollection; class ExecutionContext; class PhysicalOperatorLogger; +struct CopyFunctionInfo { + virtual ~CopyFunctionInfo() = default; + + template + TARGET &Cast() { + DynamicCastCheck(this); + return reinterpret_cast(*this); + } + template + const TARGET &Cast() const { + DynamicCastCheck(this); + return reinterpret_cast(*this); + } +}; + struct LocalFunctionData { virtual ~LocalFunctionData() = default; @@ -69,11 +84,12 @@ struct PreparedBatchData { }; struct CopyFunctionBindInput { - explicit CopyFunctionBindInput(const CopyInfo &info_p) : info(info_p) { + explicit CopyFunctionBindInput(const CopyInfo &info_p, shared_ptr function_info = nullptr) + : info(info_p), function_info(std::move(function_info)) { } const CopyInfo &info; - + shared_ptr function_info; string file_extension; }; @@ -199,6 +215,9 @@ class CopyFunction : public Function { // NOLINT: work-around bug in clang-tidy TableFunction copy_from_function; string extension; + + //! Additional function info, passed to the bind + shared_ptr function_info; }; } // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/function/create_sort_key.hpp b/src/duckdb/src/include/duckdb/function/create_sort_key.hpp index 0ce926c1f..b2b5c08c3 100644 --- a/src/duckdb/src/include/duckdb/function/create_sort_key.hpp +++ b/src/duckdb/src/include/duckdb/function/create_sort_key.hpp @@ -48,7 +48,7 @@ struct OrderModifiers { struct CreateSortKeyHelpers { static void CreateSortKey(DataChunk &input, const vector &modifiers, Vector &result); static void CreateSortKey(Vector &input, idx_t input_count, OrderModifiers modifiers, Vector &result); - static void DecodeSortKey(string_t sort_key, Vector &result, idx_t result_idx, OrderModifiers modifiers); + static idx_t DecodeSortKey(string_t sort_key, Vector &result, idx_t result_idx, OrderModifiers modifiers); static void DecodeSortKey(string_t sort_key, DataChunk &result, idx_t result_idx, const vector &modifiers); static void CreateSortKeyWithValidity(Vector &input, Vector &result, const OrderModifiers &modifiers, diff --git a/src/duckdb/src/include/duckdb/function/function.hpp b/src/duckdb/src/include/duckdb/function/function.hpp index 587216421..bd9960319 100644 --- a/src/duckdb/src/include/duckdb/function/function.hpp +++ b/src/duckdb/src/include/duckdb/function/function.hpp @@ -175,6 +175,55 @@ class BaseScalarFunction : public SimpleFunction { FunctionErrors errors = FunctionErrors::CANNOT_ERROR); DUCKDB_API ~BaseScalarFunction() override; +public: + void SetReturnType(LogicalType return_type_p) { + return_type = std::move(return_type_p); + } + const LogicalType &GetReturnType() const { + return return_type; + } + LogicalType &GetReturnType() { + return return_type; + } + + FunctionStability GetStability() const { + return stability; + } + void SetStability(FunctionStability stability_p) { + stability = stability_p; + } + + FunctionNullHandling GetNullHandling() const { + return null_handling; + } + void SetNullHandling(FunctionNullHandling null_handling_p) { + null_handling = null_handling_p; + } + + FunctionErrors GetErrorMode() const { + return errors; + } + void SetErrorMode(FunctionErrors errors_p) { + errors = errors_p; + } + + //! Set this functions error-mode as fallible (can throw runtime errors) + void SetFallible() { + errors = FunctionErrors::CAN_THROW_RUNTIME_ERROR; + } + //! Set this functions stability as volatile (can not be cached per row) + void SetVolatile() { + stability = FunctionStability::VOLATILE; + } + + void SetCollationHandling(FunctionCollationHandling collation_handling_p) { + collation_handling = collation_handling_p; + } + FunctionCollationHandling GetCollationHandling() const { + return collation_handling; + } + +public: //! Return type of the function LogicalType return_type; //! The stability of the function (see FunctionStability enum for more info) diff --git a/src/duckdb/src/include/duckdb/function/function_serialization.hpp b/src/duckdb/src/include/duckdb/function/function_serialization.hpp index d7d3480ac..55a7218c6 100644 --- a/src/duckdb/src/include/duckdb/function/function_serialization.hpp +++ b/src/duckdb/src/include/duckdb/function/function_serialization.hpp @@ -156,7 +156,6 @@ class FunctionSerializer { bind_data = FunctionDeserialize(deserializer, function); deserializer.Unset(); } else { - FunctionBinder binder(context); // Resolve templates @@ -178,8 +177,8 @@ class FunctionSerializer { binder.CastToFunctionArguments(function, children); } - if (TypeRequiresAssignment(function.return_type)) { - function.return_type = std::move(return_type); + if (TypeRequiresAssignment(function.GetReturnType())) { + function.SetReturnType(std::move(return_type)); } return make_pair(std::move(function), std::move(bind_data)); } diff --git a/src/duckdb/src/include/duckdb/function/scalar/geometry_functions.hpp b/src/duckdb/src/include/duckdb/function/scalar/geometry_functions.hpp new file mode 100644 index 000000000..7a15ba00c --- /dev/null +++ b/src/duckdb/src/include/duckdb/function/scalar/geometry_functions.hpp @@ -0,0 +1,76 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// function/scalar/geometry_functions.hpp +// +// +//===----------------------------------------------------------------------===// +// This file is automatically generated by scripts/generate_functions.py +// Do not edit this file manually, your changes will be overwritten +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/function/function_set.hpp" + +namespace duckdb { + +struct StGeomfromwkbFun { + static constexpr const char *Name = "st_geomfromwkb"; + static constexpr const char *Parameters = "wkb"; + static constexpr const char *Description = "Creates a geometry from Well-Known Binary (WKB) representation"; + static constexpr const char *Example = "ST_GeomFromWKB(X'01010000000000000000000000000000000000000000000000')"; + static constexpr const char *Categories = "geometry"; + + static ScalarFunction GetFunction(); +}; + +struct StAswkbFun { + static constexpr const char *Name = "st_aswkb"; + static constexpr const char *Parameters = "geom"; + static constexpr const char *Description = "Returns the Well-Known Binary (WKB) representation of the geometry"; + static constexpr const char *Example = "st_aswkb(ST_GeomFromWKB(X'01010000000000000000000000000000000000000000000000000'))"; + static constexpr const char *Categories = "geometry"; + + static ScalarFunction GetFunction(); +}; + +struct StAsbinaryFun { + using ALIAS = StAswkbFun; + + static constexpr const char *Name = "st_asbinary"; +}; + +struct StAstextFun { + static constexpr const char *Name = "st_astext"; + static constexpr const char *Parameters = "geom"; + static constexpr const char *Description = "Returns the Well-Known Text (WKT) representation of the geometry"; + static constexpr const char *Example = "ST_AsText(ST_GeomFromWKB(X'01010000000000000000000000000000000000000000000000'))"; + static constexpr const char *Categories = "geometry"; + + static ScalarFunction GetFunction(); +}; + +struct StAswktFun { + using ALIAS = StAstextFun; + + static constexpr const char *Name = "st_aswkt"; +}; + +struct StIntersectsExtentFun { + static constexpr const char *Name = "st_intersects_extent"; + static constexpr const char *Parameters = "geom1,geom2"; + static constexpr const char *Description = "Returns true if the geometries bounding boxes intersect"; + static constexpr const char *Example = "'POINT(5 5)'::GEOMETRY && 'LINESTRING(0 0, 10 20)'::GEOMETRY;"; + static constexpr const char *Categories = "geometry"; + + static ScalarFunction GetFunction(); +}; + +struct StIntersectsExtentFunAlias { + using ALIAS = StIntersectsExtentFun; + + static constexpr const char *Name = "&&"; +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/function/scalar/regexp.hpp b/src/duckdb/src/include/duckdb/function/scalar/regexp.hpp index 5ac80ab08..2ad1e694b 100644 --- a/src/duckdb/src/include/duckdb/function/scalar/regexp.hpp +++ b/src/duckdb/src/include/duckdb/function/scalar/regexp.hpp @@ -20,6 +20,11 @@ namespace regexp_util { bool TryParseConstantPattern(ClientContext &context, Expression &expr, string &constant_string); void ParseRegexOptions(const string &options, duckdb_re2::RE2::Options &result, bool *global_replace = nullptr); void ParseRegexOptions(ClientContext &context, Expression &expr, RE2::Options &target, bool *global_replace = nullptr); +void ParseGroupNameList(ClientContext &context, const string &function_name, Expression &group_expr, + const string &pattern_string, RE2::Options &options, bool require_constant_pattern, + vector &out_names, child_list_t &out_struct_children); + +idx_t AdvanceOneUTF8Basic(const duckdb_re2::StringPiece &input, idx_t base); inline duckdb_re2::StringPiece CreateStringPiece(const string_t &input) { return duckdb_re2::StringPiece(input.GetData(), input.GetSize()); @@ -53,6 +58,33 @@ struct RegexpBaseBindData : public FunctionData { bool Equals(const FunctionData &other_p) const override; }; +struct RegexpExtractAllStructBindData : public RegexpBaseBindData { + RegexpExtractAllStructBindData(duckdb_re2::RE2::Options options, string constant_string, bool constant_pattern, + vector group_names) + : RegexpBaseBindData(options, std::move(constant_string), constant_pattern), + group_names(std::move(group_names)) { + } + + vector group_names; // order preserved + + unique_ptr Copy() const override { + return make_uniq(options, constant_string, constant_pattern, group_names); + } + + bool Equals(const FunctionData &other_p) const override { + auto &other = other_p.Cast(); + return RegexpBaseBindData::Equals(other) && group_names == other.group_names; + } +}; + +struct RegexpExtractAllStruct { + static void Execute(DataChunk &args, ExpressionState &state, Vector &result); + static unique_ptr Bind(ClientContext &context, ScalarFunction &bound_function, + vector> &arguments); + static unique_ptr InitLocalState(ExpressionState &state, const BoundFunctionExpression &expr, + FunctionData *bind_data); +}; + struct RegexpMatchesBindData : public RegexpBaseBindData { RegexpMatchesBindData(duckdb_re2::RE2::Options options, string constant_string, bool constant_pattern); RegexpMatchesBindData(duckdb_re2::RE2::Options options, string constant_string, bool constant_pattern, diff --git a/src/duckdb/src/include/duckdb/function/scalar/variant_functions.hpp b/src/duckdb/src/include/duckdb/function/scalar/variant_functions.hpp index 6408639ec..c318a9236 100644 --- a/src/duckdb/src/include/duckdb/function/scalar/variant_functions.hpp +++ b/src/duckdb/src/include/duckdb/function/scalar/variant_functions.hpp @@ -25,11 +25,21 @@ struct VariantExtractFun { static ScalarFunctionSet GetFunctions(); }; +struct VariantNormalizeFun { + static constexpr const char *Name = "variant_normalize"; + static constexpr const char *Parameters = "input_variant"; + static constexpr const char *Description = "Normalizes the `input_variant` to a canonical representation."; + static constexpr const char *Example = "variant_normalize({'b': [1,2,3], 'a': 42})::VARIANT)"; + static constexpr const char *Categories = "variant"; + + static ScalarFunction GetFunction(); +}; + struct VariantTypeofFun { static constexpr const char *Name = "variant_typeof"; static constexpr const char *Parameters = "input_variant"; static constexpr const char *Description = "Returns the internal type of the `input_variant`."; - static constexpr const char *Example = "variant_typeof({'a': 42, 'b': [1,2,3])::VARIANT)"; + static constexpr const char *Example = "variant_typeof({'a': 42, 'b': [1,2,3]})::VARIANT)"; static constexpr const char *Categories = "variant"; static ScalarFunction GetFunction(); diff --git a/src/duckdb/src/include/duckdb/function/scalar/variant_utils.hpp b/src/duckdb/src/include/duckdb/function/scalar/variant_utils.hpp index f0c4cb82b..1c20b19c0 100644 --- a/src/duckdb/src/include/duckdb/function/scalar/variant_utils.hpp +++ b/src/duckdb/src/include/duckdb/function/scalar/variant_utils.hpp @@ -66,20 +66,25 @@ struct VariantUtils { uint32_t value_index); DUCKDB_API static VariantNestedData DecodeNestedData(const UnifiedVariantVectorData &variant, idx_t row, uint32_t value_index); + DUCKDB_API static string_t DecodeStringData(const UnifiedVariantVectorData &variant, idx_t row, + uint32_t value_index); DUCKDB_API static vector GetObjectKeys(const UnifiedVariantVectorData &variant, idx_t row, const VariantNestedData &nested_data); - DUCKDB_API static VariantChildDataCollectionResult FindChildValues(const UnifiedVariantVectorData &variant, - const VariantPathComponent &component, - optional_idx row, SelectionVector &res, - VariantNestedData *nested_data, idx_t count); + DUCKDB_API static void FindChildValues(const UnifiedVariantVectorData &variant, + const VariantPathComponent &component, + optional_ptr sel, SelectionVector &res, + ValidityMask &res_validity, VariantNestedData *nested_data, idx_t count); DUCKDB_API static VariantNestedDataCollectionResult CollectNestedData(const UnifiedVariantVectorData &variant, VariantLogicalType expected_type, const SelectionVector &sel, idx_t count, optional_idx row, idx_t offset, VariantNestedData *child_data, ValidityMask &validity); DUCKDB_API static vector ValueIsNull(const UnifiedVariantVectorData &variant, const SelectionVector &sel, idx_t count, optional_idx row); - DUCKDB_API static Value ConvertVariantToValue(const UnifiedVariantVectorData &variant, idx_t row, idx_t values_idx); + DUCKDB_API static Value ConvertVariantToValue(const UnifiedVariantVectorData &variant, idx_t row, + uint32_t values_idx); DUCKDB_API static bool Verify(Vector &variant, const SelectionVector &sel_p, idx_t count); + DUCKDB_API static void FinalizeVariantKeys(Vector &variant, OrderedOwningStringMap &dictionary, + SelectionVector &sel, idx_t sel_size); }; } // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/function/table/direct_file_reader.hpp b/src/duckdb/src/include/duckdb/function/table/direct_file_reader.hpp index 02d38ec76..e553a6d41 100644 --- a/src/duckdb/src/include/duckdb/function/table/direct_file_reader.hpp +++ b/src/duckdb/src/include/duckdb/function/table/direct_file_reader.hpp @@ -23,8 +23,8 @@ class DirectFileReader : public BaseFileReader { bool TryInitializeScan(ClientContext &context, GlobalTableFunctionState &gstate, LocalTableFunctionState &lstate) override; - void Scan(ClientContext &context, GlobalTableFunctionState &global_state, LocalTableFunctionState &local_state, - DataChunk &chunk) override; + AsyncResult Scan(ClientContext &context, GlobalTableFunctionState &global_state, + LocalTableFunctionState &local_state, DataChunk &chunk) override; void FinishFile(ClientContext &context, GlobalTableFunctionState &gstate) override; string GetReaderType() const override { diff --git a/src/duckdb/src/include/duckdb/function/table/read_file.hpp b/src/duckdb/src/include/duckdb/function/table/read_file.hpp index 966fea5ef..a0ef222fe 100644 --- a/src/duckdb/src/include/duckdb/function/table/read_file.hpp +++ b/src/duckdb/src/include/duckdb/function/table/read_file.hpp @@ -31,53 +31,4 @@ struct ReadFileGlobalState : public GlobalTableFunctionState { bool requires_file_open = false; }; -struct ReadBlobOperation { - static constexpr const char *NAME = "read_blob"; - static constexpr const char *FILE_TYPE = "blob"; - - static inline LogicalType TYPE() { - return LogicalType::BLOB; - } -}; - -struct ReadTextOperation { - static constexpr const char *NAME = "read_text"; - static constexpr const char *FILE_TYPE = "text"; - - static inline LogicalType TYPE() { - return LogicalType::VARCHAR; - } -}; - -template -struct DirectMultiFileInfo : MultiFileReaderInterface { - static unique_ptr CreateInterface(ClientContext &context); - unique_ptr InitializeOptions(ClientContext &context, - optional_ptr info) override; - bool ParseCopyOption(ClientContext &context, const string &key, const vector &values, - BaseFileReaderOptions &options, vector &expected_names, - vector &expected_types) override; - bool ParseOption(ClientContext &context, const string &key, const Value &val, MultiFileOptions &file_options, - BaseFileReaderOptions &options) override; - unique_ptr InitializeBindData(MultiFileBindData &multi_file_data, - unique_ptr options) override; - void BindReader(ClientContext &context, vector &return_types, vector &names, - MultiFileBindData &bind_data) override; - optional_idx MaxThreads(const MultiFileBindData &bind_data_p, const MultiFileGlobalState &global_state, - FileExpandResult expand_result) override; - unique_ptr InitializeGlobalState(ClientContext &context, MultiFileBindData &bind_data, - MultiFileGlobalState &global_state) override; - unique_ptr InitializeLocalState(ExecutionContext &, GlobalTableFunctionState &) override; - shared_ptr CreateReader(ClientContext &context, GlobalTableFunctionState &gstate, - BaseUnionData &union_data, const MultiFileBindData &bind_data_p) override; - shared_ptr CreateReader(ClientContext &context, GlobalTableFunctionState &gstate, - const OpenFileInfo &file, idx_t file_idx, - const MultiFileBindData &bind_data) override; - shared_ptr CreateReader(ClientContext &context, const OpenFileInfo &file, - BaseFileReaderOptions &options, - const MultiFileOptions &file_options) override; - unique_ptr GetCardinality(const MultiFileBindData &bind_data, idx_t file_count) override; - FileGlobInput GetGlobInput() override; -}; - } // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/function/table/system_functions.hpp b/src/duckdb/src/include/duckdb/function/table/system_functions.hpp index e325b2f46..49c5e794c 100644 --- a/src/duckdb/src/include/duckdb/function/table/system_functions.hpp +++ b/src/duckdb/src/include/duckdb/function/table/system_functions.hpp @@ -47,6 +47,10 @@ struct DuckDBSchemasFun { static void RegisterFunction(BuiltinFunctions &set); }; +struct DuckDBConnectionCountFun { + static void RegisterFunction(BuiltinFunctions &set); +}; + struct DuckDBApproxDatabaseCountFun { static void RegisterFunction(BuiltinFunctions &set); }; diff --git a/src/duckdb/src/include/duckdb/function/table/table_scan.hpp b/src/duckdb/src/include/duckdb/function/table/table_scan.hpp index df4c829da..22407fff5 100644 --- a/src/duckdb/src/include/duckdb/function/table/table_scan.hpp +++ b/src/duckdb/src/include/duckdb/function/table/table_scan.hpp @@ -28,6 +28,8 @@ struct TableScanBindData : public TableFunctionData { bool is_index_scan; //! Whether or not the table scan is for index creation. bool is_create_index; + //! In what order to scan the row groups + unique_ptr order_options; public: bool Equals(const FunctionData &other_p) const override { diff --git a/src/duckdb/src/include/duckdb/function/table_function.hpp b/src/duckdb/src/include/duckdb/function/table_function.hpp index f6c9cc55e..6bdd980c0 100644 --- a/src/duckdb/src/include/duckdb/function/table_function.hpp +++ b/src/duckdb/src/include/duckdb/function/table_function.hpp @@ -16,6 +16,7 @@ #include "duckdb/storage/statistics/node_statistics.hpp" #include "duckdb/common/column_index.hpp" #include "duckdb/common/table_column.hpp" +#include "duckdb/parallel/async_result.hpp" #include "duckdb/function/partition_stats.hpp" #include "duckdb/common/exception/binder_exception.hpp" @@ -34,6 +35,9 @@ class SampleOptions; struct MultiFileReader; struct OperatorPartitionData; struct OperatorPartitionInfo; +enum class OrderByColumnType; +enum class RowGroupOrderType; +enum class OrderByStatistics; struct TableFunctionInfo { DUCKDB_API virtual ~TableFunctionInfo(); @@ -108,6 +112,18 @@ struct TableFunctionBindInput { const TableFunctionRef &ref; }; +struct RowGroupOrderOptions { + RowGroupOrderOptions(column_t column_idx_p, OrderByStatistics order_by_p, RowGroupOrderType order_type_p, + OrderByColumnType column_type_p) + : column_idx(column_idx_p), order_by(order_by_p), order_type(order_type_p), column_type(column_type_p) { + } + + const column_t column_idx; + const OrderByStatistics order_by; + const RowGroupOrderType order_type; + const OrderByColumnType column_type; +}; + struct TableFunctionInitInput { TableFunctionInitInput(optional_ptr bind_data_p, vector column_ids_p, const vector &projection_ids_p, optional_ptr filters_p, @@ -158,13 +174,15 @@ struct TableFunctionInput { TableFunctionInput(optional_ptr bind_data_p, optional_ptr local_state_p, optional_ptr global_state_p) - : bind_data(bind_data_p), local_state(local_state_p), global_state(global_state_p) { + : bind_data(bind_data_p), local_state(local_state_p), global_state(global_state_p), async_result() { } public: optional_ptr bind_data; optional_ptr local_state; optional_ptr global_state; + AsyncResult async_result {}; + AsyncResultsExecutionMode results_execution_mode {AsyncResultsExecutionMode::SYNCHRONOUS}; }; struct TableFunctionPartitionInput { @@ -324,19 +342,31 @@ typedef virtual_column_map_t (*table_function_get_virtual_columns_t)(ClientConte typedef vector (*table_function_get_row_id_columns)(ClientContext &context, optional_ptr bind_data); +typedef void (*table_function_set_scan_order)(unique_ptr order_options, + optional_ptr bind_data); + //! When to call init_global to initialize the table function enum class TableFunctionInitialization { INITIALIZE_ON_EXECUTE, INITIALIZE_ON_SCHEDULE }; class TableFunction : public SimpleNamedParameterFunction { // NOLINT: work-around bug in clang-tidy public: + DUCKDB_API TableFunction(); + // Overloads taking table_function_t DUCKDB_API - TableFunction(string name, vector arguments, table_function_t function, + TableFunction(string name, const vector &arguments, table_function_t function, table_function_bind_t bind = nullptr, table_function_init_global_t init_global = nullptr, table_function_init_local_t init_local = nullptr); DUCKDB_API TableFunction(const vector &arguments, table_function_t function, table_function_bind_t bind = nullptr, table_function_init_global_t init_global = nullptr, table_function_init_local_t init_local = nullptr); - DUCKDB_API TableFunction(); + // Overloads taking std::nullptr + DUCKDB_API + TableFunction(string name, const vector &arguments, std::nullptr_t function, + table_function_bind_t bind = nullptr, table_function_init_global_t init_global = nullptr, + table_function_init_local_t init_local = nullptr); + DUCKDB_API + TableFunction(const vector &arguments, std::nullptr_t function, table_function_bind_t bind = nullptr, + table_function_init_global_t init_global = nullptr, table_function_init_local_t init_local = nullptr); //! Bind function //! This function is used for determining the return type of a table producing function and returning bind data @@ -404,6 +434,8 @@ class TableFunction : public SimpleNamedParameterFunction { // NOLINT: work-arou table_function_get_virtual_columns_t get_virtual_columns; //! (Optional) returns a list of row id columns table_function_get_row_id_columns get_row_id_columns; + //! (Optional) sets the order to scan the row groups in + table_function_set_scan_order set_scan_order; table_function_serialize_t serialize; table_function_deserialize_t deserialize; diff --git a/src/duckdb/src/include/duckdb/function/udf_function.hpp b/src/duckdb/src/include/duckdb/function/udf_function.hpp index 571a49af4..3b23445e6 100644 --- a/src/duckdb/src/include/duckdb/function/udf_function.hpp +++ b/src/duckdb/src/include/duckdb/function/udf_function.hpp @@ -123,10 +123,9 @@ struct UDFWrapper { aggregate_combine_t combine, aggregate_finalize_t finalize, aggregate_simple_update_t simple_update = nullptr, bind_aggregate_function_t bind = nullptr, aggregate_destructor_t destructor = nullptr) { - AggregateFunction aggr_function(name, arguments, return_type, state_size, initialize, update, combine, finalize, simple_update, bind, destructor); - aggr_function.null_handling = FunctionNullHandling::SPECIAL_HANDLING; + aggr_function.SetNullHandling(FunctionNullHandling::SPECIAL_HANDLING); return aggr_function; } diff --git a/src/duckdb/src/include/duckdb/function/window/window_boundaries_state.hpp b/src/duckdb/src/include/duckdb/function/window/window_boundaries_state.hpp index 11c724d9b..4f007d83b 100644 --- a/src/duckdb/src/include/duckdb/function/window/window_boundaries_state.hpp +++ b/src/duckdb/src/include/duckdb/function/window/window_boundaries_state.hpp @@ -89,7 +89,6 @@ struct WindowInputExpression { }; struct WindowBoundariesState { - static bool HasPrecedingRange(const BoundWindowExpression &wexpr); static bool HasFollowingRange(const BoundWindowExpression &wexpr); static WindowBoundsSet GetWindowBounds(const BoundWindowExpression &wexpr); diff --git a/src/duckdb/src/include/duckdb/function/window/window_collection.hpp b/src/duckdb/src/include/duckdb/function/window/window_collection.hpp index 95cf0534f..2dae27c6a 100644 --- a/src/duckdb/src/include/duckdb/function/window/window_collection.hpp +++ b/src/duckdb/src/include/duckdb/function/window/window_collection.hpp @@ -190,7 +190,6 @@ class WindowCollectionChunkScanner { template static void WindowDeltaScanner(ColumnDataCollection &collection, idx_t block_begin, idx_t block_end, const vector &scan_cols, const idx_t key_count, OP operation) { - // Stop if there is no work to do if (!collection.Count()) { return; diff --git a/src/duckdb/src/include/duckdb/logging/log_manager.hpp b/src/duckdb/src/include/duckdb/logging/log_manager.hpp index 6ee88aeda..54f623a55 100644 --- a/src/duckdb/src/include/duckdb/logging/log_manager.hpp +++ b/src/duckdb/src/include/duckdb/logging/log_manager.hpp @@ -21,7 +21,7 @@ class LogType; // - Creates Loggers with cached configuration // - Main sink for logs (either by logging directly into this, or by syncing a pre-cached set of log entries) // - Holds the log storage -class LogManager : public enable_shared_from_this { +class LogManager { friend class ThreadSafeLogger; friend class ThreadLocalLogger; friend class MutableLogger; diff --git a/src/duckdb/src/include/duckdb/logging/log_type.hpp b/src/duckdb/src/include/duckdb/logging/log_type.hpp index 23d901c4e..7ce97e5ab 100644 --- a/src/duckdb/src/include/duckdb/logging/log_type.hpp +++ b/src/duckdb/src/include/duckdb/logging/log_type.hpp @@ -20,6 +20,7 @@ class PhysicalOperator; class AttachedDatabase; class RowGroup; struct DataTableInfo; +enum class MetricsType : uint8_t; //! Log types provide some structure to the formats that the different log messages can have //! For now, this holds a type that the VARCHAR value will be auto-cast into. @@ -106,6 +107,19 @@ class PhysicalOperatorLogType : public LogType { const vector> &info); }; +class MetricsLogType : public LogType { +public: + static constexpr const char *NAME = "Metrics"; + static constexpr LogLevel LEVEL = LogLevel::LOG_INFO; + + //! Construct the log type + MetricsLogType(); + + static LogicalType GetLogType(); + + static string ConstructLogMessage(const MetricsType &type, const Value &value); +}; + class CheckpointLogType : public LogType { public: static constexpr const char *NAME = "Checkpoint"; diff --git a/src/duckdb/src/include/duckdb/main/appender.hpp b/src/duckdb/src/include/duckdb/main/appender.hpp index b32025cb0..fe8b3bc68 100644 --- a/src/duckdb/src/include/duckdb/main/appender.hpp +++ b/src/duckdb/src/include/duckdb/main/appender.hpp @@ -82,6 +82,8 @@ class BaseAppender { DUCKDB_API void Flush(); //! Flush the changes made by the appender and close it. The appender cannot be used after this point DUCKDB_API void Close(); + //! Clears any appended data (without flushing). + DUCKDB_API void Clear(); //! Returns the active types of the appender. const vector &GetActiveTypes() const; diff --git a/src/duckdb/src/include/duckdb/main/attached_database.hpp b/src/duckdb/src/include/duckdb/main/attached_database.hpp index 7333d9adb..d75a7b922 100644 --- a/src/duckdb/src/include/duckdb/main/attached_database.hpp +++ b/src/duckdb/src/include/duckdb/main/attached_database.hpp @@ -34,14 +34,23 @@ enum class AttachedDatabaseType { enum class AttachVisibility { SHOWN, HIDDEN }; +//! DEFAULT is the standard ACID crash recovery mode. +//! NO_WAL_WRITES disables the WAL for the attached database, i.e., disabling the D in ACID. +//! Use this mode with caution, as it disables recovery from crashes for the file. +enum class RecoveryMode : uint8_t { DEFAULT = 0, NO_WAL_WRITES = 1 }; + class DatabaseFilePathManager; struct StoredDatabasePath { - StoredDatabasePath(DatabaseFilePathManager &manager, string path, const string &name); + StoredDatabasePath(DatabaseManager &db_manager, DatabaseFilePathManager &manager, string path, const string &name); ~StoredDatabasePath(); + DatabaseManager &db_manager; DatabaseFilePathManager &manager; string path; + +public: + void OnDetach(); }; //! AttachOptions holds information about a database we plan to attach. These options are generalized, i.e., @@ -54,6 +63,8 @@ struct AttachOptions { //! Defaults to the access mode configured in the DBConfig, unless specified otherwise. AccessMode access_mode; + //! The recovery type of the database. + RecoveryMode recovery_mode = RecoveryMode::DEFAULT; //! The file format type. The default type is a duckdb database file, but other file formats are possible. string db_type; //! Set of remaining (key, value) options @@ -112,9 +123,13 @@ class AttachedDatabase : public CatalogEntry, public enable_shared_from_this parent_catalog; optional_ptr storage_extension; + RecoveryMode recovery_mode = RecoveryMode::DEFAULT; AttachVisibility visibility = AttachVisibility::SHOWN; bool is_initial_database = false; bool is_closed = false; diff --git a/src/duckdb/src/include/duckdb/main/buffered_data/batched_buffered_data.hpp b/src/duckdb/src/include/duckdb/main/buffered_data/batched_buffered_data.hpp index c3ffadbe2..57ea17d6b 100644 --- a/src/duckdb/src/include/duckdb/main/buffered_data/batched_buffered_data.hpp +++ b/src/duckdb/src/include/duckdb/main/buffered_data/batched_buffered_data.hpp @@ -32,7 +32,7 @@ class BatchedBufferedData : public BufferedData { static constexpr const BufferedData::Type TYPE = BufferedData::Type::BATCHED; public: - explicit BatchedBufferedData(weak_ptr context); + explicit BatchedBufferedData(ClientContext &context); public: void Append(const DataChunk &chunk, idx_t batch); diff --git a/src/duckdb/src/include/duckdb/main/buffered_data/buffered_data.hpp b/src/duckdb/src/include/duckdb/main/buffered_data/buffered_data.hpp index 0f32675ce..06a72b0f6 100644 --- a/src/duckdb/src/include/duckdb/main/buffered_data/buffered_data.hpp +++ b/src/duckdb/src/include/duckdb/main/buffered_data/buffered_data.hpp @@ -28,7 +28,7 @@ class BufferedData { enum class Type { SIMPLE, BATCHED }; public: - BufferedData(Type type, weak_ptr context_p); + BufferedData(Type type, ClientContext &context); virtual ~BufferedData(); public: diff --git a/src/duckdb/src/include/duckdb/main/buffered_data/simple_buffered_data.hpp b/src/duckdb/src/include/duckdb/main/buffered_data/simple_buffered_data.hpp index 967cc1ab7..40a5a6ede 100644 --- a/src/duckdb/src/include/duckdb/main/buffered_data/simple_buffered_data.hpp +++ b/src/duckdb/src/include/duckdb/main/buffered_data/simple_buffered_data.hpp @@ -24,7 +24,7 @@ class SimpleBufferedData : public BufferedData { static constexpr const BufferedData::Type TYPE = BufferedData::Type::SIMPLE; public: - explicit SimpleBufferedData(weak_ptr context); + explicit SimpleBufferedData(ClientContext &context); ~SimpleBufferedData() override; public: diff --git a/src/duckdb/src/include/duckdb/main/capi/capi_internal.hpp b/src/duckdb/src/include/duckdb/main/capi/capi_internal.hpp index 8307b70a3..3b736d88a 100644 --- a/src/duckdb/src/include/duckdb/main/capi/capi_internal.hpp +++ b/src/duckdb/src/include/duckdb/main/capi/capi_internal.hpp @@ -51,6 +51,8 @@ struct PreparedStatementWrapper { //! Map of name -> values case_insensitive_map_t values; unique_ptr statement; + bool success = true; + ErrorData error_data; }; struct ExtractStatementsWrapper { diff --git a/src/duckdb/src/include/duckdb/main/capi/capi_internal_table.hpp b/src/duckdb/src/include/duckdb/main/capi/capi_internal_table.hpp new file mode 100644 index 000000000..f51947cf5 --- /dev/null +++ b/src/duckdb/src/include/duckdb/main/capi/capi_internal_table.hpp @@ -0,0 +1,74 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/main/capi/capi_internal_table.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb.h" +#include "duckdb/main/capi/capi_internal.hpp" +#include "duckdb/function/table_function.hpp" + +namespace duckdb { + +// These need to be shared by both the table function API and the copy function API + +struct CTableFunctionInfo : public TableFunctionInfo { + ~CTableFunctionInfo() override { + if (extra_info && delete_callback) { + delete_callback(extra_info); + } + extra_info = nullptr; + delete_callback = nullptr; + } + + duckdb_table_function_bind_t bind = nullptr; + duckdb_table_function_init_t init = nullptr; + duckdb_table_function_init_t local_init = nullptr; + duckdb_table_function_t function = nullptr; + void *extra_info = nullptr; + duckdb_delete_callback_t delete_callback = nullptr; +}; + +struct CTableBindData : public TableFunctionData { + explicit CTableBindData(CTableFunctionInfo &info) : info(info) { + } + ~CTableBindData() override { + if (bind_data && delete_callback) { + delete_callback(bind_data); + } + bind_data = nullptr; + delete_callback = nullptr; + } + + CTableFunctionInfo &info; + void *bind_data = nullptr; + duckdb_delete_callback_t delete_callback = nullptr; + unique_ptr stats; +}; + +struct CTableInternalBindInfo { + CTableInternalBindInfo(ClientContext &context, const vector ¶meters, + const named_parameter_map_t &named_parameters, vector &return_types, + vector &names, CTableBindData &bind_data, CTableFunctionInfo &function_info) + : context(context), parameters(parameters), named_parameters(named_parameters), return_types(return_types), + names(names), bind_data(bind_data), function_info(function_info), success(true) { + } + + ClientContext &context; + + vector parameters; + named_parameter_map_t named_parameters; + + vector &return_types; + vector &names; + CTableBindData &bind_data; + CTableFunctionInfo &function_info; + bool success; + string error; +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/main/capi/extension_api.hpp b/src/duckdb/src/include/duckdb/main/capi/extension_api.hpp index 2ce10061a..34e8178a3 100644 --- a/src/duckdb/src/include/duckdb/main/capi/extension_api.hpp +++ b/src/duckdb/src/include/duckdb/main/capi/extension_api.hpp @@ -475,6 +475,7 @@ typedef struct { duckdb_state (*duckdb_appender_create_query)(duckdb_connection connection, const char *query, idx_t column_count, duckdb_logical_type *types, const char *table_name, const char **column_names, duckdb_appender *out_appender); + duckdb_state (*duckdb_appender_clear)(duckdb_appender appender); // New arrow interface functions duckdb_error_data (*duckdb_to_arrow_schema)(duckdb_arrow_options arrow_options, duckdb_logical_type *types, @@ -487,6 +488,65 @@ typedef struct { duckdb_arrow_converted_schema converted_schema, duckdb_data_chunk *out_chunk); void (*duckdb_destroy_arrow_converted_schema)(duckdb_arrow_converted_schema *arrow_converted_schema); + // New configuration options functions + + duckdb_config_option (*duckdb_create_config_option)(); + void (*duckdb_destroy_config_option)(duckdb_config_option *option); + void (*duckdb_config_option_set_name)(duckdb_config_option option, const char *name); + void (*duckdb_config_option_set_type)(duckdb_config_option option, duckdb_logical_type type); + void (*duckdb_config_option_set_default_value)(duckdb_config_option option, duckdb_value default_value); + void (*duckdb_config_option_set_default_scope)(duckdb_config_option option, + duckdb_config_option_scope default_scope); + void (*duckdb_config_option_set_description)(duckdb_config_option option, const char *description); + duckdb_state (*duckdb_register_config_option)(duckdb_connection connection, duckdb_config_option option); + duckdb_value (*duckdb_client_context_get_config_option)(duckdb_client_context context, const char *name, + duckdb_config_option_scope *out_scope); + // API to define custom copy functions + + duckdb_copy_function (*duckdb_create_copy_function)(); + void (*duckdb_copy_function_set_name)(duckdb_copy_function copy_function, const char *name); + void (*duckdb_copy_function_set_extra_info)(duckdb_copy_function copy_function, void *extra_info, + duckdb_delete_callback_t destructor); + duckdb_state (*duckdb_register_copy_function)(duckdb_connection connection, duckdb_copy_function copy_function); + void (*duckdb_destroy_copy_function)(duckdb_copy_function *copy_function); + void (*duckdb_copy_function_set_bind)(duckdb_copy_function copy_function, duckdb_copy_function_bind_t bind); + void (*duckdb_copy_function_bind_set_error)(duckdb_copy_function_bind_info info, const char *error); + void *(*duckdb_copy_function_bind_get_extra_info)(duckdb_copy_function_bind_info info); + duckdb_client_context (*duckdb_copy_function_bind_get_client_context)(duckdb_copy_function_bind_info info); + idx_t (*duckdb_copy_function_bind_get_column_count)(duckdb_copy_function_bind_info info); + duckdb_logical_type (*duckdb_copy_function_bind_get_column_type)(duckdb_copy_function_bind_info info, + idx_t col_idx); + duckdb_value (*duckdb_copy_function_bind_get_options)(duckdb_copy_function_bind_info info); + void (*duckdb_copy_function_bind_set_bind_data)(duckdb_copy_function_bind_info info, void *bind_data, + duckdb_delete_callback_t destructor); + void (*duckdb_copy_function_set_global_init)(duckdb_copy_function copy_function, + duckdb_copy_function_global_init_t init); + void (*duckdb_copy_function_global_init_set_error)(duckdb_copy_function_global_init_info info, const char *error); + void *(*duckdb_copy_function_global_init_get_extra_info)(duckdb_copy_function_global_init_info info); + duckdb_client_context (*duckdb_copy_function_global_init_get_client_context)( + duckdb_copy_function_global_init_info info); + void *(*duckdb_copy_function_global_init_get_bind_data)(duckdb_copy_function_global_init_info info); + void (*duckdb_copy_function_global_init_set_global_state)(duckdb_copy_function_global_init_info info, + void *global_state, duckdb_delete_callback_t destructor); + const char *(*duckdb_copy_function_global_init_get_file_path)(duckdb_copy_function_global_init_info info); + void (*duckdb_copy_function_set_sink)(duckdb_copy_function copy_function, duckdb_copy_function_sink_t function); + void (*duckdb_copy_function_sink_set_error)(duckdb_copy_function_sink_info info, const char *error); + void *(*duckdb_copy_function_sink_get_extra_info)(duckdb_copy_function_sink_info info); + duckdb_client_context (*duckdb_copy_function_sink_get_client_context)(duckdb_copy_function_sink_info info); + void *(*duckdb_copy_function_sink_get_bind_data)(duckdb_copy_function_sink_info info); + void *(*duckdb_copy_function_sink_get_global_state)(duckdb_copy_function_sink_info info); + void (*duckdb_copy_function_set_finalize)(duckdb_copy_function copy_function, + duckdb_copy_function_finalize_t finalize); + void (*duckdb_copy_function_finalize_set_error)(duckdb_copy_function_finalize_info info, const char *error); + void *(*duckdb_copy_function_finalize_get_extra_info)(duckdb_copy_function_finalize_info info); + duckdb_client_context (*duckdb_copy_function_finalize_get_client_context)(duckdb_copy_function_finalize_info info); + void *(*duckdb_copy_function_finalize_get_bind_data)(duckdb_copy_function_finalize_info info); + void *(*duckdb_copy_function_finalize_get_global_state)(duckdb_copy_function_finalize_info info); + void (*duckdb_copy_function_set_copy_from_function)(duckdb_copy_function copy_function, + duckdb_table_function table_function); + idx_t (*duckdb_table_function_bind_get_result_column_count)(duckdb_bind_info info); + const char *(*duckdb_table_function_bind_get_result_column_name)(duckdb_bind_info info, idx_t col_idx); + duckdb_logical_type (*duckdb_table_function_bind_get_result_column_type)(duckdb_bind_info info, idx_t col_idx); // New functions for duckdb error data duckdb_error_data (*duckdb_create_error_data)(duckdb_error_type type, const char *message); @@ -554,6 +614,11 @@ typedef struct { // New string functions that are added char *(*duckdb_value_to_string)(duckdb_value value); + // New functions around the table description + + idx_t (*duckdb_table_description_get_column_count)(duckdb_table_description table_description); + duckdb_logical_type (*duckdb_table_description_get_column_type)(duckdb_table_description table_description, + idx_t index); // New functions around table function binding void (*duckdb_table_function_get_client_context)(duckdb_bind_info info, duckdb_client_context *out_context); @@ -993,11 +1058,57 @@ inline duckdb_ext_api_v1 CreateAPIv1() { result.duckdb_append_default_to_chunk = duckdb_append_default_to_chunk; result.duckdb_appender_error_data = duckdb_appender_error_data; result.duckdb_appender_create_query = duckdb_appender_create_query; + result.duckdb_appender_clear = duckdb_appender_clear; result.duckdb_to_arrow_schema = duckdb_to_arrow_schema; result.duckdb_data_chunk_to_arrow = duckdb_data_chunk_to_arrow; result.duckdb_schema_from_arrow = duckdb_schema_from_arrow; result.duckdb_data_chunk_from_arrow = duckdb_data_chunk_from_arrow; result.duckdb_destroy_arrow_converted_schema = duckdb_destroy_arrow_converted_schema; + result.duckdb_create_config_option = duckdb_create_config_option; + result.duckdb_destroy_config_option = duckdb_destroy_config_option; + result.duckdb_config_option_set_name = duckdb_config_option_set_name; + result.duckdb_config_option_set_type = duckdb_config_option_set_type; + result.duckdb_config_option_set_default_value = duckdb_config_option_set_default_value; + result.duckdb_config_option_set_default_scope = duckdb_config_option_set_default_scope; + result.duckdb_config_option_set_description = duckdb_config_option_set_description; + result.duckdb_register_config_option = duckdb_register_config_option; + result.duckdb_client_context_get_config_option = duckdb_client_context_get_config_option; + result.duckdb_create_copy_function = duckdb_create_copy_function; + result.duckdb_copy_function_set_name = duckdb_copy_function_set_name; + result.duckdb_copy_function_set_extra_info = duckdb_copy_function_set_extra_info; + result.duckdb_register_copy_function = duckdb_register_copy_function; + result.duckdb_destroy_copy_function = duckdb_destroy_copy_function; + result.duckdb_copy_function_set_bind = duckdb_copy_function_set_bind; + result.duckdb_copy_function_bind_set_error = duckdb_copy_function_bind_set_error; + result.duckdb_copy_function_bind_get_extra_info = duckdb_copy_function_bind_get_extra_info; + result.duckdb_copy_function_bind_get_client_context = duckdb_copy_function_bind_get_client_context; + result.duckdb_copy_function_bind_get_column_count = duckdb_copy_function_bind_get_column_count; + result.duckdb_copy_function_bind_get_column_type = duckdb_copy_function_bind_get_column_type; + result.duckdb_copy_function_bind_get_options = duckdb_copy_function_bind_get_options; + result.duckdb_copy_function_bind_set_bind_data = duckdb_copy_function_bind_set_bind_data; + result.duckdb_copy_function_set_global_init = duckdb_copy_function_set_global_init; + result.duckdb_copy_function_global_init_set_error = duckdb_copy_function_global_init_set_error; + result.duckdb_copy_function_global_init_get_extra_info = duckdb_copy_function_global_init_get_extra_info; + result.duckdb_copy_function_global_init_get_client_context = duckdb_copy_function_global_init_get_client_context; + result.duckdb_copy_function_global_init_get_bind_data = duckdb_copy_function_global_init_get_bind_data; + result.duckdb_copy_function_global_init_set_global_state = duckdb_copy_function_global_init_set_global_state; + result.duckdb_copy_function_global_init_get_file_path = duckdb_copy_function_global_init_get_file_path; + result.duckdb_copy_function_set_sink = duckdb_copy_function_set_sink; + result.duckdb_copy_function_sink_set_error = duckdb_copy_function_sink_set_error; + result.duckdb_copy_function_sink_get_extra_info = duckdb_copy_function_sink_get_extra_info; + result.duckdb_copy_function_sink_get_client_context = duckdb_copy_function_sink_get_client_context; + result.duckdb_copy_function_sink_get_bind_data = duckdb_copy_function_sink_get_bind_data; + result.duckdb_copy_function_sink_get_global_state = duckdb_copy_function_sink_get_global_state; + result.duckdb_copy_function_set_finalize = duckdb_copy_function_set_finalize; + result.duckdb_copy_function_finalize_set_error = duckdb_copy_function_finalize_set_error; + result.duckdb_copy_function_finalize_get_extra_info = duckdb_copy_function_finalize_get_extra_info; + result.duckdb_copy_function_finalize_get_client_context = duckdb_copy_function_finalize_get_client_context; + result.duckdb_copy_function_finalize_get_bind_data = duckdb_copy_function_finalize_get_bind_data; + result.duckdb_copy_function_finalize_get_global_state = duckdb_copy_function_finalize_get_global_state; + result.duckdb_copy_function_set_copy_from_function = duckdb_copy_function_set_copy_from_function; + result.duckdb_table_function_bind_get_result_column_count = duckdb_table_function_bind_get_result_column_count; + result.duckdb_table_function_bind_get_result_column_name = duckdb_table_function_bind_get_result_column_name; + result.duckdb_table_function_bind_get_result_column_type = duckdb_table_function_bind_get_result_column_type; result.duckdb_create_error_data = duckdb_create_error_data; result.duckdb_destroy_error_data = duckdb_destroy_error_data; result.duckdb_error_data_error_type = duckdb_error_data_error_type; @@ -1044,6 +1155,8 @@ inline duckdb_ext_api_v1 CreateAPIv1() { result.duckdb_scalar_function_bind_get_argument = duckdb_scalar_function_bind_get_argument; result.duckdb_scalar_function_set_bind_data_copy = duckdb_scalar_function_set_bind_data_copy; result.duckdb_value_to_string = duckdb_value_to_string; + result.duckdb_table_description_get_column_count = duckdb_table_description_get_column_count; + result.duckdb_table_description_get_column_type = duckdb_table_description_get_column_type; result.duckdb_table_function_get_client_context = duckdb_table_function_get_client_context; result.duckdb_create_map_value = duckdb_create_map_value; result.duckdb_create_union_value = duckdb_create_union_value; diff --git a/src/duckdb/src/include/duckdb/main/client_config.hpp b/src/duckdb/src/include/duckdb/main/client_config.hpp index f9e673b19..284e63ce9 100644 --- a/src/duckdb/src/include/duckdb/main/client_config.hpp +++ b/src/duckdb/src/include/duckdb/main/client_config.hpp @@ -121,7 +121,7 @@ struct ClientConfig { bool AnyVerification() const; - void SetUserVariable(const string &name, Value value); + void SetUserVariable(const String &name, Value value); bool GetUserVariable(const string &name, Value &result); void ResetUserVariable(const String &name); diff --git a/src/duckdb/src/include/duckdb/main/client_context.hpp b/src/duckdb/src/include/duckdb/main/client_context.hpp index ddb14518c..5348773c7 100644 --- a/src/duckdb/src/include/duckdb/main/client_context.hpp +++ b/src/duckdb/src/include/duckdb/main/client_context.hpp @@ -28,6 +28,7 @@ #include "duckdb/main/table_description.hpp" #include "duckdb/planner/expression/bound_parameter_data.hpp" #include "duckdb/transaction/transaction_context.hpp" +#include "duckdb/main/query_parameters.hpp" namespace duckdb { @@ -56,8 +57,8 @@ class RegisteredStateManager; struct PendingQueryParameters { //! Prepared statement parameters (if any) optional_ptr> parameters; - //! Whether a stream result should be allowed - bool allow_stream_result = false; + //! Whether a stream/buffer-managed result should be allowed + QueryParameters query_parameters; }; //! The ClientContext holds information relevant to the current client session @@ -106,22 +107,24 @@ class ClientContext : public enable_shared_from_this { //! Issue a query, returning a QueryResult. The QueryResult can be either a StreamQueryResult or a //! MaterializedQueryResult. The StreamQueryResult will only be returned in the case of a successful SELECT //! statement. - DUCKDB_API unique_ptr Query(const string &query, bool allow_stream_result); - DUCKDB_API unique_ptr Query(unique_ptr statement, bool allow_stream_result); + DUCKDB_API unique_ptr Query(const string &query, QueryParameters query_parameters); + DUCKDB_API unique_ptr Query(unique_ptr statement, QueryParameters query_parameters); //! Issues a query to the database and returns a Pending Query Result. Note that "query" may only contain //! a single statement. - DUCKDB_API unique_ptr PendingQuery(const string &query, bool allow_stream_result); + DUCKDB_API unique_ptr PendingQuery(const string &query, QueryParameters query_parameters); //! Issues a query to the database and returns a Pending Query Result DUCKDB_API unique_ptr PendingQuery(unique_ptr statement, - bool allow_stream_result); + QueryParameters query_parameters); //! Create a pending query with a list of parameters DUCKDB_API unique_ptr PendingQuery(unique_ptr statement, case_insensitive_map_t &values, - bool allow_stream_result); - DUCKDB_API unique_ptr - PendingQuery(const string &query, case_insensitive_map_t &values, bool allow_stream_result); + QueryParameters query_parameters); + DUCKDB_API unique_ptr PendingQuery(const string &query, + case_insensitive_map_t &values, + QueryParameters query_parameters); + DUCKDB_API unique_ptr PendingQuery(const string &query, PendingQueryParameters parameters); //! Destroy the client context DUCKDB_API void Destroy(); @@ -147,7 +150,7 @@ class ClientContext : public enable_shared_from_this { //! Execute a relation DUCKDB_API unique_ptr PendingQuery(const shared_ptr &relation, - bool allow_stream_result); + QueryParameters query_parameters); DUCKDB_API unique_ptr Execute(const shared_ptr &relation); //! Prepare a query @@ -165,9 +168,10 @@ class ClientContext : public enable_shared_from_this { //! Execute a prepared statement with the given name and set of parameters //! It is possible that the prepared statement will be re-bound. This will generally happen if the catalog is //! modified in between the prepared statement being bound and the prepared statement being run. - DUCKDB_API unique_ptr Execute(const string &query, shared_ptr &prepared, - case_insensitive_map_t &values, - bool allow_stream_result = true); + DUCKDB_API unique_ptr + Execute(const string &query, shared_ptr &prepared, + case_insensitive_map_t &values, + QueryParameters query_parameters = QueryResultOutputType::ALLOW_STREAMING); DUCKDB_API unique_ptr Execute(const string &query, shared_ptr &prepared, const PendingQueryParameters ¶meters); @@ -238,7 +242,7 @@ class ClientContext : public enable_shared_from_this { //! Perform aggressive query verification of a SELECT statement. Only called when query_verification_enabled is //! true. ErrorData VerifyQuery(ClientContextLock &lock, const string &query, unique_ptr statement, - optional_ptr> values = nullptr); + PendingQueryParameters parameters); void InitialCleanup(ClientContextLock &lock); //! Internal clean up, does not lock. Caller must hold the context_lock. @@ -259,15 +263,14 @@ class ClientContext : public enable_shared_from_this { //! Internally prepare a SQL statement. Caller must hold the context_lock. shared_ptr CreatePreparedStatement(ClientContextLock &lock, const string &query, unique_ptr statement, - optional_ptr> values = nullptr, + PendingQueryParameters parameters, PreparedStatementMode mode = PreparedStatementMode::PREPARE_ONLY); unique_ptr PendingStatementInternal(ClientContextLock &lock, const string &query, unique_ptr statement, const PendingQueryParameters ¶meters); unique_ptr RunStatementInternal(ClientContextLock &lock, const string &query, - unique_ptr statement, bool allow_stream_result, - optional_ptr> params, - bool verify = true); + unique_ptr statement, + const PendingQueryParameters ¶meters, bool verify = true); unique_ptr PrepareInternal(ClientContextLock &lock, unique_ptr statement); void LogQueryInternal(ClientContextLock &lock, const string &query); @@ -292,7 +295,7 @@ class ClientContext : public enable_shared_from_this { const PendingQueryParameters ¶meters); unique_ptr PendingQueryInternal(ClientContextLock &, const shared_ptr &relation, - bool allow_stream_result); + QueryParameters query_parameters); void RebindPreparedStatement(ClientContextLock &lock, const string &query, shared_ptr &prepared, const PendingQueryParameters ¶meters); @@ -300,9 +303,9 @@ class ClientContext : public enable_shared_from_this { template unique_ptr ErrorResult(ErrorData error, const string &query = string()); - shared_ptr - CreatePreparedStatementInternal(ClientContextLock &lock, const string &query, unique_ptr statement, - optional_ptr> values); + shared_ptr CreatePreparedStatementInternal(ClientContextLock &lock, const string &query, + unique_ptr statement, + PendingQueryParameters parameters); SettingLookupResult TryGetCurrentSettingInternal(const string &key, Value &result) const; @@ -337,6 +340,8 @@ class QueryContext { } QueryContext(optional_ptr context) : context(context) { // NOLINT: allow implicit construction } + QueryContext(ClientContext &context) : context(&context) { // NOLINT: allow implicit construction + } public: bool Valid() const { diff --git a/src/duckdb/src/include/duckdb/main/config.hpp b/src/duckdb/src/include/duckdb/main/config.hpp index 9a685f560..3cfd12dbb 100644 --- a/src/duckdb/src/include/duckdb/main/config.hpp +++ b/src/duckdb/src/include/duckdb/main/config.hpp @@ -110,6 +110,8 @@ struct DBConfigOptions { #else bool autoinstall_known_extensions = false; #endif + //! Setting for the parser override registered by extensions. Allowed options: "default, "fallback", "strict" + string allow_parser_override_extension = "default"; //! Override for the default extension repository string custom_extension_repo = ""; //! Override for the default autoload extension repository @@ -289,6 +291,7 @@ struct DBConfig { DUCKDB_API void AddExtensionOption(const string &name, string description, LogicalType parameter, const Value &default_value = Value(), set_option_callback_t function = nullptr, SetScope default_scope = SetScope::SESSION); + DUCKDB_API bool HasExtensionOption(const string &name); //! Fetch an option by index. Returns a pointer to the option, or nullptr if out of range DUCKDB_API static optional_ptr GetOptionByIndex(idx_t index); //! Fetcha n alias by index, or nullptr if out of range @@ -300,7 +303,7 @@ struct DBConfig { DUCKDB_API void SetOptionByName(const string &name, const Value &value); DUCKDB_API void SetOptionsByName(const case_insensitive_map_t &values); DUCKDB_API void ResetOption(optional_ptr db, const ConfigurationOption &option); - DUCKDB_API void SetOption(const string &name, Value value); + DUCKDB_API void SetOption(const String &name, Value value); DUCKDB_API void ResetOption(const String &name); DUCKDB_API void ResetGenericOption(const String &name); static LogicalType ParseLogicalType(const string &type); diff --git a/src/duckdb/src/include/duckdb/main/connection.hpp b/src/duckdb/src/include/duckdb/main/connection.hpp index c27d84d21..1c88757fc 100644 --- a/src/duckdb/src/include/duckdb/main/connection.hpp +++ b/src/duckdb/src/include/duckdb/main/connection.hpp @@ -50,7 +50,6 @@ class Connection { DUCKDB_API ~Connection(); shared_ptr context; - warning_callback_t warning_cb; public: //! Returns query profiling information for the current query @@ -80,13 +79,18 @@ class Connection { //! MaterializedQueryResult. The result can be stepped through with calls to Fetch(). Note that there can only be //! one active StreamQueryResult per Connection object. Calling SendQuery() will invalidate any previously existing //! StreamQueryResult. - DUCKDB_API unique_ptr SendQuery(const string &query); + DUCKDB_API unique_ptr + SendQuery(const string &query, QueryParameters query_parameters = QueryResultOutputType::ALLOW_STREAMING); + DUCKDB_API unique_ptr + SendQuery(unique_ptr statement, + QueryParameters query_parameters = QueryResultOutputType::ALLOW_STREAMING); //! Issues a query to the database and materializes the result (if necessary). Always returns a //! MaterializedQueryResult. DUCKDB_API unique_ptr Query(const string &query); //! Issues a query to the database and materializes the result (if necessary). Always returns a //! MaterializedQueryResult. - DUCKDB_API unique_ptr Query(unique_ptr statement); + DUCKDB_API unique_ptr + Query(unique_ptr statement, QueryResultMemoryType memory_type = QueryResultMemoryType::IN_MEMORY); // prepared statements template unique_ptr Query(const string &query, ARGS... args) { @@ -96,20 +100,25 @@ class Connection { //! Issues a query to the database and returns a Pending Query Result. Note that "query" may only contain //! a single statement. - DUCKDB_API unique_ptr PendingQuery(const string &query, bool allow_stream_result = false); + DUCKDB_API unique_ptr + PendingQuery(const string &query, QueryParameters query_parameters = QueryResultOutputType::FORCE_MATERIALIZED); //! Issues a query to the database and returns a Pending Query Result - DUCKDB_API unique_ptr PendingQuery(unique_ptr statement, - bool allow_stream_result = false); - DUCKDB_API unique_ptr PendingQuery(unique_ptr statement, - case_insensitive_map_t &named_values, - bool allow_stream_result = false); - DUCKDB_API unique_ptr PendingQuery(const string &query, - case_insensitive_map_t &named_values, - bool allow_stream_result = false); - DUCKDB_API unique_ptr PendingQuery(const string &query, vector &values, - bool allow_stream_result = false); - DUCKDB_API unique_ptr PendingQuery(unique_ptr statement, vector &values, - bool allow_stream_result = false); + DUCKDB_API unique_ptr + PendingQuery(unique_ptr statement, + QueryParameters query_parameters = QueryResultOutputType::FORCE_MATERIALIZED); + DUCKDB_API unique_ptr + PendingQuery(unique_ptr statement, case_insensitive_map_t &named_values, + QueryParameters query_parameters = QueryResultOutputType::FORCE_MATERIALIZED); + DUCKDB_API unique_ptr + PendingQuery(const string &query, case_insensitive_map_t &named_values, + QueryParameters query_parameters = QueryResultOutputType::FORCE_MATERIALIZED); + DUCKDB_API unique_ptr + PendingQuery(const string &query, vector &values, + QueryParameters query_parameters = QueryResultOutputType::FORCE_MATERIALIZED); + DUCKDB_API unique_ptr PendingQuery(const string &query, PendingQueryParameters parameters); + DUCKDB_API unique_ptr + PendingQuery(unique_ptr statement, vector &values, + QueryParameters query_parameters = QueryResultOutputType::FORCE_MATERIALIZED); //! Prepare the specified query, returning a prepared statement object DUCKDB_API unique_ptr Prepare(const string &query); diff --git a/src/duckdb/src/include/duckdb/main/connection_manager.hpp b/src/duckdb/src/include/duckdb/main/connection_manager.hpp index 7fa5c66b5..1c647ce02 100644 --- a/src/duckdb/src/include/duckdb/main/connection_manager.hpp +++ b/src/duckdb/src/include/duckdb/main/connection_manager.hpp @@ -40,7 +40,6 @@ class ConnectionManager { mutex connections_lock; reference_map_t> connections; atomic connection_count; - atomic current_connection_id; }; diff --git a/src/duckdb/src/include/duckdb/main/database.hpp b/src/duckdb/src/include/duckdb/main/database.hpp index 2486d1e0e..bf2a57cbe 100644 --- a/src/duckdb/src/include/duckdb/main/database.hpp +++ b/src/duckdb/src/include/duckdb/main/database.hpp @@ -17,6 +17,7 @@ #include "duckdb/main/extension_manager.hpp" namespace duckdb { + class BufferManager; class DatabaseManager; class StorageManager; @@ -33,6 +34,7 @@ class DatabaseFileSystem; struct DatabaseCacheEntry; class LogManager; class ExternalFileCache; +class ResultSetManager; class DatabaseInstance : public enable_shared_from_this { friend class DuckDB; @@ -51,6 +53,7 @@ class DatabaseInstance : public enable_shared_from_this { DUCKDB_API DatabaseManager &GetDatabaseManager(); DUCKDB_API FileSystem &GetFileSystem(); DUCKDB_API ExternalFileCache &GetExternalFileCache(); + DUCKDB_API ResultSetManager &GetResultSetManager(); DUCKDB_API TaskScheduler &GetScheduler(); DUCKDB_API ObjectCache &GetObjectCache(); DUCKDB_API ConnectionManager &GetConnectionManager(); @@ -90,8 +93,9 @@ class DatabaseInstance : public enable_shared_from_this { unique_ptr extension_manager; ValidChecker db_validity; unique_ptr db_file_system; - shared_ptr log_manager; + unique_ptr log_manager; unique_ptr external_file_cache; + unique_ptr result_set_manager; duckdb_ext_api_v1 (*create_api_v1)(); }; diff --git a/src/duckdb/src/include/duckdb/main/database_file_path_manager.hpp b/src/duckdb/src/include/duckdb/main/database_file_path_manager.hpp index 1912a90bf..3af2f1873 100644 --- a/src/duckdb/src/include/duckdb/main/database_file_path_manager.hpp +++ b/src/duckdb/src/include/duckdb/main/database_file_path_manager.hpp @@ -12,33 +12,42 @@ #include "duckdb/common/mutex.hpp" #include "duckdb/common/case_insensitive_map.hpp" #include "duckdb/common/enums/on_create_conflict.hpp" +#include "duckdb/common/enums/access_mode.hpp" +#include "duckdb/common/reference_map.hpp" namespace duckdb { struct AttachInfo; struct AttachOptions; +class DatabaseManager; enum class InsertDatabasePathResult { SUCCESS, ALREADY_EXISTS }; struct DatabasePathInfo { - explicit DatabasePathInfo(string name_p) : name(std::move(name_p)) { - } + DatabasePathInfo(DatabaseManager &manager, string name_p, AccessMode access_mode); string name; + AccessMode access_mode; + reference_set_t attached_databases; + idx_t reference_count = 1; }; //! The DatabaseFilePathManager is used to ensure we only ever open a single database file once class DatabaseFilePathManager { public: idx_t ApproxDatabaseCount() const; - InsertDatabasePathResult InsertDatabasePath(const string &path, const string &name, OnCreateConflict on_conflict, - AttachOptions &options); + InsertDatabasePathResult InsertDatabasePath(DatabaseManager &manager, const string &path, const string &name, + OnCreateConflict on_conflict, AttachOptions &options); //! Erase a database path - indicating we are done with using it void EraseDatabasePath(const string &path); + //! Called when a database is detached, but before it is fully finished being used + void DetachDatabase(DatabaseManager &manager, const string &path); private: //! The lock to add entries to the db_paths map mutable mutex db_paths_lock; - //! A set containing all attached database paths mapped to their attached database name + //! A set containing all attached database path + //! This allows to attach many databases efficiently, and to avoid attaching the + //! same file path twice case_insensitive_map_t db_paths; }; diff --git a/src/duckdb/src/include/duckdb/main/error_manager.hpp b/src/duckdb/src/include/duckdb/main/error_manager.hpp index aaedffd4b..065f6399a 100644 --- a/src/duckdb/src/include/duckdb/main/error_manager.hpp +++ b/src/duckdb/src/include/duckdb/main/error_manager.hpp @@ -34,38 +34,39 @@ enum class ErrorType : uint16_t { class ErrorManager { public: template - string FormatException(ErrorType error_type, ARGS... params) { + string FormatException(ErrorType error_type, ARGS &&...params) { vector values; - return FormatExceptionRecursive(error_type, values, params...); + return FormatExceptionRecursive(error_type, values, std::forward(params)...); } DUCKDB_API string FormatExceptionRecursive(ErrorType error_type, vector &values); template string FormatExceptionRecursive(ErrorType error_type, vector &values, T param, - ARGS... params) { + ARGS &&...params) { values.push_back(ExceptionFormatValue::CreateFormatValue(param)); - return FormatExceptionRecursive(error_type, values, params...); + return FormatExceptionRecursive(error_type, values, std::forward(params)...); } template - static string FormatException(ClientContext &context, ErrorType error_type, ARGS... params) { - return Get(context).FormatException(error_type, params...); + static string FormatException(ClientContext &context, ErrorType error_type, ARGS &&...params) { + return Get(context).FormatException(error_type, std::forward(params)...); } DUCKDB_API static InvalidInputException InvalidUnicodeError(const String &input, const string &context); DUCKDB_API static FatalException InvalidatedDatabase(ClientContext &context, const string &invalidated_msg); + DUCKDB_API static TransactionException InvalidatedTransaction(ClientContext &context); //! Adds a custom error for a specific error type void AddCustomError(ErrorType type, string new_error); DUCKDB_API static ErrorManager &Get(ClientContext &context); + DUCKDB_API static ErrorManager &Get(DatabaseInstance &context); private: map custom_errors; }; - } // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/main/extension_entries.hpp b/src/duckdb/src/include/duckdb/main/extension_entries.hpp index a32331c9b..e08f291a0 100644 --- a/src/duckdb/src/include/duckdb/main/extension_entries.hpp +++ b/src/duckdb/src/include/duckdb/main/extension_entries.hpp @@ -42,7 +42,6 @@ struct ExtensionFunctionOverloadEntry { static constexpr ExtensionFunctionEntry EXTENSION_FUNCTIONS[] = { {"!__postfix", "core_functions", CatalogType::SCALAR_FUNCTION_ENTRY}, {"&", "core_functions", CatalogType::SCALAR_FUNCTION_ENTRY}, - {"&&", "core_functions", CatalogType::SCALAR_FUNCTION_ENTRY}, {"**", "core_functions", CatalogType::SCALAR_FUNCTION_ENTRY}, {"->>", "json", CatalogType::SCALAR_FUNCTION_ENTRY}, {"<->", "core_functions", CatalogType::SCALAR_FUNCTION_ENTRY}, @@ -69,8 +68,10 @@ static constexpr ExtensionFunctionEntry EXTENSION_FUNCTIONS[] = { {"approx_top_k", "core_functions", CatalogType::AGGREGATE_FUNCTION_ENTRY}, {"arg_max", "core_functions", CatalogType::AGGREGATE_FUNCTION_ENTRY}, {"arg_max_null", "core_functions", CatalogType::AGGREGATE_FUNCTION_ENTRY}, + {"arg_max_nulls_last", "core_functions", CatalogType::AGGREGATE_FUNCTION_ENTRY}, {"arg_min", "core_functions", CatalogType::AGGREGATE_FUNCTION_ENTRY}, {"arg_min_null", "core_functions", CatalogType::AGGREGATE_FUNCTION_ENTRY}, + {"arg_min_nulls_last", "core_functions", CatalogType::AGGREGATE_FUNCTION_ENTRY}, {"argmax", "core_functions", CatalogType::AGGREGATE_FUNCTION_ENTRY}, {"argmin", "core_functions", CatalogType::AGGREGATE_FUNCTION_ENTRY}, {"array_agg", "core_functions", CatalogType::AGGREGATE_FUNCTION_ENTRY}, @@ -475,6 +476,7 @@ static constexpr ExtensionFunctionEntry EXTENSION_FUNCTIONS[] = { {"ord", "core_functions", CatalogType::SCALAR_FUNCTION_ENTRY}, {"parquet_bloom_probe", "parquet", CatalogType::TABLE_FUNCTION_ENTRY}, {"parquet_file_metadata", "parquet", CatalogType::TABLE_FUNCTION_ENTRY}, + {"parquet_full_metadata", "parquet", CatalogType::TABLE_FUNCTION_ENTRY}, {"parquet_kv_metadata", "parquet", CatalogType::TABLE_FUNCTION_ENTRY}, {"parquet_metadata", "parquet", CatalogType::TABLE_FUNCTION_ENTRY}, {"parquet_scan", "parquet", CatalogType::TABLE_FUNCTION_ENTRY}, @@ -599,6 +601,7 @@ static constexpr ExtensionFunctionEntry EXTENSION_FUNCTIONS[] = { {"st_envelope", "spatial", CatalogType::SCALAR_FUNCTION_ENTRY}, {"st_envelope_agg", "spatial", CatalogType::AGGREGATE_FUNCTION_ENTRY}, {"st_equals", "spatial", CatalogType::SCALAR_FUNCTION_ENTRY}, + {"st_expand", "spatial", CatalogType::SCALAR_FUNCTION_ENTRY}, {"st_extent", "spatial", CatalogType::SCALAR_FUNCTION_ENTRY}, {"st_extent_agg", "spatial", CatalogType::AGGREGATE_FUNCTION_ENTRY}, {"st_extent_approx", "spatial", CatalogType::SCALAR_FUNCTION_ENTRY}, @@ -779,6 +782,7 @@ static constexpr ExtensionFunctionEntry EXTENSION_FUNCTIONS[] = { {"var_pop", "core_functions", CatalogType::AGGREGATE_FUNCTION_ENTRY}, {"var_samp", "core_functions", CatalogType::AGGREGATE_FUNCTION_ENTRY}, {"variance", "core_functions", CatalogType::AGGREGATE_FUNCTION_ENTRY}, + {"variant_to_parquet_variant", "parquet", CatalogType::SCALAR_FUNCTION_ENTRY}, {"vector_type", "core_functions", CatalogType::SCALAR_FUNCTION_ENTRY}, {"version", "core_functions", CatalogType::SCALAR_FUNCTION_ENTRY}, {"vss_join", "vss", CatalogType::TABLE_MACRO_ENTRY}, @@ -1080,7 +1084,6 @@ static constexpr ExtensionEntry EXTENSION_SETTINGS[] = { {"ui_remote_url", "ui"}, {"unsafe_disable_etag_checks", "httpfs"}, {"unsafe_enable_version_guessing", "iceberg"}, - {"variant_legacy_encoding", "parquet"}, }; // END_OF_EXTENSION_SETTINGS static constexpr ExtensionEntry EXTENSION_SECRET_TYPES[] = { diff --git a/src/duckdb/src/include/duckdb/main/materialized_query_result.hpp b/src/duckdb/src/include/duckdb/main/materialized_query_result.hpp index 25c50d980..60fbb6e24 100644 --- a/src/duckdb/src/include/duckdb/main/materialized_query_result.hpp +++ b/src/duckdb/src/include/duckdb/main/materialized_query_result.hpp @@ -30,10 +30,6 @@ class MaterializedQueryResult : public QueryResult { DUCKDB_API explicit MaterializedQueryResult(ErrorData error); public: - //! Fetches a DataChunk from the query result. - //! This will consume the result (i.e. the result can only be scanned once with this function) - DUCKDB_API unique_ptr Fetch() override; - DUCKDB_API unique_ptr FetchRaw() override; //! Converts the QueryResult to a string DUCKDB_API string ToString() override; DUCKDB_API string ToBox(ClientContext &context, const BoxRendererConfig &config) override; @@ -48,6 +44,7 @@ class MaterializedQueryResult : public QueryResult { return (T)value.GetValue(); } + DUCKDB_API bool MoreRowsThan(idx_t row_count) override; DUCKDB_API idx_t RowCount() const; //! Returns a reference to the underlying column data collection @@ -56,6 +53,9 @@ class MaterializedQueryResult : public QueryResult { //! Takes ownership of the collection, 'collection' is null after this operation unique_ptr TakeCollection(); +protected: + DUCKDB_API unique_ptr FetchInternal() override; + private: unique_ptr collection; //! Row collection, only created if GetValue is called diff --git a/src/duckdb/src/include/duckdb/main/prepared_statement_data.hpp b/src/duckdb/src/include/duckdb/main/prepared_statement_data.hpp index b40a9addb..84f50d654 100644 --- a/src/duckdb/src/include/duckdb/main/prepared_statement_data.hpp +++ b/src/duckdb/src/include/duckdb/main/prepared_statement_data.hpp @@ -45,7 +45,9 @@ class PreparedStatementData { //! The map of parameter index to the actual value entry bound_parameter_map_t value_map; //! Whether we are creating a streaming result or not - bool is_streaming = false; + QueryResultOutputType output_type; + //! Whether we are creating a buffer-managed result or not + QueryResultMemoryType memory_type; public: void CheckParameterCount(idx_t parameter_count); diff --git a/src/duckdb/src/include/duckdb/main/profiling_info.hpp b/src/duckdb/src/include/duckdb/main/profiling_info.hpp index 904f0205d..554c6cafe 100644 --- a/src/duckdb/src/include/duckdb/main/profiling_info.hpp +++ b/src/duckdb/src/include/duckdb/main/profiling_info.hpp @@ -32,9 +32,6 @@ class ProfilingInfo { profiler_settings_t expanded_settings; //! Contains all enabled metrics. profiler_metrics_t metrics; - //! Additional metrics. - // FIXME: move to metrics. - InsertionOrderPreservingMap extra_info; public: ProfilingInfo() = default; @@ -44,8 +41,8 @@ class ProfilingInfo { public: static profiler_settings_t DefaultSettings(); - static profiler_settings_t DefaultRootSettings(); - static profiler_settings_t DefaultOperatorSettings(); + static profiler_settings_t RootScopeSettings(); + static profiler_settings_t OperatorScopeSettings(); public: void ResetMetrics(); @@ -56,6 +53,7 @@ class ProfilingInfo { public: string GetMetricAsString(const MetricsType metric) const; + void WriteMetricsToLog(ClientContext &context); void WriteMetricsToJSON(duckdb_yyjson::yyjson_mut_doc *doc, duckdb_yyjson::yyjson_mut_val *destination); public: @@ -102,6 +100,7 @@ class ProfilingInfo { return MaxValue(old_value, new_value); }); } + template void MetricMax(const MetricsType type, const METRIC_TYPE &value) { auto new_value = Value::CreateValue(value); @@ -109,4 +108,19 @@ class ProfilingInfo { } }; +// Specialization for InsertionOrderPreservingMap +template <> +inline InsertionOrderPreservingMap +ProfilingInfo::GetMetricValue>(const MetricsType type) const { + auto val = metrics.at(type); + InsertionOrderPreservingMap result; + auto children = MapValue::GetChildren(val); + for (auto &child : children) { + auto struct_children = StructValue::GetChildren(child); + auto key = struct_children[0].GetValue(); + auto value = struct_children[1].GetValue(); + result.insert(key, value); + } + return result; +} } // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/main/query_parameters.hpp b/src/duckdb/src/include/duckdb/main/query_parameters.hpp new file mode 100644 index 000000000..d9bb42a3b --- /dev/null +++ b/src/duckdb/src/include/duckdb/main/query_parameters.hpp @@ -0,0 +1,33 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/main/query_parameters.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/common/common.hpp" + +namespace duckdb { + +enum class QueryResultOutputType : uint8_t { FORCE_MATERIALIZED, ALLOW_STREAMING }; + +enum class QueryResultMemoryType : uint8_t { IN_MEMORY, BUFFER_MANAGED }; + +struct QueryParameters { + QueryParameters() { + } + QueryParameters(bool allow_streaming) // NOLINT: allow implicit conversion + : output_type(allow_streaming ? QueryResultOutputType::ALLOW_STREAMING + : QueryResultOutputType::FORCE_MATERIALIZED) { + } + QueryParameters(QueryResultOutputType output_type) // NOLINT: allow implicit conversion + : output_type(output_type) { + } + QueryResultOutputType output_type = QueryResultOutputType::FORCE_MATERIALIZED; + QueryResultMemoryType memory_type = QueryResultMemoryType::IN_MEMORY; +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/main/query_profiler.hpp b/src/duckdb/src/include/duckdb/main/query_profiler.hpp index 0f7b8812d..b44ae331c 100644 --- a/src/duckdb/src/include/duckdb/main/query_profiler.hpp +++ b/src/duckdb/src/include/duckdb/main/query_profiler.hpp @@ -94,7 +94,6 @@ class OperatorProfiler { DUCKDB_API void Flush(const PhysicalOperator &phys_op); DUCKDB_API OperatorInformation &GetOperatorInfo(const PhysicalOperator &phys_op); DUCKDB_API bool OperatorInfoIsInitialized(const PhysicalOperator &phys_op); - DUCKDB_API void AddExtraInfo(InsertionOrderPreservingMap extra_info); public: ClientContext &context; @@ -117,15 +116,41 @@ class OperatorProfiler { struct QueryMetrics { QueryMetrics() : total_bytes_read(0), total_bytes_written(0) {}; + //! Reset the query metrics. + void Reset() { + query = ""; + latency.Reset(); + waiting_to_attach_latency.Reset(); + attach_load_storage_latency.Reset(); + attach_replay_wal_latency.Reset(); + checkpoint_latency.Reset(); + commit_write_wal_latency.Reset(); + wal_replay_entry_count = 0; + total_bytes_read = 0; + total_bytes_written = 0; + } + ProfilingInfo query_global_info; - //! The SQL string of the query + //! The SQL string of the query. string query; - //! The timer used to time the excution time of the entire query + //! The timer of the execution of the entire query. Profiler latency; - //! The total bytes read by the file system + //! The timer of the delay when waiting to ATTACH a file. + Profiler waiting_to_attach_latency; + //! The timer for loading from storage. + Profiler attach_load_storage_latency; + //! The timer for replaying the WAL file. + Profiler attach_replay_wal_latency; + //! The timer for running checkpoints. + Profiler checkpoint_latency; + //! The timer for the WAL writes during COMMIT. + Profiler commit_write_wal_latency; + //! The total number of entries to replay in the WAL. + atomic wal_replay_entry_count; + //! The total bytes read by the file system. atomic total_bytes_read; - //! The total bytes written by the file system + //! The total bytes written by the file system. atomic total_bytes_written; }; @@ -138,9 +163,6 @@ class QueryProfiler { DUCKDB_API explicit QueryProfiler(ClientContext &context); public: - //! Propagate save_location, enabled, detailed_enabled and automatic_print_format. - void Propagate(QueryProfiler &qp); - DUCKDB_API bool IsEnabled() const; DUCKDB_API bool IsDetailedEnabled() const; DUCKDB_API ProfilerPrintFormat GetPrintFormat(ExplainFormat format = ExplainFormat::DEFAULT) const; @@ -154,17 +176,19 @@ class QueryProfiler { DUCKDB_API void StartQuery(const string &query, bool is_explain_analyze = false, bool start_at_optimizer = false); DUCKDB_API void EndQuery(); - //! Adds nr_bytes bytes to the total bytes read. - DUCKDB_API void AddBytesRead(const idx_t nr_bytes); - //! Adds nr_bytes bytes to the total bytes written. - DUCKDB_API void AddBytesWritten(const idx_t nr_bytes); + //! Adds amount to a specific metric type. + DUCKDB_API void AddToCounter(MetricsType type, const idx_t amount); + + //! Start/End a timer for a specific metric type. + DUCKDB_API void StartTimer(MetricsType type); + DUCKDB_API void EndTimer(MetricsType type); DUCKDB_API void StartExplainAnalyze(); //! Adds the timings gathered by an OperatorProfiler to this query profiler DUCKDB_API void Flush(OperatorProfiler &profiler); //! Adds the top level query information to the global profiler. - DUCKDB_API void SetInfo(const double &blocked_thread_time); + DUCKDB_API void SetBlockedTime(const double &blocked_thread_time); DUCKDB_API void StartPhase(MetricsType phase_metric); DUCKDB_API void EndPhase(); @@ -180,11 +204,15 @@ class QueryProfiler { DUCKDB_API string ToString(ExplainFormat format = ExplainFormat::DEFAULT) const; DUCKDB_API string ToString(ProfilerPrintFormat format) const; - static InsertionOrderPreservingMap JSONSanitize(const InsertionOrderPreservingMap &input); + // Sanitize a Value::MAP + static Value JSONSanitize(const Value &input); static string JSONSanitize(const string &text); static string DrawPadded(const string &str, idx_t width); + DUCKDB_API void ToLog() const; DUCKDB_API string ToJSON() const; DUCKDB_API void WriteToFile(const char *path, string &info) const; + DUCKDB_API idx_t GetBytesRead() const; + DUCKDB_API idx_t GetBytesWritten() const; idx_t OperatorSize() { return tree_map.size(); diff --git a/src/duckdb/src/include/duckdb/main/query_result.hpp b/src/duckdb/src/include/duckdb/main/query_result.hpp index f85629428..5468e0dc7 100644 --- a/src/duckdb/src/include/duckdb/main/query_result.hpp +++ b/src/duckdb/src/include/duckdb/main/query_result.hpp @@ -44,7 +44,7 @@ class BaseQueryResult { DUCKDB_API void SetError(ErrorData error); DUCKDB_API bool HasError() const; DUCKDB_API const ExceptionType &GetErrorType() const; - DUCKDB_API const std::string &GetError(); + DUCKDB_API const std::string &GetError() const; DUCKDB_API ErrorData &GetErrorObject(); DUCKDB_API idx_t ColumnCount(); @@ -98,10 +98,10 @@ class QueryResult : public BaseQueryResult { DUCKDB_API const string &ColumnName(idx_t index) const; //! Fetches a DataChunk of normalized (flat) vectors from the query result. //! Returns nullptr if there are no more results to fetch. - DUCKDB_API virtual unique_ptr Fetch(); + DUCKDB_API unique_ptr Fetch(); //! Fetches a DataChunk from the query result. The vectors are not normalized and hence any vector types can be //! returned. - DUCKDB_API virtual unique_ptr FetchRaw() = 0; + DUCKDB_API unique_ptr FetchRaw(); //! Converts the QueryResult to a string DUCKDB_API virtual string ToString() = 0; //! Converts the QueryResult to a box-rendered string @@ -111,6 +111,9 @@ class QueryResult : public BaseQueryResult { //! Returns true if the two results are identical; false otherwise. Note that this method is destructive; it calls //! Fetch() until both results are exhausted. The data in the results will be lost. DUCKDB_API bool Equals(QueryResult &other); + //! Returns true if the query result has more rows than the given amount. + //! This might involve fetching up to that many rows - but wil not exhaust any + DUCKDB_API virtual bool MoreRowsThan(idx_t row_count); bool TryFetch(unique_ptr &result, ErrorData &error) { try { @@ -125,6 +128,9 @@ class QueryResult : public BaseQueryResult { } } +protected: + DUCKDB_API virtual unique_ptr FetchInternal() = 0; + private: class QueryResultIterator; class QueryResultRow { @@ -200,6 +206,10 @@ class QueryResult : public BaseQueryResult { return QueryResultIterator(nullptr); } +protected: + vector> stored_chunks; + bool result_exhausted = false; + protected: DUCKDB_API string HeaderToString(); diff --git a/src/duckdb/src/include/duckdb/main/relation.hpp b/src/duckdb/src/include/duckdb/main/relation.hpp index 9d9e67686..bc383ffe0 100644 --- a/src/duckdb/src/include/duckdb/main/relation.hpp +++ b/src/duckdb/src/include/duckdb/main/relation.hpp @@ -78,7 +78,8 @@ class Relation : public enable_shared_from_this { public: DUCKDB_API virtual const vector &Columns() = 0; - DUCKDB_API virtual unique_ptr GetQueryNode(); + DUCKDB_API virtual unique_ptr GetQueryNode() = 0; + DUCKDB_API virtual string GetQuery(); DUCKDB_API virtual BoundStatement Bind(Binder &binder); DUCKDB_API virtual string GetAlias(); diff --git a/src/duckdb/src/include/duckdb/main/relation/create_table_relation.hpp b/src/duckdb/src/include/duckdb/main/relation/create_table_relation.hpp index 7d5462941..8df59b8d2 100644 --- a/src/duckdb/src/include/duckdb/main/relation/create_table_relation.hpp +++ b/src/duckdb/src/include/duckdb/main/relation/create_table_relation.hpp @@ -26,6 +26,8 @@ class CreateTableRelation : public Relation { public: BoundStatement Bind(Binder &binder) override; + unique_ptr GetQueryNode() override; + string GetQuery() override; const vector &Columns() override; string ToString(idx_t depth) override; bool IsReadOnly() override { diff --git a/src/duckdb/src/include/duckdb/main/relation/create_view_relation.hpp b/src/duckdb/src/include/duckdb/main/relation/create_view_relation.hpp index cb826a86c..aa09b0def 100644 --- a/src/duckdb/src/include/duckdb/main/relation/create_view_relation.hpp +++ b/src/duckdb/src/include/duckdb/main/relation/create_view_relation.hpp @@ -26,6 +26,8 @@ class CreateViewRelation : public Relation { public: BoundStatement Bind(Binder &binder) override; + unique_ptr GetQueryNode() override; + string GetQuery() override; const vector &Columns() override; string ToString(idx_t depth) override; bool IsReadOnly() override { diff --git a/src/duckdb/src/include/duckdb/main/relation/delete_relation.hpp b/src/duckdb/src/include/duckdb/main/relation/delete_relation.hpp index c07445ba4..0c25c6576 100644 --- a/src/duckdb/src/include/duckdb/main/relation/delete_relation.hpp +++ b/src/duckdb/src/include/duckdb/main/relation/delete_relation.hpp @@ -26,6 +26,8 @@ class DeleteRelation : public Relation { public: BoundStatement Bind(Binder &binder) override; + unique_ptr GetQueryNode() override; + string GetQuery() override; const vector &Columns() override; string ToString(idx_t depth) override; bool IsReadOnly() override { diff --git a/src/duckdb/src/include/duckdb/main/relation/explain_relation.hpp b/src/duckdb/src/include/duckdb/main/relation/explain_relation.hpp index 888583b2b..96be08d8f 100644 --- a/src/duckdb/src/include/duckdb/main/relation/explain_relation.hpp +++ b/src/duckdb/src/include/duckdb/main/relation/explain_relation.hpp @@ -24,6 +24,8 @@ class ExplainRelation : public Relation { public: BoundStatement Bind(Binder &binder) override; + unique_ptr GetQueryNode() override; + string GetQuery() override; const vector &Columns() override; string ToString(idx_t depth) override; bool IsReadOnly() override { diff --git a/src/duckdb/src/include/duckdb/main/relation/insert_relation.hpp b/src/duckdb/src/include/duckdb/main/relation/insert_relation.hpp index 3695cde7b..fccb0ae92 100644 --- a/src/duckdb/src/include/duckdb/main/relation/insert_relation.hpp +++ b/src/duckdb/src/include/duckdb/main/relation/insert_relation.hpp @@ -23,6 +23,8 @@ class InsertRelation : public Relation { public: BoundStatement Bind(Binder &binder) override; + unique_ptr GetQueryNode() override; + string GetQuery() override; const vector &Columns() override; string ToString(idx_t depth) override; bool IsReadOnly() override { diff --git a/src/duckdb/src/include/duckdb/main/relation/query_relation.hpp b/src/duckdb/src/include/duckdb/main/relation/query_relation.hpp index bdb035652..b1be001b9 100644 --- a/src/duckdb/src/include/duckdb/main/relation/query_relation.hpp +++ b/src/duckdb/src/include/duckdb/main/relation/query_relation.hpp @@ -28,6 +28,7 @@ class QueryRelation : public Relation { public: static unique_ptr ParseStatement(ClientContext &context, const string &query, const string &error); unique_ptr GetQueryNode() override; + string GetQuery() override; unique_ptr GetTableRef() override; BoundStatement Bind(Binder &binder) override; diff --git a/src/duckdb/src/include/duckdb/main/relation/update_relation.hpp b/src/duckdb/src/include/duckdb/main/relation/update_relation.hpp index 58ad203b2..91eac246e 100644 --- a/src/duckdb/src/include/duckdb/main/relation/update_relation.hpp +++ b/src/duckdb/src/include/duckdb/main/relation/update_relation.hpp @@ -29,6 +29,8 @@ class UpdateRelation : public Relation { public: BoundStatement Bind(Binder &binder) override; + unique_ptr GetQueryNode() override; + string GetQuery() override; const vector &Columns() override; string ToString(idx_t depth) override; bool IsReadOnly() override { diff --git a/src/duckdb/src/include/duckdb/main/relation/write_csv_relation.hpp b/src/duckdb/src/include/duckdb/main/relation/write_csv_relation.hpp index 99d2ebe8e..cf0853ff3 100644 --- a/src/duckdb/src/include/duckdb/main/relation/write_csv_relation.hpp +++ b/src/duckdb/src/include/duckdb/main/relation/write_csv_relation.hpp @@ -23,6 +23,8 @@ class WriteCSVRelation : public Relation { public: BoundStatement Bind(Binder &binder) override; + unique_ptr GetQueryNode() override; + string GetQuery() override; const vector &Columns() override; string ToString(idx_t depth) override; bool IsReadOnly() override { diff --git a/src/duckdb/src/include/duckdb/main/relation/write_parquet_relation.hpp b/src/duckdb/src/include/duckdb/main/relation/write_parquet_relation.hpp index d32089212..138eee7c7 100644 --- a/src/duckdb/src/include/duckdb/main/relation/write_parquet_relation.hpp +++ b/src/duckdb/src/include/duckdb/main/relation/write_parquet_relation.hpp @@ -24,6 +24,8 @@ class WriteParquetRelation : public Relation { public: BoundStatement Bind(Binder &binder) override; + unique_ptr GetQueryNode() override; + string GetQuery() override; const vector &Columns() override; string ToString(idx_t depth) override; bool IsReadOnly() override { diff --git a/src/duckdb/src/include/duckdb/main/result_set_manager.hpp b/src/duckdb/src/include/duckdb/main/result_set_manager.hpp new file mode 100644 index 000000000..0be2a4b88 --- /dev/null +++ b/src/duckdb/src/include/duckdb/main/result_set_manager.hpp @@ -0,0 +1,55 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/main/result_set_manager.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/common/mutex.hpp" +#include "duckdb/common/shared_ptr.hpp" +#include "duckdb/common/reference_map.hpp" +#include "duckdb/common/optional_ptr.hpp" + +namespace duckdb { + +class DatabaseInstance; +class ClientContext; +class BlockHandle; +class ColumnDataAllocator; + +class ManagedResultSet : public enable_shared_from_this { +public: + ManagedResultSet(); + ManagedResultSet(const weak_ptr &db, vector> &handles); + +public: + bool IsValid() const; + shared_ptr GetDatabase() const; + vector> &GetHandles(); + +private: + bool valid; + weak_ptr db; + optional_ptr>> handles; +}; + +class ResultSetManager { +public: + explicit ResultSetManager(DatabaseInstance &db); + +public: + static ResultSetManager &Get(ClientContext &context); + static ResultSetManager &Get(DatabaseInstance &db); + ManagedResultSet Add(ColumnDataAllocator &allocator); + void Remove(ColumnDataAllocator &allocator); + +private: + mutex lock; + weak_ptr db; + reference_map_t>>> open_results; +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/main/secret/secret.hpp b/src/duckdb/src/include/duckdb/main/secret/secret.hpp index ed8034413..fd8a1b241 100644 --- a/src/duckdb/src/include/duckdb/main/secret/secret.hpp +++ b/src/duckdb/src/include/duckdb/main/secret/secret.hpp @@ -296,7 +296,9 @@ class KeyValueSecretReader { Value result; auto lookup_result = TryGetSecretKeyOrSetting(secret_key, setting_name, result); if (lookup_result) { - value_out = result.GetValue(); + if (!result.IsNull()) { + value_out = result.GetValue(); + } } return lookup_result; } diff --git a/src/duckdb/src/include/duckdb/main/settings.hpp b/src/duckdb/src/include/duckdb/main/settings.hpp index 383d5533b..217a4fb85 100644 --- a/src/duckdb/src/include/duckdb/main/settings.hpp +++ b/src/duckdb/src/include/duckdb/main/settings.hpp @@ -95,6 +95,18 @@ struct AllowExtensionsMetadataMismatchSetting { static constexpr SetScope DefaultScope = SetScope::GLOBAL; }; +struct AllowParserOverrideExtensionSetting { + using RETURN_TYPE = string; + static constexpr const char *Name = "allow_parser_override_extension"; + static constexpr const char *Description = "Allow extensions to override the current parser"; + static constexpr const char *InputType = "VARCHAR"; + static void SetGlobal(DatabaseInstance *db, DBConfig &config, const Value ¶meter); + static void ResetGlobal(DatabaseInstance *db, DBConfig &config); + static bool OnGlobalSet(DatabaseInstance *db, DBConfig &config, const Value &input); + static bool OnGlobalReset(DatabaseInstance *db, DBConfig &config); + static Value GetSetting(const ClientContext &context); +}; + struct AllowPersistentSecretsSetting { using RETURN_TYPE = bool; static constexpr const char *Name = "allow_persistent_secrets"; @@ -329,6 +341,17 @@ struct DebugForceNoCrossProductSetting { static constexpr SetScope DefaultScope = SetScope::SESSION; }; +struct DebugPhysicalTableScanExecutionStrategySetting { + using RETURN_TYPE = PhysicalTableScanExecutionStrategy; + static constexpr const char *Name = "debug_physical_table_scan_execution_strategy"; + static constexpr const char *Description = + "DEBUG SETTING: force use of given strategy for executing physical table scans"; + static constexpr const char *InputType = "VARCHAR"; + static constexpr const char *DefaultValue = "DEFAULT"; + static constexpr SetScope DefaultScope = SetScope::GLOBAL; + static void OnSet(SettingCallbackInfo &info, Value &input); +}; + struct DebugSkipCheckpointOnCommitSetting { using RETURN_TYPE = bool; static constexpr const char *Name = "debug_skip_checkpoint_on_commit"; @@ -338,6 +361,15 @@ struct DebugSkipCheckpointOnCommitSetting { static constexpr SetScope DefaultScope = SetScope::GLOBAL; }; +struct DebugVerifyBlocksSetting { + using RETURN_TYPE = bool; + static constexpr const char *Name = "debug_verify_blocks"; + static constexpr const char *Description = "DEBUG SETTING: verify block metadata during checkpointing"; + static constexpr const char *InputType = "BOOLEAN"; + static constexpr const char *DefaultValue = "false"; + static constexpr SetScope DefaultScope = SetScope::GLOBAL; +}; + struct DebugVerifyVectorSetting { using RETURN_TYPE = DebugVectorVerification; static constexpr const char *Name = "debug_verify_vector"; @@ -645,7 +677,7 @@ struct ExperimentalMetadataReuseSetting { static constexpr const char *Name = "experimental_metadata_reuse"; static constexpr const char *Description = "EXPERIMENTAL: Re-use row group and table metadata when checkpointing."; static constexpr const char *InputType = "BOOLEAN"; - static constexpr const char *DefaultValue = "false"; + static constexpr const char *DefaultValue = "true"; static constexpr SetScope DefaultScope = SetScope::GLOBAL; }; diff --git a/src/duckdb/src/include/duckdb/main/stream_query_result.hpp b/src/duckdb/src/include/duckdb/main/stream_query_result.hpp index 3c04a364c..775202ea7 100644 --- a/src/duckdb/src/include/duckdb/main/stream_query_result.hpp +++ b/src/duckdb/src/include/duckdb/main/stream_query_result.hpp @@ -44,8 +44,6 @@ class StreamQueryResult : public QueryResult { DUCKDB_API void WaitForTask(); //! Executes a single task within the final pipeline, returning whether or not a chunk is ready to be fetched DUCKDB_API StreamExecutionResult ExecuteTask(); - //! Fetches a DataChunk from the query result. - DUCKDB_API unique_ptr FetchRaw() override; //! Converts the QueryResult to a string DUCKDB_API string ToString() override; //! Materializes the query result and turns it into a materialized query result @@ -59,9 +57,12 @@ class StreamQueryResult : public QueryResult { //! The client context this StreamQueryResult belongs to shared_ptr context; +protected: + DUCKDB_API unique_ptr FetchInternal() override; + private: StreamExecutionResult ExecuteTaskInternal(ClientContextLock &lock); - unique_ptr FetchInternal(ClientContextLock &lock); + unique_ptr FetchNextInternal(ClientContextLock &lock); unique_ptr LockContext(); void CheckExecutableInternal(ClientContextLock &lock); bool IsOpenInternal(ClientContextLock &lock); diff --git a/src/duckdb/src/include/duckdb/optimizer/common_subplan_optimizer.hpp b/src/duckdb/src/include/duckdb/optimizer/common_subplan_optimizer.hpp new file mode 100644 index 000000000..8d8e35ea1 --- /dev/null +++ b/src/duckdb/src/include/duckdb/optimizer/common_subplan_optimizer.hpp @@ -0,0 +1,31 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/optimizer/common_subplan_optimizer.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/planner/logical_operator.hpp" + +namespace duckdb { + +class Optimizer; +class LogicalOperator; + +//! The CommonSubplanOptimizer optimizer detects common subplans, and converts them to refs of a materialized CTE +class CommonSubplanOptimizer { +public: + explicit CommonSubplanOptimizer(Optimizer &optimizer); + +public: + unique_ptr Optimize(unique_ptr op); + +private: + //! The optimizer + Optimizer &optimizer; +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/optimizer/cte_inlining.hpp b/src/duckdb/src/include/duckdb/optimizer/cte_inlining.hpp index 90439b11e..97529f6ee 100644 --- a/src/duckdb/src/include/duckdb/optimizer/cte_inlining.hpp +++ b/src/duckdb/src/include/duckdb/optimizer/cte_inlining.hpp @@ -25,6 +25,7 @@ class CTEInlining { public: explicit CTEInlining(Optimizer &optimizer); unique_ptr Optimize(unique_ptr op); + static bool EndsInAggregateOrDistinct(const LogicalOperator &op); private: void TryInlining(unique_ptr &op); diff --git a/src/duckdb/src/include/duckdb/optimizer/filter_pullup.hpp b/src/duckdb/src/include/duckdb/optimizer/filter_pullup.hpp index a35fbaab9..b6cb1e704 100644 --- a/src/duckdb/src/include/duckdb/optimizer/filter_pullup.hpp +++ b/src/duckdb/src/include/duckdb/optimizer/filter_pullup.hpp @@ -30,7 +30,7 @@ class FilterPullup { // only pull up filters when there is a fork bool can_pullup = false; - // identifiy case the branch is a set operation (INTERSECT or EXCEPT) + // identify case the branch is a set operation (INTERSECT or EXCEPT) bool can_add_column = false; private: @@ -40,30 +40,26 @@ class FilterPullup { //! Pull up a LogicalFilter op unique_ptr PullupFilter(unique_ptr op); - //! Pull up filter in a LogicalProjection op unique_ptr PullupProjection(unique_ptr op); - //! Pull up filter in a LogicalCrossProduct op unique_ptr PullupCrossProduct(unique_ptr op); - + //! Pullup a filter in a LogicalJoin unique_ptr PullupJoin(unique_ptr op); - - // PPullup filter in a left join + //! Pullup filter in a left join unique_ptr PullupFromLeft(unique_ptr op); - - // Pullup filter in a inner join + //! Pullup filter in an inner join unique_ptr PullupInnerJoin(unique_ptr op); - - // Pullup filter in LogicalIntersect or LogicalExcept op + //! Pullup filter through a distinct + unique_ptr PullupDistinct(unique_ptr op); + //! Pullup filter in LogicalIntersect or LogicalExcept op unique_ptr PullupSetOperation(unique_ptr op); - + //! Pullup filter in both sides of a join unique_ptr PullupBothSide(unique_ptr op); - // Finish pull up at this operator + //! Finish pull up at this operator unique_ptr FinishPullup(unique_ptr op); - - // special treatment for SetOperations and projections + //! special treatment for SetOperations and projections void ProjectSetOperation(LogicalProjection &proj); }; // end FilterPullup diff --git a/src/duckdb/src/include/duckdb/optimizer/filter_pushdown.hpp b/src/duckdb/src/include/duckdb/optimizer/filter_pushdown.hpp index c2fc87a52..29c2f0ac4 100644 --- a/src/duckdb/src/include/duckdb/optimizer/filter_pushdown.hpp +++ b/src/duckdb/src/include/duckdb/optimizer/filter_pushdown.hpp @@ -96,14 +96,17 @@ class FilterPushdown { unique_ptr FinishPushdown(unique_ptr op); //! Adds a filter to the set of filters. Returns FilterResult::UNSATISFIABLE if the subtree should be stripped, or //! FilterResult::SUCCESS otherwise + + unique_ptr PushFiltersIntoDelimJoin(unique_ptr op); FilterResult AddFilter(unique_ptr expr); //! Extract filter bindings to compare them with expressions in an operator and determine if the filter //! can be pushed down void ExtractFilterBindings(const Expression &expr, vector &bindings); //! Generate filters from the current set of filters stored in the FilterCombiner void GenerateFilters(); - //! if there are filters in this FilterPushdown node, push them into the combiner - void PushFilters(); + //! if there are filters in this FilterPushdown node, push them into the combiner. Returns + //! FilterResult::UNSATISFIABLE if the subtree should be stripped, or FilterResult::SUCCESS otherwise + FilterResult PushFilters(); }; } // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/optimizer/join_order/relation_manager.hpp b/src/duckdb/src/include/duckdb/optimizer/join_order/relation_manager.hpp index 3b8fda1c6..2a687ad1b 100644 --- a/src/duckdb/src/include/duckdb/optimizer/join_order/relation_manager.hpp +++ b/src/duckdb/src/include/duckdb/optimizer/join_order/relation_manager.hpp @@ -56,7 +56,11 @@ class RelationManager { //! Extract the set of relations referred to inside an expression bool ExtractBindings(Expression &expression, unordered_set &bindings); void AddRelation(LogicalOperator &op, optional_ptr parent, const RelationStats &stats); - + //! Add an unnest relation which can come from a logical unnest or a logical get which has an unnest function + void AddUnnestRelation(JoinOrderOptimizer &optimizer, LogicalOperator &op, LogicalOperator &input_op, + optional_ptr parent, RelationStats &child_stats, + optional_ptr limit_op, + vector> &datasource_filters); void AddAggregateOrWindowRelation(LogicalOperator &op, optional_ptr parent, const RelationStats &stats, LogicalOperatorType op_type); vector> GetRelations(); diff --git a/src/duckdb/src/include/duckdb/optimizer/late_materialization_helper.hpp b/src/duckdb/src/include/duckdb/optimizer/late_materialization_helper.hpp new file mode 100644 index 000000000..ca0589fb0 --- /dev/null +++ b/src/duckdb/src/include/duckdb/optimizer/late_materialization_helper.hpp @@ -0,0 +1,22 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/optimizer/late_materialization_helper.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/planner/binder.hpp" +#include "duckdb/planner/operator/logical_get.hpp" + +namespace duckdb { + +struct LateMaterializationHelper { + static unique_ptr CreateLHSGet(const LogicalGet &rhs, Binder &binder); + static vector GetOrInsertRowIds(LogicalGet &get, const vector &row_id_column_ids, + const vector &row_id_columns); +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/optimizer/remove_unused_columns.hpp b/src/duckdb/src/include/duckdb/optimizer/remove_unused_columns.hpp index c2e5b1fc4..1d01129ec 100644 --- a/src/duckdb/src/include/duckdb/optimizer/remove_unused_columns.hpp +++ b/src/duckdb/src/include/duckdb/optimizer/remove_unused_columns.hpp @@ -68,5 +68,6 @@ class RemoveUnusedColumns : public BaseColumnPruner { private: template void ClearUnusedExpressions(vector &list, idx_t table_idx, bool replace = true); + void RemoveColumnsFromLogicalGet(LogicalGet &get); }; } // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/optimizer/rule/constant_order_normalization.hpp b/src/duckdb/src/include/duckdb/optimizer/rule/constant_order_normalization.hpp new file mode 100644 index 000000000..775ba01fb --- /dev/null +++ b/src/duckdb/src/include/duckdb/optimizer/rule/constant_order_normalization.hpp @@ -0,0 +1,25 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/optimizer/rule/constant_order_normalization.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/optimizer/rule.hpp" + +namespace duckdb { + +// Move constant expression parameters to the left in expression(i.e. x + 2 + y + 2 => 2 + 2 + x + y) +// for convenience of other rules(i.e. ConstantFoldingRule). +class ConstantOrderNormalizationRule : public Rule { +public: + explicit ConstantOrderNormalizationRule(ExpressionRewriter &rewriter); + + unique_ptr Apply(LogicalOperator &op, vector> &bindings, bool &changes_made, + bool is_root) override; +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/optimizer/rule/list.hpp b/src/duckdb/src/include/duckdb/optimizer/rule/list.hpp index 4b0099cd8..312fa7d93 100644 --- a/src/duckdb/src/include/duckdb/optimizer/rule/list.hpp +++ b/src/duckdb/src/include/duckdb/optimizer/rule/list.hpp @@ -1,5 +1,6 @@ #include "duckdb/optimizer/rule/arithmetic_simplification.hpp" #include "duckdb/optimizer/rule/case_simplification.hpp" +#include "duckdb/optimizer/rule/constant_order_normalization.hpp" #include "duckdb/optimizer/rule/comparison_simplification.hpp" #include "duckdb/optimizer/rule/conjunction_simplification.hpp" #include "duckdb/optimizer/rule/constant_folding.hpp" diff --git a/src/duckdb/src/include/duckdb/optimizer/topn_window_elimination.hpp b/src/duckdb/src/include/duckdb/optimizer/topn_window_elimination.hpp new file mode 100644 index 000000000..0fb5ba00c --- /dev/null +++ b/src/duckdb/src/include/duckdb/optimizer/topn_window_elimination.hpp @@ -0,0 +1,83 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/optimizer/topn_window_elimination.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/main/client_context.hpp" +#include "duckdb/optimizer/column_binding_replacer.hpp" +#include "duckdb/optimizer/remove_unused_columns.hpp" + +namespace duckdb { + +enum class TopNPayloadType { SINGLE_COLUMN, STRUCT_PACK }; + +struct TopNWindowEliminationParameters { + //! Whether the sort is ASCENDING or DESCENDING + OrderType order_type; + //! The number of values in the LIMIT clause + int64_t limit; + //! How we fetch the payload columns + TopNPayloadType payload_type; + //! Whether to include row numbers + bool include_row_number; + //! Whether the val or arg column contains null values + bool can_be_null = false; +}; + +class TopNWindowElimination : public BaseColumnPruner { +public: + explicit TopNWindowElimination(ClientContext &context, Optimizer &optimizer, + optional_ptr>> stats_p); + + unique_ptr Optimize(unique_ptr op); + +private: + bool CanOptimize(LogicalOperator &op); + unique_ptr OptimizeInternal(unique_ptr op, ColumnBindingReplacer &replacer); + + unique_ptr CreateAggregateOperator(LogicalWindow &window, vector> args, + const TopNWindowEliminationParameters ¶ms) const; + unique_ptr TryCreateUnnestOperator(unique_ptr op, + const TopNWindowEliminationParameters ¶ms) const; + unique_ptr CreateProjectionOperator(unique_ptr op, + const TopNWindowEliminationParameters ¶ms, + const map &group_idxs) const; + + vector> GenerateAggregatePayload(const vector &bindings, + const LogicalWindow &window, map &group_idxs); + vector TraverseProjectionBindings(const std::vector &old_bindings, + reference &op); + unique_ptr CreateAggregateExpression(vector> aggregate_params, bool requires_arg, + const TopNWindowEliminationParameters ¶ms) const; + unique_ptr CreateRowNumberGenerator(unique_ptr aggregate_column_ref) const; + void AddStructExtractExprs(vector> &exprs, const LogicalType &struct_type, + const unique_ptr &aggregate_column_ref) const; + static void UpdateTopmostBindings(idx_t window_idx, const unique_ptr &op, + const map &group_idxs, + const vector &topmost_bindings, + vector &new_bindings, ColumnBindingReplacer &replacer); + TopNWindowEliminationParameters ExtractOptimizerParameters(const LogicalWindow &window, const LogicalFilter &filter, + const vector &bindings, + vector> &aggregate_payload); + + // Semi-join reduction methods + unique_ptr TryPrepareLateMaterialization(const LogicalWindow &window, + vector> &args); + unique_ptr ConstructLHS(LogicalGet &rhs, vector &projections) const; + static unique_ptr ConstructJoin(unique_ptr lhs, unique_ptr rhs, + idx_t rhs_rowid_idx, + const TopNWindowEliminationParameters ¶ms); + bool CanUseLateMaterialization(const LogicalWindow &window, vector> &args, + vector &projections, vector> &stack); + +private: + ClientContext &context; + Optimizer &optimizer; + optional_ptr>> stats; +}; +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/parallel/async_result.hpp b/src/duckdb/src/include/duckdb/parallel/async_result.hpp new file mode 100644 index 000000000..97ede1cbc --- /dev/null +++ b/src/duckdb/src/include/duckdb/parallel/async_result.hpp @@ -0,0 +1,72 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/parallel/async_result.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/common/enum_util.hpp" +#include "duckdb/common/unique_ptr.hpp" +#include "duckdb/common/vector.hpp" +#include "duckdb/common/enums/operator_result_type.hpp" + +namespace duckdb { + +class InterruptState; +class TaskExecutor; +class Executor; + +enum class AsyncResultsExecutionMode : uint8_t { + SYNCHRONOUS, // BLOCKED should not bubble up, and they should be executed synchronously + TASK_EXECUTOR // BLOCKED is allowed +}; + +class AsyncTask { +public: + virtual ~AsyncTask() {}; + virtual void Execute() = 0; +}; + +class AsyncResult { + explicit AsyncResult(AsyncResultType t); + +public: + AsyncResult() = default; + AsyncResult(AsyncResult &&) = default; + AsyncResult(SourceResultType t); // NOLINT + explicit AsyncResult(vector> &&task); + AsyncResult &operator=(SourceResultType t); + AsyncResult &operator=(AsyncResultType t); + AsyncResult &operator=(AsyncResult &&) noexcept; + // Schedule held async_tasks into the Executor, eventually unblocking InterruptState + // needs to be called with non-emopty async_tasks and from BLOCKED state, will empty the async_tasks and transform + // into INVALID + void ScheduleTasks(InterruptState &interrupt_state, Executor &executor); + // Execute tasks synchronously at callsite + // needs to be called with non-emopty async_tasks and from BLOCKED state, will empty the async_tasks and transform + // into HAVE_MORE_OUTPUT + void ExecuteTasksSynchronously(); + + static AsyncResultType GetAsyncResultType(SourceResultType s); + + // Check whether there are tasks associated + bool HasTasks() const; + AsyncResultType GetResultType() const; + // Extract associated tasks, moving them away, will empty async_tasks and trasnform to INVALID + vector> &&ExtractAsyncTasks(); + +#ifdef DUCKDB_DEBUG_ASYNC_SINK_SOURCE + static vector> GenerateTestTasks(); +#endif + + static AsyncResultsExecutionMode + ConvertToAsyncResultExecutionMode(const PhysicalTableScanExecutionStrategy &execution_mode); + +private: + AsyncResultType result_type {AsyncResultType::INVALID}; + vector> async_tasks {}; +}; +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/parallel/interrupt.hpp b/src/duckdb/src/include/duckdb/parallel/interrupt.hpp index ef0bf8139..b2db8497e 100644 --- a/src/duckdb/src/include/duckdb/parallel/interrupt.hpp +++ b/src/duckdb/src/include/duckdb/parallel/interrupt.hpp @@ -83,6 +83,11 @@ class StateWithBlockableTasks { return false; } + bool CanBlock(const unique_lock &guard) const { + VerifyLock(guard); + return can_block; + } + //! Unblock all tasks (must hold the lock) bool UnblockTasks(const unique_lock &guard) { VerifyLock(guard); diff --git a/src/duckdb/src/include/duckdb/parallel/pipeline_executor.hpp b/src/duckdb/src/include/duckdb/parallel/pipeline_executor.hpp index 9781e6fb8..ce03be210 100644 --- a/src/duckdb/src/include/duckdb/parallel/pipeline_executor.hpp +++ b/src/duckdb/src/include/duckdb/parallel/pipeline_executor.hpp @@ -152,7 +152,7 @@ class PipelineExecutor { OperatorResultType Execute(DataChunk &input, DataChunk &result, idx_t initial_index = 0); //! Notifies the sink that a new batch has started - SinkNextBatchType NextBatch(DataChunk &source_chunk); + SinkNextBatchType NextBatch(DataChunk &source_chunk, const bool have_more_output); //! Tries to flush all state from intermediate operators. Will return true if all state is flushed, false in the //! case of a blocked sink. diff --git a/src/duckdb/src/include/duckdb/parallel/sleep_async_task.hpp b/src/duckdb/src/include/duckdb/parallel/sleep_async_task.hpp new file mode 100644 index 000000000..f53fc1d4f --- /dev/null +++ b/src/duckdb/src/include/duckdb/parallel/sleep_async_task.hpp @@ -0,0 +1,39 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/parallel/sleep_async_task.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/parallel/async_result.hpp" + +#include +#include + +namespace duckdb { + +class SleepAsyncTask : public AsyncTask { +public: + explicit SleepAsyncTask(idx_t sleep_for) : sleep_for(sleep_for) { + } + void Execute() override { + std::this_thread::sleep_for(std::chrono::milliseconds(sleep_for)); + } + const idx_t sleep_for; +}; + +class ThrowAsyncTask : public AsyncTask { +public: + explicit ThrowAsyncTask(idx_t sleep_for) : sleep_for(sleep_for) { + } + void Execute() override { + std::this_thread::sleep_for(std::chrono::milliseconds(sleep_for)); + throw NotImplementedException("ThrowAsyncTask: Test error handling when throwing mid-task"); + } + const idx_t sleep_for; +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/parser/common_table_expression_info.hpp b/src/duckdb/src/include/duckdb/parser/common_table_expression_info.hpp index 552753692..7a784eeb1 100644 --- a/src/duckdb/src/include/duckdb/parser/common_table_expression_info.hpp +++ b/src/duckdb/src/include/duckdb/parser/common_table_expression_info.hpp @@ -16,16 +16,20 @@ namespace duckdb { class SelectStatement; struct CommonTableExpressionInfo { + ~CommonTableExpressionInfo(); + vector aliases; vector> key_targets; unique_ptr query; CTEMaterialize materialized = CTEMaterialize::CTE_MATERIALIZE_DEFAULT; +public: void Serialize(Serializer &serializer) const; static unique_ptr Deserialize(Deserializer &deserializer); unique_ptr Copy(); - ~CommonTableExpressionInfo(); +private: + CTEMaterialize GetMaterializedForSerialization(Serializer &serializer) const; }; } // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/parser/expression/window_expression.hpp b/src/duckdb/src/include/duckdb/parser/expression/window_expression.hpp index cc5e9c61f..dec201238 100644 --- a/src/duckdb/src/include/duckdb/parser/expression/window_expression.hpp +++ b/src/duckdb/src/include/duckdb/parser/expression/window_expression.hpp @@ -13,6 +13,11 @@ namespace duckdb { +struct WindowFunctionDefinition { + const char *name; + ExpressionType expression_type; +}; + enum class WindowBoundary : uint8_t { INVALID = 0, UNBOUNDED_PRECEDING = 1, @@ -92,6 +97,7 @@ class WindowExpression : public ParsedExpression { void Serialize(Serializer &serializer) const override; static unique_ptr Deserialize(Deserializer &deserializer); + static const WindowFunctionDefinition *WindowFunctions(); static ExpressionType WindowToExpressionType(string &fun_name); public: diff --git a/src/duckdb/src/include/duckdb/parser/parsed_data/sample_options.hpp b/src/duckdb/src/include/duckdb/parser/parsed_data/sample_options.hpp index dadbcfe92..88fc14831 100644 --- a/src/duckdb/src/include/duckdb/parser/parsed_data/sample_options.hpp +++ b/src/duckdb/src/include/duckdb/parser/parsed_data/sample_options.hpp @@ -23,7 +23,6 @@ enum class SampleMethod : uint8_t { SYSTEM_SAMPLE = 0, BERNOULLI_SAMPLE = 1, RES string SampleMethodToString(SampleMethod method); class SampleOptions { - public: explicit SampleOptions(int64_t seed_ = -1); diff --git a/src/duckdb/src/include/duckdb/parser/parsed_data/vacuum_info.hpp b/src/duckdb/src/include/duckdb/parser/parsed_data/vacuum_info.hpp index 5540d38a2..12c4f77ca 100644 --- a/src/duckdb/src/include/duckdb/parser/parsed_data/vacuum_info.hpp +++ b/src/duckdb/src/include/duckdb/parser/parsed_data/vacuum_info.hpp @@ -10,7 +10,6 @@ #include "duckdb/parser/parsed_data/parse_info.hpp" #include "duckdb/parser/tableref.hpp" -#include "duckdb/planner/tableref/bound_basetableref.hpp" #include "duckdb/common/unordered_map.hpp" #include "duckdb/common/optional_ptr.hpp" #include "duckdb/catalog/dependency_list.hpp" diff --git a/src/duckdb/src/include/duckdb/parser/parser.hpp b/src/duckdb/src/include/duckdb/parser/parser.hpp index ce373fe9c..154bb860f 100644 --- a/src/duckdb/src/include/duckdb/parser/parser.hpp +++ b/src/duckdb/src/include/duckdb/parser/parser.hpp @@ -14,6 +14,7 @@ #include "duckdb/parser/column_list.hpp" #include "duckdb/parser/simplified_token.hpp" #include "duckdb/parser/parser_options.hpp" +#include "duckdb/parser/parser_extension.hpp" namespace duckdb_libpgquery { struct PGNode; @@ -73,6 +74,9 @@ class Parser { static bool StripUnicodeSpaces(const string &query_str, string &new_query); + StatementType GetStatementType(const string &query); + void ThrowParserOverrideError(ParserOverrideResult &result); + private: ParserOptions options; }; diff --git a/src/duckdb/src/include/duckdb/parser/parser_extension.hpp b/src/duckdb/src/include/duckdb/parser/parser_extension.hpp index 61c071307..a3d2dcf64 100644 --- a/src/duckdb/src/include/duckdb/parser/parser_extension.hpp +++ b/src/duckdb/src/include/duckdb/parser/parser_extension.hpp @@ -86,12 +86,12 @@ struct ParserOverrideResult { explicit ParserOverrideResult(vector> statements_p) : type(ParserExtensionResultType::PARSE_SUCCESSFUL), statements(std::move(statements_p)) {}; - explicit ParserOverrideResult(const string &error_p) + explicit ParserOverrideResult(std::exception &error_p) : type(ParserExtensionResultType::DISPLAY_EXTENSION_ERROR), error(error_p) {}; ParserExtensionResultType type; vector> statements; - string error; + ErrorData error; }; typedef ParserOverrideResult (*parser_override_function_t)(ParserExtensionInfo *info, const string &query); @@ -103,14 +103,14 @@ class ParserExtension { public: //! The parse function of the parser extension. //! Takes a query string as input and returns ParserExtensionParseData (on success) or an error - parse_function_t parse_function; + parse_function_t parse_function = nullptr; //! The plan function of the parser extension //! Takes as input the result of the parse_function, and outputs various properties of the resulting plan - plan_function_t plan_function; + plan_function_t plan_function = nullptr; //! Override the current parser with a new parser and return a vector of SQL statements - parser_override_function_t parser_override; + parser_override_function_t parser_override = nullptr; //! Additional parser info passed to the parse function shared_ptr parser_info; diff --git a/src/duckdb/src/include/duckdb/parser/parser_options.hpp b/src/duckdb/src/include/duckdb/parser/parser_options.hpp index d388fb116..d9a42632a 100644 --- a/src/duckdb/src/include/duckdb/parser/parser_options.hpp +++ b/src/duckdb/src/include/duckdb/parser/parser_options.hpp @@ -18,6 +18,7 @@ struct ParserOptions { bool integer_division = false; idx_t max_expression_depth = 1000; const vector *extensions = nullptr; + string parser_override_setting = "default"; }; } // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/parser/query_node.hpp b/src/duckdb/src/include/duckdb/parser/query_node.hpp index ec03da095..5c091b259 100644 --- a/src/duckdb/src/include/duckdb/parser/query_node.hpp +++ b/src/duckdb/src/include/duckdb/parser/query_node.hpp @@ -25,7 +25,8 @@ enum class QueryNodeType : uint8_t { SET_OPERATION_NODE = 2, BOUND_SUBQUERY_NODE = 3, RECURSIVE_CTE_NODE = 4, - CTE_NODE = 5 + CTE_NODE = 5, + STATEMENT_NODE = 6 }; struct CommonTableExpressionInfo; @@ -59,8 +60,6 @@ class QueryNode { //! CTEs (used by SelectNode and SetOperationNode) CommonTableExpressionMap cte_map; - virtual const vector> &GetSelectList() const = 0; - public: //! Convert the query node to a string virtual string ToString() const = 0; diff --git a/src/duckdb/src/include/duckdb/parser/query_node/cte_node.hpp b/src/duckdb/src/include/duckdb/parser/query_node/cte_node.hpp index bc997a6c7..fd2589fd2 100644 --- a/src/duckdb/src/include/duckdb/parser/query_node/cte_node.hpp +++ b/src/duckdb/src/include/duckdb/parser/query_node/cte_node.hpp @@ -14,6 +14,7 @@ namespace duckdb { +//! DEPRECATED - CTENode is only preserved for backwards compatibility when serializing older databases class CTENode : public QueryNode { public: static constexpr const QueryNodeType TYPE = QueryNodeType::CTE_NODE; @@ -23,30 +24,18 @@ class CTENode : public QueryNode { } string ctename; - //! The query of the CTE unique_ptr query; - //! Child unique_ptr child; - //! Aliases of the CTE node vector aliases; CTEMaterialize materialized = CTEMaterialize::CTE_MATERIALIZE_DEFAULT; - const vector> &GetSelectList() const override { - return query->GetSelectList(); - } - public: - //! Convert the query node to a string string ToString() const override; bool Equals(const QueryNode *other) const override; - //! Create a copy of this SelectNode unique_ptr Copy() const override; - //! Serializes a QueryNode to a stand-alone binary blob - //! Deserializes a blob back into a QueryNode - void Serialize(Serializer &serializer) const override; static unique_ptr Deserialize(Deserializer &source); }; diff --git a/src/duckdb/src/include/duckdb/parser/query_node/list.hpp b/src/duckdb/src/include/duckdb/parser/query_node/list.hpp index 94bfd3438..3a2894cc4 100644 --- a/src/duckdb/src/include/duckdb/parser/query_node/list.hpp +++ b/src/duckdb/src/include/duckdb/parser/query_node/list.hpp @@ -2,3 +2,4 @@ #include "duckdb/parser/query_node/cte_node.hpp" #include "duckdb/parser/query_node/select_node.hpp" #include "duckdb/parser/query_node/set_operation_node.hpp" +#include "duckdb/parser/query_node/statement_node.hpp" diff --git a/src/duckdb/src/include/duckdb/parser/query_node/recursive_cte_node.hpp b/src/duckdb/src/include/duckdb/parser/query_node/recursive_cte_node.hpp index 6d73fda4a..1f5f16ead 100644 --- a/src/duckdb/src/include/duckdb/parser/query_node/recursive_cte_node.hpp +++ b/src/duckdb/src/include/duckdb/parser/query_node/recursive_cte_node.hpp @@ -33,10 +33,6 @@ class RecursiveCTENode : public QueryNode { //! targets for key variants vector> key_targets; - const vector> &GetSelectList() const override { - return left->GetSelectList(); - } - public: //! Convert the query node to a string string ToString() const override; diff --git a/src/duckdb/src/include/duckdb/parser/query_node/select_node.hpp b/src/duckdb/src/include/duckdb/parser/query_node/select_node.hpp index 62aa9c0b2..dfc474d14 100644 --- a/src/duckdb/src/include/duckdb/parser/query_node/select_node.hpp +++ b/src/duckdb/src/include/duckdb/parser/query_node/select_node.hpp @@ -43,10 +43,6 @@ class SelectNode : public QueryNode { //! The SAMPLE clause unique_ptr sample; - const vector> &GetSelectList() const override { - return select_list; - } - public: //! Convert the query node to a string string ToString() const override; diff --git a/src/duckdb/src/include/duckdb/parser/query_node/set_operation_node.hpp b/src/duckdb/src/include/duckdb/parser/query_node/set_operation_node.hpp index 960f6c2d6..3070e2245 100644 --- a/src/duckdb/src/include/duckdb/parser/query_node/set_operation_node.hpp +++ b/src/duckdb/src/include/duckdb/parser/query_node/set_operation_node.hpp @@ -29,8 +29,6 @@ class SetOperationNode : public QueryNode { //! The children of the set operation vector> children; - const vector> &GetSelectList() const override; - public: //! Convert the query node to a string string ToString() const override; diff --git a/src/duckdb/src/include/duckdb/parser/query_node/statement_node.hpp b/src/duckdb/src/include/duckdb/parser/query_node/statement_node.hpp new file mode 100644 index 000000000..26db46a58 --- /dev/null +++ b/src/duckdb/src/include/duckdb/parser/query_node/statement_node.hpp @@ -0,0 +1,41 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/parser/query_node/statement_node.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/parser/parsed_expression.hpp" +#include "duckdb/parser/query_node.hpp" +#include "duckdb/parser/sql_statement.hpp" + +namespace duckdb { + +class StatementNode : public QueryNode { +public: + static constexpr const QueryNodeType TYPE = QueryNodeType::STATEMENT_NODE; + +public: + explicit StatementNode(SQLStatement &stmt_p); + + SQLStatement &stmt; + +public: + //! Convert the query node to a string + string ToString() const override; + + bool Equals(const QueryNode *other) const override; + //! Create a copy of this SelectNode + unique_ptr Copy() const override; + + //! Serializes a QueryNode to a stand-alone binary blob + //! Deserializes a blob back into a QueryNode + + void Serialize(Serializer &serializer) const override; + static unique_ptr Deserialize(Deserializer &source); +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/parser/tableref/bound_ref_wrapper.hpp b/src/duckdb/src/include/duckdb/parser/tableref/bound_ref_wrapper.hpp index be997eb5d..c8586a559 100644 --- a/src/duckdb/src/include/duckdb/parser/tableref/bound_ref_wrapper.hpp +++ b/src/duckdb/src/include/duckdb/parser/tableref/bound_ref_wrapper.hpp @@ -9,7 +9,6 @@ #pragma once #include "duckdb/parser/tableref.hpp" -#include "duckdb/planner/bound_tableref.hpp" #include "duckdb/planner/binder.hpp" namespace duckdb { @@ -20,10 +19,10 @@ class BoundRefWrapper : public TableRef { static constexpr const TableReferenceType TYPE = TableReferenceType::BOUND_TABLE_REF; public: - BoundRefWrapper(unique_ptr bound_ref_p, shared_ptr binder_p); + BoundRefWrapper(BoundStatement bound_ref_p, shared_ptr binder_p); //! The bound reference object - unique_ptr bound_ref; + BoundStatement bound_ref; //! The binder that was used to bind this table ref shared_ptr binder; diff --git a/src/duckdb/src/include/duckdb/parser/tableref/delimgetref.hpp b/src/duckdb/src/include/duckdb/parser/tableref/delimgetref.hpp index ea6889362..1cf31d1eb 100644 --- a/src/duckdb/src/include/duckdb/parser/tableref/delimgetref.hpp +++ b/src/duckdb/src/include/duckdb/parser/tableref/delimgetref.hpp @@ -13,7 +13,6 @@ namespace duckdb { class DelimGetRef : public TableRef { - public: explicit DelimGetRef(const vector &types_p) : TableRef(TableReferenceType::DELIM_GET), types(types_p) { for (idx_t i = 0; i < types.size(); i++) { diff --git a/src/duckdb/src/include/duckdb/parser/tokens.hpp b/src/duckdb/src/include/duckdb/parser/tokens.hpp index 6eeb8c5e2..d5646739c 100644 --- a/src/duckdb/src/include/duckdb/parser/tokens.hpp +++ b/src/duckdb/src/include/duckdb/parser/tokens.hpp @@ -53,6 +53,7 @@ class SelectNode; class SetOperationNode; class RecursiveCTENode; class CTENode; +class StatementNode; //===--------------------------------------------------------------------===// // Expressions diff --git a/src/duckdb/src/include/duckdb/parser/transformer.hpp b/src/duckdb/src/include/duckdb/parser/transformer.hpp index 59e4f0419..1945ebc5a 100644 --- a/src/duckdb/src/include/duckdb/parser/transformer.hpp +++ b/src/duckdb/src/include/duckdb/parser/transformer.hpp @@ -80,7 +80,7 @@ class Transformer { //! The set of pivot entries to create vector> pivot_entries; //! Sets of stored CTEs, if any - vector stored_cte_map; + vector> stored_cte_map; //! Whether or not we are currently binding a window definition bool in_window_definition = false; @@ -304,7 +304,6 @@ class Transformer { string TransformAlias(duckdb_libpgquery::PGAlias *root, vector &column_name_alias); vector TransformStringList(duckdb_libpgquery::PGList *list); void TransformCTE(duckdb_libpgquery::PGWithClause &de_with_clause, CommonTableExpressionMap &cte_map); - static unique_ptr TransformMaterializedCTE(unique_ptr root); unique_ptr TransformRecursiveCTE(duckdb_libpgquery::PGCommonTableExpr &node, CommonTableExpressionInfo &info); diff --git a/src/duckdb/src/include/duckdb/planner/bind_context.hpp b/src/duckdb/src/include/duckdb/planner/bind_context.hpp index d9c20dd1d..db5b52c78 100644 --- a/src/duckdb/src/include/duckdb/planner/bind_context.hpp +++ b/src/duckdb/src/include/duckdb/planner/bind_context.hpp @@ -23,7 +23,7 @@ namespace duckdb { class Binder; class LogicalGet; -class BoundQueryNode; +struct BoundStatement; class StarExpression; @@ -43,9 +43,6 @@ class BindContext { public: explicit BindContext(Binder &binder); - //! Keep track of recursive CTE references - case_insensitive_map_t> cte_references; - public: //! Given a column name, find the matching table it belongs to. Throws an //! exception if no table has a column of the given name. @@ -57,7 +54,7 @@ class BindContext { //! matching ones vector GetSimilarBindings(const string &column_name); - optional_ptr GetCTEBinding(const string &ctename); + optional_ptr GetCTEBinding(const BindingAlias &ctename); //! Binds a column expression to the base table. Returns the bound expression //! or throws an exception if the column could not be bound. BindResult BindColumn(ColumnRefExpression &colref, idx_t depth); @@ -105,11 +102,11 @@ class BindContext { const vector &types, vector &bound_column_ids, optional_ptr entry, virtual_column_map_t virtual_columns); //! Adds a table view with a given alias to the BindContext. - void AddView(idx_t index, const string &alias, SubqueryRef &ref, BoundQueryNode &subquery, ViewCatalogEntry &view); + void AddView(idx_t index, const string &alias, SubqueryRef &ref, BoundStatement &subquery, ViewCatalogEntry &view); //! Adds a subquery with a given alias to the BindContext. - void AddSubquery(idx_t index, const string &alias, SubqueryRef &ref, BoundQueryNode &subquery); + void AddSubquery(idx_t index, const string &alias, SubqueryRef &ref, BoundStatement &subquery); //! Adds a subquery with a given alias to the BindContext. - void AddSubquery(idx_t index, const string &alias, TableFunctionRef &ref, BoundQueryNode &subquery); + void AddSubquery(idx_t index, const string &alias, TableFunctionRef &ref, BoundStatement &subquery); //! Adds a binding to a catalog entry with a given alias to the BindContext. void AddEntryBinding(idx_t index, const string &alias, const vector &names, const vector &types, StandardEntry &entry); @@ -119,10 +116,9 @@ class BindContext { //! Adds a base table with the given alias to the CTE BindContext. //! We need this to correctly bind recursive CTEs with multiple references. - void AddCTEBinding(idx_t index, const string &alias, const vector &names, const vector &types, - bool using_key = false); - - void RemoveCTEBinding(const string &alias); + void AddCTEBinding(idx_t index, BindingAlias alias, const vector &names, const vector &types, + CTEType cte_type = CTEType::CAN_BE_REFERENCED); + void AddCTEBinding(unique_ptr binding); //! Add an implicit join condition (e.g. USING (x)) void AddUsingBinding(const string &column_name, UsingColumnSet &set); @@ -146,13 +142,6 @@ class BindContext { string GetActualColumnName(const BindingAlias &binding_alias, const string &column_name); string GetActualColumnName(Binding &binding, const string &column_name); - case_insensitive_map_t> GetCTEBindings() { - return cte_bindings; - } - void SetCTEBindings(case_insensitive_map_t> bindings) { - cte_bindings = std::move(bindings); - } - //! Alias a set of column names for the specified table, using the original names if there are not enough aliases //! specified. static vector AliasColumnNames(const string &table_name, const vector &names, @@ -184,10 +173,7 @@ class BindContext { vector> bindings_list; //! The set of columns used in USING join conditions case_insensitive_map_t> using_columns; - //! Using column sets - vector> using_column_sets; - //! The set of CTE bindings - case_insensitive_map_t> cte_bindings; + vector> cte_bindings; }; } // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/planner/binder.hpp b/src/duckdb/src/include/duckdb/planner/binder.hpp index 5a664f2dc..cf01e74d6 100644 --- a/src/duckdb/src/include/duckdb/planner/binder.hpp +++ b/src/duckdb/src/include/duckdb/planner/binder.hpp @@ -27,7 +27,6 @@ #include "duckdb/planner/joinside.hpp" #include "duckdb/planner/bound_constraint.hpp" #include "duckdb/planner/logical_operator.hpp" -#include "duckdb/planner/tableref/bound_delimgetref.hpp" #include "duckdb/common/enums/copy_option_mode.hpp" //! fwd declare @@ -69,6 +68,8 @@ struct PivotColumnEntry; struct UnpivotEntry; struct CopyInfo; struct CopyOption; +struct BoundSetOpChild; +struct BoundCTEData; template class IndexVector; @@ -100,6 +101,89 @@ struct CorrelatedColumnInfo { } }; +struct CorrelatedColumns { +private: + using container_type = vector; + +public: + CorrelatedColumns() : delim_index(1ULL << 63) { + } + + void AddColumn(container_type::value_type info) { + // Add to beginning + correlated_columns.insert(correlated_columns.begin(), std::move(info)); + delim_index++; + } + + void SetDelimIndexToZero() { + delim_index = 0; + } + + idx_t GetDelimIndex() const { + return delim_index; + } + + const container_type::value_type &operator[](const idx_t &index) const { + return correlated_columns.at(index); + } + + idx_t size() const { // NOLINT: match stl case + return correlated_columns.size(); + } + + bool empty() const { // NOLINT: match stl case + return correlated_columns.empty(); + } + + void clear() { // NOLINT: match stl case + correlated_columns.clear(); + } + + container_type::iterator begin() { // NOLINT: match stl case + return correlated_columns.begin(); + } + + container_type::iterator end() { // NOLINT: match stl case + return correlated_columns.end(); + } + + container_type::const_iterator begin() const { // NOLINT: match stl case + return correlated_columns.begin(); + } + + container_type::const_iterator end() const { // NOLINT: match stl case + return correlated_columns.end(); + } + +private: + container_type correlated_columns; + idx_t delim_index; +}; + +//! GlobalBinderState is state shared over the ENTIRE query, including subqueries, views, etc +struct GlobalBinderState { + //! The count of bound_tables + idx_t bound_tables = 0; + //! Statement properties + StatementProperties prop; + //! Binding mode + BindingMode mode = BindingMode::STANDARD_BINDING; + //! Table names extracted for BindingMode::EXTRACT_NAMES or BindingMode::EXTRACT_QUALIFIED_NAMES. + unordered_set table_names; + //! Replacement Scans extracted for BindingMode::EXTRACT_REPLACEMENT_SCANS + case_insensitive_map_t> replacement_scans; + //! Using column sets + vector> using_column_sets; + //! The set of parameter expressions bound by this binder + optional_ptr parameters; +}; + +// QueryBinderState is state shared WITHIN a query, a new query-binder state is created when binding inside e.g. a view +struct QueryBinderState { + //! The vector of active binders + vector> active_binders; +}; + //! Bind the parsed query tree to the actual columns present in the catalog. /*! The binder is responsible for binding tables and columns to actual physical @@ -116,15 +200,11 @@ class Binder : public enable_shared_from_this { //! The client context ClientContext &context; - //! A mapping of names to common table expressions - case_insensitive_set_t CTE_bindings; // NOLINT //! The bind context BindContext bind_context; //! The set of correlated columns bound by this binder (FIXME: this should probably be an unordered_set and not a //! vector) - vector correlated_columns; - //! The set of parameter expressions bound by this binder - optional_ptr parameters; + CorrelatedColumns correlated_columns; //! The alias for the currently processing subquery, if it exists string alias; //! Macro parameter bindings (if any) @@ -171,8 +251,7 @@ class Binder : public enable_shared_from_this { QueryErrorContext &error_context, string &func_name); unique_ptr BindPragma(PragmaInfo &info, QueryErrorContext error_context); - unique_ptr Bind(TableRef &ref); - unique_ptr CreatePlan(BoundTableRef &ref); + BoundStatement Bind(TableRef &ref); //! Generates an unused index for a table idx_t GenerateTableIndex(); @@ -180,12 +259,8 @@ class Binder : public enable_shared_from_this { optional_ptr GetCatalogEntry(const string &catalog, const string &schema, const EntryLookupInfo &lookup_info, OnEntryNotFound on_entry_not_found); - //! Add a common table expression to the binder - void AddCTE(const string &name); //! Find all candidate common table expression by name; returns empty vector if none exists - vector> FindCTE(const string &name, bool skip = false); - - bool CTEExists(const string &name); + optional_ptr GetCTEBinding(const BindingAlias &name); //! Add the view to the set of currently bound views - used for detecting recursive view definitions void AddBoundView(ViewCatalogEntry &view); @@ -198,7 +273,7 @@ class Binder : public enable_shared_from_this { vector> &GetActiveBinders(); - void MergeCorrelatedColumns(vector &other); + void MergeCorrelatedColumns(CorrelatedColumns &other); //! Add a correlated column to this binder (if it does not exist) void AddCorrelatedColumn(const CorrelatedColumnInfo &info); @@ -228,12 +303,11 @@ class Binder : public enable_shared_from_this { void AddReplacementScan(const string &table_name, unique_ptr replacement); const unordered_set &GetTableNames(); case_insensitive_map_t> &GetReplacementScans(); - optional_ptr GetRootStatement() { - return root_statement; - } CatalogEntryRetriever &EntryRetriever() { return entry_retriever; } + optional_ptr GetParameters(); + void SetParameters(BoundParameterMap ¶meters); //! Returns a ColumnRefExpression after it was resolved (i.e. past the STAR expression/USING clauses) static optional_ptr GetResolvedColumnExpression(ParsedExpression &root_expr); @@ -250,42 +324,28 @@ class Binder : public enable_shared_from_this { private: //! The parent binder (if any) shared_ptr parent; - //! The vector of active binders - vector> active_binders; - //! The count of bound_tables - idx_t bound_tables; + //! What kind of node we are binding using this binder + BinderType binder_type = BinderType::REGULAR_BINDER; + //! Global binder state + shared_ptr global_binder_state; + //! Query binder state + shared_ptr query_binder_state; //! Whether or not the binder has any unplanned dependent joins that still need to be planned/flattened bool has_unplanned_dependent_joins = false; //! Whether or not outside dependent joins have been planned and flattened bool is_outside_flattened = true; - //! What kind of node we are binding using this binder - BinderType binder_type = BinderType::REGULAR_BINDER; //! Whether or not the binder can contain NULLs as the root of expressions bool can_contain_nulls = false; - //! The root statement of the query that is currently being parsed - optional_ptr root_statement; - //! Binding mode - BindingMode mode = BindingMode::STANDARD_BINDING; - //! Table names extracted for BindingMode::EXTRACT_NAMES or BindingMode::EXTRACT_QUALIFIED_NAMES. - unordered_set table_names; - //! Replacement Scans extracted for BindingMode::EXTRACT_REPLACEMENT_SCANS - case_insensitive_map_t> replacement_scans; //! The set of bound views reference_set_t bound_views; //! Used to retrieve CatalogEntry's CatalogEntryRetriever entry_retriever; //! Unnamed subquery index idx_t unnamed_subquery_index = 1; - //! Statement properties - StatementProperties prop; - //! Root binder - Binder &root_binder; //! Binder depth idx_t depth; private: - //! Get the root binder (binder with no parent) - Binder &GetRootBinder(); //! Determine the depth of the binder idx_t GetBinderDepth() const; //! Increase the depth of the binder @@ -303,7 +363,7 @@ class Binder : public enable_shared_from_this { void MoveCorrelatedExpressions(Binder &other); //! Tries to bind the table name with replacement scans - unique_ptr BindWithReplacementScan(ClientContext &context, BaseTableRef &ref); + BoundStatement BindWithReplacementScan(ClientContext &context, BaseTableRef &ref); template BoundStatement BindWithCTE(T &statement); @@ -344,41 +404,39 @@ class Binder : public enable_shared_from_this { unique_ptr BindTableMacro(FunctionExpression &function, TableMacroCatalogEntry ¯o_func, idx_t depth); - unique_ptr BindMaterializedCTE(CommonTableExpressionMap &cte_map); - unique_ptr BindCTE(CTENode &statement); + BoundStatement BindCTE(const string &ctename, CommonTableExpressionInfo &info); - unique_ptr BindNode(SelectNode &node); - unique_ptr BindNode(SetOperationNode &node); - unique_ptr BindNode(RecursiveCTENode &node); - unique_ptr BindNode(CTENode &node); - unique_ptr BindNode(QueryNode &node); + BoundStatement BindNode(SelectNode &node); + BoundStatement BindNode(SetOperationNode &node); + BoundStatement BindNode(RecursiveCTENode &node); + BoundStatement BindNode(QueryNode &node); + BoundStatement BindNode(StatementNode &node); unique_ptr VisitQueryNode(BoundQueryNode &node, unique_ptr root); - unique_ptr CreatePlan(BoundRecursiveCTENode &node); - unique_ptr CreatePlan(BoundCTENode &node); - unique_ptr CreatePlan(BoundCTENode &node, unique_ptr base); unique_ptr CreatePlan(BoundSelectNode &statement); unique_ptr CreatePlan(BoundSetOperationNode &node); unique_ptr CreatePlan(BoundQueryNode &node); - unique_ptr BindJoin(Binder &parent, TableRef &ref); - unique_ptr Bind(BaseTableRef &ref); - unique_ptr Bind(BoundRefWrapper &ref); - unique_ptr Bind(JoinRef &ref); - unique_ptr Bind(SubqueryRef &ref); - unique_ptr Bind(TableFunctionRef &ref); - unique_ptr Bind(EmptyTableRef &ref); - unique_ptr Bind(DelimGetRef &ref); - unique_ptr Bind(ExpressionListRef &ref); - unique_ptr Bind(ColumnDataRef &ref); - unique_ptr Bind(PivotRef &expr); - unique_ptr Bind(ShowRef &ref); + void BuildUnionByNameInfo(BoundSetOperationNode &result); + + BoundStatement BindJoin(Binder &parent, TableRef &ref); + BoundStatement Bind(BaseTableRef &ref); + BoundStatement Bind(BoundRefWrapper &ref); + BoundStatement Bind(JoinRef &ref); + BoundStatement Bind(SubqueryRef &ref); + BoundStatement Bind(TableFunctionRef &ref); + BoundStatement Bind(EmptyTableRef &ref); + BoundStatement Bind(DelimGetRef &ref); + BoundStatement Bind(ExpressionListRef &ref); + BoundStatement Bind(ColumnDataRef &ref); + BoundStatement Bind(PivotRef &expr); + BoundStatement Bind(ShowRef &ref); unique_ptr BindPivot(PivotRef &expr, vector> all_columns); unique_ptr BindUnpivot(Binder &child_binder, PivotRef &expr, vector> all_columns, unique_ptr &where_clause); - unique_ptr BindBoundPivot(PivotRef &expr); + BoundStatement BindBoundPivot(PivotRef &expr); void ExtractUnpivotEntries(Binder &child_binder, PivotColumnEntry &entry, vector &unpivot_entries); void ExtractUnpivotColumnName(ParsedExpression &expr, vector &result); @@ -387,26 +445,14 @@ class Binder : public enable_shared_from_this { bool BindTableFunctionParameters(TableFunctionCatalogEntry &table_function, vector> &expressions, vector &arguments, vector ¶meters, named_parameter_map_t &named_parameters, - unique_ptr &subquery, ErrorData &error); - void BindTableInTableOutFunction(vector> &expressions, - unique_ptr &subquery); - unique_ptr BindTableFunction(TableFunction &function, vector parameters); - unique_ptr BindTableFunctionInternal(TableFunction &table_function, const TableFunctionRef &ref, - vector parameters, - named_parameter_map_t named_parameters, - vector input_table_types, - vector input_table_names); - - unique_ptr CreatePlan(BoundBaseTableRef &ref); + BoundStatement &subquery, ErrorData &error); + void BindTableInTableOutFunction(vector> &expressions, BoundStatement &subquery); + BoundStatement BindTableFunction(TableFunction &function, vector parameters); + BoundStatement BindTableFunctionInternal(TableFunction &table_function, const TableFunctionRef &ref, + vector parameters, named_parameter_map_t named_parameters, + vector input_table_types, vector input_table_names); + unique_ptr CreatePlan(BoundJoinRef &ref); - unique_ptr CreatePlan(BoundSubqueryRef &ref); - unique_ptr CreatePlan(BoundTableFunction &ref); - unique_ptr CreatePlan(BoundEmptyTableRef &ref); - unique_ptr CreatePlan(BoundExpressionListRef &ref); - unique_ptr CreatePlan(BoundColumnDataRef &ref); - unique_ptr CreatePlan(BoundCTERef &ref); - unique_ptr CreatePlan(BoundPivotRef &ref); - unique_ptr CreatePlan(BoundDelimGetRef &ref); BoundStatement BindCopyTo(CopyStatement &stmt, const CopyFunction &function, CopyToType copy_to_type); BoundStatement BindCopyFrom(CopyStatement &stmt, const CopyFunction &function); @@ -426,12 +472,12 @@ class Binder : public enable_shared_from_this { void PlanSubqueries(unique_ptr &expr, unique_ptr &root); unique_ptr PlanSubquery(BoundSubqueryExpression &expr, unique_ptr &root); unique_ptr PlanLateralJoin(unique_ptr left, unique_ptr right, - vector &correlated_columns, + CorrelatedColumns &correlated_columns, JoinType join_type = JoinType::INNER, unique_ptr condition = nullptr); - unique_ptr CastLogicalOperatorToTypes(vector &source_types, - vector &target_types, + unique_ptr CastLogicalOperatorToTypes(const vector &source_types, + const vector &target_types, unique_ptr op); BindingAlias FindBinding(const string &using_column, const string &join_side); @@ -441,8 +487,6 @@ class Binder : public enable_shared_from_this { BindingAlias RetrieveUsingBinding(Binder ¤t_binder, optional_ptr current_set, const string &column_name, const string &join_side); - void AddCTEMap(CommonTableExpressionMap &cte_map); - void ExpandStarExpressions(vector> &select_list, vector> &new_select_list); void ExpandStarExpression(unique_ptr expr, vector> &new_select_list); @@ -463,14 +507,14 @@ class Binder : public enable_shared_from_this { LogicalType BindLogicalTypeInternal(const LogicalType &type, optional_ptr catalog, const string &schema); - unique_ptr BindSelectNode(SelectNode &statement, unique_ptr from_table); + BoundStatement BindSelectNode(SelectNode &statement, BoundStatement from_table); unique_ptr BindCopyDatabaseSchema(Catalog &source_catalog, const string &target_database_name); unique_ptr BindCopyDatabaseData(Catalog &source_catalog, const string &target_database_name); - unique_ptr BindShowQuery(ShowRef &ref); - unique_ptr BindShowTable(ShowRef &ref); - unique_ptr BindSummarize(ShowRef &ref); + BoundStatement BindShowQuery(ShowRef &ref); + BoundStatement BindShowTable(ShowRef &ref); + BoundStatement BindSummarize(ShowRef &ref); void BindInsertColumnList(TableCatalogEntry &table, vector &columns, bool default_values, vector &named_column_map, vector &expected_types, @@ -491,6 +535,9 @@ class Binder : public enable_shared_from_this { static void CheckInsertColumnCountMismatch(idx_t expected_columns, idx_t result_columns, bool columns_provided, const string &tname); + BoundCTEData PrepareCTE(const string &ctename, CommonTableExpressionInfo &statement); + BoundStatement FinishCTE(BoundCTEData &bound_cte, BoundStatement child_data); + private: Binder(ClientContext &context, shared_ptr parent, BinderType binder_type); }; diff --git a/src/duckdb/src/include/duckdb/planner/bound_query_node.hpp b/src/duckdb/src/include/duckdb/planner/bound_query_node.hpp index cd5a78b6a..76c461e78 100644 --- a/src/duckdb/src/include/duckdb/planner/bound_query_node.hpp +++ b/src/duckdb/src/include/duckdb/planner/bound_query_node.hpp @@ -17,13 +17,8 @@ namespace duckdb { //! Bound equivalent of QueryNode class BoundQueryNode { public: - explicit BoundQueryNode(QueryNodeType type) : type(type) { - } - virtual ~BoundQueryNode() { - } + virtual ~BoundQueryNode() = default; - //! The type of the query node, either SetOperation or Select - QueryNodeType type; //! The result modifiers that should be applied to this query node vector> modifiers; @@ -34,23 +29,6 @@ class BoundQueryNode { public: virtual idx_t GetRootIndex() = 0; - -public: - template - TARGET &Cast() { - if (type != TARGET::TYPE) { - throw InternalException("Failed to cast bound query node to type - query node type mismatch"); - } - return reinterpret_cast(*this); - } - - template - const TARGET &Cast() const { - if (type != TARGET::TYPE) { - throw InternalException("Failed to cast bound query node to type - query node type mismatch"); - } - return reinterpret_cast(*this); - } }; } // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/planner/bound_statement.hpp b/src/duckdb/src/include/duckdb/planner/bound_statement.hpp index bb1f7bfec..23fae54d6 100644 --- a/src/duckdb/src/include/duckdb/planner/bound_statement.hpp +++ b/src/duckdb/src/include/duckdb/planner/bound_statement.hpp @@ -9,17 +9,31 @@ #pragma once #include "duckdb/common/string.hpp" +#include "duckdb/common/unique_ptr.hpp" #include "duckdb/common/vector.hpp" +#include "duckdb/common/enums/set_operation_type.hpp" +#include "duckdb/common/shared_ptr.hpp" namespace duckdb { class LogicalOperator; struct LogicalType; +struct BoundStatement; +class ParsedExpression; +class Binder; + +struct ExtraBoundInfo { + SetOperationType setop_type = SetOperationType::NONE; + vector> child_binders; + vector bound_children; + vector> original_expressions; +}; struct BoundStatement { unique_ptr plan; vector types; vector names; + ExtraBoundInfo extra_info; }; } // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/planner/bound_tableref.hpp b/src/duckdb/src/include/duckdb/planner/bound_tableref.hpp deleted file mode 100644 index 0a831c54a..000000000 --- a/src/duckdb/src/include/duckdb/planner/bound_tableref.hpp +++ /dev/null @@ -1,46 +0,0 @@ -//===----------------------------------------------------------------------===// -// DuckDB -// -// duckdb/planner/bound_tableref.hpp -// -// -//===----------------------------------------------------------------------===// - -#pragma once - -#include "duckdb/common/common.hpp" -#include "duckdb/common/enums/tableref_type.hpp" -#include "duckdb/parser/parsed_data/sample_options.hpp" - -namespace duckdb { - -class BoundTableRef { -public: - explicit BoundTableRef(TableReferenceType type) : type(type) { - } - virtual ~BoundTableRef() { - } - - //! The type of table reference - TableReferenceType type; - //! The sample options (if any) - unique_ptr sample; - -public: - template - TARGET &Cast() { - if (type != TARGET::TYPE) { - throw InternalException("Failed to cast bound table ref to type - table ref type mismatch"); - } - return reinterpret_cast(*this); - } - - template - const TARGET &Cast() const { - if (type != TARGET::TYPE) { - throw InternalException("Failed to cast bound table ref to type - table ref type mismatch"); - } - return reinterpret_cast(*this); - } -}; -} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/planner/bound_tokens.hpp b/src/duckdb/src/include/duckdb/planner/bound_tokens.hpp index bd75aac19..862ef5a11 100644 --- a/src/duckdb/src/include/duckdb/planner/bound_tokens.hpp +++ b/src/duckdb/src/include/duckdb/planner/bound_tokens.hpp @@ -16,8 +16,6 @@ namespace duckdb { class BoundQueryNode; class BoundSelectNode; class BoundSetOperationNode; -class BoundRecursiveCTENode; -class BoundCTENode; //===--------------------------------------------------------------------===// // Expressions @@ -45,18 +43,7 @@ class BoundWindowExpression; //===--------------------------------------------------------------------===// // TableRefs //===--------------------------------------------------------------------===// -class BoundTableRef; - -class BoundBaseTableRef; class BoundJoinRef; -class BoundSubqueryRef; -class BoundTableFunction; -class BoundEmptyTableRef; -class BoundExpressionListRef; -class BoundColumnDataRef; -class BoundCTERef; -class BoundPivotRef; - class BoundMergeIntoAction; } // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/planner/constraints/bound_unique_constraint.hpp b/src/duckdb/src/include/duckdb/planner/constraints/bound_unique_constraint.hpp index 58d136372..9b729ebe5 100644 --- a/src/duckdb/src/include/duckdb/planner/constraints/bound_unique_constraint.hpp +++ b/src/duckdb/src/include/duckdb/planner/constraints/bound_unique_constraint.hpp @@ -22,7 +22,6 @@ class BoundUniqueConstraint : public BoundConstraint { BoundUniqueConstraint(vector keys_p, physical_index_set_t key_set_p, const bool is_primary_key) : BoundConstraint(ConstraintType::UNIQUE), keys(std::move(keys_p)), key_set(std::move(key_set_p)), is_primary_key(is_primary_key) { - #ifdef DEBUG D_ASSERT(keys.size() == key_set.size()); for (auto &key : keys) { diff --git a/src/duckdb/src/include/duckdb/planner/expression/bound_subquery_expression.hpp b/src/duckdb/src/include/duckdb/planner/expression/bound_subquery_expression.hpp index aa07a67b9..35792c8d4 100644 --- a/src/duckdb/src/include/duckdb/planner/expression/bound_subquery_expression.hpp +++ b/src/duckdb/src/include/duckdb/planner/expression/bound_subquery_expression.hpp @@ -29,7 +29,7 @@ class BoundSubqueryExpression : public Expression { //! The binder used to bind the subquery node shared_ptr binder; //! The bound subquery node - unique_ptr subquery; + BoundStatement subquery; //! The subquery type SubqueryType subquery_type; //! the child expressions to compare with (in case of IN, ANY, ALL operators) diff --git a/src/duckdb/src/include/duckdb/planner/expression_binder/lateral_binder.hpp b/src/duckdb/src/include/duckdb/planner/expression_binder/lateral_binder.hpp index eb68a0cdf..55f046cd7 100644 --- a/src/duckdb/src/include/duckdb/planner/expression_binder/lateral_binder.hpp +++ b/src/duckdb/src/include/duckdb/planner/expression_binder/lateral_binder.hpp @@ -24,7 +24,7 @@ class LateralBinder : public ExpressionBinder { return !correlated_columns.empty(); } - static void ReduceExpressionDepth(LogicalOperator &op, const vector &info); + static void ReduceExpressionDepth(LogicalOperator &op, const CorrelatedColumns &info); protected: BindResult BindExpression(unique_ptr &expr_ptr, idx_t depth, @@ -37,7 +37,7 @@ class LateralBinder : public ExpressionBinder { void ExtractCorrelatedColumns(Expression &expr); private: - vector correlated_columns; + CorrelatedColumns correlated_columns; }; } // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/planner/expression_binder/select_bind_state.hpp b/src/duckdb/src/include/duckdb/planner/expression_binder/select_bind_state.hpp index d11d94731..956c66fab 100644 --- a/src/duckdb/src/include/duckdb/planner/expression_binder/select_bind_state.hpp +++ b/src/duckdb/src/include/duckdb/planner/expression_binder/select_bind_state.hpp @@ -11,7 +11,6 @@ #include "duckdb/planner/bound_query_node.hpp" #include "duckdb/planner/logical_operator.hpp" #include "duckdb/parser/expression_map.hpp" -#include "duckdb/planner/bound_tableref.hpp" #include "duckdb/parser/parsed_data/sample_options.hpp" #include "duckdb/parser/group_by_node.hpp" diff --git a/src/duckdb/src/include/duckdb/planner/expression_iterator.hpp b/src/duckdb/src/include/duckdb/planner/expression_iterator.hpp index 5b2e2e8a8..52599a1c8 100644 --- a/src/duckdb/src/include/duckdb/planner/expression_iterator.hpp +++ b/src/duckdb/src/include/duckdb/planner/expression_iterator.hpp @@ -14,8 +14,6 @@ #include namespace duckdb { -class BoundQueryNode; -class BoundTableRef; class ExpressionIterator { public: @@ -47,18 +45,4 @@ class ExpressionIterator { } }; -class BoundNodeVisitor { -public: - virtual ~BoundNodeVisitor() = default; - - virtual void VisitBoundQueryNode(BoundQueryNode &op); - virtual void VisitBoundTableRef(BoundTableRef &ref); - virtual void VisitExpression(unique_ptr &expression); - -protected: - // The VisitExpressionChildren method is called at the end of every call to VisitExpression to recursively visit all - // expressions in an expression tree. It can be overloaded to prevent automatically visiting the entire tree. - virtual void VisitExpressionChildren(Expression &expression); -}; - } // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/planner/logical_operator.hpp b/src/duckdb/src/include/duckdb/planner/logical_operator.hpp index e7f533bdd..743a9153b 100644 --- a/src/duckdb/src/include/duckdb/planner/logical_operator.hpp +++ b/src/duckdb/src/include/duckdb/planner/logical_operator.hpp @@ -45,6 +45,7 @@ class LogicalOperator { public: virtual vector GetColumnBindings(); + virtual idx_t GetRootIndex(); static string ColumnBindingsToString(const vector &bindings); void PrintColumnBindings(); static vector GenerateColumnBindings(idx_t table_idx, idx_t column_count); diff --git a/src/duckdb/src/include/duckdb/planner/operator/logical_column_data_get.hpp b/src/duckdb/src/include/duckdb/planner/operator/logical_column_data_get.hpp index 6d27b679e..6b4ee004e 100644 --- a/src/duckdb/src/include/duckdb/planner/operator/logical_column_data_get.hpp +++ b/src/duckdb/src/include/duckdb/planner/operator/logical_column_data_get.hpp @@ -14,6 +14,8 @@ namespace duckdb { +class ManagedResultSet; + //! LogicalColumnDataGet represents a scan operation from a ColumnDataCollection class LogicalColumnDataGet : public LogicalOperator { public: diff --git a/src/duckdb/src/include/duckdb/planner/operator/logical_cte.hpp b/src/duckdb/src/include/duckdb/planner/operator/logical_cte.hpp index cd2ed3c21..0548cd4e7 100644 --- a/src/duckdb/src/include/duckdb/planner/operator/logical_cte.hpp +++ b/src/duckdb/src/include/duckdb/planner/operator/logical_cte.hpp @@ -35,6 +35,6 @@ class LogicalCTE : public LogicalOperator { string ctename; idx_t table_index; idx_t column_count; - vector correlated_columns; + CorrelatedColumns correlated_columns; }; } // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/planner/operator/logical_dependent_join.hpp b/src/duckdb/src/include/duckdb/planner/operator/logical_dependent_join.hpp index 724f2bc57..5e4c83919 100644 --- a/src/duckdb/src/include/duckdb/planner/operator/logical_dependent_join.hpp +++ b/src/duckdb/src/include/duckdb/planner/operator/logical_dependent_join.hpp @@ -27,7 +27,7 @@ class LogicalDependentJoin : public LogicalComparisonJoin { public: explicit LogicalDependentJoin(unique_ptr left, unique_ptr right, - vector correlated_columns, JoinType type, + CorrelatedColumns correlated_columns, JoinType type, unique_ptr condition); explicit LogicalDependentJoin(JoinType type); @@ -35,7 +35,7 @@ class LogicalDependentJoin : public LogicalComparisonJoin { //! The conditions of the join unique_ptr join_condition; //! The list of columns that have correlations with the right - vector correlated_columns; + CorrelatedColumns correlated_columns; SubqueryType subquery_type = SubqueryType::INVALID; bool perform_delim = true; @@ -51,7 +51,7 @@ class LogicalDependentJoin : public LogicalComparisonJoin { public: static unique_ptr Create(unique_ptr left, unique_ptr right, - vector correlated_columns, JoinType type, + CorrelatedColumns correlated_columns, JoinType type, unique_ptr condition); }; } // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/planner/query_node/bound_cte_node.hpp b/src/duckdb/src/include/duckdb/planner/query_node/bound_cte_node.hpp deleted file mode 100644 index cbfdecd1f..000000000 --- a/src/duckdb/src/include/duckdb/planner/query_node/bound_cte_node.hpp +++ /dev/null @@ -1,46 +0,0 @@ -//===----------------------------------------------------------------------===// -// DuckDB -// -// duckdb/planner/query_node/bound_cte_node.hpp -// -// -//===----------------------------------------------------------------------===// - -#pragma once - -#include "duckdb/planner/binder.hpp" -#include "duckdb/planner/bound_query_node.hpp" - -namespace duckdb { - -class BoundCTENode : public BoundQueryNode { -public: - static constexpr const QueryNodeType TYPE = QueryNodeType::CTE_NODE; - -public: - BoundCTENode() : BoundQueryNode(QueryNodeType::CTE_NODE) { - } - - //! Keep track of the CTE name this node represents - string ctename; - - //! The cte node - unique_ptr query; - //! The child node - unique_ptr child; - //! Index used by the set operation - idx_t setop_index; - //! The binder used by the query side of the CTE - shared_ptr query_binder; - //! The binder used by the child side of the CTE - shared_ptr child_binder; - - CTEMaterialize materialized = CTEMaterialize::CTE_MATERIALIZE_DEFAULT; - -public: - idx_t GetRootIndex() override { - return child->GetRootIndex(); - } -}; - -} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/planner/query_node/bound_recursive_cte_node.hpp b/src/duckdb/src/include/duckdb/planner/query_node/bound_recursive_cte_node.hpp deleted file mode 100644 index 3da295e2a..000000000 --- a/src/duckdb/src/include/duckdb/planner/query_node/bound_recursive_cte_node.hpp +++ /dev/null @@ -1,49 +0,0 @@ -//===----------------------------------------------------------------------===// -// DuckDB -// -// duckdb/planner/query_node/bound_recursive_cte_node.hpp -// -// -//===----------------------------------------------------------------------===// - -#pragma once - -#include "duckdb/planner/binder.hpp" -#include "duckdb/planner/bound_query_node.hpp" - -namespace duckdb { - -//! Bound equivalent of SetOperationNode -class BoundRecursiveCTENode : public BoundQueryNode { -public: - static constexpr const QueryNodeType TYPE = QueryNodeType::RECURSIVE_CTE_NODE; - -public: - BoundRecursiveCTENode() : BoundQueryNode(QueryNodeType::RECURSIVE_CTE_NODE) { - } - - //! Keep track of the CTE name this node represents - string ctename; - - bool union_all; - //! The left side of the set operation - unique_ptr left; - //! The right side of the set operation - unique_ptr right; - //! Target columns for the recursive key variant - vector> key_targets; - - //! Index used by the set operation - idx_t setop_index; - //! The binder used by the left side of the set operation - shared_ptr left_binder; - //! The binder used by the right side of the set operation - shared_ptr right_binder; - -public: - idx_t GetRootIndex() override { - return setop_index; - } -}; - -} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/planner/query_node/bound_select_node.hpp b/src/duckdb/src/include/duckdb/planner/query_node/bound_select_node.hpp index b3a22966a..3fdc186e9 100644 --- a/src/duckdb/src/include/duckdb/planner/query_node/bound_select_node.hpp +++ b/src/duckdb/src/include/duckdb/planner/query_node/bound_select_node.hpp @@ -11,7 +11,6 @@ #include "duckdb/planner/bound_query_node.hpp" #include "duckdb/planner/logical_operator.hpp" #include "duckdb/parser/expression_map.hpp" -#include "duckdb/planner/bound_tableref.hpp" #include "duckdb/parser/parsed_data/sample_options.hpp" #include "duckdb/parser/group_by_node.hpp" #include "duckdb/planner/expression_binder/select_bind_state.hpp" @@ -36,18 +35,12 @@ struct BoundUnnestNode { //! Bound equivalent of SelectNode class BoundSelectNode : public BoundQueryNode { public: - static constexpr const QueryNodeType TYPE = QueryNodeType::SELECT_NODE; - -public: - BoundSelectNode() : BoundQueryNode(QueryNodeType::SELECT_NODE) { - } - //! Bind information SelectBindState bind_state; //! The projection list vector> select_list; //! The FROM clause - unique_ptr from_table; + BoundStatement from_table; //! The WHERE clause unique_ptr where_clause; //! list of groups diff --git a/src/duckdb/src/include/duckdb/planner/query_node/bound_set_operation_node.hpp b/src/duckdb/src/include/duckdb/planner/query_node/bound_set_operation_node.hpp index 01fa37caf..675007b50 100644 --- a/src/duckdb/src/include/duckdb/planner/query_node/bound_set_operation_node.hpp +++ b/src/duckdb/src/include/duckdb/planner/query_node/bound_set_operation_node.hpp @@ -14,28 +14,17 @@ namespace duckdb { -struct BoundSetOpChild { - unique_ptr node; - shared_ptr binder; - //! Exprs used by the UNION BY NAME operations to add a new projection - vector> reorder_expressions; -}; - //! Bound equivalent of SetOperationNode class BoundSetOperationNode : public BoundQueryNode { public: - static constexpr const QueryNodeType TYPE = QueryNodeType::SET_OPERATION_NODE; - -public: - BoundSetOperationNode() : BoundQueryNode(QueryNodeType::SET_OPERATION_NODE) { - } - //! The type of set operation SetOperationType setop_type = SetOperationType::NONE; //! whether the ALL modifier was used or not bool setop_all = false; //! The bound children - vector bound_children; + vector bound_children; + //! Child binders + vector> child_binders; //! Index used by the set operation idx_t setop_index; diff --git a/src/duckdb/src/include/duckdb/planner/query_node/list.hpp b/src/duckdb/src/include/duckdb/planner/query_node/list.hpp index 5c7dbda94..dcac81248 100644 --- a/src/duckdb/src/include/duckdb/planner/query_node/list.hpp +++ b/src/duckdb/src/include/duckdb/planner/query_node/list.hpp @@ -1,4 +1,2 @@ -#include "duckdb/planner/query_node/bound_recursive_cte_node.hpp" -#include "duckdb/planner/query_node/bound_cte_node.hpp" #include "duckdb/planner/query_node/bound_select_node.hpp" #include "duckdb/planner/query_node/bound_set_operation_node.hpp" diff --git a/src/duckdb/src/include/duckdb/planner/subquery/flatten_dependent_join.hpp b/src/duckdb/src/include/duckdb/planner/subquery/flatten_dependent_join.hpp index 2f343e901..14ad4510c 100644 --- a/src/duckdb/src/include/duckdb/planner/subquery/flatten_dependent_join.hpp +++ b/src/duckdb/src/include/duckdb/planner/subquery/flatten_dependent_join.hpp @@ -18,7 +18,7 @@ namespace duckdb { //! The FlattenDependentJoins class is responsible for pushing the dependent join down into the plan to create a //! flattened subquery struct FlattenDependentJoins { - FlattenDependentJoins(Binder &binder, const vector &correlated, bool perform_delim = true, + FlattenDependentJoins(Binder &binder, const CorrelatedColumns &correlated, bool perform_delim = true, bool any_join = false, optional_ptr parent = nullptr); static unique_ptr DecorrelateIndependent(Binder &binder, unique_ptr plan); @@ -47,7 +47,7 @@ struct FlattenDependentJoins { reference_map_t has_correlated_expressions; column_binding_map_t correlated_map; column_binding_map_t replacement_map; - const vector &correlated_columns; + const CorrelatedColumns &correlated_columns; vector delim_types; bool perform_delim; diff --git a/src/duckdb/src/include/duckdb/planner/subquery/has_correlated_expressions.hpp b/src/duckdb/src/include/duckdb/planner/subquery/has_correlated_expressions.hpp index 6b238ffcc..81a097b49 100644 --- a/src/duckdb/src/include/duckdb/planner/subquery/has_correlated_expressions.hpp +++ b/src/duckdb/src/include/duckdb/planner/subquery/has_correlated_expressions.hpp @@ -16,7 +16,7 @@ namespace duckdb { //! Helper class to recursively detect correlated expressions inside a single LogicalOperator class HasCorrelatedExpressions : public LogicalOperatorVisitor { public: - explicit HasCorrelatedExpressions(const vector &correlated, bool lateral = false, + explicit HasCorrelatedExpressions(const CorrelatedColumns &correlated, bool lateral = false, idx_t lateral_depth = 0); void VisitOperator(LogicalOperator &op) override; @@ -28,7 +28,7 @@ class HasCorrelatedExpressions : public LogicalOperatorVisitor { unique_ptr VisitReplace(BoundColumnRefExpression &expr, unique_ptr *expr_ptr) override; unique_ptr VisitReplace(BoundSubqueryExpression &expr, unique_ptr *expr_ptr) override; - const vector &correlated_columns; + const CorrelatedColumns &correlated_columns; // Tracks number of nested laterals idx_t lateral_depth; }; diff --git a/src/duckdb/src/include/duckdb/planner/subquery/rewrite_cte_scan.hpp b/src/duckdb/src/include/duckdb/planner/subquery/rewrite_cte_scan.hpp index e2c507e73..72886f80e 100644 --- a/src/duckdb/src/include/duckdb/planner/subquery/rewrite_cte_scan.hpp +++ b/src/duckdb/src/include/duckdb/planner/subquery/rewrite_cte_scan.hpp @@ -17,13 +17,13 @@ namespace duckdb { //! Helper class to rewrite correlated cte scans within a single LogicalOperator class RewriteCTEScan : public LogicalOperatorVisitor { public: - RewriteCTEScan(idx_t table_index, const vector &correlated_columns); + RewriteCTEScan(idx_t table_index, const CorrelatedColumns &correlated_columns); void VisitOperator(LogicalOperator &op) override; private: idx_t table_index; - const vector &correlated_columns; + const CorrelatedColumns &correlated_columns; }; } // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/planner/table_binding.hpp b/src/duckdb/src/include/duckdb/planner/table_binding.hpp index 9aedc7e70..836f52c41 100644 --- a/src/duckdb/src/include/duckdb/planner/table_binding.hpp +++ b/src/duckdb/src/include/duckdb/planner/table_binding.hpp @@ -17,6 +17,7 @@ #include "duckdb/planner/binding_alias.hpp" #include "duckdb/common/column_index.hpp" #include "duckdb/common/table_column.hpp" +#include "duckdb/planner/bound_statement.hpp" namespace duckdb { class BindContext; @@ -26,30 +27,16 @@ class SubqueryRef; class LogicalGet; class TableCatalogEntry; class TableFunctionCatalogEntry; -class BoundTableFunction; class StandardEntry; struct ColumnBinding; -enum class BindingType { BASE, TABLE, DUMMY, CATALOG_ENTRY }; +enum class BindingType { BASE, TABLE, DUMMY, CATALOG_ENTRY, CTE }; //! A Binding represents a binding to a table, table-producing function or subquery with a specified table index. struct Binding { Binding(BindingType binding_type, BindingAlias alias, vector types, vector names, idx_t index); virtual ~Binding() = default; - //! The type of Binding - BindingType binding_type; - //! The alias of the binding - BindingAlias alias; - //! The table index of the binding - idx_t index; - //! The types of the bound columns - vector types; - //! Column names of the subquery - vector names; - //! Name -> index for the names - case_insensitive_map_t name_map; - public: bool TryGetBindingIndex(const string &column_name, column_t &column_index); column_t GetBindingIndex(const string &column_name); @@ -59,6 +46,14 @@ struct Binding { virtual optional_ptr GetStandardEntry(); string GetAlias() const; + BindingType GetBindingType(); + const BindingAlias &GetBindingAlias(); + idx_t GetIndex(); + const vector &GetColumnTypes(); + const vector &GetColumnNames(); + idx_t GetColumnCount(); + void SetColumnType(idx_t col_idx, LogicalType type); + static BindingAlias GetAlias(const string &explicit_alias, const StandardEntry &entry); static BindingAlias GetAlias(const string &explicit_alias, optional_ptr entry); @@ -78,6 +73,23 @@ struct Binding { } return reinterpret_cast(*this); } + +protected: + void Initialize(); + +protected: + //! The type of Binding + BindingType binding_type; + //! The alias of the binding + BindingAlias alias; + //! The table index of the binding + idx_t index; + //! The types of the bound columns + vector types; + //! Column names of the subquery + vector names; + //! Name -> index for the names + case_insensitive_map_t name_map; }; struct EntryBinding : public Binding { @@ -149,4 +161,44 @@ struct DummyBinding : public Binding { unique_ptr ParamToArg(ColumnRefExpression &col_ref); }; +enum class CTEType { CAN_BE_REFERENCED, CANNOT_BE_REFERENCED }; +struct CTEBinding; + +struct CTEBindState { + CTEBindState(Binder &parent_binder, QueryNode &cte_def, const vector &aliases); + ~CTEBindState(); + + Binder &parent_binder; + QueryNode &cte_def; + const vector &aliases; + idx_t active_binder_count; + shared_ptr query_binder; + BoundStatement query; + vector names; + vector types; + +public: + bool IsBound() const; + void Bind(CTEBinding &binding); +}; + +struct CTEBinding : public Binding { +public: + static constexpr const BindingType TYPE = BindingType::CTE; + +public: + CTEBinding(BindingAlias alias, vector types, vector names, idx_t index, CTEType type); + CTEBinding(BindingAlias alias, shared_ptr bind_state, idx_t index); + +public: + bool CanBeReferenced() const; + bool IsReferenced() const; + void Reference(); + +private: + CTEType cte_type; + idx_t reference_count; + shared_ptr bind_state; +}; + } // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/planner/tableref/bound_basetableref.hpp b/src/duckdb/src/include/duckdb/planner/tableref/bound_basetableref.hpp deleted file mode 100644 index b1f7f6f46..000000000 --- a/src/duckdb/src/include/duckdb/planner/tableref/bound_basetableref.hpp +++ /dev/null @@ -1,30 +0,0 @@ -//===----------------------------------------------------------------------===// -// DuckDB -// -// duckdb/planner/tableref/bound_basetableref.hpp -// -// -//===----------------------------------------------------------------------===// - -#pragma once - -#include "duckdb/planner/bound_tableref.hpp" -#include "duckdb/planner/logical_operator.hpp" - -namespace duckdb { -class TableCatalogEntry; - -//! Represents a TableReference to a base table in the schema -class BoundBaseTableRef : public BoundTableRef { -public: - static constexpr const TableReferenceType TYPE = TableReferenceType::BASE_TABLE; - -public: - BoundBaseTableRef(TableCatalogEntry &table, unique_ptr get) - : BoundTableRef(TableReferenceType::BASE_TABLE), table(table), get(std::move(get)) { - } - - TableCatalogEntry &table; - unique_ptr get; -}; -} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/planner/tableref/bound_column_data_ref.hpp b/src/duckdb/src/include/duckdb/planner/tableref/bound_column_data_ref.hpp deleted file mode 100644 index 025bc4712..000000000 --- a/src/duckdb/src/include/duckdb/planner/tableref/bound_column_data_ref.hpp +++ /dev/null @@ -1,30 +0,0 @@ -//===----------------------------------------------------------------------===// -// DuckDB -// -// duckdb/planner/tableref/bound_column_data_ref.hpp -// -// -//===----------------------------------------------------------------------===// - -#pragma once - -#include "duckdb/planner/bound_tableref.hpp" -#include "duckdb/common/optionally_owned_ptr.hpp" -#include "duckdb/common/types/column/column_data_collection.hpp" - -namespace duckdb { -//! Represents a TableReference to a base table in the schema -class BoundColumnDataRef : public BoundTableRef { -public: - static constexpr const TableReferenceType TYPE = TableReferenceType::COLUMN_DATA; - -public: - explicit BoundColumnDataRef(optionally_owned_ptr collection) - : BoundTableRef(TableReferenceType::COLUMN_DATA), collection(std::move(collection)) { - } - //! The (optionally owned) materialized column data to scan - optionally_owned_ptr collection; - //! The index in the bind context - idx_t bind_index; -}; -} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/planner/tableref/bound_cteref.hpp b/src/duckdb/src/include/duckdb/planner/tableref/bound_cteref.hpp deleted file mode 100644 index 781402fbe..000000000 --- a/src/duckdb/src/include/duckdb/planner/tableref/bound_cteref.hpp +++ /dev/null @@ -1,40 +0,0 @@ -//===----------------------------------------------------------------------===// -// DuckDB -// -// duckdb/planner/tableref/bound_cteref.hpp -// -// -//===----------------------------------------------------------------------===// - -#pragma once - -#include "duckdb/planner/bound_tableref.hpp" -#include "duckdb/common/enums/cte_materialize.hpp" - -namespace duckdb { - -class BoundCTERef : public BoundTableRef { -public: - static constexpr const TableReferenceType TYPE = TableReferenceType::CTE; - -public: - BoundCTERef(idx_t bind_index, idx_t cte_index) - : BoundTableRef(TableReferenceType::CTE), bind_index(bind_index), cte_index(cte_index) { - } - - BoundCTERef(idx_t bind_index, idx_t cte_index, bool is_recurring) - : BoundTableRef(TableReferenceType::CTE), bind_index(bind_index), cte_index(cte_index), - is_recurring(is_recurring) { - } - //! The set of columns bound to this base table reference - vector bound_columns; - //! The types of the values list - vector types; - //! The index in the bind context - idx_t bind_index; - //! The index of the cte - idx_t cte_index; - //! Is this a reference to the recurring table of a CTE - bool is_recurring = false; -}; -} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/planner/tableref/bound_delimgetref.hpp b/src/duckdb/src/include/duckdb/planner/tableref/bound_delimgetref.hpp deleted file mode 100644 index 7b1022482..000000000 --- a/src/duckdb/src/include/duckdb/planner/tableref/bound_delimgetref.hpp +++ /dev/null @@ -1,26 +0,0 @@ -//===----------------------------------------------------------------------===// -// DuckDB -// -// duckdb/planner/tableref/bound_delimgetref.hpp -// -// -//===----------------------------------------------------------------------===// - -#pragma once - -#include "duckdb/planner/bound_tableref.hpp" - -namespace duckdb { - -class BoundDelimGetRef : public BoundTableRef { -public: - static constexpr const TableReferenceType TYPE = TableReferenceType::DELIM_GET; - -public: - BoundDelimGetRef(idx_t bind_index, const vector &column_types_p) - : BoundTableRef(TableReferenceType::DELIM_GET), bind_index(bind_index), column_types(column_types_p) { - } - idx_t bind_index; - vector column_types; -}; -} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/planner/tableref/bound_dummytableref.hpp b/src/duckdb/src/include/duckdb/planner/tableref/bound_dummytableref.hpp deleted file mode 100644 index 3a68f5166..000000000 --- a/src/duckdb/src/include/duckdb/planner/tableref/bound_dummytableref.hpp +++ /dev/null @@ -1,26 +0,0 @@ -//===----------------------------------------------------------------------===// -// DuckDB -// -// duckdb/planner/tableref/bound_dummytableref.hpp -// -// -//===----------------------------------------------------------------------===// - -#pragma once - -#include "duckdb/planner/bound_tableref.hpp" - -namespace duckdb { - -//! Represents a cross product -class BoundEmptyTableRef : public BoundTableRef { -public: - static constexpr const TableReferenceType TYPE = TableReferenceType::EMPTY_FROM; - -public: - explicit BoundEmptyTableRef(idx_t bind_index) - : BoundTableRef(TableReferenceType::EMPTY_FROM), bind_index(bind_index) { - } - idx_t bind_index; -}; -} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/planner/tableref/bound_expressionlistref.hpp b/src/duckdb/src/include/duckdb/planner/tableref/bound_expressionlistref.hpp deleted file mode 100644 index 7fc563dda..000000000 --- a/src/duckdb/src/include/duckdb/planner/tableref/bound_expressionlistref.hpp +++ /dev/null @@ -1,33 +0,0 @@ -//===----------------------------------------------------------------------===// -// DuckDB -// -// duckdb/planner/tableref/bound_expressionlistref.hpp -// -// -//===----------------------------------------------------------------------===// - -#pragma once - -#include "duckdb/planner/bound_tableref.hpp" -#include "duckdb/planner/expression.hpp" - -namespace duckdb { -//! Represents a TableReference to a base table in the schema -class BoundExpressionListRef : public BoundTableRef { -public: - static constexpr const TableReferenceType TYPE = TableReferenceType::EXPRESSION_LIST; - -public: - BoundExpressionListRef() : BoundTableRef(TableReferenceType::EXPRESSION_LIST) { - } - - //! The bound VALUES list - vector>> values; - //! The generated names of the values list - vector names; - //! The types of the values list - vector types; - //! The index in the bind context - idx_t bind_index; -}; -} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/planner/tableref/bound_joinref.hpp b/src/duckdb/src/include/duckdb/planner/tableref/bound_joinref.hpp index 38c83c95f..87976ba30 100644 --- a/src/duckdb/src/include/duckdb/planner/tableref/bound_joinref.hpp +++ b/src/duckdb/src/include/duckdb/planner/tableref/bound_joinref.hpp @@ -11,19 +11,14 @@ #include "duckdb/planner/binder.hpp" #include "duckdb/common/enums/join_type.hpp" #include "duckdb/common/enums/joinref_type.hpp" -#include "duckdb/planner/bound_tableref.hpp" #include "duckdb/planner/expression.hpp" namespace duckdb { //! Represents a join -class BoundJoinRef : public BoundTableRef { +class BoundJoinRef { public: - static constexpr const TableReferenceType TYPE = TableReferenceType::JOIN; - -public: - explicit BoundJoinRef(JoinRefType ref_type) - : BoundTableRef(TableReferenceType::JOIN), type(JoinType::INNER), ref_type(ref_type), lateral(false) { + explicit BoundJoinRef(JoinRefType ref_type) : type(JoinType::INNER), ref_type(ref_type), lateral(false) { } //! The binder used to bind the LHS of the join @@ -31,9 +26,9 @@ class BoundJoinRef : public BoundTableRef { //! The binder used to bind the RHS of the join shared_ptr right_binder; //! The left hand side of the join - unique_ptr left; + BoundStatement left; //! The right hand side of the join - unique_ptr right; + BoundStatement right; //! The join condition unique_ptr condition; //! Duplicate Eliminated Columns (if any) @@ -47,7 +42,7 @@ class BoundJoinRef : public BoundTableRef { //! Whether or not this is a lateral join bool lateral; //! The correlated columns of the right-side with the left-side - vector correlated_columns; + CorrelatedColumns correlated_columns; //! The mark index, for mark joins generated by the relational API idx_t mark_index {}; }; diff --git a/src/duckdb/src/include/duckdb/planner/tableref/bound_pivotref.hpp b/src/duckdb/src/include/duckdb/planner/tableref/bound_pivotref.hpp index 3219f6307..5a2d68aa1 100644 --- a/src/duckdb/src/include/duckdb/planner/tableref/bound_pivotref.hpp +++ b/src/duckdb/src/include/duckdb/planner/tableref/bound_pivotref.hpp @@ -9,7 +9,6 @@ #pragma once #include "duckdb/planner/binder.hpp" -#include "duckdb/planner/bound_tableref.hpp" #include "duckdb/planner/expression.hpp" #include "duckdb/parser/tableref/pivotref.hpp" #include "duckdb/function/aggregate_function.hpp" @@ -30,19 +29,13 @@ struct BoundPivotInfo { static BoundPivotInfo Deserialize(Deserializer &deserializer); }; -class BoundPivotRef : public BoundTableRef { +class BoundPivotRef { public: - static constexpr const TableReferenceType TYPE = TableReferenceType::PIVOT; - -public: - explicit BoundPivotRef() : BoundTableRef(TableReferenceType::PIVOT) { - } - idx_t bind_index; //! The binder used to bind the child of the pivot shared_ptr child_binder; //! The child node of the pivot - unique_ptr child; + BoundStatement child; //! The bound pivot info BoundPivotInfo bound_pivot; }; diff --git a/src/duckdb/src/include/duckdb/planner/tableref/bound_pos_join_ref.hpp b/src/duckdb/src/include/duckdb/planner/tableref/bound_pos_join_ref.hpp deleted file mode 100644 index 4cb057e41..000000000 --- a/src/duckdb/src/include/duckdb/planner/tableref/bound_pos_join_ref.hpp +++ /dev/null @@ -1,38 +0,0 @@ -//===----------------------------------------------------------------------===// -// DuckDB -// -// duckdb/planner/tableref/bound_pos_join_ref.hpp -// -// -//===----------------------------------------------------------------------===// - -#pragma once - -#include "duckdb/planner/binder.hpp" -#include "duckdb/planner/bound_tableref.hpp" - -namespace duckdb { - -//! Represents a positional join -class BoundPositionalJoinRef : public BoundTableRef { -public: - static constexpr const TableReferenceType TYPE = TableReferenceType::POSITIONAL_JOIN; - -public: - BoundPositionalJoinRef() : BoundTableRef(TableReferenceType::POSITIONAL_JOIN), lateral(false) { - } - - //! The binder used to bind the LHS of the positional join - shared_ptr left_binder; - //! The binder used to bind the RHS of the positional join - shared_ptr right_binder; - //! The left hand side of the positional join - unique_ptr left; - //! The right hand side of the positional join - unique_ptr right; - //! Whether or not this is a lateral positional join - bool lateral; - //! The correlated columns of the right-side with the left-side - vector correlated_columns; -}; -} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/planner/tableref/bound_subqueryref.hpp b/src/duckdb/src/include/duckdb/planner/tableref/bound_subqueryref.hpp deleted file mode 100644 index 2d1061c98..000000000 --- a/src/duckdb/src/include/duckdb/planner/tableref/bound_subqueryref.hpp +++ /dev/null @@ -1,32 +0,0 @@ -//===----------------------------------------------------------------------===// -// DuckDB -// -// duckdb/planner/tableref/bound_subqueryref.hpp -// -// -//===----------------------------------------------------------------------===// - -#pragma once - -#include "duckdb/planner/binder.hpp" -#include "duckdb/planner/bound_query_node.hpp" -#include "duckdb/planner/bound_tableref.hpp" - -namespace duckdb { - -//! Represents a cross product -class BoundSubqueryRef : public BoundTableRef { -public: - static constexpr const TableReferenceType TYPE = TableReferenceType::SUBQUERY; - -public: - BoundSubqueryRef(shared_ptr binder_p, unique_ptr subquery) - : BoundTableRef(TableReferenceType::SUBQUERY), binder(std::move(binder_p)), subquery(std::move(subquery)) { - } - - //! The binder used to bind the subquery - shared_ptr binder; - //! The bound subquery node (if any) - unique_ptr subquery; -}; -} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/planner/tableref/bound_table_function.hpp b/src/duckdb/src/include/duckdb/planner/tableref/bound_table_function.hpp deleted file mode 100644 index 6aafe2b36..000000000 --- a/src/duckdb/src/include/duckdb/planner/tableref/bound_table_function.hpp +++ /dev/null @@ -1,31 +0,0 @@ -//===----------------------------------------------------------------------===// -// DuckDB -// -// duckdb/planner/tableref/bound_table_function.hpp -// -// -//===----------------------------------------------------------------------===// - -#pragma once - -#include "duckdb/planner/bound_tableref.hpp" -#include "duckdb/planner/logical_operator.hpp" -#include "duckdb/planner/tableref/bound_subqueryref.hpp" - -namespace duckdb { - -//! Represents a reference to a table-producing function call -class BoundTableFunction : public BoundTableRef { -public: - static constexpr const TableReferenceType TYPE = TableReferenceType::TABLE_FUNCTION; - -public: - explicit BoundTableFunction(unique_ptr get) - : BoundTableRef(TableReferenceType::TABLE_FUNCTION), get(std::move(get)) { - } - - unique_ptr get; - unique_ptr subquery; -}; - -} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/planner/tableref/list.hpp b/src/duckdb/src/include/duckdb/planner/tableref/list.hpp index 79a00ce62..dbc8394df 100644 --- a/src/duckdb/src/include/duckdb/planner/tableref/list.hpp +++ b/src/duckdb/src/include/duckdb/planner/tableref/list.hpp @@ -1,11 +1,2 @@ -#include "duckdb/planner/tableref/bound_basetableref.hpp" -#include "duckdb/planner/tableref/bound_cteref.hpp" -#include "duckdb/planner/tableref/bound_dummytableref.hpp" -#include "duckdb/planner/tableref/bound_expressionlistref.hpp" #include "duckdb/planner/tableref/bound_joinref.hpp" -#include "duckdb/planner/tableref/bound_subqueryref.hpp" -#include "duckdb/planner/tableref/bound_column_data_ref.hpp" -#include "duckdb/planner/tableref/bound_table_function.hpp" #include "duckdb/planner/tableref/bound_pivotref.hpp" -#include "duckdb/parser/tableref/delimgetref.hpp" -#include "duckdb/planner/tableref/bound_delimgetref.hpp" diff --git a/src/duckdb/src/include/duckdb/storage/arena_allocator.hpp b/src/duckdb/src/include/duckdb/storage/arena_allocator.hpp index 687e4af9b..6cf150366 100644 --- a/src/duckdb/src/include/duckdb/storage/arena_allocator.hpp +++ b/src/duckdb/src/include/duckdb/storage/arena_allocator.hpp @@ -11,6 +11,7 @@ #include "duckdb/common/allocator.hpp" #include "duckdb/common/common.hpp" #include "duckdb/common/types/string.hpp" +#include "duckdb/common/arena_containers/arena_ptr.hpp" namespace duckdb { @@ -84,6 +85,16 @@ class ArenaAllocator { return new (mem) T(std::forward(args)...); } + template + arena_ptr MakePtr(ARGS &&... args) { + return arena_ptr(Make(std::forward(args)...)); + } + + template + unsafe_arena_ptr MakeUnsafePtr(ARGS &&... args) { + return unsafe_arena_ptr(Make(std::forward(args)...)); + } + String MakeString(const char *data, const size_t len) { data_ptr_t mem = nullptr; diff --git a/src/duckdb/src/include/duckdb/storage/block.hpp b/src/duckdb/src/include/duckdb/storage/block.hpp index 3aa18a7bc..12fd7f818 100644 --- a/src/duckdb/src/include/duckdb/storage/block.hpp +++ b/src/duckdb/src/include/duckdb/storage/block.hpp @@ -61,6 +61,15 @@ struct MetaBlockPointer { block_id_t GetBlockId() const; uint32_t GetBlockIndex() const; + bool operator==(const MetaBlockPointer &rhs) const { + return block_pointer == rhs.block_pointer && offset == rhs.offset; + } + + friend std::ostream &operator<<(std::ostream &os, const MetaBlockPointer &obj) { + return os << "{block_id: " << obj.GetBlockId() << " index: " << obj.GetBlockIndex() << " offset: " << obj.offset + << "}"; + } + void Serialize(Serializer &serializer) const; static MetaBlockPointer Deserialize(Deserializer &source); }; diff --git a/src/duckdb/src/include/duckdb/storage/block_manager.hpp b/src/duckdb/src/include/duckdb/storage/block_manager.hpp index 0fd9df675..ef6f1941e 100644 --- a/src/duckdb/src/include/duckdb/storage/block_manager.hpp +++ b/src/duckdb/src/include/duckdb/storage/block_manager.hpp @@ -24,6 +24,8 @@ class ClientContext; class DatabaseInstance; class MetadataManager; +enum class ConvertToPersistentMode { DESTRUCTIVE, THREAD_SAFE }; + //! BlockManager is an abstract representation to manage blocks on DuckDB. When writing or reading blocks, the //! BlockManager creates and accesses blocks. The concrete types implement specific block storage strategies. class BlockManager { @@ -37,6 +39,9 @@ class BlockManager { BufferManager &buffer_manager; public: + BufferManager &GetBufferManager() const { + return buffer_manager; + } //! Creates a new block inside the block manager virtual unique_ptr ConvertBlock(block_id_t block_id, FileBuffer &source_buffer) = 0; virtual unique_ptr CreateBlock(block_id_t block_id, FileBuffer *source_buffer) = 0; @@ -95,10 +100,15 @@ class BlockManager { //! Register a block with the given block id in the base file shared_ptr RegisterBlock(block_id_t block_id); //! Convert an existing in-memory buffer into a persistent disk-backed block + //! If mode is set to destructive (default) - the old_block will be destroyed as part of this method + //! This can only be safely used when there is no other (lingering) usage of old_block + //! If there is concurrent usage of the block elsewhere - use the THREAD_SAFE mode which creates an extra copy shared_ptr ConvertToPersistent(QueryContext context, block_id_t block_id, - shared_ptr old_block, BufferHandle old_handle); + shared_ptr old_block, BufferHandle old_handle, + ConvertToPersistentMode mode = ConvertToPersistentMode::DESTRUCTIVE); shared_ptr ConvertToPersistent(QueryContext context, block_id_t block_id, - shared_ptr old_block); + shared_ptr old_block, + ConvertToPersistentMode mode = ConvertToPersistentMode::DESTRUCTIVE); void UnregisterBlock(BlockHandle &block); //! UnregisterBlock, only accepts non-temporary block ids diff --git a/src/duckdb/src/include/duckdb/storage/buffer_manager.hpp b/src/duckdb/src/include/duckdb/storage/buffer_manager.hpp index 619e89a5a..3d4f5e595 100644 --- a/src/duckdb/src/include/duckdb/storage/buffer_manager.hpp +++ b/src/duckdb/src/include/duckdb/storage/buffer_manager.hpp @@ -58,6 +58,7 @@ class BufferManager { virtual void ReAllocate(shared_ptr &handle, idx_t block_size) = 0; //! Pin a block handle. virtual BufferHandle Pin(shared_ptr &handle) = 0; + virtual BufferHandle Pin(const QueryContext &context, shared_ptr &handle) = 0; //! Pre-fetch a series of blocks. //! Using this function is a performance suggestion. virtual void Prefetch(vector> &handles) = 0; @@ -100,6 +101,8 @@ class BufferManager { //! Set a new swap limit. virtual void SetSwapLimit(optional_idx limit = optional_idx()); + //! Get the block manager used for in-memory data + virtual BlockManager &GetTemporaryBlockManager() = 0; //! Get the temporary file information of each temporary file. virtual vector GetTemporaryFiles(); //! Get the path to the temporary file directory. diff --git a/src/duckdb/src/include/duckdb/storage/caching_file_system.hpp b/src/duckdb/src/include/duckdb/storage/caching_file_system.hpp index 99e949418..93cec496d 100644 --- a/src/duckdb/src/include/duckdb/storage/caching_file_system.hpp +++ b/src/duckdb/src/include/duckdb/storage/caching_file_system.hpp @@ -61,7 +61,8 @@ struct CachingFileHandle { //! Tries to read from the cache, filling "overlapping_ranges" with ranges that overlap with the request. //! Returns an invalid BufferHandle if it fails BufferHandle TryReadFromCache(data_ptr_t &buffer, idx_t nr_bytes, idx_t location, - vector> &overlapping_ranges); + vector> &overlapping_ranges, + optional_idx &start_location_of_next_range); //! Try to read from the specified range, return an invalid BufferHandle if it fails BufferHandle TryReadFromFileRange(const unique_ptr &guard, CachedFileRange &file_range, data_ptr_t &buffer, idx_t nr_bytes, idx_t location); diff --git a/src/duckdb/src/include/duckdb/storage/compression/alp/alp_analyze.hpp b/src/duckdb/src/include/duckdb/storage/compression/alp/alp_analyze.hpp index bac590d0e..13eecf42b 100644 --- a/src/duckdb/src/include/duckdb/storage/compression/alp/alp_analyze.hpp +++ b/src/duckdb/src/include/duckdb/storage/compression/alp/alp_analyze.hpp @@ -82,7 +82,12 @@ unique_ptr AlpInitAnalyze(ColumnData &col_data, PhysicalType type) */ template bool AlpAnalyze(AnalyzeState &state, Vector &input, idx_t count) { - auto &analyze_state = (AlpAnalyzeState &)state; + if (state.info.GetBlockSize() + state.info.GetBlockHeaderSize() < DEFAULT_BLOCK_ALLOC_SIZE) { + return false; + } + + auto &analyze_state = state.Cast>(); + bool must_skip_current_vector = alp::AlpUtils::MustSkipSamplingFromCurrentVector( analyze_state.vectors_count, analyze_state.vectors_sampled_count, count); analyze_state.vectors_count += 1; diff --git a/src/duckdb/src/include/duckdb/storage/compression/alp/alp_compress.hpp b/src/duckdb/src/include/duckdb/storage/compression/alp/alp_compress.hpp index e08b8b5bb..38ed76769 100644 --- a/src/duckdb/src/include/duckdb/storage/compression/alp/alp_compress.hpp +++ b/src/duckdb/src/include/duckdb/storage/compression/alp/alp_compress.hpp @@ -28,7 +28,6 @@ namespace duckdb { template struct AlpCompressionState : public CompressionState { - public: using EXACT_TYPE = typename FloatingToExact::TYPE; diff --git a/src/duckdb/src/include/duckdb/storage/compression/alp/alp_constants.hpp b/src/duckdb/src/include/duckdb/storage/compression/alp/alp_constants.hpp index cf9766177..6774ab47e 100644 --- a/src/duckdb/src/include/duckdb/storage/compression/alp/alp_constants.hpp +++ b/src/duckdb/src/include/duckdb/storage/compression/alp/alp_constants.hpp @@ -66,7 +66,6 @@ struct AlpTypedConstants {}; template <> struct AlpTypedConstants { - static constexpr float MAGIC_NUMBER = 12582912.0; //! 2^22 + 2^23 static constexpr uint8_t MAX_EXPONENT = 10; @@ -80,7 +79,6 @@ struct AlpTypedConstants { template <> struct AlpTypedConstants { - static constexpr double MAGIC_NUMBER = 6755399441055744.0; //! 2^51 + 2^52 static constexpr uint8_t MAX_EXPONENT = 18; //! 10^18 is the maximum int64 diff --git a/src/duckdb/src/include/duckdb/storage/compression/alp/alp_scan.hpp b/src/duckdb/src/include/duckdb/storage/compression/alp/alp_scan.hpp index 28b52b848..8c7d12e67 100644 --- a/src/duckdb/src/include/duckdb/storage/compression/alp/alp_scan.hpp +++ b/src/duckdb/src/include/duckdb/storage/compression/alp/alp_scan.hpp @@ -201,7 +201,7 @@ struct AlpScanState : public SegmentScanState { }; template -unique_ptr AlpInitScan(ColumnSegment &segment) { +unique_ptr AlpInitScan(const QueryContext &context, ColumnSegment &segment) { auto result = make_uniq_base>(segment); return result; } diff --git a/src/duckdb/src/include/duckdb/storage/compression/alp/alp_utils.hpp b/src/duckdb/src/include/duckdb/storage/compression/alp/alp_utils.hpp index b9c0e6eab..1fa3b3664 100644 --- a/src/duckdb/src/include/duckdb/storage/compression/alp/alp_utils.hpp +++ b/src/duckdb/src/include/duckdb/storage/compression/alp/alp_utils.hpp @@ -36,7 +36,6 @@ class AlpUtils { public: static AlpSamplingParameters GetSamplingParameters(idx_t current_vector_n_values) { - auto n_lookup_values = NumericCast(MinValue(current_vector_n_values, (idx_t)AlpConstants::ALP_VECTOR_SIZE)); //! We sample equidistant values within a vector; to do this we jump a fixed number of values diff --git a/src/duckdb/src/include/duckdb/storage/compression/alprd/algorithm/alprd.hpp b/src/duckdb/src/include/duckdb/storage/compression/alprd/algorithm/alprd.hpp index 1dac66f4a..47ebddc57 100644 --- a/src/duckdb/src/include/duckdb/storage/compression/alprd/algorithm/alprd.hpp +++ b/src/duckdb/src/include/duckdb/storage/compression/alprd/algorithm/alprd.hpp @@ -152,7 +152,6 @@ struct AlpRDCompression { } static void Compress(const EXACT_TYPE *input_vector, idx_t n_values, State &state) { - uint64_t right_parts[AlpRDConstants::ALP_VECTOR_SIZE]; uint16_t left_parts[AlpRDConstants::ALP_VECTOR_SIZE]; @@ -207,7 +206,6 @@ struct AlpRDDecompression { EXACT_TYPE *output, idx_t values_count, uint16_t exceptions_count, const uint16_t *exceptions, const uint16_t *exceptions_positions, uint8_t left_bit_width, uint8_t right_bit_width) { - uint8_t left_decoded[AlpRDConstants::ALP_VECTOR_SIZE * 8] = {0}; uint8_t right_decoded[AlpRDConstants::ALP_VECTOR_SIZE * 8] = {0}; diff --git a/src/duckdb/src/include/duckdb/storage/compression/alprd/alprd_analyze.hpp b/src/duckdb/src/include/duckdb/storage/compression/alprd/alprd_analyze.hpp index 25901667e..da7f8bda0 100644 --- a/src/duckdb/src/include/duckdb/storage/compression/alprd/alprd_analyze.hpp +++ b/src/duckdb/src/include/duckdb/storage/compression/alprd/alprd_analyze.hpp @@ -47,8 +47,12 @@ unique_ptr AlpRDInitAnalyze(ColumnData &col_data, PhysicalType typ */ template bool AlpRDAnalyze(AnalyzeState &state, Vector &input, idx_t count) { + if (state.info.GetBlockSize() + state.info.GetBlockHeaderSize() < DEFAULT_BLOCK_ALLOC_SIZE) { + return false; + } + using EXACT_TYPE = typename FloatingToExact::TYPE; - auto &analyze_state = (AlpRDAnalyzeState &)state; + auto &analyze_state = state.Cast>(); bool must_skip_current_vector = alp::AlpUtils::MustSkipSamplingFromCurrentVector( analyze_state.vectors_count, analyze_state.vectors_sampled_count, count); diff --git a/src/duckdb/src/include/duckdb/storage/compression/alprd/alprd_compress.hpp b/src/duckdb/src/include/duckdb/storage/compression/alprd/alprd_compress.hpp index 86559d604..2eff3f97d 100644 --- a/src/duckdb/src/include/duckdb/storage/compression/alprd/alprd_compress.hpp +++ b/src/duckdb/src/include/duckdb/storage/compression/alprd/alprd_compress.hpp @@ -30,7 +30,6 @@ namespace duckdb { template struct AlpRDCompressionState : public CompressionState { - public: using EXACT_TYPE = typename FloatingToExact::TYPE; diff --git a/src/duckdb/src/include/duckdb/storage/compression/alprd/alprd_scan.hpp b/src/duckdb/src/include/duckdb/storage/compression/alprd/alprd_scan.hpp index a3feb94b5..520d38fa2 100644 --- a/src/duckdb/src/include/duckdb/storage/compression/alprd/alprd_scan.hpp +++ b/src/duckdb/src/include/duckdb/storage/compression/alprd/alprd_scan.hpp @@ -208,7 +208,7 @@ struct AlpRDScanState : public SegmentScanState { }; template -unique_ptr AlpRDInitScan(ColumnSegment &segment) { +unique_ptr AlpRDInitScan(const QueryContext &context, ColumnSegment &segment) { auto result = make_uniq_base>(segment); return result; } diff --git a/src/duckdb/src/include/duckdb/storage/compression/chimp/algorithm/chimp128.hpp b/src/duckdb/src/include/duckdb/storage/compression/chimp/algorithm/chimp128.hpp index 277e30b6a..a2b0566e1 100644 --- a/src/duckdb/src/include/duckdb/storage/compression/chimp/algorithm/chimp128.hpp +++ b/src/duckdb/src/include/duckdb/storage/compression/chimp/algorithm/chimp128.hpp @@ -31,7 +31,6 @@ namespace duckdb { template struct Chimp128CompressionState { - Chimp128CompressionState() : ring_buffer(), previous_leading_zeros(NumericLimits::Maximum()) { previous_value = 0; } @@ -104,7 +103,6 @@ class Chimp128Compression { } static void CompressValue(CHIMP_TYPE in, State &state) { - auto key = state.ring_buffer.Key(in); CHIMP_TYPE xor_result; uint8_t previous_index; diff --git a/src/duckdb/src/include/duckdb/storage/compression/chimp/algorithm/flag_buffer.hpp b/src/duckdb/src/include/duckdb/storage/compression/chimp/algorithm/flag_buffer.hpp index f5c0d70ba..7a38c065e 100644 --- a/src/duckdb/src/include/duckdb/storage/compression/chimp/algorithm/flag_buffer.hpp +++ b/src/duckdb/src/include/duckdb/storage/compression/chimp/algorithm/flag_buffer.hpp @@ -34,7 +34,6 @@ struct FlagBufferConstants { // So we can just read/write from left to right template class FlagBuffer { - public: FlagBuffer() : counter(0), buffer(nullptr) { } diff --git a/src/duckdb/src/include/duckdb/storage/compression/chimp/algorithm/leading_zero_buffer.hpp b/src/duckdb/src/include/duckdb/storage/compression/chimp/algorithm/leading_zero_buffer.hpp index c4b23cfd9..23376dc1f 100644 --- a/src/duckdb/src/include/duckdb/storage/compression/chimp/algorithm/leading_zero_buffer.hpp +++ b/src/duckdb/src/include/duckdb/storage/compression/chimp/algorithm/leading_zero_buffer.hpp @@ -40,7 +40,6 @@ struct LeadingZeroBufferConstants { template class LeadingZeroBuffer { - public: static constexpr uint32_t CHIMP_GROUP_SIZE = 1024; static constexpr uint32_t LEADING_ZERO_BITS_SIZE = 3; diff --git a/src/duckdb/src/include/duckdb/storage/compression/chimp/chimp_scan.hpp b/src/duckdb/src/include/duckdb/storage/compression/chimp/chimp_scan.hpp index de11979cb..00b27f9e8 100644 --- a/src/duckdb/src/include/duckdb/storage/compression/chimp/chimp_scan.hpp +++ b/src/duckdb/src/include/duckdb/storage/compression/chimp/chimp_scan.hpp @@ -185,7 +185,6 @@ struct ChimpScanState : public SegmentScanState { } void LoadGroup(CHIMP_TYPE *value_buffer) { - //! FIXME: If we change the order of this to flag -> leading_zero_blocks -> packed_data //! We can leave out the leading zero block count as well, because it can be derived from //! Extracting all the flags and counting the 3's @@ -252,7 +251,7 @@ struct ChimpScanState : public SegmentScanState { }; template -unique_ptr ChimpInitScan(ColumnSegment &segment) { +unique_ptr ChimpInitScan(const QueryContext &context, ColumnSegment &segment) { auto result = make_uniq_base>(segment); return result; } diff --git a/src/duckdb/src/include/duckdb/storage/compression/dict_fsst/compression.hpp b/src/duckdb/src/include/duckdb/storage/compression/dict_fsst/compression.hpp index 1bc613a0a..d8718eb36 100644 --- a/src/duckdb/src/include/duckdb/storage/compression/dict_fsst/compression.hpp +++ b/src/duckdb/src/include/duckdb/storage/compression/dict_fsst/compression.hpp @@ -1,5 +1,6 @@ #pragma once +#include "duckdb/common/primitive_dictionary.hpp" #include "duckdb/common/typedefs.hpp" #include "duckdb/storage/compression/dict_fsst/common.hpp" #include "duckdb/storage/compression/dict_fsst/analyze.hpp" @@ -75,7 +76,7 @@ struct DictFSSTCompressionState : public CompressionState { bitpacking_width_t dictionary_indices_width = 0; //! string -> dictionary_index (for lookups) - string_map_t current_string_map; + PrimitiveDictionary current_string_map; //! strings added to the dictionary waiting to be encoded vector dictionary_encoding_buffer; idx_t to_encode_string_sum = 0; diff --git a/src/duckdb/src/include/duckdb/storage/compression/dict_fsst/decompression.hpp b/src/duckdb/src/include/duckdb/storage/compression/dict_fsst/decompression.hpp index 032370f86..1cb377baf 100644 --- a/src/duckdb/src/include/duckdb/storage/compression/dict_fsst/decompression.hpp +++ b/src/duckdb/src/include/duckdb/storage/compression/dict_fsst/decompression.hpp @@ -59,7 +59,7 @@ struct CompressedStringScanState : public SegmentScanState { data_ptr_t dictionary_indices_ptr; data_ptr_t string_lengths_ptr; - buffer_ptr dictionary; + buffer_ptr dictionary; void *decoder = nullptr; bool all_values_inlined = false; diff --git a/src/duckdb/src/include/duckdb/storage/compression/dictionary/analyze.hpp b/src/duckdb/src/include/duckdb/storage/compression/dictionary/analyze.hpp index 99eb72156..14565e476 100644 --- a/src/duckdb/src/include/duckdb/storage/compression/dictionary/analyze.hpp +++ b/src/duckdb/src/include/duckdb/storage/compression/dictionary/analyze.hpp @@ -21,11 +21,14 @@ struct DictionaryAnalyzeState : public DictionaryCompressionState { bool CalculateSpaceRequirements(bool new_string, idx_t string_size) override; void Flush(bool final = false) override; void Verify() override; + void UpdateMaxUniqueCount(); public: idx_t segment_count; idx_t current_tuple_count; idx_t current_unique_count; + idx_t max_unique_count_across_segments = + 0; // Is used to allocate the dictionary optimally later on at the InitCompression step idx_t current_dict_size; StringHeap heap; string_set_t current_set; diff --git a/src/duckdb/src/include/duckdb/storage/compression/dictionary/compression.hpp b/src/duckdb/src/include/duckdb/storage/compression/dictionary/compression.hpp index 09f1f44bd..97f33bf24 100644 --- a/src/duckdb/src/include/duckdb/storage/compression/dictionary/compression.hpp +++ b/src/duckdb/src/include/duckdb/storage/compression/dictionary/compression.hpp @@ -1,5 +1,6 @@ #pragma once +#include "duckdb/common/primitive_dictionary.hpp" #include "duckdb/common/typedefs.hpp" #include "duckdb/storage/compression/dictionary/common.hpp" #include "duckdb/function/compression_function.hpp" @@ -23,7 +24,8 @@ namespace duckdb { //===--------------------------------------------------------------------===// struct DictionaryCompressionCompressState : public DictionaryCompressionState { public: - DictionaryCompressionCompressState(ColumnDataCheckpointData &checkpoint_data_p, const CompressionInfo &info); + DictionaryCompressionCompressState(ColumnDataCheckpointData &checkpoint_data_p, const CompressionInfo &info, + idx_t max_unique_count_across_all_segments); public: void CreateEmptySegment(idx_t row_start); @@ -47,7 +49,7 @@ struct DictionaryCompressionCompressState : public DictionaryCompressionState { data_ptr_t current_end_ptr; // Buffers and map for current segment - string_map_t current_string_map; + PrimitiveDictionary current_string_map; vector index_buffer; vector selection_buffer; diff --git a/src/duckdb/src/include/duckdb/storage/compression/dictionary/decompression.hpp b/src/duckdb/src/include/duckdb/storage/compression/dictionary/decompression.hpp index 1656ec718..e7381f11a 100644 --- a/src/duckdb/src/include/duckdb/storage/compression/dictionary/decompression.hpp +++ b/src/duckdb/src/include/duckdb/storage/compression/dictionary/decompression.hpp @@ -41,7 +41,7 @@ struct CompressedStringScanState : public StringScanState { uint32_t *index_buffer_ptr; uint32_t index_buffer_count; - buffer_ptr dictionary; + buffer_ptr dictionary; idx_t dictionary_size; StringDictionaryContainer dict; idx_t block_size; diff --git a/src/duckdb/src/include/duckdb/storage/compression/empty_validity.hpp b/src/duckdb/src/include/duckdb/storage/compression/empty_validity.hpp index 1118f77f2..476faf89f 100644 --- a/src/duckdb/src/include/duckdb/storage/compression/empty_validity.hpp +++ b/src/duckdb/src/include/duckdb/storage/compression/empty_validity.hpp @@ -77,7 +77,7 @@ class EmptyValidityCompression { auto &checkpoint_state = checkpoint_data.GetCheckpointState(); checkpoint_state.FlushSegment(std::move(compressed_segment), std::move(handle), 0); } - static unique_ptr InitScan(ColumnSegment &segment) { + static unique_ptr InitScan(const QueryContext &context, ColumnSegment &segment) { return make_uniq(); } static void ScanPartial(ColumnSegment &segment, ColumnScanState &state, idx_t scan_count, Vector &result, diff --git a/src/duckdb/src/include/duckdb/storage/compression/patas/patas_scan.hpp b/src/duckdb/src/include/duckdb/storage/compression/patas/patas_scan.hpp index b523600e3..4261d2d23 100644 --- a/src/duckdb/src/include/duckdb/storage/compression/patas/patas_scan.hpp +++ b/src/duckdb/src/include/duckdb/storage/compression/patas/patas_scan.hpp @@ -204,7 +204,7 @@ struct PatasScanState : public SegmentScanState { }; template -unique_ptr PatasInitScan(ColumnSegment &segment) { +unique_ptr PatasInitScan(const QueryContext &context, ColumnSegment &segment) { auto result = make_uniq_base>(segment); return result; } diff --git a/src/duckdb/src/include/duckdb/storage/data_pointer.hpp b/src/duckdb/src/include/duckdb/storage/data_pointer.hpp index 0ae44d7f3..54f6f239f 100644 --- a/src/duckdb/src/include/duckdb/storage/data_pointer.hpp +++ b/src/duckdb/src/include/duckdb/storage/data_pointer.hpp @@ -19,6 +19,7 @@ namespace duckdb { class Serializer; class Deserializer; +class QueryContext; struct ColumnSegmentState { virtual ~ColumnSegmentState() { diff --git a/src/duckdb/src/include/duckdb/storage/data_table.hpp b/src/duckdb/src/include/duckdb/storage/data_table.hpp index bc8727a18..bf355ed05 100644 --- a/src/duckdb/src/include/duckdb/storage/data_table.hpp +++ b/src/duckdb/src/include/duckdb/storage/data_table.hpp @@ -196,7 +196,7 @@ class DataTable : public enable_shared_from_this { //! Remove the chunk with the specified set of row identifiers from all indexes of the table void RemoveFromIndexes(TableAppendState &state, DataChunk &chunk, Vector &row_identifiers); //! Remove the row identifiers from all the indexes of the table - void RemoveFromIndexes(Vector &row_identifiers, idx_t count); + void RemoveFromIndexes(const QueryContext &context, Vector &row_identifiers, idx_t count); void SetAsMainTable() { this->version = DataTableVersion::MAIN_TABLE; @@ -234,7 +234,7 @@ class DataTable : public enable_shared_from_this { idx_t ColumnCount() const; idx_t GetTotalRows() const; - vector GetColumnSegmentInfo(); + vector GetColumnSegmentInfo(const QueryContext &context); //! Scans the next chunk for the CREATE INDEX operator bool CreateIndexScan(TableScanState &state, DataChunk &result, TableScanType type); diff --git a/src/duckdb/src/include/duckdb/storage/index.hpp b/src/duckdb/src/include/duckdb/storage/index.hpp index 2b624c2c1..492f37e29 100644 --- a/src/duckdb/src/include/duckdb/storage/index.hpp +++ b/src/duckdb/src/include/duckdb/storage/index.hpp @@ -31,9 +31,15 @@ class Index { protected: Index(const vector &column_ids, TableIOManager &table_io_manager, AttachedDatabase &db); - //! The logical column ids of the indexed table + //! The physical column ids of the indexed columns. + //! For example, given a table with the following columns: + //! (a INT, gen AS (2 * a), b INT, c VARCHAR), an index on columns (a,c) would have physical + //! column_ids [0,2] (since the virtual column is skipped in the physical representation). + //! Also see comments in bound_index.hpp to see how these column IDs are used in the context of + //! bound/unbound expressions. + //! Note that these are the columns for this Index, not all Indexes on the table. vector column_ids; - //! Unordered set of column_ids used by the index + //! Unordered set of column_ids used by the Index unordered_set column_id_set; public: diff --git a/src/duckdb/src/include/duckdb/storage/metadata/metadata_manager.hpp b/src/duckdb/src/include/duckdb/storage/metadata/metadata_manager.hpp index cd63a96b8..c2320c29d 100644 --- a/src/duckdb/src/include/duckdb/storage/metadata/metadata_manager.hpp +++ b/src/duckdb/src/include/duckdb/storage/metadata/metadata_manager.hpp @@ -62,10 +62,14 @@ class MetadataManager { MetadataManager(BlockManager &block_manager, BufferManager &buffer_manager); ~MetadataManager(); + BufferManager &GetBufferManager() const { + return buffer_manager; + } + MetadataHandle AllocateHandle(); MetadataHandle Pin(const MetadataPointer &pointer); - MetadataHandle Pin(QueryContext context, const MetadataPointer &pointer); + MetadataHandle Pin(const QueryContext &context, const MetadataPointer &pointer); MetaBlockPointer GetDiskPointer(const MetadataPointer &pointer, uint32_t offset = 0); MetadataPointer FromDiskPointer(MetaBlockPointer pointer); @@ -77,6 +81,8 @@ class MetadataManager { //! Flush all blocks to disk void Flush(); + bool BlockHasBeenCleared(const MetaBlockPointer &ptr); + void MarkBlocksAsModified(); void ClearModifiedBlocks(const vector &pointers); diff --git a/src/duckdb/src/include/duckdb/storage/metadata/metadata_reader.hpp b/src/duckdb/src/include/duckdb/storage/metadata/metadata_reader.hpp index 51894886a..ce8d01b41 100644 --- a/src/duckdb/src/include/duckdb/storage/metadata/metadata_reader.hpp +++ b/src/duckdb/src/include/duckdb/storage/metadata/metadata_reader.hpp @@ -52,7 +52,7 @@ class MetadataReader : public ReadStream { MetadataManager &manager; BlockReaderType type; MetadataHandle block; - MetadataPointer next_pointer; + MetaBlockPointer next_pointer; bool has_next_block; optional_ptr> read_pointers; idx_t index; diff --git a/src/duckdb/src/include/duckdb/storage/optimistic_data_writer.hpp b/src/duckdb/src/include/duckdb/storage/optimistic_data_writer.hpp index 1ded8bba6..98d371437 100644 --- a/src/duckdb/src/include/duckdb/storage/optimistic_data_writer.hpp +++ b/src/duckdb/src/include/duckdb/storage/optimistic_data_writer.hpp @@ -14,11 +14,16 @@ namespace duckdb { class PartialBlockManager; struct OptimisticWriteCollection { + ~OptimisticWriteCollection(); + shared_ptr collection; idx_t last_flushed = 0; idx_t complete_row_groups = 0; + vector> partial_block_managers; }; +enum class OptimisticWritePartialManagers { PER_COLUMN, GLOBAL }; + class OptimisticDataWriter { public: OptimisticDataWriter(ClientContext &context, DataTable &table); @@ -26,8 +31,9 @@ class OptimisticDataWriter { ~OptimisticDataWriter(); //! Creates a collection to write to - static unique_ptr CreateCollection(DataTable &storage, - const vector &insert_types); + unique_ptr + CreateCollection(DataTable &storage, const vector &insert_types, + OptimisticWritePartialManagers type = OptimisticWritePartialManagers::PER_COLUMN); //! Write a new row group to disk (if possible) void WriteNewRowGroup(OptimisticWriteCollection &row_groups); //! Write the last row group of a collection to disk @@ -35,9 +41,10 @@ class OptimisticDataWriter { //! Final flush of the optimistic writer - fully flushes the partial block manager void FinalFlush(); //! Flushes a specific row group to disk - void FlushToDisk(const vector> &row_groups); + void FlushToDisk(OptimisticWriteCollection &collection, const vector> &row_groups); //! Merge the partially written blocks from one optimistic writer into another void Merge(OptimisticDataWriter &other); + void Merge(unique_ptr &other_manager); //! Rollback void Rollback(); diff --git a/src/duckdb/src/include/duckdb/storage/standard_buffer_manager.hpp b/src/duckdb/src/include/duckdb/storage/standard_buffer_manager.hpp index d0a54c597..ff4ce4684 100644 --- a/src/duckdb/src/include/duckdb/storage/standard_buffer_manager.hpp +++ b/src/duckdb/src/include/duckdb/storage/standard_buffer_manager.hpp @@ -71,7 +71,7 @@ class StandardBufferManager : public BufferManager { void ReAllocate(shared_ptr &handle, idx_t block_size) final; BufferHandle Pin(shared_ptr &handle) final; - BufferHandle Pin(QueryContext context, shared_ptr &handle); + BufferHandle Pin(const QueryContext &context, shared_ptr &handle) final; void Prefetch(vector> &handles) final; void Unpin(shared_ptr &handle) final; @@ -84,6 +84,8 @@ class StandardBufferManager : public BufferManager { //! Returns information about memory usage vector GetMemoryUsageInfo() const override; + BlockManager &GetTemporaryBlockManager() final; + //! Returns a list of all temporary files vector GetTemporaryFiles() final; diff --git a/src/duckdb/src/include/duckdb/storage/statistics/base_statistics.hpp b/src/duckdb/src/include/duckdb/storage/statistics/base_statistics.hpp index 2101bcb31..e37879f61 100644 --- a/src/duckdb/src/include/duckdb/storage/statistics/base_statistics.hpp +++ b/src/duckdb/src/include/duckdb/storage/statistics/base_statistics.hpp @@ -15,6 +15,7 @@ #include "duckdb/common/types/value.hpp" #include "duckdb/storage/statistics/numeric_stats.hpp" #include "duckdb/storage/statistics/string_stats.hpp" +#include "duckdb/storage/statistics/geometry_stats.hpp" namespace duckdb { struct SelectionVector; @@ -33,7 +34,15 @@ enum class StatsInfo : uint8_t { CAN_HAVE_NULL_AND_VALID_VALUES = 4 }; -enum class StatisticsType : uint8_t { NUMERIC_STATS, STRING_STATS, LIST_STATS, STRUCT_STATS, BASE_STATS, ARRAY_STATS }; +enum class StatisticsType : uint8_t { + NUMERIC_STATS, + STRING_STATS, + LIST_STATS, + STRUCT_STATS, + BASE_STATS, + ARRAY_STATS, + GEOMETRY_STATS +}; class BaseStatistics { friend struct NumericStats; @@ -41,6 +50,7 @@ class BaseStatistics { friend struct StructStats; friend struct ListStats; friend struct ArrayStats; + friend struct GeometryStats; public: DUCKDB_API ~BaseStatistics(); @@ -146,6 +156,8 @@ class BaseStatistics { NumericStatsData numeric_data; //! String stats data, for string stats StringStatsData string_data; + //! Geometry stats data, for geometry stats + GeometryStatsData geometry_data; } stats_union; //! Child stats (for LIST and STRUCT) unsafe_unique_array child_stats; diff --git a/src/duckdb/src/include/duckdb/storage/statistics/geometry_stats.hpp b/src/duckdb/src/include/duckdb/storage/statistics/geometry_stats.hpp new file mode 100644 index 000000000..6c6cfa35a --- /dev/null +++ b/src/duckdb/src/include/duckdb/storage/statistics/geometry_stats.hpp @@ -0,0 +1,165 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/storage/statistics/geometry_stats.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/common/common.hpp" +#include "duckdb/common/enums/expression_type.hpp" +#include "duckdb/common/enums/filter_propagate_result.hpp" +#include "duckdb/common/exception.hpp" +#include "duckdb/common/operator/comparison_operators.hpp" +#include "duckdb/common/types/hugeint.hpp" +#include "duckdb/common/array_ptr.hpp" +#include "duckdb/common/types/geometry.hpp" + +namespace duckdb { +class BaseStatistics; +struct SelectionVector; + +class GeometryTypeSet { +public: + static constexpr auto VERT_TYPES = 4; + static constexpr auto PART_TYPES = 8; + + static GeometryTypeSet Unknown() { + GeometryTypeSet result; + for (idx_t i = 0; i < VERT_TYPES; i++) { + result.sets[i] = 0xFF; + } + return result; + } + static GeometryTypeSet Empty() { + GeometryTypeSet result; + for (idx_t i = 0; i < VERT_TYPES; i++) { + result.sets[i] = 0; + } + return result; + } + + bool IsEmpty() const { + for (idx_t i = 0; i < VERT_TYPES; i++) { + if (sets[i] != 0) { + return false; + } + } + return true; + } + + bool IsUnknown() const { + for (idx_t i = 0; i < VERT_TYPES; i++) { + if (sets[i] != 0xFF) { + return false; + } + } + return true; + } + + void Add(GeometryType geom_type, VertexType vert_type) { + const auto vert_idx = static_cast(vert_type); + const auto geom_idx = static_cast(geom_type); + D_ASSERT(vert_idx < VERT_TYPES); + D_ASSERT(geom_idx < PART_TYPES); + sets[vert_idx] |= (1 << geom_idx); + } + + void Merge(const GeometryTypeSet &other) { + for (idx_t i = 0; i < VERT_TYPES; i++) { + sets[i] |= other.sets[i]; + } + } + + void Clear() { + for (idx_t i = 0; i < VERT_TYPES; i++) { + sets[i] = 0; + } + } + + void AddWKBType(int32_t wkb_type) { + const auto vert_idx = static_cast((wkb_type / 1000) % 10); + const auto geom_idx = static_cast(wkb_type % 1000); + D_ASSERT(vert_idx < VERT_TYPES); + D_ASSERT(geom_idx < PART_TYPES); + sets[vert_idx] |= (1 << geom_idx); + } + + vector ToWKBList() const { + vector result; + for (uint8_t vert_idx = 0; vert_idx < VERT_TYPES; vert_idx++) { + for (uint8_t geom_idx = 1; geom_idx < PART_TYPES; geom_idx++) { + if (sets[vert_idx] & (1 << geom_idx)) { + result.push_back(geom_idx + vert_idx * 1000); + } + } + } + return result; + } + + vector ToString(bool snake_case) const; + + uint8_t sets[VERT_TYPES]; +}; + +struct GeometryStatsData { + GeometryTypeSet types; + GeometryExtent extent; + + void SetEmpty() { + types = GeometryTypeSet::Empty(); + extent = GeometryExtent::Empty(); + } + + void SetUnknown() { + types = GeometryTypeSet::Unknown(); + extent = GeometryExtent::Unknown(); + } + + void Merge(const GeometryStatsData &other) { + types.Merge(other.types); + extent.Merge(other.extent); + } + + void Update(const string_t &geom_blob) { + // Parse type + const auto type_info = Geometry::GetType(geom_blob); + types.Add(type_info.first, type_info.second); + + // Update extent + Geometry::GetExtent(geom_blob, extent); + } +}; + +struct GeometryStats { + //! Unknown statistics + DUCKDB_API static BaseStatistics CreateUnknown(LogicalType type); + //! Empty statistics + DUCKDB_API static BaseStatistics CreateEmpty(LogicalType type); + + DUCKDB_API static void Serialize(const BaseStatistics &stats, Serializer &serializer); + DUCKDB_API static void Deserialize(Deserializer &deserializer, BaseStatistics &base); + + DUCKDB_API static string ToString(const BaseStatistics &stats); + + DUCKDB_API static void Update(BaseStatistics &stats, const string_t &value); + DUCKDB_API static void Merge(BaseStatistics &stats, const BaseStatistics &other); + DUCKDB_API static void Verify(const BaseStatistics &stats, Vector &vector, const SelectionVector &sel, idx_t count); + + //! Check if a spatial predicate check with a constant could possibly be satisfied by rows given the statistics + DUCKDB_API static FilterPropagateResult CheckZonemap(const BaseStatistics &stats, + const unique_ptr &expr); + + DUCKDB_API static GeometryExtent &GetExtent(BaseStatistics &stats); + DUCKDB_API static const GeometryExtent &GetExtent(const BaseStatistics &stats); + DUCKDB_API static GeometryTypeSet &GetTypes(BaseStatistics &stats); + DUCKDB_API static const GeometryTypeSet &GetTypes(const BaseStatistics &stats); + +private: + static GeometryStatsData &GetDataUnsafe(BaseStatistics &stats); + static const GeometryStatsData &GetDataUnsafe(const BaseStatistics &stats); +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/storage/statistics/string_stats.hpp b/src/duckdb/src/include/duckdb/storage/statistics/string_stats.hpp index 0982f8905..6e5814a36 100644 --- a/src/duckdb/src/include/duckdb/storage/statistics/string_stats.hpp +++ b/src/duckdb/src/include/duckdb/storage/statistics/string_stats.hpp @@ -71,6 +71,8 @@ struct StringStats { ExpressionType comparison_type, const string &value); DUCKDB_API static void Update(BaseStatistics &stats, const string_t &value); + DUCKDB_API static void SetMin(BaseStatistics &stats, const string_t &value); + DUCKDB_API static void SetMax(BaseStatistics &stats, const string_t &value); DUCKDB_API static void Merge(BaseStatistics &stats, const BaseStatistics &other); DUCKDB_API static void Verify(const BaseStatistics &stats, Vector &vector, const SelectionVector &sel, idx_t count); diff --git a/src/duckdb/src/include/duckdb/storage/storage_manager.hpp b/src/duckdb/src/include/duckdb/storage/storage_manager.hpp index c96a76ff7..692efd311 100644 --- a/src/duckdb/src/include/duckdb/storage/storage_manager.hpp +++ b/src/duckdb/src/include/duckdb/storage/storage_manager.hpp @@ -156,7 +156,8 @@ class StorageManager { bool load_complete = false; //! The serialization compatibility version when reading and writing from this database optional_idx storage_version; - //! Estimated size of changes for determining automatic checkpointing on in-memory databases + //! Estimated size of changes for determining automatic checkpointing on in-memory databases and databases without a + //! WAL. atomic in_memory_change_size; //! Storage options passed in through configuration StorageOptions storage_options; diff --git a/src/duckdb/src/include/duckdb/storage/string_uncompressed.hpp b/src/duckdb/src/include/duckdb/storage/string_uncompressed.hpp index b5342829c..755e99339 100644 --- a/src/duckdb/src/include/duckdb/storage/string_uncompressed.hpp +++ b/src/duckdb/src/include/duckdb/storage/string_uncompressed.hpp @@ -67,7 +67,7 @@ struct UncompressedStringStorage { static unique_ptr StringInitAnalyze(ColumnData &col_data, PhysicalType type); static bool StringAnalyze(AnalyzeState &state_p, Vector &input, idx_t count); static idx_t StringFinalAnalyze(AnalyzeState &state_p); - static unique_ptr StringInitScan(ColumnSegment &segment); + static unique_ptr StringInitScan(const QueryContext &context, ColumnSegment &segment); static void StringScanPartial(ColumnSegment &segment, ColumnScanState &state, idx_t scan_count, Vector &result, idx_t result_offset); static void StringScan(ColumnSegment &segment, ColumnScanState &state, idx_t scan_count, Vector &result); @@ -201,7 +201,11 @@ struct UncompressedStringStorage { public: static inline void UpdateStringStats(SegmentStatistics &stats, const string_t &new_value) { - StringStats::Update(stats.statistics, new_value); + if (stats.statistics.GetStatsType() == StatisticsType::GEOMETRY_STATS) { + GeometryStats::Update(stats.statistics, new_value); + } else { + StringStats::Update(stats.statistics, new_value); + } } static void SetDictionary(ColumnSegment &segment, BufferHandle &handle, StringDictionaryContainer dict); diff --git a/src/duckdb/src/include/duckdb/storage/table/append_state.hpp b/src/duckdb/src/include/duckdb/storage/table/append_state.hpp index 0a5c7b170..203263f70 100644 --- a/src/duckdb/src/include/duckdb/storage/table/append_state.hpp +++ b/src/duckdb/src/include/duckdb/storage/table/append_state.hpp @@ -24,12 +24,14 @@ class LocalTableStorage; class RowGroup; class UpdateSegment; class TableCatalogEntry; +template +struct SegmentNode; struct TableAppendState; struct ColumnAppendState { //! The current segment of the append - ColumnSegment *current; + optional_ptr> current; //! Child append states vector child_appends; //! The write lock that is held by the append @@ -67,7 +69,7 @@ struct TableAppendState { //! The total number of rows appended by the append operation idx_t total_append_count; //! The first row-group that has been appended to - RowGroup *start_row_group; + optional_ptr> start_row_group; //! The transaction data TransactionData transaction; //! Table statistics diff --git a/src/duckdb/src/include/duckdb/storage/table/array_column_data.hpp b/src/duckdb/src/include/duckdb/storage/table/array_column_data.hpp index abc9577a3..c246d68b6 100644 --- a/src/duckdb/src/include/duckdb/storage/table/array_column_data.hpp +++ b/src/duckdb/src/include/duckdb/storage/table/array_column_data.hpp @@ -48,10 +48,10 @@ class ArrayColumnData : public ColumnData { idx_t Fetch(ColumnScanState &state, row_t row_id, Vector &result) override; void FetchRow(TransactionData transaction, ColumnFetchState &state, row_t row_id, Vector &result, idx_t result_idx) override; - void Update(TransactionData transaction, idx_t column_index, Vector &update_vector, row_t *row_ids, - idx_t update_count) override; - void UpdateColumn(TransactionData transaction, const vector &column_path, Vector &update_vector, - row_t *row_ids, idx_t update_count, idx_t depth) override; + void Update(TransactionData transaction, DataTable &data_table, idx_t column_index, Vector &update_vector, + row_t *row_ids, idx_t update_count) override; + void UpdateColumn(TransactionData transaction, DataTable &data_table, const vector &column_path, + Vector &update_vector, row_t *row_ids, idx_t update_count, idx_t depth) override; unique_ptr GetUpdateStatistics() override; void CommitDropColumn() override; @@ -65,8 +65,8 @@ class ArrayColumnData : public ColumnData { PersistentColumnData Serialize() override; void InitializeColumn(PersistentColumnData &column_data, BaseStatistics &target_stats) override; - void GetColumnSegmentInfo(duckdb::idx_t row_group_index, vector col_path, - vector &result) override; + void GetColumnSegmentInfo(const QueryContext &context, duckdb::idx_t row_group_index, + vector col_path, vector &result) override; void Verify(RowGroup &parent) override; }; diff --git a/src/duckdb/src/include/duckdb/storage/table/chunk_info.hpp b/src/duckdb/src/include/duckdb/storage/table/chunk_info.hpp index 44b92dd74..9e33b4201 100644 --- a/src/duckdb/src/include/duckdb/storage/table/chunk_info.hpp +++ b/src/duckdb/src/include/duckdb/storage/table/chunk_info.hpp @@ -11,6 +11,7 @@ #include "duckdb/common/common.hpp" #include "duckdb/common/vector_size.hpp" #include "duckdb/common/atomic.hpp" +#include "duckdb/execution/index/index_pointer.hpp" namespace duckdb { class RowGroup; @@ -20,6 +21,7 @@ struct TransactionData; struct DeleteInfo; class Serializer; class Deserializer; +class FixedSizeAllocator; enum class ChunkInfoType : uint8_t { CONSTANT_INFO, VECTOR_INFO, EMPTY_INFO }; @@ -38,19 +40,19 @@ class ChunkInfo { public: //! Gets up to max_count entries from the chunk info. If the ret is 0>ret>max_count, the selection vector is filled //! with the tuples - virtual idx_t GetSelVector(TransactionData transaction, SelectionVector &sel_vector, idx_t max_count) = 0; + virtual idx_t GetSelVector(TransactionData transaction, SelectionVector &sel_vector, idx_t max_count) const = 0; virtual idx_t GetCommittedSelVector(transaction_t min_start_id, transaction_t min_transaction_id, SelectionVector &sel_vector, idx_t max_count) = 0; //! Returns whether or not a single row in the ChunkInfo should be used or not for the given transaction virtual bool Fetch(TransactionData transaction, row_t row) = 0; virtual void CommitAppend(transaction_t commit_id, idx_t start, idx_t end) = 0; - virtual idx_t GetCommittedDeletedCount(idx_t max_count) = 0; + virtual idx_t GetCommittedDeletedCount(idx_t max_count) const = 0; virtual bool Cleanup(transaction_t lowest_transaction, unique_ptr &result) const; virtual bool HasDeletes() const = 0; virtual void Write(WriteStream &writer) const; - static unique_ptr Read(ReadStream &reader); + static unique_ptr Read(FixedSizeAllocator &allocator, ReadStream &reader); public: template @@ -81,12 +83,12 @@ class ChunkConstantInfo : public ChunkInfo { transaction_t delete_id; public: - idx_t GetSelVector(TransactionData transaction, SelectionVector &sel_vector, idx_t max_count) override; + idx_t GetSelVector(TransactionData transaction, SelectionVector &sel_vector, idx_t max_count) const override; idx_t GetCommittedSelVector(transaction_t min_start_id, transaction_t min_transaction_id, SelectionVector &sel_vector, idx_t max_count) override; bool Fetch(TransactionData transaction, row_t row) override; void CommitAppend(transaction_t commit_id, idx_t start, idx_t end) override; - idx_t GetCommittedDeletedCount(idx_t max_count) override; + idx_t GetCommittedDeletedCount(idx_t max_count) const override; bool Cleanup(transaction_t lowest_transaction, unique_ptr &result) const override; bool HasDeletes() const override; @@ -105,27 +107,19 @@ class ChunkVectorInfo : public ChunkInfo { static constexpr const ChunkInfoType TYPE = ChunkInfoType::VECTOR_INFO; public: - explicit ChunkVectorInfo(idx_t start); - - //! The transaction ids of the transactions that inserted the tuples (if any) - transaction_t inserted[STANDARD_VECTOR_SIZE]; - transaction_t insert_id; - bool same_inserted_id; - - //! The transaction ids of the transactions that deleted the tuples (if any) - transaction_t deleted[STANDARD_VECTOR_SIZE]; - bool any_deleted; + explicit ChunkVectorInfo(FixedSizeAllocator &allocator, idx_t start, transaction_t insert_id = 0); + ~ChunkVectorInfo() override; public: idx_t GetSelVector(transaction_t start_time, transaction_t transaction_id, SelectionVector &sel_vector, idx_t max_count) const; - idx_t GetSelVector(TransactionData transaction, SelectionVector &sel_vector, idx_t max_count) override; + idx_t GetSelVector(TransactionData transaction, SelectionVector &sel_vector, idx_t max_count) const override; idx_t GetCommittedSelVector(transaction_t min_start_id, transaction_t min_transaction_id, SelectionVector &sel_vector, idx_t max_count) override; bool Fetch(TransactionData transaction, row_t row) override; void CommitAppend(transaction_t commit_id, idx_t start, idx_t end) override; bool Cleanup(transaction_t lowest_transaction, unique_ptr &result) const override; - idx_t GetCommittedDeletedCount(idx_t max_count) override; + idx_t GetCommittedDeletedCount(idx_t max_count) const override; void Append(idx_t start, idx_t end, transaction_t commit_id); @@ -138,14 +132,32 @@ class ChunkVectorInfo : public ChunkInfo { void CommitDelete(transaction_t commit_id, const DeleteInfo &info); bool HasDeletes() const override; + bool AnyDeleted() const; + bool HasConstantInsertionId() const; + transaction_t ConstantInsertId() const; void Write(WriteStream &writer) const override; - static unique_ptr Read(ReadStream &reader); + static unique_ptr Read(FixedSizeAllocator &allocator, ReadStream &reader); private: template idx_t TemplatedGetSelVector(transaction_t start_time, transaction_t transaction_id, SelectionVector &sel_vector, idx_t max_count) const; + + IndexPointer GetInsertedPointer() const; + IndexPointer GetDeletedPointer() const; + IndexPointer GetInitializedInsertedPointer(); + IndexPointer GetInitializedDeletedPointer(); + +private: + FixedSizeAllocator &allocator; + //! The transaction ids of the transactions that inserted the tuples (if any) + IndexPointer inserted_data; + //! The constant insert id (if there is only one) + transaction_t constant_insert_id; + + //! The transaction ids of the transactions that deleted the tuples (if any) + IndexPointer deleted_data; }; } // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/storage/table/column_data.hpp b/src/duckdb/src/include/duckdb/storage/table/column_data.hpp index 400daeaa6..ab8a5970e 100644 --- a/src/duckdb/src/include/duckdb/storage/table/column_data.hpp +++ b/src/duckdb/src/include/duckdb/storage/table/column_data.hpp @@ -39,14 +39,16 @@ struct PersistentColumnData; using column_segment_vector_t = vector>; struct ColumnCheckpointInfo { - ColumnCheckpointInfo(RowGroupWriteInfo &info, idx_t column_idx) : info(info), column_idx(column_idx) { - } + ColumnCheckpointInfo(RowGroupWriteInfo &info, idx_t column_idx); - RowGroupWriteInfo &info; idx_t column_idx; public: + PartialBlockManager &GetPartialBlockManager(); CompressionType GetCompressionType(); + +private: + RowGroupWriteInfo &info; }; class ColumnData { @@ -154,10 +156,10 @@ class ColumnData { virtual void FetchRow(TransactionData transaction, ColumnFetchState &state, row_t row_id, Vector &result, idx_t result_idx); - virtual void Update(TransactionData transaction, idx_t column_index, Vector &update_vector, row_t *row_ids, - idx_t update_count); - virtual void UpdateColumn(TransactionData transaction, const vector &column_path, Vector &update_vector, - row_t *row_ids, idx_t update_count, idx_t depth); + virtual void Update(TransactionData transaction, DataTable &data_table, idx_t column_index, Vector &update_vector, + row_t *row_ids, idx_t update_count); + virtual void UpdateColumn(TransactionData transaction, DataTable &data_table, const vector &column_path, + Vector &update_vector, row_t *row_ids, idx_t update_count, idx_t depth); virtual unique_ptr GetUpdateStatistics(); virtual void CommitDropColumn(); @@ -178,7 +180,8 @@ class ColumnData { static shared_ptr Deserialize(BlockManager &block_manager, DataTableInfo &info, idx_t column_index, idx_t start_row, ReadStream &source, const LogicalType &type); - virtual void GetColumnSegmentInfo(idx_t row_group_index, vector col_path, vector &result); + virtual void GetColumnSegmentInfo(const QueryContext &context, idx_t row_group_index, vector col_path, + vector &result); virtual void Verify(RowGroup &parent); FilterPropagateResult CheckZonemap(TableFilter &filter); @@ -217,8 +220,8 @@ class ColumnData { void FetchUpdates(TransactionData transaction, idx_t vector_index, Vector &result, idx_t scan_count, bool allow_updates, bool scan_committed); void FetchUpdateRow(TransactionData transaction, row_t row_id, Vector &result, idx_t result_idx); - void UpdateInternal(TransactionData transaction, idx_t column_index, Vector &update_vector, row_t *row_ids, - idx_t update_count, Vector &base_vector); + void UpdateInternal(TransactionData transaction, DataTable &data_table, idx_t column_index, Vector &update_vector, + row_t *row_ids, idx_t update_count, Vector &base_vector); idx_t FetchUpdateData(ColumnScanState &state, row_t *row_ids, Vector &base_vector); idx_t GetVectorCount(idx_t vector_index) const; diff --git a/src/duckdb/src/include/duckdb/storage/table/column_segment.hpp b/src/duckdb/src/include/duckdb/storage/table/column_segment.hpp index 61b2c0d4f..e99664958 100644 --- a/src/duckdb/src/include/duckdb/storage/table/column_segment.hpp +++ b/src/duckdb/src/include/duckdb/storage/table/column_segment.hpp @@ -29,7 +29,6 @@ class DatabaseInstance; class TableFilter; class Transaction; class UpdateSegment; - struct ColumnAppendState; struct ColumnFetchState; struct ColumnScanState; diff --git a/src/duckdb/src/include/duckdb/storage/table/list_column_data.hpp b/src/duckdb/src/include/duckdb/storage/table/list_column_data.hpp index c8e75d136..621ece451 100644 --- a/src/duckdb/src/include/duckdb/storage/table/list_column_data.hpp +++ b/src/duckdb/src/include/duckdb/storage/table/list_column_data.hpp @@ -46,10 +46,10 @@ class ListColumnData : public ColumnData { idx_t Fetch(ColumnScanState &state, row_t row_id, Vector &result) override; void FetchRow(TransactionData transaction, ColumnFetchState &state, row_t row_id, Vector &result, idx_t result_idx) override; - void Update(TransactionData transaction, idx_t column_index, Vector &update_vector, row_t *row_ids, - idx_t update_count) override; - void UpdateColumn(TransactionData transaction, const vector &column_path, Vector &update_vector, - row_t *row_ids, idx_t update_count, idx_t depth) override; + void Update(TransactionData transaction, DataTable &data_table, idx_t column_index, Vector &update_vector, + row_t *row_ids, idx_t update_count) override; + void UpdateColumn(TransactionData transaction, DataTable &data_table, const vector &column_path, + Vector &update_vector, row_t *row_ids, idx_t update_count, idx_t depth) override; unique_ptr GetUpdateStatistics() override; void CommitDropColumn() override; @@ -63,8 +63,8 @@ class ListColumnData : public ColumnData { PersistentColumnData Serialize() override; void InitializeColumn(PersistentColumnData &column_data, BaseStatistics &target_stats) override; - void GetColumnSegmentInfo(duckdb::idx_t row_group_index, vector col_path, - vector &result) override; + void GetColumnSegmentInfo(const QueryContext &context, duckdb::idx_t row_group_index, + vector col_path, vector &result) override; private: uint64_t FetchListOffset(idx_t row_idx); diff --git a/src/duckdb/src/include/duckdb/storage/table/row_group.hpp b/src/duckdb/src/include/duckdb/storage/table/row_group.hpp index 242e19121..2498cca70 100644 --- a/src/duckdb/src/include/duckdb/storage/table/row_group.hpp +++ b/src/duckdb/src/include/duckdb/storage/table/row_group.hpp @@ -50,22 +50,34 @@ class MetadataManager; class RowVersionManager; class ScanFilterInfo; class StorageCommitState; +template +struct SegmentNode; struct RowGroupWriteInfo { RowGroupWriteInfo(PartialBlockManager &manager, const vector &compression_types, - CheckpointType checkpoint_type = CheckpointType::FULL_CHECKPOINT) - : manager(manager), compression_types(compression_types), checkpoint_type(checkpoint_type) { - } + CheckpointType checkpoint_type = CheckpointType::FULL_CHECKPOINT); + RowGroupWriteInfo(PartialBlockManager &manager, const vector &compression_types, + vector> &column_partial_block_managers_p); +private: PartialBlockManager &manager; + +public: const vector &compression_types; CheckpointType checkpoint_type; + +public: + PartialBlockManager &GetPartialBlockManager(idx_t column_idx); + +private: + optional_ptr>> column_partial_block_managers; }; struct RowGroupWriteData { vector> states; vector statistics; - vector existing_pointers; + bool reuse_existing_metadata_blocks = false; + vector existing_extra_metadata_blocks; }; class RowGroup : public SegmentBase { @@ -94,7 +106,10 @@ class RowGroup : public SegmentBase { return collection.get(); } //! Returns the list of meta block pointers used by the columns - vector GetColumnPointers(); + vector GetOrComputeExtraMetadataBlocks(bool force_compute = false); + + const vector &GetColumnStartPointers() const; + //! Returns the list of meta block pointers used by the deletes const vector &GetDeletesPointers() const { return deletes_pointers; @@ -104,7 +119,7 @@ class RowGroup : public SegmentBase { unique_ptr AlterType(RowGroupCollection &collection, const LogicalType &target_type, idx_t changed_idx, ExpressionExecutor &executor, CollectionScanState &scan_state, - DataChunk &scan_chunk); + SegmentNode &node, DataChunk &scan_chunk); unique_ptr AddColumn(RowGroupCollection &collection, ColumnDefinition &new_column, ExpressionExecutor &executor, Vector &intermediate); unique_ptr RemoveColumn(RowGroupCollection &collection, idx_t removed_column); @@ -116,8 +131,8 @@ class RowGroup : public SegmentBase { bool HasChanges() const; //! Initialize a scan over this row_group - bool InitializeScan(CollectionScanState &state); - bool InitializeScanWithOffset(CollectionScanState &state, idx_t vector_offset); + bool InitializeScan(CollectionScanState &state, SegmentNode &node); + bool InitializeScanWithOffset(CollectionScanState &state, SegmentNode &node, idx_t vector_offset); //! Checks the given set of table filters against the row-group statistics. Returns false if the entire row group //! can be skipped. bool CheckZonemap(ScanFilterInfo &filters); @@ -162,19 +177,19 @@ class RowGroup : public SegmentBase { void InitializeAppend(RowGroupAppendState &append_state); void Append(RowGroupAppendState &append_state, DataChunk &chunk, idx_t append_count); - void Update(TransactionData transaction, DataChunk &updates, row_t *ids, idx_t offset, idx_t count, - const vector &column_ids); + void Update(TransactionData transaction, DataTable &data_table, DataChunk &updates, row_t *ids, idx_t offset, + idx_t count, const vector &column_ids); //! Update a single column; corresponds to DataTable::UpdateColumn //! This method should only be called from the WAL - void UpdateColumn(TransactionData transaction, DataChunk &updates, Vector &row_ids, idx_t offset, idx_t count, - const vector &column_path); + void UpdateColumn(TransactionData transaction, DataTable &data_table, DataChunk &updates, Vector &row_ids, + idx_t offset, idx_t count, const vector &column_path); void MergeStatistics(idx_t column_idx, const BaseStatistics &other); void MergeIntoStatistics(idx_t column_idx, BaseStatistics &other); void MergeIntoStatistics(TableStatistics &other); unique_ptr GetStatistics(idx_t column_idx); - void GetColumnSegmentInfo(idx_t row_group_index, vector &result); + void GetColumnSegmentInfo(const QueryContext &context, idx_t row_group_index, vector &result); PartitionStatistics GetPartitionStats() const; idx_t GetAllocationSize() const { diff --git a/src/duckdb/src/include/duckdb/storage/table/row_group_collection.hpp b/src/duckdb/src/include/duckdb/storage/table/row_group_collection.hpp index 32808ff4c..ddd22d29d 100644 --- a/src/duckdb/src/include/duckdb/storage/table/row_group_collection.hpp +++ b/src/duckdb/src/include/duckdb/storage/table/row_group_collection.hpp @@ -36,6 +36,7 @@ struct CollectionCheckpointState; struct PersistentCollectionData; class CheckpointTask; class TableIOManager; +class DataTable; class RowGroupCollection { public: @@ -61,13 +62,14 @@ class RowGroupCollection { void Verify(); void Destroy(); - void InitializeScan(CollectionScanState &state, const vector &column_ids, + void InitializeScan(const QueryContext &context, CollectionScanState &state, const vector &column_ids, optional_ptr table_filters); void InitializeCreateIndexScan(CreateIndexScanState &state); - void InitializeScanWithOffset(CollectionScanState &state, const vector &column_ids, idx_t start_row, - idx_t end_row); - static bool InitializeScanInRowGroup(CollectionScanState &state, RowGroupCollection &collection, - RowGroup &row_group, idx_t vector_index, idx_t max_row); + void InitializeScanWithOffset(const QueryContext &context, CollectionScanState &state, + const vector &column_ids, idx_t start_row, idx_t end_row); + static bool InitializeScanInRowGroup(const QueryContext &context, CollectionScanState &state, + RowGroupCollection &collection, SegmentNode &row_group, + idx_t vector_index, idx_t max_row); void InitializeParallelScan(ParallelCollectionScanState &state); bool NextParallelScan(ClientContext &context, ParallelCollectionScanState &state, CollectionScanState &scan_state); @@ -97,17 +99,18 @@ class RowGroupCollection { optional_ptr commit_state); bool IsPersistent() const; - void RemoveFromIndexes(TableIndexList &indexes, Vector &row_identifiers, idx_t count); + void RemoveFromIndexes(const QueryContext &context, TableIndexList &indexes, Vector &row_identifiers, idx_t count); idx_t Delete(TransactionData transaction, DataTable &table, row_t *ids, idx_t count); - void Update(TransactionData transaction, row_t *ids, const vector &column_ids, DataChunk &updates); - void UpdateColumn(TransactionData transaction, Vector &row_ids, const vector &column_path, - DataChunk &updates); + void Update(TransactionData transaction, DataTable &table, row_t *ids, const vector &column_ids, + DataChunk &updates); + void UpdateColumn(TransactionData transaction, DataTable &table, Vector &row_ids, + const vector &column_path, DataChunk &updates); void Checkpoint(TableDataWriter &writer, TableStatistics &global_stats); void InitializeVacuumState(CollectionCheckpointState &checkpoint_state, VacuumState &state, - vector> &segments); + vector>> &segments); bool ScheduleVacuumTasks(CollectionCheckpointState &checkpoint_state, VacuumState &state, idx_t segment_idx, bool schedule_vacuum); unique_ptr GetCheckpointTask(CollectionCheckpointState &checkpoint_state, idx_t segment_idx); @@ -116,7 +119,7 @@ class RowGroupCollection { void CommitDropTable(); vector GetPartitionStats() const; - vector GetColumnSegmentInfo(); + vector GetColumnSegmentInfo(const QueryContext &context); const vector &GetTypes() const; shared_ptr AddColumn(ClientContext &context, ColumnDefinition &new_column, @@ -124,7 +127,7 @@ class RowGroupCollection { shared_ptr RemoveColumn(idx_t col_idx); shared_ptr AlterType(ClientContext &context, idx_t changed_idx, const LogicalType &target_type, vector bound_columns, Expression &cast_expr); - void VerifyNewConstraint(DataTable &parent, const BoundConstraint &constraint); + void VerifyNewConstraint(const QueryContext &context, DataTable &parent, const BoundConstraint &constraint); void CopyStats(TableStatistics &stats); unique_ptr CopyStats(column_t column_id); @@ -152,7 +155,7 @@ class RowGroupCollection { private: bool IsEmpty(SegmentLock &) const; - optional_ptr NextUpdateRowGroup(row_t *ids, idx_t &pos, idx_t count) const; + optional_ptr> NextUpdateRowGroup(row_t *ids, idx_t &pos, idx_t count) const; private: //! BlockManager diff --git a/src/duckdb/src/include/duckdb/storage/table/row_id_column_data.hpp b/src/duckdb/src/include/duckdb/storage/table/row_id_column_data.hpp index 3bc6572ae..f839c6b24 100644 --- a/src/duckdb/src/include/duckdb/storage/table/row_id_column_data.hpp +++ b/src/duckdb/src/include/duckdb/storage/table/row_id_column_data.hpp @@ -48,10 +48,10 @@ class RowIdColumnData : public ColumnData { void AppendData(BaseStatistics &stats, ColumnAppendState &state, UnifiedVectorFormat &vdata, idx_t count) override; void RevertAppend(row_t start_row) override; - void Update(TransactionData transaction, idx_t column_index, Vector &update_vector, row_t *row_ids, - idx_t update_count) override; - void UpdateColumn(TransactionData transaction, const vector &column_path, Vector &update_vector, - row_t *row_ids, idx_t update_count, idx_t depth) override; + void Update(TransactionData transaction, DataTable &data_table, idx_t column_index, Vector &update_vector, + row_t *row_ids, idx_t update_count) override; + void UpdateColumn(TransactionData transaction, DataTable &data_table, const vector &column_path, + Vector &update_vector, row_t *row_ids, idx_t update_count, idx_t depth) override; void CommitDropColumn() override; diff --git a/src/duckdb/src/include/duckdb/storage/table/row_version_manager.hpp b/src/duckdb/src/include/duckdb/storage/table/row_version_manager.hpp index bb0d0056b..49ab2f40b 100644 --- a/src/duckdb/src/include/duckdb/storage/table/row_version_manager.hpp +++ b/src/duckdb/src/include/duckdb/storage/table/row_version_manager.hpp @@ -12,21 +12,22 @@ #include "duckdb/storage/table/chunk_info.hpp" #include "duckdb/storage/storage_info.hpp" #include "duckdb/common/mutex.hpp" +#include "duckdb/execution/index/fixed_size_allocator.hpp" namespace duckdb { struct DeleteInfo; class MetadataManager; +class BufferManager; struct MetaBlockPointer; class RowVersionManager { public: - explicit RowVersionManager(idx_t start) noexcept; + explicit RowVersionManager(BufferManager &buffer_manager) noexcept; - idx_t GetStart() { - return start; + FixedSizeAllocator &GetAllocator() { + return allocator; } - void SetStart(idx_t start); idx_t GetCommittedDeletedCount(idx_t count); idx_t GetSelVector(TransactionData transaction, idx_t vector_idx, SelectionVector &sel_vector, idx_t max_count); @@ -43,12 +44,11 @@ class RowVersionManager { void CommitDelete(idx_t vector_idx, transaction_t commit_id, const DeleteInfo &info); vector Checkpoint(MetadataManager &manager); - static shared_ptr Deserialize(MetaBlockPointer delete_pointer, MetadataManager &manager, - idx_t start); + static shared_ptr Deserialize(MetaBlockPointer delete_pointer, MetadataManager &manager); private: mutex version_lock; - idx_t start; + FixedSizeAllocator allocator; vector> vector_info; bool has_changes; vector storage_pointers; diff --git a/src/duckdb/src/include/duckdb/storage/table/scan_state.hpp b/src/duckdb/src/include/duckdb/storage/table/scan_state.hpp index 97416cbf2..4d6f9448d 100644 --- a/src/duckdb/src/include/duckdb/storage/table/scan_state.hpp +++ b/src/duckdb/src/include/duckdb/storage/table/scan_state.hpp @@ -42,6 +42,8 @@ struct AdaptiveFilterState; struct TableScanOptions; struct ScanSamplingInfo; struct TableFilterState; +template +struct SegmentNode; struct SegmentScanState { virtual ~SegmentScanState() { @@ -78,8 +80,10 @@ struct IndexScanState { typedef unordered_map buffer_handle_set_t; struct ColumnScanState { + //! The query context for this scan + QueryContext context; //! The column segment that is currently being scanned - ColumnSegment *current = nullptr; + optional_ptr> current; //! Column segment tree ColumnSegmentTree *segment_tree = nullptr; //! The current row index of the scan @@ -105,9 +109,9 @@ struct ColumnScanState { optional_ptr scan_options; public: - void Initialize(const LogicalType &type, const vector &children, + void Initialize(const QueryContext &context_p, const LogicalType &type, const vector &children, optional_ptr options); - void Initialize(const LogicalType &type, optional_ptr options); + void Initialize(const QueryContext &context_p, const LogicalType &type, optional_ptr options); //! Move the scan state forward by "count" rows (including all child states) void Next(idx_t count); //! Move ONLY this state forward by "count" rows (i.e. not the child states) @@ -115,6 +119,8 @@ struct ColumnScanState { }; struct ColumnFetchState { + //! The query context for this fetch + QueryContext context; //! The set of pinned block handles for this set of fetches buffer_handle_set_t handles; //! Any child states of the fetch @@ -178,12 +184,36 @@ class ScanFilterInfo { idx_t always_true_filters = 0; }; +enum class OrderByStatistics { MIN, MAX }; +enum class RowGroupOrderType { ASC, DESC }; +enum class OrderByColumnType { NUMERIC, STRING }; + +class RowGroupReorderer { +public: + explicit RowGroupReorderer(const RowGroupOrderOptions &options); + optional_ptr> GetRootSegment(RowGroupSegmentTree &row_groups); + optional_ptr> GetNextRowGroup(SegmentNode &row_group); + +private: + const column_t column_idx; + const OrderByStatistics order_by; + const RowGroupOrderType order_type; + const OrderByColumnType column_type; + + idx_t offset; + bool initialized; + vector>> ordered_row_groups; + +private: + static Value RetrieveStat(const BaseStatistics &stats, OrderByStatistics order_by, OrderByColumnType column_type); +}; + class CollectionScanState { public: explicit CollectionScanState(TableScanState &parent_p); //! The current row_group we are scanning - RowGroup *row_group; + optional_ptr> row_group; //! The vector index within the row_group idx_t vector_index; //! The maximum row within the row group @@ -201,12 +231,18 @@ class CollectionScanState { RandomEngine random; + //! Optional state for custom row group ordering + unique_ptr reorderer; + public: - void Initialize(const vector &types); + void Initialize(const QueryContext &context, const vector &types); const vector &GetColumnIds(); ScanFilterInfo &GetFilterInfo(); ScanSamplingInfo &GetSamplingInfo(); TableScanOptions &GetOptions(); + optional_ptr> GetNextRowGroup(SegmentNode &row_group) const; + optional_ptr> GetNextRowGroup(SegmentLock &l, SegmentNode &row_group) const; + optional_ptr> GetRootSegment() const; bool Scan(DuckTransaction &transaction, DataChunk &result); bool ScanCommitted(DataChunk &result, TableScanType type); bool ScanCommitted(DataChunk &result, SegmentLock &l, TableScanType type); @@ -272,15 +308,21 @@ class TableScanState { struct ParallelCollectionScanState { ParallelCollectionScanState(); + optional_ptr> GetRootSegment(RowGroupSegmentTree &row_groups) const; + optional_ptr> GetNextRowGroup(RowGroupSegmentTree &row_groups, + SegmentNode &row_group) const; //! The row group collection we are scanning RowGroupCollection *collection; - RowGroup *current_row_group; + optional_ptr> current_row_group; idx_t vector_index; idx_t max_row; idx_t batch_index; atomic processed_rows; mutex lock; + + //! Optional state for custom row group ordering + unique_ptr reorderer; }; struct ParallelTableScanState { diff --git a/src/duckdb/src/include/duckdb/storage/table/segment_base.hpp b/src/duckdb/src/include/duckdb/storage/table/segment_base.hpp index b71587bf8..57002ebb4 100644 --- a/src/duckdb/src/include/duckdb/storage/table/segment_base.hpp +++ b/src/duckdb/src/include/duckdb/storage/table/segment_base.hpp @@ -16,28 +16,13 @@ namespace duckdb { template class SegmentBase { public: - SegmentBase(idx_t start, idx_t count) : start(start), count(count), next(nullptr) { - } - T *Next() { -#ifndef DUCKDB_R_BUILD - return next.load(); -#else - return next; -#endif + SegmentBase(idx_t start, idx_t count) : start(start), count(count) { } //! The start row id of this chunk idx_t start; //! The amount of entries in this storage chunk atomic count; - //! The next segment after this one -#ifndef DUCKDB_R_BUILD - atomic next; -#else - T *next; -#endif - //! The index within the segment tree - idx_t index; }; } // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/storage/table/segment_tree.hpp b/src/duckdb/src/include/duckdb/storage/table/segment_tree.hpp index f427a5275..9a0427391 100644 --- a/src/duckdb/src/include/duckdb/storage/table/segment_tree.hpp +++ b/src/duckdb/src/include/duckdb/storage/table/segment_tree.hpp @@ -19,8 +19,28 @@ namespace duckdb { template struct SegmentNode { + SegmentNode() : next(nullptr) { + } + idx_t row_start; unique_ptr node; + //! The next segment after this one +#ifndef DUCKDB_R_BUILD + atomic *> next; +#else + SegmentNode *next; +#endif + //! The index within the segment tree + idx_t index; + +public: + optional_ptr> Next() { +#ifndef DUCKDB_R_BUILD + return next.load(); +#else + return next; +#endif + } }; //! The SegmentTree maintains a list of all segments of a specific column in a table, and allows searching for a segment @@ -29,6 +49,7 @@ template class SegmentTree { private: class SegmentIterationHelper; + class SegmentNodeIterationHelper; public: explicit SegmentTree() : finished_loading(true) { @@ -47,39 +68,39 @@ class SegmentTree { } //! Gets a pointer to the first segment. Useful for scans. - T *GetRootSegment() { + optional_ptr> GetRootSegment() { auto l = Lock(); return GetRootSegment(l); } - T *GetRootSegment(SegmentLock &l) { + optional_ptr> GetRootSegment(SegmentLock &l) { if (nodes.empty()) { LoadNextSegment(l); } return GetRootSegmentInternal(); } //! Obtains ownership of the data of the segment tree - vector> MoveSegments(SegmentLock &l) { + vector>> MoveSegments(SegmentLock &l) { LoadAllSegments(l); return std::move(nodes); } - vector> MoveSegments() { + vector>> MoveSegments() { auto l = Lock(); return MoveSegments(l); } - const vector> &ReferenceSegments(SegmentLock &l) { + const vector>> &ReferenceSegments(SegmentLock &l) { LoadAllSegments(l); return nodes; } - const vector> &ReferenceSegments() { + const vector>> &ReferenceSegments() { auto l = Lock(); return ReferenceSegments(l); } - vector> &ReferenceLoadedSegmentsMutable(SegmentLock &l) { + vector>> &ReferenceLoadedSegmentsMutable(SegmentLock &l) { return nodes; } - const vector> &ReferenceLoadedSegments(SegmentLock &l) const { + const vector>> &ReferenceLoadedSegments(SegmentLock &l) const { return nodes; } @@ -91,11 +112,11 @@ class SegmentTree { return nodes.size(); } //! Gets a pointer to the nth segment. Negative numbers start from the back. - T *GetSegmentByIndex(int64_t index) { + optional_ptr> GetSegmentByIndex(int64_t index) { auto l = Lock(); return GetSegmentByIndex(l, index); } - T *GetSegmentByIndex(SegmentLock &l, int64_t index) { + optional_ptr> GetSegmentByIndex(SegmentLock &l, int64_t index) { if (index < 0) { // load all segments LoadAllSegments(l); @@ -103,7 +124,7 @@ class SegmentTree { if (index < 0) { return nullptr; } - return nodes[UnsafeNumericCast(index)].node.get(); + return nodes[UnsafeNumericCast(index)].get(); } else { // lazily load segments until we reach the specific segment while (idx_t(index) >= nodes.size() && LoadNextSegment(l)) { @@ -111,59 +132,56 @@ class SegmentTree { if (idx_t(index) >= nodes.size()) { return nullptr; } - return nodes[UnsafeNumericCast(index)].node.get(); + return nodes[UnsafeNumericCast(index)].get(); } } //! Gets the next segment - T *GetNextSegment(T *segment) { + optional_ptr> GetNextSegment(SegmentNode &node) { if (!SUPPORTS_LAZY_LOADING) { - return segment->Next(); + return node.Next(); } if (finished_loading) { - return segment->Next(); + return node.Next(); } auto l = Lock(); - return GetNextSegment(l, segment); + return GetNextSegment(l, node); } - T *GetNextSegment(SegmentLock &l, T *segment) { - if (!segment) { - return nullptr; - } + optional_ptr> GetNextSegment(SegmentLock &l, SegmentNode &node) { #ifdef DEBUG - D_ASSERT(nodes[segment->index].node.get() == segment); + D_ASSERT(RefersToSameObject(*nodes[node.index], node)); #endif - return GetSegmentByIndex(l, UnsafeNumericCast(segment->index + 1)); + return GetSegmentByIndex(l, UnsafeNumericCast(node.index + 1)); } //! Gets a pointer to the last segment. Useful for appends. - T *GetLastSegment(SegmentLock &l) { + optional_ptr> GetLastSegment(SegmentLock &l) { LoadAllSegments(l); if (nodes.empty()) { return nullptr; } - return nodes.back().node.get(); + return nodes.back().get(); } //! Gets a pointer to a specific column segment for the given row - T *GetSegment(idx_t row_number) { + optional_ptr> GetSegment(idx_t row_number) { auto l = Lock(); return GetSegment(l, row_number); } - T *GetSegment(SegmentLock &l, idx_t row_number) { - return nodes[GetSegmentIndex(l, row_number)].node.get(); + optional_ptr> GetSegment(SegmentLock &l, idx_t row_number) { + return nodes[GetSegmentIndex(l, row_number)].get(); } //! Append a column segment to the tree void AppendSegmentInternal(SegmentLock &l, unique_ptr segment) { D_ASSERT(segment); // add the node to the list of nodes + auto node = make_uniq>(); + node->row_start = segment->start; + node->node = std::move(segment); + node->index = nodes.size(); + node->next = nullptr; if (!nodes.empty()) { - nodes.back().node->next = segment.get(); + nodes.back()->next = node.get(); } - SegmentNode node; - segment->index = nodes.size(); - segment->next = nullptr; - node.row_start = segment->start; - node.node = std::move(segment); nodes.push_back(std::move(node)); } void AppendSegment(unique_ptr segment) { @@ -175,12 +193,12 @@ class SegmentTree { AppendSegmentInternal(l, std::move(segment)); } //! Debug method, check whether the segment is in the segment tree - bool HasSegment(T *segment) { + bool HasSegment(SegmentNode &segment) { auto l = Lock(); return HasSegment(l, segment); } - bool HasSegment(SegmentLock &, T *segment) { - return segment->index < nodes.size() && nodes[segment->index].node.get() == segment; + bool HasSegment(SegmentLock &, SegmentNode &segment) { + return segment.index < nodes.size() && RefersToSameObject(*nodes[segment.index], segment); } //! Erase all segments after a specific segment @@ -201,15 +219,15 @@ class SegmentTree { string error; error = StringUtil::Format("Attempting to find row number \"%lld\" in %lld nodes\n", row_number, nodes.size()); for (idx_t i = 0; i < nodes.size(); i++) { - error += StringUtil::Format("Node %lld: Start %lld, Count %lld", i, nodes[i].row_start, - nodes[i].node->count.load()); + error += StringUtil::Format("Node %lld: Start %lld, Count %lld", i, nodes[i]->row_start, + nodes[i]->node->count.load()); } throw InternalException("Could not find node in column segment tree!\n%s", error); } bool TryGetSegmentIndex(SegmentLock &l, idx_t row_number, idx_t &result) { // load segments until the row number is within bounds - while (nodes.empty() || (row_number >= (nodes.back().row_start + nodes.back().node->count))) { + while (nodes.empty() || (row_number >= (nodes.back()->row_start + nodes.back()->node->count))) { if (!LoadNextSegment(l)) { break; } @@ -225,12 +243,12 @@ class SegmentTree { if (index >= nodes.size()) { string segments; for (auto &entry : nodes) { - segments += StringUtil::Format("Start %d Count %d", entry.row_start, entry.node->count.load()); + segments += StringUtil::Format("Start %d Count %d", entry->row_start, entry->node->count.load()); } throw InternalException("Segment tree index not found for row number %d\nSegments:%s", row_number, segments); } - auto &entry = nodes[index]; + auto &entry = *nodes[index]; D_ASSERT(entry.row_start == entry.node->start); if (row_number < entry.row_start) { upper = index - 1; @@ -246,11 +264,11 @@ class SegmentTree { void Verify(SegmentLock &) { #ifdef DEBUG - idx_t base_start = nodes.empty() ? 0 : nodes[0].node->start; + idx_t base_start = nodes.empty() ? 0 : nodes[0]->node->start; for (idx_t i = 0; i < nodes.size(); i++) { - D_ASSERT(nodes[i].row_start == nodes[i].node->start); - D_ASSERT(nodes[i].node->start == base_start); - base_start += nodes[i].node->count; + D_ASSERT(nodes[i]->row_start == nodes[i]->node->start); + D_ASSERT(nodes[i]->node->start == base_start); + base_start += nodes[i]->node->count; } #endif } @@ -269,17 +287,25 @@ class SegmentTree { return SegmentIterationHelper(*this, l); } + SegmentNodeIterationHelper SegmentNodes() { + return SegmentNodeIterationHelper(*this); + } + + SegmentNodeIterationHelper SegmentNodes(SegmentLock &l) { + return SegmentNodeIterationHelper(*this, l); + } + void Reinitialize() { if (nodes.empty()) { return; } - idx_t offset = nodes[0].node->start; + idx_t offset = nodes[0]->node->start; for (auto &entry : nodes) { - if (entry.node->start != offset) { + if (entry->node->start != offset) { throw InternalException("In SegmentTree::Reinitialize - gap found between nodes!"); } - entry.row_start = offset; - offset += entry.node->count; + entry->row_start = offset; + offset += entry->node->count; } } @@ -291,17 +317,40 @@ class SegmentTree { return nullptr; } - T *GetRootSegmentInternal() const { - return nodes.empty() ? nullptr : nodes[0].node.get(); + optional_ptr> GetRootSegmentInternal() const { + return nodes.empty() ? nullptr : nodes[0].get(); } private: //! The nodes in the tree, can be binary searched - vector> nodes; + vector>> nodes; //! Lock to access or modify the nodes mutable mutex node_lock; private: + class BaseSegmentIterator { + public: + BaseSegmentIterator(SegmentTree &tree_p, optional_ptr> current_p, optional_ptr lock) + : tree(tree_p), current(current_p), lock(lock) { + } + + SegmentTree &tree; + optional_ptr> current; + optional_ptr lock; + + public: + void Next() { + current = lock ? tree.GetNextSegment(*lock, *current) : tree.GetNextSegment(*current); + } + + BaseSegmentIterator &operator++() { + Next(); + return *this; + } + bool operator!=(const BaseSegmentIterator &other) const { + return current != other.current; + } + }; class SegmentIterationHelper { public: explicit SegmentIterationHelper(SegmentTree &tree) : tree(tree) { @@ -314,31 +363,46 @@ class SegmentTree { optional_ptr lock; private: - class SegmentIterator { + class SegmentIterator : public BaseSegmentIterator { public: - SegmentIterator(SegmentTree &tree_p, T *current_p, optional_ptr lock) - : tree(tree_p), current(current_p), lock(lock) { + SegmentIterator(SegmentTree &tree_p, optional_ptr> current_p, optional_ptr lock) + : BaseSegmentIterator(tree_p, current_p, lock) { } - SegmentTree &tree; - T *current; - optional_ptr lock; + T &operator*() const { + return *BaseSegmentIterator::current->node; + } + }; + + public: + SegmentIterator begin() { // NOLINT: match stl API + auto root = lock ? tree.GetRootSegment(*lock) : tree.GetRootSegment(); + return SegmentIterator(tree, root, lock); + } + SegmentIterator end() { // NOLINT: match stl API + return SegmentIterator(tree, nullptr, lock); + } + }; + class SegmentNodeIterationHelper { + public: + explicit SegmentNodeIterationHelper(SegmentTree &tree) : tree(tree) { + } + SegmentNodeIterationHelper(SegmentTree &tree, SegmentLock &l) : tree(tree), lock(l) { + } + + private: + SegmentTree &tree; + optional_ptr lock; + private: + class SegmentIterator : public BaseSegmentIterator { public: - void Next() { - current = lock ? tree.GetNextSegment(*lock, current) : tree.GetNextSegment(current); + SegmentIterator(SegmentTree &tree_p, optional_ptr> current_p, optional_ptr lock) + : BaseSegmentIterator(tree_p, current_p, lock) { } - SegmentIterator &operator++() { - Next(); - return *this; - } - bool operator!=(const SegmentIterator &other) const { - return current != other.current; - } - T &operator*() const { - D_ASSERT(current); - return *current; + SegmentNode &operator*() { + return *BaseSegmentIterator::current; } }; diff --git a/src/duckdb/src/include/duckdb/storage/table/standard_column_data.hpp b/src/duckdb/src/include/duckdb/storage/table/standard_column_data.hpp index 48ac6ccb7..ec06eb30a 100644 --- a/src/duckdb/src/include/duckdb/storage/table/standard_column_data.hpp +++ b/src/duckdb/src/include/duckdb/storage/table/standard_column_data.hpp @@ -47,10 +47,10 @@ class StandardColumnData : public ColumnData { idx_t Fetch(ColumnScanState &state, row_t row_id, Vector &result) override; void FetchRow(TransactionData transaction, ColumnFetchState &state, row_t row_id, Vector &result, idx_t result_idx) override; - void Update(TransactionData transaction, idx_t column_index, Vector &update_vector, row_t *row_ids, - idx_t update_count) override; - void UpdateColumn(TransactionData transaction, const vector &column_path, Vector &update_vector, - row_t *row_ids, idx_t update_count, idx_t depth) override; + void Update(TransactionData transaction, DataTable &data_table, idx_t column_index, Vector &update_vector, + row_t *row_ids, idx_t update_count) override; + void UpdateColumn(TransactionData transaction, DataTable &data_table, const vector &column_path, + Vector &update_vector, row_t *row_ids, idx_t update_count, idx_t depth) override; unique_ptr GetUpdateStatistics() override; void CommitDropColumn() override; @@ -61,8 +61,8 @@ class StandardColumnData : public ColumnData { void CheckpointScan(ColumnSegment &segment, ColumnScanState &state, idx_t row_group_start, idx_t count, Vector &scan_vector) override; - void GetColumnSegmentInfo(duckdb::idx_t row_group_index, vector col_path, - vector &result) override; + void GetColumnSegmentInfo(const QueryContext &context, duckdb::idx_t row_group_index, + vector col_path, vector &result) override; bool IsPersistent() override; bool HasAnyChanges() const override; diff --git a/src/duckdb/src/include/duckdb/storage/table/struct_column_data.hpp b/src/duckdb/src/include/duckdb/storage/table/struct_column_data.hpp index d05436bfc..798a21326 100644 --- a/src/duckdb/src/include/duckdb/storage/table/struct_column_data.hpp +++ b/src/duckdb/src/include/duckdb/storage/table/struct_column_data.hpp @@ -46,10 +46,10 @@ class StructColumnData : public ColumnData { idx_t Fetch(ColumnScanState &state, row_t row_id, Vector &result) override; void FetchRow(TransactionData transaction, ColumnFetchState &state, row_t row_id, Vector &result, idx_t result_idx) override; - void Update(TransactionData transaction, idx_t column_index, Vector &update_vector, row_t *row_ids, - idx_t update_count) override; - void UpdateColumn(TransactionData transaction, const vector &column_path, Vector &update_vector, - row_t *row_ids, idx_t update_count, idx_t depth) override; + void Update(TransactionData transaction, DataTable &data_table, idx_t column_index, Vector &update_vector, + row_t *row_ids, idx_t update_count) override; + void UpdateColumn(TransactionData transaction, DataTable &data_table, const vector &column_path, + Vector &update_vector, row_t *row_ids, idx_t update_count, idx_t depth) override; unique_ptr GetUpdateStatistics() override; void CommitDropColumn() override; @@ -63,8 +63,8 @@ class StructColumnData : public ColumnData { PersistentColumnData Serialize() override; void InitializeColumn(PersistentColumnData &column_data, BaseStatistics &target_stats) override; - void GetColumnSegmentInfo(duckdb::idx_t row_group_index, vector col_path, - vector &result) override; + void GetColumnSegmentInfo(const QueryContext &context, duckdb::idx_t row_group_index, + vector col_path, vector &result) override; void Verify(RowGroup &parent) override; }; diff --git a/src/duckdb/src/include/duckdb/storage/table/update_segment.hpp b/src/duckdb/src/include/duckdb/storage/table/update_segment.hpp index 75cf25ecf..3f5b9d211 100644 --- a/src/duckdb/src/include/duckdb/storage/table/update_segment.hpp +++ b/src/duckdb/src/include/duckdb/storage/table/update_segment.hpp @@ -38,8 +38,8 @@ class UpdateSegment { void FetchUpdates(TransactionData transaction, idx_t vector_index, Vector &result); void FetchCommitted(idx_t vector_index, Vector &result); void FetchCommittedRange(idx_t start_row, idx_t count, Vector &result); - void Update(TransactionData transaction, idx_t column_index, Vector &update, row_t *ids, idx_t count, - Vector &base_data); + void Update(TransactionData transaction, DataTable &data_table, idx_t column_index, Vector &update, row_t *ids, + idx_t count, Vector &base_data); void FetchRow(TransactionData transaction, idx_t row_id, Vector &result, idx_t result_idx); void RollbackUpdate(UpdateInfo &info); diff --git a/src/duckdb/src/include/duckdb/storage/write_ahead_log.hpp b/src/duckdb/src/include/duckdb/storage/write_ahead_log.hpp index c35820836..1986227ce 100644 --- a/src/duckdb/src/include/duckdb/storage/write_ahead_log.hpp +++ b/src/duckdb/src/include/duckdb/storage/write_ahead_log.hpp @@ -52,8 +52,9 @@ class WriteAheadLog { virtual ~WriteAheadLog(); public: - //! Replay and initialize the WAL - static unique_ptr Replay(FileSystem &fs, AttachedDatabase &database, const string &wal_path); + //! Replay and initialize the WAL, QueryContext is passed for metric collection purposes only!! + static unique_ptr Replay(QueryContext context, FileSystem &fs, AttachedDatabase &database, + const string &wal_path); AttachedDatabase &GetDatabase(); @@ -121,7 +122,9 @@ class WriteAheadLog { void WriteCheckpoint(MetaBlockPointer meta_block); protected: - static unique_ptr ReplayInternal(AttachedDatabase &database, unique_ptr handle); + //! Internally replay all WAL entries. QueryContext is passed for metric collection purposes only!! + static unique_ptr ReplayInternal(QueryContext context, AttachedDatabase &database, + unique_ptr handle); protected: AttachedDatabase &database; diff --git a/src/duckdb/src/include/duckdb/transaction/cleanup_state.hpp b/src/duckdb/src/include/duckdb/transaction/cleanup_state.hpp index 0de2faabd..21f117674 100644 --- a/src/duckdb/src/include/duckdb/transaction/cleanup_state.hpp +++ b/src/duckdb/src/include/duckdb/transaction/cleanup_state.hpp @@ -11,6 +11,7 @@ #include "duckdb/transaction/undo_buffer.hpp" #include "duckdb/common/types/data_chunk.hpp" #include "duckdb/common/unordered_map.hpp" +#include "duckdb/main/client_context.hpp" namespace duckdb { @@ -21,7 +22,7 @@ struct UpdateInfo; class CleanupState { public: - explicit CleanupState(transaction_t lowest_active_transaction); + explicit CleanupState(const QueryContext &context, transaction_t lowest_active_transaction); ~CleanupState(); // all tables with indexes that possibly need a vacuum (after e.g. a delete) @@ -31,6 +32,7 @@ class CleanupState { void CleanupEntry(UndoFlags type, data_ptr_t data); private: + QueryContext context; //! Lowest active transaction transaction_t lowest_active_transaction; // data for index cleanup diff --git a/src/duckdb/src/include/duckdb/transaction/duck_transaction.hpp b/src/duckdb/src/include/duckdb/transaction/duck_transaction.hpp index 12c4d180c..b9080f192 100644 --- a/src/duckdb/src/include/duckdb/transaction/duck_transaction.hpp +++ b/src/duckdb/src/include/duckdb/transaction/duck_transaction.hpp @@ -35,14 +35,12 @@ class DuckTransaction : public Transaction { transaction_t transaction_id; //! The commit id of this transaction, if it has successfully been committed transaction_t commit_id; - //! Highest active query when the transaction finished, used for cleaning up - transaction_t highest_active_query; atomic catalog_version; //! Transactions undergo Cleanup, after (1) removing them directly in RemoveTransaction, - //! or (2) after they exist old_transactions. - //! Some (after rollback) enter old_transactions, but do not require Cleanup. + //! or (2) after they enter cleanup_queue. + //! Some (after rollback) enter cleanup_queue, but do not require Cleanup. bool awaiting_cleanup; public: @@ -76,7 +74,7 @@ class DuckTransaction : public Transaction { idx_t base_row); void PushSequenceUsage(SequenceCatalogEntry &entry, const SequenceData &data); void PushAppend(DataTable &table, idx_t row_start, idx_t row_count); - UndoBufferReference CreateUpdateInfo(idx_t type_size, idx_t entries); + UndoBufferReference CreateUpdateInfo(idx_t type_size, DataTable &data_table, idx_t entries); bool IsDuckTransaction() const override { return true; @@ -90,6 +88,7 @@ class DuckTransaction : public Transaction { //! Get a shared lock on a table shared_ptr SharedLockTable(DataTableInfo &info); + //! Hold an owning reference of the table, needed to safely reference it inside the transaction commit/undo logic void ModifyTable(DataTable &tbl); private: diff --git a/src/duckdb/src/include/duckdb/transaction/duck_transaction_manager.hpp b/src/duckdb/src/include/duckdb/transaction/duck_transaction_manager.hpp index 63531ae7d..a3bf3f47a 100644 --- a/src/duckdb/src/include/duckdb/transaction/duck_transaction_manager.hpp +++ b/src/duckdb/src/include/duckdb/transaction/duck_transaction_manager.hpp @@ -110,8 +110,6 @@ class DuckTransactionManager : public TransactionManager { vector> active_transactions; //! Set of recently committed transactions vector> recently_committed_transactions; - //! Transactions awaiting GC - vector> old_transactions; //! The lock used for transaction operations mutex transaction_lock; //! The checkpoint lock diff --git a/src/duckdb/src/include/duckdb/transaction/local_storage.hpp b/src/duckdb/src/include/duckdb/transaction/local_storage.hpp index 5d29da46c..1cab839d0 100644 --- a/src/duckdb/src/include/duckdb/transaction/local_storage.hpp +++ b/src/duckdb/src/include/duckdb/transaction/local_storage.hpp @@ -40,6 +40,8 @@ class LocalTableStorage : public enable_shared_from_this { ExpressionExecutor &default_executor); ~LocalTableStorage(); + QueryContext context; + reference table_ref; Allocator &allocator; @@ -189,6 +191,10 @@ class LocalStorage { void VerifyNewConstraint(DataTable &parent, const BoundConstraint &constraint); + ClientContext &GetClientContext() const { + return context; + } + private: ClientContext &context; DuckTransaction &transaction; diff --git a/src/duckdb/src/include/duckdb/transaction/update_info.hpp b/src/duckdb/src/include/duckdb/transaction/update_info.hpp index 7cccd923e..5eb139261 100644 --- a/src/duckdb/src/include/duckdb/transaction/update_info.hpp +++ b/src/duckdb/src/include/duckdb/transaction/update_info.hpp @@ -17,6 +17,7 @@ namespace duckdb { class UpdateSegment; struct DataTableInfo; +class DataTable; //! UpdateInfo is a class that represents a set of updates applied to a single vector. //! The UpdateInfo struct contains metadata associated with the update. @@ -26,6 +27,8 @@ struct DataTableInfo; struct UpdateInfo { //! The update segment that this update info affects UpdateSegment *segment; + //! The table this was update was made on + DataTable *table; //! The column index of which column we are updating idx_t column_index; //! The version number @@ -87,7 +90,7 @@ struct UpdateInfo { //! Returns the total allocation size for an UpdateInfo entry, together with space for the tuple data static idx_t GetAllocSize(idx_t type_size); //! Initialize an UpdateInfo struct that has been allocated using GetAllocSize (i.e. has extra space after it) - static void Initialize(UpdateInfo &info, transaction_t transaction_id); + static void Initialize(UpdateInfo &info, DataTable &data_table, transaction_t transaction_id); }; } // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/transaction/wal_write_state.hpp b/src/duckdb/src/include/duckdb/transaction/wal_write_state.hpp index aad1a672c..4c68da487 100644 --- a/src/duckdb/src/include/duckdb/transaction/wal_write_state.hpp +++ b/src/duckdb/src/include/duckdb/transaction/wal_write_state.hpp @@ -31,7 +31,7 @@ class WALWriteState { void CommitEntry(UndoFlags type, data_ptr_t data); private: - void SwitchTable(DataTableInfo *table, UndoFlags new_op); + void SwitchTable(DataTableInfo &table, UndoFlags new_op); void WriteCatalogEntry(CatalogEntry &entry, data_ptr_t extra_data); void WriteDelete(DeleteInfo &info); diff --git a/src/duckdb/src/include/duckdb/verification/statement_verifier.hpp b/src/duckdb/src/include/duckdb/verification/statement_verifier.hpp index a60abf187..77fed9815 100644 --- a/src/duckdb/src/include/duckdb/verification/statement_verifier.hpp +++ b/src/duckdb/src/include/duckdb/verification/statement_verifier.hpp @@ -85,6 +85,8 @@ class StatementVerifier { private: const vector> empty_select_list = {}; + + const vector> &GetSelectList(QueryNode &node); }; } // namespace duckdb diff --git a/src/duckdb/src/include/duckdb_extension.h b/src/duckdb/src/include/duckdb_extension.h index 7c5136059..0bade5f3e 100644 --- a/src/duckdb/src/include/duckdb_extension.h +++ b/src/duckdb/src/include/duckdb_extension.h @@ -544,6 +544,7 @@ typedef struct { duckdb_state (*duckdb_appender_create_query)(duckdb_connection connection, const char *query, idx_t column_count, duckdb_logical_type *types, const char *table_name, const char **column_names, duckdb_appender *out_appender); + duckdb_state (*duckdb_appender_clear)(duckdb_appender appender); #endif // New arrow interface functions @@ -560,6 +561,69 @@ typedef struct { void (*duckdb_destroy_arrow_converted_schema)(duckdb_arrow_converted_schema *arrow_converted_schema); #endif +// New configuration options functions +#ifdef DUCKDB_EXTENSION_API_VERSION_UNSTABLE + duckdb_config_option (*duckdb_create_config_option)(); + void (*duckdb_destroy_config_option)(duckdb_config_option *option); + void (*duckdb_config_option_set_name)(duckdb_config_option option, const char *name); + void (*duckdb_config_option_set_type)(duckdb_config_option option, duckdb_logical_type type); + void (*duckdb_config_option_set_default_value)(duckdb_config_option option, duckdb_value default_value); + void (*duckdb_config_option_set_default_scope)(duckdb_config_option option, + duckdb_config_option_scope default_scope); + void (*duckdb_config_option_set_description)(duckdb_config_option option, const char *description); + duckdb_state (*duckdb_register_config_option)(duckdb_connection connection, duckdb_config_option option); + duckdb_value (*duckdb_client_context_get_config_option)(duckdb_client_context context, const char *name, + duckdb_config_option_scope *out_scope); +#endif + +// API to define custom copy functions +#ifdef DUCKDB_EXTENSION_API_VERSION_UNSTABLE + duckdb_copy_function (*duckdb_create_copy_function)(); + void (*duckdb_copy_function_set_name)(duckdb_copy_function copy_function, const char *name); + void (*duckdb_copy_function_set_extra_info)(duckdb_copy_function copy_function, void *extra_info, + duckdb_delete_callback_t destructor); + duckdb_state (*duckdb_register_copy_function)(duckdb_connection connection, duckdb_copy_function copy_function); + void (*duckdb_destroy_copy_function)(duckdb_copy_function *copy_function); + void (*duckdb_copy_function_set_bind)(duckdb_copy_function copy_function, duckdb_copy_function_bind_t bind); + void (*duckdb_copy_function_bind_set_error)(duckdb_copy_function_bind_info info, const char *error); + void *(*duckdb_copy_function_bind_get_extra_info)(duckdb_copy_function_bind_info info); + duckdb_client_context (*duckdb_copy_function_bind_get_client_context)(duckdb_copy_function_bind_info info); + idx_t (*duckdb_copy_function_bind_get_column_count)(duckdb_copy_function_bind_info info); + duckdb_logical_type (*duckdb_copy_function_bind_get_column_type)(duckdb_copy_function_bind_info info, + idx_t col_idx); + duckdb_value (*duckdb_copy_function_bind_get_options)(duckdb_copy_function_bind_info info); + void (*duckdb_copy_function_bind_set_bind_data)(duckdb_copy_function_bind_info info, void *bind_data, + duckdb_delete_callback_t destructor); + void (*duckdb_copy_function_set_global_init)(duckdb_copy_function copy_function, + duckdb_copy_function_global_init_t init); + void (*duckdb_copy_function_global_init_set_error)(duckdb_copy_function_global_init_info info, const char *error); + void *(*duckdb_copy_function_global_init_get_extra_info)(duckdb_copy_function_global_init_info info); + duckdb_client_context (*duckdb_copy_function_global_init_get_client_context)( + duckdb_copy_function_global_init_info info); + void *(*duckdb_copy_function_global_init_get_bind_data)(duckdb_copy_function_global_init_info info); + void (*duckdb_copy_function_global_init_set_global_state)(duckdb_copy_function_global_init_info info, + void *global_state, duckdb_delete_callback_t destructor); + const char *(*duckdb_copy_function_global_init_get_file_path)(duckdb_copy_function_global_init_info info); + void (*duckdb_copy_function_set_sink)(duckdb_copy_function copy_function, duckdb_copy_function_sink_t function); + void (*duckdb_copy_function_sink_set_error)(duckdb_copy_function_sink_info info, const char *error); + void *(*duckdb_copy_function_sink_get_extra_info)(duckdb_copy_function_sink_info info); + duckdb_client_context (*duckdb_copy_function_sink_get_client_context)(duckdb_copy_function_sink_info info); + void *(*duckdb_copy_function_sink_get_bind_data)(duckdb_copy_function_sink_info info); + void *(*duckdb_copy_function_sink_get_global_state)(duckdb_copy_function_sink_info info); + void (*duckdb_copy_function_set_finalize)(duckdb_copy_function copy_function, + duckdb_copy_function_finalize_t finalize); + void (*duckdb_copy_function_finalize_set_error)(duckdb_copy_function_finalize_info info, const char *error); + void *(*duckdb_copy_function_finalize_get_extra_info)(duckdb_copy_function_finalize_info info); + duckdb_client_context (*duckdb_copy_function_finalize_get_client_context)(duckdb_copy_function_finalize_info info); + void *(*duckdb_copy_function_finalize_get_bind_data)(duckdb_copy_function_finalize_info info); + void *(*duckdb_copy_function_finalize_get_global_state)(duckdb_copy_function_finalize_info info); + void (*duckdb_copy_function_set_copy_from_function)(duckdb_copy_function copy_function, + duckdb_table_function table_function); + idx_t (*duckdb_table_function_bind_get_result_column_count)(duckdb_bind_info info); + const char *(*duckdb_table_function_bind_get_result_column_name)(duckdb_bind_info info, idx_t col_idx); + duckdb_logical_type (*duckdb_table_function_bind_get_result_column_type)(duckdb_bind_info info, idx_t col_idx); +#endif + // New functions for duckdb error data #ifdef DUCKDB_EXTENSION_API_VERSION_UNSTABLE duckdb_error_data (*duckdb_create_error_data)(duckdb_error_type type, const char *message); @@ -643,6 +707,13 @@ typedef struct { char *(*duckdb_value_to_string)(duckdb_value value); #endif +// New functions around the table description +#ifdef DUCKDB_EXTENSION_API_VERSION_UNSTABLE + idx_t (*duckdb_table_description_get_column_count)(duckdb_table_description table_description); + duckdb_logical_type (*duckdb_table_description_get_column_type)(duckdb_table_description table_description, + idx_t index); +#endif + // New functions around table function binding #ifdef DUCKDB_EXTENSION_API_VERSION_UNSTABLE void (*duckdb_table_function_get_client_context)(duckdb_bind_info info, duckdb_client_context *out_context); @@ -1093,6 +1164,7 @@ typedef struct { // Version unstable_new_append_functions #define duckdb_appender_create_query duckdb_ext_api.duckdb_appender_create_query #define duckdb_appender_error_data duckdb_ext_api.duckdb_appender_error_data +#define duckdb_appender_clear duckdb_ext_api.duckdb_appender_clear #define duckdb_append_default_to_chunk duckdb_ext_api.duckdb_append_default_to_chunk // Version unstable_new_arrow_functions @@ -1102,6 +1174,60 @@ typedef struct { #define duckdb_data_chunk_from_arrow duckdb_ext_api.duckdb_data_chunk_from_arrow #define duckdb_destroy_arrow_converted_schema duckdb_ext_api.duckdb_destroy_arrow_converted_schema +// Version unstable_new_config_options_functions +#define duckdb_create_config_option duckdb_ext_api.duckdb_create_config_option +#define duckdb_destroy_config_option duckdb_ext_api.duckdb_destroy_config_option +#define duckdb_config_option_set_name duckdb_ext_api.duckdb_config_option_set_name +#define duckdb_config_option_set_type duckdb_ext_api.duckdb_config_option_set_type +#define duckdb_config_option_set_default_value duckdb_ext_api.duckdb_config_option_set_default_value +#define duckdb_config_option_set_default_scope duckdb_ext_api.duckdb_config_option_set_default_scope +#define duckdb_config_option_set_description duckdb_ext_api.duckdb_config_option_set_description +#define duckdb_register_config_option duckdb_ext_api.duckdb_register_config_option +#define duckdb_client_context_get_config_option duckdb_ext_api.duckdb_client_context_get_config_option + +// Version unstable_new_copy_functions_api +#define duckdb_create_copy_function duckdb_ext_api.duckdb_create_copy_function +#define duckdb_copy_function_set_name duckdb_ext_api.duckdb_copy_function_set_name +#define duckdb_copy_function_set_extra_info duckdb_ext_api.duckdb_copy_function_set_extra_info +#define duckdb_register_copy_function duckdb_ext_api.duckdb_register_copy_function +#define duckdb_destroy_copy_function duckdb_ext_api.duckdb_destroy_copy_function +#define duckdb_copy_function_set_bind duckdb_ext_api.duckdb_copy_function_set_bind +#define duckdb_copy_function_bind_set_error duckdb_ext_api.duckdb_copy_function_bind_set_error +#define duckdb_copy_function_bind_get_extra_info duckdb_ext_api.duckdb_copy_function_bind_get_extra_info +#define duckdb_copy_function_bind_get_client_context duckdb_ext_api.duckdb_copy_function_bind_get_client_context +#define duckdb_copy_function_bind_get_column_count duckdb_ext_api.duckdb_copy_function_bind_get_column_count +#define duckdb_copy_function_bind_get_column_type duckdb_ext_api.duckdb_copy_function_bind_get_column_type +#define duckdb_copy_function_bind_get_options duckdb_ext_api.duckdb_copy_function_bind_get_options +#define duckdb_copy_function_bind_set_bind_data duckdb_ext_api.duckdb_copy_function_bind_set_bind_data +#define duckdb_copy_function_set_global_init duckdb_ext_api.duckdb_copy_function_set_global_init +#define duckdb_copy_function_global_init_set_error duckdb_ext_api.duckdb_copy_function_global_init_set_error +#define duckdb_copy_function_global_init_get_extra_info duckdb_ext_api.duckdb_copy_function_global_init_get_extra_info +#define duckdb_copy_function_global_init_get_client_context \ + duckdb_ext_api.duckdb_copy_function_global_init_get_client_context +#define duckdb_copy_function_global_init_get_bind_data duckdb_ext_api.duckdb_copy_function_global_init_get_bind_data +#define duckdb_copy_function_global_init_get_file_path duckdb_ext_api.duckdb_copy_function_global_init_get_file_path +#define duckdb_copy_function_global_init_set_global_state \ + duckdb_ext_api.duckdb_copy_function_global_init_set_global_state +#define duckdb_copy_function_set_sink duckdb_ext_api.duckdb_copy_function_set_sink +#define duckdb_copy_function_sink_set_error duckdb_ext_api.duckdb_copy_function_sink_set_error +#define duckdb_copy_function_sink_get_extra_info duckdb_ext_api.duckdb_copy_function_sink_get_extra_info +#define duckdb_copy_function_sink_get_client_context duckdb_ext_api.duckdb_copy_function_sink_get_client_context +#define duckdb_copy_function_sink_get_bind_data duckdb_ext_api.duckdb_copy_function_sink_get_bind_data +#define duckdb_copy_function_sink_get_global_state duckdb_ext_api.duckdb_copy_function_sink_get_global_state +#define duckdb_copy_function_set_finalize duckdb_ext_api.duckdb_copy_function_set_finalize +#define duckdb_copy_function_finalize_set_error duckdb_ext_api.duckdb_copy_function_finalize_set_error +#define duckdb_copy_function_finalize_get_extra_info duckdb_ext_api.duckdb_copy_function_finalize_get_extra_info +#define duckdb_copy_function_finalize_get_client_context duckdb_ext_api.duckdb_copy_function_finalize_get_client_context +#define duckdb_copy_function_finalize_get_bind_data duckdb_ext_api.duckdb_copy_function_finalize_get_bind_data +#define duckdb_copy_function_finalize_get_global_state duckdb_ext_api.duckdb_copy_function_finalize_get_global_state +#define duckdb_copy_function_set_copy_from_function duckdb_ext_api.duckdb_copy_function_set_copy_from_function +#define duckdb_table_function_bind_get_result_column_count \ + duckdb_ext_api.duckdb_table_function_bind_get_result_column_count +#define duckdb_table_function_bind_get_result_column_name \ + duckdb_ext_api.duckdb_table_function_bind_get_result_column_name +#define duckdb_table_function_bind_get_result_column_type \ + duckdb_ext_api.duckdb_table_function_bind_get_result_column_type + // Version unstable_new_error_data_functions #define duckdb_create_error_data duckdb_ext_api.duckdb_create_error_data #define duckdb_destroy_error_data duckdb_ext_api.duckdb_destroy_error_data @@ -1164,6 +1290,10 @@ typedef struct { // Version unstable_new_string_functions #define duckdb_value_to_string duckdb_ext_api.duckdb_value_to_string +// Version unstable_new_table_description_functions +#define duckdb_table_description_get_column_count duckdb_ext_api.duckdb_table_description_get_column_count +#define duckdb_table_description_get_column_type duckdb_ext_api.duckdb_table_description_get_column_type + // Version unstable_new_table_function_functions #define duckdb_table_function_get_client_context duckdb_ext_api.duckdb_table_function_get_client_context diff --git a/src/duckdb/src/logging/log_manager.cpp b/src/duckdb/src/logging/log_manager.cpp index 2785386ab..9a55dd280 100644 --- a/src/duckdb/src/logging/log_manager.cpp +++ b/src/duckdb/src/logging/log_manager.cpp @@ -266,6 +266,7 @@ void LogManager::RegisterDefaultLogTypes() { RegisterLogType(make_uniq()); RegisterLogType(make_uniq()); RegisterLogType(make_uniq()); + RegisterLogType(make_uniq()); } } // namespace duckdb diff --git a/src/duckdb/src/logging/log_storage.cpp b/src/duckdb/src/logging/log_storage.cpp index c6733d968..141ac1f15 100644 --- a/src/duckdb/src/logging/log_storage.cpp +++ b/src/duckdb/src/logging/log_storage.cpp @@ -14,9 +14,9 @@ #include "duckdb/function/cast/vector_cast_helpers.hpp" #include "duckdb/common/operator/string_cast.hpp" #include "duckdb/execution/operator/csv_scanner/sniffer/csv_sniffer.hpp" +#include "duckdb/common/printer.hpp" #include -#include namespace duckdb { @@ -258,8 +258,9 @@ void BufferingLogStorage::UpdateConfigInternal(DatabaseInstance &db, case_insens } void StdOutLogStorage::StdOutWriteStream::WriteData(const_data_ptr_t buffer, idx_t write_size) { - std::cout.write(const_char_ptr_cast(buffer), NumericCast(write_size)); - std::cout.flush(); + string data(const_char_ptr_cast(buffer), NumericCast(write_size)); + Printer::RawPrint(OutputStream::STREAM_STDOUT, data); + Printer::Flush(OutputStream::STREAM_STDOUT); } StdOutLogStorage::StdOutLogStorage(DatabaseInstance &db) : CSVLogStorage(db, false, 1) { @@ -599,7 +600,6 @@ BufferingLogStorage::~BufferingLogStorage() { } static void WriteLoggingContextsToChunk(DataChunk &chunk, const RegisteredLoggingContext &context, idx_t &col) { - auto size = chunk.size(); auto context_id_data = FlatVector::GetData(chunk.data[col++]); diff --git a/src/duckdb/src/logging/log_types.cpp b/src/duckdb/src/logging/log_types.cpp index f78abae59..4d2e0dea2 100644 --- a/src/duckdb/src/logging/log_types.cpp +++ b/src/duckdb/src/logging/log_types.cpp @@ -14,6 +14,7 @@ constexpr LogLevel FileSystemLogType::LEVEL; constexpr LogLevel QueryLogType::LEVEL; constexpr LogLevel HTTPLogType::LEVEL; constexpr LogLevel PhysicalOperatorLogType::LEVEL; +constexpr LogLevel MetricsLogType::LEVEL; constexpr LogLevel CheckpointLogType::LEVEL; //===--------------------------------------------------------------------===// @@ -147,6 +148,29 @@ string PhysicalOperatorLogType::ConstructLogMessage(const PhysicalOperator &phys return Value::STRUCT(std::move(child_list)).ToString(); } + +//===--------------------------------------------------------------------===// +// MetricsLogType +//===--------------------------------------------------------------------===// +MetricsLogType::MetricsLogType() : LogType(NAME, LEVEL, GetLogType()) { +} + +LogicalType MetricsLogType::GetLogType() { + child_list_t child_list = { + {"metric", LogicalType::VARCHAR}, + {"value", LogicalType::VARCHAR}, + }; + return LogicalType::STRUCT(child_list); +} + +string MetricsLogType::ConstructLogMessage(const MetricsType &metric, const Value &value) { + child_list_t child_list = { + {"metric", EnumUtil::ToString(metric)}, + {"value", value.ToString()}, + }; + return Value::STRUCT(std::move(child_list)).ToString(); +} + //===--------------------------------------------------------------------===// // CheckpointLogType //===--------------------------------------------------------------------===// diff --git a/src/duckdb/src/main/appender.cpp b/src/duckdb/src/main/appender.cpp index bac1c06f1..3ff69a74e 100644 --- a/src/duckdb/src/main/appender.cpp +++ b/src/duckdb/src/main/appender.cpp @@ -424,7 +424,6 @@ void BaseAppender::ClearColumns() { //===--------------------------------------------------------------------===// Appender::Appender(Connection &con, const string &database_name, const string &schema_name, const string &table_name) : BaseAppender(Allocator::DefaultAllocator(), AppenderType::LOGICAL), context(con.context) { - description = con.TableInfo(database_name, schema_name, table_name); if (!description) { throw CatalogException( @@ -620,4 +619,14 @@ void BaseAppender::Close() { } } +void BaseAppender::Clear() { + chunk.Reset(); + + if (collection) { + collection->Reset(); + } + + column = 0; +} + } // namespace duckdb diff --git a/src/duckdb/src/main/attached_database.cpp b/src/duckdb/src/main/attached_database.cpp index e98070c18..2ce7140ac 100644 --- a/src/duckdb/src/main/attached_database.cpp +++ b/src/duckdb/src/main/attached_database.cpp @@ -14,14 +14,19 @@ namespace duckdb { -StoredDatabasePath::StoredDatabasePath(DatabaseFilePathManager &manager, string path_p, const string &name) - : manager(manager), path(std::move(path_p)) { +StoredDatabasePath::StoredDatabasePath(DatabaseManager &db_manager, DatabaseFilePathManager &manager, string path_p, + const string &name) + : db_manager(db_manager), manager(manager), path(std::move(path_p)) { } StoredDatabasePath::~StoredDatabasePath() { manager.EraseDatabasePath(path); } +void StoredDatabasePath::OnDetach() { + manager.DetachDatabase(db_manager, path); +} + //===--------------------------------------------------------------------===// // Attach Options //===--------------------------------------------------------------------===// @@ -31,11 +36,9 @@ AttachOptions::AttachOptions(const DBConfigOptions &options) AttachOptions::AttachOptions(const unordered_map &attach_options, const AccessMode default_access_mode) : access_mode(default_access_mode) { - for (auto &entry : attach_options) { if (entry.first == "readonly" || entry.first == "read_only") { // Extract the read access mode. - auto read_only = BooleanValue::Get(entry.second.DefaultCastAs(LogicalType::BOOLEAN)); if (read_only) { access_mode = AccessMode::READ_ONLY; @@ -45,6 +48,13 @@ AttachOptions::AttachOptions(const unordered_map &attach_options, continue; } + if (entry.first == "recovery_mode") { + // Extract the recovery mode. + auto mode_str = StringValue::Get(entry.second.DefaultCastAs(LogicalType::VARCHAR)); + recovery_mode = EnumUtil::FromString(mode_str); + continue; + } + if (entry.first == "readwrite" || entry.first == "read_write") { // Extract the write access mode. auto read_write = BooleanValue::Get(entry.second.DefaultCastAs(LogicalType::BOOLEAN)); @@ -77,7 +87,6 @@ AttachedDatabase::AttachedDatabase(DatabaseInstance &db, AttachedDatabaseType ty : CatalogEntry(CatalogType::DATABASE_ENTRY, type == AttachedDatabaseType::SYSTEM_DATABASE ? SYSTEM_CATALOG : TEMP_CATALOG, 0), db(db), type(type) { - // This database does not have storage, or uses temporary_objects for in-memory storage. D_ASSERT(type == AttachedDatabaseType::TEMP_DATABASE || type == AttachedDatabaseType::SYSTEM_DATABASE); if (type == AttachedDatabaseType::TEMP_DATABASE) { @@ -99,7 +108,9 @@ AttachedDatabase::AttachedDatabase(DatabaseInstance &db, Catalog &catalog_p, str } else { type = AttachedDatabaseType::READ_WRITE_DATABASE; } + recovery_mode = options.recovery_mode; visibility = options.visibility; + // We create the storage after the catalog to guarantee we allow extensions to instantiate the DuckCatalog. catalog = make_uniq(*this); stored_database_path = std::move(options.stored_database_path); @@ -117,6 +128,7 @@ AttachedDatabase::AttachedDatabase(DatabaseInstance &db, Catalog &catalog_p, Sto } else { type = AttachedDatabaseType::READ_WRITE_DATABASE; } + recovery_mode = options.recovery_mode; visibility = options.visibility; optional_ptr storage_info = storage_extension->storage_info.get(); @@ -157,6 +169,13 @@ bool AttachedDatabase::NameIsReserved(const string &name) { return name == DEFAULT_SCHEMA || name == TEMP_CATALOG || name == SYSTEM_CATALOG; } +string AttachedDatabase::StoredPath() const { + if (stored_database_path) { + return stored_database_path->path; + } + return string(); +} + static string RemoveQueryParams(const string &name) { auto vec = StringUtil::Split(name, "?"); D_ASSERT(!vec.empty()); @@ -181,7 +200,7 @@ void AttachedDatabase::Initialize(optional_ptr context) { catalog->Initialize(context, false); } if (storage) { - storage->Initialize(QueryContext(context)); + storage->Initialize(context); } } @@ -232,6 +251,9 @@ void AttachedDatabase::OnDetach(ClientContext &context) { if (catalog) { catalog->OnDetach(context); } + if (stored_database_path && visibility != AttachVisibility::HIDDEN) { + stored_database_path->OnDetach(); + } } void AttachedDatabase::Close() { @@ -266,10 +288,6 @@ void AttachedDatabase::Close() { catalog.reset(); storage.reset(); stored_database_path.reset(); - - if (Allocator::SupportsFlush()) { - Allocator::FlushAll(); - } } } // namespace duckdb diff --git a/src/duckdb/src/main/buffered_data/batched_buffered_data.cpp b/src/duckdb/src/main/buffered_data/batched_buffered_data.cpp index 3c593374c..e9f949098 100644 --- a/src/duckdb/src/main/buffered_data/batched_buffered_data.cpp +++ b/src/duckdb/src/main/buffered_data/batched_buffered_data.cpp @@ -14,9 +14,8 @@ void BatchedBufferedData::BlockSink(const InterruptState &blocked_sink, idx_t ba blocked_sinks.emplace(batch, blocked_sink); } -BatchedBufferedData::BatchedBufferedData(weak_ptr context) - : BufferedData(BufferedData::Type::BATCHED, std::move(context)), buffer_byte_count(0), read_queue_byte_count(0), - min_batch(0) { +BatchedBufferedData::BatchedBufferedData(ClientContext &context) + : BufferedData(BufferedData::Type::BATCHED, context), buffer_byte_count(0), read_queue_byte_count(0), min_batch(0) { read_queue_capacity = (idx_t)(static_cast(total_buffer_size) * 0.6); buffer_capacity = (idx_t)(static_cast(total_buffer_size) * 0.4); } diff --git a/src/duckdb/src/main/buffered_data/buffered_data.cpp b/src/duckdb/src/main/buffered_data/buffered_data.cpp index 156539815..0e01df8dc 100644 --- a/src/duckdb/src/main/buffered_data/buffered_data.cpp +++ b/src/duckdb/src/main/buffered_data/buffered_data.cpp @@ -4,9 +4,8 @@ namespace duckdb { -BufferedData::BufferedData(Type type, weak_ptr context_p) : type(type), context(std::move(context_p)) { - auto client_context = context.lock(); - auto &config = ClientConfig::GetConfig(*client_context); +BufferedData::BufferedData(Type type, ClientContext &context_p) : type(type), context(context_p.shared_from_this()) { + auto &config = ClientConfig::GetConfig(context_p); total_buffer_size = config.streaming_buffer_size; } diff --git a/src/duckdb/src/main/buffered_data/simple_buffered_data.cpp b/src/duckdb/src/main/buffered_data/simple_buffered_data.cpp index 4b6a3a534..59cde1f43 100644 --- a/src/duckdb/src/main/buffered_data/simple_buffered_data.cpp +++ b/src/duckdb/src/main/buffered_data/simple_buffered_data.cpp @@ -6,8 +6,7 @@ namespace duckdb { -SimpleBufferedData::SimpleBufferedData(weak_ptr context) - : BufferedData(BufferedData::Type::SIMPLE, std::move(context)) { +SimpleBufferedData::SimpleBufferedData(ClientContext &context) : BufferedData(BufferedData::Type::SIMPLE, context) { buffered_count = 0; buffer_size = total_buffer_size; } diff --git a/src/duckdb/src/main/capi/aggregate_function-c.cpp b/src/duckdb/src/main/capi/aggregate_function-c.cpp index 4eb461123..b6c895aa6 100644 --- a/src/duckdb/src/main/capi/aggregate_function-c.cpp +++ b/src/duckdb/src/main/capi/aggregate_function-c.cpp @@ -193,7 +193,7 @@ void duckdb_aggregate_function_set_return_type(duckdb_aggregate_function functio } auto &aggregate_function = GetCAggregateFunction(function); auto logical_type = reinterpret_cast(type); - aggregate_function.return_type = *logical_type; + aggregate_function.SetReturnType(*logical_type); } void duckdb_aggregate_function_set_functions(duckdb_aggregate_function function, duckdb_aggregate_state_size state_size, @@ -237,7 +237,7 @@ void duckdb_aggregate_function_set_special_handling(duckdb_aggregate_function fu return; } auto &aggregate_function = GetCAggregateFunction(function); - aggregate_function.null_handling = duckdb::FunctionNullHandling::SPECIAL_HANDLING; + aggregate_function.SetNullHandling(duckdb::FunctionNullHandling::SPECIAL_HANDLING); } void duckdb_aggregate_function_set_extra_info(duckdb_aggregate_function function, void *extra_info, @@ -311,8 +311,8 @@ duckdb_state duckdb_register_aggregate_function_set(duckdb_connection connection if (aggregate_function.name.empty() || !info.update || !info.combine || !info.finalize) { return DuckDBError; } - if (duckdb::TypeVisitor::Contains(aggregate_function.return_type, duckdb::LogicalTypeId::INVALID) || - duckdb::TypeVisitor::Contains(aggregate_function.return_type, duckdb::LogicalTypeId::ANY)) { + if (duckdb::TypeVisitor::Contains(aggregate_function.GetReturnType(), duckdb::LogicalTypeId::INVALID) || + duckdb::TypeVisitor::Contains(aggregate_function.GetReturnType(), duckdb::LogicalTypeId::ANY)) { return DuckDBError; } for (const auto &argument : aggregate_function.arguments) { diff --git a/src/duckdb/src/main/capi/appender-c.cpp b/src/duckdb/src/main/capi/appender-c.cpp index a54536b20..959db5098 100644 --- a/src/duckdb/src/main/capi/appender-c.cpp +++ b/src/duckdb/src/main/capi/appender-c.cpp @@ -318,6 +318,10 @@ duckdb_state duckdb_appender_flush(duckdb_appender appender_p) { return duckdb_appender_run_function(appender_p, [&](BaseAppender &appender) { appender.Flush(); }); } +duckdb_state duckdb_appender_clear(duckdb_appender appender_p) { + return duckdb_appender_run_function(appender_p, [&](BaseAppender &appender) { appender.Clear(); }); +} + duckdb_state duckdb_appender_close(duckdb_appender appender_p) { return duckdb_appender_run_function(appender_p, [&](BaseAppender &appender) { appender.Close(); }); } diff --git a/src/duckdb/src/main/capi/arrow-c.cpp b/src/duckdb/src/main/capi/arrow-c.cpp index a1bc5391f..f3f2f7fe5 100644 --- a/src/duckdb/src/main/capi/arrow-c.cpp +++ b/src/duckdb/src/main/capi/arrow-c.cpp @@ -18,7 +18,6 @@ using duckdb::QueryResultType; duckdb_error_data duckdb_to_arrow_schema(duckdb_arrow_options arrow_options, duckdb_logical_type *types, const char **names, idx_t column_count, struct ArrowSchema *out_schema) { - if (!types || !names || !arrow_options || !out_schema) { return duckdb_create_error_data(DUCKDB_ERROR_INVALID_INPUT, "Invalid argument(s) to duckdb_to_arrow_schema"); } @@ -298,7 +297,6 @@ void duckdb_destroy_arrow(duckdb_arrow *result) { } void duckdb_destroy_arrow_stream(duckdb_arrow_stream *stream_p) { - auto stream = reinterpret_cast(*stream_p); if (!stream) { return; diff --git a/src/duckdb/src/main/capi/cast_function-c.cpp b/src/duckdb/src/main/capi/cast_function-c.cpp index 39a5d90a7..a0b5f243e 100644 --- a/src/duckdb/src/main/capi/cast_function-c.cpp +++ b/src/duckdb/src/main/capi/cast_function-c.cpp @@ -25,7 +25,6 @@ struct CCastFunction { }; struct CCastFunctionUserData { - duckdb_function_info data_ptr = nullptr; duckdb_delete_callback_t delete_callback = nullptr; @@ -56,7 +55,6 @@ struct CCastFunctionData final : public BoundCastData { }; static bool CAPICastFunction(Vector &input, Vector &output, idx_t count, CastParameters ¶meters) { - const auto is_const = input.GetVectorType() == VectorType::CONSTANT_VECTOR; input.Flatten(count); diff --git a/src/duckdb/src/main/capi/config_options-c.cpp b/src/duckdb/src/main/capi/config_options-c.cpp new file mode 100644 index 000000000..b895245fc --- /dev/null +++ b/src/duckdb/src/main/capi/config_options-c.cpp @@ -0,0 +1,159 @@ +#include "duckdb/main/capi/capi_internal.hpp" + +namespace duckdb { +namespace { + +struct CConfigOption { + string name; + LogicalType type; + Value default_value; + SetScope default_scope = SetScope::SESSION; + string description; +}; + +} // namespace +} // namespace duckdb + +duckdb_config_option duckdb_create_config_option() { + auto coption = new duckdb::CConfigOption(); + return reinterpret_cast(coption); +} + +void duckdb_destroy_config_option(duckdb_config_option *option) { + if (!option || !*option) { + return; + } + auto coption = *reinterpret_cast(option); + delete coption; + + *option = nullptr; +} + +void duckdb_config_option_set_name(duckdb_config_option option, const char *name) { + if (!option || !name) { + return; + } + auto coption = reinterpret_cast(option); + coption->name = name; +} + +void duckdb_config_option_set_type(duckdb_config_option option, duckdb_logical_type type) { + if (!option || !type) { + return; + } + auto coption = reinterpret_cast(option); + coption->type = *reinterpret_cast(type); +} + +void duckdb_config_option_set_default_value(duckdb_config_option option, duckdb_value default_value) { + if (!option || !default_value) { + return; + } + auto coption = reinterpret_cast(option); + auto cvalue = reinterpret_cast(default_value); + + if (coption->type.id() == duckdb::LogicalTypeId::INVALID) { + coption->type = cvalue->type(); + coption->default_value = *cvalue; + return; + } + + if (coption->type != cvalue->type()) { + coption->default_value = cvalue->DefaultCastAs(coption->type, false); + return; + } + + coption->default_value = *cvalue; +} + +void duckdb_config_option_set_default_scope(duckdb_config_option option, duckdb_config_option_scope scope) { + if (!option) { + return; + } + auto coption = reinterpret_cast(option); + switch (scope) { + case DUCKDB_CONFIG_OPTION_SCOPE_LOCAL: + coption->default_scope = duckdb::SetScope::LOCAL; + break; + // Set the option for the current session/connection only. + case DUCKDB_CONFIG_OPTION_SCOPE_SESSION: + coption->default_scope = duckdb::SetScope::SESSION; + break; + // Set the option globally for all sessions/connections. + case DUCKDB_CONFIG_OPTION_SCOPE_GLOBAL: + coption->default_scope = duckdb::SetScope::GLOBAL; + break; + default: + return; + } +} + +void duckdb_config_option_set_description(duckdb_config_option option, const char *description) { + if (!option || !description) { + return; + } + auto coption = reinterpret_cast(option); + coption->description = description; +} + +duckdb_state duckdb_register_config_option(duckdb_connection connection, duckdb_config_option option) { + if (!connection || !option) { + return DuckDBError; + } + + auto conn = reinterpret_cast(connection); + auto coption = reinterpret_cast(option); + + if (coption->name.empty() || coption->type.id() == duckdb::LogicalTypeId::INVALID) { + return DuckDBError; + } + + // TODO: This is not transactional... but theres no easy way to make it so currently. + try { + if (conn->context->db->config.HasExtensionOption(coption->name)) { + // Option already exists + return DuckDBError; + } + conn->context->db->config.AddExtensionOption(coption->name, coption->description, coption->type, + coption->default_value, nullptr, coption->default_scope); + } catch (...) { + return DuckDBError; + } + + return DuckDBSuccess; +} + +duckdb_value duckdb_client_context_get_config_option(duckdb_client_context context, const char *option_name, + duckdb_config_option_scope *out_scope) { + if (!context || !option_name) { + return nullptr; + } + + auto wrapper = reinterpret_cast(context); + auto &ctx = wrapper->context; + + duckdb_config_option_scope res_scope = DUCKDB_CONFIG_OPTION_SCOPE_INVALID; + duckdb::Value *res_value = nullptr; + + duckdb::Value result; + switch (ctx.TryGetCurrentSetting(option_name, result).GetScope()) { + case duckdb::SettingScope::LOCAL: + // This is a bit messy, but "session" is presented as LOCAL on the "settings" side of the API. + res_value = new duckdb::Value(std::move(result)); + res_scope = DUCKDB_CONFIG_OPTION_SCOPE_SESSION; + break; + case duckdb::SettingScope::GLOBAL: + res_value = new duckdb::Value(std::move(result)); + res_scope = DUCKDB_CONFIG_OPTION_SCOPE_GLOBAL; + break; + default: + res_value = nullptr; + res_scope = DUCKDB_CONFIG_OPTION_SCOPE_INVALID; + break; + } + + if (out_scope) { + *out_scope = res_scope; + } + return reinterpret_cast(res_value); +} diff --git a/src/duckdb/src/main/capi/copy_function-c.cpp b/src/duckdb/src/main/capi/copy_function-c.cpp new file mode 100644 index 000000000..b1bb4394b --- /dev/null +++ b/src/duckdb/src/main/capi/copy_function-c.cpp @@ -0,0 +1,821 @@ +#include "duckdb/common/type_visitor.hpp" +#include "duckdb/common/types.hpp" +#include "duckdb/common/helper.hpp" +#include "duckdb/function/copy_function.hpp" +#include "duckdb/function/table_function.hpp" +#include "duckdb/main/capi/capi_internal.hpp" +#include "duckdb/main/capi/capi_internal_table.hpp" +#include "duckdb/parser/parsed_data/create_copy_function_info.hpp" + +//---------------------------------------------------------------------------------------------------------------------- +// Common Copy Function Info +//---------------------------------------------------------------------------------------------------------------------- + +namespace duckdb { +namespace { + +struct CCopyFunctionInfo : public CopyFunctionInfo { + ~CCopyFunctionInfo() override { + if (extra_info && delete_callback) { + delete_callback(extra_info); + } + extra_info = nullptr; + delete_callback = nullptr; + } + + duckdb_copy_function_bind_t bind_to = nullptr; + duckdb_copy_function_global_init_t global_init = nullptr; + duckdb_copy_function_sink_t sink = nullptr; + duckdb_copy_function_finalize_t finalize = nullptr; + + void *extra_info = nullptr; + duckdb_delete_callback_t delete_callback = nullptr; +}; + +Value MakeValueFromCopyOptions(const case_insensitive_map_t> &options) { + child_list_t option_list; + for (auto &entry : options) { + // Uppercase the option name, to make it simpler for users + auto name = StringUtil::Upper(entry.first); + auto &values = entry.second; + + if (values.empty()) { + // Null! + option_list.emplace_back(std::move(name), Value()); + continue; + } + if (values.size() == 1) { + // Single value + option_list.emplace_back(std::move(name), values[0]); + continue; + } + + auto is_same_type = true; + auto first_type = values[0].type(); + for (auto &val : values) { + if (val.type() != first_type) { + // Different types, cannot unify + is_same_type = false; + break; + } + } + + // Is same type: create a list of that type + if (is_same_type) { + option_list.emplace_back(std::move(name), Value::LIST(first_type, values)); + continue; + } + + // Different types: create an unnamed struct + child_list_t children; + for (auto &val : values) { + children.emplace_back("", val); + } + option_list.emplace_back(std::move(name), Value::STRUCT(children)); + } + + if (option_list.empty()) { + // No options + return Value(); + } + + // Return a struct of all options + return Value::STRUCT(std::move(option_list)); +} + +} // namespace +} // namespace duckdb + +duckdb_copy_function duckdb_create_copy_function() { + auto function = new duckdb::CopyFunction(""); + + function->function_info = duckdb::make_shared_ptr(); + + return reinterpret_cast(function); +} + +void duckdb_copy_function_set_name(duckdb_copy_function copy_function, const char *name) { + if (!copy_function || !name) { + return; + } + auto ©_function_ref = *reinterpret_cast(copy_function); + copy_function_ref.name = name; +} + +void duckdb_destroy_copy_function(duckdb_copy_function *copy_function) { + if (copy_function && *copy_function) { + auto function = reinterpret_cast(*copy_function); + delete function; + *copy_function = nullptr; + } +} + +void duckdb_copy_function_set_extra_info(duckdb_copy_function function, void *extra_info, + duckdb_delete_callback_t destroy) { + if (!function) { + return; + } + auto ©_function_ref = *reinterpret_cast(function); + auto &info = copy_function_ref.function_info->Cast(); + info.extra_info = extra_info; + info.delete_callback = destroy; +} + +//---------------------------------------------------------------------------------------------------------------------- +// Copy To Bind +//---------------------------------------------------------------------------------------------------------------------- +namespace duckdb { +namespace { +struct CCopyToBindInfo : FunctionData { + shared_ptr function_info; + void *bind_data = nullptr; + duckdb_delete_callback_t delete_callback = nullptr; + + unique_ptr Copy() const override { + throw InternalException("CCopyToBindInfo cannot be copied"); + } + + bool Equals(const FunctionData &other_p) const override { + auto &other = other_p.Cast(); + return bind_data == other.bind_data && delete_callback == other.delete_callback; + } + + ~CCopyToBindInfo() override { + if (bind_data && delete_callback) { + delete_callback(bind_data); + } + bind_data = nullptr; + delete_callback = nullptr; + } +}; + +struct CCopyFunctionToInternalBindInfo { + CCopyFunctionToInternalBindInfo(ClientContext &context, CopyFunctionBindInput &input, + const vector &sql_types, const vector &names, + const CCopyFunctionInfo &function_info) + : context(context), input(input), sql_types(sql_types), names(names), function_info(function_info), + success(true) { + } + + ClientContext &context; + CopyFunctionBindInput &input; + const vector &sql_types; + const vector &names; + const CCopyFunctionInfo &function_info; + bool success; + string error; + + // Supplied by the user + void *bind_data = nullptr; + duckdb_delete_callback_t delete_callback = nullptr; +}; + +unique_ptr CCopyToBind(ClientContext &context, CopyFunctionBindInput &input, const vector &names, + const vector &sql_types) { + auto &info = input.function_info->Cast(); + + auto result = make_uniq(); + result->function_info = input.function_info; + + if (info.bind_to) { + // Call the user-defined bind function + CCopyFunctionToInternalBindInfo bind_info(context, input, sql_types, names, info); + info.bind_to(reinterpret_cast(&bind_info)); + + // Pass on user bind data to the result + result->bind_data = bind_info.bind_data; + result->delete_callback = bind_info.delete_callback; + + if (!bind_info.success) { + throw BinderException(bind_info.error); + } + } + return std::move(result); +} + +} // namespace +} // namespace duckdb + +void duckdb_copy_function_set_bind(duckdb_copy_function copy_function, duckdb_copy_function_bind_t bind) { + if (!copy_function || !bind) { + return; + } + + auto ©_function_ref = *reinterpret_cast(copy_function); + auto &info = copy_function_ref.function_info->Cast(); + + // Set C bind callback + info.bind_to = bind; +} + +void duckdb_copy_function_bind_set_error(duckdb_copy_function_bind_info info, const char *error) { + if (!info || !error) { + return; + } + auto &info_ref = *reinterpret_cast(info); + + // Set the error message + info_ref.error = error; + info_ref.success = false; +} + +void *duckdb_copy_function_bind_get_extra_info(duckdb_copy_function_bind_info info) { + if (!info) { + return nullptr; + } + auto &info_ref = *reinterpret_cast(info); + return info_ref.function_info.extra_info; +} + +duckdb_client_context duckdb_copy_function_bind_get_client_context(duckdb_copy_function_bind_info info) { + if (!info) { + return nullptr; + } + auto &info_ref = *reinterpret_cast(info); + auto wrapper = new duckdb::CClientContextWrapper(info_ref.context); + return reinterpret_cast(wrapper); +} + +idx_t duckdb_copy_function_bind_get_column_count(duckdb_copy_function_bind_info info) { + if (!info) { + return 0; + } + auto &info_ref = *reinterpret_cast(info); + return info_ref.sql_types.size(); +} + +duckdb_logical_type duckdb_copy_function_bind_get_column_type(duckdb_copy_function_bind_info info, idx_t col_idx) { + if (!info) { + return nullptr; + } + auto &info_ref = *reinterpret_cast(info); + if (col_idx >= info_ref.sql_types.size()) { + return nullptr; + } + return reinterpret_cast(new duckdb::LogicalType(info_ref.sql_types[col_idx])); +} + +duckdb_value duckdb_copy_function_bind_get_options(duckdb_copy_function_bind_info info) { + if (!info) { + return nullptr; + } + + auto &info_ref = *reinterpret_cast(info); + auto &options = info_ref.input.info.options; + + // return as struct of options + auto options_value = duckdb::MakeValueFromCopyOptions(options); + return reinterpret_cast(new duckdb::Value(options_value)); +} + +void duckdb_copy_function_bind_set_bind_data(duckdb_copy_function_bind_info info, void *bind_data, + duckdb_delete_callback_t destructor) { + if (!info) { + return; + } + auto &info_ref = *reinterpret_cast(info); + + // Store the bind data and destructor + info_ref.bind_data = bind_data; + info_ref.delete_callback = destructor; +} + +//---------------------------------------------------------------------------------------------------------------------- +// Copy To Global Initialize +//---------------------------------------------------------------------------------------------------------------------- +namespace duckdb { +namespace { + +struct CCopyToGlobalState : GlobalFunctionData { + void *global_state = nullptr; + duckdb_delete_callback_t delete_callback = nullptr; + + ~CCopyToGlobalState() override { + if (global_state && delete_callback) { + delete_callback(global_state); + } + global_state = nullptr; + delete_callback = nullptr; + } +}; + +struct CCopyToGlobalInitInfo { + CCopyToGlobalInitInfo(ClientContext &context, FunctionData &bind_data, const string &file_path) + : context(context), bind_data(bind_data), file_path(file_path) { + } + + ClientContext &context; + FunctionData &bind_data; + const string &file_path; + + string error; + bool success = true; + + void *global_state = nullptr; + duckdb_delete_callback_t delete_callback = nullptr; +}; + +unique_ptr CCopyToGlobalInit(ClientContext &context, FunctionData &bind_data, + const string &file_path) { + auto &bind_info = bind_data.Cast(); + auto &function_info = bind_info.function_info->Cast(); + + auto result = make_uniq(); + + if (function_info.global_init) { + // Call the user-defined global init function + CCopyToGlobalInitInfo global_init_info(context, bind_data, file_path); + function_info.global_init(reinterpret_cast(&global_init_info)); + + // Pass on user global state to the result + result->global_state = global_init_info.global_state; + result->delete_callback = global_init_info.delete_callback; + + if (!global_init_info.success) { + throw InvalidInputException(global_init_info.error); + } + } + + return std::move(result); +} + +} // namespace +} // namespace duckdb + +void duckdb_copy_function_set_global_init(duckdb_copy_function copy_function, duckdb_copy_function_global_init_t init) { + if (!copy_function || !init) { + return; + } + auto ©_function_ref = *reinterpret_cast(copy_function); + auto &info = copy_function_ref.function_info->Cast(); + + // Set C global init callback + info.global_init = init; +} + +void duckdb_copy_function_global_init_set_error(duckdb_copy_function_global_init_info info, const char *error) { + if (!info || !error) { + return; + } + auto &info_ref = *reinterpret_cast(info); + + // Set the error message + info_ref.error = error; + info_ref.success = false; +} + +void *duckdb_copy_function_global_init_get_extra_info(duckdb_copy_function_global_init_info info) { + if (!info) { + return nullptr; + } + auto &info_ref = *reinterpret_cast(info); + return info_ref.bind_data.Cast() + .function_info->Cast() + .extra_info; +} + +duckdb_client_context duckdb_copy_function_global_init_get_client_context(duckdb_copy_function_global_init_info info) { + if (!info) { + return nullptr; + } + auto &info_ref = *reinterpret_cast(info); + auto wrapper = new duckdb::CClientContextWrapper(info_ref.context); + return reinterpret_cast(wrapper); +} + +void *duckdb_copy_function_global_init_get_bind_data(duckdb_copy_function_global_init_info info) { + if (!info) { + return nullptr; + } + auto &info_ref = *reinterpret_cast(info); + auto &bind_info = info_ref.bind_data.Cast(); + + return bind_info.bind_data; +} + +void duckdb_copy_function_global_init_set_global_state(duckdb_copy_function_global_init_info info, void *global_state, + duckdb_delete_callback_t destructor) { + if (!info) { + return; + } + auto &info_ref = *reinterpret_cast(info); + info_ref.global_state = global_state; + info_ref.delete_callback = destructor; +} + +const char *duckdb_copy_function_global_init_get_file_path(duckdb_copy_function_global_init_info info) { + if (!info) { + return nullptr; + } + auto &info_ref = *reinterpret_cast(info); + return info_ref.file_path.c_str(); +} + +//---------------------------------------------------------------------------------------------------------------------- +// Copy To Local Initialize +//---------------------------------------------------------------------------------------------------------------------- +namespace duckdb { +namespace { + +unique_ptr CCopyToLocalInit(ExecutionContext &context, FunctionData &bind_data) { + // This isnt exposed to the C-API yet, so we just return empty local function data + return make_uniq(); +} + +} // namespace +} // namespace duckdb +//---------------------------------------------------------------------------------------------------------------------- +// Copy To Sink +//---------------------------------------------------------------------------------------------------------------------- +namespace duckdb { +namespace { + +struct CCopyToSinkInfo { + CCopyToSinkInfo(ClientContext &context, FunctionData &bind_data, GlobalFunctionData &gstate) + : context(context), bind_data(bind_data), gstate(gstate) { + } + + ClientContext &context; + FunctionData &bind_data; + GlobalFunctionData &gstate; + string error; + bool success = true; +}; + +void CCopyToSink(ExecutionContext &context, FunctionData &bind_data, GlobalFunctionData &gstate, + LocalFunctionData &lstate, DataChunk &input) { + auto &bind_info = bind_data.Cast(); + auto &function_info = bind_info.function_info->Cast(); + + // Flatten input (we dont support compressed execution yet!) + // TODO: Dont flatten! + input.Flatten(); + + CCopyToSinkInfo copy_to_sink_info(context.client, bind_data, gstate); + + // Sink is required! + function_info.sink(reinterpret_cast(©_to_sink_info), + reinterpret_cast(&input)); + + if (!copy_to_sink_info.success) { + throw InvalidInputException(copy_to_sink_info.error); + } +} + +} // namespace +} // namespace duckdb + +void duckdb_copy_function_set_sink(duckdb_copy_function copy_function, duckdb_copy_function_sink_t function) { + if (!copy_function || !function) { + return; + } + auto ©_function_ref = *reinterpret_cast(copy_function); + auto &info = copy_function_ref.function_info->Cast(); + + // Set C sink callback + info.sink = function; +} + +void duckdb_copy_function_sink_set_error(duckdb_copy_function_sink_info info, const char *error) { + if (!info || !error) { + return; + } + auto &info_ref = *reinterpret_cast(info); + // Set the error message + info_ref.error = error; + info_ref.success = false; +} + +void *duckdb_copy_function_sink_get_extra_info(duckdb_copy_function_sink_info info) { + if (!info) { + return nullptr; + } + auto &info_ref = *reinterpret_cast(info); + return info_ref.bind_data.Cast() + .function_info->Cast() + .extra_info; +} + +duckdb_client_context duckdb_copy_function_sink_get_client_context(duckdb_copy_function_sink_info info) { + if (!info) { + return nullptr; + } + auto &info_ref = *reinterpret_cast(info); + auto wrapper = new duckdb::CClientContextWrapper(info_ref.context); + return reinterpret_cast(wrapper); +} + +void *duckdb_copy_function_sink_get_bind_data(duckdb_copy_function_sink_info info) { + if (!info) { + return nullptr; + } + auto &info_ref = *reinterpret_cast(info); + auto &bind_info = info_ref.bind_data.Cast(); + + return bind_info.bind_data; +} + +void *duckdb_copy_function_sink_get_global_state(duckdb_copy_function_sink_info info) { + if (!info) { + return nullptr; + } + auto &info_ref = *reinterpret_cast(info); + auto &gstate = info_ref.gstate.Cast(); + + return gstate.global_state; +} + +//---------------------------------------------------------------------------------------------------------------------- +// Copy To Combine +//---------------------------------------------------------------------------------------------------------------------- +namespace duckdb { +namespace { + +void CCopyToCombine(ExecutionContext &context, FunctionData &bind_data, GlobalFunctionData &gstate, + LocalFunctionData &lstate) { + // Do nothing for now (this isnt exposed to the C-API yet) +} + +} // namespace +} // namespace duckdb + +//---------------------------------------------------------------------------------------------------------------------- +// Copy To Finalize +//---------------------------------------------------------------------------------------------------------------------- +namespace duckdb { +namespace { + +struct CCopyToFinalizeInfo { + CCopyToFinalizeInfo(ClientContext &context, FunctionData &bind_data, GlobalFunctionData &gstate) + : context(context), bind_data(bind_data), gstate(gstate) { + } + + ClientContext &context; + FunctionData &bind_data; + GlobalFunctionData &gstate; + + string error; + bool success = true; +}; + +void CCopyToFinalize(ClientContext &context, FunctionData &bind_data, GlobalFunctionData &gstate) { + auto &bind_info = bind_data.Cast(); + auto &function_info = bind_info.function_info->Cast(); + + // Finalize is optional + if (function_info.finalize) { + CCopyToFinalizeInfo copy_to_finalize_info(context, bind_data, gstate); + function_info.finalize(reinterpret_cast(©_to_finalize_info)); + + if (!copy_to_finalize_info.success) { + throw InvalidInputException(copy_to_finalize_info.error); + } + } +} + +} // namespace +} // namespace duckdb + +void duckdb_copy_function_set_finalize(duckdb_copy_function copy_function, duckdb_copy_function_finalize_t finalize) { + if (!copy_function || !finalize) { + return; + } + + auto ©_function_ref = *reinterpret_cast(copy_function); + auto &info = copy_function_ref.function_info->Cast(); + + // Set C finalize callback + info.finalize = finalize; +} + +void duckdb_copy_function_finalize_set_error(duckdb_copy_function_finalize_info info, const char *error) { + if (!info || !error) { + return; + } + + auto &info_ref = *reinterpret_cast(info); + // Set the error message + info_ref.error = error; + info_ref.success = false; +} + +void *duckdb_copy_function_finalize_get_extra_info(duckdb_copy_function_finalize_info info) { + if (!info) { + return nullptr; + } + + auto &info_ref = *reinterpret_cast(info); + return info_ref.bind_data.Cast() + .function_info->Cast() + .extra_info; +} + +duckdb_client_context duckdb_copy_function_finalize_get_client_context(duckdb_copy_function_finalize_info info) { + if (!info) { + return nullptr; + } + auto &info_ref = *reinterpret_cast(info); + auto wrapper = new duckdb::CClientContextWrapper(info_ref.context); + return reinterpret_cast(wrapper); +} + +void *duckdb_copy_function_finalize_get_bind_data(duckdb_copy_function_finalize_info info) { + if (!info) { + return nullptr; + } + + auto &info_ref = *reinterpret_cast(info); + auto &bind_info = info_ref.bind_data.Cast(); + return bind_info.bind_data; +} + +void *duckdb_copy_function_finalize_get_global_state(duckdb_copy_function_finalize_info info) { + if (!info) { + return nullptr; + } + + auto &info_ref = *reinterpret_cast(info); + auto &gstate = info_ref.gstate.Cast(); + return gstate.global_state; +} + +//---------------------------------------------------------------------------------------------------------------------- +// Copy FROM +//---------------------------------------------------------------------------------------------------------------------- +namespace duckdb { +namespace { + +unique_ptr CCopyFromBind(ClientContext &context, CopyFromFunctionBindInput &info, + vector &expected_names, vector &expected_types) { + auto &tf_info = info.tf.function_info->Cast(); + auto result = make_uniq(tf_info); + + named_parameter_map_t named_parameters; + + // Turn all options into named parameters + for (auto opt : info.info.options) { + auto param_it = info.tf.named_parameters.find(opt.first); + if (param_it == info.tf.named_parameters.end()) { + // Option not found in the table function's named parameters + throw BinderException("'%s' is not a supported option for copy function '%s'", opt.first.c_str(), + info.tf.name.c_str()); + } + + // Try to convert a list of values into a single Value, either by extracting or unifying into a list + Value param_value; + if (opt.second.empty()) { + continue; + } + if (opt.second.size() == 1) { + param_value = opt.second[0]; + } else { + auto first_type = opt.second[0].type(); + auto is_same_type = true; + for (auto &val : opt.second) { + if (val.type() != first_type) { + is_same_type = false; + break; + } + } + if (is_same_type) { + param_value = Value::LIST(first_type, opt.second); + } else { + throw BinderException("Cannot pass multiple values of different types for copy option '%s'", + opt.first.c_str()); + } + } + + // Assing the option as a named parameter + named_parameters[opt.first] = param_value; + } + + // Also pass file path as a regular parameter + vector parameters; + parameters.push_back(Value(info.info.file_path)); + + // Now bind, using the normal table function bind mechanism + CTableInternalBindInfo bind_info(context, parameters, named_parameters, expected_types, expected_names, *result, + tf_info); + tf_info.bind(reinterpret_cast(&bind_info)); + if (!bind_info.success) { + throw BinderException(bind_info.error); + } + + return std::move(result); +} + +} // namespace +} // namespace duckdb + +void duckdb_copy_function_set_copy_from_function(duckdb_copy_function copy_function, + duckdb_table_function table_function) { + auto ©_function_ref = *reinterpret_cast(copy_function); + if (!copy_function || !table_function) { + return; + } + auto &tf = *reinterpret_cast(table_function); + auto &tf_info = tf.function_info->Cast(); + + if (tf.name.empty()) { + // Take the name from the copy function if not set + tf.name = copy_function_ref.name; + } + + if (!tf_info.bind || !tf_info.init || !tf_info.function) { + return; + } + for (auto it = tf.named_parameters.begin(); it != tf.named_parameters.end(); it++) { + if (duckdb::TypeVisitor::Contains(it->second, duckdb::LogicalTypeId::INVALID)) { + return; + } + } + for (const auto &argument : tf.arguments) { + if (duckdb::TypeVisitor::Contains(argument, duckdb::LogicalTypeId::INVALID)) { + return; + } + } + + // Set the bind callback to mark this as a "copy from" capable function + copy_function_ref.copy_from_bind = duckdb::CCopyFromBind; + copy_function_ref.copy_from_function = tf; +} + +idx_t duckdb_table_function_bind_get_result_column_count(duckdb_bind_info bind_info) { + if (!bind_info) { + return 0; + } + auto &bind_info_ref = *reinterpret_cast(bind_info); + return bind_info_ref.return_types.size(); +} + +duckdb_logical_type duckdb_table_function_bind_get_result_column_type(duckdb_bind_info bind_info, idx_t col_idx) { + if (!bind_info) { + return nullptr; + } + auto &bind_info_ref = *reinterpret_cast(bind_info); + if (col_idx >= bind_info_ref.return_types.size()) { + return nullptr; + } + return reinterpret_cast(new duckdb::LogicalType(bind_info_ref.return_types[col_idx])); +} + +const char *duckdb_table_function_bind_get_result_column_name(duckdb_bind_info bind_info, idx_t col_idx) { + if (!bind_info) { + return nullptr; + } + auto &bind_info_ref = *reinterpret_cast(bind_info); + if (col_idx >= bind_info_ref.names.size()) { + return nullptr; + } + return bind_info_ref.names[col_idx].c_str(); +} + +//---------------------------------------------------------------------------------------------------------------------- +// Register +//---------------------------------------------------------------------------------------------------------------------- +duckdb_state duckdb_register_copy_function(duckdb_connection connection, duckdb_copy_function copy_function) { + if (!connection || !copy_function) { + return DuckDBError; + } + + auto ©_function_ref = *reinterpret_cast(copy_function); + + // Check that the copy function has a valid name + if (copy_function_ref.name.empty()) { + return DuckDBError; + } + + auto &info = copy_function_ref.function_info->Cast(); + + auto is_copy_to = false; + auto is_copy_from = copy_function_ref.copy_from_bind != nullptr; + + if (info.sink) { + // Set the copy function callbacks + is_copy_to = true; + copy_function_ref.copy_to_bind = duckdb::CCopyToBind; + copy_function_ref.copy_to_initialize_global = duckdb::CCopyToGlobalInit; + copy_function_ref.copy_to_initialize_local = duckdb::CCopyToLocalInit; + copy_function_ref.copy_to_sink = duckdb::CCopyToSink; + copy_function_ref.copy_to_combine = duckdb::CCopyToCombine; + copy_function_ref.copy_to_finalize = duckdb::CCopyToFinalize; + } + + if (!is_copy_to && !is_copy_from) { + // At least one of copy to or copy from must be implemented + return DuckDBError; + } + + auto &conn = *reinterpret_cast(connection); + try { + conn.context->RunFunctionInTransaction([&]() { + auto &catalog = duckdb::Catalog::GetSystemCatalog(*conn.context); + duckdb::CreateCopyFunctionInfo cp_info(copy_function_ref); + cp_info.on_conflict = duckdb::OnCreateConflict::ALTER_ON_CONFLICT; + catalog.CreateCopyFunction(*conn.context, cp_info); + }); + } catch (...) { // LCOV_EXCL_START + return DuckDBError; + } // LCOV_EXCL_STOP + return DuckDBSuccess; +} diff --git a/src/duckdb/src/main/capi/data_chunk-c.cpp b/src/duckdb/src/main/capi/data_chunk-c.cpp index 7274852c4..77f6482ab 100644 --- a/src/duckdb/src/main/capi/data_chunk-c.cpp +++ b/src/duckdb/src/main/capi/data_chunk-c.cpp @@ -167,20 +167,20 @@ idx_t duckdb_list_vector_get_size(duckdb_vector vector) { duckdb_state duckdb_list_vector_set_size(duckdb_vector vector, idx_t size) { if (!vector) { - return duckdb_state::DuckDBError; + return DuckDBError; } auto v = reinterpret_cast(vector); duckdb::ListVector::SetListSize(*v, size); - return duckdb_state::DuckDBSuccess; + return DuckDBSuccess; } duckdb_state duckdb_list_vector_reserve(duckdb_vector vector, idx_t required_capacity) { if (!vector) { - return duckdb_state::DuckDBError; + return DuckDBError; } auto v = reinterpret_cast(vector); duckdb::ListVector::Reserve(*v, required_capacity); - return duckdb_state::DuckDBSuccess; + return DuckDBSuccess; } duckdb_vector duckdb_struct_vector_get_child(duckdb_vector vector, idx_t index) { diff --git a/src/duckdb/src/main/capi/file_system-c.cpp b/src/duckdb/src/main/capi/file_system-c.cpp index af82daa6c..e697c363d 100644 --- a/src/duckdb/src/main/capi/file_system-c.cpp +++ b/src/duckdb/src/main/capi/file_system-c.cpp @@ -3,7 +3,6 @@ namespace duckdb { namespace { struct CFileSystem { - FileSystem &fs; ErrorData error_data; diff --git a/src/duckdb/src/main/capi/prepared-c.cpp b/src/duckdb/src/main/capi/prepared-c.cpp index 28b2f011f..ac5b638f8 100644 --- a/src/duckdb/src/main/capi/prepared-c.cpp +++ b/src/duckdb/src/main/capi/prepared-c.cpp @@ -88,7 +88,13 @@ duckdb_state duckdb_prepare(duckdb_connection connection, const char *query, const char *duckdb_prepare_error(duckdb_prepared_statement prepared_statement) { auto wrapper = reinterpret_cast(prepared_statement); - if (!wrapper || !wrapper->statement || !wrapper->statement->HasError()) { + if (!wrapper) { + return nullptr; + } + if (!wrapper->success) { + return wrapper->error_data.Message().c_str(); + } + if (!wrapper->statement || !wrapper->statement->HasError()) { return nullptr; } return wrapper->statement->error.Message().c_str(); @@ -191,7 +197,7 @@ const char *duckdb_prepared_statement_column_name(duckdb_prepared_statement prep } auto &names = wrapper->statement->GetNames(); - if (col_idx < 0 || col_idx >= names.size()) { + if (col_idx >= names.size()) { return nullptr; } return strdup(names[col_idx].c_str()); @@ -204,7 +210,7 @@ duckdb_logical_type duckdb_prepared_statement_column_logical_type(duckdb_prepare return nullptr; } auto types = wrapper->statement->GetTypes(); - if (col_idx < 0 || col_idx >= types.size()) { + if (col_idx >= types.size()) { return nullptr; } return reinterpret_cast(new LogicalType(types[col_idx])); @@ -229,9 +235,10 @@ duckdb_state duckdb_bind_value(duckdb_prepared_statement prepared_statement, idx return DuckDBError; } if (param_idx <= 0 || param_idx > wrapper->statement->named_param_map.size()) { - wrapper->statement->error = + wrapper->error_data = duckdb::InvalidInputException("Can not bind to parameter number %d, statement only has %d parameter(s)", param_idx, wrapper->statement->named_param_map.size()); + wrapper->success = false; return DuckDBError; } auto identifier = duckdb_parameter_name_internal(prepared_statement, param_idx); diff --git a/src/duckdb/src/main/capi/scalar_function-c.cpp b/src/duckdb/src/main/capi/scalar_function-c.cpp index 7233b2c20..f33759199 100644 --- a/src/duckdb/src/main/capi/scalar_function-c.cpp +++ b/src/duckdb/src/main/capi/scalar_function-c.cpp @@ -28,6 +28,7 @@ struct CScalarFunctionInfo : public ScalarFunctionInfo { struct CScalarFunctionBindData : public FunctionData { explicit CScalarFunctionBindData(CScalarFunctionInfo &info) : info(info) { } + ~CScalarFunctionBindData() override { if (bind_data && delete_callback) { delete_callback(bind_data); @@ -45,6 +46,7 @@ struct CScalarFunctionBindData : public FunctionData { } return std::move(copy); } + bool Equals(const FunctionData &other_p) const override { auto &other = other_p.Cast(); return info.extra_info == other.info.extra_info && info.function == other.info.function; @@ -148,7 +150,7 @@ void CAPIScalarFunction(DataChunk &input, ExpressionState &state, Vector &result if (!function_info.success) { throw InvalidInputException(function_info.error); } - if (all_const && (input.size() == 1 || function.function.stability != FunctionStability::VOLATILE)) { + if (all_const && (input.size() == 1 || function.function.GetStability() != FunctionStability::VOLATILE)) { result.SetVectorType(VectorType::CONSTANT_VECTOR); } } @@ -198,7 +200,7 @@ void duckdb_scalar_function_set_special_handling(duckdb_scalar_function function return; } auto &scalar_function = GetCScalarFunction(function); - scalar_function.null_handling = duckdb::FunctionNullHandling::SPECIAL_HANDLING; + scalar_function.SetNullHandling(duckdb::FunctionNullHandling::SPECIAL_HANDLING); } void duckdb_scalar_function_set_volatile(duckdb_scalar_function function) { @@ -206,7 +208,7 @@ void duckdb_scalar_function_set_volatile(duckdb_scalar_function function) { return; } auto &scalar_function = GetCScalarFunction(function); - scalar_function.stability = duckdb::FunctionStability::VOLATILE; + scalar_function.SetVolatile(); } void duckdb_scalar_function_add_parameter(duckdb_scalar_function function, duckdb_logical_type type) { @@ -224,7 +226,7 @@ void duckdb_scalar_function_set_return_type(duckdb_scalar_function function, duc } auto &scalar_function = GetCScalarFunction(function); auto logical_type = reinterpret_cast(type); - scalar_function.return_type = *logical_type; + scalar_function.SetReturnType(*logical_type); } void *duckdb_scalar_function_get_extra_info(duckdb_function_info info) { @@ -390,8 +392,8 @@ duckdb_state duckdb_register_scalar_function_set(duckdb_connection connection, d if (scalar_function.name.empty() || !info.function) { return DuckDBError; } - if (duckdb::TypeVisitor::Contains(scalar_function.return_type, duckdb::LogicalTypeId::INVALID) || - duckdb::TypeVisitor::Contains(scalar_function.return_type, duckdb::LogicalTypeId::ANY)) { + if (duckdb::TypeVisitor::Contains(scalar_function.GetReturnType(), duckdb::LogicalTypeId::INVALID) || + duckdb::TypeVisitor::Contains(scalar_function.GetReturnType(), duckdb::LogicalTypeId::ANY)) { return DuckDBError; } for (const auto &argument : scalar_function.arguments) { diff --git a/src/duckdb/src/main/capi/table_description-c.cpp b/src/duckdb/src/main/capi/table_description-c.cpp index 26624bbfc..cfcd01c43 100644 --- a/src/duckdb/src/main/capi/table_description-c.cpp +++ b/src/duckdb/src/main/capi/table_description-c.cpp @@ -1,5 +1,5 @@ -#include "duckdb/main/capi/capi_internal.hpp" #include "duckdb/common/string_util.hpp" +#include "duckdb/main/capi/capi_internal.hpp" using duckdb::Connection; using duckdb::ErrorData; @@ -68,14 +68,14 @@ const char *duckdb_table_description_error(duckdb_table_description table) { return wrapper->error.c_str(); } -duckdb_state GetTableDescription(TableDescriptionWrapper *wrapper, idx_t index) { +duckdb_state GetTableDescription(TableDescriptionWrapper *wrapper, duckdb::optional_idx index) { if (!wrapper) { return DuckDBError; } auto &table = wrapper->description; - if (index >= table->columns.size()) { - wrapper->error = duckdb::StringUtil::Format("Column index %d is out of range, table only has %d columns", index, - table->columns.size()); + if (index.IsValid() && index.GetIndex() >= table->columns.size()) { + wrapper->error = duckdb::StringUtil::Format("Column index %d is out of range, table only has %d columns", + index.GetIndex(), table->columns.size()); return DuckDBError; } return DuckDBSuccess; @@ -97,6 +97,16 @@ duckdb_state duckdb_column_has_default(duckdb_table_description table_descriptio return DuckDBSuccess; } +idx_t duckdb_table_description_get_column_count(duckdb_table_description table_description) { + auto wrapper = reinterpret_cast(table_description); + if (GetTableDescription(wrapper, duckdb::optional_idx()) == DuckDBError) { + return 0; + } + + auto &table = wrapper->description; + return table->columns.size(); +} + char *duckdb_table_description_get_column_name(duckdb_table_description table_description, idx_t index) { auto wrapper = reinterpret_cast(table_description); if (GetTableDescription(wrapper, index) == DuckDBError) { @@ -113,3 +123,16 @@ char *duckdb_table_description_get_column_name(duckdb_table_description table_de return result; } + +duckdb_logical_type duckdb_table_description_get_column_type(duckdb_table_description table_description, idx_t index) { + auto wrapper = reinterpret_cast(table_description); + if (GetTableDescription(wrapper, index) == DuckDBError) { + return nullptr; + } + + auto &table = wrapper->description; + auto &column = table->columns[index]; + + auto logical_type = new duckdb::LogicalType(column.Type()); + return reinterpret_cast(logical_type); +} diff --git a/src/duckdb/src/main/capi/table_function-c.cpp b/src/duckdb/src/main/capi/table_function-c.cpp index deb382ebd..7a6ab6459 100644 --- a/src/duckdb/src/main/capi/table_function-c.cpp +++ b/src/duckdb/src/main/capi/table_function-c.cpp @@ -3,65 +3,16 @@ #include "duckdb/common/types.hpp" #include "duckdb/function/table_function.hpp" #include "duckdb/main/capi/capi_internal.hpp" +#include "duckdb/main/capi/capi_internal_table.hpp" #include "duckdb/main/client_context.hpp" #include "duckdb/parser/parsed_data/create_table_function_info.hpp" #include "duckdb/storage/statistics/node_statistics.hpp" namespace duckdb { - +namespace { //===--------------------------------------------------------------------===// // Structures //===--------------------------------------------------------------------===// -struct CTableFunctionInfo : public TableFunctionInfo { - ~CTableFunctionInfo() override { - if (extra_info && delete_callback) { - delete_callback(extra_info); - } - extra_info = nullptr; - delete_callback = nullptr; - } - - duckdb_table_function_bind_t bind = nullptr; - duckdb_table_function_init_t init = nullptr; - duckdb_table_function_init_t local_init = nullptr; - duckdb_table_function_t function = nullptr; - void *extra_info = nullptr; - duckdb_delete_callback_t delete_callback = nullptr; -}; - -struct CTableBindData : public TableFunctionData { - explicit CTableBindData(CTableFunctionInfo &info) : info(info) { - } - ~CTableBindData() override { - if (bind_data && delete_callback) { - delete_callback(bind_data); - } - bind_data = nullptr; - delete_callback = nullptr; - } - - CTableFunctionInfo &info; - void *bind_data = nullptr; - duckdb_delete_callback_t delete_callback = nullptr; - unique_ptr stats; -}; - -struct CTableInternalBindInfo { - CTableInternalBindInfo(ClientContext &context, TableFunctionBindInput &input, vector &return_types, - vector &names, CTableBindData &bind_data, CTableFunctionInfo &function_info) - : context(context), input(input), return_types(return_types), names(names), bind_data(bind_data), - function_info(function_info), success(true) { - } - - ClientContext &context; - TableFunctionBindInput &input; - vector &return_types; - vector &names; - CTableBindData &bind_data; - CTableFunctionInfo &function_info; - bool success; - string error; -}; struct CTableInitData { ~CTableInitData() { @@ -160,7 +111,7 @@ unique_ptr CTableFunctionBind(ClientContext &context, TableFunctio D_ASSERT(info.bind && info.function && info.init); auto result = make_uniq(info); - CTableInternalBindInfo bind_info(context, input, return_types, names, *result, info); + CTableInternalBindInfo bind_info(context, input.inputs, input.named_parameters, return_types, names, *result, info); info.bind(ToCTableFunctionBindInfo(bind_info)); if (!bind_info.success) { throw BinderException(bind_info.error); @@ -216,6 +167,7 @@ void CTableFunction(ClientContext &context, TableFunctionInput &data_p, DataChun } } +} // namespace } // namespace duckdb //===--------------------------------------------------------------------===// @@ -398,7 +350,7 @@ idx_t duckdb_bind_get_parameter_count(duckdb_bind_info info) { return 0; } auto &bind_info = GetCTableFunctionBindInfo(info); - return bind_info.input.inputs.size(); + return bind_info.parameters.size(); } duckdb_value duckdb_bind_get_parameter(duckdb_bind_info info, idx_t index) { @@ -406,7 +358,7 @@ duckdb_value duckdb_bind_get_parameter(duckdb_bind_info info, idx_t index) { return nullptr; } auto &bind_info = GetCTableFunctionBindInfo(info); - return reinterpret_cast(new duckdb::Value(bind_info.input.inputs[index])); + return reinterpret_cast(new duckdb::Value(bind_info.parameters[index])); } duckdb_value duckdb_bind_get_named_parameter(duckdb_bind_info info, const char *name) { @@ -414,8 +366,8 @@ duckdb_value duckdb_bind_get_named_parameter(duckdb_bind_info info, const char * return nullptr; } auto &bind_info = GetCTableFunctionBindInfo(info); - auto t = bind_info.input.named_parameters.find(name); - if (t == bind_info.input.named_parameters.end()) { + auto t = bind_info.named_parameters.find(name); + if (t == bind_info.named_parameters.end()) { return nullptr; } else { return reinterpret_cast(new duckdb::Value(t->second)); diff --git a/src/duckdb/src/main/client_config.cpp b/src/duckdb/src/main/client_config.cpp index 868c80730..e8e7d8d86 100644 --- a/src/duckdb/src/main/client_config.cpp +++ b/src/duckdb/src/main/client_config.cpp @@ -8,8 +8,8 @@ bool ClientConfig::AnyVerification() const { return query_verification_enabled || verify_external || verify_serializer || verify_fetch_row; } -void ClientConfig::SetUserVariable(const string &name, Value value) { - user_variables[name] = std::move(value); +void ClientConfig::SetUserVariable(const String &name, Value value) { + user_variables[name.ToStdString()] = std::move(value); } bool ClientConfig::GetUserVariable(const string &name, Value &result) { diff --git a/src/duckdb/src/main/client_context.cpp b/src/duckdb/src/main/client_context.cpp index f52fbabdd..b7433074f 100644 --- a/src/duckdb/src/main/client_context.cpp +++ b/src/duckdb/src/main/client_context.cpp @@ -51,6 +51,7 @@ #include "duckdb/logging/log_type.hpp" #include "duckdb/logging/log_manager.hpp" #include "duckdb/main/settings.hpp" +#include "duckdb/main/result_set_manager.hpp" namespace duckdb { @@ -333,7 +334,8 @@ unique_ptr ClientContext::FetchResultInternal(ClientContextLock &lo D_ASSERT(active_query->prepared); auto &executor = GetExecutor(); auto &prepared = *active_query->prepared; - bool create_stream_result = prepared.properties.allow_stream_result && pending.allow_stream_result; + bool create_stream_result = + prepared.properties.output_type == QueryResultOutputType::ALLOW_STREAMING && pending.allow_stream_result; unique_ptr result; D_ASSERT(executor.HasResultCollector()); // we have a result collector - fetch the result directly from the result collector @@ -357,10 +359,10 @@ static bool IsExplainAnalyze(SQLStatement *statement) { return explain.explain_type == ExplainType::EXPLAIN_ANALYZE; } -shared_ptr -ClientContext::CreatePreparedStatementInternal(ClientContextLock &lock, const string &query, - unique_ptr statement, - optional_ptr> values) { +shared_ptr ClientContext::CreatePreparedStatementInternal(ClientContextLock &lock, + const string &query, + unique_ptr statement, + PendingQueryParameters parameters) { StatementType statement_type = statement->type; auto result = make_shared_ptr(statement_type); @@ -368,8 +370,8 @@ ClientContext::CreatePreparedStatementInternal(ClientContextLock &lock, const st profiler.StartQuery(query, IsExplainAnalyze(statement.get()), true); profiler.StartPhase(MetricsType::PLANNER); Planner logical_planner(*this); - if (values) { - auto ¶meter_values = *values; + if (parameters.parameters) { + auto ¶meter_values = *parameters.parameters; for (auto &value : parameter_values) { logical_planner.parameter_data.emplace(value.first, BoundParameterData(value.second)); } @@ -412,10 +414,10 @@ ClientContext::CreatePreparedStatementInternal(ClientContextLock &lock, const st return result; } -shared_ptr -ClientContext::CreatePreparedStatement(ClientContextLock &lock, const string &query, unique_ptr statement, - optional_ptr> values, - PreparedStatementMode mode) { +shared_ptr ClientContext::CreatePreparedStatement(ClientContextLock &lock, const string &query, + unique_ptr statement, + PendingQueryParameters parameters, + PreparedStatementMode mode) { // check if any client context state could request a rebind bool can_request_rebind = false; for (auto &state : registered_state->States()) { @@ -428,7 +430,7 @@ ClientContext::CreatePreparedStatement(ClientContextLock &lock, const string &qu // if any registered state can request a rebind we do the binding on a copy first shared_ptr result; try { - result = CreatePreparedStatementInternal(lock, query, statement->Copy(), values); + result = CreatePreparedStatementInternal(lock, query, statement->Copy(), parameters); } catch (std::exception &ex) { ErrorData error(ex); // check if any registered client context state wants to try a rebind @@ -457,7 +459,7 @@ ClientContext::CreatePreparedStatement(ClientContextLock &lock, const string &qu // an extension wants to do a rebind - do it once } - return CreatePreparedStatementInternal(lock, query, std::move(statement), values); + return CreatePreparedStatementInternal(lock, query, std::move(statement), parameters); } QueryProgress ClientContext::GetQueryProgress() { @@ -483,8 +485,7 @@ void ClientContext::RebindPreparedStatement(ClientContextLock &lock, const strin "an unbound statement so rebinding cannot be done"); } // catalog was modified: rebind the statement before execution - auto new_prepared = - CreatePreparedStatement(lock, query, prepared->unbound_statement->Copy(), parameters.parameters); + auto new_prepared = CreatePreparedStatement(lock, query, prepared->unbound_statement->Copy(), parameters); D_ASSERT(new_prepared->properties.bound_all_parameters); new_prepared->properties.parameter_count = prepared->properties.parameter_count; prepared = std::move(new_prepared); @@ -539,7 +540,8 @@ ClientContext::PendingPreparedStatementInternal(ClientContextLock &lock, query_progress.Restart(); } - auto stream_result = parameters.allow_stream_result && statement_data.properties.allow_stream_result; + const auto stream_result = parameters.query_parameters.output_type == QueryResultOutputType::ALLOW_STREAMING && + statement_data.properties.output_type == QueryResultOutputType::ALLOW_STREAMING; // Decide how to get the result collector. get_result_collector_t get_collector = PhysicalResultCollector::GetResultCollector; @@ -547,7 +549,9 @@ ClientContext::PendingPreparedStatementInternal(ClientContextLock &lock, if (!stream_result && client_config.get_result_collector) { get_collector = client_config.get_result_collector; } - statement_data.is_streaming = stream_result; + statement_data.output_type = + stream_result ? QueryResultOutputType::ALLOW_STREAMING : QueryResultOutputType::FORCE_MATERIALIZED; + statement_data.memory_type = parameters.query_parameters.memory_type; // Get the result collector and initialize the executor. auto &collector = get_collector(*this, statement_data); @@ -707,7 +711,8 @@ unique_ptr ClientContext::PrepareInternal(ClientContextLock & shared_ptr prepared_data; auto unbound_statement = statement->Copy(); RunFunctionInTransactionInternal( - lock, [&]() { prepared_data = CreatePreparedStatement(lock, statement_query, std::move(statement)); }, false); + lock, [&]() { prepared_data = CreatePreparedStatement(lock, statement_query, std::move(statement), {}); }, + false); prepared_data->unbound_statement = std::move(unbound_statement); return make_uniq(shared_from_this(), std::move(prepared_data), std::move(statement_query), std::move(named_param_map)); @@ -775,10 +780,10 @@ unique_ptr ClientContext::Execute(const string &query, shared_ptr

ClientContext::Execute(const string &query, shared_ptr &prepared, case_insensitive_map_t &values, - bool allow_stream_result) { + QueryParameters query_parameters) { PendingQueryParameters parameters; parameters.parameters = &values; - parameters.allow_stream_result = allow_stream_result; + parameters.query_parameters = query_parameters; return Execute(query, prepared, parameters); } @@ -790,7 +795,7 @@ unique_ptr ClientContext::PendingStatementInternal(ClientCon PreparedStatement::VerifyParameters(*parameters.parameters, statement->named_param_map); } - auto prepared = CreatePreparedStatement(lock, query, std::move(statement), parameters.parameters, + auto prepared = CreatePreparedStatement(lock, query, std::move(statement), parameters, PreparedStatementMode::PREPARE_AND_EXECUTE); idx_t parameter_count = !parameters.parameters ? 0 : parameters.parameters->size(); @@ -807,13 +812,9 @@ unique_ptr ClientContext::PendingStatementInternal(ClientCon return PendingPreparedStatementInternal(lock, std::move(prepared), parameters); } -unique_ptr -ClientContext::RunStatementInternal(ClientContextLock &lock, const string &query, unique_ptr statement, - bool allow_stream_result, - optional_ptr> params, bool verify) { - PendingQueryParameters parameters; - parameters.allow_stream_result = allow_stream_result; - parameters.parameters = params; +unique_ptr ClientContext::RunStatementInternal(ClientContextLock &lock, const string &query, + unique_ptr statement, + const PendingQueryParameters ¶meters, bool verify) { auto pending = PendingQueryInternal(lock, std::move(statement), parameters, verify); if (pending->HasError()) { return ErrorResult(pending->GetErrorObject()); @@ -846,7 +847,7 @@ unique_ptr ClientContext::PendingStatementOrPreparedStatemen // in case this is a select query, we verify the original statement ErrorData error; try { - error = VerifyQuery(lock, query, std::move(statement), parameters.parameters); + error = VerifyQuery(lock, query, std::move(statement), parameters); } catch (std::exception &ex) { error = ErrorData(ex); } @@ -958,15 +959,15 @@ void ClientContext::LogQueryInternal(ClientContextLock &, const string &query) { client_data->log_query_writer->Sync(); } -unique_ptr ClientContext::Query(unique_ptr statement, bool allow_stream_result) { - auto pending_query = PendingQuery(std::move(statement), allow_stream_result); +unique_ptr ClientContext::Query(unique_ptr statement, QueryParameters parameters) { + auto pending_query = PendingQuery(std::move(statement), parameters); if (pending_query->HasError()) { return ErrorResult(pending_query->GetErrorObject()); } return pending_query->Execute(); } -unique_ptr ClientContext::Query(const string &query, bool allow_stream_result) { +unique_ptr ClientContext::Query(const string &query, QueryParameters query_parameters) { auto lock = LockContext(); vector> statements; @@ -991,7 +992,10 @@ unique_ptr ClientContext::Query(const string &query, bool allow_str auto &statement = statements[i]; bool is_last_statement = i + 1 == statements.size(); PendingQueryParameters parameters; - parameters.allow_stream_result = allow_stream_result && is_last_statement; + parameters.query_parameters = query_parameters; + if (!is_last_statement) { + parameters.query_parameters.output_type = QueryResultOutputType::FORCE_MATERIALIZED; + } auto pending_query = PendingQueryInternal(*lock, std::move(statement), parameters); auto has_result = pending_query->properties.return_type == StatementReturnType::QUERY_RESULT; unique_ptr current_result; @@ -1032,20 +1036,27 @@ vector> ClientContext::ParseStatements(ClientContextLoc return ParseStatementsInternal(lock, query); } -unique_ptr ClientContext::PendingQuery(const string &query, bool allow_stream_result) { +unique_ptr ClientContext::PendingQuery(const string &query, QueryParameters parameters) { case_insensitive_map_t empty_param_list; - return PendingQuery(query, empty_param_list, allow_stream_result); + return PendingQuery(query, empty_param_list, parameters); } unique_ptr ClientContext::PendingQuery(unique_ptr statement, - bool allow_stream_result) { + QueryParameters parameters) { case_insensitive_map_t empty_param_list; - return PendingQuery(std::move(statement), empty_param_list, allow_stream_result); + return PendingQuery(std::move(statement), empty_param_list, parameters); } unique_ptr ClientContext::PendingQuery(const string &query, case_insensitive_map_t &values, - bool allow_stream_result) { + QueryParameters parameters) { + PendingQueryParameters params; + params.parameters = values; + params.query_parameters = parameters; + return PendingQuery(query, params); +} + +unique_ptr ClientContext::PendingQuery(const string &query, PendingQueryParameters parameters) { auto lock = LockContext(); try { InitialCleanup(*lock); @@ -1058,11 +1069,7 @@ unique_ptr ClientContext::PendingQuery(const string &query, throw InvalidInputException("Cannot prepare multiple statements at once!"); } - PendingQueryParameters params; - params.allow_stream_result = allow_stream_result; - params.parameters = values; - - return PendingQueryInternal(*lock, std::move(statements[0]), params, true); + return PendingQueryInternal(*lock, std::move(statements[0]), parameters, true); } catch (std::exception &ex) { ErrorData error(ex); ProcessError(error, query); @@ -1072,14 +1079,14 @@ unique_ptr ClientContext::PendingQuery(const string &query, unique_ptr ClientContext::PendingQuery(unique_ptr statement, case_insensitive_map_t &values, - bool allow_stream_result) { + QueryParameters parameters) { auto lock = LockContext(); auto query = statement->query; try { InitialCleanup(*lock); PendingQueryParameters params; - params.allow_stream_result = allow_stream_result; + params.query_parameters = parameters; params.parameters = values; return PendingQueryInternal(*lock, std::move(statement), params, true); @@ -1335,7 +1342,7 @@ unordered_set ClientContext::GetTableNames(const string &query, const bo unique_ptr ClientContext::PendingQueryInternal(ClientContextLock &lock, const shared_ptr &relation, - bool allow_stream_result) { + QueryParameters query_parameters) { InitialCleanup(lock); string query; @@ -1347,20 +1354,23 @@ unique_ptr ClientContext::PendingQueryInternal(ClientContext // verify read only statements by running a select statement auto select = make_uniq(); select->node = relation->GetQueryNode(); - RunStatementInternal(lock, query, std::move(select), false, nullptr); + PendingQueryParameters parameters; + parameters.query_parameters = query_parameters; + parameters.query_parameters.output_type = QueryResultOutputType::FORCE_MATERIALIZED; + RunStatementInternal(lock, query, std::move(select), parameters); } } auto relation_stmt = make_uniq(relation); PendingQueryParameters parameters; - parameters.allow_stream_result = allow_stream_result; + parameters.query_parameters = query_parameters; return PendingQueryInternal(lock, std::move(relation_stmt), parameters); } unique_ptr ClientContext::PendingQuery(const shared_ptr &relation, - bool allow_stream_result) { + QueryParameters query_parameters) { auto lock = LockContext(); - return PendingQueryInternal(*lock, relation, allow_stream_result); + return PendingQueryInternal(*lock, relation, query_parameters); } unique_ptr ClientContext::Execute(const shared_ptr &relation) { @@ -1443,6 +1453,7 @@ ParserOptions ClientContext::GetParserOptions() const { options.integer_division = DBConfig::GetSetting(*this); options.max_expression_depth = client_config.max_expression_depth; options.extensions = &DBConfig::GetConfig(*this).parser_extensions; + options.parser_override_setting = DBConfig::GetConfig(*this).options.allow_parser_override_extension; return options; } diff --git a/src/duckdb/src/main/client_data.cpp b/src/duckdb/src/main/client_data.cpp index 1348c0b09..0c63cca16 100644 --- a/src/duckdb/src/main/client_data.cpp +++ b/src/duckdb/src/main/client_data.cpp @@ -56,6 +56,9 @@ class ClientBufferManager : public BufferManager { return buffer_manager.ReAllocate(handle, block_size); } BufferHandle Pin(shared_ptr &handle) override { + return Pin(QueryContext(), handle); + } + BufferHandle Pin(const QueryContext &context, shared_ptr &handle) override { return buffer_manager.Pin(handle); } void Prefetch(vector> &handles) override { @@ -116,6 +119,9 @@ class ClientBufferManager : public BufferManager { return buffer_manager.SetSwapLimit(limit); } + BlockManager &GetTemporaryBlockManager() override { + return buffer_manager.GetTemporaryBlockManager(); + } vector GetTemporaryFiles() override { return buffer_manager.GetTemporaryFiles(); } diff --git a/src/duckdb/src/main/client_verify.cpp b/src/duckdb/src/main/client_verify.cpp index 05b190b07..2287c8b9e 100644 --- a/src/duckdb/src/main/client_verify.cpp +++ b/src/duckdb/src/main/client_verify.cpp @@ -22,7 +22,7 @@ static void ThrowIfExceptionIsInternal(StatementVerifier &verifier) { } ErrorData ClientContext::VerifyQuery(ClientContextLock &lock, const string &query, unique_ptr statement, - optional_ptr> parameters) { + PendingQueryParameters query_parameters) { D_ASSERT(statement->type == StatementType::SELECT_STATEMENT); // Aggressive query verification @@ -32,6 +32,10 @@ ErrorData ClientContext::VerifyQuery(ClientContextLock &lock, const string &quer bool run_slow_verifiers = false; #endif + auto parameters = query_parameters.parameters; + query_parameters.query_parameters.output_type = QueryResultOutputType::FORCE_MATERIALIZED; + query_parameters.query_parameters.memory_type = QueryResultMemoryType::IN_MEMORY; + // The purpose of this function is to test correctness of otherwise hard to test features: // Copy() of statements and expressions // Serialize()/Deserialize() of expressions @@ -98,7 +102,7 @@ ErrorData ClientContext::VerifyQuery(ClientContextLock &lock, const string &quer bool any_failed = original->Run(*this, query, [&](const string &q, unique_ptr s, optional_ptr> params) { - return RunStatementInternal(lock, q, std::move(s), false, params, false); + return RunStatementInternal(lock, q, std::move(s), query_parameters, false); }); if (!any_failed) { statement_verifiers.emplace_back( @@ -109,7 +113,7 @@ ErrorData ClientContext::VerifyQuery(ClientContextLock &lock, const string &quer bool failed = verifier->Run(*this, query, [&](const string &q, unique_ptr s, optional_ptr> params) { - return RunStatementInternal(lock, q, std::move(s), false, params, false); + return RunStatementInternal(lock, q, std::move(s), query_parameters, false); }); any_failed = any_failed || failed; } @@ -120,7 +124,7 @@ ErrorData ClientContext::VerifyQuery(ClientContextLock &lock, const string &quer *this, query, [&](const string &q, unique_ptr s, optional_ptr> params) { - return RunStatementInternal(lock, q, std::move(s), false, params, false); + return RunStatementInternal(lock, q, std::move(s), query_parameters, false); }); if (!failed) { // PreparedStatementVerifier fails if it runs into a ParameterNotAllowedException, which is OK @@ -155,7 +159,7 @@ ErrorData ClientContext::VerifyQuery(ClientContextLock &lock, const string &quer *this, explain_q, [&](const string &q, unique_ptr s, optional_ptr> params) { - return RunStatementInternal(lock, q, std::move(s), false, params, false); + return RunStatementInternal(lock, q, std::move(s), query_parameters, false); }); if (explain_failed) { // LCOV_EXCL_START @@ -173,7 +177,8 @@ ErrorData ClientContext::VerifyQuery(ClientContextLock &lock, const string &quer // test with a random width config.max_width = random.NextRandomInteger() % 500; BoxRenderer renderer(config); - renderer.ToString(*this, original->materialized_result->names, original->materialized_result->Collection()); + auto pinned_result_set = original->materialized_result->Pin(); + renderer.ToString(*this, original->materialized_result->names, pinned_result_set->collection); #endif } diff --git a/src/duckdb/src/main/config.cpp b/src/duckdb/src/main/config.cpp index 78b174902..bd34a1cfa 100644 --- a/src/duckdb/src/main/config.cpp +++ b/src/duckdb/src/main/config.cpp @@ -63,6 +63,7 @@ static const ConfigurationOption internal_options[] = { DUCKDB_GLOBAL(AllocatorFlushThresholdSetting), DUCKDB_GLOBAL(AllowCommunityExtensionsSetting), DUCKDB_SETTING(AllowExtensionsMetadataMismatchSetting), + DUCKDB_GLOBAL(AllowParserOverrideExtensionSetting), DUCKDB_GLOBAL(AllowPersistentSecretsSetting), DUCKDB_GLOBAL(AllowUnredactedSecretsSetting), DUCKDB_GLOBAL(AllowUnsignedExtensionsSetting), @@ -85,7 +86,9 @@ static const ConfigurationOption internal_options[] = { DUCKDB_SETTING_CALLBACK(DebugCheckpointAbortSetting), DUCKDB_LOCAL(DebugForceExternalSetting), DUCKDB_SETTING(DebugForceNoCrossProductSetting), + DUCKDB_SETTING_CALLBACK(DebugPhysicalTableScanExecutionStrategySetting), DUCKDB_SETTING(DebugSkipCheckpointOnCommitSetting), + DUCKDB_SETTING(DebugVerifyBlocksSetting), DUCKDB_SETTING_CALLBACK(DebugVerifyVectorSetting), DUCKDB_SETTING_CALLBACK(DebugWindowModeSetting), DUCKDB_GLOBAL(DefaultBlockSizeSetting), @@ -179,12 +182,12 @@ static const ConfigurationOption internal_options[] = { DUCKDB_GLOBAL(ZstdMinStringLengthSetting), FINAL_SETTING}; -static const ConfigurationAlias setting_aliases[] = {DUCKDB_SETTING_ALIAS("memory_limit", 83), - DUCKDB_SETTING_ALIAS("null_order", 33), - DUCKDB_SETTING_ALIAS("profiling_output", 102), - DUCKDB_SETTING_ALIAS("user", 117), - DUCKDB_SETTING_ALIAS("wal_autocheckpoint", 20), - DUCKDB_SETTING_ALIAS("worker_threads", 116), +static const ConfigurationAlias setting_aliases[] = {DUCKDB_SETTING_ALIAS("memory_limit", 86), + DUCKDB_SETTING_ALIAS("null_order", 36), + DUCKDB_SETTING_ALIAS("profiling_output", 105), + DUCKDB_SETTING_ALIAS("user", 120), + DUCKDB_SETTING_ALIAS("wal_autocheckpoint", 21), + DUCKDB_SETTING_ALIAS("worker_threads", 119), FINAL_ALIAS}; vector DBConfig::GetOptions() { @@ -326,9 +329,9 @@ void DBConfig::ResetOption(optional_ptr db, const Configuratio option.reset_global(db.get(), *this); } -void DBConfig::SetOption(const string &name, Value value) { +void DBConfig::SetOption(const String &name, Value value) { lock_guard l(config_lock); - options.set_variables[name] = std::move(value); + options.set_variables[name.ToStdString()] = std::move(value); } void DBConfig::ResetOption(const String &name) { @@ -440,8 +443,14 @@ LogicalType DBConfig::ParseLogicalType(const string &type) { return type_id; } +bool DBConfig::HasExtensionOption(const string &name) { + lock_guard l(config_lock); + return extension_parameters.find(name) != extension_parameters.end(); +} + void DBConfig::AddExtensionOption(const string &name, string description, LogicalType parameter, const Value &default_value, set_option_callback_t function, SetScope default_scope) { + lock_guard l(config_lock); extension_parameters.insert(make_pair( name, ExtensionOption(std::move(description), std::move(parameter), function, default_value, default_scope))); // copy over unrecognized options, if they match the new extension option @@ -517,8 +526,7 @@ void DBConfig::CheckLock(const String &name) { return; } // not allowed! - throw InvalidInputException("Cannot change configuration option \"%s\" - the configuration has been locked", - name.ToStdString()); + throw InvalidInputException("Cannot change configuration option \"%s\" - the configuration has been locked", name); } idx_t DBConfig::GetSystemMaxThreads(FileSystem &fs) { diff --git a/src/duckdb/src/main/connection.cpp b/src/duckdb/src/main/connection.cpp index e561a3cb9..b458721c1 100644 --- a/src/duckdb/src/main/connection.cpp +++ b/src/duckdb/src/main/connection.cpp @@ -19,7 +19,7 @@ namespace duckdb { Connection::Connection(DatabaseInstance &database) - : context(make_shared_ptr(database.shared_from_this())), warning_cb(nullptr) { + : context(make_shared_ptr(database.shared_from_this())) { auto &connection_manager = ConnectionManager::Get(database); connection_manager.AddConnection(*context); connection_manager.AssignConnectionId(*this); @@ -31,18 +31,15 @@ Connection::Connection(DatabaseInstance &database) } Connection::Connection(DuckDB &database) : Connection(*database.instance) { - // Initialization of warning_cb happens in the other constructor } -Connection::Connection(Connection &&other) noexcept : warning_cb(nullptr) { +Connection::Connection(Connection &&other) noexcept { std::swap(context, other.context); - std::swap(warning_cb, other.warning_cb); std::swap(connection_id, other.connection_id); } Connection &Connection::operator=(Connection &&other) noexcept { std::swap(context, other.context); - std::swap(warning_cb, other.warning_cb); std::swap(connection_id, other.connection_id); return *this; } @@ -98,40 +95,51 @@ void Connection::ForceParallelism() { ClientConfig::GetConfig(*context).verify_parallelism = true; } -unique_ptr Connection::SendQuery(const string &query) { - return context->Query(query, true); +unique_ptr Connection::SendQuery(const string &query, QueryParameters query_parameters) { + return context->Query(query, query_parameters); +} + +unique_ptr Connection::SendQuery(unique_ptr statement, QueryParameters query_parameters) { + return context->Query(std::move(statement), query_parameters); } unique_ptr Connection::Query(const string &query) { - auto result = context->Query(query, false); + QueryParameters query_parameters; + query_parameters.output_type = QueryResultOutputType::FORCE_MATERIALIZED; + auto result = context->Query(query, query_parameters); D_ASSERT(result->type == QueryResultType::MATERIALIZED_RESULT); return unique_ptr_cast(std::move(result)); } -unique_ptr Connection::Query(unique_ptr statement) { - auto result = context->Query(std::move(statement), false); +unique_ptr Connection::Query(unique_ptr statement, + QueryResultMemoryType memory_type) { + QueryParameters query_parameters; + query_parameters.output_type = QueryResultOutputType::FORCE_MATERIALIZED; + query_parameters.memory_type = memory_type; + auto result = context->Query(std::move(statement), query_parameters); D_ASSERT(result->type == QueryResultType::MATERIALIZED_RESULT); return unique_ptr_cast(std::move(result)); } -unique_ptr Connection::PendingQuery(const string &query, bool allow_stream_result) { - return context->PendingQuery(query, allow_stream_result); +unique_ptr Connection::PendingQuery(const string &query, QueryParameters query_parameters) { + return context->PendingQuery(query, query_parameters); } -unique_ptr Connection::PendingQuery(unique_ptr statement, bool allow_stream_result) { - return context->PendingQuery(std::move(statement), allow_stream_result); +unique_ptr Connection::PendingQuery(unique_ptr statement, + QueryParameters query_parameters) { + return context->PendingQuery(std::move(statement), query_parameters); } unique_ptr Connection::PendingQuery(const string &query, case_insensitive_map_t &named_values, - bool allow_stream_result) { - return context->PendingQuery(query, named_values, allow_stream_result); + QueryParameters query_parameters) { + return context->PendingQuery(query, named_values, query_parameters); } unique_ptr Connection::PendingQuery(unique_ptr statement, case_insensitive_map_t &named_values, - bool allow_stream_result) { - return context->PendingQuery(std::move(statement), named_values, allow_stream_result); + QueryParameters query_parameters) { + return context->PendingQuery(std::move(statement), named_values, query_parameters); } static case_insensitive_map_t ConvertParamListToMap(vector ¶m_list) { @@ -144,15 +152,19 @@ static case_insensitive_map_t ConvertParamListToMap(vector Connection::PendingQuery(const string &query, vector &values, - bool allow_stream_result) { + QueryParameters query_parameters) { auto named_params = ConvertParamListToMap(values); - return context->PendingQuery(query, named_params, allow_stream_result); + return context->PendingQuery(query, named_params, query_parameters); } unique_ptr Connection::PendingQuery(unique_ptr statement, vector &values, - bool allow_stream_result) { + QueryParameters query_parameters) { auto named_params = ConvertParamListToMap(values); - return context->PendingQuery(std::move(statement), named_params, allow_stream_result); + return context->PendingQuery(std::move(statement), named_params, query_parameters); +} + +unique_ptr Connection::PendingQuery(const string &query, PendingQueryParameters parameters) { + return context->PendingQuery(query, parameters); } unique_ptr Connection::Prepare(const string &query) { @@ -165,7 +177,11 @@ unique_ptr Connection::Prepare(unique_ptr state unique_ptr Connection::QueryParamsRecursive(const string &query, vector &values) { auto named_params = ConvertParamListToMap(values); - auto pending = PendingQuery(query, named_params, false); + PendingQueryParameters parameters; + parameters.parameters = &named_params; + parameters.query_parameters.output_type = QueryResultOutputType::FORCE_MATERIALIZED; + parameters.query_parameters.memory_type = QueryResultMemoryType::BUFFER_MANAGED; + auto pending = PendingQuery(query, parameters); if (pending->HasError()) { return make_uniq(pending->GetErrorObject()); } diff --git a/src/duckdb/src/main/database.cpp b/src/duckdb/src/main/database.cpp index 3d644d408..bb36e25de 100644 --- a/src/duckdb/src/main/database.cpp +++ b/src/duckdb/src/main/database.cpp @@ -32,6 +32,7 @@ #include "duckdb/common/http_util.hpp" #include "mbedtls_wrapper.hpp" #include "duckdb/main/database_file_path_manager.hpp" +#include "duckdb/main/result_set_manager.hpp" #ifndef DUCKDB_NO_THREADS #include "duckdb/common/thread.hpp" @@ -87,6 +88,7 @@ DatabaseInstance::~DatabaseInstance() { log_manager.reset(); external_file_cache.reset(); + result_set_manager.reset(); buffer_manager.reset(); @@ -283,10 +285,11 @@ void DatabaseInstance::Initialize(const char *database_path, DBConfig *user_conf buffer_manager = make_uniq(*this, config.options.temporary_directory); } - log_manager = make_shared_ptr(*this, LogConfig()); + log_manager = make_uniq(*this, LogConfig()); log_manager->Initialize(); external_file_cache = make_uniq(*this, config.options.enable_external_file_cache); + result_set_manager = make_uniq(*this); scheduler = make_uniq(*this); object_cache = make_uniq(); @@ -382,6 +385,10 @@ ExternalFileCache &DatabaseInstance::GetExternalFileCache() { return *external_file_cache; } +ResultSetManager &DatabaseInstance::GetResultSetManager() { + return *result_set_manager; +} + ConnectionManager &DatabaseInstance::GetConnectionManager() { return *connection_manager; } diff --git a/src/duckdb/src/main/database_file_path_manager.cpp b/src/duckdb/src/main/database_file_path_manager.cpp index 05adeadfe..2e107210a 100644 --- a/src/duckdb/src/main/database_file_path_manager.cpp +++ b/src/duckdb/src/main/database_file_path_manager.cpp @@ -5,30 +5,57 @@ namespace duckdb { +DatabasePathInfo::DatabasePathInfo(DatabaseManager &manager, string name_p, AccessMode access_mode) + : name(std::move(name_p)), access_mode(access_mode) { + attached_databases.insert(manager); +} + idx_t DatabaseFilePathManager::ApproxDatabaseCount() const { lock_guard path_lock(db_paths_lock); return db_paths.size(); } -InsertDatabasePathResult DatabaseFilePathManager::InsertDatabasePath(const string &path, const string &name, - OnCreateConflict on_conflict, +InsertDatabasePathResult DatabaseFilePathManager::InsertDatabasePath(DatabaseManager &manager, const string &path, + const string &name, OnCreateConflict on_conflict, AttachOptions &options) { if (path.empty() || path == IN_MEMORY_PATH) { return InsertDatabasePathResult::SUCCESS; } lock_guard path_lock(db_paths_lock); - auto entry = db_paths.emplace(path, DatabasePathInfo(name)); + auto entry = db_paths.emplace(path, DatabasePathInfo(manager, name, options.access_mode)); if (!entry.second) { auto &existing = entry.first->second; + bool already_exists = false; + bool attached_in_this_system = false; if (on_conflict == OnCreateConflict::IGNORE_ON_CONFLICT && existing.name == name) { - return InsertDatabasePathResult::ALREADY_EXISTS; + already_exists = true; + attached_in_this_system = existing.attached_databases.find(manager) != existing.attached_databases.end(); + } + if (options.access_mode == AccessMode::READ_ONLY && existing.access_mode == AccessMode::READ_ONLY) { + if (attached_in_this_system) { + return InsertDatabasePathResult::ALREADY_EXISTS; + } + // all attaches are in read-only mode - there is no conflict, just increase the reference count + existing.attached_databases.insert(manager); + existing.reference_count++; + } else { + if (already_exists) { + if (attached_in_this_system) { + return InsertDatabasePathResult::ALREADY_EXISTS; + } + throw BinderException( + "Unique file handle conflict: Cannot attach \"%s\" - the database file \"%s\" is in " + "the process of being detached", + name, path); + } + throw BinderException( + "Unique file handle conflict: Cannot attach \"%s\" - the database file \"%s\" is already " + "attached by database \"%s\"", + name, path, existing.name); } - throw BinderException("Unique file handle conflict: Cannot attach \"%s\" - the database file \"%s\" is already " - "attached by database \"%s\"", - name, path, existing.name); } - options.stored_database_path = make_uniq(*this, path, name); + options.stored_database_path = make_uniq(manager, *this, path, name); return InsertDatabasePathResult::SUCCESS; } @@ -37,7 +64,25 @@ void DatabaseFilePathManager::EraseDatabasePath(const string &path) { return; } lock_guard path_lock(db_paths_lock); - db_paths.erase(path); + auto entry = db_paths.find(path); + if (entry != db_paths.end()) { + if (entry->second.reference_count <= 1) { + db_paths.erase(entry); + } else { + entry->second.reference_count--; + } + } +} + +void DatabaseFilePathManager::DetachDatabase(DatabaseManager &manager, const string &path) { + if (path.empty() || path == IN_MEMORY_PATH) { + return; + } + lock_guard path_lock(db_paths_lock); + auto entry = db_paths.find(path); + if (entry != db_paths.end()) { + entry->second.attached_databases.erase(manager); + } } } // namespace duckdb diff --git a/src/duckdb/src/main/database_manager.cpp b/src/duckdb/src/main/database_manager.cpp index ae0a6447d..c5ac5ac6f 100644 --- a/src/duckdb/src/main/database_manager.cpp +++ b/src/duckdb/src/main/database_manager.cpp @@ -85,28 +85,38 @@ shared_ptr DatabaseManager::GetDatabaseInternal(const lock_gua shared_ptr DatabaseManager::AttachDatabase(ClientContext &context, AttachInfo &info, AttachOptions &options) { if (options.db_type.empty() || StringUtil::CIEquals(options.db_type, "duckdb")) { + // Start timing the ATTACH-delay step. + auto profiler = context.client_data->profiler; + profiler->StartTimer(MetricsType::WAITING_TO_ATTACH_LATENCY); + while (InsertDatabasePath(info, options) == InsertDatabasePathResult::ALREADY_EXISTS) { // database with this name and path already exists // first check if it exists within this transaction auto &meta_transaction = MetaTransaction::Get(context); auto existing_db = meta_transaction.GetReferencedDatabaseOwning(info.name); if (existing_db) { + profiler->EndTimer(MetricsType::WAITING_TO_ATTACH_LATENCY); // it does! return it return existing_db; } + // ... but it might not be done attaching yet! // verify the database has actually finished attaching prior to returning lock_guard guard(databases_lock); auto entry = databases.find(info.name); if (entry != databases.end()) { - // database ACTUALLY exists - return it + // The database ACTUALLY exists, so we return it. + profiler->EndTimer(MetricsType::WAITING_TO_ATTACH_LATENCY); return entry->second; } if (context.interrupted) { + profiler->EndTimer(MetricsType::WAITING_TO_ATTACH_LATENCY); throw InterruptException(); } } + profiler->EndTimer(MetricsType::WAITING_TO_ATTACH_LATENCY); } + auto &config = DBConfig::GetConfig(context); GetDatabaseType(context, info, config, options); if (!options.db_type.empty()) { @@ -270,7 +280,7 @@ idx_t DatabaseManager::ApproxDatabaseCount() { } InsertDatabasePathResult DatabaseManager::InsertDatabasePath(const AttachInfo &info, AttachOptions &options) { - return path_manager->InsertDatabasePath(info.path, info.name, info.on_conflict, options); + return path_manager->InsertDatabasePath(*this, info.path, info.name, info.on_conflict, options); } vector DatabaseManager::GetAttachedDatabasePaths() { @@ -293,7 +303,6 @@ vector DatabaseManager::GetAttachedDatabasePaths() { void DatabaseManager::GetDatabaseType(ClientContext &context, AttachInfo &info, const DBConfig &config, AttachOptions &options) { - // Test if the database is a DuckDB database file. if (StringUtil::CIEquals(options.db_type, "duckdb")) { options.db_type = ""; @@ -303,7 +312,7 @@ void DatabaseManager::GetDatabaseType(ClientContext &context, AttachInfo &info, // Try to extract the database type from the path. if (options.db_type.empty()) { auto &fs = FileSystem::GetFileSystem(context); - DBPathAndType::CheckMagicBytes(QueryContext(context), fs, info.path, options.db_type); + DBPathAndType::CheckMagicBytes(context, fs, info.path, options.db_type); } if (options.db_type.empty()) { diff --git a/src/duckdb/src/main/db_instance_cache.cpp b/src/duckdb/src/main/db_instance_cache.cpp index 57f4ee457..e8faeca69 100644 --- a/src/duckdb/src/main/db_instance_cache.cpp +++ b/src/duckdb/src/main/db_instance_cache.cpp @@ -139,7 +139,6 @@ shared_ptr DBInstanceCache::GetOrCreateInstance(const string &database, const std::function &on_create) { unique_lock lock(cache_lock, std::defer_lock); if (cache_instance) { - // While we do not own the lock, we cannot definitively say that the database instance does not exist. while (!lock.owns_lock()) { // The problem is, that we have to unlock the mutex in GetInstanceInternal, so we can non-blockingly wait diff --git a/src/duckdb/src/main/extension.cpp b/src/duckdb/src/main/extension.cpp index cd786d863..c982a4bc3 100644 --- a/src/duckdb/src/main/extension.cpp +++ b/src/duckdb/src/main/extension.cpp @@ -7,6 +7,8 @@ namespace duckdb { +constexpr const idx_t ParsedExtensionMetaData::FOOTER_SIZE; + Extension::~Extension() { } diff --git a/src/duckdb/src/main/extension/extension_helper.cpp b/src/duckdb/src/main/extension/extension_helper.cpp index 74add5379..5f64dc166 100644 --- a/src/duckdb/src/main/extension/extension_helper.cpp +++ b/src/duckdb/src/main/extension/extension_helper.cpp @@ -175,7 +175,6 @@ bool ExtensionHelper::CanAutoloadExtension(const string &ext_name) { string ExtensionHelper::AddExtensionInstallHintToErrorMsg(ClientContext &context, const string &base_error, const string &extension_name) { - return AddExtensionInstallHintToErrorMsg(DatabaseInstance::GetDatabase(context), base_error, extension_name); } string ExtensionHelper::AddExtensionInstallHintToErrorMsg(DatabaseInstance &db, const string &base_error, diff --git a/src/duckdb/src/main/extension/extension_load.cpp b/src/duckdb/src/main/extension/extension_load.cpp index 96e559ec0..eb995e5b0 100644 --- a/src/duckdb/src/main/extension/extension_load.cpp +++ b/src/duckdb/src/main/extension/extension_load.cpp @@ -76,7 +76,7 @@ struct ExtensionAccess { load_state.has_error = true; load_state.error_data = error ? ErrorData(error) - : ErrorData(ExceptionType::UNKNOWN_TYPE, "Extension has indicated an error occured during " + : ErrorData(ExceptionType::UNKNOWN_TYPE, "Extension has indicated an error occurred during " "initialization, but did not set an error message."); } @@ -591,7 +591,7 @@ void ExtensionHelper::LoadExternalExtensionInternal(DatabaseInstance &db, FileSy if (result == false) { throw FatalException( "Extension '%s' failed to initialize but did not return an error. This indicates an " - "error in the extension: C API extensions should return a boolean `true` to indicate succesful " + "error in the extension: C API extensions should return a boolean `true` to indicate successful " "initialization. " "This means that the Extension may be partially initialized resulting in an inconsistent state of " "DuckDB.", diff --git a/src/duckdb/src/main/http/http_util.cpp b/src/duckdb/src/main/http/http_util.cpp index a51fb3e7f..554346489 100644 --- a/src/duckdb/src/main/http/http_util.cpp +++ b/src/duckdb/src/main/http/http_util.cpp @@ -367,7 +367,9 @@ HTTPUtil::RunRequestWithRetry(const std::function(void) try { response = on_request(); - response->url = request.url; + if (response) { + response->url = request.url; + } } catch (IOException &e) { exception_error = e.what(); caught_e = std::current_exception(); diff --git a/src/duckdb/src/main/materialized_query_result.cpp b/src/duckdb/src/main/materialized_query_result.cpp index d319d5686..572f4f3bf 100644 --- a/src/duckdb/src/main/materialized_query_result.cpp +++ b/src/duckdb/src/main/materialized_query_result.cpp @@ -62,6 +62,10 @@ idx_t MaterializedQueryResult::RowCount() const { return collection ? collection->Count() : 0; } +bool MaterializedQueryResult::MoreRowsThan(idx_t row_count) { + return RowCount() >= row_count; +} + ColumnDataCollection &MaterializedQueryResult::Collection() { if (HasError()) { throw InvalidInputException("Attempting to get collection from an unsuccessful query result\n: Error %s", @@ -84,11 +88,7 @@ unique_ptr MaterializedQueryResult::TakeCollection() { return std::move(collection); } -unique_ptr MaterializedQueryResult::Fetch() { - return FetchRaw(); -} - -unique_ptr MaterializedQueryResult::FetchRaw() { +unique_ptr MaterializedQueryResult::FetchInternal() { if (HasError()) { throw InvalidInputException("Attempting to fetch from an unsuccessful query result\nError: %s", GetError()); } diff --git a/src/duckdb/src/main/prepared_statement.cpp b/src/duckdb/src/main/prepared_statement.cpp index 34d12cadc..49ff9ac94 100644 --- a/src/duckdb/src/main/prepared_statement.cpp +++ b/src/duckdb/src/main/prepared_statement.cpp @@ -110,7 +110,10 @@ unique_ptr PreparedStatement::PendingQuery(case_insensitive_ } D_ASSERT(data); - parameters.allow_stream_result = allow_stream_result && data->properties.allow_stream_result; + parameters.query_parameters.output_type = + allow_stream_result && data->properties.output_type == QueryResultOutputType::ALLOW_STREAMING + ? QueryResultOutputType::ALLOW_STREAMING + : QueryResultOutputType::FORCE_MATERIALIZED; auto result = context->PendingQuery(query, data, parameters); // The result should not contain any reference to the 'vector parameters.parameters' return result; diff --git a/src/duckdb/src/main/profiling_info.cpp b/src/duckdb/src/main/profiling_info.cpp index 8f744d51b..a1a2ffa10 100644 --- a/src/duckdb/src/main/profiling_info.cpp +++ b/src/duckdb/src/main/profiling_info.cpp @@ -2,6 +2,7 @@ #include "duckdb/common/enum_util.hpp" #include "duckdb/main/query_profiler.hpp" +#include "duckdb/logging/log_manager.hpp" #include "yyjson.hpp" @@ -23,12 +24,12 @@ ProfilingInfo::ProfilingInfo(const profiler_settings_t &n_settings, const idx_t // Reduce. if (depth == 0) { - auto op_metrics = DefaultOperatorSettings(); + auto op_metrics = OperatorScopeSettings(); for (const auto metric : op_metrics) { settings.erase(metric); } } else { - auto root_metrics = DefaultRootSettings(); + auto root_metrics = RootScopeSettings(); for (const auto metric : root_metrics) { settings.erase(metric); } @@ -37,32 +38,48 @@ ProfilingInfo::ProfilingInfo(const profiler_settings_t &n_settings, const idx_t } profiler_settings_t ProfilingInfo::DefaultSettings() { - return {MetricsType::QUERY_NAME, + return {MetricsType::ATTACH_LOAD_STORAGE_LATENCY, + MetricsType::ATTACH_REPLAY_WAL_LATENCY, MetricsType::BLOCKED_THREAD_TIME, - MetricsType::SYSTEM_PEAK_BUFFER_MEMORY, - MetricsType::SYSTEM_PEAK_TEMP_DIR_SIZE, + MetricsType::CHECKPOINT_LATENCY, MetricsType::CPU_TIME, - MetricsType::EXTRA_INFO, + MetricsType::COMMIT_WRITE_WAL_LATENCY, MetricsType::CUMULATIVE_CARDINALITY, - MetricsType::OPERATOR_NAME, - MetricsType::OPERATOR_TYPE, - MetricsType::OPERATOR_CARDINALITY, MetricsType::CUMULATIVE_ROWS_SCANNED, + MetricsType::EXTRA_INFO, + MetricsType::LATENCY, + MetricsType::OPERATOR_CARDINALITY, + MetricsType::OPERATOR_NAME, MetricsType::OPERATOR_ROWS_SCANNED, MetricsType::OPERATOR_TIMING, + MetricsType::OPERATOR_TYPE, MetricsType::RESULT_SET_SIZE, - MetricsType::LATENCY, MetricsType::ROWS_RETURNED, + MetricsType::SYSTEM_PEAK_BUFFER_MEMORY, + MetricsType::SYSTEM_PEAK_TEMP_DIR_SIZE, MetricsType::TOTAL_BYTES_READ, - MetricsType::TOTAL_BYTES_WRITTEN}; + MetricsType::TOTAL_BYTES_WRITTEN, + MetricsType::WAITING_TO_ATTACH_LATENCY, + MetricsType::WAL_REPLAY_ENTRY_COUNT, + MetricsType::QUERY_NAME}; } -profiler_settings_t ProfilingInfo::DefaultRootSettings() { - return {MetricsType::QUERY_NAME, MetricsType::BLOCKED_THREAD_TIME, MetricsType::LATENCY, - MetricsType::ROWS_RETURNED}; +profiler_settings_t ProfilingInfo::RootScopeSettings() { + return {MetricsType::ATTACH_LOAD_STORAGE_LATENCY, + MetricsType::ATTACH_REPLAY_WAL_LATENCY, + MetricsType::BLOCKED_THREAD_TIME, + MetricsType::CHECKPOINT_LATENCY, + MetricsType::COMMIT_WRITE_WAL_LATENCY, + MetricsType::LATENCY, + MetricsType::ROWS_RETURNED, + MetricsType::TOTAL_BYTES_READ, + MetricsType::TOTAL_BYTES_WRITTEN, + MetricsType::WAITING_TO_ATTACH_LATENCY, + MetricsType::WAL_REPLAY_ENTRY_COUNT, + MetricsType::QUERY_NAME}; } -profiler_settings_t ProfilingInfo::DefaultOperatorSettings() { +profiler_settings_t ProfilingInfo::OperatorScopeSettings() { return {MetricsType::OPERATOR_CARDINALITY, MetricsType::OPERATOR_ROWS_SCANNED, MetricsType::OPERATOR_TIMING, MetricsType::OPERATOR_NAME, MetricsType::OPERATOR_TYPE}; } @@ -83,6 +100,11 @@ void ProfilingInfo::ResetMetrics() { case MetricsType::BLOCKED_THREAD_TIME: case MetricsType::CPU_TIME: case MetricsType::OPERATOR_TIMING: + case MetricsType::WAITING_TO_ATTACH_LATENCY: + case MetricsType::ATTACH_LOAD_STORAGE_LATENCY: + case MetricsType::ATTACH_REPLAY_WAL_LATENCY: + case MetricsType::CHECKPOINT_LATENCY: + case MetricsType::COMMIT_WRITE_WAL_LATENCY: metrics[metric] = Value::CreateValue(0.0); break; case MetricsType::OPERATOR_NAME: @@ -101,9 +123,11 @@ void ProfilingInfo::ResetMetrics() { case MetricsType::SYSTEM_PEAK_TEMP_DIR_SIZE: case MetricsType::TOTAL_BYTES_READ: case MetricsType::TOTAL_BYTES_WRITTEN: + case MetricsType::WAL_REPLAY_ENTRY_COUNT: metrics[metric] = Value::CreateValue(0); break; case MetricsType::EXTRA_INFO: + metrics[metric] = Value::MAP(InsertionOrderPreservingMap()); break; default: throw InternalException("MetricsType" + EnumUtil::ToString(metric) + "not implemented"); @@ -149,26 +173,25 @@ string ProfilingInfo::GetMetricAsString(const MetricsType metric) const { throw InternalException("Metric %s not enabled", EnumUtil::ToString(metric)); } - if (metric == MetricsType::EXTRA_INFO) { - string result; - for (auto &it : extra_info) { - if (!result.empty()) { - result += ", "; - } - result += StringUtil::Format("%s: %s", it.first, it.second); - } - return "\"" + result + "\""; - } - // The metric cannot be NULL and must be initialized. D_ASSERT(!metrics.at(metric).IsNull()); if (metric == MetricsType::OPERATOR_TYPE) { - auto type = PhysicalOperatorType(metrics.at(metric).GetValue()); + const auto type = PhysicalOperatorType(metrics.at(metric).GetValue()); return EnumUtil::ToString(type); } return metrics.at(metric).ToString(); } +void ProfilingInfo::WriteMetricsToLog(ClientContext &context) { + auto &logger = Logger::Get(context); + if (logger.ShouldLog(MetricsLogType::NAME, MetricsLogType::LEVEL)) { + for (auto &metric : settings) { + logger.WriteLog(MetricsLogType::NAME, MetricsLogType::LEVEL, + MetricsLogType::ConstructLogMessage(metric, metrics[metric])); + } + } +} + void ProfilingInfo::WriteMetricsToJSON(yyjson_mut_doc *doc, yyjson_mut_val *dest) { for (auto &metric : settings) { auto metric_str = StringUtil::Lower(EnumUtil::ToString(metric)); @@ -178,18 +201,25 @@ void ProfilingInfo::WriteMetricsToJSON(yyjson_mut_doc *doc, yyjson_mut_val *dest if (metric == MetricsType::EXTRA_INFO) { auto extra_info_obj = yyjson_mut_obj(doc); - for (auto &it : extra_info) { - auto &key = it.first; - auto &value = it.second; - auto splits = StringUtil::Split(value, "\n"); + auto extra_info = metrics.at(metric); + auto children = MapValue::GetChildren(extra_info); + for (auto &child : children) { + auto struct_children = StructValue::GetChildren(child); + auto key = struct_children[0].GetValue(); + auto value = struct_children[1].GetValue(); + + auto key_mut = unsafe_yyjson_mut_strncpy(doc, key.c_str(), key.size()); + auto value_mut = unsafe_yyjson_mut_strncpy(doc, value.c_str(), value.size()); + + auto splits = StringUtil::Split(value_mut, "\n"); if (splits.size() > 1) { auto list_items = yyjson_mut_arr(doc); for (auto &split : splits) { yyjson_mut_arr_add_strcpy(doc, list_items, split.c_str()); } - yyjson_mut_obj_add_val(doc, extra_info_obj, key.c_str(), list_items); + yyjson_mut_obj_add_val(doc, extra_info_obj, key_mut, list_items); } else { - yyjson_mut_obj_add_strcpy(doc, extra_info_obj, key.c_str(), value.c_str()); + yyjson_mut_obj_add_strcpy(doc, extra_info_obj, key_mut, value_mut); } } yyjson_mut_obj_add_val(doc, dest, key_ptr, extra_info_obj); @@ -212,7 +242,12 @@ void ProfilingInfo::WriteMetricsToJSON(yyjson_mut_doc *doc, yyjson_mut_val *dest case MetricsType::LATENCY: case MetricsType::BLOCKED_THREAD_TIME: case MetricsType::CPU_TIME: - case MetricsType::OPERATOR_TIMING: { + case MetricsType::OPERATOR_TIMING: + case MetricsType::WAITING_TO_ATTACH_LATENCY: + case MetricsType::ATTACH_LOAD_STORAGE_LATENCY: + case MetricsType::ATTACH_REPLAY_WAL_LATENCY: + case MetricsType::COMMIT_WRITE_WAL_LATENCY: + case MetricsType::CHECKPOINT_LATENCY: { yyjson_mut_obj_add_real(doc, dest, key_ptr, metrics[metric].GetValue()); break; } @@ -228,6 +263,7 @@ void ProfilingInfo::WriteMetricsToJSON(yyjson_mut_doc *doc, yyjson_mut_val *dest case MetricsType::OPERATOR_ROWS_SCANNED: case MetricsType::SYSTEM_PEAK_BUFFER_MEMORY: case MetricsType::SYSTEM_PEAK_TEMP_DIR_SIZE: + case MetricsType::WAL_REPLAY_ENTRY_COUNT: case MetricsType::TOTAL_BYTES_READ: case MetricsType::TOTAL_BYTES_WRITTEN: { yyjson_mut_obj_add_uint(doc, dest, key_ptr, metrics[metric].GetValue()); diff --git a/src/duckdb/src/main/query_profiler.cpp b/src/duckdb/src/main/query_profiler.cpp index 4c9c9328a..77df96938 100644 --- a/src/duckdb/src/main/query_profiler.cpp +++ b/src/duckdb/src/main/query_profiler.cpp @@ -16,6 +16,7 @@ #include "duckdb/planner/expression/bound_function_expression.hpp" #include "duckdb/storage/buffer/buffer_pool.hpp" #include "yyjson.hpp" +#include "yyjson_utils.hpp" #include #include @@ -52,6 +53,8 @@ ProfilerPrintFormat QueryProfiler::GetPrintFormat(ExplainFormat format) const { return ProfilerPrintFormat::HTML; case ExplainFormat::GRAPHVIZ: return ProfilerPrintFormat::GRAPHVIZ; + case ExplainFormat::MERMAID: + return ProfilerPrintFormat::MERMAID; default: throw NotImplementedException("No mapping from ExplainFormat::%s to ProfilerPrintFormat", EnumUtil::ToString(format)); @@ -69,6 +72,8 @@ ExplainFormat QueryProfiler::GetExplainFormat(ProfilerPrintFormat format) const return ExplainFormat::HTML; case ProfilerPrintFormat::GRAPHVIZ: return ExplainFormat::GRAPHVIZ; + case ProfilerPrintFormat::MERMAID: + return ExplainFormat::MERMAID; case ProfilerPrintFormat::NO_OUTPUT: throw InternalException("Should not attempt to get ExplainFormat for ProfilerPrintFormat::NO_OUTPUT"); default: @@ -102,9 +107,7 @@ void QueryProfiler::Reset() { phase_timings.clear(); phase_stack.clear(); running = false; - query_metrics.query = ""; - query_metrics.total_bytes_read = 0; - query_metrics.total_bytes_written = 0; + query_metrics.Reset(); } void QueryProfiler::StartQuery(const string &query, bool is_explain_analyze_p, bool start_at_optimizer) { @@ -180,7 +183,6 @@ void QueryProfiler::Finalize(ProfilingNode &node) { auto type = PhysicalOperatorType(info.GetMetricValue(MetricsType::OPERATOR_TYPE)); if (type == PhysicalOperatorType::UNION && info.Enabled(info.expanded_settings, MetricsType::OPERATOR_CARDINALITY)) { - auto &child_info = child->GetProfilingInfo(); auto value = child_info.metrics[MetricsType::OPERATOR_CARDINALITY].GetValue(); info.MetricSum(MetricsType::OPERATOR_CARDINALITY, value); @@ -281,6 +283,28 @@ void QueryProfiler::EndQuery() { if (info.Enabled(settings, MetricsType::RESULT_SET_SIZE)) { info.metrics[MetricsType::RESULT_SET_SIZE] = child_info.metrics[MetricsType::RESULT_SET_SIZE]; } + if (info.Enabled(settings, MetricsType::WAITING_TO_ATTACH_LATENCY)) { + info.metrics[MetricsType::WAITING_TO_ATTACH_LATENCY] = + query_metrics.waiting_to_attach_latency.Elapsed(); + } + if (info.Enabled(settings, MetricsType::ATTACH_LOAD_STORAGE_LATENCY)) { + info.metrics[MetricsType::ATTACH_LOAD_STORAGE_LATENCY] = + query_metrics.attach_load_storage_latency.Elapsed(); + } + if (info.Enabled(settings, MetricsType::ATTACH_REPLAY_WAL_LATENCY)) { + info.metrics[MetricsType::ATTACH_REPLAY_WAL_LATENCY] = + query_metrics.attach_replay_wal_latency.Elapsed(); + } + if (info.Enabled(settings, MetricsType::COMMIT_WRITE_WAL_LATENCY)) { + info.metrics[MetricsType::COMMIT_WRITE_WAL_LATENCY] = query_metrics.commit_write_wal_latency.Elapsed(); + } + if (info.Enabled(settings, MetricsType::WAL_REPLAY_ENTRY_COUNT)) { + info.metrics[MetricsType::WAL_REPLAY_ENTRY_COUNT] = + Value::UBIGINT(query_metrics.wal_replay_entry_count); + } + if (info.Enabled(settings, MetricsType::CHECKPOINT_LATENCY)) { + info.metrics[MetricsType::CHECKPOINT_LATENCY] = query_metrics.checkpoint_latency.Elapsed(); + } MoveOptimizerPhasesToRoot(); if (info.Enabled(settings, MetricsType::CUMULATIVE_OPTIMIZER_TIMING)) { @@ -297,6 +321,9 @@ void QueryProfiler::EndQuery() { guard.unlock(); + // To log is inexpensive, whether to log or not depends on whether logging is active + ToLog(); + if (emit_output) { string tree = ToString(); auto save_location = GetSaveLocation(); @@ -310,15 +337,80 @@ void QueryProfiler::EndQuery() { } } -void QueryProfiler::AddBytesRead(const idx_t nr_bytes) { - if (IsEnabled()) { - query_metrics.total_bytes_read += nr_bytes; +void QueryProfiler::AddToCounter(const MetricsType type, const idx_t amount) { + if (!IsEnabled()) { + return; + } + + switch (type) { + case MetricsType::TOTAL_BYTES_READ: + query_metrics.total_bytes_read += amount; + return; + case MetricsType::TOTAL_BYTES_WRITTEN: + query_metrics.total_bytes_written += amount; + return; + case MetricsType::WAL_REPLAY_ENTRY_COUNT: + query_metrics.wal_replay_entry_count += amount; + return; + default: + return; } } -void QueryProfiler::AddBytesWritten(const idx_t nr_bytes) { - if (IsEnabled()) { - query_metrics.total_bytes_written += nr_bytes; +idx_t QueryProfiler::GetBytesRead() const { + return query_metrics.total_bytes_read; +} + +idx_t QueryProfiler::GetBytesWritten() const { + return query_metrics.total_bytes_written; +} + +void QueryProfiler::StartTimer(const MetricsType type) { + if (!IsEnabled()) { + return; + } + + switch (type) { + case MetricsType::WAITING_TO_ATTACH_LATENCY: + query_metrics.waiting_to_attach_latency.Start(); + return; + case MetricsType::ATTACH_LOAD_STORAGE_LATENCY: + query_metrics.attach_load_storage_latency.Start(); + return; + case MetricsType::ATTACH_REPLAY_WAL_LATENCY: + query_metrics.attach_replay_wal_latency.Start(); + return; + case MetricsType::CHECKPOINT_LATENCY: + query_metrics.checkpoint_latency.Start(); + return; + case MetricsType::COMMIT_WRITE_WAL_LATENCY: + query_metrics.commit_write_wal_latency.Start(); + return; + default: + return; + } +} + +void QueryProfiler::EndTimer(MetricsType type) { + if (!IsEnabled()) { + return; + } + + switch (type) { + case MetricsType::WAITING_TO_ATTACH_LATENCY: + query_metrics.waiting_to_attach_latency.End(); + return; + case MetricsType::ATTACH_LOAD_STORAGE_LATENCY: + query_metrics.attach_load_storage_latency.End(); + return; + case MetricsType::ATTACH_REPLAY_WAL_LATENCY: + query_metrics.attach_replay_wal_latency.End(); + return; + case MetricsType::CHECKPOINT_LATENCY: + query_metrics.checkpoint_latency.End(); + return; + default: + return; } } @@ -339,7 +431,8 @@ string QueryProfiler::ToString(ProfilerPrintFormat format) const { case ProfilerPrintFormat::NO_OUTPUT: return ""; case ProfilerPrintFormat::HTML: - case ProfilerPrintFormat::GRAPHVIZ: { + case ProfilerPrintFormat::GRAPHVIZ: + case ProfilerPrintFormat::MERMAID: { lock_guard guard(lock); // checking the tree to ensure the query is really empty // the query string is empty when a logical plan is deserialized @@ -404,7 +497,7 @@ OperatorProfiler::OperatorProfiler(ClientContext &context) : context(context) { } // Reduce. - auto root_metrics = ProfilingInfo::DefaultRootSettings(); + auto root_metrics = ProfilingInfo::RootScopeSettings(); for (const auto metric : root_metrics) { settings.erase(metric); } @@ -557,7 +650,7 @@ void QueryProfiler::Flush(OperatorProfiler &profiler) { info.MetricSum(MetricsType::RESULT_SET_SIZE, node.second.result_set_size); } if (ProfilingInfo::Enabled(profiler.settings, MetricsType::EXTRA_INFO)) { - info.extra_info = node.second.extra_info; + info.metrics[MetricsType::EXTRA_INFO] = Value::MAP(node.second.extra_info); } if (ProfilingInfo::Enabled(profiler.settings, MetricsType::SYSTEM_PEAK_BUFFER_MEMORY)) { query_metrics.query_global_info.MetricMax(MetricsType::SYSTEM_PEAK_BUFFER_MEMORY, @@ -571,7 +664,7 @@ void QueryProfiler::Flush(OperatorProfiler &profiler) { profiler.operator_infos.clear(); } -void QueryProfiler::SetInfo(const double &blocked_thread_time) { +void QueryProfiler::SetBlockedTime(const double &blocked_thread_time) { lock_guard guard(lock); if (!IsEnabled() || !running) { return; @@ -721,18 +814,24 @@ void QueryProfiler::QueryTreeToStream(std::ostream &ss) const { } } -InsertionOrderPreservingMap QueryProfiler::JSONSanitize(const InsertionOrderPreservingMap &input) { +Value QueryProfiler::JSONSanitize(const Value &input) { + D_ASSERT(input.type().id() == LogicalTypeId::MAP); + InsertionOrderPreservingMap result; - for (auto &it : input) { - auto key = it.first; + auto children = MapValue::GetChildren(input); + for (auto &child : children) { + auto struct_children = StructValue::GetChildren(child); + auto key = struct_children[0].GetValue(); + auto value = struct_children[1].GetValue(); + if (StringUtil::StartsWith(key, "__")) { key = StringUtil::Replace(key, "__", ""); key = StringUtil::Replace(key, "_", " "); key = StringUtil::Title(key); } - result[key] = it.second; + result[key] = value; } - return result; + return Value::MAP(result); } string QueryProfiler::JSONSanitize(const std::string &text) { @@ -772,7 +871,12 @@ string QueryProfiler::JSONSanitize(const std::string &text) { static yyjson_mut_val *ToJSONRecursive(yyjson_mut_doc *doc, ProfilingNode &node) { auto result_obj = yyjson_mut_obj(doc); auto &profiling_info = node.GetProfilingInfo(); - profiling_info.extra_info = QueryProfiler::JSONSanitize(profiling_info.extra_info); + + if (profiling_info.Enabled(profiling_info.settings, MetricsType::EXTRA_INFO)) { + profiling_info.metrics[MetricsType::EXTRA_INFO] = + QueryProfiler::JSONSanitize(profiling_info.metrics.at(MetricsType::EXTRA_INFO)); + } + profiling_info.WriteMetricsToJSON(doc, result_obj); auto children_list = yyjson_mut_arr(doc); @@ -784,44 +888,56 @@ static yyjson_mut_val *ToJSONRecursive(yyjson_mut_doc *doc, ProfilingNode &node) return result_obj; } -static string StringifyAndFree(yyjson_mut_doc *doc, yyjson_mut_val *object) { - auto data = yyjson_mut_val_write_opts(object, YYJSON_WRITE_ALLOW_INF_AND_NAN | YYJSON_WRITE_PRETTY, nullptr, - nullptr, nullptr); - if (!data) { - yyjson_mut_doc_free(doc); +static string StringifyAndFree(ConvertedJSONHolder &json_holder, yyjson_mut_val *object) { + json_holder.stringified_json = yyjson_mut_val_write_opts( + object, YYJSON_WRITE_ALLOW_INF_AND_NAN | YYJSON_WRITE_PRETTY, nullptr, nullptr, nullptr); + if (!json_holder.stringified_json) { throw InternalException("The plan could not be rendered as JSON, yyjson failed"); } - auto result = string(data); - free(data); - yyjson_mut_doc_free(doc); + auto result = string(json_holder.stringified_json); return result; } +void QueryProfiler::ToLog() const { + lock_guard guard(lock); + + if (!root) { + // No root, not much to do + return; + } + + auto &settings = root->GetProfilingInfo(); + + settings.WriteMetricsToLog(context); +} + string QueryProfiler::ToJSON() const { lock_guard guard(lock); - auto doc = yyjson_mut_doc_new(nullptr); - auto result_obj = yyjson_mut_obj(doc); - yyjson_mut_doc_set_root(doc, result_obj); + ConvertedJSONHolder json_holder; + + json_holder.doc = yyjson_mut_doc_new(nullptr); + auto result_obj = yyjson_mut_obj(json_holder.doc); + yyjson_mut_doc_set_root(json_holder.doc, result_obj); if (query_metrics.query.empty() && !root) { - yyjson_mut_obj_add_str(doc, result_obj, "result", "empty"); - return StringifyAndFree(doc, result_obj); + yyjson_mut_obj_add_str(json_holder.doc, result_obj, "result", "empty"); + return StringifyAndFree(json_holder, result_obj); } if (!root) { - yyjson_mut_obj_add_str(doc, result_obj, "result", "error"); - return StringifyAndFree(doc, result_obj); + yyjson_mut_obj_add_str(json_holder.doc, result_obj, "result", "error"); + return StringifyAndFree(json_holder, result_obj); } auto &settings = root->GetProfilingInfo(); - settings.WriteMetricsToJSON(doc, result_obj); + settings.WriteMetricsToJSON(json_holder.doc, result_obj); // recursively print the physical operator tree - auto children_list = yyjson_mut_arr(doc); - yyjson_mut_obj_add_val(doc, result_obj, "children", children_list); - auto child = ToJSONRecursive(doc, *root->GetChild(0)); + auto children_list = yyjson_mut_arr(json_holder.doc); + yyjson_mut_obj_add_val(json_holder.doc, result_obj, "children", children_list); + auto child = ToJSONRecursive(json_holder.doc, *root->GetChild(0)); yyjson_mut_arr_add_val(children_list, child); - return StringifyAndFree(doc, result_obj); + return StringifyAndFree(json_holder, result_obj); } void QueryProfiler::WriteToFile(const char *path, string &info) const { @@ -871,7 +987,7 @@ unique_ptr QueryProfiler::CreateTree(const PhysicalOperator &root info.MetricSum(MetricsType::OPERATOR_TYPE, static_cast(root_p.type)); } if (info.Enabled(info.settings, MetricsType::EXTRA_INFO)) { - info.extra_info = root_p.ParamsToString(); + info.metrics[MetricsType::EXTRA_INFO] = Value::MAP(root_p.ParamsToString()); } tree_map.insert(make_pair(reference(root_p), reference(*node))); @@ -904,13 +1020,20 @@ string QueryProfiler::RenderDisabledMessage(ProfilerPrintFormat format) const { node_0_0 [label="Query profiling is disabled. Use 'PRAGMA enable_profiling;' to enable profiling!"]; } )"; + case ProfilerPrintFormat::MERMAID: + return R"(flowchart TD + node_0_0["`**DISABLED** +Query profiling is disabled. +Use 'PRAGMA enable_profiling;' to enable profiling!`"] +)"; case ProfilerPrintFormat::JSON: { - auto doc = yyjson_mut_doc_new(nullptr); - auto result_obj = yyjson_mut_obj(doc); - yyjson_mut_doc_set_root(doc, result_obj); + ConvertedJSONHolder json_holder; + json_holder.doc = yyjson_mut_doc_new(nullptr); + auto result_obj = yyjson_mut_obj(json_holder.doc); + yyjson_mut_doc_set_root(json_holder.doc, result_obj); - yyjson_mut_obj_add_str(doc, result_obj, "result", "disabled"); - return StringifyAndFree(doc, result_obj); + yyjson_mut_obj_add_str(json_holder.doc, result_obj, "result", "disabled"); + return StringifyAndFree(json_holder, result_obj); } default: throw InternalException("Unknown ProfilerPrintFormat \"%s\"", EnumUtil::ToString(format)); @@ -962,7 +1085,4 @@ void QueryProfiler::MoveOptimizerPhasesToRoot() { } } -void QueryProfiler::Propagate(QueryProfiler &) { -} - } // namespace duckdb diff --git a/src/duckdb/src/main/query_result.cpp b/src/duckdb/src/main/query_result.cpp index a20f9a87a..b3220e412 100644 --- a/src/duckdb/src/main/query_result.cpp +++ b/src/duckdb/src/main/query_result.cpp @@ -41,7 +41,7 @@ const ExceptionType &BaseQueryResult::GetErrorType() const { return error.Type(); } -const std::string &BaseQueryResult::GetError() { +const std::string &BaseQueryResult::GetError() const { D_ASSERT(HasError()); return error.Message(); } @@ -110,6 +110,42 @@ unique_ptr QueryResult::Fetch() { return chunk; } +unique_ptr QueryResult::FetchRaw() { + if (!stored_chunks.empty()) { + auto result = std::move(stored_chunks.back()); + stored_chunks.pop_back(); + return result; + } + if (result_exhausted) { + return nullptr; + } + return FetchInternal(); +} + +bool QueryResult::MoreRowsThan(idx_t row_count) { + // fetch chunks until we have seen more than "row_count" - OR the result is exhausted + // store any fetched chunks in "stored_chunks" - we return these again in "Fetch" upon request + idx_t result_row_count = 0; + if (!stored_chunks.empty()) { + std::reverse(stored_chunks.begin(), stored_chunks.end()); + for (auto &chunk : stored_chunks) { + result_row_count += chunk->size(); + } + } + while (result_row_count < row_count) { + auto chunk = FetchInternal(); + if (!chunk) { + // exhausted result + result_exhausted = true; + break; + } + result_row_count += chunk->size(); + stored_chunks.push_back(std::move(chunk)); + } + std::reverse(stored_chunks.begin(), stored_chunks.end()); + return result_row_count >= row_count; +} + bool QueryResult::Equals(QueryResult &other) { // LCOV_EXCL_START // first compare the success state of the results if (success != other.success) { diff --git a/src/duckdb/src/main/relation.cpp b/src/duckdb/src/main/relation.cpp index 9a28349e7..b9e4d50ff 100644 --- a/src/duckdb/src/main/relation.cpp +++ b/src/duckdb/src/main/relation.cpp @@ -394,8 +394,8 @@ string Relation::ToString() { } // LCOV_EXCL_START -unique_ptr Relation::GetQueryNode() { - throw InternalException("Cannot create a query node from this node type"); +string Relation::GetQuery() { + return GetQueryNode()->ToString(); } void Relation::Head(idx_t limit) { diff --git a/src/duckdb/src/main/relation/create_table_relation.cpp b/src/duckdb/src/main/relation/create_table_relation.cpp index 2492f244b..39aa65e36 100644 --- a/src/duckdb/src/main/relation/create_table_relation.cpp +++ b/src/duckdb/src/main/relation/create_table_relation.cpp @@ -29,6 +29,14 @@ BoundStatement CreateTableRelation::Bind(Binder &binder) { return binder.Bind(stmt.Cast()); } +unique_ptr CreateTableRelation::GetQueryNode() { + throw InternalException("Cannot create a query node from a create table relation"); +} + +string CreateTableRelation::GetQuery() { + return string(); +} + const vector &CreateTableRelation::Columns() { return columns; } diff --git a/src/duckdb/src/main/relation/create_view_relation.cpp b/src/duckdb/src/main/relation/create_view_relation.cpp index c00deef38..6f77f013f 100644 --- a/src/duckdb/src/main/relation/create_view_relation.cpp +++ b/src/duckdb/src/main/relation/create_view_relation.cpp @@ -35,6 +35,14 @@ BoundStatement CreateViewRelation::Bind(Binder &binder) { return binder.Bind(stmt.Cast()); } +unique_ptr CreateViewRelation::GetQueryNode() { + throw InternalException("Cannot create a query node from an update relation"); +} + +string CreateViewRelation::GetQuery() { + return string(); +} + const vector &CreateViewRelation::Columns() { return columns; } diff --git a/src/duckdb/src/main/relation/delete_relation.cpp b/src/duckdb/src/main/relation/delete_relation.cpp index 64b3f231e..2ec60f664 100644 --- a/src/duckdb/src/main/relation/delete_relation.cpp +++ b/src/duckdb/src/main/relation/delete_relation.cpp @@ -26,6 +26,14 @@ BoundStatement DeleteRelation::Bind(Binder &binder) { return binder.Bind(stmt.Cast()); } +unique_ptr DeleteRelation::GetQueryNode() { + throw InternalException("Cannot create a query node from a delete relation"); +} + +string DeleteRelation::GetQuery() { + return string(); +} + const vector &DeleteRelation::Columns() { return columns; } diff --git a/src/duckdb/src/main/relation/explain_relation.cpp b/src/duckdb/src/main/relation/explain_relation.cpp index f91e1d29f..9f2976c9d 100644 --- a/src/duckdb/src/main/relation/explain_relation.cpp +++ b/src/duckdb/src/main/relation/explain_relation.cpp @@ -20,6 +20,14 @@ BoundStatement ExplainRelation::Bind(Binder &binder) { return binder.Bind(explain.Cast()); } +unique_ptr ExplainRelation::GetQueryNode() { + throw InternalException("Cannot create a query node from an explain relation"); +} + +string ExplainRelation::GetQuery() { + return string(); +} + const vector &ExplainRelation::Columns() { return columns; } diff --git a/src/duckdb/src/main/relation/insert_relation.cpp b/src/duckdb/src/main/relation/insert_relation.cpp index 9728570a0..84ef16ec6 100644 --- a/src/duckdb/src/main/relation/insert_relation.cpp +++ b/src/duckdb/src/main/relation/insert_relation.cpp @@ -24,6 +24,14 @@ BoundStatement InsertRelation::Bind(Binder &binder) { return binder.Bind(stmt.Cast()); } +unique_ptr InsertRelation::GetQueryNode() { + throw InternalException("Cannot create a query node from an insert relation"); +} + +string InsertRelation::GetQuery() { + return string(); +} + const vector &InsertRelation::Columns() { return columns; } diff --git a/src/duckdb/src/main/relation/query_relation.cpp b/src/duckdb/src/main/relation/query_relation.cpp index e0cf2e280..79aa1f981 100644 --- a/src/duckdb/src/main/relation/query_relation.cpp +++ b/src/duckdb/src/main/relation/query_relation.cpp @@ -49,6 +49,10 @@ unique_ptr QueryRelation::GetQueryNode() { return std::move(select->node); } +string QueryRelation::GetQuery() { + return query; +} + unique_ptr QueryRelation::GetTableRef() { auto subquery_ref = make_uniq(GetSelectStatement(), GetAlias()); return std::move(subquery_ref); @@ -61,9 +65,6 @@ BoundStatement QueryRelation::Bind(Binder &binder) { auto result = Relation::Bind(binder); auto &replacements = binder.GetReplacementScans(); if (first_bind) { - auto &query_node = *select_stmt->node; - auto &cte_map = query_node.cte_map; - vector> materialized_ctes; for (auto &kv : replacements) { auto &name = kv.first; auto &tableref = kv.second; @@ -83,29 +84,16 @@ BoundStatement QueryRelation::Bind(Binder &binder) { auto cte_info = make_uniq(); cte_info->query = std::move(select); + auto subquery = make_uniq(std::move(select_stmt), "query_relation"); + auto top_level_select = make_uniq(); + auto top_level_select_node = make_uniq(); + top_level_select_node->select_list.push_back(make_uniq()); + top_level_select_node->from_table = std::move(subquery); + auto &cte_map = top_level_select_node->cte_map; + top_level_select->node = std::move(top_level_select_node); cte_map.map[name] = std::move(cte_info); - - // We can not rely on CTE inlining anymore, so we need to add a materialized CTE node - // to the query node to ensure that the CTE exists - auto &cte_entry = cte_map.map[name]; - auto mat_cte = make_uniq(); - mat_cte->ctename = name; - mat_cte->query = cte_entry->query->node->Copy(); - mat_cte->aliases = cte_entry->aliases; - mat_cte->materialized = cte_entry->materialized; - materialized_ctes.push_back(std::move(mat_cte)); - } - - auto root = std::move(select_stmt->node); - while (!materialized_ctes.empty()) { - unique_ptr node_result; - node_result = std::move(materialized_ctes.back()); - node_result->cte_map = root->cte_map.Copy(); - node_result->child = std::move(root); - root = std::move(node_result); - materialized_ctes.pop_back(); + select_stmt = std::move(top_level_select); } - select_stmt->node = std::move(root); } replacements.clear(); binder.SetBindingMode(saved_binding_mode); diff --git a/src/duckdb/src/main/relation/read_json_relation.cpp b/src/duckdb/src/main/relation/read_json_relation.cpp index 2f849597d..6fe7e4a7c 100644 --- a/src/duckdb/src/main/relation/read_json_relation.cpp +++ b/src/duckdb/src/main/relation/read_json_relation.cpp @@ -15,7 +15,6 @@ ReadJSONRelation::ReadJSONRelation(const shared_ptr &context, vec : TableFunctionRelation(context, auto_detect ? "read_json_auto" : "read_json", {MultiFileReader::CreateValueFromFileList(input)}, std::move(options)), alias(std::move(alias_p)) { - InitializeAlias(input); } @@ -24,7 +23,6 @@ ReadJSONRelation::ReadJSONRelation(const shared_ptr &context, str : TableFunctionRelation(context, auto_detect ? "read_json_auto" : "read_json", {Value(json_file_p)}, std::move(options)), json_file(std::move(json_file_p)), alias(std::move(alias_p)) { - if (alias.empty()) { alias = StringUtil::Split(json_file, ".")[0]; } diff --git a/src/duckdb/src/main/relation/update_relation.cpp b/src/duckdb/src/main/relation/update_relation.cpp index 9176cf2f2..81d85ca89 100644 --- a/src/duckdb/src/main/relation/update_relation.cpp +++ b/src/duckdb/src/main/relation/update_relation.cpp @@ -35,6 +35,14 @@ BoundStatement UpdateRelation::Bind(Binder &binder) { return binder.Bind(stmt.Cast()); } +unique_ptr UpdateRelation::GetQueryNode() { + throw InternalException("Cannot create a query node from an update relation"); +} + +string UpdateRelation::GetQuery() { + return string(); +} + const vector &UpdateRelation::Columns() { return columns; } diff --git a/src/duckdb/src/main/relation/write_csv_relation.cpp b/src/duckdb/src/main/relation/write_csv_relation.cpp index 4795c7a51..f77d6f1ee 100644 --- a/src/duckdb/src/main/relation/write_csv_relation.cpp +++ b/src/duckdb/src/main/relation/write_csv_relation.cpp @@ -25,6 +25,14 @@ BoundStatement WriteCSVRelation::Bind(Binder &binder) { return binder.Bind(copy.Cast()); } +unique_ptr WriteCSVRelation::GetQueryNode() { + throw InternalException("Cannot create a query node from a write CSV relation"); +} + +string WriteCSVRelation::GetQuery() { + return string(); +} + const vector &WriteCSVRelation::Columns() { return columns; } diff --git a/src/duckdb/src/main/relation/write_parquet_relation.cpp b/src/duckdb/src/main/relation/write_parquet_relation.cpp index d6e403618..b1dfdb29f 100644 --- a/src/duckdb/src/main/relation/write_parquet_relation.cpp +++ b/src/duckdb/src/main/relation/write_parquet_relation.cpp @@ -25,6 +25,14 @@ BoundStatement WriteParquetRelation::Bind(Binder &binder) { return binder.Bind(copy.Cast()); } +unique_ptr WriteParquetRelation::GetQueryNode() { + throw InternalException("Cannot create a query node from a write parquet relation"); +} + +string WriteParquetRelation::GetQuery() { + return string(); +} + const vector &WriteParquetRelation::Columns() { return columns; } diff --git a/src/duckdb/src/main/result_set_manager.cpp b/src/duckdb/src/main/result_set_manager.cpp new file mode 100644 index 000000000..d8913b8e8 --- /dev/null +++ b/src/duckdb/src/main/result_set_manager.cpp @@ -0,0 +1,51 @@ +#include "duckdb/main/result_set_manager.hpp" + +#include "duckdb/main/client_context.hpp" +#include "duckdb/main/database.hpp" + +namespace duckdb { + +ManagedResultSet::ManagedResultSet() : valid(false) { +} + +ManagedResultSet::ManagedResultSet(const weak_ptr &db_p, vector> &handles_p) + : valid(true), db(db_p), handles(handles_p) { +} + +bool ManagedResultSet::IsValid() const { + return valid; +} + +shared_ptr ManagedResultSet::GetDatabase() const { + D_ASSERT(IsValid()); + return db.lock(); +} + +vector> &ManagedResultSet::GetHandles() { + D_ASSERT(IsValid()); + return *handles; +} + +ResultSetManager::ResultSetManager(DatabaseInstance &db_p) : db(db_p.shared_from_this()) { +} + +ResultSetManager &ResultSetManager::Get(ClientContext &context) { + return Get(*context.db); +} + +ResultSetManager &ResultSetManager::Get(DatabaseInstance &db_p) { + return db_p.GetResultSetManager(); +} + +ManagedResultSet ResultSetManager::Add(ColumnDataAllocator &allocator) { + lock_guard guard(lock); + auto &handles = *open_results.emplace(allocator, make_uniq>>()).first->second; + return ManagedResultSet(db, handles); +} + +void ResultSetManager::Remove(ColumnDataAllocator &allocator) { + lock_guard guard(lock); + open_results.erase(allocator); +} + +} // namespace duckdb diff --git a/src/duckdb/src/main/settings/autogenerated_settings.cpp b/src/duckdb/src/main/settings/autogenerated_settings.cpp index 96c3065f2..9774ee3b8 100644 --- a/src/duckdb/src/main/settings/autogenerated_settings.cpp +++ b/src/duckdb/src/main/settings/autogenerated_settings.cpp @@ -78,6 +78,28 @@ Value AllowCommunityExtensionsSetting::GetSetting(const ClientContext &context) return Value::BOOLEAN(config.options.allow_community_extensions); } +//===----------------------------------------------------------------------===// +// Allow Parser Override Extension +//===----------------------------------------------------------------------===// +void AllowParserOverrideExtensionSetting::SetGlobal(DatabaseInstance *db, DBConfig &config, const Value &input) { + if (!OnGlobalSet(db, config, input)) { + return; + } + config.options.allow_parser_override_extension = input.GetValue(); +} + +void AllowParserOverrideExtensionSetting::ResetGlobal(DatabaseInstance *db, DBConfig &config) { + if (!OnGlobalReset(db, config)) { + return; + } + config.options.allow_parser_override_extension = DBConfigOptions().allow_parser_override_extension; +} + +Value AllowParserOverrideExtensionSetting::GetSetting(const ClientContext &context) { + auto &config = DBConfig::GetConfig(context); + return Value(config.options.allow_parser_override_extension); +} + //===----------------------------------------------------------------------===// // Allow Unredacted Secrets //===----------------------------------------------------------------------===// @@ -232,6 +254,13 @@ Value DebugForceExternalSetting::GetSetting(const ClientContext &context) { return Value::BOOLEAN(config.force_external); } +//===----------------------------------------------------------------------===// +// Debug Physical Table Scan Execution Strategy +//===----------------------------------------------------------------------===// +void DebugPhysicalTableScanExecutionStrategySetting::OnSet(SettingCallbackInfo &info, Value ¶meter) { + EnumUtil::FromString(StringValue::Get(parameter)); +} + //===----------------------------------------------------------------------===// // Debug Verify Vector //===----------------------------------------------------------------------===// diff --git a/src/duckdb/src/main/settings/custom_settings.cpp b/src/duckdb/src/main/settings/custom_settings.cpp index 8e9b491e3..377a1a5b5 100644 --- a/src/duckdb/src/main/settings/custom_settings.cpp +++ b/src/duckdb/src/main/settings/custom_settings.cpp @@ -34,6 +34,14 @@ namespace duckdb { +constexpr const char *LoggingMode::Name; +constexpr const char *LoggingLevel::Name; +constexpr const char *EnableLogging::Name; +constexpr const char *LoggingStorage::Name; +constexpr const char *EnabledLogTypes::Name; +constexpr const char *DisabledLogTypes::Name; +constexpr const char *DisabledFilesystemsSetting::Name; + const string GetDefaultUserAgent() { return StringUtil::Format("duckdb/%s(%s)", DuckDB::LibraryVersion(), DuckDB::Platform()); } @@ -150,6 +158,27 @@ bool AllowCommunityExtensionsSetting::OnGlobalReset(DatabaseInstance *db, DBConf return true; } +//===----------------------------------------------------------------------===// +// Allow Parser Override +//===----------------------------------------------------------------------===// +bool AllowParserOverrideExtensionSetting::OnGlobalSet(DatabaseInstance *db, DBConfig &config, const Value &input) { + auto new_value = input.GetValue(); + vector supported_options = {"default", "fallback", "strict", "strict_when_supported"}; + string supported_option_string; + for (const auto &option : supported_options) { + if (StringUtil::CIEquals(new_value, option)) { + return true; + } + } + throw InvalidInputException("Unrecognized value for parser override setting. Valid options are: %s", + StringUtil::Join(supported_options, ", ")); +} + +bool AllowParserOverrideExtensionSetting::OnGlobalReset(DatabaseInstance *db, DBConfig &config) { + config.options.allow_parser_override_extension = "default"; + return true; +} + //===----------------------------------------------------------------------===// // Allow Persistent Secrets //===----------------------------------------------------------------------===// @@ -962,9 +991,15 @@ void ForceCompressionSetting::SetGlobal(DatabaseInstance *db, DBConfig &config, } else { auto compression_type = CompressionTypeFromString(compression); //! FIXME: do we want to try to retrieve the AttachedDatabase here to get the StorageManager ?? - if (CompressionTypeIsDeprecated(compression_type)) { - throw ParserException("Attempted to force a deprecated compression type (%s)", - CompressionTypeToString(compression_type)); + auto compression_availability_result = CompressionTypeIsAvailable(compression_type); + if (!compression_availability_result.IsAvailable()) { + if (compression_availability_result.IsDeprecated()) { + throw ParserException("Attempted to force a deprecated compression type (%s)", + CompressionTypeToString(compression_type)); + } else { + throw ParserException("Attempted to force a compression type that isn't available yet (%s)", + CompressionTypeToString(compression_type)); + } } if (compression_type == CompressionType::COMPRESSION_AUTO) { auto compression_types = StringUtil::Join(ListCompressionTypes(), ", "); diff --git a/src/duckdb/src/main/stream_query_result.cpp b/src/duckdb/src/main/stream_query_result.cpp index 9e4b06caa..676791e2f 100644 --- a/src/duckdb/src/main/stream_query_result.cpp +++ b/src/duckdb/src/main/stream_query_result.cpp @@ -69,7 +69,7 @@ static bool ExecutionErrorOccurred(StreamExecutionResult result) { return false; } -unique_ptr StreamQueryResult::FetchInternal(ClientContextLock &lock) { +unique_ptr StreamQueryResult::FetchNextInternal(ClientContextLock &lock) { bool invalidate_query = true; unique_ptr chunk; try { @@ -106,12 +106,12 @@ unique_ptr StreamQueryResult::FetchInternal(ClientContextLock &lock) return nullptr; } -unique_ptr StreamQueryResult::FetchRaw() { +unique_ptr StreamQueryResult::FetchInternal() { unique_ptr chunk; { auto lock = LockContext(); CheckExecutableInternal(*lock); - chunk = FetchInternal(*lock); + chunk = FetchNextInternal(*lock); } if (!chunk || chunk->ColumnCount() == 0 || chunk->size() == 0) { Close(); diff --git a/src/duckdb/src/optimizer/common_subplan_optimizer.cpp b/src/duckdb/src/optimizer/common_subplan_optimizer.cpp new file mode 100644 index 000000000..0c3c9cb35 --- /dev/null +++ b/src/duckdb/src/optimizer/common_subplan_optimizer.cpp @@ -0,0 +1,575 @@ +#include "duckdb/optimizer/common_subplan_optimizer.hpp" + +#include "duckdb/optimizer/optimizer.hpp" +#include "duckdb/optimizer/cte_inlining.hpp" +#include "duckdb/optimizer/column_binding_replacer.hpp" +#include "duckdb/planner/operator/list.hpp" +#include "duckdb/common/serializer/memory_stream.hpp" +#include "duckdb/common/serializer/binary_serializer.hpp" + +namespace duckdb { + +//===--------------------------------------------------------------------===// +// Subplan Signature/Info +//===--------------------------------------------------------------------===// +struct PlanSignatureCreateState { + PlanSignatureCreateState() : stream(DEFAULT_BLOCK_ALLOC_SIZE), serializer(stream) { + } + + void Reset() { + to_canonical.clear(); + from_canonical.clear(); + table_indices.clear(); + expression_info.clear(); + } + + MemoryStream stream; + BinarySerializer serializer; + + unordered_map to_canonical; + unordered_map from_canonical; + + vector table_indices; + vector> expression_info; +}; + +class PlanSignature { +private: + PlanSignature(const MemoryStream &stream_p, idx_t offset_p, idx_t length_p, + vector> &&child_signatures_p, idx_t operator_count_p) + : stream(stream_p), offset(offset_p), length(length_p), + signature_hash(Hash(stream_p.GetData() + offset, length)), child_signatures(std::move(child_signatures_p)), + operator_count(operator_count_p) { + } + +public: + static unique_ptr Create(PlanSignatureCreateState &state, LogicalOperator &op, + vector> &&child_signatures, + const idx_t operator_count) { + state.Reset(); + if (!OperatorIsSupported(op)) { + return nullptr; + } + + if (op.type == LogicalOperatorType::LOGICAL_CHUNK_GET && + op.Cast().collection->Count() > 1000) { + // Avoid serializing massive amounts of data (this is here because of the "Test TPCH arrow roundtrip" test) + return nullptr; + } + + // Construct maps for converting column bindings to canonical representation and back + static constexpr idx_t CANONICAL_TABLE_INDEX_OFFSET = 10000000000000; + for (const auto &child_op : op.children) { + for (const auto &child_cb : child_op->GetColumnBindings()) { + const auto &original = child_cb.table_index; + auto it = state.to_canonical.find(original); + if (it != state.to_canonical.end()) { + continue; // We've seen this table index before + } + const auto canonical = CANONICAL_TABLE_INDEX_OFFSET + state.to_canonical.size(); + state.to_canonical[original] = canonical; + state.from_canonical[canonical] = original; + } + } + + // Convert operators to canonical table indices + ConvertTableIndices(op, state.table_indices); + + // Convert expressions to canonical (table indices, aliases, query locations) + bool can_materialize = ConvertExpressions(op, state.to_canonical, state.expression_info); + + // Temporarily move children here as we don't want to serialize them + auto children = std::move(op.children); + op.children.clear(); + + // TODO: to allow for better detection of equivalent plans, we could: + // 1. Sort the children of operators + // 2. Sort the expressions of operators + + // Serialize canonical representation of operator + const auto offset = state.stream.GetPosition(); + state.serializer.Begin(); + try { // Operators will throw if they cannot serialize, so we need to try/catch here + op.Serialize(state.serializer); + } catch (std::exception &) { + can_materialize = false; + } + state.serializer.End(); + const auto length = state.stream.GetPosition() - offset; + + // Convert back from canonical + ConvertTableIndices(op, state.table_indices); + ConvertExpressions(op, state.from_canonical, state.expression_info); + + // Restore children + op.children = std::move(children); + + if (can_materialize) { + return unique_ptr( + new PlanSignature(state.stream, offset, length, std::move(child_signatures), operator_count)); + } + return nullptr; + } + + idx_t OperatorCount() const { + return operator_count; + } + + hash_t HashSignature() const { + auto res = signature_hash; + for (auto &child : child_signatures) { + res = CombineHash(res, child.get().HashSignature()); + } + return res; + } + + bool Equals(const PlanSignature &other) const { + if (this->GetSignature() != other.GetSignature()) { + return false; + } + if (this->child_signatures.size() != other.child_signatures.size()) { + return false; + } + for (idx_t child_idx = 0; child_idx < this->child_signatures.size(); ++child_idx) { + if (!this->child_signatures[child_idx].get().Equals(other.child_signatures[child_idx].get())) { + return false; + } + } + return true; + } + +private: + String GetSignature() const { + return String(char_ptr_cast(stream.GetData() + offset), NumericCast(length)); + } + + static bool OperatorIsSupported(const LogicalOperator &op) { + switch (op.type) { + case LogicalOperatorType::LOGICAL_PROJECTION: + case LogicalOperatorType::LOGICAL_FILTER: + case LogicalOperatorType::LOGICAL_AGGREGATE_AND_GROUP_BY: + case LogicalOperatorType::LOGICAL_WINDOW: + case LogicalOperatorType::LOGICAL_UNNEST: + case LogicalOperatorType::LOGICAL_LIMIT: + case LogicalOperatorType::LOGICAL_ORDER_BY: + case LogicalOperatorType::LOGICAL_TOP_N: + case LogicalOperatorType::LOGICAL_DISTINCT: + case LogicalOperatorType::LOGICAL_PIVOT: + case LogicalOperatorType::LOGICAL_GET: + case LogicalOperatorType::LOGICAL_CHUNK_GET: + case LogicalOperatorType::LOGICAL_EXPRESSION_GET: + case LogicalOperatorType::LOGICAL_DUMMY_SCAN: + case LogicalOperatorType::LOGICAL_EMPTY_RESULT: + case LogicalOperatorType::LOGICAL_COMPARISON_JOIN: + case LogicalOperatorType::LOGICAL_ANY_JOIN: + case LogicalOperatorType::LOGICAL_CROSS_PRODUCT: + case LogicalOperatorType::LOGICAL_POSITIONAL_JOIN: + case LogicalOperatorType::LOGICAL_ASOF_JOIN: + case LogicalOperatorType::LOGICAL_UNION: + case LogicalOperatorType::LOGICAL_EXCEPT: + case LogicalOperatorType::LOGICAL_INTERSECT: + return true; + default: + // Unsupported: + // - case LogicalOperatorType::LOGICAL_COPY_TO_FILE: + // - case LogicalOperatorType::LOGICAL_SAMPLE: + // - case LogicalOperatorType::LOGICAL_COPY_DATABASE: + // - case LogicalOperatorType::LOGICAL_DELIM_GET: + // - case LogicalOperatorType::LOGICAL_CTE_REF: + // - case LogicalOperatorType::LOGICAL_JOIN: + // - case LogicalOperatorType::LOGICAL_DELIM_JOIN: + // - case LogicalOperatorType::LOGICAL_DEPENDENT_JOIN: + // - case LogicalOperatorType::LOGICAL_RECURSIVE_CTE: + // - case LogicalOperatorType::LOGICAL_MATERIALIZED_CTE: + // - case LogicalOperatorType::LOGICAL_EXTENSION_OPERATOR + return false; + } + } + + template + static void ConvertTableIndices(LogicalOperator &op, vector &table_indices) { + switch (op.type) { + case LogicalOperatorType::LOGICAL_GET: { + ConvertTableIndicesGeneric(op, table_indices); + break; + } + case LogicalOperatorType::LOGICAL_CHUNK_GET: { + ConvertTableIndicesGeneric(op, table_indices); + break; + } + case LogicalOperatorType::LOGICAL_EXPRESSION_GET: { + ConvertTableIndicesGeneric(op, table_indices); + break; + } + case LogicalOperatorType::LOGICAL_DUMMY_SCAN: { + ConvertTableIndicesGeneric(op, table_indices); + break; + } + case LogicalOperatorType::LOGICAL_CTE_REF: { + ConvertTableIndicesGeneric(op, table_indices); + break; + } + case LogicalOperatorType::LOGICAL_PROJECTION: { + ConvertTableIndicesGeneric(op, table_indices); + break; + } + case LogicalOperatorType::LOGICAL_PIVOT: { + auto &pivot = op.Cast(); + if (TO_CANONICAL) { + table_indices.emplace_back(pivot.pivot_index); + } + pivot.pivot_index = TO_CANONICAL ? 0 : table_indices[0]; + break; + } + case LogicalOperatorType::LOGICAL_UNNEST: { + auto &unnest = op.Cast(); + if (TO_CANONICAL) { + table_indices.emplace_back(unnest.unnest_index); + } + unnest.unnest_index = TO_CANONICAL ? 0 : table_indices[0]; + break; + } + case LogicalOperatorType::LOGICAL_WINDOW: { + auto &window = op.Cast(); + if (TO_CANONICAL) { + table_indices.emplace_back(window.window_index); + } + window.window_index = TO_CANONICAL ? 0 : table_indices[0]; + break; + } + case LogicalOperatorType::LOGICAL_AGGREGATE_AND_GROUP_BY: { + auto &aggregate = op.Cast(); + if (TO_CANONICAL) { + table_indices.emplace_back(aggregate.group_index); + table_indices.emplace_back(aggregate.aggregate_index); + table_indices.emplace_back(aggregate.groupings_index); + } + aggregate.group_index = TO_CANONICAL ? 0 : table_indices[0]; + aggregate.aggregate_index = TO_CANONICAL ? 1 : table_indices[1]; + aggregate.groupings_index = TO_CANONICAL ? 2 : table_indices[2]; + break; + } + case LogicalOperatorType::LOGICAL_UNION: + case LogicalOperatorType::LOGICAL_EXCEPT: + case LogicalOperatorType::LOGICAL_INTERSECT: { + auto &setop = op.Cast(); + if (TO_CANONICAL) { + table_indices.emplace_back(setop.table_index); + } + setop.table_index = TO_CANONICAL ? 0 : table_indices[0]; + break; + } + default: + break; + } + } + + template + static void ConvertTableIndicesGeneric(LogicalOperator &op, vector &table_idxs) { + auto &generic = op.Cast(); + if (TO_CANONICAL) { + table_idxs.emplace_back(generic.table_index); + } + generic.table_index = TO_CANONICAL ? 0 : table_idxs[0]; + } + + static bool ConvertExpressions(LogicalOperator &op, const unordered_map &table_index_mapping, + vector> &expression_info) { + bool can_materialize = true; + const auto to_canonical = expression_info.empty(); + idx_t info_idx = 0; + LogicalOperatorVisitor::EnumerateExpressions(op, [&](unique_ptr *expr) { + ExpressionIterator::EnumerateExpression(*expr, [&](unique_ptr &child) { + if (child->GetExpressionClass() == ExpressionClass::BOUND_COLUMN_REF) { + auto &col_ref = child->Cast(); + auto &table_index = col_ref.binding.table_index; + auto it = table_index_mapping.find(table_index); + D_ASSERT(it != table_index_mapping.end()); + table_index = it->second; + } + if (to_canonical) { + expression_info.emplace_back(std::move(child->alias), child->query_location); + child->alias.clear(); + child->query_location.SetInvalid(); + } else { + auto &info = expression_info[info_idx++]; + child->alias = std::move(info.first); + child->query_location = info.second; + } + if (child->IsVolatile()) { + can_materialize = false; + } + }); + }); + return can_materialize; + } + +private: + const MemoryStream &stream; + const idx_t offset; + const idx_t length; + + const hash_t signature_hash; + + const vector> child_signatures; + const idx_t operator_count; +}; + +struct PlanSignatureHash { + std::size_t operator()(const PlanSignature &k) const { + return k.HashSignature(); + } +}; + +struct PlanSignatureEquality { + bool operator()(const PlanSignature &a, const PlanSignature &b) const { + return a.Equals(b); + } +}; + +struct SubplanInfo { + explicit SubplanInfo(unique_ptr &op) : subplans({op}), lowest_common_ancestor(op) { + } + vector>> subplans; + reference> lowest_common_ancestor; +}; + +using subplan_map_t = unordered_map, SubplanInfo, PlanSignatureHash, PlanSignatureEquality>; + +//===--------------------------------------------------------------------===// +// CommonSubplanFinder +//===--------------------------------------------------------------------===// +class CommonSubplanFinder { +public: + CommonSubplanFinder() { + } + +private: + struct OperatorInfo { + OperatorInfo(unique_ptr &parent_p, const idx_t &depth_p) : parent(parent_p), depth(depth_p) { + } + + unique_ptr &parent; + const idx_t depth; + unique_ptr signature; + }; + + struct StackNode { + explicit StackNode(unique_ptr &op_p) : op(op_p), child_index(0) { + } + + bool HasMoreChildren() const { + return child_index < op->children.size(); + } + + unique_ptr &GetNextChild() { + D_ASSERT(child_index < op->children.size()); + return op->children[child_index++]; + }; + + unique_ptr &op; + idx_t child_index; + }; + +public: + subplan_map_t FindCommonSubplans(reference> root) { + // Find first operator with more than 1 child + while (root.get()->children.size() == 1) { + root = root.get()->children[0]; + } + + // Recurse through query plan using stack-based recursion + vector stack; + stack.emplace_back(root); + operator_infos.emplace(root, OperatorInfo(root, 0)); + + while (!stack.empty()) { + auto ¤t = stack.back(); + + // Depth-first + if (current.HasMoreChildren()) { + auto &child = current.GetNextChild(); + operator_infos.emplace(child, OperatorInfo(current.op, stack.size())); + stack.emplace_back(child); + continue; + } + + if (!RefersToSameObject(current.op, root.get())) { + // We have all child information for this operator now, compute signature + auto &signature = operator_infos.find(current.op)->second.signature; + signature = CreatePlanSignature(current.op); + + // Add to subplans (if we got actually got a signature) + if (signature) { + auto it = subplans.find(*signature); + if (it == subplans.end()) { + subplans.emplace(*signature, SubplanInfo(current.op)); + } else { + auto &info = it->second; + info.subplans.emplace_back(current.op); + info.lowest_common_ancestor = LowestCommonAncestor(info.lowest_common_ancestor, current.op); + } + } + } + + // Done with current + stack.pop_back(); + } + + // Filter out redundant or ineligible subplans before returning + for (auto it = subplans.begin(); it != subplans.end();) { + if (it->first.get().OperatorCount() == 1) { + it = subplans.erase(it); // Just one operator in this subplan + continue; + } + if (it->second.subplans.size() == 1) { + it = subplans.erase(it); // No other identical subplan + continue; + } + auto &subplan = it->second.subplans[0].get(); + auto &parent = operator_infos.find(subplan)->second.parent; + auto &parent_signature = operator_infos.find(parent)->second.signature; + if (parent_signature) { + auto parent_it = subplans.find(*parent_signature); + if (parent_it != subplans.end() && it->second.subplans.size() == parent_it->second.subplans.size()) { + it = subplans.erase(it); // Parent has exact same number of identical subplans + continue; + } + } + if (!CTEInlining::EndsInAggregateOrDistinct(*subplan)) { + it = subplans.erase(it); // Not eligible for materialization + continue; + } + it++; // This subplan might be useful + } + + return std::move(subplans); + } + +private: + unique_ptr CreatePlanSignature(const unique_ptr &op) { + vector> child_signatures; + idx_t operator_count = 1; + for (auto &child : op->children) { + auto it = operator_infos.find(child); + D_ASSERT(it != operator_infos.end()); + if (!it->second.signature) { + return nullptr; // Failed to create signature from one of the children + } + child_signatures.emplace_back(*it->second.signature); + operator_count += it->second.signature->OperatorCount(); + } + return PlanSignature::Create(state, *op, std::move(child_signatures), operator_count); + } + + unique_ptr &LowestCommonAncestor(reference> a, + reference> b) { + auto a_it = operator_infos.find(a); + auto b_it = operator_infos.find(b); + D_ASSERT(a_it != operator_infos.end() && b_it != operator_infos.end()); + + // Get parents of a and b until they're at the same depth + while (a_it->second.depth > b_it->second.depth) { + a = a_it->second.parent; + a_it = operator_infos.find(a); + D_ASSERT(a_it != operator_infos.end()); + } + while (b_it->second.depth > a_it->second.depth) { + b = b_it->second.parent; + b_it = operator_infos.find(b); + D_ASSERT(b_it != operator_infos.end()); + } + + // Move up one level at a time for both until ancestor is the same + while (!RefersToSameObject(a, b)) { + a_it = operator_infos.find(a); + b_it = operator_infos.find(b); + D_ASSERT(a_it != operator_infos.end() && b_it != operator_infos.end()); + a = a_it->second.parent; + b = b_it->second.parent; + } + + return a.get(); + } + +private: + //! Mapping from operator to info + reference_map_t, OperatorInfo> operator_infos; + //! Mapping from subplan signature to subplan information + subplan_map_t subplans; + //! State for creating PlanSignature with reusable data structures + PlanSignatureCreateState state; +}; + +//===--------------------------------------------------------------------===// +// CommonSubplanOptimizer +//===--------------------------------------------------------------------===// +CommonSubplanOptimizer::CommonSubplanOptimizer(Optimizer &optimizer_p) : optimizer(optimizer_p) { +} + +static void ConvertSubplansToCTE(Optimizer &optimizer, unique_ptr &op, SubplanInfo &subplan_info) { + const auto cte_index = optimizer.binder.GenerateTableIndex(); + const auto cte_name = StringUtil::Format("__common_subplan_1"); + + // Resolve types to be used for creating the materialized CTE and refs + op->ResolveOperatorTypes(); + + // Get types and names + const auto &types = subplan_info.subplans[0].get()->types; + vector col_names; + for (idx_t i = 0; i < types.size(); i++) { + col_names.emplace_back(StringUtil::Format("%s_col_%llu", cte_name, i)); + } + + // Create CTE refs and figure out column binding replacements + vector> cte_refs; + ColumnBindingReplacer replacer; + for (auto &subplan : subplan_info.subplans) { + cte_refs.emplace_back( + make_uniq(optimizer.binder.GenerateTableIndex(), cte_index, types, col_names)); + const auto old_bindings = subplan.get()->GetColumnBindings(); + const auto new_bindings = cte_refs.back()->GetColumnBindings(); + D_ASSERT(old_bindings.size() == new_bindings.size()); + for (idx_t i = 0; i < old_bindings.size(); i++) { + replacer.replacement_bindings.emplace_back(old_bindings[i], new_bindings[i]); + } + } + + // Create the materialized CTE and replace the common subplans with references to it + auto &lowest_common_ancestor = subplan_info.lowest_common_ancestor.get(); + auto cte = + make_uniq(cte_name, cte_index, types.size(), std::move(subplan_info.subplans[0].get()), + std::move(lowest_common_ancestor), CTEMaterialize::CTE_MATERIALIZE_DEFAULT); + for (idx_t i = 0; i < subplan_info.subplans.size(); i++) { + subplan_info.subplans[i].get() = std::move(cte_refs[i]); + } + lowest_common_ancestor = std::move(cte); + + // Replace bindings of subplans with those of the CTE refs + replacer.stop_operator = lowest_common_ancestor.get(); + replacer.VisitOperator(*op); // Replace from the root until CTE + replacer.VisitOperator(*lowest_common_ancestor->children[1]); // Replace in CTE child +} + +unique_ptr CommonSubplanOptimizer::Optimize(unique_ptr op) { + // Bottom-up identification of identical subplans + CommonSubplanFinder finder; + auto subplans = finder.FindCommonSubplans(op); + + // Identify the single best subplan (TODO: for now, in the future we should identify multiple) + if (subplans.empty()) { + return op; // No eligible subplans + } + auto best_it = subplans.begin(); + for (auto it = ++subplans.begin(); it != subplans.end(); it++) { + if (it->first.get().OperatorCount() > best_it->first.get().OperatorCount()) { + best_it = it; + } + } + + // Create a CTE! + ConvertSubplansToCTE(optimizer, op, best_it->second); + return op; +} + +} // namespace duckdb diff --git a/src/duckdb/src/optimizer/cte_inlining.cpp b/src/duckdb/src/optimizer/cte_inlining.cpp index 0b9e942ee..116d64768 100644 --- a/src/duckdb/src/optimizer/cte_inlining.cpp +++ b/src/duckdb/src/optimizer/cte_inlining.cpp @@ -55,10 +55,14 @@ static bool ContainsLimit(const LogicalOperator &op) { return false; } -static bool EndsInAggregateOrDistinct(const LogicalOperator &op) { - if (op.type == LogicalOperatorType::LOGICAL_AGGREGATE_AND_GROUP_BY || - op.type == LogicalOperatorType::LOGICAL_DISTINCT) { +bool CTEInlining::EndsInAggregateOrDistinct(const LogicalOperator &op) { + switch (op.type) { + case LogicalOperatorType::LOGICAL_AGGREGATE_AND_GROUP_BY: + case LogicalOperatorType::LOGICAL_DISTINCT: + case LogicalOperatorType::LOGICAL_WINDOW: return true; + default: + break; } if (op.children.size() != 1) { return false; @@ -146,8 +150,7 @@ void CTEInlining::TryInlining(unique_ptr &op) { } } -bool CTEInlining::Inline(unique_ptr &op, LogicalOperator &materialized_cte, - bool requires_copy) { +bool CTEInlining::Inline(unique_ptr &op, LogicalOperator &materialized_cte, bool requires_copy) { if (op->type == LogicalOperatorType::LOGICAL_CTE_REF) { auto &cteref = op->Cast(); auto &cte = materialized_cte.Cast(); diff --git a/src/duckdb/src/optimizer/expression_rewriter.cpp b/src/duckdb/src/optimizer/expression_rewriter.cpp index c8836a380..e21b5bcfd 100644 --- a/src/duckdb/src/optimizer/expression_rewriter.cpp +++ b/src/duckdb/src/optimizer/expression_rewriter.cpp @@ -55,7 +55,7 @@ unique_ptr ExpressionRewriter::ConstantOrNull(vector(value)); return make_uniq(type, func, std::move(children), ConstantOrNull::Bind(std::move(value))); } diff --git a/src/duckdb/src/optimizer/filter_combiner.cpp b/src/duckdb/src/optimizer/filter_combiner.cpp index 8e4a295b4..ddbe82ab0 100644 --- a/src/duckdb/src/optimizer/filter_combiner.cpp +++ b/src/duckdb/src/optimizer/filter_combiner.cpp @@ -1,5 +1,6 @@ #include "duckdb/optimizer/filter_combiner.hpp" +#include "duckdb/common/enums/expression_type.hpp" #include "duckdb/execution/expression_executor.hpp" #include "duckdb/optimizer/optimizer.hpp" #include "duckdb/planner/expression.hpp" @@ -907,6 +908,12 @@ FilterResult FilterCombiner::AddTransitiveFilters(BoundComparisonExpression &com idx_t left_equivalence_set = GetEquivalenceSet(left_node); idx_t right_equivalence_set = GetEquivalenceSet(right_node); if (left_equivalence_set == right_equivalence_set) { + if (comparison.GetExpressionType() == ExpressionType::COMPARE_GREATERTHAN || + comparison.GetExpressionType() == ExpressionType::COMPARE_LESSTHAN) { + // non equal comparison has equal equivalence set, then it is unsatisfiable + // e.g., j > i AND i < j is unsatisfiable + return FilterResult::UNSATISFIABLE; + } // this equality filter already exists, prune it return FilterResult::SUCCESS; } diff --git a/src/duckdb/src/optimizer/filter_pullup.cpp b/src/duckdb/src/optimizer/filter_pullup.cpp index 219611387..f9ebb63c3 100644 --- a/src/duckdb/src/optimizer/filter_pullup.cpp +++ b/src/duckdb/src/optimizer/filter_pullup.cpp @@ -6,6 +6,7 @@ #include "duckdb/planner/operator/logical_comparison_join.hpp" #include "duckdb/planner/operator/logical_cross_product.hpp" #include "duckdb/planner/operator/logical_join.hpp" +#include "duckdb/planner/operator/logical_distinct.hpp" namespace duckdb { @@ -26,6 +27,7 @@ unique_ptr FilterPullup::Rewrite(unique_ptr op case LogicalOperatorType::LOGICAL_EXCEPT: return PullupSetOperation(std::move(op)); case LogicalOperatorType::LOGICAL_DISTINCT: + return PullupDistinct(std::move(op)); case LogicalOperatorType::LOGICAL_ORDER_BY: { // we can just pull directly through these operations without any rewriting op->children[0] = Rewrite(std::move(op->children[0])); @@ -115,6 +117,18 @@ unique_ptr FilterPullup::PullupCrossProduct(unique_ptr FilterPullup::PullupDistinct(unique_ptr op) { + const auto &distinct = op->Cast(); + if (distinct.distinct_type == DistinctType::DISTINCT) { + // Can pull up through a DISTINCT + op->children[0] = Rewrite(std::move(op->children[0])); + return op; + } + // Cannot pull up through a DISTINCT ON (see #19327) + D_ASSERT(distinct.distinct_type == DistinctType::DISTINCT_ON); + return FinishPullup(std::move(op)); +} + unique_ptr FilterPullup::GeneratePullupFilter(unique_ptr child, vector> &expressions) { unique_ptr filter = make_uniq(); diff --git a/src/duckdb/src/optimizer/filter_pushdown.cpp b/src/duckdb/src/optimizer/filter_pushdown.cpp index c4f7bb04b..7c13386d9 100644 --- a/src/duckdb/src/optimizer/filter_pushdown.cpp +++ b/src/duckdb/src/optimizer/filter_pushdown.cpp @@ -208,17 +208,23 @@ unique_ptr FilterPushdown::PushdownJoin(unique_ptrfilter)); D_ASSERT(result != FilterResult::UNSUPPORTED); - (void)result; + if (result == FilterResult::UNSATISFIABLE) { + // one of the filters is unsatisfiable - abort filter pushdown + return FilterResult::UNSATISFIABLE; + } } filters.clear(); + return FilterResult::SUCCESS; } FilterResult FilterPushdown::AddFilter(unique_ptr expr) { - PushFilters(); + if (PushFilters() == FilterResult::UNSATISFIABLE) { + return FilterResult::UNSATISFIABLE; + } // split up the filters by AND predicate vector> expressions; expressions.push_back(std::move(expr)); @@ -276,51 +282,52 @@ unique_ptr FilterPushdown::PushFinalFilters(unique_ptr FilterPushdown::FinishPushdown(unique_ptr op) { - if (op->type == LogicalOperatorType::LOGICAL_DELIM_JOIN) { - for (idx_t i = 0; i < filters.size(); i++) { - auto &f = *filters[i]; - for (auto &child : op->children) { - FilterPushdown pushdown(optimizer, convert_mark_joins); +unique_ptr FilterPushdown::PushFiltersIntoDelimJoin(unique_ptr op) { + for (idx_t i = 0; i < filters.size(); i++) { + auto &f = *filters[i]; + for (auto &child : op->children) { + FilterPushdown pushdown(optimizer, convert_mark_joins); - // check if filter bindings can be applied to the child bindings. - auto child_bindings = child->GetColumnBindings(); - unordered_set child_bindings_table; - for (auto &binding : child_bindings) { - child_bindings_table.insert(binding.table_index); - } + // check if filter bindings can be applied to the child bindings. + auto child_bindings = child->GetColumnBindings(); + unordered_set child_bindings_table; + for (auto &binding : child_bindings) { + child_bindings_table.insert(binding.table_index); + } - // Check if ALL bindings of the filter are present in the child - bool should_push = true; - for (auto &binding : f.bindings) { - if (child_bindings_table.find(binding) == child_bindings_table.end()) { - should_push = false; - break; - } + // Check if ALL bindings of the filter are present in the child + bool should_push = true; + for (auto &binding : f.bindings) { + if (child_bindings_table.find(binding) == child_bindings_table.end()) { + should_push = false; + break; } + } - if (!should_push) { - continue; - } + if (!should_push) { + continue; + } - // copy the filter - auto filter_copy = f.filter->Copy(); - if (pushdown.AddFilter(std::move(filter_copy)) == FilterResult::UNSATISFIABLE) { - return make_uniq(std::move(op)); - } + // copy the filter + auto filter_copy = f.filter->Copy(); + if (pushdown.AddFilter(std::move(filter_copy)) == FilterResult::UNSATISFIABLE) { + return make_uniq(std::move(op)); + } - // push the filter into the child. - pushdown.GenerateFilters(); - child = pushdown.Rewrite(std::move(child)); + // push the filter into the child. + pushdown.GenerateFilters(); + child = pushdown.Rewrite(std::move(child)); - // Don't push same filter again - filters.erase_at(i); - i--; - break; - } + // Don't push same filter again + filters.erase_at(i); + i--; + break; } } + return op; +} +unique_ptr FilterPushdown::FinishPushdown(unique_ptr op) { // unhandled type, first perform filter pushdown in its children for (auto &child : op->children) { FilterPushdown pushdown(optimizer, convert_mark_joins); diff --git a/src/duckdb/src/optimizer/join_order/join_order_optimizer.cpp b/src/duckdb/src/optimizer/join_order/join_order_optimizer.cpp index b90d22b0f..767b918c4 100644 --- a/src/duckdb/src/optimizer/join_order/join_order_optimizer.cpp +++ b/src/duckdb/src/optimizer/join_order/join_order_optimizer.cpp @@ -25,7 +25,6 @@ JoinOrderOptimizer JoinOrderOptimizer::CreateChildOptimizer() { unique_ptr JoinOrderOptimizer::Optimize(unique_ptr plan, optional_ptr stats) { - if (depth > query_graph_manager.context.config.max_expression_depth) { // Very deep plans will eventually consume quite some stack space // Returning the current plan is always a valid choice diff --git a/src/duckdb/src/optimizer/join_order/plan_enumerator.cpp b/src/duckdb/src/optimizer/join_order/plan_enumerator.cpp index fc282aba9..e79107606 100644 --- a/src/duckdb/src/optimizer/join_order/plan_enumerator.cpp +++ b/src/duckdb/src/optimizer/join_order/plan_enumerator.cpp @@ -102,7 +102,6 @@ const reference_map_t> &PlanEnumerator:: unique_ptr PlanEnumerator::CreateJoinTree(JoinRelationSet &set, const vector> &possible_connections, DPJoinNode &left, DPJoinNode &right) { - // FIXME: should consider different join algorithms, should we pick a join algorithm here as well? (probably) optional_ptr best_connection = possible_connections.back().get(); // cross products are technically still connections, but the filter expression is a null_ptr diff --git a/src/duckdb/src/optimizer/join_order/query_graph.cpp b/src/duckdb/src/optimizer/join_order/query_graph.cpp index beb9e1521..01e167fec 100644 --- a/src/duckdb/src/optimizer/join_order/query_graph.cpp +++ b/src/duckdb/src/optimizer/join_order/query_graph.cpp @@ -79,7 +79,6 @@ void QueryGraphEdges::CreateEdge(JoinRelationSet &left, JoinRelationSet &right, void QueryGraphEdges::EnumerateNeighborsDFS(JoinRelationSet &node, reference info, idx_t index, const std::function &callback) const { - for (auto &neighbor : info.get().neighbors) { if (callback(*neighbor)) { return; diff --git a/src/duckdb/src/optimizer/join_order/query_graph_manager.cpp b/src/duckdb/src/optimizer/join_order/query_graph_manager.cpp index 94ee1a2c8..28f6a9eb3 100644 --- a/src/duckdb/src/optimizer/join_order/query_graph_manager.cpp +++ b/src/duckdb/src/optimizer/join_order/query_graph_manager.cpp @@ -243,7 +243,6 @@ GenerateJoinRelation QueryGraphManager::GenerateJoins(vectorsecond; if (!dp_entry->second->is_leaf) { - // generate the left and right children auto left = GenerateJoins(extracted_relations, node->left_set); auto right = GenerateJoins(extracted_relations, node->right_set); diff --git a/src/duckdb/src/optimizer/join_order/relation_manager.cpp b/src/duckdb/src/optimizer/join_order/relation_manager.cpp index 4916f662e..7359f639b 100644 --- a/src/duckdb/src/optimizer/join_order/relation_manager.cpp +++ b/src/duckdb/src/optimizer/join_order/relation_manager.cpp @@ -46,7 +46,6 @@ void RelationManager::AddAggregateOrWindowRelation(LogicalOperator &op, optional void RelationManager::AddRelation(LogicalOperator &op, optional_ptr parent, const RelationStats &stats) { - // if parent is null, then this is a root relation // if parent is not null, it should have multiple children D_ASSERT(!parent || parent->children.size() >= 2); @@ -54,6 +53,13 @@ void RelationManager::AddRelation(LogicalOperator &op, optional_ptr(); + if (get.function.name == "unnest") { + is_unnest_or_get_with_unnest = true; + } + } if (table_indexes.empty()) { // relation represents a non-reorderable relation, most likely a join relation // Get the tables referenced in the non-reorderable relation and add them to the relation mapping @@ -65,7 +71,7 @@ void RelationManager::AddRelation(LogicalOperator &op, optional_ptr limit_op, RelationS } } +void RelationManager::AddUnnestRelation(JoinOrderOptimizer &optimizer, LogicalOperator &op, LogicalOperator &input_op, + optional_ptr parent, RelationStats &child_stats, + optional_ptr limit_op, + vector> &datasource_filters) { + D_ASSERT(!op.children.empty()); + auto child_optimizer = optimizer.CreateChildOptimizer(); + op.children[0] = child_optimizer.Optimize(std::move(op.children[0]), &child_stats); + if (!datasource_filters.empty()) { + child_stats.cardinality = LossyNumericCast(static_cast(child_stats.cardinality) * + RelationStatisticsHelper::DEFAULT_SELECTIVITY); + } + ModifyStatsIfLimit(limit_op.get(), child_stats); + AddRelation(input_op, parent, child_stats); +} + bool RelationManager::ExtractJoinRelations(JoinOrderOptimizer &optimizer, LogicalOperator &input_op, vector> &filter_operators, optional_ptr parent) { @@ -279,15 +300,7 @@ bool RelationManager::ExtractJoinRelations(JoinOrderOptimizer &optimizer, Logica case LogicalOperatorType::LOGICAL_UNNEST: { // optimize children of unnest RelationStats child_stats; - auto child_optimizer = optimizer.CreateChildOptimizer(); - op->children[0] = child_optimizer.Optimize(std::move(op->children[0]), &child_stats); - // the extracted cardinality should be set for window - if (!datasource_filters.empty()) { - child_stats.cardinality = LossyNumericCast(static_cast(child_stats.cardinality) * - RelationStatisticsHelper::DEFAULT_SELECTIVITY); - } - ModifyStatsIfLimit(limit_op.get(), child_stats); - AddRelation(input_op, parent, child_stats); + AddUnnestRelation(optimizer, *op, input_op, parent, child_stats, limit_op, datasource_filters); return true; } case LogicalOperatorType::LOGICAL_COMPARISON_JOIN: { @@ -345,6 +358,11 @@ bool RelationManager::ExtractJoinRelations(JoinOrderOptimizer &optimizer, Logica case LogicalOperatorType::LOGICAL_GET: { // TODO: Get stats from a logical GET auto &get = op->Cast(); + if (get.function.name == "unnest" && !op->children.empty()) { + RelationStats child_stats; + AddUnnestRelation(optimizer, *op, input_op, parent, child_stats, limit_op, datasource_filters); + return true; + } auto stats = RelationStatisticsHelper::ExtractGetStats(get, context); // if there is another logical filter that could not be pushed down into the // table scan, apply another selectivity. @@ -542,7 +560,6 @@ vector> RelationManager::ExtractEdges(LogicalOperator &op auto &join = f_op.Cast(); D_ASSERT(join.expressions.empty()); if (join.join_type == JoinType::SEMI || join.join_type == JoinType::ANTI) { - auto conjunction_expression = make_uniq(ExpressionType::CONJUNCTION_AND); // create a conjunction expression for the semi join. // It's possible multiple LHS relations have a condition in diff --git a/src/duckdb/src/optimizer/late_materialization.cpp b/src/duckdb/src/optimizer/late_materialization.cpp index 4e5b0f13e..fa2f50172 100644 --- a/src/duckdb/src/optimizer/late_materialization.cpp +++ b/src/duckdb/src/optimizer/late_materialization.cpp @@ -1,4 +1,6 @@ #include "duckdb/optimizer/late_materialization.hpp" + +#include "duckdb/optimizer/late_materialization_helper.hpp" #include "duckdb/planner/operator/logical_comparison_join.hpp" #include "duckdb/planner/operator/logical_filter.hpp" #include "duckdb/planner/operator/logical_get.hpp" @@ -22,53 +24,6 @@ LateMaterialization::LateMaterialization(Optimizer &optimizer) : optimizer(optim max_row_count = DBConfig::GetSetting(optimizer.context); } -vector LateMaterialization::GetOrInsertRowIds(LogicalGet &get) { - auto &column_ids = get.GetMutableColumnIds(); - - vector result; - for (idx_t r_idx = 0; r_idx < row_id_column_ids.size(); ++r_idx) { - // check if it is already projected - auto row_id_column_id = row_id_column_ids[r_idx]; - auto &row_id_column = row_id_columns[r_idx]; - optional_idx row_id_index; - for (idx_t i = 0; i < column_ids.size(); ++i) { - if (column_ids[i].GetPrimaryIndex() == row_id_column_id) { - // already projected - return the id - row_id_index = i; - break; - } - } - if (row_id_index.IsValid()) { - result.push_back(row_id_index.GetIndex()); - continue; - } - // row id is not yet projected - push it and return the new index - column_ids.push_back(ColumnIndex(row_id_column_id)); - if (!get.projection_ids.empty()) { - get.projection_ids.push_back(column_ids.size() - 1); - } - if (!get.types.empty()) { - get.types.push_back(row_id_column.type); - } - result.push_back(column_ids.size() - 1); - } - return result; -} - -unique_ptr LateMaterialization::ConstructLHS(LogicalGet &get) { - // we need to construct a new scan of the same table - auto table_index = optimizer.binder.GenerateTableIndex(); - auto new_get = make_uniq(table_index, get.function, get.bind_data->Copy(), get.returned_types, - get.names, get.virtual_columns); - new_get->GetMutableColumnIds() = get.GetColumnIds(); - new_get->projection_ids = get.projection_ids; - new_get->parameters = get.parameters; - new_get->named_parameters = get.named_parameters; - new_get->input_table_types = get.input_table_types; - new_get->input_table_names = get.input_table_names; - return new_get; -} - vector LateMaterialization::ConstructRHS(unique_ptr &op) { // traverse down until we reach the LogicalGet vector> stack; @@ -80,7 +35,7 @@ vector LateMaterialization::ConstructRHS(unique_ptr(); - auto row_id_indexes = GetOrInsertRowIds(get); + auto row_id_indexes = LateMaterializationHelper::GetOrInsertRowIds(get, row_id_column_ids, row_id_columns); idx_t column_count = get.projection_ids.empty() ? get.GetColumnIds().size() : get.projection_ids.size(); D_ASSERT(column_count == get.GetColumnBindings().size()); @@ -281,12 +236,12 @@ bool LateMaterialization::TryLateMaterialization(unique_ptr &op // we need to ensure the operator returns exactly the same column bindings as before // construct the LHS from the LogicalGet - auto lhs = ConstructLHS(get); + auto lhs = LateMaterializationHelper::CreateLHSGet(get, optimizer.binder); // insert the row-id column on the left hand side auto &lhs_get = *lhs; auto lhs_index = lhs_get.table_index; auto lhs_columns = lhs_get.GetColumnIds().size(); - auto lhs_row_indexes = GetOrInsertRowIds(lhs_get); + auto lhs_row_indexes = LateMaterializationHelper::GetOrInsertRowIds(lhs_get, row_id_column_ids, row_id_columns); vector lhs_bindings; for (auto &lhs_row_index : lhs_row_indexes) { lhs_bindings.emplace_back(lhs_index, lhs_row_index); diff --git a/src/duckdb/src/optimizer/late_materialization_helper.cpp b/src/duckdb/src/optimizer/late_materialization_helper.cpp new file mode 100644 index 000000000..4b81ba3a1 --- /dev/null +++ b/src/duckdb/src/optimizer/late_materialization_helper.cpp @@ -0,0 +1,52 @@ +#include "duckdb/optimizer/late_materialization_helper.hpp" + +namespace duckdb { + +unique_ptr LateMaterializationHelper::CreateLHSGet(const LogicalGet &rhs, Binder &binder) { + // we need to construct a new scan of the same table + auto table_index = binder.GenerateTableIndex(); + auto new_get = make_uniq(table_index, rhs.function, rhs.bind_data->Copy(), rhs.returned_types, + rhs.names, rhs.virtual_columns); + new_get->GetMutableColumnIds() = rhs.GetColumnIds(); + new_get->projection_ids = rhs.projection_ids; + new_get->parameters = rhs.parameters; + new_get->named_parameters = rhs.named_parameters; + new_get->input_table_types = rhs.input_table_types; + new_get->input_table_names = rhs.input_table_names; + return new_get; +} + +vector LateMaterializationHelper::GetOrInsertRowIds(LogicalGet &get, const vector &row_id_column_ids, + const vector &row_id_columns) { + auto &column_ids = get.GetMutableColumnIds(); + + vector result; + for (idx_t r_idx = 0; r_idx < row_id_column_ids.size(); ++r_idx) { + // check if it is already projected + auto row_id_column_id = row_id_column_ids[r_idx]; + auto &row_id_column = row_id_columns[r_idx]; + optional_idx row_id_index; + for (idx_t i = 0; i < column_ids.size(); ++i) { + if (column_ids[i].GetPrimaryIndex() == row_id_column_id) { + // already projected - return the id + row_id_index = i; + break; + } + } + if (row_id_index.IsValid()) { + result.push_back(row_id_index.GetIndex()); + continue; + } + // row id is not yet projected - push it and return the new index + column_ids.push_back(ColumnIndex(row_id_column_id)); + if (!get.projection_ids.empty()) { + get.projection_ids.push_back(column_ids.size() - 1); + } + if (!get.types.empty()) { + get.types.push_back(row_id_column.type); + } + result.push_back(column_ids.size() - 1); + } + return result; +} +} // namespace duckdb diff --git a/src/duckdb/src/optimizer/optimizer.cpp b/src/duckdb/src/optimizer/optimizer.cpp index ce6cb0045..42f1eba62 100644 --- a/src/duckdb/src/optimizer/optimizer.cpp +++ b/src/duckdb/src/optimizer/optimizer.cpp @@ -32,14 +32,17 @@ #include "duckdb/optimizer/statistics_propagator.hpp" #include "duckdb/optimizer/sum_rewriter.hpp" #include "duckdb/optimizer/topn_optimizer.hpp" +#include "duckdb/optimizer/topn_window_elimination.hpp" #include "duckdb/optimizer/unnest_rewriter.hpp" #include "duckdb/optimizer/late_materialization.hpp" +#include "duckdb/optimizer/common_subplan_optimizer.hpp" #include "duckdb/planner/binder.hpp" #include "duckdb/planner/planner.hpp" namespace duckdb { Optimizer::Optimizer(Binder &binder, ClientContext &context) : context(context), binder(binder), rewriter(context) { + rewriter.rules.push_back(make_uniq(rewriter)); rewriter.rules.push_back(make_uniq(rewriter)); rewriter.rules.push_back(make_uniq(rewriter)); rewriter.rules.push_back(make_uniq(rewriter)); @@ -127,6 +130,12 @@ void Optimizer::RunBuiltInOptimizers() { plan = cte_inlining.Optimize(std::move(plan)); }); + // convert common subplans into materialized CTEs + RunOptimizer(OptimizerType::COMMON_SUBPLAN, [&]() { + CommonSubplanOptimizer common_subplan_optimizer(*this); + plan = common_subplan_optimizer.Optimize(std::move(plan)); + }); + // Rewrites SUM(x + C) into SUM(x) + C * COUNT(x) RunOptimizer(OptimizerType::SUM_REWRITER, [&]() { SumRewriterOptimizer optimizer(*this); @@ -257,6 +266,12 @@ void Optimizer::RunBuiltInOptimizers() { statistics_map = propagator.GetStatisticsMap(); }); + // rewrite row_number window function + filter on row_number to aggregate + RunOptimizer(OptimizerType::TOP_N_WINDOW_ELIMINATION, [&]() { + TopNWindowElimination topn_window_elimination(context, *this, &statistics_map); + plan = topn_window_elimination.Optimize(std::move(plan)); + }); + // remove duplicate aggregates RunOptimizer(OptimizerType::COMMON_AGGREGATE, [&]() { CommonAggregateOptimizer common_aggregate; diff --git a/src/duckdb/src/optimizer/pushdown/pushdown_get.cpp b/src/duckdb/src/optimizer/pushdown/pushdown_get.cpp index 90dbbb823..ac4b6532a 100644 --- a/src/duckdb/src/optimizer/pushdown/pushdown_get.cpp +++ b/src/duckdb/src/optimizer/pushdown/pushdown_get.cpp @@ -4,6 +4,7 @@ #include "duckdb/planner/expression/bound_parameter_expression.hpp" #include "duckdb/planner/operator/logical_filter.hpp" #include "duckdb/planner/operator/logical_get.hpp" +#include "duckdb/planner/operator/logical_empty_result.hpp" namespace duckdb { unique_ptr FilterPushdown::PushdownGet(unique_ptr op) { @@ -48,7 +49,9 @@ unique_ptr FilterPushdown::PushdownGet(unique_ptr(std::move(op)); + } //! We generate the table filters that will be executed during the table scan vector pushdown_results; diff --git a/src/duckdb/src/optimizer/pushdown/pushdown_inner_join.cpp b/src/duckdb/src/optimizer/pushdown/pushdown_inner_join.cpp index 8370f4ca9..e2e4730d1 100644 --- a/src/duckdb/src/optimizer/pushdown/pushdown_inner_join.cpp +++ b/src/duckdb/src/optimizer/pushdown/pushdown_inner_join.cpp @@ -14,6 +14,7 @@ unique_ptr FilterPushdown::PushdownInnerJoin(unique_ptrCast(); D_ASSERT(join.join_type == JoinType::INNER); if (op->type == LogicalOperatorType::LOGICAL_DELIM_JOIN) { + op = PushFiltersIntoDelimJoin(std::move(op)); return FinishPushdown(std::move(op)); } // inner join: gather all the conditions of the inner join and add to the filter list diff --git a/src/duckdb/src/optimizer/pushdown/pushdown_left_join.cpp b/src/duckdb/src/optimizer/pushdown/pushdown_left_join.cpp index 9e56ed9d6..1ebf3cedd 100644 --- a/src/duckdb/src/optimizer/pushdown/pushdown_left_join.cpp +++ b/src/duckdb/src/optimizer/pushdown/pushdown_left_join.cpp @@ -78,6 +78,7 @@ unique_ptr FilterPushdown::PushdownLeftJoin(unique_ptr &right_bindings) { auto &join = op->Cast(); if (op->type == LogicalOperatorType::LOGICAL_DELIM_JOIN) { + op = PushFiltersIntoDelimJoin(std::move(op)); return FinishPushdown(std::move(op)); } FilterPushdown left_pushdown(optimizer, convert_mark_joins), right_pushdown(optimizer, convert_mark_joins); diff --git a/src/duckdb/src/optimizer/pushdown/pushdown_outer_join.cpp b/src/duckdb/src/optimizer/pushdown/pushdown_outer_join.cpp index 3d81c68c1..648c5e3d4 100644 --- a/src/duckdb/src/optimizer/pushdown/pushdown_outer_join.cpp +++ b/src/duckdb/src/optimizer/pushdown/pushdown_outer_join.cpp @@ -174,7 +174,6 @@ PushDownFiltersOnCoalescedEqualJoinKeys(vector> &filters, unique_ptr FilterPushdown::PushdownOuterJoin(unique_ptr op, unordered_set &left_bindings, unordered_set &right_bindings) { - if (op->type != LogicalOperatorType::LOGICAL_COMPARISON_JOIN) { return FinishPushdown(std::move(op)); } diff --git a/src/duckdb/src/optimizer/pushdown/pushdown_semi_anti_join.cpp b/src/duckdb/src/optimizer/pushdown/pushdown_semi_anti_join.cpp index 7d240e3f6..0b937fe25 100644 --- a/src/duckdb/src/optimizer/pushdown/pushdown_semi_anti_join.cpp +++ b/src/duckdb/src/optimizer/pushdown/pushdown_semi_anti_join.cpp @@ -12,6 +12,7 @@ using Filter = FilterPushdown::Filter; unique_ptr FilterPushdown::PushdownSemiAntiJoin(unique_ptr op) { auto &join = op->Cast(); if (op->type == LogicalOperatorType::LOGICAL_DELIM_JOIN) { + op = PushFiltersIntoDelimJoin(std::move(op)); return FinishPushdown(std::move(op)); } diff --git a/src/duckdb/src/optimizer/regex_range_filter.cpp b/src/duckdb/src/optimizer/regex_range_filter.cpp index fd9f98fe6..987c579af 100644 --- a/src/duckdb/src/optimizer/regex_range_filter.cpp +++ b/src/duckdb/src/optimizer/regex_range_filter.cpp @@ -16,7 +16,6 @@ namespace duckdb { unique_ptr RegexRangeFilter::Rewrite(unique_ptr op) { - for (idx_t child_idx = 0; child_idx < op->children.size(); child_idx++) { op->children[child_idx] = Rewrite(std::move(op->children[child_idx])); } diff --git a/src/duckdb/src/optimizer/remove_unused_columns.cpp b/src/duckdb/src/optimizer/remove_unused_columns.cpp index 20817633a..ece345984 100644 --- a/src/duckdb/src/optimizer/remove_unused_columns.cpp +++ b/src/duckdb/src/optimizer/remove_unused_columns.cpp @@ -228,89 +228,13 @@ void RemoveUnusedColumns::VisitOperator(LogicalOperator &op) { } case LogicalOperatorType::LOGICAL_GET: { LogicalOperatorVisitor::VisitOperatorExpressions(op); - if (everything_referenced) { - return; - } auto &get = op.Cast(); - if (!get.function.projection_pushdown) { - return; - } - - auto final_column_ids = get.GetColumnIds(); - - // Create "selection vector" of all column ids - vector proj_sel; - for (idx_t col_idx = 0; col_idx < final_column_ids.size(); col_idx++) { - proj_sel.push_back(col_idx); - } - // Create a copy that we can use to match ids later - auto col_sel = proj_sel; - // Clear unused ids, exclude filter columns that are projected out immediately - ClearUnusedExpressions(proj_sel, get.table_index, false); - - vector> filter_expressions; - // for every table filter, push a column binding into the column references map to prevent the column from - // being projected out - for (auto &filter : get.table_filters.filters) { - optional_idx index; - for (idx_t i = 0; i < final_column_ids.size(); i++) { - if (final_column_ids[i].GetPrimaryIndex() == filter.first) { - index = i; - break; - } - } - if (!index.IsValid()) { - throw InternalException("Could not find column index for table filter"); - } - - auto column_type = get.GetColumnType(ColumnIndex(filter.first)); - - ColumnBinding filter_binding(get.table_index, index.GetIndex()); - auto column_ref = make_uniq(std::move(column_type), filter_binding); - auto filter_expr = filter.second->ToExpression(*column_ref); - if (filter_expr->IsScalar()) { - filter_expr = std::move(column_ref); - } - VisitExpression(&filter_expr); - filter_expressions.push_back(std::move(filter_expr)); - } - - // Clear unused ids, include filter columns that are projected out immediately - ClearUnusedExpressions(col_sel, get.table_index); - - // Now set the column ids in the LogicalGet using the "selection vector" - vector column_ids; - column_ids.reserve(col_sel.size()); - for (auto col_sel_idx : col_sel) { - auto entry = column_references.find(ColumnBinding(get.table_index, col_sel_idx)); - if (entry == column_references.end()) { - throw InternalException("RemoveUnusedColumns - could not find referenced column"); - } - ColumnIndex new_index(final_column_ids[col_sel_idx].GetPrimaryIndex(), entry->second.child_columns); - column_ids.emplace_back(new_index); - } - if (column_ids.empty()) { - // this generally means we are only interested in whether or not anything exists in the table (e.g. - // EXISTS(SELECT * FROM tbl)) in this case, we just scan the row identifier column as it means we do not - // need to read any of the columns - column_ids.emplace_back(get.GetAnyColumn()); - } - get.SetColumnIds(std::move(column_ids)); - - if (!get.function.filter_prune) { - return; - } - // Now set the projection cols by matching the "selection vector" that excludes filter columns - // with the "selection vector" that includes filter columns - idx_t col_idx = 0; - get.projection_ids.clear(); - for (auto proj_sel_idx : proj_sel) { - for (; col_idx < col_sel.size(); col_idx++) { - if (proj_sel_idx == col_sel[col_idx]) { - get.projection_ids.push_back(col_idx); - break; - } - } + RemoveColumnsFromLogicalGet(get); + if (!op.children.empty()) { + // Some LOGICAL_GET operators (e.g., table in out functions) may have a + // child operator. So we recurse into it if it exists. + RemoveUnusedColumns remove(binder, context, true); + remove.VisitOperator(*op.children[0]); } return; } @@ -363,6 +287,92 @@ void RemoveUnusedColumns::VisitOperator(LogicalOperator &op) { } } +void RemoveUnusedColumns::RemoveColumnsFromLogicalGet(LogicalGet &get) { + if (everything_referenced) { + return; + } + if (!get.function.projection_pushdown) { + return; + } + + auto final_column_ids = get.GetColumnIds(); + + // Create "selection vector" of all column ids + vector proj_sel; + for (idx_t col_idx = 0; col_idx < final_column_ids.size(); col_idx++) { + proj_sel.push_back(col_idx); + } + // Create a copy that we can use to match ids later + auto col_sel = proj_sel; + // Clear unused ids, exclude filter columns that are projected out immediately + ClearUnusedExpressions(proj_sel, get.table_index, false); + + vector> filter_expressions; + // for every table filter, push a column binding into the column references map to prevent the column from + // being projected out + for (auto &filter : get.table_filters.filters) { + optional_idx index; + for (idx_t i = 0; i < final_column_ids.size(); i++) { + if (final_column_ids[i].GetPrimaryIndex() == filter.first) { + index = i; + break; + } + } + if (!index.IsValid()) { + throw InternalException("Could not find column index for table filter"); + } + + auto column_type = get.GetColumnType(ColumnIndex(filter.first)); + + ColumnBinding filter_binding(get.table_index, index.GetIndex()); + auto column_ref = make_uniq(std::move(column_type), filter_binding); + auto filter_expr = filter.second->ToExpression(*column_ref); + if (filter_expr->IsScalar()) { + filter_expr = std::move(column_ref); + } + VisitExpression(&filter_expr); + filter_expressions.push_back(std::move(filter_expr)); + } + + // Clear unused ids, include filter columns that are projected out immediately + ClearUnusedExpressions(col_sel, get.table_index); + + // Now set the column ids in the LogicalGet using the "selection vector" + vector column_ids; + column_ids.reserve(col_sel.size()); + for (auto col_sel_idx : col_sel) { + auto entry = column_references.find(ColumnBinding(get.table_index, col_sel_idx)); + if (entry == column_references.end()) { + throw InternalException("RemoveUnusedColumns - could not find referenced column"); + } + ColumnIndex new_index(final_column_ids[col_sel_idx].GetPrimaryIndex(), entry->second.child_columns); + column_ids.emplace_back(new_index); + } + if (column_ids.empty()) { + // this generally means we are only interested in whether or not anything exists in the table (e.g. + // EXISTS(SELECT * FROM tbl)) in this case, we just scan the row identifier column as it means we do not + // need to read any of the columns + column_ids.emplace_back(get.GetAnyColumn()); + } + get.SetColumnIds(std::move(column_ids)); + + if (!get.function.filter_prune) { + return; + } + // Now set the projection cols by matching the "selection vector" that excludes filter columns + // with the "selection vector" that includes filter columns + idx_t col_idx = 0; + get.projection_ids.clear(); + for (auto proj_sel_idx : proj_sel) { + for (; col_idx < col_sel.size(); col_idx++) { + if (proj_sel_idx == col_sel[col_idx]) { + get.projection_ids.push_back(col_idx); + break; + } + } + } +} + bool BaseColumnPruner::HandleStructExtractRecursive(Expression &expr, optional_ptr &colref, vector &indexes) { if (expr.GetExpressionClass() != ExpressionClass::BOUND_FUNCTION) { diff --git a/src/duckdb/src/optimizer/rule/comparison_simplification.cpp b/src/duckdb/src/optimizer/rule/comparison_simplification.cpp index dc778cfff..5fe22818e 100644 --- a/src/duckdb/src/optimizer/rule/comparison_simplification.cpp +++ b/src/duckdb/src/optimizer/rule/comparison_simplification.cpp @@ -17,7 +17,6 @@ ComparisonSimplificationRule::ComparisonSimplificationRule(ExpressionRewriter &r unique_ptr ComparisonSimplificationRule::Apply(LogicalOperator &op, vector> &bindings, bool &changes_made, bool is_root) { - auto &expr = bindings[0].get().Cast(); auto &constant_expr = bindings[1].get(); bool column_ref_left = expr.left.get() != &constant_expr; diff --git a/src/duckdb/src/optimizer/rule/constant_order_normalization.cpp b/src/duckdb/src/optimizer/rule/constant_order_normalization.cpp new file mode 100644 index 000000000..09376c3d3 --- /dev/null +++ b/src/duckdb/src/optimizer/rule/constant_order_normalization.cpp @@ -0,0 +1,127 @@ +#include "duckdb/optimizer/rule/constant_order_normalization.hpp" + +#include "duckdb/optimizer/expression_rewriter.hpp" +#include "duckdb/planner/expression/bound_function_expression.hpp" +#include "duckdb/function/function_binder.hpp" + +namespace duckdb { + +class RecursiveFunctionExpressionMatcher : public ExpressionMatcher { +public: + explicit RecursiveFunctionExpressionMatcher(vector> func_matchers) + : func_matchers(std::move(func_matchers)) { + } + bool Match(Expression &expr, vector> &bindings) override { + FunctionExpressionMatcher *target_matcher = nullptr; + for (const auto &matcher : func_matchers) { + if (matcher->Match(expr, bindings)) { + target_matcher = matcher.get(); + break; + } + } + if (target_matcher == nullptr) { + return false; + } + bindings.clear(); + RecursiveMatch(target_matcher, expr, bindings); + bindings.push_back(expr); + return true; + } + +private: + void RecursiveMatch(FunctionExpressionMatcher *func_matcher, Expression &expr, + vector> &bindings) { + vector> curr_bindings; + if (func_matcher->Match(expr, curr_bindings)) { + auto &func_expr = expr.Cast(); + for (auto &child : func_expr.children) { + RecursiveMatch(func_matcher, *(child.get()), bindings); + } + } else { + bindings.push_back(expr); + } + } + + vector> func_matchers; +}; + +ConstantOrderNormalizationRule::ConstantOrderNormalizationRule(ExpressionRewriter &rewriter) : Rule(rewriter) { + // '+' and '*' satisfy commutative law and associative law. + auto add_matcher = make_uniq(); + add_matcher->function = make_uniq("+"); + add_matcher->type = make_uniq(); + auto left_expression_matcher = make_uniq(); + auto right_expression_matcher = make_uniq(); + left_expression_matcher->type = make_uniq(); + right_expression_matcher->type = make_uniq(); + add_matcher->matchers.push_back(std::move(left_expression_matcher)); + add_matcher->matchers.push_back(std::move(right_expression_matcher)); + add_matcher->policy = SetMatcher::Policy::ORDERED; + + auto multiply_matcher = make_uniq(); + multiply_matcher->function = make_uniq("*"); + multiply_matcher->type = make_uniq(); + left_expression_matcher = make_uniq(); + right_expression_matcher = make_uniq(); + left_expression_matcher->type = make_uniq(); + right_expression_matcher->type = make_uniq(); + multiply_matcher->matchers.push_back(std::move(left_expression_matcher)); + multiply_matcher->matchers.push_back(std::move(right_expression_matcher)); + multiply_matcher->policy = SetMatcher::Policy::ORDERED; + + vector> func_matchers; + func_matchers.push_back(std::move(add_matcher)); + func_matchers.push_back(std::move(multiply_matcher)); + auto op = make_uniq(std::move(func_matchers)); + root = std::move(op); +} + +unique_ptr ConstantOrderNormalizationRule::Apply(LogicalOperator &op, + vector> &bindings, + bool &changes_made, bool is_root) { + auto &root = bindings.back().get().Cast(); + + // Put all constant expressions in front. + vector> ordered_bindings; + vector> remain_bindings; + idx_t last_constant_position = 0; + for (idx_t i = 0; i < bindings.size() - 1; ++i) { + if (bindings[i].get().IsFoldable()) { + ordered_bindings.push_back(bindings[i]); + last_constant_position = i; + } else { + remain_bindings.push_back(bindings[i]); + } + } + + if (ordered_bindings.size() <= 1 || last_constant_position == ordered_bindings.size() - 1) { + return nullptr; + } + ordered_bindings.insert(ordered_bindings.end(), remain_bindings.begin(), remain_bindings.end()); + + // Reconstruct the expression. + FunctionBinder binder(rewriter.context); + ErrorData error; + unique_ptr new_root = ordered_bindings[0].get().Copy(); + vector> children; + children.push_back(std::move(new_root)); + for (idx_t i = 1; i < ordered_bindings.size(); ++i) { + // Right child. + children.push_back(ordered_bindings[i].get().Copy()); + new_root = + binder.BindScalarFunction(DEFAULT_SCHEMA, root.function.name, std::move(children), error, root.is_operator); + if (!new_root) { + error.Throw(); + } + children.clear(); + // Left child. + children.push_back(std::move(new_root)); + } + + D_ASSERT(children.size() == 1); + D_ASSERT(children[0]->return_type == root.return_type); + + return std::move(children[0]); +} + +} // namespace duckdb diff --git a/src/duckdb/src/optimizer/rule/date_trunc_simplification.cpp b/src/duckdb/src/optimizer/rule/date_trunc_simplification.cpp index 392574de3..9e760a2f3 100644 --- a/src/duckdb/src/optimizer/rule/date_trunc_simplification.cpp +++ b/src/duckdb/src/optimizer/rule/date_trunc_simplification.cpp @@ -246,7 +246,7 @@ unique_ptr DateTruncSimplificationRule::Apply(LogicalOperator &op, v case ExpressionType::COMPARE_LESSTHANOREQUALTO: case ExpressionType::COMPARE_GREATERTHAN: - // date_trunc(part, column) <= constant_rhs --> column <= date_trunc(part, date_add(constant_rhs, + // date_trunc(part, column) <= constant_rhs --> column < date_trunc(part, date_add(constant_rhs, // INTERVAL 1 part)) // date_trunc(part, column) > constant_rhs --> column >= date_trunc(part, date_add(constant_rhs, // INTERVAL 1 part)) @@ -265,13 +265,19 @@ unique_ptr DateTruncSimplificationRule::Apply(LogicalOperator &op, v expr.left = std::move(trunc); } - // If this is a >, we need to change it to >= for correctness. + // > needs to become >=, and <= needs to become <. if (rhs_comparison_type == ExpressionType::COMPARE_GREATERTHAN) { if (col_is_lhs) { expr.SetExpressionTypeUnsafe(ExpressionType::COMPARE_GREATERTHANOREQUALTO); } else { expr.SetExpressionTypeUnsafe(ExpressionType::COMPARE_LESSTHANOREQUALTO); } + } else { + if (col_is_lhs) { + expr.SetExpressionTypeUnsafe(ExpressionType::COMPARE_LESSTHAN); + } else { + expr.SetExpressionTypeUnsafe(ExpressionType::COMPARE_GREATERTHAN); + } } changes_made = true; diff --git a/src/duckdb/src/optimizer/rule/enum_comparison.cpp b/src/duckdb/src/optimizer/rule/enum_comparison.cpp index 5dfcfaf30..0553285eb 100644 --- a/src/duckdb/src/optimizer/rule/enum_comparison.cpp +++ b/src/duckdb/src/optimizer/rule/enum_comparison.cpp @@ -47,7 +47,6 @@ bool AreMatchesPossible(LogicalType &left, LogicalType &right) { } unique_ptr EnumComparisonRule::Apply(LogicalOperator &op, vector> &bindings, bool &changes_made, bool is_root) { - auto &root = bindings[0].get().Cast(); auto &left_child = bindings[1].get().Cast(); auto &right_child = bindings[3].get().Cast(); diff --git a/src/duckdb/src/optimizer/rule/regex_optimizations.cpp b/src/duckdb/src/optimizer/rule/regex_optimizations.cpp index 24786867b..3a0697e99 100644 --- a/src/duckdb/src/optimizer/rule/regex_optimizations.cpp +++ b/src/duckdb/src/optimizer/rule/regex_optimizations.cpp @@ -184,6 +184,13 @@ unique_ptr RegexOptimizationRule::Apply(LogicalOperator &op, vector< if (!escaped_like_string.exists) { return nullptr; } + + // if regexp had options, remove them so the new Contains Expression can be matched for other optimizers. + if (root.children.size() == 3) { + root.children.pop_back(); + D_ASSERT(root.children.size() == 2); + } + auto parameter = make_uniq(Value(std::move(escaped_like_string.like_string))); auto contains = make_uniq(root.return_type, GetStringContains(), std::move(root.children), nullptr); diff --git a/src/duckdb/src/optimizer/statistics/operator/propagate_get.cpp b/src/duckdb/src/optimizer/statistics/operator/propagate_get.cpp index b01f7d704..a7aeec267 100644 --- a/src/duckdb/src/optimizer/statistics/operator/propagate_get.cpp +++ b/src/duckdb/src/optimizer/statistics/operator/propagate_get.cpp @@ -46,7 +46,12 @@ FilterPropagateResult StatisticsPropagator::PropagateTableFilter(ColumnBinding s // replace BoundColumnRefs with BoundRefs ExpressionFilter::ReplaceExpressionRecursive(filter_expr, *colref, ExpressionType::BOUND_COLUMN_REF); expr_filter.expr = std::move(filter_expr); - return propagate_result; + + // If we were able to prune solely based on the expression, return that result + if (propagate_result != FilterPropagateResult::NO_PRUNING_POSSIBLE) { + return propagate_result; + } + // Otherwise, check the statistics } return filter.CheckStatistics(stats); } diff --git a/src/duckdb/src/optimizer/topn_optimizer.cpp b/src/duckdb/src/optimizer/topn_optimizer.cpp index e42c748cb..a48faafbe 100644 --- a/src/duckdb/src/optimizer/topn_optimizer.cpp +++ b/src/duckdb/src/optimizer/topn_optimizer.cpp @@ -11,9 +11,29 @@ #include "duckdb/execution/operator/join/join_filter_pushdown.hpp" #include "duckdb/optimizer/join_filter_pushdown_optimizer.hpp" #include "duckdb/planner/expression/bound_columnref_expression.hpp" +#include "duckdb/storage/table/scan_state.hpp" namespace duckdb { +namespace { + +bool CanReorderRowGroups(LogicalTopN &op) { + // Only reorder row groups if there are no additional limit operators since they could modify the order + reference current_op = op; + while (!current_op.get().children.empty()) { + if (current_op.get().children.size() > 1) { + return false; + } + if (current_op.get().type == LogicalOperatorType::LOGICAL_LIMIT) { + return false; + } + current_op = *current_op.get().children[0]; + } + return true; +} + +} // namespace + TopN::TopN(ClientContext &context_p) : context(context_p) { } @@ -39,6 +59,9 @@ bool TopN::CanOptimize(LogicalOperator &op, optional_ptr context) if (child_op->has_estimated_cardinality) { // only check if we should switch to full sorting if we have estimated cardinality auto constant_limit = static_cast(limit.limit_val.GetConstantValue()); + if (limit.offset_val.Type() == LimitNodeType::CONSTANT_VALUE) { + constant_limit += static_cast(limit.offset_val.GetConstantValue()); + } auto child_card = static_cast(child_op->estimated_cardinality); // if the limit is > 0.7% of the child cardinality, sorting the whole table is faster @@ -110,6 +133,9 @@ void TopN::PushdownDynamicFilters(LogicalTopN &op) { // put the filter into the Top-N clause op.dynamic_filter = filter_data; + bool use_custom_rowgroup_order = + CanReorderRowGroups(op) && (colref.return_type.IsNumeric() || colref.return_type.IsTemporal()); + for (auto &target : pushdown_targets) { auto &get = target.get; D_ASSERT(target.columns.size() == 1); @@ -122,12 +148,23 @@ void TopN::PushdownDynamicFilters(LogicalTopN &op) { // push the filter into the table scan auto &column_index = get.GetColumnIds()[col_idx]; get.table_filters.PushFilter(column_index, std::move(optional_filter)); + + // Scan row groups in custom order + if (get.function.set_scan_order && use_custom_rowgroup_order) { + auto column_type = + colref.return_type == LogicalType::VARCHAR ? OrderByColumnType::STRING : OrderByColumnType::NUMERIC; + auto order_type = + op.orders[0].type == OrderType::ASCENDING ? RowGroupOrderType::ASC : RowGroupOrderType::DESC; + auto order_by = order_type == RowGroupOrderType::ASC ? OrderByStatistics::MIN : OrderByStatistics::MAX; + auto order_options = + make_uniq(column_index.GetPrimaryIndex(), order_by, order_type, column_type); + get.function.set_scan_order(std::move(order_options), get.bind_data.get()); + } } } unique_ptr TopN::Optimize(unique_ptr op) { if (CanOptimize(*op, &context)) { - vector> projections; // traverse operator tree and collect all projection nodes until we reach diff --git a/src/duckdb/src/optimizer/topn_window_elimination.cpp b/src/duckdb/src/optimizer/topn_window_elimination.cpp new file mode 100644 index 000000000..c5544d21f --- /dev/null +++ b/src/duckdb/src/optimizer/topn_window_elimination.cpp @@ -0,0 +1,979 @@ +#include "duckdb/optimizer/topn_window_elimination.hpp" + +#include "duckdb/catalog/catalog_entry/aggregate_function_catalog_entry.hpp" +#include "duckdb/catalog/catalog_entry/scalar_macro_catalog_entry.hpp" +#include "duckdb/catalog/catalog_entry/scalar_function_catalog_entry.hpp" +#include "duckdb/optimizer/late_materialization_helper.hpp" +#include "duckdb/planner/binder.hpp" +#include "duckdb/planner/operator/logical_aggregate.hpp" +#include "duckdb/planner/operator/logical_comparison_join.hpp" +#include "duckdb/planner/operator/logical_get.hpp" +#include "duckdb/planner/operator/logical_filter.hpp" +#include "duckdb/planner/operator/logical_projection.hpp" +#include "duckdb/planner/operator/logical_unnest.hpp" +#include "duckdb/planner/operator/logical_window.hpp" +#include "duckdb/function/scalar/nested_functions.hpp" +#include "duckdb/function/scalar/struct_functions.hpp" +#include "duckdb/optimizer/optimizer.hpp" +#include "duckdb/parser/expression/function_expression.hpp" +#include "duckdb/planner/expression/bound_aggregate_expression.hpp" +#include "duckdb/planner/expression/bound_comparison_expression.hpp" +#include "duckdb/planner/expression/bound_constant_expression.hpp" +#include "duckdb/planner/expression/bound_unnest_expression.hpp" +#include "duckdb/planner/expression/bound_window_expression.hpp" +#include "duckdb/function/function_binder.hpp" +#include "duckdb/main/database.hpp" + +namespace duckdb { + +namespace { + +idx_t GetGroupIdx(const unique_ptr &op) { + if (op->type == LogicalOperatorType::LOGICAL_AGGREGATE_AND_GROUP_BY) { + return op->Cast().group_index; + } + if (op->type == LogicalOperatorType::LOGICAL_COMPARISON_JOIN) { + return op->children[0]->GetTableIndex()[0]; + } + return op->GetTableIndex()[0]; +} + +idx_t GetAggregateIdx(const unique_ptr &op) { + if (op->type == LogicalOperatorType::LOGICAL_AGGREGATE_AND_GROUP_BY) { + return op->Cast().aggregate_index; + } + if (op->type == LogicalOperatorType::LOGICAL_COMPARISON_JOIN) { + return op->children[0]->GetTableIndex()[0]; + } + return op->GetTableIndex()[0]; +} + +LogicalType GetAggregateType(const unique_ptr &op) { + switch (op->type) { + case LogicalOperatorType::LOGICAL_UNNEST: { + const auto &logical_unnest = op->Cast(); + const idx_t unnest_offset = logical_unnest.children[0]->types.size(); + return logical_unnest.types[unnest_offset]; + } + case LogicalOperatorType::LOGICAL_AGGREGATE_AND_GROUP_BY: { + const auto &logical_aggregate = op->Cast(); + const idx_t aggregate_column_idx = logical_aggregate.groups.size(); + return logical_aggregate.types[aggregate_column_idx]; + } + default: { + throw InternalException("Unnest or aggregate expected to extract aggregate type."); + } + } +} + +vector ExtractReturnTypes(const vector> &exprs) { + vector types; + types.reserve(exprs.size()); + for (const auto &expr : exprs) { + types.push_back(expr->return_type); + } + return types; +} + +bool BindingsReferenceRowNumber(const vector &bindings, const LogicalWindow &window) { + for (const auto &binding : bindings) { + if (binding.table_index == window.window_index) { + return true; + } + } + return false; +} + +ColumnBinding GetRowNumberColumnBinding(const unique_ptr &op) { + switch (op->type) { + case LogicalOperatorType::LOGICAL_UNNEST: { + const auto column_bindings = op->GetColumnBindings(); + const idx_t row_number_offset = op->children[0]->types.size() + 1; + D_ASSERT(op->types.size() == row_number_offset + 1); + return column_bindings[row_number_offset]; + } + case LogicalOperatorType::LOGICAL_PROJECTION: { + const auto &projection = op->Cast(); + return {projection.table_index, projection.types.size() - 1}; + } + case LogicalOperatorType::LOGICAL_COMPARISON_JOIN: { + const auto &join = op->Cast(); + D_ASSERT(!join.right_projection_map.empty()); + const auto child_bindings = op->GetColumnBindings(); + return child_bindings[child_bindings.size() - 1]; + } + default: { + throw InternalException("Operator type not supported."); + } + } +} + +idx_t TraverseAndFindAggregateOffset(const unique_ptr &op) { + reference current_op = *op; + while (current_op.get().type != LogicalOperatorType::LOGICAL_AGGREGATE_AND_GROUP_BY) { + D_ASSERT(!current_op.get().children.empty()); + current_op = *current_op.get().children[0]; + } + const auto &aggregate = current_op.get().Cast(); + return aggregate.groups.size(); +} + +} // namespace + +TopNWindowElimination::TopNWindowElimination(ClientContext &context_p, Optimizer &optimizer, + optional_ptr>> stats_p) + : context(context_p), optimizer(optimizer), stats(stats_p) { +} + +unique_ptr TopNWindowElimination::Optimize(unique_ptr op) { + auto &extension_manager = context.db->GetExtensionManager(); + if (!extension_manager.ExtensionIsLoaded("core_functions")) { + return op; + } + + ColumnBindingReplacer replacer; + op = OptimizeInternal(std::move(op), replacer); + if (!replacer.replacement_bindings.empty()) { + replacer.VisitOperator(*op); + } + return op; +} + +unique_ptr TopNWindowElimination::OptimizeInternal(unique_ptr op, + ColumnBindingReplacer &replacer) { + if (!CanOptimize(*op)) { + // Traverse through query plan to find grouped top-n pattern + if (op->children.size() > 1) { + // If an operator has multiple children, we do not want them to overwrite each other's stop operator. + // Thus, first update only the column binding in op, then set op as the new stop operator. + for (auto &child : op->children) { + ColumnBindingReplacer r2; + child = OptimizeInternal(std::move(child), r2); + + if (!r2.replacement_bindings.empty()) { + r2.VisitOperator(*op); + replacer.replacement_bindings.insert(replacer.replacement_bindings.end(), + r2.replacement_bindings.begin(), + r2.replacement_bindings.end()); + replacer.stop_operator = op; + } + } + } else if (!op->children.empty()) { + op->children[0] = OptimizeInternal(std::move(op->children[0]), replacer); + } + + return op; + } + // We have made sure that this is an operator sequence of filter -> N optional projections -> window + auto &filter = op->Cast(); + reference child = *filter.children[0]; + + // Get bindings and types from filter to use in top-most operator later + const auto topmost_bindings = filter.GetColumnBindings(); + auto new_bindings = TraverseProjectionBindings(topmost_bindings, child); + + D_ASSERT(child.get().type == LogicalOperatorType::LOGICAL_WINDOW); + auto &window = child.get().Cast(); + const idx_t window_idx = window.window_index; + + // Map the input column offsets of the group columns to the output offset if there are projections on the group + // We use an ordered map here because we need to iterate over them in order later + map group_projection_idxs; + auto aggregate_payload = GenerateAggregatePayload(new_bindings, window, group_projection_idxs); + auto params = ExtractOptimizerParameters(window, filter, new_bindings, aggregate_payload); + + unique_ptr late_mat_lhs = nullptr; + if (params.payload_type == TopNPayloadType::STRUCT_PACK) { + // Try circumventing struct-packing with late materialization + late_mat_lhs = TryPrepareLateMaterialization(window, aggregate_payload); + if (late_mat_lhs) { + params.payload_type = TopNPayloadType::SINGLE_COLUMN; + } + } + + // Optimize window children + window.children[0] = Optimize(std::move(window.children[0])); + + op = CreateAggregateOperator(window, std::move(aggregate_payload), params); + op = TryCreateUnnestOperator(std::move(op), params); + op = CreateProjectionOperator(std::move(op), params, group_projection_idxs); + + D_ASSERT(op->type != LogicalOperatorType::LOGICAL_UNNEST); + + if (late_mat_lhs) { + op = ConstructJoin(std::move(late_mat_lhs), std::move(op), group_projection_idxs.size(), params); + } + + UpdateTopmostBindings(window_idx, op, group_projection_idxs, topmost_bindings, new_bindings, replacer); + replacer.stop_operator = op.get(); + + RemoveUnusedColumns unused_optimizer(optimizer.binder, optimizer.context, true); + unused_optimizer.VisitOperator(*op); + + return unique_ptr(std::move(op)); +} + +unique_ptr +TopNWindowElimination::CreateAggregateExpression(vector> aggregate_params, + const bool requires_arg, + const TopNWindowEliminationParameters ¶ms) const { + auto &catalog = Catalog::GetSystemCatalog(context); + FunctionBinder function_binder(context); + + // If the value column can be null, we must use the nulls_last function to follow null ordering semantics + const bool change_to_arg = !requires_arg && params.can_be_null && params.limit > 1; + if (change_to_arg) { + // Copy value as argument + aggregate_params.insert(aggregate_params.begin() + 1, aggregate_params[0]->Copy()); + } + + D_ASSERT(params.order_type == OrderType::ASCENDING || params.order_type == OrderType::DESCENDING); + string fun_name = requires_arg || change_to_arg ? "arg_" : ""; + fun_name += params.order_type == OrderType::ASCENDING ? "min" : "max"; + fun_name += params.can_be_null && (requires_arg || change_to_arg) ? "_nulls_last" : ""; + + auto &fun_entry = catalog.GetEntry(context, DEFAULT_SCHEMA, fun_name); + const auto fun = fun_entry.functions.GetFunctionByArguments(context, ExtractReturnTypes(aggregate_params)); + return function_binder.BindAggregateFunction(fun, std::move(aggregate_params)); +} + +unique_ptr +TopNWindowElimination::CreateAggregateOperator(LogicalWindow &window, vector> args, + const TopNWindowEliminationParameters ¶ms) const { + auto &window_expr = window.expressions[0]->Cast(); + D_ASSERT(window_expr.orders.size() == 1); + + vector> aggregate_params; + aggregate_params.reserve(3); + + const bool use_arg = !args.empty(); + if (args.size() == 1) { + aggregate_params.push_back(std::move(args[0])); + } else if (args.size() > 1) { + // For more than one arg, we must use struct pack + auto &catalog = Catalog::GetSystemCatalog(context); + FunctionBinder function_binder(context); + auto &struct_pack_entry = catalog.GetEntry(context, DEFAULT_SCHEMA, "struct_pack"); + const auto struct_pack_fun = + struct_pack_entry.functions.GetFunctionByArguments(context, ExtractReturnTypes(args)); + auto struct_pack_expr = function_binder.BindScalarFunction(struct_pack_fun, std::move(args)); + aggregate_params.push_back(std::move(struct_pack_expr)); + } + + aggregate_params.push_back(std::move(window_expr.orders[0].expression)); + if (params.limit > 1) { + aggregate_params.push_back(std::move(make_uniq(Value::BIGINT(params.limit)))); + } + + auto aggregate_expr = CreateAggregateExpression(std::move(aggregate_params), use_arg, params); + + vector> select_list; + select_list.push_back(std::move(aggregate_expr)); + + auto aggregate = make_uniq(optimizer.binder.GenerateTableIndex(), + optimizer.binder.GenerateTableIndex(), std::move(select_list)); + aggregate->groupings_index = optimizer.binder.GenerateTableIndex(); + aggregate->groups = std::move(window_expr.partitions); + aggregate->children.push_back(std::move(window.children[0])); + aggregate->ResolveOperatorTypes(); + + // Add group statistics to allow for perfect hash aggregation if applicable + aggregate->group_stats.resize(aggregate->groups.size()); + for (idx_t i = 0; i < aggregate->groups.size(); i++) { + auto &group = aggregate->groups[i]; + if (group->type == ExpressionType::BOUND_COLUMN_REF) { + auto &column_ref = group->Cast(); + if (stats) { + auto group_stats = stats->find(column_ref.binding); + if (group_stats == stats->end()) { + continue; + } + aggregate->group_stats[i] = group_stats->second->ToUnique(); + } + } + } + + return unique_ptr(std::move(aggregate)); +} + +unique_ptr +TopNWindowElimination::CreateRowNumberGenerator(unique_ptr aggregate_column_ref) const { + // Create unnest(generate_series(1, array_length(column_ref, 1))) function to generate row ids + FunctionBinder function_binder(context); + auto &catalog = Catalog::GetSystemCatalog(context); + + // array_length + auto &array_length_entry = catalog.GetEntry(context, DEFAULT_SCHEMA, "array_length"); + vector> array_length_exprs; + array_length_exprs.push_back(std::move(aggregate_column_ref)); + array_length_exprs.push_back(make_uniq(1)); + + const auto array_length_fun = array_length_entry.functions.GetFunctionByArguments( + context, {array_length_exprs[0]->return_type, array_length_exprs[1]->return_type}); + auto bound_array_length_fun = function_binder.BindScalarFunction(array_length_fun, std::move(array_length_exprs)); + + // generate_series + auto &generate_series_entry = + catalog.GetEntry(context, DEFAULT_SCHEMA, "generate_series"); + + vector> generate_series_exprs; + generate_series_exprs.push_back(make_uniq(1)); + generate_series_exprs.push_back(std::move(bound_array_length_fun)); + + const auto generate_series_fun = generate_series_entry.functions.GetFunctionByArguments( + context, {generate_series_exprs[0]->return_type, generate_series_exprs[1]->return_type}); + auto bound_generate_series_fun = + function_binder.BindScalarFunction(generate_series_fun, std::move(generate_series_exprs)); + + // unnest + auto unnest_row_number_expr = make_uniq(LogicalType::BIGINT); + unnest_row_number_expr->alias = "row_number"; + unnest_row_number_expr->child = std::move(bound_generate_series_fun); + + return unique_ptr(std::move(unnest_row_number_expr)); +} + +unique_ptr +TopNWindowElimination::TryCreateUnnestOperator(unique_ptr op, + const TopNWindowEliminationParameters ¶ms) const { + D_ASSERT(op->type == LogicalOperatorType::LOGICAL_AGGREGATE_AND_GROUP_BY); + + auto &logical_aggregate = op->Cast(); + const idx_t aggregate_column_idx = logical_aggregate.groups.size(); + LogicalType aggregate_type = logical_aggregate.types[aggregate_column_idx]; + + if (params.limit <= 1) { + // LIMIT 1 -> we do not need to unnest + return std::move(op); + } + + // Create unnest expression for aggregate args + const auto aggregate_bindings = logical_aggregate.GetColumnBindings(); + auto aggregate_column_ref = + make_uniq(aggregate_type, aggregate_bindings[aggregate_column_idx]); + + vector> unnest_exprs; + + auto unnest_aggregate = make_uniq(ListType::GetChildType(aggregate_type)); + unnest_aggregate->child = aggregate_column_ref->Copy(); + unnest_exprs.push_back(std::move(unnest_aggregate)); + + if (params.include_row_number) { + // Create row number expression + unnest_exprs.push_back(CreateRowNumberGenerator(std::move(aggregate_column_ref))); + } + + auto unnest = make_uniq(optimizer.binder.GenerateTableIndex()); + unnest->expressions = std::move(unnest_exprs); + unnest->children.push_back(std::move(op)); + unnest->ResolveOperatorTypes(); + + return unique_ptr(std::move(unnest)); +} + +void TopNWindowElimination::AddStructExtractExprs( + vector> &exprs, const LogicalType &struct_type, + const unique_ptr &aggregate_column_ref) const { + FunctionBinder function_binder(context); + auto &catalog = Catalog::GetSystemCatalog(context); + auto &struct_extract_entry = + catalog.GetEntry(context, DEFAULT_SCHEMA, "struct_extract"); + const auto struct_extract_fun = + struct_extract_entry.functions.GetFunctionByArguments(context, {struct_type, LogicalType::VARCHAR}); + + const auto &child_types = StructType::GetChildTypes(struct_type); + for (idx_t i = 0; i < child_types.size(); i++) { + const auto &alias = child_types[i].first; + + vector> fun_args(2); + fun_args[0] = aggregate_column_ref->Copy(); + fun_args[1] = make_uniq(alias); + + auto bound_function = function_binder.BindScalarFunction(struct_extract_fun, std::move(fun_args)); + bound_function->alias = alias; + exprs.push_back(std::move(bound_function)); + } +} + +unique_ptr +TopNWindowElimination::CreateProjectionOperator(unique_ptr op, + const TopNWindowEliminationParameters ¶ms, + const map &group_idxs) const { + const auto aggregate_type = GetAggregateType(op); + const idx_t aggregate_table_idx = GetAggregateIdx(op); + const auto op_column_bindings = op->GetColumnBindings(); + + vector> proj_exprs; + // Only project necessary group columns + for (const auto &group_idx : group_idxs) { + proj_exprs.push_back( + make_uniq(op->types[group_idx.second], op_column_bindings[group_idx.second])); + } + + auto aggregate_column_ref = + make_uniq(aggregate_type, ColumnBinding(aggregate_table_idx, 0)); + + if (params.payload_type == TopNPayloadType::STRUCT_PACK) { + AddStructExtractExprs(proj_exprs, aggregate_type, aggregate_column_ref); + } else { + // No need for struct_unpack! Just reference the aggregate column + proj_exprs.push_back(std::move(aggregate_column_ref)); + } + + if (params.include_row_number) { + // If aggregate (i.e., limit 1): constant, if unnest: expect there to be a second column + if (op->type == LogicalOperatorType::LOGICAL_UNNEST) { + auto row_number_column_binding = GetRowNumberColumnBinding(op); + proj_exprs.push_back( + make_uniq("row_number", LogicalType::BIGINT, row_number_column_binding)); + } else { + proj_exprs.push_back(make_uniq(Value::BIGINT(1))); + } + } + + auto logical_projection = + make_uniq(optimizer.binder.GenerateTableIndex(), std::move(proj_exprs)); + logical_projection->children.push_back(std::move(op)); + logical_projection->ResolveOperatorTypes(); + + return unique_ptr(std::move(logical_projection)); +} + +bool TopNWindowElimination::CanOptimize(LogicalOperator &op) { + if (op.type != LogicalOperatorType::LOGICAL_FILTER) { + return false; + } + + const auto &filter = op.Cast(); + if (filter.expressions.size() != 1) { + return false; + } + + if (filter.expressions[0]->type != ExpressionType::COMPARE_LESSTHANOREQUALTO) { + return false; + } + + auto &filter_comparison = filter.expressions[0]->Cast(); + if (filter_comparison.right->type != ExpressionType::VALUE_CONSTANT) { + return false; + } + auto &filter_value = filter_comparison.right->Cast(); + if (filter_value.value.type() != LogicalType::BIGINT) { + return false; + } + if (filter_value.value.GetValue() < 1) { + return false; + } + + if (filter_comparison.left->type != ExpressionType::BOUND_COLUMN_REF) { + return false; + } + VisitExpression(&filter_comparison.left); + + reference child = *filter.children[0]; + while (child.get().type == LogicalOperatorType::LOGICAL_PROJECTION) { + auto &projection = child.get().Cast(); + if (column_references.size() != 1) { + column_references.clear(); + return false; + } + + const auto current_column_ref = column_references.begin()->first; + column_references.clear(); + D_ASSERT(current_column_ref.table_index == projection.table_index); + VisitExpression(&projection.expressions[current_column_ref.column_index]); + + child = *child.get().children[0]; + } + + if (column_references.size() != 1) { + column_references.clear(); + return false; + } + const auto filter_col_idx = column_references.begin()->first.table_index; + column_references.clear(); + + if (child.get().type != LogicalOperatorType::LOGICAL_WINDOW) { + return false; + } + const auto &window = child.get().Cast(); + if (window.window_index != filter_col_idx) { + return false; + } + if (window.expressions.size() != 1) { + for (idx_t i = 1; i < window.expressions.size(); ++i) { + if (!window.expressions[i]->Equals(*window.expressions[0])) { + return false; + } + } + } + if (window.expressions[0]->type != ExpressionType::WINDOW_ROW_NUMBER) { + return false; + } + auto &window_expr = window.expressions[0]->Cast(); + + if (window_expr.orders.size() != 1) { + return false; + } + if (window_expr.orders[0].type != OrderType::DESCENDING && window_expr.orders[0].type != OrderType::ASCENDING) { + return false; + } + if (window_expr.orders[0].null_order != OrderByNullType::NULLS_LAST) { + return false; + } + + // We have found a grouped top-n window construct! + return true; +} + +vector> TopNWindowElimination::GenerateAggregatePayload(const vector &bindings, + const LogicalWindow &window, + map &group_idxs) { + vector> aggregate_args; + aggregate_args.reserve(bindings.size()); + + window.children[0]->ResolveOperatorTypes(); + const auto &window_child_types = window.children[0]->types; + const auto window_child_bindings = window.children[0]->GetColumnBindings(); + auto &window_expr = window.expressions[0]->Cast(); + + // Remember order of group columns to recreate that order in new bindings later + column_binding_map_t group_bindings; + for (idx_t i = 0; i < window_expr.partitions.size(); i++) { + auto &expr = window_expr.partitions[i]; + VisitExpression(&expr); + group_bindings[column_references.begin()->first] = i; + column_references.clear(); + } + + for (idx_t i = 0; i < bindings.size(); i++) { + const auto &binding = bindings[i]; + const auto group_binding = group_bindings.find(binding); + if (group_binding != group_bindings.end()) { + group_idxs[i] = group_binding->second; + continue; + } + if (binding.table_index == window.window_index) { + continue; + } + + auto column_id = binding.ToString(); + if (window.children[0]->type == LogicalOperatorType::LOGICAL_PROJECTION) { + // The column index points to the correct column binding + aggregate_args.push_back( + make_uniq(column_id, window_child_types[binding.column_index], binding)); + } else { + // The child operator could have multiple or no table indexes. Therefore, we must find the right type first + const auto child_column_idx = + static_cast(std::find(window_child_bindings.begin(), window_child_bindings.end(), binding) - + window_child_bindings.begin()); + aggregate_args.push_back( + make_uniq(column_id, window_child_types[child_column_idx], binding)); + } + } + + if (aggregate_args.size() == 1) { + // If we only project the aggregate value itself, we do not need it as an arg + VisitExpression(&window_expr.orders[0].expression); + const auto aggregate_value_binding = column_references.begin()->first; + column_references.clear(); + + if (window_expr.orders[0].expression->type == ExpressionType::BOUND_COLUMN_REF && + aggregate_args[0]->Cast().binding == aggregate_value_binding) { + return {}; + } + } + + return aggregate_args; +} + +vector TopNWindowElimination::TraverseProjectionBindings(const std::vector &old_bindings, + reference &op) { + auto new_bindings = old_bindings; + + // Traverse child projections to retrieve projections on window output + while (op.get().type == LogicalOperatorType::LOGICAL_PROJECTION) { + auto &projection = op.get().Cast(); + + for (idx_t i = 0; i < new_bindings.size(); i++) { + auto &new_binding = new_bindings[i]; + D_ASSERT(new_binding.table_index == projection.table_index); + VisitExpression(&projection.expressions[new_binding.column_index]); + new_binding = column_references.begin()->first; + column_references.clear(); + } + op = *op.get().children[0]; + } + + return new_bindings; +} + +void TopNWindowElimination::UpdateTopmostBindings(const idx_t window_idx, const unique_ptr &op, + const map &group_idxs, + const vector &topmost_bindings, + vector &new_bindings, + ColumnBindingReplacer &replacer) { + // The top-most operator's column order is [group][aggregate args][row number]. Now, set the new resulting bindings. + D_ASSERT(topmost_bindings.size() == new_bindings.size()); + replacer.replacement_bindings.reserve(new_bindings.size()); + set row_id_binding_idxs; + + const idx_t group_table_idx = GetGroupIdx(op); + const idx_t aggregate_table_idx = GetAggregateIdx(op); + + // Project the group columns + idx_t current_column_idx = 0; + for (auto group_idx : group_idxs) { + const idx_t group_referencing_idx = group_idx.first; + new_bindings[group_referencing_idx].table_index = group_table_idx; + new_bindings[group_referencing_idx].column_index = group_idx.second; + replacer.replacement_bindings.emplace_back(topmost_bindings[group_referencing_idx], + new_bindings[group_referencing_idx]); + current_column_idx++; + } + + if (group_table_idx != aggregate_table_idx) { + // If the topmost operator is an aggregate, the table indexes are different, and we start back from 0 + current_column_idx = 0; + } + if (op->type == LogicalOperatorType::LOGICAL_COMPARISON_JOIN) { + // We do not have an aggregate index, so we need to set an offset to hit the correct columns + current_column_idx = TraverseAndFindAggregateOffset(op->children[1]); + } + + // Project the args/value + for (idx_t i = 0; i < new_bindings.size(); i++) { + auto &binding = new_bindings[i]; + if (group_idxs.find(i) != group_idxs.end()) { + continue; + } + if (binding.table_index == window_idx) { + row_id_binding_idxs.insert(i); + continue; + } + binding.column_index = current_column_idx++; + binding.table_index = aggregate_table_idx; + replacer.replacement_bindings.emplace_back(topmost_bindings[i], binding); + } + + // Project the row number + for (const auto row_id_binding_idx : row_id_binding_idxs) { + // Let all projections on row id point to the last output column + auto &binding = new_bindings[row_id_binding_idx]; + binding = GetRowNumberColumnBinding(op); + replacer.replacement_bindings.emplace_back(topmost_bindings[row_id_binding_idx], binding); + } +} + +TopNWindowEliminationParameters +TopNWindowElimination::ExtractOptimizerParameters(const LogicalWindow &window, const LogicalFilter &filter, + const vector &bindings, + vector> &aggregate_payload) { + TopNWindowEliminationParameters params; + + auto &limit_expr = filter.expressions[0]->Cast().right; + params.limit = limit_expr->Cast().value.GetValue(); + params.include_row_number = BindingsReferenceRowNumber(bindings, window); + params.payload_type = aggregate_payload.size() > 1 ? TopNPayloadType::STRUCT_PACK : TopNPayloadType::SINGLE_COLUMN; + auto &window_expr = window.expressions[0]->Cast(); + params.order_type = window_expr.orders[0].type; + + VisitExpression(&window_expr.orders[0].expression); + if (params.payload_type == TopNPayloadType::SINGLE_COLUMN && !aggregate_payload.empty()) { + VisitExpression(&aggregate_payload[0]); + } + for (const auto &column_ref : column_references) { + const auto &column_stats = stats->find(column_ref.first); + if (column_stats == stats->end() || column_stats->second->CanHaveNull()) { + params.can_be_null = true; + } + } + column_references.clear(); + + return params; +} + +bool TopNWindowElimination::CanUseLateMaterialization(const LogicalWindow &window, vector> &args, + vector &lhs_projections, + vector> &stack) { + auto &window_expr = window.expressions[0]->Cast(); + vector projections(window_expr.partitions.size() + args.size()); + + // Build a projection list for an LHS table scan to recreate the column order of an aggregate with struct packing + for (idx_t i = 0; i < window_expr.partitions.size(); i++) { + auto &partition = window_expr.partitions[i]; + VisitExpression(&partition); + projections[i] = column_references.begin()->first; + column_references.clear(); + } + for (idx_t i = 0; i < args.size(); i++) { + auto &arg = args[i]; + VisitExpression(&arg); + projections[window_expr.partitions.size() + i] = column_references.begin()->first; + column_references.clear(); + } + + reference op = *window.children[0]; + + // Traverse projections to a single table scan + while (!op.get().children.empty()) { + stack.push_back(op); + switch (op.get().type) { + case LogicalOperatorType::LOGICAL_PROJECTION: { + auto &projection = op.get().Cast(); + for (idx_t i = 0; i < projections.size(); i++) { + D_ASSERT(projection.table_index == projections[i].table_index); + const idx_t projection_idx = projections[i].column_index; + VisitExpression(&projection.expressions[projection_idx]); + projections[i] = column_references.begin()->first; + column_references.clear(); + } + op = *op.get().children[0]; + break; + } + case LogicalOperatorType::LOGICAL_FILTER: { + op = *op.get().children[0]; + break; + } + case LogicalOperatorType::LOGICAL_COMPARISON_JOIN: { + auto &join = op.get().Cast(); + if (join.join_type != JoinType::INNER && join.join_type != JoinType::SEMI && + join.join_type != JoinType::ANTI) { + return false; + } + + // If there is a join, we only allow late materialization if the projected output stems from a single table. + // However, we allow replacing references to join columns as they are equal to the other side by condition. + column_binding_map_t replaceable_bindings; + for (auto &condition : join.conditions) { + if (condition.comparison != ExpressionType::COMPARE_EQUAL) { + return false; + } + VisitExpression(&condition.left); + auto left_binding = column_references.begin()->first; + column_references.clear(); + VisitExpression(&condition.right); + auto right_binding = column_references.begin()->first; + column_references.clear(); + + replaceable_bindings[left_binding] = right_binding; + replaceable_bindings[right_binding] = left_binding; + } + + auto left_column_bindings = join.children[0]->GetColumnBindings(); + auto right_column_bindings = join.children[1]->GetColumnBindings(); + auto lidxs = join.children[0]->GetTableIndex(); + auto ridxs = join.children[1]->GetTableIndex(); + if (lidxs.size() != 1 || ridxs.size() != 1) { + return false; + } + auto left_idx = lidxs[0]; + auto right_idx = ridxs[0]; + + bool all_left_replaceable = true; + bool all_right_replaceable = true; + for (idx_t i = 0; i < projections.size(); i++) { + const auto &projection = projections[i]; + auto &column_binding = projection.table_index == left_idx + ? left_column_bindings[projection.column_index] + : right_column_bindings[projection.column_index]; + if (replaceable_bindings.find(column_binding) == replaceable_bindings.end()) { + if (column_binding.table_index == left_idx) { + all_left_replaceable = false; + } else { + all_right_replaceable = false; + } + } + } + + if (!all_left_replaceable && !all_right_replaceable) { + // We cannot use late materialization by scanning a single table. + return false; + } + + idx_t replace_table_idx = all_right_replaceable ? right_idx : left_idx; + for (idx_t i = 0; i < projections.size(); i++) { + const auto projection_idx = projections[i]; + auto &column_binding = projection_idx.table_index == left_idx + ? left_column_bindings[projection_idx.column_index] + : right_column_bindings[projection_idx.column_index]; + if (column_binding.table_index == replace_table_idx) { + projections[i] = replaceable_bindings[column_binding]; + } + } + + if (all_right_replaceable) { + op = *op.get().children[0]; + } else { + op = *op.get().children[1]; + } + break; + } + default: { + return false; + } + } + } + stack.push_back(op); + + D_ASSERT(op.get().type == LogicalOperatorType::LOGICAL_GET); + auto &logical_get = op.get().Cast(); + if (!logical_get.function.late_materialization || !logical_get.function.get_row_id_columns) { + return false; + } + + const auto rowid_column_idxs = logical_get.function.get_row_id_columns(context, logical_get.bind_data.get()); + if (rowid_column_idxs.size() > 1) { + // TODO: support multi-column rowids for parquet + return false; + } + for (const auto &col_idx : rowid_column_idxs) { + auto entry = logical_get.virtual_columns.find(col_idx); + if (entry == logical_get.virtual_columns.end()) { + return false; + } + } + // Check if we need the projection map + for (idx_t i = 0; i < projections.size(); i++) { + if (projections[i].column_index != i) { + for (auto &proj : projections) { + lhs_projections.push_back(proj.column_index); + } + break; + } + } + return true; +} + +unique_ptr TopNWindowElimination::TryPrepareLateMaterialization(const LogicalWindow &window, + vector> &args) { + vector lhs_projections; + vector> stack; + bool use_late_materialization = CanUseLateMaterialization(window, args, lhs_projections, stack); + if (!use_late_materialization) { + return nullptr; + } + + D_ASSERT(stack.back().get().type == LogicalOperatorType::LOGICAL_GET); + auto &rhs_get = stack.back().get().Cast(); + auto lhs = ConstructLHS(rhs_get, lhs_projections); + + const auto rhs_rowid_column_idxs = rhs_get.function.get_row_id_columns(context, rhs_get.bind_data.get()); + vector rhs_rowid_columns; + for (const auto &col_idx : rhs_rowid_column_idxs) { + rhs_rowid_columns.push_back(rhs_get.virtual_columns[col_idx]); + } + const auto rhs_rowid_idxs = + LateMaterializationHelper::GetOrInsertRowIds(rhs_get, rhs_rowid_column_idxs, rhs_rowid_columns); + + // Add rowid column to the operators on the right-hand side + idx_t last_table_idx = rhs_get.table_index; + idx_t last_rowid_offset = rhs_rowid_idxs[0]; + + // Add rowid projections to the query tree on the right-hand side + for (auto stack_it = std::next(stack.rbegin()); stack_it != stack.rend(); ++stack_it) { + auto &op = stack_it->get(); + + switch (op.type) { + case LogicalOperatorType::LOGICAL_PROJECTION: { + auto &rowid_column = rhs_rowid_columns[0]; + op.expressions.push_back(make_uniq( + rowid_column.name, rowid_column.type, ColumnBinding {last_table_idx, last_rowid_offset})); + last_table_idx = op.GetTableIndex()[0]; + last_rowid_offset = op.expressions.size() - 1; + break; + } + case LogicalOperatorType::LOGICAL_FILTER: { + if (op.HasProjectionMap()) { + auto &filter = op.Cast(); + filter.projection_map.push_back(last_rowid_offset); + } + break; + } + case LogicalOperatorType::LOGICAL_COMPARISON_JOIN: { + if (op.HasProjectionMap()) { + auto &join = op.Cast(); + auto &op_child = std::prev(stack_it)->get(); + if (&op_child == &*join.children[0]) { + join.left_projection_map.push_back(last_rowid_offset); + } else { + join.right_projection_map.push_back(last_rowid_offset); + } + } + break; + } + default: + throw InternalException("Unsupported operator in late materialization right-hand side."); + } + } + + // Change args to project rowid + args.clear(); + args.push_back(make_uniq(rhs_rowid_columns[0].name, rhs_rowid_columns[0].type, + ColumnBinding {last_table_idx, last_rowid_offset})); + + return lhs; +} + +unique_ptr TopNWindowElimination::ConstructLHS(LogicalGet &rhs, vector &projections) const { + auto lhs_get = LateMaterializationHelper::CreateLHSGet(rhs, optimizer.binder); + const auto lhs_rowid_column_idxs = lhs_get->function.get_row_id_columns(context, lhs_get->bind_data.get()); + vector lhs_rowid_columns; + for (const auto &col_idx : lhs_rowid_column_idxs) { + lhs_rowid_columns.push_back(rhs.virtual_columns[col_idx]); + } + + const auto lhs_rowid_idxs = + LateMaterializationHelper::GetOrInsertRowIds(*lhs_get, lhs_rowid_column_idxs, lhs_rowid_columns); + + if (!projections.empty()) { + for (auto rowid_idx : lhs_rowid_idxs) { + projections.push_back(rowid_idx); + } + lhs_get->ResolveOperatorTypes(); + + vector> projs; + projs.reserve(projections.size()); + for (auto projection_id : projections) { + projs.push_back(make_uniq(lhs_get->types[projection_id], + ColumnBinding {lhs_get->table_index, projection_id})); + } + auto projection = make_uniq(optimizer.binder.GenerateTableIndex(), std::move(projs)); + projection->children.push_back(std::move(lhs_get)); + return unique_ptr(std::move(projection)); + } + return unique_ptr(std::move(lhs_get)); +} + +unique_ptr TopNWindowElimination::ConstructJoin(unique_ptr lhs, + unique_ptr rhs, + const idx_t aggregate_offset, + const TopNWindowEliminationParameters ¶ms) { + auto join = make_uniq(JoinType::SEMI); + + JoinCondition condition; + condition.comparison = ExpressionType::COMPARE_EQUAL; + + lhs->ResolveOperatorTypes(); + const auto lhs_rowid_idx = lhs->types.size() - 1; + const auto rhs_rowid_idx = rhs->type == LogicalOperatorType::LOGICAL_AGGREGATE_AND_GROUP_BY ? 0 : aggregate_offset; + + condition.left = make_uniq("rowid", lhs->types[lhs_rowid_idx], + ColumnBinding {lhs->GetTableIndex()[0], lhs_rowid_idx}); + condition.right = make_uniq("rowid", rhs->types[aggregate_offset], + ColumnBinding {GetAggregateIdx(rhs), rhs_rowid_idx}); + + join->conditions.push_back(std::move(condition)); + if (params.include_row_number) { + // Add row_number to join result + join->join_type = JoinType::INNER; + join->right_projection_map.push_back(rhs->types.size() - 1); + } + + join->children.push_back(std::move(lhs)); + join->children.push_back(std::move(rhs)); + + return unique_ptr(std::move(join)); +} + +} // namespace duckdb diff --git a/src/duckdb/src/optimizer/unnest_rewriter.cpp b/src/duckdb/src/optimizer/unnest_rewriter.cpp index 4c4207e2a..a73810bd1 100644 --- a/src/duckdb/src/optimizer/unnest_rewriter.cpp +++ b/src/duckdb/src/optimizer/unnest_rewriter.cpp @@ -33,14 +33,12 @@ void UnnestRewriterPlanUpdater::VisitExpression(unique_ptr *expressi } unique_ptr UnnestRewriter::Optimize(unique_ptr op) { - UnnestRewriterPlanUpdater updater; vector>> candidates; FindCandidates(op, candidates); // rewrite the plan and update the bindings for (auto &candidate : candidates) { - // rearrange the logical operators if (RewriteCandidate(candidate)) { updater.overwritten_tbl_idx = overwritten_tbl_idx; @@ -106,7 +104,6 @@ void UnnestRewriter::FindCandidates(unique_ptr &op, } bool UnnestRewriter::RewriteCandidate(unique_ptr &candidate) { - auto &topmost_op = *candidate; if (topmost_op.type != LogicalOperatorType::LOGICAL_PROJECTION && topmost_op.type != LogicalOperatorType::LOGICAL_WINDOW && @@ -160,14 +157,12 @@ bool UnnestRewriter::RewriteCandidate(unique_ptr &candidate) { void UnnestRewriter::UpdateRHSBindings(unique_ptr &plan, unique_ptr &candidate, UnnestRewriterPlanUpdater &updater) { - auto &topmost_op = *candidate; idx_t shift = lhs_bindings.size(); vector *> path_to_unnest; auto curr_op = &topmost_op.children[0]; while (curr_op->get()->type == LogicalOperatorType::LOGICAL_PROJECTION) { - path_to_unnest.push_back(curr_op); D_ASSERT(curr_op->get()->type == LogicalOperatorType::LOGICAL_PROJECTION); auto &proj = curr_op->get()->Cast(); @@ -222,7 +217,6 @@ void UnnestRewriter::UpdateRHSBindings(unique_ptr &plan, unique // add the LHS expressions to each LOGICAL_PROJECTION for (idx_t i = path_to_unnest.size(); i > 0; i--) { - D_ASSERT(path_to_unnest[i - 1]->get()->type == LogicalOperatorType::LOGICAL_PROJECTION); auto &proj = path_to_unnest[i - 1]->get()->Cast(); @@ -254,7 +248,6 @@ void UnnestRewriter::UpdateRHSBindings(unique_ptr &plan, unique void UnnestRewriter::UpdateBoundUnnestBindings(UnnestRewriterPlanUpdater &updater, unique_ptr &candidate) { - auto &topmost_op = *candidate; // traverse LOGICAL_PROJECTION(s) @@ -296,7 +289,6 @@ void UnnestRewriter::UpdateBoundUnnestBindings(UnnestRewriterPlanUpdater &update } void UnnestRewriter::GetDelimColumns(LogicalOperator &op) { - D_ASSERT(op.type == LogicalOperatorType::LOGICAL_DELIM_JOIN); auto &delim_join = op.Cast(); for (idx_t i = 0; i < delim_join.duplicate_eliminated_columns.size(); i++) { @@ -308,7 +300,6 @@ void UnnestRewriter::GetDelimColumns(LogicalOperator &op) { } void UnnestRewriter::GetLHSExpressions(LogicalOperator &op) { - op.ResolveOperatorTypes(); auto col_bindings = op.GetColumnBindings(); D_ASSERT(op.types.size() == col_bindings.size()); diff --git a/src/duckdb/src/parallel/async_result.cpp b/src/duckdb/src/parallel/async_result.cpp new file mode 100644 index 000000000..a32086b84 --- /dev/null +++ b/src/duckdb/src/parallel/async_result.cpp @@ -0,0 +1,192 @@ +#include "duckdb/parallel/executor_task.hpp" +#include "duckdb/parallel/async_result.hpp" +#include "duckdb/parallel/interrupt.hpp" +#include "duckdb/parallel/task_scheduler.hpp" +#include "duckdb/execution/executor.hpp" +#include "duckdb/execution/physical_table_scan_enum.hpp" + +#ifdef DUCKDB_DEBUG_ASYNC_SINK_SOURCE +#include "duckdb/parallel/sleep_async_task.hpp" +#endif + +namespace duckdb { + +struct Counter { + explicit Counter(idx_t size) : counter(size) { + } + bool IterateAndCheckCounter() { + D_ASSERT(counter.load() > 0); + idx_t post_decreast = --counter; + return (post_decreast == 0); + } + +private: + atomic counter; +}; + +class AsyncExecutionTask : public ExecutorTask { +public: + AsyncExecutionTask(Executor &executor, unique_ptr &&async_task, InterruptState &interrupt_state, + shared_ptr counter) + : ExecutorTask(executor, nullptr), async_task(std::move(async_task)), interrupt_state(interrupt_state), + counter(std::move(counter)) { + } + TaskExecutionResult ExecuteTask(TaskExecutionMode mode) override { + async_task->Execute(); + if (counter->IterateAndCheckCounter()) { + interrupt_state.Callback(); + } + return TaskExecutionResult::TASK_FINISHED; + } + + string TaskType() const override { + return "AsyncTask"; + } + +private: + unique_ptr async_task; + InterruptState interrupt_state; + shared_ptr counter; +}; + +AsyncResult::AsyncResult(SourceResultType t) : AsyncResult(GetAsyncResultType(t)) { +} + +AsyncResult::AsyncResult(AsyncResultType t) : result_type(t) { + if (result_type == AsyncResultType::BLOCKED) { + throw InternalException("AsyncResult constructed with a BLOCKED state, do provide AsyncTasks"); + } +} + +AsyncResult::AsyncResult(vector> &&tasks) + : result_type(AsyncResultType::BLOCKED), async_tasks(std::move(tasks)) { + if (async_tasks.empty()) { + throw InternalException("AsyncResult constructed from empty vector of tasks"); + } +} + +AsyncResult &AsyncResult::operator=(duckdb::SourceResultType t) { + return operator=(AsyncResult(t)); +} + +AsyncResult &AsyncResult::operator=(duckdb::AsyncResultType t) { + return operator=(AsyncResult(t)); +} + +AsyncResult &AsyncResult::operator=(AsyncResult &&other) noexcept { + result_type = other.result_type; + async_tasks = std::move(other.async_tasks); + return *this; +} + +void AsyncResult::ScheduleTasks(InterruptState &interrupt_state, Executor &executor) { + if (result_type != AsyncResultType::BLOCKED) { + throw InternalException("AsyncResult::ScheduleTasks called on non BLOCKED AsyncResult"); + } + + if (async_tasks.empty()) { + throw InternalException("AsyncResult::ScheduleTasks called with no available tasks"); + } + + shared_ptr counter = make_shared_ptr(async_tasks.size()); + + for (auto &async_task : async_tasks) { + auto task = make_uniq(executor, std::move(async_task), interrupt_state, counter); + TaskScheduler::GetScheduler(executor.context).ScheduleTask(executor.GetToken(), std::move(task)); + } +} + +void AsyncResult::ExecuteTasksSynchronously() { + if (result_type != AsyncResultType::BLOCKED) { + throw InternalException("AsyncResult::ExecuteTasksSynchronously called on non BLOCKED AsyncResult"); + } + + if (async_tasks.empty()) { + throw InternalException("AsyncResult::ExecuteTasksSynchronously called with no available tasks"); + } + + for (auto &async_task : async_tasks) { + async_task->Execute(); + } + + async_tasks.clear(); + + result_type = AsyncResultType::HAVE_MORE_OUTPUT; +} + +AsyncResultType AsyncResult::GetAsyncResultType(SourceResultType s) { + switch (s) { + case SourceResultType::HAVE_MORE_OUTPUT: + return AsyncResultType::HAVE_MORE_OUTPUT; + case SourceResultType::FINISHED: + return AsyncResultType::FINISHED; + case SourceResultType::BLOCKED: + return AsyncResultType::BLOCKED; + } + throw InternalException("GetAsyncResultType has an unexpected input"); +} + +bool AsyncResult::HasTasks() const { + D_ASSERT(result_type != AsyncResultType::INVALID); + if (async_tasks.empty()) { + D_ASSERT(result_type != AsyncResultType::BLOCKED); + return false; + } else { + D_ASSERT(result_type == AsyncResultType::BLOCKED); + return true; + } +} +AsyncResultType AsyncResult::GetResultType() const { + D_ASSERT(result_type != AsyncResultType::INVALID); + if (async_tasks.empty()) { + D_ASSERT(result_type != AsyncResultType::BLOCKED); + } else { + D_ASSERT(result_type == AsyncResultType::BLOCKED); + } + return result_type; +} +vector> &&AsyncResult::ExtractAsyncTasks() { + D_ASSERT(result_type != AsyncResultType::INVALID); + result_type = AsyncResultType::INVALID; + return std::move(async_tasks); +} + +#ifdef DUCKDB_DEBUG_ASYNC_SINK_SOURCE +vector> AsyncResult::GenerateTestTasks() { + vector> tasks; + auto random_number = rand() % 16; + switch (random_number) { + case 0: + tasks.push_back(make_uniq(rand() % 32)); + tasks.push_back(make_uniq(rand() % 32)); + tasks.push_back(make_uniq(rand() % 32)); + tasks.push_back(make_uniq(rand() % 32)); + tasks.push_back(make_uniq(rand() % 32)); + tasks.push_back(make_uniq(rand() % 32)); + tasks.push_back(make_uniq(rand() % 32)); + tasks.push_back(make_uniq(rand() % 32)); +#ifndef AVOID_DUCKDB_DEBUG_ASYNC_THROW + case 1: + tasks.push_back(make_uniq(rand() % 32)); +#endif + default: + break; + } + return tasks; +} +#endif + +AsyncResultsExecutionMode +AsyncResult::ConvertToAsyncResultExecutionMode(const PhysicalTableScanExecutionStrategy &execution_mode) { + switch (execution_mode) { + case PhysicalTableScanExecutionStrategy::DEFAULT: + case PhysicalTableScanExecutionStrategy::TASK_EXECUTOR: + case PhysicalTableScanExecutionStrategy::TASK_EXECUTOR_BUT_FORCE_SYNC_CHECKS: + return AsyncResultsExecutionMode::TASK_EXECUTOR; + case PhysicalTableScanExecutionStrategy::SYNCHRONOUS: + return AsyncResultsExecutionMode::SYNCHRONOUS; + } + throw InternalException("ConvertToAsyncResultExecutionMode passed an unexpected execution_mode"); +} + +} // namespace duckdb diff --git a/src/duckdb/src/parallel/executor.cpp b/src/duckdb/src/parallel/executor.cpp index d79fa1816..9a9cf4703 100644 --- a/src/duckdb/src/parallel/executor.cpp +++ b/src/duckdb/src/parallel/executor.cpp @@ -379,7 +379,6 @@ void Executor::Initialize(PhysicalOperator &plan) { } void Executor::InitializeInternal(PhysicalOperator &plan) { - auto &scheduler = TaskScheduler::GetScheduler(context); { lock_guard elock(executor_lock); @@ -423,7 +422,6 @@ void Executor::InitializeInternal(PhysicalOperator &plan) { void Executor::CancelTasks() { task.reset(); - { lock_guard elock(executor_lock); // mark the query as cancelled so tasks will early-out @@ -463,17 +461,23 @@ void Executor::SignalTaskRescheduled(lock_guard &) { void Executor::WaitForTask() { #ifndef DUCKDB_NO_THREADS - static constexpr std::chrono::milliseconds WAIT_TIME_MS = std::chrono::milliseconds(WAIT_TIME); + static constexpr std::chrono::microseconds WAIT_TIME_MS = std::chrono::microseconds(WAIT_TIME * 1000); + auto begin = std::chrono::high_resolution_clock::now(); std::unique_lock l(executor_lock); + auto end = std::chrono::high_resolution_clock::now(); + auto dur = end - begin; + auto ms = NumericCast(std::chrono::duration_cast(dur).count()); if (to_be_rescheduled_tasks.empty()) { + blocked_thread_time += ms; return; } if (ResultCollectorIsBlocked()) { // If the result collector is blocked, it won't get unblocked until the connection calls Fetch + blocked_thread_time += ms; return; } - blocked_thread_time++; + blocked_thread_time += ms + WAIT_TIME_MS.count(); task_reschedule.wait_for(l, WAIT_TIME_MS); #endif } @@ -578,6 +582,12 @@ PendingExecutionResult Executor::ExecuteTask(bool dry_run) { } else if (result == TaskExecutionResult::TASK_FINISHED) { // if the task is finished, clean it up task.reset(); + } else if (result == TaskExecutionResult::TASK_ERROR) { + if (!HasError()) { + // This is very much unexpected, TASK_ERROR means this executor should have an Error + throw InternalException("A task executed within Executor::ExecuteTask, from own producer, returned " + "TASK_ERROR without setting error on the Executor"); + } } } if (!HasError()) { @@ -672,13 +682,12 @@ void Executor::ThrowException() { } void Executor::Flush(ThreadContext &thread_context) { - static constexpr std::chrono::milliseconds WAIT_TIME_MS = std::chrono::milliseconds(WAIT_TIME); auto global_profiler = profiler; if (global_profiler) { global_profiler->Flush(thread_context.profiler); auto blocked_time = blocked_thread_time.load(); - global_profiler->SetInfo(double(blocked_time * WAIT_TIME_MS.count()) / 1000); + global_profiler->SetBlockedTime(double(blocked_time) / 1000.0 / 1000.0); } } diff --git a/src/duckdb/src/parallel/pipeline_executor.cpp b/src/duckdb/src/parallel/pipeline_executor.cpp index 9db69ac99..4b886961e 100644 --- a/src/duckdb/src/parallel/pipeline_executor.cpp +++ b/src/duckdb/src/parallel/pipeline_executor.cpp @@ -123,12 +123,15 @@ bool PipelineExecutor::TryFlushCachingOperators(ExecutionBudget &chunk_budget) { return true; } -SinkNextBatchType PipelineExecutor::NextBatch(DataChunk &source_chunk) { +SinkNextBatchType PipelineExecutor::NextBatch(DataChunk &source_chunk, const bool have_more_output) { D_ASSERT(required_partition_info.AnyRequired()); auto max_batch_index = pipeline.base_batch_index + PipelineBuildState::BATCH_INCREMENT - 1; // by default set it to the maximum valid batch index value for the current pipeline + auto &partition_info = local_sink_state->partition_info; OperatorPartitionData next_data(max_batch_index); - if (source_chunk.size() > 0) { + if ((source_chunk.size() > 0)) { + D_ASSERT(local_source_state); + D_ASSERT(pipeline.source_state); // if we retrieved data - initialize the next batch index auto partition_data = pipeline.source->GetPartitionData(context, source_chunk, *pipeline.source_state, *local_source_state, required_partition_info); @@ -140,8 +143,9 @@ SinkNextBatchType PipelineExecutor::NextBatch(DataChunk &source_chunk) { throw InternalException("Pipeline batch index - invalid batch index %llu returned by source operator", batch_index); } + } else if (have_more_output) { + next_data.batch_index = partition_info.batch_index.GetIndex(); } - auto &partition_info = local_sink_state->partition_info; if (next_data.batch_index == partition_info.batch_index.GetIndex()) { // no changes, return return SinkNextBatchType::READY; @@ -221,7 +225,7 @@ PipelineExecuteResult PipelineExecutor::Execute(idx_t max_chunks) { } } } else if (!exhausted_source || next_batch_blocked) { - SourceResultType source_result; + SourceResultType source_result = SourceResultType::BLOCKED; if (!next_batch_blocked) { // "Regular" path: fetch a chunk from the source and push it through the pipeline source_chunk.Reset(); @@ -235,7 +239,7 @@ PipelineExecuteResult PipelineExecutor::Execute(idx_t max_chunks) { } if (required_partition_info.AnyRequired()) { - auto next_batch_result = NextBatch(source_chunk); + auto next_batch_result = NextBatch(source_chunk, source_result == SourceResultType::HAVE_MORE_OUTPUT); next_batch_blocked = next_batch_result == SinkNextBatchType::BLOCKED; if (next_batch_blocked) { return PipelineExecuteResult::INTERRUPTED; @@ -243,7 +247,6 @@ PipelineExecuteResult PipelineExecutor::Execute(idx_t max_chunks) { } if (exhausted_source && source_chunk.size() == 0) { - // To ensure that we're not early-terminating the pipeline continue; } diff --git a/src/duckdb/src/parallel/task_executor.cpp b/src/duckdb/src/parallel/task_executor.cpp index fa2c0087c..9487a1427 100644 --- a/src/duckdb/src/parallel/task_executor.cpp +++ b/src/duckdb/src/parallel/task_executor.cpp @@ -69,8 +69,10 @@ TaskExecutionResult BaseExecutorTask::Execute(TaskExecutionMode mode) { return TaskExecutionResult::TASK_FINISHED; } try { - TaskNotifier task_notifier {executor.context}; - ExecuteTask(); + { + TaskNotifier task_notifier {executor.context}; + ExecuteTask(); + } executor.FinishTask(); return TaskExecutionResult::TASK_FINISHED; } catch (std::exception &ex) { diff --git a/src/duckdb/src/parser/expression/lambda_expression.cpp b/src/duckdb/src/parser/expression/lambda_expression.cpp index d8d4fe891..ec98300e9 100644 --- a/src/duckdb/src/parser/expression/lambda_expression.cpp +++ b/src/duckdb/src/parser/expression/lambda_expression.cpp @@ -34,7 +34,6 @@ LambdaExpression::LambdaExpression(unique_ptr lhs, unique_ptr< } vector> LambdaExpression::ExtractColumnRefExpressions(string &error_message) const { - // we return an error message because we can't throw a binder exception here, // since we can't distinguish between a lambda function and the JSON operator yet vector> column_refs; diff --git a/src/duckdb/src/parser/expression/lambdaref_expression.cpp b/src/duckdb/src/parser/expression/lambdaref_expression.cpp index fed844fea..f1e7e59bf 100644 --- a/src/duckdb/src/parser/expression/lambdaref_expression.cpp +++ b/src/duckdb/src/parser/expression/lambdaref_expression.cpp @@ -37,7 +37,6 @@ unique_ptr LambdaRefExpression::Copy() const { unique_ptr LambdaRefExpression::FindMatchingBinding(optional_ptr> &lambda_bindings, const string &column_name) { - // if this is a lambda parameter, then we temporarily add a BoundLambdaRef, // which we capture and remove later @@ -47,7 +46,7 @@ LambdaRefExpression::FindMatchingBinding(optional_ptr> &lam if (lambda_bindings) { for (idx_t i = lambda_bindings->size(); i > 0; i--) { if ((*lambda_bindings)[i - 1].HasMatchingBinding(column_name)) { - D_ASSERT((*lambda_bindings)[i - 1].alias.IsSet()); + D_ASSERT((*lambda_bindings)[i - 1].GetBindingAlias().IsSet()); return make_uniq(i - 1, column_name); } } diff --git a/src/duckdb/src/parser/expression/window_expression.cpp b/src/duckdb/src/parser/expression/window_expression.cpp index 9720d2abf..8a655332f 100644 --- a/src/duckdb/src/parser/expression/window_expression.cpp +++ b/src/duckdb/src/parser/expression/window_expression.cpp @@ -35,31 +35,35 @@ WindowExpression::WindowExpression(ExpressionType type, string catalog_name, str } } +static const WindowFunctionDefinition internal_window_functions[] = { + {"rank", ExpressionType::WINDOW_RANK}, + {"rank_dense", ExpressionType::WINDOW_RANK_DENSE}, + {"dense_rank", ExpressionType::WINDOW_RANK_DENSE}, + {"percent_rank", ExpressionType::WINDOW_PERCENT_RANK}, + {"row_number", ExpressionType::WINDOW_ROW_NUMBER}, + {"first_value", ExpressionType::WINDOW_FIRST_VALUE}, + {"first", ExpressionType::WINDOW_FIRST_VALUE}, + {"last_value", ExpressionType::WINDOW_LAST_VALUE}, + {"last", ExpressionType::WINDOW_LAST_VALUE}, + {"nth_value", ExpressionType::WINDOW_NTH_VALUE}, + {"cume_dist", ExpressionType::WINDOW_CUME_DIST}, + {"lead", ExpressionType::WINDOW_LEAD}, + {"lag", ExpressionType::WINDOW_LAG}, + {"ntile", ExpressionType::WINDOW_NTILE}, + {"fill", ExpressionType::WINDOW_FILL}, + {nullptr, ExpressionType::INVALID}}; + +const WindowFunctionDefinition *WindowExpression::WindowFunctions() { + return internal_window_functions; +} + ExpressionType WindowExpression::WindowToExpressionType(string &fun_name) { - if (fun_name == "rank") { - return ExpressionType::WINDOW_RANK; - } else if (fun_name == "rank_dense" || fun_name == "dense_rank") { - return ExpressionType::WINDOW_RANK_DENSE; - } else if (fun_name == "percent_rank") { - return ExpressionType::WINDOW_PERCENT_RANK; - } else if (fun_name == "row_number") { - return ExpressionType::WINDOW_ROW_NUMBER; - } else if (fun_name == "first_value" || fun_name == "first") { - return ExpressionType::WINDOW_FIRST_VALUE; - } else if (fun_name == "last_value" || fun_name == "last") { - return ExpressionType::WINDOW_LAST_VALUE; - } else if (fun_name == "nth_value") { - return ExpressionType::WINDOW_NTH_VALUE; - } else if (fun_name == "cume_dist") { - return ExpressionType::WINDOW_CUME_DIST; - } else if (fun_name == "lead") { - return ExpressionType::WINDOW_LEAD; - } else if (fun_name == "lag") { - return ExpressionType::WINDOW_LAG; - } else if (fun_name == "ntile") { - return ExpressionType::WINDOW_NTILE; - } else if (fun_name == "fill") { - return ExpressionType::WINDOW_FILL; + D_ASSERT(StringUtil::IsLower(fun_name)); + auto functions = WindowFunctions(); + for (idx_t i = 0; functions[i].name != nullptr; i++) { + if (fun_name == functions[i].name) { + return functions[i].expression_type; + } } return ExpressionType::WINDOW_AGGREGATE; } diff --git a/src/duckdb/src/parser/parsed_expression_iterator.cpp b/src/duckdb/src/parser/parsed_expression_iterator.cpp index 7ca38a10e..f5746f9f7 100644 --- a/src/duckdb/src/parser/parsed_expression_iterator.cpp +++ b/src/duckdb/src/parser/parsed_expression_iterator.cpp @@ -162,7 +162,6 @@ void ParsedExpressionIterator::EnumerateChildren( void ParsedExpressionIterator::EnumerateQueryNodeModifiers( QueryNode &node, const std::function &child)> &callback) { - for (auto &modifier : node.modifiers) { switch (modifier->type) { case ResultModifierType::LIMIT_MODIFIER: { @@ -271,12 +270,6 @@ void ParsedExpressionIterator::EnumerateQueryNodeChildren( EnumerateQueryNodeChildren(*rcte_node.right, expr_callback, ref_callback); break; } - case QueryNodeType::CTE_NODE: { - auto &cte_node = node.Cast(); - EnumerateQueryNodeChildren(*cte_node.query, expr_callback, ref_callback); - EnumerateQueryNodeChildren(*cte_node.child, expr_callback, ref_callback); - break; - } case QueryNodeType::SELECT_NODE: { auto &sel_node = node.Cast(); for (idx_t i = 0; i < sel_node.select_list.size(); i++) { diff --git a/src/duckdb/src/parser/parser.cpp b/src/duckdb/src/parser/parser.cpp index 552b6e180..e61c015f6 100644 --- a/src/duckdb/src/parser/parser.cpp +++ b/src/duckdb/src/parser/parser.cpp @@ -165,33 +165,75 @@ bool Parser::StripUnicodeSpaces(const string &query_str, string &new_query) { return ReplaceUnicodeSpaces(query_str, new_query, unicode_spaces); } -vector SplitQueryStringIntoStatements(const string &query) { - // Break sql string down into sql statements using the tokenizer - vector query_statements; - auto tokens = Parser::Tokenize(query); - idx_t next_statement_start = 0; - for (idx_t i = 1; i < tokens.size(); ++i) { - auto &t_prev = tokens[i - 1]; - auto &t = tokens[i]; - if (t_prev.type == SimplifiedTokenType::SIMPLIFIED_TOKEN_OPERATOR) { - // LCOV_EXCL_START - for (idx_t c = t_prev.start; c <= t.start; ++c) { - if (query.c_str()[c] == ';') { - query_statements.emplace_back(query.substr(next_statement_start, t.start - next_statement_start)); - next_statement_start = tokens[i].start; - } +vector SplitQueries(const string &input_query) { + vector queries; + auto tokenized_input = Parser::Tokenize(input_query); + size_t last_split = 0; + + for (const auto &token : tokenized_input) { + if (token.type == SimplifiedTokenType::SIMPLIFIED_TOKEN_OPERATOR && input_query[token.start] == ';') { + string segment = input_query.substr(last_split, token.start - last_split); + StringUtil::Trim(segment); + if (!segment.empty()) { + segment.append(";"); + queries.push_back(std::move(segment)); } - // LCOV_EXCL_STOP + last_split = token.start + 1; + } + } + string final_segment = input_query.substr(last_split); + StringUtil::Trim(final_segment); + if (!final_segment.empty()) { + final_segment.append(";"); + queries.push_back(std::move(final_segment)); + } + return queries; +} + +StatementType Parser::GetStatementType(const string &query) { + Transformer transformer(options); + vector> statements; + PostgresParser parser; + parser.Parse(query); + if (parser.success) { + if (!parser.parse_tree) { + // empty statement + return StatementType::INVALID_STATEMENT; + } + transformer.TransformParseTree(parser.parse_tree, statements); + return statements[0]->type; + } else { + return StatementType::INVALID_STATEMENT; + } +} + +void Parser::ThrowParserOverrideError(ParserOverrideResult &result) { + if (result.type == ParserExtensionResultType::DISPLAY_ORIGINAL_ERROR) { + throw ParserException("Parser override failed to return a valid statement: %s\n\nConsider restarting the " + "database and " + "using the setting \"set allow_parser_override_extension=fallback\" to fallback to the " + "default parser.", + result.error.RawMessage()); + } + if (result.type == ParserExtensionResultType::DISPLAY_EXTENSION_ERROR) { + if (result.error.Type() == ExceptionType::NOT_IMPLEMENTED) { + throw NotImplementedException("Parser override has not yet implemented this " + "transformer rule. (Original error: %s)", + result.error.RawMessage()); } + if (result.error.Type() == ExceptionType::PARSER) { + throw ParserException("Parser override could not parse this query. (Original error: %s)", + result.error.RawMessage()); + } + result.error.Throw(); } - query_statements.emplace_back(query.substr(next_statement_start, query.size() - next_statement_start)); - return query_statements; } void Parser::ParseQuery(const string &query) { Transformer transformer(options); string parser_error; optional_idx parser_error_location; + string parser_override_option = StringUtil::Lower(options.parser_override_setting); { // check if there are any unicode spaces in the string string new_query; @@ -207,12 +249,39 @@ void Parser::ParseQuery(const string &query) { if (!ext.parser_override) { continue; } + if (StringUtil::CIEquals(parser_override_option, "default")) { + continue; + } auto result = ext.parser_override(ext.parser_info.get(), query); if (result.type == ParserExtensionResultType::PARSE_SUCCESSFUL) { statements = std::move(result.statements); return; - } else if (result.type == ParserExtensionResultType::DISPLAY_EXTENSION_ERROR) { - throw ParserException(result.error); + } + if (StringUtil::CIEquals(parser_override_option, "strict")) { + ThrowParserOverrideError(result); + } + if (StringUtil::CIEquals(parser_override_option, "strict_when_supported")) { + auto statement_type = GetStatementType(query); + bool is_supported = false; + switch (statement_type) { + case StatementType::CALL_STATEMENT: + case StatementType::TRANSACTION_STATEMENT: + case StatementType::VARIABLE_SET_STATEMENT: + case StatementType::LOAD_STATEMENT: + case StatementType::ATTACH_STATEMENT: + case StatementType::DETACH_STATEMENT: + case StatementType::DELETE_STATEMENT: + is_supported = true; + break; + default: + is_supported = false; + break; + } + if (is_supported) { + ThrowParserOverrideError(result); + } + } else if (StringUtil::CIEquals(parser_override_option, "fallback")) { + continue; } } } @@ -250,9 +319,9 @@ void Parser::ParseQuery(const string &query) { throw ParserException::SyntaxError(query, parser_error, parser_error_location); } else { // split sql string into statements and re-parse using extension - auto query_statements = SplitQueryStringIntoStatements(query); + auto queries = SplitQueries(query); idx_t stmt_loc = 0; - for (auto const &query_statement : query_statements) { + for (auto const &query_statement : queries) { ErrorData another_parser_error; // Creating a new scope to allow extensions to use PostgresParser, which is not reentrant { @@ -284,7 +353,9 @@ void Parser::ParseQuery(const string &query) { bool parsed_single_statement = false; for (auto &ext : *options.extensions) { D_ASSERT(!parsed_single_statement); - D_ASSERT(ext.parse_function); + if (!ext.parse_function) { + continue; + } auto result = ext.parse_function(ext.parser_info.get(), query_statement); if (result.type == ParserExtensionResultType::PARSE_SUCCESSFUL) { auto statement = make_uniq(ext, std::move(result.parse_data)); diff --git a/src/duckdb/src/parser/query_node/cte_node.cpp b/src/duckdb/src/parser/query_node/cte_node.cpp index 1e1f0e199..29c0599a5 100644 --- a/src/duckdb/src/parser/query_node/cte_node.cpp +++ b/src/duckdb/src/parser/query_node/cte_node.cpp @@ -1,42 +1,20 @@ #include "duckdb/parser/query_node/cte_node.hpp" #include "duckdb/common/serializer/serializer.hpp" #include "duckdb/common/serializer/deserializer.hpp" +#include "duckdb/parser/statement/select_statement.hpp" namespace duckdb { string CTENode::ToString() const { - string result; - result += child->ToString(); - return result; + throw InternalException("CTENode is a legacy type"); } bool CTENode::Equals(const QueryNode *other_p) const { - if (!QueryNode::Equals(other_p)) { - return false; - } - if (this == other_p) { - return true; - } - auto &other = other_p->Cast(); - - if (!query->Equals(other.query.get())) { - return false; - } - if (!child->Equals(other.child.get())) { - return false; - } - return true; + throw InternalException("CTENode is a legacy type"); } unique_ptr CTENode::Copy() const { - auto result = make_uniq(); - result->ctename = ctename; - result->query = query->Copy(); - result->child = child->Copy(); - result->aliases = aliases; - result->materialized = materialized; - this->CopyProperties(*result); - return std::move(result); + throw InternalException("CTENode is a legacy type"); } } // namespace duckdb diff --git a/src/duckdb/src/parser/query_node/set_operation_node.cpp b/src/duckdb/src/parser/query_node/set_operation_node.cpp index a8b624f21..cdc188820 100644 --- a/src/duckdb/src/parser/query_node/set_operation_node.cpp +++ b/src/duckdb/src/parser/query_node/set_operation_node.cpp @@ -8,10 +8,6 @@ namespace duckdb { SetOperationNode::SetOperationNode() : QueryNode(QueryNodeType::SET_OPERATION_NODE) { } -const vector> &SetOperationNode::GetSelectList() const { - return children[0]->GetSelectList(); -} - string SetOperationNode::ToString() const { string result; result = cte_map.ToString(); diff --git a/src/duckdb/src/parser/query_node/statement_node.cpp b/src/duckdb/src/parser/query_node/statement_node.cpp new file mode 100644 index 000000000..66e7b8e5a --- /dev/null +++ b/src/duckdb/src/parser/query_node/statement_node.cpp @@ -0,0 +1,40 @@ +#include "duckdb/parser/query_node/statement_node.hpp" + +namespace duckdb { + +StatementNode::StatementNode(SQLStatement &stmt_p) : QueryNode(QueryNodeType::STATEMENT_NODE), stmt(stmt_p) { +} + +//! Convert the query node to a string +string StatementNode::ToString() const { + return stmt.ToString(); +} + +bool StatementNode::Equals(const QueryNode *other_p) const { + if (!QueryNode::Equals(other_p)) { + return false; + } + if (this == other_p) { + return true; + } + auto &other = other_p->Cast(); + return RefersToSameObject(stmt, other.stmt); +} + +//! Create a copy of this SelectNode +unique_ptr StatementNode::Copy() const { + return make_uniq(stmt); +} + +//! Serializes a QueryNode to a stand-alone binary blob +//! Deserializes a blob back into a QueryNode + +void StatementNode::Serialize(Serializer &serializer) const { + throw InternalException("StatementNode cannot be serialized"); +} + +unique_ptr StatementNode::Deserialize(Deserializer &source) { + throw InternalException("StatementNode cannot be deserialized"); +} + +} // namespace duckdb diff --git a/src/duckdb/src/parser/statement/relation_statement.cpp b/src/duckdb/src/parser/statement/relation_statement.cpp index 9b3801495..023d3cac9 100644 --- a/src/duckdb/src/parser/statement/relation_statement.cpp +++ b/src/duckdb/src/parser/statement/relation_statement.cpp @@ -5,10 +5,7 @@ namespace duckdb { RelationStatement::RelationStatement(shared_ptr relation_p) : SQLStatement(StatementType::RELATION_STATEMENT), relation(std::move(relation_p)) { - if (relation->type == RelationType::QUERY_RELATION) { - auto &query_relation = relation->Cast(); - query = query_relation.query; - } + query = relation->GetQuery(); } unique_ptr RelationStatement::Copy() const { diff --git a/src/duckdb/src/parser/statement/update_statement.cpp b/src/duckdb/src/parser/statement/update_statement.cpp index b09fd1a0d..29ce2432e 100644 --- a/src/duckdb/src/parser/statement/update_statement.cpp +++ b/src/duckdb/src/parser/statement/update_statement.cpp @@ -49,7 +49,6 @@ UpdateStatement::UpdateStatement(const UpdateStatement &other) } string UpdateStatement::ToString() const { - string result; result = cte_map.ToString(); result += "UPDATE "; diff --git a/src/duckdb/src/parser/tableref/bound_ref_wrapper.cpp b/src/duckdb/src/parser/tableref/bound_ref_wrapper.cpp index a2ecb5086..6bf31269c 100644 --- a/src/duckdb/src/parser/tableref/bound_ref_wrapper.cpp +++ b/src/duckdb/src/parser/tableref/bound_ref_wrapper.cpp @@ -2,7 +2,7 @@ namespace duckdb { -BoundRefWrapper::BoundRefWrapper(unique_ptr bound_ref_p, shared_ptr binder_p) +BoundRefWrapper::BoundRefWrapper(BoundStatement bound_ref_p, shared_ptr binder_p) : TableRef(TableReferenceType::BOUND_TABLE_REF), bound_ref(std::move(bound_ref_p)), binder(std::move(binder_p)) { } diff --git a/src/duckdb/src/parser/transform/expression/transform_array_access.cpp b/src/duckdb/src/parser/transform/expression/transform_array_access.cpp index 447688c61..7e3529feb 100644 --- a/src/duckdb/src/parser/transform/expression/transform_array_access.cpp +++ b/src/duckdb/src/parser/transform/expression/transform_array_access.cpp @@ -7,7 +7,6 @@ namespace duckdb { unique_ptr Transformer::TransformArrayAccess(duckdb_libpgquery::PGAIndirection &indirection_node) { - // Transform the source expression. unique_ptr result; result = TransformExpression(indirection_node.arg); diff --git a/src/duckdb/src/parser/transform/expression/transform_expression.cpp b/src/duckdb/src/parser/transform/expression/transform_expression.cpp index 8cbf53b70..73b42c9dd 100644 --- a/src/duckdb/src/parser/transform/expression/transform_expression.cpp +++ b/src/duckdb/src/parser/transform/expression/transform_expression.cpp @@ -16,7 +16,6 @@ unique_ptr Transformer::TransformResTarget(duckdb_libpgquery:: } unique_ptr Transformer::TransformNamedArg(duckdb_libpgquery::PGNamedArgExpr &root) { - auto expr = TransformExpression(PGPointerCast(root.arg)); if (root.name) { expr->SetAlias(root.name); @@ -25,7 +24,6 @@ unique_ptr Transformer::TransformNamedArg(duckdb_libpgquery::P } unique_ptr Transformer::TransformExpression(duckdb_libpgquery::PGNode &node) { - auto stack_checker = StackCheck(); switch (node.type) { diff --git a/src/duckdb/src/parser/transform/expression/transform_function.cpp b/src/duckdb/src/parser/transform/expression/transform_function.cpp index b1993643c..574be3e2e 100644 --- a/src/duckdb/src/parser/transform/expression/transform_function.cpp +++ b/src/duckdb/src/parser/transform/expression/transform_function.cpp @@ -38,7 +38,6 @@ void Transformer::TransformWindowDef(duckdb_libpgquery::PGWindowDef &window_spec static inline WindowBoundary TransformFrameOption(const int frameOptions, const WindowBoundary rows, const WindowBoundary range, const WindowBoundary groups) { - if (frameOptions & FRAMEOPTION_RANGE) { return range; } else if (frameOptions & FRAMEOPTION_GROUPS) { diff --git a/src/duckdb/src/parser/transform/expression/transform_multi_assign_reference.cpp b/src/duckdb/src/parser/transform/expression/transform_multi_assign_reference.cpp index 28dc623f3..4ad2a0de3 100644 --- a/src/duckdb/src/parser/transform/expression/transform_multi_assign_reference.cpp +++ b/src/duckdb/src/parser/transform/expression/transform_multi_assign_reference.cpp @@ -4,7 +4,6 @@ namespace duckdb { unique_ptr Transformer::TransformMultiAssignRef(duckdb_libpgquery::PGMultiAssignRef &root) { - // Early-out, if the root is not a function call. if (root.source->type != duckdb_libpgquery::T_PGFuncCall) { return TransformExpression(root.source); diff --git a/src/duckdb/src/parser/transform/expression/transform_subquery.cpp b/src/duckdb/src/parser/transform/expression/transform_subquery.cpp index bc8a9762d..986e46e25 100644 --- a/src/duckdb/src/parser/transform/expression/transform_subquery.cpp +++ b/src/duckdb/src/parser/transform/expression/transform_subquery.cpp @@ -24,7 +24,6 @@ unique_ptr Transformer::TransformSubquery(duckdb_libpgquery::P subquery_expr->subquery = TransformSelectStmt(*root.subselect); SetQueryLocation(*subquery_expr, root.location); D_ASSERT(subquery_expr->subquery); - D_ASSERT(!subquery_expr->subquery->node->GetSelectList().empty()); switch (root.subLinkType) { case duckdb_libpgquery::PG_EXISTS_SUBLINK: { diff --git a/src/duckdb/src/parser/transform/helpers/transform_cte.cpp b/src/duckdb/src/parser/transform/helpers/transform_cte.cpp index 2de5d8334..f5f232fc3 100644 --- a/src/duckdb/src/parser/transform/helpers/transform_cte.cpp +++ b/src/duckdb/src/parser/transform/helpers/transform_cte.cpp @@ -23,9 +23,16 @@ unique_ptr CommonTableExpressionInfo::Copy() { CommonTableExpressionInfo::~CommonTableExpressionInfo() { } +CTEMaterialize CommonTableExpressionInfo::GetMaterializedForSerialization(Serializer &serializer) const { + if (serializer.ShouldSerialize(7)) { + return materialized; + } + return CTEMaterialize::CTE_MATERIALIZE_DEFAULT; +} + void Transformer::ExtractCTEsRecursive(CommonTableExpressionMap &cte_map) { for (auto &cte_entry : stored_cte_map) { - for (auto &entry : cte_entry->map) { + for (auto &entry : cte_entry.get().map) { auto found_entry = cte_map.map.find(entry.first); if (found_entry != cte_map.map.end()) { // entry already present - use top-most entry @@ -40,7 +47,7 @@ void Transformer::ExtractCTEsRecursive(CommonTableExpressionMap &cte_map) { } void Transformer::TransformCTE(duckdb_libpgquery::PGWithClause &de_with_clause, CommonTableExpressionMap &cte_map) { - stored_cte_map.push_back(&cte_map); + stored_cte_map.push_back(cte_map); // TODO: might need to update in case of future lawsuit D_ASSERT(de_with_clause.ctes); diff --git a/src/duckdb/src/parser/transform/helpers/transform_typename.cpp b/src/duckdb/src/parser/transform/helpers/transform_typename.cpp index c071af00b..86ce2a73c 100644 --- a/src/duckdb/src/parser/transform/helpers/transform_typename.cpp +++ b/src/duckdb/src/parser/transform/helpers/transform_typename.cpp @@ -17,7 +17,6 @@ struct SizeModifiers { }; static SizeModifiers GetSizeModifiers(duckdb_libpgquery::PGTypeName &type_name, LogicalTypeId base_type) { - SizeModifiers result; if (base_type == LogicalTypeId::DECIMAL) { @@ -97,6 +96,11 @@ LogicalType Transformer::TransformTypeNameInternal(duckdb_libpgquery::PGTypeName // transform it to the SQL type LogicalTypeId base_type = TransformStringToLogicalTypeId(name); + if (base_type == LogicalTypeId::GEOMETRY) { + // Always return a type with GeoTypeInfo + return LogicalType::GEOMETRY(); + } + if (base_type == LogicalTypeId::LIST) { throw ParserException("LIST is not valid as a stand-alone type"); } diff --git a/src/duckdb/src/parser/transform/statement/transform_alter_table.cpp b/src/duckdb/src/parser/transform/statement/transform_alter_table.cpp index 7fb15a90d..ea374c51e 100644 --- a/src/duckdb/src/parser/transform/statement/transform_alter_table.cpp +++ b/src/duckdb/src/parser/transform/statement/transform_alter_table.cpp @@ -19,7 +19,6 @@ vector Transformer::TransformNameList(duckdb_libpgquery::PGList &list) { } unique_ptr Transformer::TransformAlter(duckdb_libpgquery::PGAlterTableStmt &stmt) { - D_ASSERT(stmt.relation); if (stmt.cmds->length != 1) { throw ParserException("Only one ALTER command per statement is supported"); @@ -30,7 +29,6 @@ unique_ptr Transformer::TransformAlter(duckdb_libpgquery::PGAlte // Check the ALTER type. for (auto c = stmt.cmds->head; c != nullptr; c = c->next) { - auto command = PGPointerCast(c->data.ptr_value); AlterEntryData data(qualified_name.catalog, qualified_name.schema, qualified_name.name, TransformOnEntryNotFound(stmt.missing_ok)); diff --git a/src/duckdb/src/parser/transform/statement/transform_explain.cpp b/src/duckdb/src/parser/transform/statement/transform_explain.cpp index 510395a99..969a8827b 100644 --- a/src/duckdb/src/parser/transform/statement/transform_explain.cpp +++ b/src/duckdb/src/parser/transform/statement/transform_explain.cpp @@ -11,7 +11,8 @@ ExplainFormat ParseFormat(const Value &val) { auto format_val = val.GetValue(); case_insensitive_map_t format_mapping { {"default", ExplainFormat::DEFAULT}, {"text", ExplainFormat::TEXT}, {"json", ExplainFormat::JSON}, - {"html", ExplainFormat::HTML}, {"graphviz", ExplainFormat::GRAPHVIZ}, {"yaml", ExplainFormat::YAML}}; + {"html", ExplainFormat::HTML}, {"graphviz", ExplainFormat::GRAPHVIZ}, {"yaml", ExplainFormat::YAML}, + {"mermaid", ExplainFormat::MERMAID}}; auto it = format_mapping.find(format_val); if (it != format_mapping.end()) { return it->second; diff --git a/src/duckdb/src/parser/transform/statement/transform_pivot_stmt.cpp b/src/duckdb/src/parser/transform/statement/transform_pivot_stmt.cpp index 07dfb420d..4572a3a36 100644 --- a/src/duckdb/src/parser/transform/statement/transform_pivot_stmt.cpp +++ b/src/duckdb/src/parser/transform/statement/transform_pivot_stmt.cpp @@ -95,7 +95,7 @@ unique_ptr Transformer::GenerateCreateEnumStmt(unique_ptr(); - select->node = TransformMaterializedCTE(std::move(subselect)); + select->node = std::move(subselect); info->query = std::move(select); info->type = LogicalType::INVALID; diff --git a/src/duckdb/src/parser/transform/statement/transform_select.cpp b/src/duckdb/src/parser/transform/statement/transform_select.cpp index 2e5135ef6..16cd1a490 100644 --- a/src/duckdb/src/parser/transform/statement/transform_select.cpp +++ b/src/duckdb/src/parser/transform/statement/transform_select.cpp @@ -26,13 +26,10 @@ unique_ptr Transformer::TransformSelectNodeInternal(duckdb_libpgquery throw ParserException("SELECT locking clause is not supported!"); } } - unique_ptr stmt = nullptr; if (select.pivot) { - stmt = TransformPivotStatement(select); - } else { - stmt = TransformSelectInternal(select); + return TransformPivotStatement(select); } - return TransformMaterializedCTE(std::move(stmt)); + return TransformSelectInternal(select); } unique_ptr Transformer::TransformSelectStmt(duckdb_libpgquery::PGSelectStmt &select, bool is_select) { diff --git a/src/duckdb/src/parser/transform/statement/transform_upsert.cpp b/src/duckdb/src/parser/transform/statement/transform_upsert.cpp index aa0130f3c..8d5fdaf35 100644 --- a/src/duckdb/src/parser/transform/statement/transform_upsert.cpp +++ b/src/duckdb/src/parser/transform/statement/transform_upsert.cpp @@ -67,7 +67,6 @@ unique_ptr Transformer::DummyOnConflictClause(duckdb_libpgquery: unique_ptr Transformer::TransformOnConflictClause(duckdb_libpgquery::PGOnConflictClause *node, const string &) { - auto stmt = PGPointerCast(node); D_ASSERT(stmt); diff --git a/src/duckdb/src/parser/transformer.cpp b/src/duckdb/src/parser/transformer.cpp index 4ab39fca7..32ddaa87a 100644 --- a/src/duckdb/src/parser/transformer.cpp +++ b/src/duckdb/src/parser/transformer.cpp @@ -232,31 +232,6 @@ unique_ptr Transformer::TransformStatementInternal(duckdb_libpgque } } -unique_ptr Transformer::TransformMaterializedCTE(unique_ptr root) { - // Extract materialized CTEs from cte_map - vector> materialized_ctes; - - for (auto &cte : root->cte_map.map) { - auto &cte_entry = cte.second; - auto mat_cte = make_uniq(); - mat_cte->ctename = cte.first; - mat_cte->query = TransformMaterializedCTE(cte_entry->query->node->Copy()); - mat_cte->aliases = cte_entry->aliases; - mat_cte->materialized = cte_entry->materialized; - materialized_ctes.push_back(std::move(mat_cte)); - } - - while (!materialized_ctes.empty()) { - unique_ptr node_result; - node_result = std::move(materialized_ctes.back()); - node_result->child = std::move(root); - root = std::move(node_result); - materialized_ctes.pop_back(); - } - - return root; -} - void Transformer::SetQueryLocation(ParsedExpression &expr, int query_location) { if (query_location < 0) { return; diff --git a/src/duckdb/src/planner/bind_context.cpp b/src/duckdb/src/planner/bind_context.cpp index b6e5df81f..cc1b3d25e 100644 --- a/src/duckdb/src/planner/bind_context.cpp +++ b/src/duckdb/src/planner/bind_context.cpp @@ -38,16 +38,17 @@ optional_ptr BindContext::GetMatchingBinding(const string &column_name) optional_ptr result; for (auto &binding_ptr : bindings_list) { auto &binding = *binding_ptr; - auto is_using_binding = GetUsingBinding(column_name, binding.alias); + auto is_using_binding = GetUsingBinding(column_name, binding.GetBindingAlias()); if (is_using_binding) { continue; } if (binding.HasMatchingBinding(column_name)) { if (result || is_using_binding) { - throw BinderException("Ambiguous reference to column name \"%s\" (use: \"%s.%s\" " - "or \"%s.%s\")", - column_name, MinimumUniqueAlias(result->alias, binding.alias), column_name, - MinimumUniqueAlias(binding.alias, result->alias), column_name); + throw BinderException( + "Ambiguous reference to column name \"%s\" (use: \"%s.%s\" " + "or \"%s.%s\")", + column_name, MinimumUniqueAlias(result->GetBindingAlias(), binding.GetBindingAlias()), column_name, + MinimumUniqueAlias(binding.GetBindingAlias(), result->GetBindingAlias()), column_name); } result = &binding; } @@ -58,8 +59,8 @@ optional_ptr BindContext::GetMatchingBinding(const string &column_name) vector BindContext::GetSimilarBindings(const string &column_name) { vector> scores; for (auto &binding_ptr : bindings_list) { - auto binding = *binding_ptr; - for (auto &name : binding.names) { + auto &binding = *binding_ptr; + for (auto &name : binding.GetColumnNames()) { double distance = StringUtil::SimilarityRating(name, column_name); // check if we need to qualify the column auto matching_bindings = GetMatchingBindings(name); @@ -77,10 +78,6 @@ void BindContext::AddUsingBinding(const string &column_name, UsingColumnSet &set using_columns[column_name].insert(set); } -void BindContext::AddUsingBindingSet(unique_ptr set) { - using_column_sets.push_back(std::move(set)); -} - optional_ptr BindContext::GetUsingBinding(const string &column_name) { auto entry = using_columns.find(column_name); if (entry == using_columns.end()) { @@ -161,7 +158,7 @@ string BindContext::GetActualColumnName(Binding &binding, const string &column_n throw InternalException("Binding with name \"%s\" does not have a column named \"%s\"", binding.GetAlias(), column_name); } // LCOV_EXCL_STOP - return binding.names[binding_index]; + return binding.GetColumnNames()[binding_index]; } string BindContext::GetActualColumnName(const BindingAlias &binding_alias, const string &column_name) { @@ -204,7 +201,7 @@ unique_ptr BindContext::CreateColumnReference(const string &ta } static bool ColumnIsGenerated(Binding &binding, column_t index) { - if (binding.binding_type != BindingType::TABLE) { + if (binding.GetBindingType() != BindingType::TABLE) { return false; } auto &table_binding = binding.Cast(); @@ -243,10 +240,12 @@ unique_ptr BindContext::CreateColumnReference(const string &ca auto column_index = binding->GetBindingIndex(column_name); if (bind_type == ColumnBindType::EXPAND_GENERATED_COLUMNS && ColumnIsGenerated(*binding, column_index)) { return ExpandGeneratedColumn(binding->Cast(), column_name); - } else if (column_index < binding->names.size() && binding->names[column_index] != column_name) { + } + auto &column_names = binding->GetColumnNames(); + if (column_index < column_names.size() && column_names[column_index] != column_name) { // because of case insensitivity in the binder we rename the column to the original name // as it appears in the binding itself - result->SetAlias(binding->names[column_index]); + result->SetAlias(column_names[column_index]); } return std::move(result); } @@ -257,14 +256,6 @@ unique_ptr BindContext::CreateColumnReference(const string &sc return CreateColumnReference(catalog_name, schema_name, table_name, column_name, bind_type); } -optional_ptr BindContext::GetCTEBinding(const string &ctename) { - auto match = cte_bindings.find(ctename); - if (match == cte_bindings.end()) { - return nullptr; - } - return match->second.get(); -} - string GetCandidateAlias(const BindingAlias &main_alias, const BindingAlias &new_alias) { string candidate; if (!main_alias.GetCatalog().empty() && !new_alias.GetCatalog().empty()) { @@ -283,7 +274,7 @@ vector> BindContext::GetBindings(const BindingAlias &alias, E } vector> matching_bindings; for (auto &binding : bindings_list) { - if (binding->alias.Matches(alias)) { + if (binding->GetBindingAlias().Matches(alias)) { matching_bindings.push_back(*binding); } } @@ -291,7 +282,7 @@ vector> BindContext::GetBindings(const BindingAlias &alias, E // alias not found in this BindContext vector candidates; for (auto &binding : bindings_list) { - candidates.push_back(GetCandidateAlias(alias, binding->alias)); + candidates.push_back(GetCandidateAlias(alias, binding->GetBindingAlias())); } auto main_alias = GetCandidateAlias(alias, alias); string candidate_str = @@ -315,14 +306,14 @@ string BindContext::AmbiguityException(const BindingAlias &alias, const vector handled_using_columns; for (auto &entry : bindings_list) { auto &binding = *entry; - for (auto &column_name : binding.names) { - QualifiedColumnName qualified_column(binding.alias, column_name); + auto &column_names = binding.GetColumnNames(); + auto &binding_alias = binding.GetBindingAlias(); + for (auto &column_name : column_names) { + QualifiedColumnName qualified_column(binding_alias, column_name); if (CheckExclusionList(expr, qualified_column, exclusion_info)) { continue; } // check if this column is a USING column - auto using_binding_ptr = GetUsingBinding(column_name, binding.alias); + auto using_binding_ptr = GetUsingBinding(column_name, binding_alias); if (using_binding_ptr) { auto &using_binding = *using_binding_ptr; // it is! @@ -530,7 +524,7 @@ void BindContext::GenerateAllColumnExpressions(StarExpression &expr, continue; } auto new_expr = - CreateColumnReference(binding.alias, column_name, ColumnBindType::DO_NOT_EXPAND_GENERATED_COLUMNS); + CreateColumnReference(binding_alias, column_name, ColumnBindType::DO_NOT_EXPAND_GENERATED_COLUMNS); HandleRename(expr, qualified_column, *new_expr); new_select_list.push_back(std::move(new_expr)); } @@ -548,17 +542,20 @@ void BindContext::GenerateAllColumnExpressions(StarExpression &expr, } is_struct_ref = true; } + auto &binding_alias = binding->GetBindingAlias(); + auto &column_names = binding->GetColumnNames(); + auto &column_types = binding->GetColumnTypes(); if (is_struct_ref) { auto col_idx = binding->GetBindingIndex(expr.relation_name); - auto col_type = binding->types[col_idx]; + auto col_type = column_types[col_idx]; if (col_type.id() != LogicalTypeId::STRUCT) { throw BinderException(StringUtil::Format( "Cannot extract field from expression \"%s\" because it is not a struct", expr.ToString())); } auto &struct_children = StructType::GetChildTypes(col_type); vector column_names(3); - column_names[0] = binding->alias.GetAlias(); + column_names[0] = binding->GetAlias(); column_names[1] = expr.relation_name; for (auto &child : struct_children) { QualifiedColumnName qualified_name(child.first); @@ -571,13 +568,13 @@ void BindContext::GenerateAllColumnExpressions(StarExpression &expr, new_select_list.push_back(std::move(new_expr)); } } else { - for (auto &column_name : binding->names) { - QualifiedColumnName qualified_name(binding->alias, column_name); + for (auto &column_name : column_names) { + QualifiedColumnName qualified_name(binding_alias, column_name); if (CheckExclusionList(expr, qualified_name, exclusion_info)) { continue; } auto new_expr = - CreateColumnReference(binding->alias, column_name, ColumnBindType::DO_NOT_EXPAND_GENERATED_COLUMNS); + CreateColumnReference(binding_alias, column_name, ColumnBindType::DO_NOT_EXPAND_GENERATED_COLUMNS); HandleRename(expr, qualified_name, *new_expr); new_select_list.push_back(std::move(new_expr)); } @@ -613,10 +610,12 @@ void BindContext::GenerateAllColumnExpressions(StarExpression &expr, void BindContext::GetTypesAndNames(vector &result_names, vector &result_types) { for (auto &binding_entry : bindings_list) { auto &binding = *binding_entry; - D_ASSERT(binding.names.size() == binding.types.size()); - for (idx_t i = 0; i < binding.names.size(); i++) { - result_names.push_back(binding.names[i]); - result_types.push_back(binding.types[i]); + auto &column_names = binding.GetColumnNames(); + auto &column_types = binding.GetColumnTypes(); + D_ASSERT(column_names.size() == column_types.size()); + for (idx_t i = 0; i < column_names.size(); i++) { + result_names.push_back(column_names[i]); + result_types.push_back(column_types[i]); } } } @@ -686,7 +685,7 @@ vector BindContext::AliasColumnNames(const string &table_name, const vec return result; } -void BindContext::AddSubquery(idx_t index, const string &alias, SubqueryRef &ref, BoundQueryNode &subquery) { +void BindContext::AddSubquery(idx_t index, const string &alias, SubqueryRef &ref, BoundStatement &subquery) { auto names = AliasColumnNames(alias, subquery.names, ref.column_name_alias); AddGenericBinding(index, alias, names, subquery.types); } @@ -696,13 +695,13 @@ void BindContext::AddEntryBinding(idx_t index, const string &alias, const vector AddBinding(make_uniq(alias, types, names, index, entry)); } -void BindContext::AddView(idx_t index, const string &alias, SubqueryRef &ref, BoundQueryNode &subquery, +void BindContext::AddView(idx_t index, const string &alias, SubqueryRef &ref, BoundStatement &subquery, ViewCatalogEntry &view) { auto names = AliasColumnNames(alias, subquery.names, ref.column_name_alias); AddEntryBinding(index, alias, names, subquery.types, view.Cast()); } -void BindContext::AddSubquery(idx_t index, const string &alias, TableFunctionRef &ref, BoundQueryNode &subquery) { +void BindContext::AddSubquery(idx_t index, const string &alias, TableFunctionRef &ref, BoundStatement &subquery) { auto names = AliasColumnNames(alias, subquery.names, ref.column_name_alias); AddGenericBinding(index, alias, names, subquery.types); } @@ -712,33 +711,28 @@ void BindContext::AddGenericBinding(idx_t index, const string &alias, const vect AddBinding(make_uniq(BindingType::BASE, BindingAlias(alias), types, names, index)); } -void BindContext::AddCTEBinding(idx_t index, const string &alias, const vector &names, - const vector &types, bool using_key) { - auto binding = make_shared_ptr(BindingType::BASE, BindingAlias(alias), types, names, index); - - if (cte_bindings.find(alias) != cte_bindings.end()) { - throw BinderException("Duplicate CTE binding \"%s\" in query!", alias); +void BindContext::AddCTEBinding(unique_ptr binding) { + for (auto &cte_binding : cte_bindings) { + if (cte_binding->GetBindingAlias() == binding->GetBindingAlias()) { + throw BinderException("Duplicate CTE binding \"%s\" in query!", binding->GetBindingAlias().ToString()); + } } - cte_bindings[alias] = std::move(binding); - cte_references[alias] = make_shared_ptr(0); + cte_bindings.push_back(std::move(binding)); +} - if (using_key) { - auto recurring_alias = "recurring." + alias; - cte_bindings[recurring_alias] = - make_shared_ptr(BindingType::BASE, BindingAlias(recurring_alias), types, names, index); - cte_references[recurring_alias] = make_shared_ptr(0); - } +void BindContext::AddCTEBinding(idx_t index, BindingAlias alias_p, const vector &names, + const vector &types, CTEType cte_type) { + auto binding = make_uniq(std::move(alias_p), types, names, index, cte_type); + AddCTEBinding(std::move(binding)); } -void BindContext::RemoveCTEBinding(const std::string &alias) { - auto it = cte_bindings.find(alias); - if (it != cte_bindings.end()) { - cte_bindings.erase(it); - } - auto it2 = cte_references.find(alias); - if (it2 != cte_references.end()) { - cte_references.erase(it2); +optional_ptr BindContext::GetCTEBinding(const BindingAlias &ctename) { + for (auto &binding : cte_bindings) { + if (binding->GetBindingAlias().Matches(ctename)) { + return binding.get(); + } } + return nullptr; } void BindContext::AddContext(BindContext other) { @@ -755,7 +749,7 @@ void BindContext::AddContext(BindContext other) { vector BindContext::GetBindingAliases() { vector result; for (auto &binding : bindings_list) { - result.push_back(BindingAlias(binding->alias)); + result.push_back(binding->GetBindingAlias()); } return result; } @@ -782,7 +776,7 @@ void BindContext::RemoveContext(const vector &aliases) { // remove the binding from the list of bindings auto it = std::remove_if(bindings_list.begin(), bindings_list.end(), - [&](unique_ptr &x) { return x->alias == alias; }); + [&](unique_ptr &x) { return x->GetBindingAlias() == alias; }); bindings_list.erase(it, bindings_list.end()); } } diff --git a/src/duckdb/src/planner/binder.cpp b/src/duckdb/src/planner/binder.cpp index 2ba52b64f..fe1b59cf3 100644 --- a/src/duckdb/src/planner/binder.cpp +++ b/src/duckdb/src/planner/binder.cpp @@ -28,10 +28,6 @@ namespace duckdb { -Binder &Binder::GetRootBinder() { - return root_binder; -} - idx_t Binder::GetBinderDepth() const { return depth; } @@ -50,9 +46,11 @@ shared_ptr Binder::CreateBinder(ClientContext &context, optional_ptr parent_p, BinderType binder_type) - : context(context), bind_context(*this), parent(std::move(parent_p)), bound_tables(0), binder_type(binder_type), - entry_retriever(context), root_binder(parent ? parent->GetRootBinder() : *this), - depth(parent ? parent->GetBinderDepth() : 1) { + : context(context), bind_context(*this), parent(std::move(parent_p)), binder_type(binder_type), + global_binder_state(parent ? parent->global_binder_state : make_shared_ptr()), + query_binder_state(parent && binder_type == BinderType::REGULAR_BINDER ? parent->query_binder_state + : make_shared_ptr()), + entry_retriever(context), depth(parent ? parent->GetBinderDepth() : 1) { IncreaseDepth(); if (parent) { entry_retriever.Inherit(parent->entry_retriever); @@ -60,85 +58,22 @@ Binder::Binder(ClientContext &context, shared_ptr parent_p, BinderType b // We have to inherit macro and lambda parameter bindings and from the parent binder, if there is a parent. macro_binding = parent->macro_binding; lambda_bindings = parent->lambda_bindings; - - if (binder_type == BinderType::REGULAR_BINDER) { - // We have to inherit CTE bindings from the parent bind_context, if there is a parent. - bind_context.SetCTEBindings(parent->bind_context.GetCTEBindings()); - bind_context.cte_references = parent->bind_context.cte_references; - parameters = parent->parameters; - } - } -} - -unique_ptr Binder::BindMaterializedCTE(CommonTableExpressionMap &cte_map) { - // Extract materialized CTEs from cte_map - vector> materialized_ctes; - for (auto &cte : cte_map.map) { - auto &cte_entry = cte.second; - auto mat_cte = make_uniq(); - mat_cte->ctename = cte.first; - mat_cte->query = cte_entry->query->node->Copy(); - mat_cte->aliases = cte_entry->aliases; - mat_cte->materialized = cte_entry->materialized; - materialized_ctes.push_back(std::move(mat_cte)); - } - - if (materialized_ctes.empty()) { - return nullptr; - } - - unique_ptr cte_root = nullptr; - while (!materialized_ctes.empty()) { - unique_ptr node_result; - node_result = std::move(materialized_ctes.back()); - node_result->cte_map = cte_map.Copy(); - if (cte_root) { - node_result->child = std::move(cte_root); - } else { - node_result->child = nullptr; - } - cte_root = std::move(node_result); - materialized_ctes.pop_back(); } - - AddCTEMap(cte_map); - auto bound_cte = BindCTE(cte_root->Cast()); - - return bound_cte; } template BoundStatement Binder::BindWithCTE(T &statement) { - BoundStatement bound_statement; - auto bound_cte = BindMaterializedCTE(statement.template Cast().cte_map); - if (bound_cte) { - reference tail_ref = *bound_cte; - - while (tail_ref.get().child && tail_ref.get().child->type == QueryNodeType::CTE_NODE) { - tail_ref = tail_ref.get().child->Cast(); - } - - auto &tail = tail_ref.get(); - bound_statement = tail.child_binder->Bind(statement.template Cast()); - - tail.types = bound_statement.types; - tail.names = bound_statement.names; - - for (auto &c : tail.query_binder->correlated_columns) { - tail.child_binder->AddCorrelatedColumn(c); - } - MoveCorrelatedExpressions(*tail.child_binder); - - auto plan = std::move(bound_statement.plan); - bound_statement.plan = CreatePlan(*bound_cte, std::move(plan)); - } else { - bound_statement = Bind(statement.template Cast()); + auto &cte_map = statement.cte_map; + if (cte_map.map.empty()) { + return Bind(statement); } - return bound_statement; + + auto stmt_node = make_uniq(statement); + stmt_node->cte_map = cte_map.Copy(); + return Bind(*stmt_node); } BoundStatement Binder::Bind(SQLStatement &statement) { - root_statement = &statement; switch (statement.type) { case StatementType::SELECT_STATEMENT: return Bind(statement.Cast()); @@ -198,64 +133,12 @@ BoundStatement Binder::Bind(SQLStatement &statement) { } // LCOV_EXCL_STOP } -void Binder::AddCTEMap(CommonTableExpressionMap &cte_map) { - for (auto &cte_it : cte_map.map) { - AddCTE(cte_it.first); - } -} - -unique_ptr Binder::BindNode(QueryNode &node) { - // first we visit the set of CTEs and add them to the bind context - AddCTEMap(node.cte_map); - // now we bind the node - unique_ptr result; - switch (node.type) { - case QueryNodeType::SELECT_NODE: - result = BindNode(node.Cast()); - break; - case QueryNodeType::RECURSIVE_CTE_NODE: - result = BindNode(node.Cast()); - break; - case QueryNodeType::CTE_NODE: - result = BindNode(node.Cast()); - break; - default: - D_ASSERT(node.type == QueryNodeType::SET_OPERATION_NODE); - result = BindNode(node.Cast()); - break; - } - return result; -} - BoundStatement Binder::Bind(QueryNode &node) { - BoundStatement result; - auto bound_node = BindNode(node); - - result.names = bound_node->names; - result.types = bound_node->types; - - // and plan it - result.plan = CreatePlan(*bound_node); - return result; + return BindNode(node); } -unique_ptr Binder::CreatePlan(BoundQueryNode &node) { - switch (node.type) { - case QueryNodeType::SELECT_NODE: - return CreatePlan(node.Cast()); - case QueryNodeType::SET_OPERATION_NODE: - return CreatePlan(node.Cast()); - case QueryNodeType::RECURSIVE_CTE_NODE: - return CreatePlan(node.Cast()); - case QueryNodeType::CTE_NODE: - return CreatePlan(node.Cast()); - default: - throw InternalException("Unsupported bound query node type"); - } -} - -unique_ptr Binder::Bind(TableRef &ref) { - unique_ptr result; +BoundStatement Binder::Bind(TableRef &ref) { + BoundStatement result; switch (ref.type) { case TableReferenceType::BASE_TABLE: result = Bind(ref.Cast()); @@ -295,80 +178,33 @@ unique_ptr Binder::Bind(TableRef &ref) { default: throw InternalException("Unknown table ref type (%s)", EnumUtil::ToString(ref.type)); } - result->sample = std::move(ref.sample); - return result; -} - -unique_ptr Binder::CreatePlan(BoundTableRef &ref) { - unique_ptr root; - switch (ref.type) { - case TableReferenceType::BASE_TABLE: - root = CreatePlan(ref.Cast()); - break; - case TableReferenceType::SUBQUERY: - root = CreatePlan(ref.Cast()); - break; - case TableReferenceType::JOIN: - root = CreatePlan(ref.Cast()); - break; - case TableReferenceType::TABLE_FUNCTION: - root = CreatePlan(ref.Cast()); - break; - case TableReferenceType::EMPTY_FROM: - root = CreatePlan(ref.Cast()); - break; - case TableReferenceType::EXPRESSION_LIST: - root = CreatePlan(ref.Cast()); - break; - case TableReferenceType::COLUMN_DATA: - root = CreatePlan(ref.Cast()); - break; - case TableReferenceType::CTE: - root = CreatePlan(ref.Cast()); - break; - case TableReferenceType::PIVOT: - root = CreatePlan(ref.Cast()); - break; - case TableReferenceType::DELIM_GET: - root = CreatePlan(ref.Cast()); - break; - case TableReferenceType::INVALID: - default: - throw InternalException("Unsupported bound table ref type (%s)", EnumUtil::ToString(ref.type)); - } - // plan the sample clause if (ref.sample) { - root = make_uniq(std::move(ref.sample), std::move(root)); - } - return root; -} - -void Binder::AddCTE(const string &name) { - D_ASSERT(!name.empty()); - CTE_bindings.insert(name); -} - -vector> Binder::FindCTE(const string &name, bool skip) { - auto entry = bind_context.GetCTEBinding(name); - vector> ctes; - if (entry) { - ctes.push_back(*entry.get()); - } - if (parent && binder_type == BinderType::REGULAR_BINDER) { - auto parent_ctes = parent->FindCTE(name, name == alias); - ctes.insert(ctes.end(), parent_ctes.begin(), parent_ctes.end()); + result.plan = make_uniq(std::move(ref.sample), std::move(result.plan)); } - return ctes; + return result; } -bool Binder::CTEExists(const string &name) { - if (CTE_bindings.find(name) != CTE_bindings.end()) { - return true; - } - if (parent && binder_type == BinderType::REGULAR_BINDER) { - return parent->CTEExists(name); +optional_ptr Binder::GetCTEBinding(const BindingAlias &name) { + reference current_binder(*this); + optional_ptr result; + while (true) { + auto ¤t = current_binder.get(); + auto entry = current.bind_context.GetCTEBinding(name); + if (entry) { + // we only directly return the CTE if it can be referenced + // if it cannot be referenced (circular reference) we keep going up the stack + // to look for a CTE that can be referenced + if (entry->CanBeReferenced()) { + return entry; + } + result = entry; + } + if (!current.parent || current.binder_type != BinderType::REGULAR_BINDER) { + break; + } + current_binder = *current.parent; } - return false; + return result; } void Binder::AddBoundView(ViewCatalogEntry &view) { @@ -384,13 +220,19 @@ void Binder::AddBoundView(ViewCatalogEntry &view) { } idx_t Binder::GenerateTableIndex() { - auto &root_binder = GetRootBinder(); - return root_binder.bound_tables++; + return global_binder_state->bound_tables++; } StatementProperties &Binder::GetStatementProperties() { - auto &root_binder = GetRootBinder(); - return root_binder.prop; + return global_binder_state->prop; +} + +optional_ptr Binder::GetParameters() { + return global_binder_state->parameters; +} + +void Binder::SetParameters(BoundParameterMap ¶meters) { + global_binder_state->parameters = parameters; } void Binder::PushExpressionBinder(ExpressionBinder &binder) { @@ -416,17 +258,11 @@ bool Binder::HasActiveBinder() { } vector> &Binder::GetActiveBinders() { - reference root = *this; - while (root.get().parent && root.get().binder_type == BinderType::REGULAR_BINDER) { - root = *root.get().parent; - } - auto &root_binder = root.get(); - return root_binder.active_binders; + return query_binder_state->active_binders; } void Binder::AddUsingBindingSet(unique_ptr set) { - auto &root_binder = GetRootBinder(); - root_binder.bind_context.AddUsingBindingSet(std::move(set)); + global_binder_state->using_column_sets.push_back(std::move(set)); } void Binder::MoveCorrelatedExpressions(Binder &other) { @@ -434,7 +270,7 @@ void Binder::MoveCorrelatedExpressions(Binder &other) { other.correlated_columns.clear(); } -void Binder::MergeCorrelatedColumns(vector &other) { +void Binder::MergeCorrelatedColumns(CorrelatedColumns &other) { for (idx_t i = 0; i < other.size(); i++) { AddCorrelatedColumn(other[i]); } @@ -443,7 +279,7 @@ void Binder::MergeCorrelatedColumns(vector &other) { void Binder::AddCorrelatedColumn(const CorrelatedColumnInfo &info) { // we only add correlated columns to the list if they are not already there if (std::find(correlated_columns.begin(), correlated_columns.end(), info) == correlated_columns.end()) { - correlated_columns.push_back(info); + correlated_columns.AddColumn(info); } } @@ -463,7 +299,6 @@ optional_ptr Binder::GetMatchingBinding(const string &catalog_name, con const string &table_name, const string &column_name, ErrorData &error) { optional_ptr binding; - D_ASSERT(!lambda_bindings); if (macro_binding && table_name == macro_binding->GetAlias()) { binding = optional_ptr(macro_binding.get()); } else { @@ -474,13 +309,11 @@ optional_ptr Binder::GetMatchingBinding(const string &catalog_name, con } void Binder::SetBindingMode(BindingMode mode) { - auto &root_binder = GetRootBinder(); - root_binder.mode = mode; + global_binder_state->mode = mode; } BindingMode Binder::GetBindingMode() { - auto &root_binder = GetRootBinder(); - return root_binder.mode; + return global_binder_state->mode; } void Binder::SetCanContainNulls(bool can_contain_nulls_p) { @@ -493,30 +326,26 @@ void Binder::SetAlwaysRequireRebind() { } void Binder::AddTableName(string table_name) { - auto &root_binder = GetRootBinder(); - root_binder.table_names.insert(std::move(table_name)); + global_binder_state->table_names.insert(std::move(table_name)); } void Binder::AddReplacementScan(const string &table_name, unique_ptr replacement) { - auto &root_binder = GetRootBinder(); - auto it = root_binder.replacement_scans.find(table_name); + auto it = global_binder_state->replacement_scans.find(table_name); replacement->column_name_alias.clear(); replacement->alias.clear(); - if (it == root_binder.replacement_scans.end()) { - root_binder.replacement_scans[table_name] = std::move(replacement); + if (it == global_binder_state->replacement_scans.end()) { + global_binder_state->replacement_scans[table_name] = std::move(replacement); } else { // A replacement scan by this name was previously registered, we can just use it } } const unordered_set &Binder::GetTableNames() { - auto &root_binder = GetRootBinder(); - return root_binder.table_names; + return global_binder_state->table_names; } case_insensitive_map_t> &Binder::GetReplacementScans() { - auto &root_binder = GetRootBinder(); - return root_binder.replacement_scans; + return global_binder_state->replacement_scans; } // FIXME: this is extremely naive @@ -537,7 +366,6 @@ void VerifyNotExcluded(const ParsedExpression &root_expr) { BoundStatement Binder::BindReturning(vector> returning_list, TableCatalogEntry &table, const string &alias, idx_t update_table_index, unique_ptr child_operator, virtual_column_map_t virtual_columns) { - vector types; vector names; @@ -582,7 +410,7 @@ BoundStatement Binder::BindReturning(vector> return // returned, it should be guaranteed that the row has been inserted. // see https://github.com/duckdb/duckdb/issues/8310 auto &properties = GetStatementProperties(); - properties.allow_stream_result = false; + properties.output_type = QueryResultOutputType::FORCE_MATERIALIZED; properties.return_type = StatementReturnType::QUERY_RESULT; return result; } diff --git a/src/duckdb/src/planner/binder/expression/bind_between_expression.cpp b/src/duckdb/src/planner/binder/expression/bind_between_expression.cpp index 09c92dd48..fe7e34e53 100644 --- a/src/duckdb/src/planner/binder/expression/bind_between_expression.cpp +++ b/src/duckdb/src/planner/binder/expression/bind_between_expression.cpp @@ -32,7 +32,6 @@ BindResult ExpressionBinder::BindExpression(BetweenExpression &expr, idx_t depth LogicalType input_type; if (!BoundComparisonExpression::TryBindComparison(context, input_sql_type, lower_sql_type, input_type, expr.GetExpressionType())) { - throw BinderException(expr, "Cannot mix values of type %s and %s in BETWEEN clause - an explicit cast is required", input_sql_type.ToString(), lower_sql_type.ToString()); diff --git a/src/duckdb/src/planner/binder/expression/bind_columnref_expression.cpp b/src/duckdb/src/planner/binder/expression/bind_columnref_expression.cpp index 886a1ff42..0c0d982d5 100644 --- a/src/duckdb/src/planner/binder/expression/bind_columnref_expression.cpp +++ b/src/duckdb/src/planner/binder/expression/bind_columnref_expression.cpp @@ -94,12 +94,12 @@ unique_ptr ExpressionBinder::QualifyColumnName(const string &c // bind as a macro column if (is_macro_column) { - return binder.bind_context.CreateColumnReference(binder.macro_binding->alias, column_name); + return binder.bind_context.CreateColumnReference(binder.macro_binding->GetBindingAlias(), column_name); } // bind as a regular column if (table_binding) { - return binder.bind_context.CreateColumnReference(table_binding->alias, column_name); + return binder.bind_context.CreateColumnReference(table_binding->GetBindingAlias(), column_name); } // it's not, find candidates and error @@ -111,7 +111,6 @@ unique_ptr ExpressionBinder::QualifyColumnName(const string &c void ExpressionBinder::QualifyColumnNames(unique_ptr &expr, vector> &lambda_params, const bool within_function_expression) { - bool next_within_function_expression = false; switch (expr->GetExpressionType()) { case ExpressionType::COLUMN_REF: { @@ -177,7 +176,6 @@ void ExpressionBinder::QualifyColumnNames(unique_ptr &expr, void ExpressionBinder::QualifyColumnNamesInLambda(FunctionExpression &function, vector> &lambda_params) { - for (auto &child : function.children) { if (child->GetExpressionClass() != ExpressionClass::LAMBDA) { // not a lambda expression @@ -228,7 +226,6 @@ void ExpressionBinder::QualifyColumnNames(ExpressionBinder &expression_binder, u unique_ptr ExpressionBinder::CreateStructExtract(unique_ptr base, const string &field_name) { - vector> children; children.push_back(std::move(base)); children.push_back(make_uniq_base(Value(field_name))); @@ -276,11 +273,12 @@ unique_ptr ExpressionBinder::CreateStructPack(ColumnRefExpress } // We found the table, now create the struct_pack expression + auto &column_names = binding->GetColumnNames(); vector> child_expressions; - child_expressions.reserve(binding->names.size()); - for (const auto &column_name : binding->names) { + child_expressions.reserve(column_names.size()); + for (const auto &column_name : column_names) { child_expressions.push_back(binder.bind_context.CreateColumnReference( - binding->alias, column_name, ColumnBindType::DO_NOT_EXPAND_GENERATED_COLUMNS)); + binding->GetBindingAlias(), column_name, ColumnBindType::DO_NOT_EXPAND_GENERATED_COLUMNS)); } return make_uniq("struct_pack", std::move(child_expressions)); } @@ -312,7 +310,7 @@ unique_ptr ExpressionBinder::QualifyColumnNameWithManyDotsInte if (binding) { // part1 is a catalog - the column reference is "catalog.schema.table.column" struct_extract_start = 4; - return binder.bind_context.CreateColumnReference(binding->alias, col_ref.column_names[3]); + return binder.bind_context.CreateColumnReference(binding->GetBindingAlias(), col_ref.column_names[3]); } } ErrorData catalog_table_error; @@ -321,7 +319,7 @@ unique_ptr ExpressionBinder::QualifyColumnNameWithManyDotsInte if (binding) { // part1 is a catalog - the column reference is "catalog.table.column" struct_extract_start = 3; - return binder.bind_context.CreateColumnReference(binding->alias, col_ref.column_names[2]); + return binder.bind_context.CreateColumnReference(binding->GetBindingAlias(), col_ref.column_names[2]); } ErrorData schema_table_error; binding = binder.GetMatchingBinding(col_ref.column_names[0], col_ref.column_names[1], col_ref.column_names[2], @@ -330,7 +328,7 @@ unique_ptr ExpressionBinder::QualifyColumnNameWithManyDotsInte // part1 is a schema - the column reference is "schema.table.column" // any additional fields are turned into struct_extract calls struct_extract_start = 3; - return binder.bind_context.CreateColumnReference(binding->alias, col_ref.column_names[2]); + return binder.bind_context.CreateColumnReference(binding->GetBindingAlias(), col_ref.column_names[2]); } ErrorData table_column_error; binding = binder.GetMatchingBinding(col_ref.column_names[0], col_ref.column_names[1], table_column_error); @@ -339,7 +337,7 @@ unique_ptr ExpressionBinder::QualifyColumnNameWithManyDotsInte // the column reference is "table.column" // any additional fields are turned into struct_extract calls struct_extract_start = 2; - return binder.bind_context.CreateColumnReference(binding->alias, col_ref.column_names[1]); + return binder.bind_context.CreateColumnReference(binding->GetBindingAlias(), col_ref.column_names[1]); } // part1 could be a column ErrorData unused_error; @@ -360,7 +358,7 @@ unique_ptr ExpressionBinder::QualifyColumnNameWithManyDotsInte optional_idx schema_pos; optional_idx table_pos; for (const auto &binding_entry : binder.bind_context.GetBindingsList()) { - auto &alias = binding_entry->alias; + auto &alias = binding_entry->GetBindingAlias(); string catalog = alias.GetCatalog(); string schema = alias.GetSchema(); string table = alias.GetAlias(); @@ -483,7 +481,7 @@ unique_ptr ExpressionBinder::QualifyColumnName(ColumnRefExpres auto binding = binder.GetMatchingBinding(col_ref.column_names[0], col_ref.column_names[1], error); if (binding) { // it is! return the column reference directly - return binder.bind_context.CreateColumnReference(binding->alias, col_ref.GetColumnName()); + return binder.bind_context.CreateColumnReference(binding->GetBindingAlias(), col_ref.GetColumnName()); } // otherwise check if we can turn this into a struct extract diff --git a/src/duckdb/src/planner/binder/expression/bind_function_expression.cpp b/src/duckdb/src/planner/binder/expression/bind_function_expression.cpp index e0d775db1..6befdf0ed 100644 --- a/src/duckdb/src/planner/binder/expression/bind_function_expression.cpp +++ b/src/duckdb/src/planner/binder/expression/bind_function_expression.cpp @@ -195,7 +195,7 @@ BindResult ExpressionBinder::BindFunction(FunctionExpression &function, ScalarFu } if (result->GetExpressionType() == ExpressionType::BOUND_FUNCTION) { auto &bound_function = result->Cast(); - if (bound_function.function.stability == FunctionStability::CONSISTENT_WITHIN_QUERY) { + if (bound_function.function.GetStability() == FunctionStability::CONSISTENT_WITHIN_QUERY) { binder.SetAlwaysRequireRebind(); } } @@ -204,7 +204,6 @@ BindResult ExpressionBinder::BindFunction(FunctionExpression &function, ScalarFu BindResult ExpressionBinder::BindLambdaFunction(FunctionExpression &function, ScalarFunctionCatalogEntry &func, idx_t depth) { - // get the callback function for the lambda parameter types auto &scalar_function = func.functions.functions.front(); auto &bind_lambda_function = scalar_function.bind_lambda; @@ -302,13 +301,14 @@ BindResult ExpressionBinder::BindLambdaFunction(FunctionExpression &function, Sc idx_t offset = 0; if (lambda_bindings) { for (idx_t i = lambda_bindings->size(); i > 0; i--) { - auto &binding = (*lambda_bindings)[i - 1]; - D_ASSERT(binding.names.size() == binding.types.size()); + auto &column_names = binding.GetColumnNames(); + auto &column_types = binding.GetColumnTypes(); + D_ASSERT(column_names.size() == column_types.size()); - for (idx_t column_idx = binding.names.size(); column_idx > 0; column_idx--) { - auto bound_lambda_param = make_uniq(binding.names[column_idx - 1], - binding.types[column_idx - 1], offset); + for (idx_t column_idx = column_names.size(); column_idx > 0; column_idx--) { + auto bound_lambda_param = make_uniq(column_names[column_idx - 1], + column_types[column_idx - 1], offset); offset++; bound_function_expr.children.push_back(std::move(bound_lambda_param)); } diff --git a/src/duckdb/src/planner/binder/expression/bind_lambda.cpp b/src/duckdb/src/planner/binder/expression/bind_lambda.cpp index 592daa245..0d6334fc4 100644 --- a/src/duckdb/src/planner/binder/expression/bind_lambda.cpp +++ b/src/duckdb/src/planner/binder/expression/bind_lambda.cpp @@ -12,31 +12,30 @@ namespace duckdb { -idx_t GetLambdaParamCount(const vector &lambda_bindings) { +idx_t GetLambdaParamCount(vector &lambda_bindings) { idx_t count = 0; for (auto &binding : lambda_bindings) { - count += binding.names.size(); + count += binding.GetColumnCount(); } return count; } -idx_t GetLambdaParamIndex(const vector &lambda_bindings, const BoundLambdaExpression &bound_lambda_expr, +idx_t GetLambdaParamIndex(vector &lambda_bindings, const BoundLambdaExpression &bound_lambda_expr, const BoundLambdaRefExpression &bound_lambda_ref_expr) { D_ASSERT(bound_lambda_ref_expr.lambda_idx < lambda_bindings.size()); idx_t offset = 0; // count the remaining lambda parameters BEFORE the current lambda parameter, // as these will be in front of the current lambda parameter in the input chunk for (idx_t i = bound_lambda_ref_expr.lambda_idx + 1; i < lambda_bindings.size(); i++) { - offset += lambda_bindings[i].names.size(); + offset += lambda_bindings[i].GetColumnCount(); } - offset += - lambda_bindings[bound_lambda_ref_expr.lambda_idx].names.size() - bound_lambda_ref_expr.binding.column_index - 1; + offset += lambda_bindings[bound_lambda_ref_expr.lambda_idx].GetColumnCount() - + bound_lambda_ref_expr.binding.column_index - 1; offset += bound_lambda_expr.parameter_count; return offset; } void ExtractParameter(const ParsedExpression &expr, vector &column_names, vector &column_aliases) { - auto &column_ref = expr.Cast(); if (column_ref.IsQualified()) { throw BinderException(LambdaExpression::InvalidParametersErrorMessage()); @@ -47,7 +46,6 @@ void ExtractParameter(const ParsedExpression &expr, vector &column_names } void ExtractParameters(LambdaExpression &expr, vector &column_names, vector &column_aliases) { - // extract the lambda parameters, which are a single column // reference, or a list of column references (ROW function) string error_message; @@ -136,28 +134,26 @@ void ExpressionBinder::TransformCapturedLambdaColumn(unique_ptr &ori BoundLambdaExpression &bound_lambda_expr, const optional_ptr bind_lambda_function, const vector &function_child_types) { - // check if the original expression is a lambda parameter if (original->GetExpressionClass() == ExpressionClass::BOUND_LAMBDA_REF) { - auto &bound_lambda_ref = original->Cast(); auto alias = bound_lambda_ref.GetAlias(); // refers to a lambda parameter outside the current lambda function // so the lambda parameter will be inside the lambda_bindings if (lambda_bindings && bound_lambda_ref.lambda_idx != lambda_bindings->size()) { - auto &binding = (*lambda_bindings)[bound_lambda_ref.lambda_idx]; - D_ASSERT(binding.names.size() == binding.types.size()); + auto &column_names = binding.GetColumnNames(); + auto &column_types = binding.GetColumnTypes(); + D_ASSERT(column_names.size() == column_types.size()); // find the matching dummy column in the lambda binding - for (idx_t column_idx = 0; column_idx < binding.names.size(); column_idx++) { + for (idx_t column_idx = 0; column_idx < binding.GetColumnCount(); column_idx++) { if (column_idx == bound_lambda_ref.binding.column_index) { - // now create the replacement auto index = GetLambdaParamIndex(*lambda_bindings, bound_lambda_expr, bound_lambda_ref); - replacement = make_uniq(binding.names[column_idx], - binding.types[column_idx], index); + replacement = + make_uniq(column_names[column_idx], column_types[column_idx], index); return; } } @@ -188,7 +184,6 @@ void ExpressionBinder::TransformCapturedLambdaColumn(unique_ptr &ori void ExpressionBinder::CaptureLambdaColumns(BoundLambdaExpression &bound_lambda_expr, unique_ptr &expr, const optional_ptr bind_lambda_function, const vector &function_child_types) { - if (expr->GetExpressionClass() == ExpressionClass::BOUND_SUBQUERY) { throw BinderException("subqueries in lambda expressions are not supported"); } @@ -206,7 +201,6 @@ void ExpressionBinder::CaptureLambdaColumns(BoundLambdaExpression &bound_lambda_ if (expr->GetExpressionClass() == ExpressionClass::BOUND_COLUMN_REF || expr->GetExpressionClass() == ExpressionClass::BOUND_PARAMETER || expr->GetExpressionClass() == ExpressionClass::BOUND_LAMBDA_REF) { - if (expr->GetExpressionClass() == ExpressionClass::BOUND_COLUMN_REF) { // Search for UNNEST. auto &column_binding = expr->Cast().binding; diff --git a/src/duckdb/src/planner/binder/expression/bind_macro_expression.cpp b/src/duckdb/src/planner/binder/expression/bind_macro_expression.cpp index cce06d712..fccf527ff 100644 --- a/src/duckdb/src/planner/binder/expression/bind_macro_expression.cpp +++ b/src/duckdb/src/planner/binder/expression/bind_macro_expression.cpp @@ -11,7 +11,6 @@ namespace duckdb { void ExpressionBinder::ReplaceMacroParametersInLambda(FunctionExpression &function, vector> &lambda_params) { - for (auto &child : function.children) { if (child->GetExpressionClass() != ExpressionClass::LAMBDA) { ReplaceMacroParameters(child, lambda_params); @@ -47,7 +46,6 @@ void ExpressionBinder::ReplaceMacroParametersInLambda(FunctionExpression &functi void ExpressionBinder::ReplaceMacroParameters(unique_ptr &expr, vector> &lambda_params) { - switch (expr->GetExpressionClass()) { case ExpressionClass::COLUMN_REF: { // If the expression is a column reference, we replace it with its argument. @@ -98,6 +96,7 @@ void ExpressionBinder::UnfoldMacroExpression(FunctionExpression &function, Scala // validate the arguments and separate positional and default arguments vector> positional_arguments; InsertionOrderPreservingMap> named_arguments; + binder.lambda_bindings = lambda_bindings; auto bind_result = MacroFunction::BindMacroFunction(binder, macro_func.macros, macro_func.name, function, positional_arguments, named_arguments, depth); if (!bind_result.error.empty()) { diff --git a/src/duckdb/src/planner/binder/expression/bind_parameter_expression.cpp b/src/duckdb/src/planner/binder/expression/bind_parameter_expression.cpp index 109c0ecbd..3fe02467e 100644 --- a/src/duckdb/src/planner/binder/expression/bind_parameter_expression.cpp +++ b/src/duckdb/src/planner/binder/expression/bind_parameter_expression.cpp @@ -8,19 +8,19 @@ namespace duckdb { BindResult ExpressionBinder::BindExpression(ParameterExpression &expr, idx_t depth) { - if (!binder.parameters) { + auto parameters = binder.GetParameters(); + if (!parameters) { throw BinderException("Unexpected prepared parameter. This type of statement can't be prepared!"); } auto parameter_id = expr.identifier; - D_ASSERT(binder.parameters); // Check if a parameter value has already been supplied - auto ¶meter_data = binder.parameters->GetParameterData(); + auto ¶meter_data = parameters->GetParameterData(); auto param_data_it = parameter_data.find(parameter_id); if (param_data_it != parameter_data.end()) { // it has! emit a constant directly auto &data = param_data_it->second; - auto return_type = binder.parameters->GetReturnType(parameter_id); + auto return_type = parameters->GetReturnType(parameter_id); bool is_literal = return_type.id() == LogicalTypeId::INTEGER_LITERAL || return_type.id() == LogicalTypeId::STRING_LITERAL; auto constant = make_uniq(data.GetValue()); @@ -32,7 +32,7 @@ BindResult ExpressionBinder::BindExpression(ParameterExpression &expr, idx_t dep return BindResult(std::move(cast)); } - auto bound_parameter = binder.parameters->BindParameterExpression(expr); + auto bound_parameter = parameters->BindParameterExpression(expr); return BindResult(std::move(bound_parameter)); } diff --git a/src/duckdb/src/planner/binder/expression/bind_star_expression.cpp b/src/duckdb/src/planner/binder/expression/bind_star_expression.cpp index f48fc14e6..4cc0e3a23 100644 --- a/src/duckdb/src/planner/binder/expression/bind_star_expression.cpp +++ b/src/duckdb/src/planner/binder/expression/bind_star_expression.cpp @@ -152,10 +152,15 @@ string Binder::ReplaceColumnsAlias(const string &alias, const string &column_nam void TryTransformStarLike(unique_ptr &root) { // detect "* LIKE [literal]" and similar expressions - if (root->GetExpressionClass() != ExpressionClass::FUNCTION) { + bool inverse = root->GetExpressionType() == ExpressionType::OPERATOR_NOT; + auto &expr = inverse ? root->Cast().children[0] : root; + if (!expr) { return; } - auto &function = root->Cast(); + if (expr->GetExpressionClass() != ExpressionClass::FUNCTION) { + return; + } + auto &function = expr->Cast(); if (function.children.size() < 2 || function.children.size() > 3) { return; } @@ -197,7 +202,7 @@ void TryTransformStarLike(unique_ptr &root) { auto original_alias = root->GetAlias(); auto star_expr = std::move(left); unique_ptr child_expr; - if (function.function_name == "regexp_full_match" && star.exclude_list.empty()) { + if (!inverse && function.function_name == "regexp_full_match" && star.exclude_list.empty()) { // * SIMILAR TO '[regex]' is equivalent to COLUMNS('[regex]') so we can just move the expression directly child_expr = std::move(right); } else { @@ -207,13 +212,20 @@ void TryTransformStarLike(unique_ptr &root) { vector named_parameters; named_parameters.push_back("__lambda_col"); function.children[0] = make_uniq("__lambda_col"); + function.children[1] = std::move(right); + + unique_ptr lambda_body = std::move(expr); + if (inverse) { + vector> root_children; + root_children.push_back(std::move(lambda_body)); + lambda_body = make_uniq(ExpressionType::OPERATOR_NOT, std::move(root_children)); + } + auto lambda = make_uniq(std::move(named_parameters), std::move(lambda_body)); - auto lambda = make_uniq(std::move(named_parameters), std::move(root)); vector> filter_children; filter_children.push_back(std::move(star_expr)); filter_children.push_back(std::move(lambda)); - auto list_filter = make_uniq("list_filter", std::move(filter_children)); - child_expr = std::move(list_filter); + child_expr = make_uniq("list_filter", std::move(filter_children)); } auto columns_expr = make_uniq(); diff --git a/src/duckdb/src/planner/binder/expression/bind_subquery_expression.cpp b/src/duckdb/src/planner/binder/expression/bind_subquery_expression.cpp index d413c88ed..7f03f0e32 100644 --- a/src/duckdb/src/planner/binder/expression/bind_subquery_expression.cpp +++ b/src/duckdb/src/planner/binder/expression/bind_subquery_expression.cpp @@ -13,20 +13,16 @@ class BoundSubqueryNode : public QueryNode { static constexpr const QueryNodeType TYPE = QueryNodeType::BOUND_SUBQUERY_NODE; public: - BoundSubqueryNode(shared_ptr subquery_binder, unique_ptr bound_node, + BoundSubqueryNode(shared_ptr subquery_binder, BoundStatement bound_node, unique_ptr subquery) : QueryNode(QueryNodeType::BOUND_SUBQUERY_NODE), subquery_binder(std::move(subquery_binder)), bound_node(std::move(bound_node)), subquery(std::move(subquery)) { } shared_ptr subquery_binder; - unique_ptr bound_node; + BoundStatement bound_node; unique_ptr subquery; - const vector> &GetSelectList() const override { - throw InternalException("Cannot get select list of bound subquery node"); - } - string ToString() const override { throw InternalException("Cannot ToString bound subquery node"); } @@ -116,15 +112,15 @@ BindResult ExpressionBinder::BindExpression(SubqueryExpression &expr, idx_t dept idx_t expected_columns = 1; if (expr.child) { auto &child = BoundExpression::GetExpression(*expr.child); - ExtractSubqueryChildren(child, child_expressions, bound_subquery.bound_node->types); + ExtractSubqueryChildren(child, child_expressions, bound_subquery.bound_node.types); if (child_expressions.empty()) { child_expressions.push_back(std::move(child)); } expected_columns = child_expressions.size(); } - if (bound_subquery.bound_node->types.size() != expected_columns) { + if (bound_subquery.bound_node.types.size() != expected_columns) { throw BinderException(expr, "Subquery returns %zu columns - expected %d", - bound_subquery.bound_node->types.size(), expected_columns); + bound_subquery.bound_node.types.size(), expected_columns); } } // both binding the child and binding the subquery was successful @@ -132,7 +128,7 @@ BindResult ExpressionBinder::BindExpression(SubqueryExpression &expr, idx_t dept auto subquery_binder = std::move(bound_subquery.subquery_binder); auto bound_node = std::move(bound_subquery.bound_node); LogicalType return_type = - expr.subquery_type == SubqueryType::SCALAR ? bound_node->types[0] : LogicalType(LogicalTypeId::BOOLEAN); + expr.subquery_type == SubqueryType::SCALAR ? bound_node.types[0] : LogicalType(LogicalTypeId::BOOLEAN); if (return_type.id() == LogicalTypeId::UNKNOWN) { return_type = LogicalType::SQLNULL; } @@ -144,7 +140,7 @@ BindResult ExpressionBinder::BindExpression(SubqueryExpression &expr, idx_t dept for (idx_t child_idx = 0; child_idx < child_expressions.size(); child_idx++) { auto &child = child_expressions[child_idx]; auto child_type = ExpressionBinder::GetExpressionReturnType(*child); - auto &subquery_type = bound_node->types[child_idx]; + auto &subquery_type = bound_node.types[child_idx]; LogicalType compare_type; if (!LogicalType::TryGetMaxLogicalType(context, child_type, subquery_type, compare_type)) { throw BinderException( diff --git a/src/duckdb/src/planner/binder/expression/bind_unnest_expression.cpp b/src/duckdb/src/planner/binder/expression/bind_unnest_expression.cpp index fedcb8257..8be35d798 100644 --- a/src/duckdb/src/planner/binder/expression/bind_unnest_expression.cpp +++ b/src/duckdb/src/planner/binder/expression/bind_unnest_expression.cpp @@ -15,19 +15,30 @@ #include "duckdb/function/scalar/nested_functions.hpp" #include "duckdb/execution/expression_executor.hpp" #include "duckdb/planner/expression/bound_cast_expression.hpp" +#include "duckdb/main/settings.hpp" namespace duckdb { -unique_ptr CreateBoundStructExtract(ClientContext &context, unique_ptr expr, string key) { +unique_ptr CreateBoundStructExtract(ClientContext &context, unique_ptr expr, + const vector &key_path, bool keep_parent_names) { vector> arguments; arguments.push_back(std::move(expr)); - arguments.push_back(make_uniq(Value(key))); + arguments.push_back(make_uniq(Value(key_path.back()))); auto extract_function = GetKeyExtractFunction(); auto bind_info = extract_function.bind(context, extract_function, arguments); - auto return_type = extract_function.return_type; + auto return_type = extract_function.GetReturnType(); auto result = make_uniq(return_type, std::move(extract_function), std::move(arguments), std::move(bind_info)); - result->SetAlias(std::move(key)); + + if (keep_parent_names) { + auto alias = StringUtil::Join(key_path, "."); + if (!alias.empty() && alias[0] == '.') { + alias = alias.substr(1); + } + result->SetAlias(alias); + } else { + result->SetAlias(key_path[0]); + } return std::move(result); } @@ -37,7 +48,7 @@ unique_ptr CreateBoundStructExtractIndex(ClientContext &context, uni arguments.push_back(make_uniq(Value::BIGINT(int64_t(key)))); auto extract_function = GetIndexExtractFunction(); auto bind_info = extract_function.bind(context, extract_function, arguments); - auto return_type = extract_function.return_type; + auto return_type = extract_function.GetReturnType(); auto result = make_uniq(return_type, std::move(extract_function), std::move(arguments), std::move(bind_info)); result->SetAlias("element" + to_string(key)); @@ -65,7 +76,7 @@ BindResult SelectBinder::BindUnnest(FunctionExpression &function, idx_t depth, b ErrorData error; if (function.children.empty()) { - return BindResult(BinderException(function, "UNNEST() requires a single argument")); + return BindResult(BinderException(function, "UNNEST() requires at lease one argument")); } if (inside_window) { return BindResult(BinderException(function, UnsupportedUnnestMessage())); @@ -77,13 +88,10 @@ BindResult SelectBinder::BindUnnest(FunctionExpression &function, idx_t depth, b } idx_t max_depth = 1; + bool keep_parent_names = false; if (function.children.size() != 1) { - bool has_parameter = false; bool supported_argument = false; for (idx_t i = 1; i < function.children.size(); i++) { - if (has_parameter) { - return BindResult(BinderException(function, "UNNEST() only supports a single additional argument")); - } if (function.children[i]->HasParameter()) { throw ParameterNotAllowedException("Parameter not allowed in unnest parameter"); } @@ -107,17 +115,19 @@ BindResult SelectBinder::BindUnnest(FunctionExpression &function, idx_t depth, b if (max_depth == 0) { throw BinderException("UNNEST cannot have a max depth of 0"); } + } else if (alias == "keep_parent_names") { + keep_parent_names = value.GetValue(); } else if (!alias.empty()) { throw BinderException("Unsupported parameter \"%s\" for unnest", alias); } else { break; } - has_parameter = true; supported_argument = true; } if (!supported_argument) { - return BindResult(BinderException(function, "UNNEST - unsupported extra argument, unnest only supports " - "recursive := [true/false] or max_depth := #")); + return BindResult(BinderException( + function, "UNNEST - unsupported extra argument, unnest only supports " + "recursive := [true/false], max_depth := # or keep_parent_names := [true/false]")); } } unnest_level++; @@ -216,7 +226,6 @@ BindResult SelectBinder::BindUnnest(FunctionExpression &function, idx_t depth, b if (struct_unnests > 0) { vector> struct_expressions; struct_expressions.push_back(std::move(unnest_expr)); - for (idx_t i = 0; i < struct_unnests; i++) { vector> new_expressions; // check if there are any structs left @@ -232,7 +241,14 @@ BindResult SelectBinder::BindUnnest(FunctionExpression &function, idx_t depth, b } } else { for (auto &entry : child_types) { - new_expressions.push_back(CreateBoundStructExtract(context, expr->Copy(), entry.first)); + vector current_key_path; + // During recursive expansion, not all expressions are BoundFunctionExpression + if (keep_parent_names && expr->type == ExpressionType::BOUND_FUNCTION) { + current_key_path.push_back(expr->alias); + } + current_key_path.push_back(entry.first); + new_expressions.push_back( + CreateBoundStructExtract(context, expr->Copy(), current_key_path, keep_parent_names)); } } has_structs = true; diff --git a/src/duckdb/src/planner/binder/expression/bind_window_expression.cpp b/src/duckdb/src/planner/binder/expression/bind_window_expression.cpp index 4b950a29c..00b7bcded 100644 --- a/src/duckdb/src/planner/binder/expression/bind_window_expression.cpp +++ b/src/duckdb/src/planner/binder/expression/bind_window_expression.cpp @@ -17,7 +17,6 @@ namespace duckdb { static LogicalType ResolveWindowExpressionType(ExpressionType window_type, const vector &child_types) { - idx_t param_count; switch (window_type) { case ExpressionType::WINDOW_RANK: @@ -115,7 +114,6 @@ static bool IsFillType(const LogicalType &type) { static LogicalType BindRangeExpression(ClientContext &context, const string &name, unique_ptr &expr, unique_ptr &order_expr) { - vector> children; D_ASSERT(order_expr.get()); diff --git a/src/duckdb/src/planner/binder/query_node/bind_cte_node.cpp b/src/duckdb/src/planner/binder/query_node/bind_cte_node.cpp index 2a7cf8346..8cea37ebc 100644 --- a/src/duckdb/src/planner/binder/query_node/bind_cte_node.cpp +++ b/src/duckdb/src/planner/binder/query_node/bind_cte_node.cpp @@ -1,93 +1,169 @@ -#include "duckdb/parser/expression/constant_expression.hpp" -#include "duckdb/parser/expression_map.hpp" -#include "duckdb/parser/query_node/select_node.hpp" #include "duckdb/parser/query_node/cte_node.hpp" #include "duckdb/planner/binder.hpp" -#include "duckdb/planner/query_node/bound_cte_node.hpp" -#include "duckdb/planner/query_node/bound_select_node.hpp" +#include "duckdb/planner/operator/logical_materialized_cte.hpp" +#include "duckdb/parser/query_node/list.hpp" +#include "duckdb/parser/statement/select_statement.hpp" namespace duckdb { -unique_ptr Binder::BindNode(CTENode &statement) { - // first recursively visit the materialized CTE operations - // the left side is visited first and is added to the BindContext of the right side - D_ASSERT(statement.query); +struct BoundCTEData { + string ctename; + CTEMaterialize materialized; + idx_t setop_index; + shared_ptr child_binder; + shared_ptr cte_bind_state; +}; + +BoundStatement Binder::BindNode(QueryNode &node) { + reference current_binder(*this); + vector bound_ctes; + for (auto &cte : node.cte_map.map) { + bound_ctes.push_back(current_binder.get().PrepareCTE(cte.first, *cte.second)); + current_binder = *bound_ctes.back().child_binder; + } + BoundStatement result; + // now we bind the node + switch (node.type) { + case QueryNodeType::SELECT_NODE: + result = current_binder.get().BindNode(node.Cast()); + break; + case QueryNodeType::RECURSIVE_CTE_NODE: + result = current_binder.get().BindNode(node.Cast()); + break; + case QueryNodeType::SET_OPERATION_NODE: + result = current_binder.get().BindNode(node.Cast()); + break; + case QueryNodeType::STATEMENT_NODE: + result = current_binder.get().BindNode(node.Cast()); + break; + default: + throw InternalException("Unsupported query node type"); + } + for (idx_t i = bound_ctes.size(); i > 0; i--) { + auto &finish_binder = i == 1 ? *this : *bound_ctes[i - 2].child_binder; + result = finish_binder.FinishCTE(bound_ctes[i - 1], std::move(result)); + } + return result; +} - return BindCTE(statement); +CTEBindState::CTEBindState(Binder &parent_binder_p, QueryNode &cte_def_p, const vector &aliases_p) + : parent_binder(parent_binder_p), cte_def(cte_def_p), aliases(aliases_p), + active_binder_count(parent_binder.GetActiveBinders().size()) { } -unique_ptr Binder::BindCTE(CTENode &statement) { - auto result = make_uniq(); +CTEBindState::~CTEBindState() { +} - // first recursively visit the materialized CTE operations - // the left side is visited first and is added to the BindContext of the right side - D_ASSERT(statement.query); +bool CTEBindState::IsBound() const { + return query_binder.get() != nullptr; +} + +void CTEBindState::Bind(CTEBinding &binding) { + // we are lazily binding the CTE + // we need to bind it as if we were binding it during PrepareCTE + query_binder = Binder::CreateBinder(parent_binder.context, parent_binder); + + // we clear any expression binders that were added in the mean-time, to ensure we are not binding to any newly added + // correlated columns + auto &active_binders = parent_binder.GetActiveBinders(); + vector> stored_binders; + for (idx_t i = active_binder_count; i < active_binders.size(); i++) { + stored_binders.push_back(active_binders[i]); + } + active_binders.erase(active_binders.begin() + UnsafeNumericCast(active_binder_count), + active_binders.end()); - result->ctename = statement.ctename; - result->materialized = statement.materialized; - result->setop_index = GenerateTableIndex(); + // add this CTE to the query binder on the RHS with "CANNOT_BE_REFERENCED" to detect recursive references to + // ourselves + query_binder->bind_context.AddCTEBinding(binding.GetIndex(), binding.GetBindingAlias(), vector(), + vector(), CTEType::CANNOT_BE_REFERENCED); - AddCTE(result->ctename); + // bind the actual CTE + query = query_binder->Bind(cte_def); - result->query_binder = Binder::CreateBinder(context, this); - result->query = result->query_binder->BindNode(*statement.query); + // after binding - we add the active binders we removed back so we can leave the binder in its original state + for (auto &stored_binder : stored_binders) { + active_binders.push_back(stored_binder); + } // the result types of the CTE are the types of the LHS - result->types = result->query->types; + types = query.types; // names are picked from the LHS, unless aliases are explicitly specified - result->names = result->query->names; - for (idx_t i = 0; i < statement.aliases.size() && i < result->names.size(); i++) { - result->names[i] = statement.aliases[i]; + names = query.names; + for (idx_t i = 0; i < aliases.size() && i < names.size(); i++) { + names[i] = aliases[i]; } // Rename columns if duplicate names are detected idx_t index = 1; - vector names; + vector new_names; // Use a case-insensitive set to track names case_insensitive_set_t ci_names; - for (auto &n : result->names) { + for (auto &n : names) { string name = n; while (ci_names.find(name) != ci_names.end()) { name = n + "_" + std::to_string(index++); } - names.push_back(name); + new_names.push_back(name); ci_names.insert(name); } + names = std::move(new_names); +} + +BoundCTEData Binder::PrepareCTE(const string &ctename, CommonTableExpressionInfo &statement) { + BoundCTEData result; + + // first recursively visit the materialized CTE operations + // the left side is visited first and is added to the BindContext of the right side + D_ASSERT(statement.query); - // This allows the right side to reference the CTE - bind_context.AddGenericBinding(result->setop_index, statement.ctename, names, result->types); + result.ctename = ctename; + result.materialized = statement.materialized; + result.setop_index = GenerateTableIndex(); - result->child_binder = Binder::CreateBinder(context, this); + // instead of eagerly binding the CTE here we add the CTE bind state to the list of CTE bindings + // the CTE is bound lazily - when referenced for the first time we perform the binding + result.cte_bind_state = make_shared_ptr(*this, *statement.query->node, statement.aliases); + + result.child_binder = Binder::CreateBinder(context, this); // Add bindings of left side to temporary CTE bindings context - // If there is already a binding for the CTE, we need to remove it first // as we are binding a CTE currently, we take precendence over the existing binding. // This implements the CTE shadowing behavior. - result->child_binder->bind_context.RemoveCTEBinding(statement.ctename); - result->child_binder->bind_context.AddCTEBinding(result->setop_index, statement.ctename, names, result->types); - - if (statement.child) { - // Move all modifiers to the child node. - for (auto &modifier : statement.modifiers) { - statement.child->modifiers.push_back(std::move(modifier)); - } + auto cte_binding = make_uniq(BindingAlias(ctename), result.cte_bind_state, result.setop_index); + result.child_binder->bind_context.AddCTEBinding(std::move(cte_binding)); + return result; +} - statement.modifiers.clear(); +BoundStatement Binder::FinishCTE(BoundCTEData &bound_cte, BoundStatement child) { + if (!bound_cte.cte_bind_state->IsBound()) { + // CTE was not bound - just ignore it + return child; + } + auto &bind_state = *bound_cte.cte_bind_state; + for (auto &c : bind_state.query_binder->correlated_columns) { + bound_cte.child_binder->AddCorrelatedColumn(c); + } - result->child = result->child_binder->BindNode(*statement.child); - for (auto &c : result->query_binder->correlated_columns) { - result->child_binder->AddCorrelatedColumn(c); - } + BoundStatement result; + // the result types of the CTE are the types of the LHS + result.types = child.types; + result.names = child.names; - // the result types of the CTE are the types of the LHS - result->types = result->child->types; - result->names = result->child->names; + MoveCorrelatedExpressions(*bound_cte.child_binder); + MoveCorrelatedExpressions(*bind_state.query_binder); - MoveCorrelatedExpressions(*result->child_binder); - } + auto cte_query = std::move(bind_state.query.plan); + auto cte_child = std::move(child.plan); - MoveCorrelatedExpressions(*result->query_binder); + auto root = make_uniq(bound_cte.ctename, bound_cte.setop_index, result.types.size(), + std::move(cte_query), std::move(cte_child), bound_cte.materialized); + // check if there are any unplanned subqueries left in either child + has_unplanned_dependent_joins = has_unplanned_dependent_joins || + bound_cte.child_binder->has_unplanned_dependent_joins || + bind_state.query_binder->has_unplanned_dependent_joins; + result.plan = std::move(root); return result; } diff --git a/src/duckdb/src/planner/binder/query_node/bind_recursive_cte_node.cpp b/src/duckdb/src/planner/binder/query_node/bind_recursive_cte_node.cpp index 54e9e9fa5..efc9740de 100644 --- a/src/duckdb/src/planner/binder/query_node/bind_recursive_cte_node.cpp +++ b/src/duckdb/src/planner/binder/query_node/bind_recursive_cte_node.cpp @@ -3,14 +3,12 @@ #include "duckdb/parser/query_node/select_node.hpp" #include "duckdb/parser/query_node/recursive_cte_node.hpp" #include "duckdb/planner/binder.hpp" -#include "duckdb/planner/query_node/bound_recursive_cte_node.hpp" -#include "duckdb/planner/query_node/bound_select_node.hpp" +#include "duckdb/planner/operator/logical_set_operation.hpp" +#include "duckdb/planner/operator/logical_recursive_cte.hpp" namespace duckdb { -unique_ptr Binder::BindNode(RecursiveCTENode &statement) { - auto result = make_uniq(); - +BoundStatement Binder::BindNode(RecursiveCTENode &statement) { // first recursively visit the recursive CTE operations // the left side is visited first and is added to the BindContext of the right side D_ASSERT(statement.left); @@ -19,53 +17,55 @@ unique_ptr Binder::BindNode(RecursiveCTENode &statement) { throw BinderException("UNION ALL cannot be used with USING KEY in recursive CTE."); } - result->ctename = statement.ctename; - result->union_all = statement.union_all; - result->setop_index = GenerateTableIndex(); + auto ctename = statement.ctename; + auto union_all = statement.union_all; + auto setop_index = GenerateTableIndex(); - result->left_binder = Binder::CreateBinder(context, this); - result->left = result->left_binder->BindNode(*statement.left); + auto left_binder = Binder::CreateBinder(context, this); + auto left = left_binder->BindNode(*statement.left); + BoundStatement result; // the result types of the CTE are the types of the LHS - result->types = result->left->types; + result.types = left.types; // names are picked from the LHS, unless aliases are explicitly specified - result->names = result->left->names; - for (idx_t i = 0; i < statement.aliases.size() && i < result->names.size(); i++) { - result->names[i] = statement.aliases[i]; + result.names = left.names; + for (idx_t i = 0; i < statement.aliases.size() && i < result.names.size(); i++) { + result.names[i] = statement.aliases[i]; } // This allows the right side to reference the CTE recursively - bind_context.AddGenericBinding(result->setop_index, statement.ctename, result->names, result->types); + bind_context.AddGenericBinding(setop_index, statement.ctename, result.names, result.types); - result->right_binder = Binder::CreateBinder(context, this); + auto right_binder = Binder::CreateBinder(context, this); // Add bindings of left side to temporary CTE bindings context - // If there is already a binding for the CTE, we need to remove it first - // as we are binding a CTE currently, we take precendence over the existing binding. - // This implements the CTE shadowing behavior. - result->right_binder->bind_context.RemoveCTEBinding(statement.ctename); - result->right_binder->bind_context.AddCTEBinding(result->setop_index, statement.ctename, result->names, - result->types, !statement.key_targets.empty()); - - result->right = result->right_binder->BindNode(*statement.right); - for (auto &c : result->left_binder->correlated_columns) { - result->right_binder->AddCorrelatedColumn(c); + BindingAlias cte_alias(statement.ctename); + right_binder->bind_context.AddCTEBinding(setop_index, std::move(cte_alias), result.names, result.types); + if (!statement.key_targets.empty()) { + BindingAlias recurring_alias("recurring", statement.ctename); + right_binder->bind_context.AddCTEBinding(setop_index, std::move(recurring_alias), result.names, result.types); + } + + auto right = right_binder->BindNode(*statement.right); + for (auto &c : left_binder->correlated_columns) { + right_binder->AddCorrelatedColumn(c); } // move the correlated expressions from the child binders to this binder - MoveCorrelatedExpressions(*result->left_binder); - MoveCorrelatedExpressions(*result->right_binder); + MoveCorrelatedExpressions(*left_binder); + MoveCorrelatedExpressions(*right_binder); + vector> key_targets; // bind specified keys to the referenced column auto expression_binder = ExpressionBinder(*this, context); - for (unique_ptr &expr : statement.key_targets) { + for (auto &expr : statement.key_targets) { auto bound_expr = expression_binder.Bind(expr); D_ASSERT(bound_expr->type == ExpressionType::BOUND_COLUMN_REF); - result->key_targets.push_back(std::move(bound_expr)); + key_targets.push_back(std::move(bound_expr)); } // now both sides have been bound we can resolve types - if (result->left->types.size() != result->right->types.size()) { + if (left.types.size() != right.types.size()) { throw BinderException("Set operations can only apply to expressions with the " "same number of result columns"); } @@ -74,7 +74,42 @@ unique_ptr Binder::BindNode(RecursiveCTENode &statement) { throw NotImplementedException("FIXME: bind modifiers in recursive CTE"); } - return std::move(result); + // Generate the logical plan for the left and right sides of the set operation + left_binder->is_outside_flattened = is_outside_flattened; + right_binder->is_outside_flattened = is_outside_flattened; + + auto left_node = std::move(left.plan); + auto right_node = std::move(right.plan); + + // check if there are any unplanned subqueries left in either child + has_unplanned_dependent_joins = has_unplanned_dependent_joins || left_binder->has_unplanned_dependent_joins || + right_binder->has_unplanned_dependent_joins; + + // for both the left and right sides, cast them to the same types + left_node = CastLogicalOperatorToTypes(left.types, result.types, std::move(left_node)); + right_node = CastLogicalOperatorToTypes(right.types, result.types, std::move(right_node)); + + auto recurring_binding = right_binder->GetCTEBinding(BindingAlias("recurring", ctename)); + bool ref_recurring = recurring_binding && recurring_binding->IsReferenced(); + if (key_targets.empty() && ref_recurring) { + throw InvalidInputException("RECURRING can only be used with USING KEY in recursive CTE."); + } + + // Check if there is a reference to the recursive or recurring table, if not create a set operator. + auto cte_binding = right_binder->GetCTEBinding(BindingAlias(ctename)); + bool ref_cte = cte_binding && cte_binding->IsReferenced(); + if (!ref_cte && !ref_recurring) { + auto root = + make_uniq(setop_index, result.types.size(), std::move(left_node), + std::move(right_node), LogicalOperatorType::LOGICAL_UNION, union_all); + result.plan = std::move(root); + } else { + auto root = make_uniq(ctename, setop_index, result.types.size(), union_all, + std::move(key_targets), std::move(left_node), std::move(right_node)); + root->ref_recurring = ref_recurring; + result.plan = std::move(root); + } + return result; } } // namespace duckdb diff --git a/src/duckdb/src/planner/binder/query_node/bind_select_node.cpp b/src/duckdb/src/planner/binder/query_node/bind_select_node.cpp index 4f52dfc4a..44173573a 100644 --- a/src/duckdb/src/planner/binder/query_node/bind_select_node.cpp +++ b/src/duckdb/src/planner/binder/query_node/bind_select_node.cpp @@ -141,12 +141,27 @@ void Binder::PrepareModifiers(OrderBinder &order_binder, QueryNode &statement, B } } order_binder.SetQueryComponent("DISTINCT ON"); + auto &order_binders = order_binder.GetBinders(); for (auto &distinct_on_target : distinct.distinct_on_targets) { - auto expr = BindOrderExpression(order_binder, std::move(distinct_on_target)); - if (!expr) { - continue; + vector> target_list; + order_binders[0].get().ExpandStarExpression(std::move(distinct_on_target), target_list); + for (auto &target : target_list) { + auto expr = BindOrderExpression(order_binder, std::move(target)); + if (!expr) { + continue; + } + // Skip duplicates + bool duplicate = false; + for (auto &existing : bound_distinct->target_distincts) { + if (expr->Equals(*existing)) { + duplicate = true; + break; + } + } + if (!duplicate) { + bound_distinct->target_distincts.push_back(std::move(expr)); + } } - bound_distinct->target_distincts.push_back(std::move(expr)); } order_binder.SetQueryComponent(); @@ -154,7 +169,6 @@ void Binder::PrepareModifiers(OrderBinder &order_binder, QueryNode &statement, B break; } case ResultModifierType::ORDER_MODIFIER: { - auto &order = mod->Cast(); auto bound_order = make_uniq(); auto &config = DBConfig::GetConfig(context); @@ -363,7 +377,7 @@ void Binder::BindModifiers(BoundQueryNode &result, idx_t table_index, const vect } } -unique_ptr Binder::BindNode(SelectNode &statement) { +BoundStatement Binder::BindNode(SelectNode &statement) { D_ASSERT(statement.from_table); // first bind the FROM table statement @@ -403,21 +417,22 @@ void Binder::BindWhereStarExpression(unique_ptr &expr) { } } -unique_ptr Binder::BindSelectNode(SelectNode &statement, unique_ptr from_table) { - D_ASSERT(from_table); +BoundStatement Binder::BindSelectNode(SelectNode &statement, BoundStatement from_table) { + D_ASSERT(from_table.plan); D_ASSERT(!statement.from_table); - auto result = make_uniq(); - result->projection_index = GenerateTableIndex(); - result->group_index = GenerateTableIndex(); - result->aggregate_index = GenerateTableIndex(); - result->groupings_index = GenerateTableIndex(); - result->window_index = GenerateTableIndex(); - result->prune_index = GenerateTableIndex(); - - result->from_table = std::move(from_table); + auto result_ptr = make_uniq(); + auto &result = *result_ptr; + result.projection_index = GenerateTableIndex(); + result.group_index = GenerateTableIndex(); + result.aggregate_index = GenerateTableIndex(); + result.groupings_index = GenerateTableIndex(); + result.window_index = GenerateTableIndex(); + result.prune_index = GenerateTableIndex(); + + result.from_table = std::move(from_table); // bind the sample clause if (statement.sample) { - result->sample_options = std::move(statement.sample); + result.sample_options = std::move(statement.sample); } // visit the select list and expand any "*" statements @@ -429,19 +444,19 @@ unique_ptr Binder::BindSelectNode(SelectNode &statement, unique_ } statement.select_list = std::move(new_select_list); - auto &bind_state = result->bind_state; + auto &bind_state = result.bind_state; for (idx_t i = 0; i < statement.select_list.size(); i++) { auto &expr = statement.select_list[i]; - result->names.push_back(expr->GetName()); + result.names.push_back(expr->GetName()); ExpressionBinder::QualifyColumnNames(*this, expr); if (!expr->GetAlias().empty()) { bind_state.alias_map[expr->GetAlias()] = i; - result->names[i] = expr->GetAlias(); + result.names[i] = expr->GetAlias(); } bind_state.projection_map[*expr] = i; bind_state.original_expressions.push_back(expr->Copy()); } - result->column_count = statement.select_list.size(); + result.column_count = statement.select_list.size(); // first visit the WHERE clause // the WHERE clause happens before the GROUP BY, PROJECTION or HAVING clauses @@ -452,12 +467,12 @@ unique_ptr Binder::BindSelectNode(SelectNode &statement, unique_ ColumnAliasBinder alias_binder(bind_state); WhereBinder where_binder(*this, context, &alias_binder); unique_ptr condition = std::move(statement.where_clause); - result->where_clause = where_binder.Bind(condition); + result.where_clause = where_binder.Bind(condition); } // now bind all the result modifiers; including DISTINCT and ORDER BY targets OrderBinder order_binder({*this}, statement, bind_state); - PrepareModifiers(order_binder, statement, *result); + PrepareModifiers(order_binder, statement, result); vector> unbound_groups; BoundGroupInformation info; @@ -465,9 +480,8 @@ unique_ptr Binder::BindSelectNode(SelectNode &statement, unique_ if (!group_expressions.empty()) { // the statement has a GROUP BY clause, bind it unbound_groups.resize(group_expressions.size()); - GroupBinder group_binder(*this, context, statement, result->group_index, bind_state, info.alias_map); + GroupBinder group_binder(*this, context, statement, result.group_index, bind_state, info.alias_map); for (idx_t i = 0; i < group_expressions.size(); i++) { - // we keep a copy of the unbound expression; // we keep the unbound copy around to check for group references in the SELECT and HAVING clause // the reason we want the unbound copy is because we want to figure out whether an expression @@ -489,7 +503,7 @@ unique_ptr Binder::BindSelectNode(SelectNode &statement, unique_ if (!contains_subquery && requires_collation) { // if there is a collation on a group x, we should group by the collated expr, // but also push a first(x) aggregate in case x is selected (uncollated) - info.collated_groups[i] = result->aggregates.size(); + info.collated_groups[i] = result.aggregates.size(); auto first_fun = FirstFunctionGetter::GetFunction(bound_expr_ref.return_type); vector> first_children; @@ -499,9 +513,9 @@ unique_ptr Binder::BindSelectNode(SelectNode &statement, unique_ FunctionBinder function_binder(*this); auto function = function_binder.BindAggregateFunction(first_fun, std::move(first_children)); function->SetAlias("__collated_group"); - result->aggregates.push_back(std::move(function)); + result.aggregates.push_back(std::move(function)); } - result->groups.group_expressions.push_back(std::move(bound_expr)); + result.groups.group_expressions.push_back(std::move(bound_expr)); // in the unbound expression we DO bind the table names of any ColumnRefs // we do this to make sure that "table.a" and "a" are treated the same @@ -512,13 +526,13 @@ unique_ptr Binder::BindSelectNode(SelectNode &statement, unique_ info.map[*unbound_groups[i]] = i; } } - result->groups.grouping_sets = std::move(statement.groups.grouping_sets); + result.groups.grouping_sets = std::move(statement.groups.grouping_sets); // bind the HAVING clause, if any if (statement.having) { - HavingBinder having_binder(*this, context, *result, info, statement.aggregate_handling); + HavingBinder having_binder(*this, context, result, info, statement.aggregate_handling); ExpressionBinder::QualifyColumnNames(having_binder, statement.having); - result->having = having_binder.Bind(statement.having); + result.having = having_binder.Bind(statement.having); } // bind the QUALIFY clause, if any @@ -527,9 +541,9 @@ unique_ptr Binder::BindSelectNode(SelectNode &statement, unique_ if (statement.aggregate_handling == AggregateHandling::FORCE_AGGREGATES) { throw BinderException("Combining QUALIFY with GROUP BY ALL is not supported yet"); } - QualifyBinder qualify_binder(*this, context, *result, info); + QualifyBinder qualify_binder(*this, context, result, info); ExpressionBinder::QualifyColumnNames(*this, statement.qualify); - result->qualify = qualify_binder.Bind(statement.qualify); + result.qualify = qualify_binder.Bind(statement.qualify); if (qualify_binder.HasBoundColumns()) { if (qualify_binder.BoundAggregates()) { throw BinderException("Cannot mix aggregates with non-aggregated columns!"); @@ -539,7 +553,7 @@ unique_ptr Binder::BindSelectNode(SelectNode &statement, unique_ } // after that, we bind to the SELECT list - SelectBinder select_binder(*this, context, *result, info); + SelectBinder select_binder(*this, context, result, info); // if we expand select-list expressions, e.g., via UNNEST, then we need to possibly // adjust the column index of the already bound ORDER BY modifiers, and not only set their types @@ -549,13 +563,13 @@ unique_ptr Binder::BindSelectNode(SelectNode &statement, unique_ for (idx_t i = 0; i < statement.select_list.size(); i++) { bool is_window = statement.select_list[i]->IsWindow(); - idx_t unnest_count = result->unnests.size(); + idx_t unnest_count = result.unnests.size(); LogicalType result_type; auto expr = select_binder.Bind(statement.select_list[i], &result_type, true); - bool is_original_column = i < result->column_count; + bool is_original_column = i < result.column_count; bool can_group_by_all = statement.aggregate_handling == AggregateHandling::FORCE_AGGREGATES && is_original_column; - result->bound_column_count++; + result.bound_column_count++; if (expr->GetExpressionType() == ExpressionType::BOUND_EXPANDED) { if (!is_original_column) { @@ -571,9 +585,9 @@ unique_ptr Binder::BindSelectNode(SelectNode &statement, unique_ for (auto &struct_expr : struct_expressions) { new_names.push_back(struct_expr->GetName()); - result->types.push_back(struct_expr->return_type); + result.types.push_back(struct_expr->return_type); internal_sql_types.push_back(struct_expr->return_type); - result->select_list.push_back(std::move(struct_expr)); + result.select_list.push_back(std::move(struct_expr)); } bind_state.AddExpandedColumn(struct_expressions.size()); continue; @@ -594,7 +608,7 @@ unique_ptr Binder::BindSelectNode(SelectNode &statement, unique_ if (is_window) { throw BinderException("Cannot group on a window clause"); } - if (result->unnests.size() > unnest_count) { + if (result.unnests.size() > unnest_count) { throw BinderException("Cannot group on an UNNEST or UNLIST clause"); } // we are forcing aggregates, and the node has columns bound @@ -602,10 +616,10 @@ unique_ptr Binder::BindSelectNode(SelectNode &statement, unique_ group_by_all_indexes.push_back(i); } - result->select_list.push_back(std::move(expr)); + result.select_list.push_back(std::move(expr)); if (is_original_column) { - new_names.push_back(std::move(result->names[i])); - result->types.push_back(result_type); + new_names.push_back(std::move(result.names[i])); + result.types.push_back(result_type); } internal_sql_types.push_back(result_type); @@ -617,31 +631,31 @@ unique_ptr Binder::BindSelectNode(SelectNode &statement, unique_ // push the GROUP BY ALL expressions into the group set for (auto &group_by_all_index : group_by_all_indexes) { - auto &expr = result->select_list[group_by_all_index]; + auto &expr = result.select_list[group_by_all_index]; auto group_ref = make_uniq( - expr->return_type, ColumnBinding(result->group_index, result->groups.group_expressions.size())); - result->groups.group_expressions.push_back(std::move(expr)); + expr->return_type, ColumnBinding(result.group_index, result.groups.group_expressions.size())); + result.groups.group_expressions.push_back(std::move(expr)); expr = std::move(group_ref); } set group_by_all_indexes_set; if (!group_by_all_indexes.empty()) { - idx_t num_set_indexes = result->groups.group_expressions.size(); + idx_t num_set_indexes = result.groups.group_expressions.size(); for (idx_t i = 0; i < num_set_indexes; i++) { group_by_all_indexes_set.insert(i); } - D_ASSERT(result->groups.grouping_sets.empty()); - result->groups.grouping_sets.push_back(group_by_all_indexes_set); + D_ASSERT(result.groups.grouping_sets.empty()); + result.groups.grouping_sets.push_back(group_by_all_indexes_set); } - result->column_count = new_names.size(); - result->names = std::move(new_names); - result->need_prune = result->select_list.size() > result->column_count; + result.column_count = new_names.size(); + result.names = std::move(new_names); + result.need_prune = result.select_list.size() > result.column_count; // in the normal select binder, we bind columns as if there is no aggregation // i.e. in the query [SELECT i, SUM(i) FROM integers;] the "i" will be bound as a normal column // since we have an aggregation, we need to either (1) throw an error, or (2) wrap the column in a FIRST() aggregate // we choose the former one [CONTROVERSIAL: this is the PostgreSQL behavior] - if (!result->groups.group_expressions.empty() || !result->aggregates.empty() || statement.having || - !result->groups.grouping_sets.empty()) { + if (!result.groups.group_expressions.empty() || !result.aggregates.empty() || statement.having || + !result.groups.grouping_sets.empty()) { if (statement.aggregate_handling == AggregateHandling::NO_AGGREGATES_ALLOWED) { throw BinderException("Aggregates cannot be present in a Project relation!"); } else { @@ -672,13 +686,19 @@ unique_ptr Binder::BindSelectNode(SelectNode &statement, unique_ // QUALIFY clause requires at least one window function to be specified in at least one of the SELECT column list or // the filter predicate of the QUALIFY clause - if (statement.qualify && result->windows.empty()) { + if (statement.qualify && result.windows.empty()) { throw BinderException("at least one window function must appear in the SELECT column or QUALIFY clause"); } // now that the SELECT list is bound, we set the types of DISTINCT/ORDER BY expressions - BindModifiers(*result, result->projection_index, result->names, internal_sql_types, bind_state); - return std::move(result); + BindModifiers(result, result.projection_index, result.names, internal_sql_types, bind_state); + + BoundStatement result_statement; + result_statement.types = result.types; + result_statement.names = result.names; + result_statement.plan = CreatePlan(result); + result_statement.extra_info.original_expressions = std::move(result.bind_state.original_expressions); + return result_statement; } } // namespace duckdb diff --git a/src/duckdb/src/planner/binder/query_node/bind_setop_node.cpp b/src/duckdb/src/planner/binder/query_node/bind_setop_node.cpp index 50c6b3c06..91a501b2f 100644 --- a/src/duckdb/src/planner/binder/query_node/bind_setop_node.cpp +++ b/src/duckdb/src/planner/binder/query_node/bind_setop_node.cpp @@ -10,89 +10,109 @@ #include "duckdb/planner/query_node/bound_select_node.hpp" #include "duckdb/planner/query_node/bound_set_operation_node.hpp" #include "duckdb/planner/expression_binder/select_bind_state.hpp" +#include "duckdb/planner/operator/logical_projection.hpp" #include "duckdb/common/enum_util.hpp" namespace duckdb { -static void GatherAliases(BoundQueryNode &node, SelectBindState &bind_state, const vector &reorder_idx) { - if (node.type == QueryNodeType::SET_OPERATION_NODE) { - // setop, recurse - auto &setop = node.Cast(); +struct SetOpAliasGatherer { +public: + explicit SetOpAliasGatherer(SelectBindState &bind_state_p) : bind_state(bind_state_p) { + } - // create new reorder index - if (setop.setop_type == SetOperationType::UNION_BY_NAME) { - // for UNION BY NAME - create a new re-order index - case_insensitive_map_t reorder_map; - for (idx_t col_idx = 0; col_idx < setop.names.size(); ++col_idx) { - reorder_map[setop.names[col_idx]] = reorder_idx[col_idx]; - } + void GatherAliases(BoundStatement &stmt, const vector &reorder_idx); + void GatherSetOpAliases(SetOperationType setop_type, const vector &names, + vector &bound_children, const vector &reorder_idx); - // use new reorder index - for (auto &child : setop.bound_children) { - vector new_reorder_idx; - for (idx_t col_idx = 0; col_idx < child.node->names.size(); col_idx++) { - auto &col_name = child.node->names[col_idx]; - auto entry = reorder_map.find(col_name); - if (entry == reorder_map.end()) { - throw InternalException("SetOp - Column name not found in reorder_map in UNION BY NAME"); - } - new_reorder_idx.push_back(entry->second); - } - GatherAliases(*child.node, bind_state, new_reorder_idx); - } - return; - } +private: + SelectBindState &bind_state; +}; - for (auto &child : setop.bound_children) { - GatherAliases(*child.node, bind_state, reorder_idx); - } - } else { - // query node - D_ASSERT(node.type == QueryNodeType::SELECT_NODE); - auto &select = node.Cast(); - // fill the alias lists with the names - D_ASSERT(reorder_idx.size() == select.names.size()); - for (idx_t i = 0; i < select.names.size(); i++) { - auto &name = select.names[i]; - // first check if the alias is already in there - auto entry = bind_state.alias_map.find(name); +void SetOpAliasGatherer::GatherAliases(BoundStatement &stmt, const vector &reorder_idx) { + if (stmt.extra_info.setop_type != SetOperationType::NONE) { + GatherSetOpAliases(stmt.extra_info.setop_type, stmt.names, stmt.extra_info.bound_children, reorder_idx); + return; + } + + // query node + auto &select_names = stmt.names; + // fill the alias lists with the names + D_ASSERT(reorder_idx.size() == select_names.size()); + for (idx_t i = 0; i < select_names.size(); i++) { + auto &name = select_names[i]; + // first check if the alias is already in there + auto entry = bind_state.alias_map.find(name); - idx_t index = reorder_idx[i]; + idx_t index = reorder_idx[i]; - if (entry == bind_state.alias_map.end()) { - // the alias is not in there yet, just assign it - bind_state.alias_map[name] = index; + if (entry == bind_state.alias_map.end()) { + // the alias is not in there yet, just assign it + bind_state.alias_map[name] = index; + } + } + // check if the expression matches one of the expressions in the original expression list + auto &select_list = stmt.extra_info.original_expressions; + for (idx_t i = 0; i < select_list.size(); i++) { + auto &expr = select_list[i]; + idx_t index = reorder_idx[i]; + // now check if the node is already in the set of expressions + auto expr_entry = bind_state.projection_map.find(*expr); + if (expr_entry != bind_state.projection_map.end()) { + // the node is in there + // repeat the same as with the alias: if there is an ambiguity we insert "-1" + if (expr_entry->second != index) { + bind_state.projection_map[*expr] = DConstants::INVALID_INDEX; } + } else { + // not in there yet, just place it in there + bind_state.projection_map[*expr] = index; + } + } +} + +void SetOpAliasGatherer::GatherSetOpAliases(SetOperationType setop_type, const vector &stmt_names, + vector &bound_children, const vector &reorder_idx) { + // create new reorder index + if (setop_type == SetOperationType::UNION_BY_NAME) { + auto &setop_names = stmt_names; + // for UNION BY NAME - create a new re-order index + case_insensitive_map_t reorder_map; + for (idx_t col_idx = 0; col_idx < setop_names.size(); ++col_idx) { + reorder_map[setop_names[col_idx]] = reorder_idx[col_idx]; } - // check if the expression matches one of the expressions in the original expression list - for (idx_t i = 0; i < select.bind_state.original_expressions.size(); i++) { - auto &expr = select.bind_state.original_expressions[i]; - idx_t index = reorder_idx[i]; - // now check if the node is already in the set of expressions - auto expr_entry = bind_state.projection_map.find(*expr); - if (expr_entry != bind_state.projection_map.end()) { - // the node is in there - // repeat the same as with the alias: if there is an ambiguity we insert "-1" - if (expr_entry->second != index) { - bind_state.projection_map[*expr] = DConstants::INVALID_INDEX; + + // use new reorder index + for (auto &child : bound_children) { + vector new_reorder_idx; + auto &child_names = child.names; + for (idx_t col_idx = 0; col_idx < child_names.size(); col_idx++) { + auto &col_name = child_names[col_idx]; + auto entry = reorder_map.find(col_name); + if (entry == reorder_map.end()) { + throw InternalException("SetOp - Column name not found in reorder_map in UNION BY NAME"); } - } else { - // not in there yet, just place it in there - bind_state.projection_map[*expr] = index; + new_reorder_idx.push_back(entry->second); } + GatherAliases(child, new_reorder_idx); + } + } else { + for (auto &child : bound_children) { + GatherAliases(child, reorder_idx); } } } -static void GatherAliases(BoundQueryNode &node, SelectBindState &bind_state) { +static void GatherAliases(BoundSetOperationNode &root, vector &child_statements, + SelectBindState &bind_state) { + SetOpAliasGatherer gatherer(bind_state); vector reorder_idx; - for (idx_t i = 0; i < node.names.size(); i++) { + for (idx_t i = 0; i < root.names.size(); i++) { reorder_idx.push_back(i); } - GatherAliases(node, bind_state, reorder_idx); + gatherer.GatherSetOpAliases(root.setop_type, root.names, child_statements, reorder_idx); } -static void BuildUnionByNameInfo(ClientContext &context, BoundSetOperationNode &result, bool can_contain_nulls) { +void Binder::BuildUnionByNameInfo(BoundSetOperationNode &result) { D_ASSERT(result.setop_type == SetOperationType::UNION_BY_NAME); vector> node_name_maps; case_insensitive_set_t global_name_set; @@ -101,10 +121,10 @@ static void BuildUnionByNameInfo(ClientContext &context, BoundSetOperationNode & // We throw a binder exception if two same name in the SELECT list D_ASSERT(result.names.empty()); for (auto &child : result.bound_children) { - auto &child_node = *child.node; + auto &child_names = child.names; case_insensitive_map_t node_name_map; - for (idx_t i = 0; i < child_node.names.size(); ++i) { - auto &col_name = child_node.names[i]; + for (idx_t i = 0; i < child_names.size(); ++i) { + auto &col_name = child_names[i]; if (node_name_map.find(col_name) != node_name_map.end()) { throw BinderException( "UNION (ALL) BY NAME operation doesn't support duplicate names in the SELECT list - " @@ -129,7 +149,7 @@ static void BuildUnionByNameInfo(ClientContext &context, BoundSetOperationNode & auto &col_name = result.names[i]; LogicalType result_type(LogicalTypeId::INVALID); for (idx_t child_idx = 0; child_idx < result.bound_children.size(); ++child_idx) { - auto &child = result.bound_children[child_idx]; + auto &child_types = result.bound_children[child_idx].types; auto &child_name_map = node_name_maps[child_idx]; // check if the column exists in this child node auto entry = child_name_map.find(col_name); @@ -137,7 +157,7 @@ static void BuildUnionByNameInfo(ClientContext &context, BoundSetOperationNode & need_reorder = true; } else { auto col_idx_in_child = entry->second; - auto &child_col_type = child.node->types[col_idx_in_child]; + auto &child_col_type = child_types[col_idx_in_child]; // the child exists in this node - compute the type if (result_type.id() == LogicalTypeId::INVALID) { result_type = child_col_type; @@ -165,6 +185,8 @@ static void BuildUnionByNameInfo(ClientContext &context, BoundSetOperationNode & return; } // If reorder is required, generate the expressions for each node + vector>> reorder_expressions; + reorder_expressions.resize(result.bound_children.size()); for (idx_t i = 0; i < new_size; ++i) { auto &col_name = result.names[i]; for (idx_t child_idx = 0; child_idx < result.bound_children.size(); ++child_idx) { @@ -179,34 +201,48 @@ static void BuildUnionByNameInfo(ClientContext &context, BoundSetOperationNode & } else { // the column exists - reference it auto col_idx_in_child = entry->second; - auto &child_col_type = child.node->types[col_idx_in_child]; - expr = make_uniq(child_col_type, - ColumnBinding(child.node->GetRootIndex(), col_idx_in_child)); + auto &child_col_type = child.types[col_idx_in_child]; + auto root_idx = child.plan->GetRootIndex(); + expr = make_uniq(child_col_type, ColumnBinding(root_idx, col_idx_in_child)); } - child.reorder_expressions.push_back(std::move(expr)); + reorder_expressions[child_idx].push_back(std::move(expr)); + } + } + // now push projections for each node + for (idx_t child_idx = 0; child_idx < result.bound_children.size(); ++child_idx) { + auto &child = result.bound_children[child_idx]; + auto &child_reorder_expressions = reorder_expressions[child_idx]; + // if we have re-order expressions push a projection + vector child_types; + for (auto &expr : child_reorder_expressions) { + child_types.push_back(expr->return_type); } + auto child_projection = + make_uniq(GenerateTableIndex(), std::move(child_reorder_expressions)); + child_projection->children.push_back(std::move(child.plan)); + child.plan = std::move(child_projection); + child.types = std::move(child_types); } } -static void GatherSetOpBinders(BoundQueryNode &node, Binder &binder, vector> &binders) { - if (node.type != QueryNodeType::SET_OPERATION_NODE) { - binders.push_back(binder); - return; +static void GatherSetOpBinders(vector &children, vector> &binders, + vector> &result) { + for (auto &child_binder : binders) { + result.push_back(*child_binder); } - auto &setop_node = node.Cast(); - for (auto &child : setop_node.bound_children) { - GatherSetOpBinders(*child.node, *child.binder, binders); + for (auto &child_node : children) { + GatherSetOpBinders(child_node.extra_info.bound_children, child_node.extra_info.child_binders, result); } } -unique_ptr Binder::BindNode(SetOperationNode &statement) { - auto result = make_uniq(); - result->setop_type = statement.setop_type; - result->setop_all = statement.setop_all; +BoundStatement Binder::BindNode(SetOperationNode &statement) { + BoundSetOperationNode result; + result.setop_type = statement.setop_type; + result.setop_all = statement.setop_all; // first recursively visit the set operations // all children have an independent BindContext and Binder - result->setop_index = GenerateTableIndex(); + result.setop_index = GenerateTableIndex(); if (statement.children.size() < 2) { throw InternalException("Set Operations must have at least 2 children"); } @@ -215,27 +251,23 @@ unique_ptr Binder::BindNode(SetOperationNode &statement) { throw InternalException("Set Operation type must have exactly 2 children - except for UNION/UNION_BY_NAME"); } for (auto &child : statement.children) { - BoundSetOpChild bound_child; - bound_child.binder = Binder::CreateBinder(context, this); - bound_child.binder->can_contain_nulls = true; - bound_child.node = bound_child.binder->BindNode(*child); - result->bound_children.push_back(std::move(bound_child)); + auto child_binder = Binder::CreateBinder(context, this); + child_binder->can_contain_nulls = true; + auto child_node = child_binder->BindNode(*child); + MoveCorrelatedExpressions(*child_binder); + result.bound_children.push_back(std::move(child_node)); + result.child_binders.push_back(std::move(child_binder)); } - // move the correlated expressions from the child binders to this binder - for (auto &bound_child : result->bound_children) { - MoveCorrelatedExpressions(*bound_child.binder); - } - - if (result->setop_type == SetOperationType::UNION_BY_NAME) { + if (result.setop_type == SetOperationType::UNION_BY_NAME) { // UNION BY NAME - merge the columns from all sides - BuildUnionByNameInfo(context, *result, can_contain_nulls); + BuildUnionByNameInfo(result); } else { // UNION ALL BY POSITION - the columns of both sides must match exactly - result->names = result->bound_children[0].node->names; - auto result_columns = result->bound_children[0].node->types.size(); - for (idx_t i = 1; i < result->bound_children.size(); ++i) { - if (result->bound_children[i].node->types.size() != result_columns) { + result.names = result.bound_children[0].names; + auto result_columns = result.bound_children[0].types.size(); + for (idx_t i = 1; i < result.bound_children.size(); ++i) { + if (result.bound_children[i].types.size() != result_columns) { throw BinderException("Set operations can only apply to expressions with the " "same number of result columns"); } @@ -243,40 +275,43 @@ unique_ptr Binder::BindNode(SetOperationNode &statement) { // figure out the types of the setop result by picking the max of both for (idx_t i = 0; i < result_columns; i++) { - auto result_type = result->bound_children[0].node->types[i]; - for (idx_t child_idx = 1; child_idx < result->bound_children.size(); ++child_idx) { - auto &child_node = *result->bound_children[child_idx].node; - result_type = LogicalType::ForceMaxLogicalType(result_type, child_node.types[i]); + auto result_type = result.bound_children[0].types[i]; + for (idx_t child_idx = 1; child_idx < result.bound_children.size(); ++child_idx) { + auto &child_types = result.bound_children[child_idx].types; + result_type = LogicalType::ForceMaxLogicalType(result_type, child_types[i]); } if (!can_contain_nulls) { if (ExpressionBinder::ContainsNullType(result_type)) { result_type = ExpressionBinder::ExchangeNullType(result_type); } } - result->types.push_back(result_type); + result.types.push_back(result_type); } } SelectBindState bind_state; if (!statement.modifiers.empty()) { // handle the ORDER BY/DISTINCT clauses - - // we recursively visit the children of this node to extract aliases and expressions that can be referenced - // in the ORDER BYs - GatherAliases(*result, bind_state); + vector> binders; + GatherSetOpBinders(result.bound_children, result.child_binders, binders); + GatherAliases(result, result.bound_children, bind_state); // now we perform the actual resolution of the ORDER BY/DISTINCT expressions - vector> binders; - for (auto &child : result->bound_children) { - GatherSetOpBinders(*child.node, *child.binder, binders); - } OrderBinder order_binder(binders, bind_state); - PrepareModifiers(order_binder, statement, *result); + PrepareModifiers(order_binder, statement, result); } // finally bind the types of the ORDER/DISTINCT clause expressions - BindModifiers(*result, result->setop_index, result->names, result->types, bind_state); - return std::move(result); + BindModifiers(result, result.setop_index, result.names, result.types, bind_state); + + BoundStatement result_statement; + result_statement.types = result.types; + result_statement.names = result.names; + result_statement.plan = CreatePlan(result); + result_statement.extra_info.setop_type = statement.setop_type; + result_statement.extra_info.bound_children = std::move(result.bound_children); + result_statement.extra_info.child_binders = std::move(result.child_binders); + return result_statement; } } // namespace duckdb diff --git a/src/duckdb/src/planner/binder/query_node/bind_statement_node.cpp b/src/duckdb/src/planner/binder/query_node/bind_statement_node.cpp new file mode 100644 index 000000000..6f6f9941a --- /dev/null +++ b/src/duckdb/src/planner/binder/query_node/bind_statement_node.cpp @@ -0,0 +1,26 @@ +#include "duckdb/parser/query_node/statement_node.hpp" +#include "duckdb/parser/statement/insert_statement.hpp" +#include "duckdb/parser/statement/update_statement.hpp" +#include "duckdb/parser/statement/delete_statement.hpp" +#include "duckdb/parser/statement/merge_into_statement.hpp" +#include "duckdb/planner/binder.hpp" + +namespace duckdb { + +BoundStatement Binder::BindNode(StatementNode &statement) { + // switch on type here to ensure we bind WITHOUT ctes to prevent infinite recursion + switch (statement.stmt.type) { + case StatementType::INSERT_STATEMENT: + return Bind(statement.stmt.Cast()); + case StatementType::DELETE_STATEMENT: + return Bind(statement.stmt.Cast()); + case StatementType::UPDATE_STATEMENT: + return Bind(statement.stmt.Cast()); + case StatementType::MERGE_INTO_STATEMENT: + return Bind(statement.stmt.Cast()); + default: + return Bind(statement.stmt); + } +} + +} // namespace duckdb diff --git a/src/duckdb/src/planner/binder/query_node/plan_cte_node.cpp b/src/duckdb/src/planner/binder/query_node/plan_cte_node.cpp deleted file mode 100644 index 5bd06c0e5..000000000 --- a/src/duckdb/src/planner/binder/query_node/plan_cte_node.cpp +++ /dev/null @@ -1,60 +0,0 @@ -#include "duckdb/common/string_util.hpp" -#include "duckdb/planner/binder.hpp" -#include "duckdb/planner/expression/bound_cast_expression.hpp" -#include "duckdb/planner/operator/logical_materialized_cte.hpp" -#include "duckdb/planner/operator/logical_projection.hpp" -#include "duckdb/planner/operator/logical_set_operation.hpp" -#include "duckdb/planner/query_node/bound_cte_node.hpp" - -namespace duckdb { - -unique_ptr Binder::CreatePlan(BoundCTENode &node) { - // Generate the logical plan for the cte_query and child. - auto cte_query = CreatePlan(*node.query); - auto cte_child = CreatePlan(*node.child); - - auto root = make_uniq(node.ctename, node.setop_index, node.types.size(), - std::move(cte_query), std::move(cte_child), node.materialized); - - // check if there are any unplanned subqueries left in either child - has_unplanned_dependent_joins = has_unplanned_dependent_joins || node.child_binder->has_unplanned_dependent_joins || - node.query_binder->has_unplanned_dependent_joins; - - return VisitQueryNode(node, std::move(root)); -} - -unique_ptr Binder::CreatePlan(BoundCTENode &node, unique_ptr base) { - // Generate the logical plan for the cte_query and child. - auto cte_query = CreatePlan(*node.query); - unique_ptr root; - if (node.child && node.child->type == QueryNodeType::CTE_NODE) { - root = CreatePlan(node.child->Cast(), std::move(base)); - } else if (node.child) { - root = CreatePlan(*node.child); - } else { - root = std::move(base); - } - - // Only keep the materialized CTE, if it is used - if (node.child_binder->bind_context.cte_references[node.ctename] && - *node.child_binder->bind_context.cte_references[node.ctename] > 0) { - - // Push the CTE through single-child operators so query modifiers appear ABOVE the CTE (internal issue #2652) - // Otherwise, we may have a LIMIT on top of the CTE, and an ORDER BY in the query, and we can't make a TopN - reference> cte_child = root; - while (cte_child.get()->children.size() == 1 && cte_child.get()->type != LogicalOperatorType::LOGICAL_CTE_REF) { - cte_child = cte_child.get()->children[0]; - } - cte_child.get() = - make_uniq(node.ctename, node.setop_index, node.types.size(), std::move(cte_query), - std::move(cte_child.get()), node.materialized); - - // check if there are any unplanned subqueries left in either child - has_unplanned_dependent_joins = has_unplanned_dependent_joins || - node.child_binder->has_unplanned_dependent_joins || - node.query_binder->has_unplanned_dependent_joins; - } - return VisitQueryNode(node, std::move(root)); -} - -} // namespace duckdb diff --git a/src/duckdb/src/planner/binder/query_node/plan_recursive_cte_node.cpp b/src/duckdb/src/planner/binder/query_node/plan_recursive_cte_node.cpp deleted file mode 100644 index 4064136b6..000000000 --- a/src/duckdb/src/planner/binder/query_node/plan_recursive_cte_node.cpp +++ /dev/null @@ -1,51 +0,0 @@ -#include "duckdb/planner/expression/bound_cast_expression.hpp" -#include "duckdb/planner/binder.hpp" -#include "duckdb/planner/operator/logical_projection.hpp" -#include "duckdb/planner/operator/logical_recursive_cte.hpp" -#include "duckdb/planner/operator/logical_set_operation.hpp" -#include "duckdb/planner/query_node/bound_recursive_cte_node.hpp" -#include "duckdb/common/string_util.hpp" - -namespace duckdb { - -unique_ptr Binder::CreatePlan(BoundRecursiveCTENode &node) { - // Generate the logical plan for the left and right sides of the set operation - node.left_binder->is_outside_flattened = is_outside_flattened; - node.right_binder->is_outside_flattened = is_outside_flattened; - - auto left_node = node.left_binder->CreatePlan(*node.left); - auto right_node = node.right_binder->CreatePlan(*node.right); - - // check if there are any unplanned subqueries left in either child - has_unplanned_dependent_joins = has_unplanned_dependent_joins || node.left_binder->has_unplanned_dependent_joins || - node.right_binder->has_unplanned_dependent_joins; - - // for both the left and right sides, cast them to the same types - left_node = CastLogicalOperatorToTypes(node.left->types, node.types, std::move(left_node)); - right_node = CastLogicalOperatorToTypes(node.right->types, node.types, std::move(right_node)); - - bool ref_recurring = node.right_binder->bind_context.cte_references["recurring." + node.ctename] && - *node.right_binder->bind_context.cte_references["recurring." + node.ctename] != 0; - - if (node.key_targets.empty() && ref_recurring) { - throw InvalidInputException("RECURRING can only be used with USING KEY in recursive CTE."); - } - - // Check if there is a reference to the recursive or recurring table, if not create a set operator. - if ((!node.right_binder->bind_context.cte_references[node.ctename] || - *node.right_binder->bind_context.cte_references[node.ctename] == 0) && - !ref_recurring) { - auto root = - make_uniq(node.setop_index, node.types.size(), std::move(left_node), - std::move(right_node), LogicalOperatorType::LOGICAL_UNION, node.union_all); - return VisitQueryNode(node, std::move(root)); - } - - auto root = - make_uniq(node.ctename, node.setop_index, node.types.size(), node.union_all, - std::move(node.key_targets), std::move(left_node), std::move(right_node)); - root->ref_recurring = ref_recurring; - return VisitQueryNode(node, std::move(root)); -} - -} // namespace duckdb diff --git a/src/duckdb/src/planner/binder/query_node/plan_select_node.cpp b/src/duckdb/src/planner/binder/query_node/plan_select_node.cpp index 46e5d2e12..10b206f24 100644 --- a/src/duckdb/src/planner/binder/query_node/plan_select_node.cpp +++ b/src/duckdb/src/planner/binder/query_node/plan_select_node.cpp @@ -16,10 +16,8 @@ unique_ptr Binder::PlanFilter(unique_ptr condition, } unique_ptr Binder::CreatePlan(BoundSelectNode &statement) { - unique_ptr root; - D_ASSERT(statement.from_table); - root = CreatePlan(*statement.from_table); - D_ASSERT(root); + D_ASSERT(statement.from_table.plan); + auto root = std::move(statement.from_table.plan); // plan the sample clause if (statement.sample_options) { diff --git a/src/duckdb/src/planner/binder/query_node/plan_setop.cpp b/src/duckdb/src/planner/binder/query_node/plan_setop.cpp index 9b0fa7c94..a1a7f60b0 100644 --- a/src/duckdb/src/planner/binder/query_node/plan_setop.cpp +++ b/src/duckdb/src/planner/binder/query_node/plan_setop.cpp @@ -10,8 +10,8 @@ namespace duckdb { // Optionally push a PROJECTION operator -unique_ptr Binder::CastLogicalOperatorToTypes(vector &source_types, - vector &target_types, +unique_ptr Binder::CastLogicalOperatorToTypes(const vector &source_types, + const vector &target_types, unique_ptr op) { D_ASSERT(op); // first check if we even need to cast @@ -113,29 +113,16 @@ unique_ptr Binder::CreatePlan(BoundSetOperationNode &node) { D_ASSERT(node.bound_children.size() >= 2); vector> children; - for (auto &child : node.bound_children) { - child.binder->is_outside_flattened = is_outside_flattened; + for (idx_t child_idx = 0; child_idx < node.bound_children.size(); child_idx++) { + auto &child = node.bound_children[child_idx]; + auto &child_binder = *node.child_binders[child_idx]; // construct the logical plan for the child node - auto child_node = child.binder->CreatePlan(*child.node); - if (!child.reorder_expressions.empty()) { - // if we have re-order expressions push a projection - vector child_types; - for (auto &expr : child.reorder_expressions) { - child_types.push_back(expr->return_type); - } - auto child_projection = - make_uniq(GenerateTableIndex(), std::move(child.reorder_expressions)); - child_projection->children.push_back(std::move(child_node)); - child_node = std::move(child_projection); - - child_node = CastLogicalOperatorToTypes(child_types, node.types, std::move(child_node)); - } else { - // otherwise push only casts - child_node = CastLogicalOperatorToTypes(child.node->types, node.types, std::move(child_node)); - } + auto child_node = std::move(child.plan); + // push casts for the target types + child_node = CastLogicalOperatorToTypes(child.types, node.types, std::move(child_node)); // check if there are any unplanned subqueries left in any child - if (child.binder->has_unplanned_dependent_joins) { + if (child_binder.has_unplanned_dependent_joins) { has_unplanned_dependent_joins = true; } children.push_back(std::move(child_node)); diff --git a/src/duckdb/src/planner/binder/query_node/plan_subquery.cpp b/src/duckdb/src/planner/binder/query_node/plan_subquery.cpp index 2664903d3..29a419ab7 100644 --- a/src/duckdb/src/planner/binder/query_node/plan_subquery.cpp +++ b/src/duckdb/src/planner/binder/query_node/plan_subquery.cpp @@ -186,9 +186,10 @@ static unique_ptr PlanUncorrelatedSubquery(Binder &binder, BoundSubq } } -static unique_ptr -CreateDuplicateEliminatedJoin(const vector &correlated_columns, JoinType join_type, - unique_ptr original_plan, bool perform_delim) { +static unique_ptr CreateDuplicateEliminatedJoin(const CorrelatedColumns &correlated_columns, + JoinType join_type, + unique_ptr original_plan, + bool perform_delim) { auto delim_join = make_uniq(join_type); delim_join->correlated_columns = correlated_columns; delim_join->perform_delim = perform_delim; @@ -216,7 +217,7 @@ static bool PerformDelimOnType(const LogicalType &type) { return true; } -static bool PerformDuplicateElimination(Binder &binder, vector &correlated_columns) { +static bool PerformDuplicateElimination(Binder &binder, CorrelatedColumns &correlated_columns) { if (!ClientConfig::GetConfig(binder.context).enable_optimizer) { // if optimizations are disabled we always do a delim join return true; @@ -235,7 +236,8 @@ static bool PerformDuplicateElimination(Binder &binder, vector Binder::PlanSubquery(BoundSubqueryExpression &expr, uniqu // first we translate the QueryNode of the subquery into a logical plan auto sub_binder = Binder::CreateBinder(context, this); sub_binder->is_outside_flattened = false; - auto subquery_root = sub_binder->CreatePlan(*expr.subquery); + auto subquery_root = std::move(expr.subquery.plan); D_ASSERT(subquery_root); // now we actually flatten the subquery @@ -403,7 +405,7 @@ void Binder::PlanSubqueries(unique_ptr &expr_ptr, unique_ptr Binder::PlanLateralJoin(unique_ptr left, unique_ptr right, - vector &correlated, JoinType join_type, + CorrelatedColumns &correlated, JoinType join_type, unique_ptr condition) { // scan the right operator for correlated columns // correlated LATERAL JOIN diff --git a/src/duckdb/src/planner/binder/statement/bind_attach.cpp b/src/duckdb/src/planner/binder/statement/bind_attach.cpp index 0e8655d2f..6da075e25 100644 --- a/src/duckdb/src/planner/binder/statement/bind_attach.cpp +++ b/src/duckdb/src/planner/binder/statement/bind_attach.cpp @@ -1,7 +1,6 @@ #include "duckdb/planner/binder.hpp" #include "duckdb/parser/statement/attach_statement.hpp" #include "duckdb/parser/tableref/table_function_ref.hpp" -#include "duckdb/planner/tableref/bound_table_function.hpp" #include "duckdb/planner/operator/logical_simple.hpp" #include "duckdb/planner/expression_binder/table_function_binder.hpp" #include "duckdb/execution/expression_executor.hpp" @@ -29,7 +28,7 @@ BoundStatement Binder::Bind(AttachStatement &stmt) { result.plan = make_uniq(LogicalOperatorType::LOGICAL_ATTACH, std::move(stmt.info)); auto &properties = GetStatementProperties(); - properties.allow_stream_result = false; + properties.output_type = QueryResultOutputType::FORCE_MATERIALIZED; properties.return_type = StatementReturnType::NOTHING; return result; } diff --git a/src/duckdb/src/planner/binder/statement/bind_call.cpp b/src/duckdb/src/planner/binder/statement/bind_call.cpp index ba96927e8..a746e1689 100644 --- a/src/duckdb/src/planner/binder/statement/bind_call.cpp +++ b/src/duckdb/src/planner/binder/statement/bind_call.cpp @@ -1,8 +1,6 @@ #include "duckdb/parser/statement/call_statement.hpp" #include "duckdb/parser/tableref/table_function_ref.hpp" #include "duckdb/planner/binder.hpp" -#include "duckdb/planner/operator/logical_get.hpp" -#include "duckdb/planner/tableref/bound_table_function.hpp" #include "duckdb/parser/query_node/select_node.hpp" #include "duckdb/parser/expression/star_expression.hpp" @@ -19,7 +17,7 @@ BoundStatement Binder::Bind(CallStatement &stmt) { auto result = Bind(select_statement); auto &properties = GetStatementProperties(); - properties.allow_stream_result = false; + properties.output_type = QueryResultOutputType::FORCE_MATERIALIZED; return result; } diff --git a/src/duckdb/src/planner/binder/statement/bind_copy.cpp b/src/duckdb/src/planner/binder/statement/bind_copy.cpp index b7881a0a1..94a8e44cf 100644 --- a/src/duckdb/src/planner/binder/statement/bind_copy.cpp +++ b/src/duckdb/src/planner/binder/statement/bind_copy.cpp @@ -36,7 +36,7 @@ void IsFormatExtensionKnown(const string &format) { // It's a match, we must throw throw CatalogException( "Copy Function with name \"%s\" is not in the catalog, but it exists in the %s extension.", format, - file_postfixes.extension); + std::string(file_postfixes.extension)); } } } @@ -115,7 +115,7 @@ BoundStatement Binder::BindCopyTo(CopyStatement &stmt, const CopyFunction &funct PreserveOrderType preserve_order = PreserveOrderType::AUTOMATIC; CopyFunctionReturnType return_type = CopyFunctionReturnType::CHANGED_ROWS; - CopyFunctionBindInput bind_input(*stmt.info); + CopyFunctionBindInput bind_input(*stmt.info, function.function_info); bind_input.file_extension = function.extension; @@ -251,7 +251,6 @@ BoundStatement Binder::BindCopyTo(CopyStatement &stmt, const CopyFunction &funct auto new_select_list = function.copy_to_select(input); if (!new_select_list.empty()) { - // We have a new select list, create a projection on top of the current plan auto projection = make_uniq(GenerateTableIndex(), std::move(new_select_list)); projection->children.push_back(std::move(select_node.plan)); @@ -551,8 +550,8 @@ BoundStatement Binder::Bind(CopyStatement &stmt, CopyToType copy_to_type) { // check if this matches the mode if (copy_option.mode != CopyOptionMode::READ_WRITE && copy_option.mode != copy_mode) { throw InvalidInputException("Option \"%s\" is not supported for %s - only for %s", provided_option, - stmt.info->is_from ? "reading" : "writing", - stmt.info->is_from ? "writing" : "reading"); + std::string(stmt.info->is_from ? "reading" : "writing"), + std::string(stmt.info->is_from ? "writing" : "reading")); } if (copy_option.type.id() != LogicalTypeId::ANY) { if (provided_entry.second.empty()) { @@ -599,7 +598,7 @@ BoundStatement Binder::Bind(CopyStatement &stmt, CopyToType copy_to_type) { } auto &properties = GetStatementProperties(); - properties.allow_stream_result = false; + properties.output_type = QueryResultOutputType::FORCE_MATERIALIZED; properties.return_type = StatementReturnType::CHANGED_ROWS; if (stmt.info->is_from) { return BindCopyFrom(stmt, function); diff --git a/src/duckdb/src/planner/binder/statement/bind_copy_database.cpp b/src/duckdb/src/planner/binder/statement/bind_copy_database.cpp index d2c0a03fb..9fd527ae2 100644 --- a/src/duckdb/src/planner/binder/statement/bind_copy_database.cpp +++ b/src/duckdb/src/planner/binder/statement/bind_copy_database.cpp @@ -21,7 +21,6 @@ namespace duckdb { unique_ptr Binder::BindCopyDatabaseSchema(Catalog &from_database, const string &target_database_name) { - catalog_entry_vector_t catalog_entries; catalog_entries = PhysicalExport::GetNaiveExportOrder(context, from_database); @@ -125,7 +124,7 @@ BoundStatement Binder::Bind(CopyDatabaseStatement &stmt) { result.plan = std::move(plan); auto &properties = GetStatementProperties(); - properties.allow_stream_result = false; + properties.output_type = QueryResultOutputType::FORCE_MATERIALIZED; properties.return_type = StatementReturnType::NOTHING; properties.RegisterDBModify(target_catalog, context); return result; diff --git a/src/duckdb/src/planner/binder/statement/bind_create.cpp b/src/duckdb/src/planner/binder/statement/bind_create.cpp index 76b43f60a..4a29661a6 100644 --- a/src/duckdb/src/planner/binder/statement/bind_create.cpp +++ b/src/duckdb/src/planner/binder/statement/bind_create.cpp @@ -39,7 +39,6 @@ #include "duckdb/planner/operator/logical_projection.hpp" #include "duckdb/planner/parsed_data/bound_create_table_info.hpp" #include "duckdb/planner/query_node/bound_select_node.hpp" -#include "duckdb/planner/tableref/bound_basetableref.hpp" #include "duckdb/storage/storage_extension.hpp" #include "duckdb/common/extension_type_info.hpp" #include "duckdb/common/type_visitor.hpp" @@ -120,11 +119,11 @@ void Binder::SearchSchema(CreateInfo &info) { if (!info.temporary) { // non-temporary create: not read only if (info.catalog == TEMP_CATALOG) { - throw ParserException("Only TEMPORARY table names can use the \"%s\" catalog", TEMP_CATALOG); + throw ParserException("Only TEMPORARY table names can use the \"%s\" catalog", std::string(TEMP_CATALOG)); } } else { if (info.catalog != TEMP_CATALOG) { - throw ParserException("TEMPORARY table names can *only* use the \"%s\" catalog", TEMP_CATALOG); + throw ParserException("TEMPORARY table names can *only* use the \"%s\" catalog", std::string(TEMP_CATALOG)); } } } @@ -345,11 +344,7 @@ SchemaCatalogEntry &Binder::BindCreateFunctionInfo(CreateInfo &info) { try { dummy_binder->Bind(*query_node); } catch (const std::exception &ex) { - // TODO: we would like to do something like "error = ErrorData(ex);" here, - // but that breaks macro's like "create macro m(x) as table (from query_table(x));", - // because dummy-binding these always throws an error instead of a ParameterNotResolvedException. - // So, for now, we allow macro's with bind errors to be created. - // Binding is still useful because we can create the dependencies. + error = ErrorData(ex); } } @@ -548,23 +543,21 @@ BoundStatement Binder::Bind(CreateStatement &stmt) { create_index_info.table); auto table_ref = make_uniq(table_description); auto bound_table = Bind(*table_ref); - if (bound_table->type != TableReferenceType::BASE_TABLE) { + auto plan = std::move(bound_table.plan); + if (plan->type != LogicalOperatorType::LOGICAL_GET) { + throw BinderException("can only create an index on a base table"); + } + auto &get = plan->Cast(); + auto table_ptr = get.GetTable(); + if (!table_ptr) { throw BinderException("can only create an index on a base table"); } - auto &table_binding = bound_table->Cast(); - auto &table = table_binding.table; + auto &table = *table_ptr; if (table.temporary) { stmt.info->temporary = true; } properties.RegisterDBModify(table.catalog, context); - - // create a plan over the bound table - auto plan = CreatePlan(*bound_table); - if (plan->type != LogicalOperatorType::LOGICAL_GET) { - throw BinderException("Cannot create index on a view!"); - } - result.plan = table.catalog.BindCreateIndex(*this, stmt, table, std::move(plan)); break; } @@ -718,7 +711,7 @@ BoundStatement Binder::Bind(CreateStatement &stmt) { throw InternalException("Unrecognized type!"); } properties.return_type = StatementReturnType::NOTHING; - properties.allow_stream_result = false; + properties.output_type = QueryResultOutputType::FORCE_MATERIALIZED; return result; } diff --git a/src/duckdb/src/planner/binder/statement/bind_create_table.cpp b/src/duckdb/src/planner/binder/statement/bind_create_table.cpp index ad70fe14a..22c35402d 100644 --- a/src/duckdb/src/planner/binder/statement/bind_create_table.cpp +++ b/src/duckdb/src/planner/binder/statement/bind_create_table.cpp @@ -40,10 +40,18 @@ static void VerifyCompressionType(ClientContext &context, optional_ptrCast(); for (auto &col : base.columns.Logical()) { auto compression_type = col.CompressionType(); - if (CompressionTypeIsDeprecated(compression_type, storage_manager)) { - throw BinderException("Can't compress using user-provided compression type '%s', that type is deprecated " - "and only has decompress support", - CompressionTypeToString(compression_type)); + auto compression_availability_result = CompressionTypeIsAvailable(compression_type, storage_manager); + if (!compression_availability_result.IsAvailable()) { + if (compression_availability_result.IsDeprecated()) { + throw BinderException( + "Can't compress using user-provided compression type '%s', that type is deprecated " + "and only has decompress support", + CompressionTypeToString(compression_type)); + } else { + throw BinderException( + "Can't compress using user-provided compression type '%s', that type is not available yet", + CompressionTypeToString(compression_type)); + } } auto logical_type = col.GetType(); if (logical_type.id() == LogicalTypeId::USER && logical_type.HasAlias()) { @@ -289,7 +297,7 @@ void Binder::BindGeneratedColumns(BoundCreateTableInfo &info) { col.SetType(bound_expression->return_type); // Update the type in the binding, for future expansions - table_binding->types[i.index] = col.Type(); + table_binding->SetColumnType(i.index, col.Type()); } bound_indices.insert(i); } @@ -673,7 +681,7 @@ unique_ptr Binder::BindCreateTableInfo(unique_ptrdependencies.VerifyDependencies(schema.catalog, result->Base().table); auto &properties = GetStatementProperties(); - properties.allow_stream_result = false; + properties.output_type = QueryResultOutputType::FORCE_MATERIALIZED; return result; } diff --git a/src/duckdb/src/planner/binder/statement/bind_delete.cpp b/src/duckdb/src/planner/binder/statement/bind_delete.cpp index e83a62ae3..0d4b8630d 100644 --- a/src/duckdb/src/planner/binder/statement/bind_delete.cpp +++ b/src/duckdb/src/planner/binder/statement/bind_delete.cpp @@ -5,8 +5,6 @@ #include "duckdb/planner/operator/logical_delete.hpp" #include "duckdb/planner/operator/logical_filter.hpp" #include "duckdb/planner/operator/logical_get.hpp" -#include "duckdb/planner/bound_tableref.hpp" -#include "duckdb/planner/tableref/bound_basetableref.hpp" #include "duckdb/planner/operator/logical_cross_product.hpp" #include "duckdb/catalog/catalog_entry/table_catalog_entry.hpp" @@ -15,38 +13,34 @@ namespace duckdb { BoundStatement Binder::Bind(DeleteStatement &stmt) { // visit the table reference auto bound_table = Bind(*stmt.table); - if (bound_table->type != TableReferenceType::BASE_TABLE) { - throw BinderException("Can only delete from base table!"); + auto root = std::move(bound_table.plan); + if (root->type != LogicalOperatorType::LOGICAL_GET) { + throw BinderException("Can only delete from base table"); } - auto &table_binding = bound_table->Cast(); - auto &table = table_binding.table; - - auto root = CreatePlan(*bound_table); auto &get = root->Cast(); - D_ASSERT(root->type == LogicalOperatorType::LOGICAL_GET); - + auto table_ptr = get.GetTable(); + if (!table_ptr) { + throw BinderException("Can only delete from base table"); + } + auto &table = *table_ptr; if (!table.temporary) { // delete from persistent table: not read only! auto &properties = GetStatementProperties(); properties.RegisterDBModify(table.catalog, context); } - // Add CTEs as bindable - AddCTEMap(stmt.cte_map); - // plan any tables from the various using clauses if (!stmt.using_clauses.empty()) { unique_ptr child_operator; for (auto &using_clause : stmt.using_clauses) { // bind the using clause auto using_binder = Binder::CreateBinder(context, this); - auto bound_node = using_binder->Bind(*using_clause); - auto op = CreatePlan(*bound_node); + auto op = using_binder->Bind(*using_clause); if (child_operator) { // already bound a child: create a cross product to unify the two - child_operator = LogicalCrossProduct::Create(std::move(child_operator), std::move(op)); + child_operator = LogicalCrossProduct::Create(std::move(child_operator), std::move(op.plan)); } else { - child_operator = std::move(op); + child_operator = std::move(op.plan); } bind_context.AddContext(std::move(using_binder->bind_context)); } @@ -90,7 +84,7 @@ BoundStatement Binder::Bind(DeleteStatement &stmt) { result.types = {LogicalType::BIGINT}; auto &properties = GetStatementProperties(); - properties.allow_stream_result = false; + properties.output_type = QueryResultOutputType::FORCE_MATERIALIZED; properties.return_type = StatementReturnType::CHANGED_ROWS; return result; diff --git a/src/duckdb/src/planner/binder/statement/bind_detach.cpp b/src/duckdb/src/planner/binder/statement/bind_detach.cpp index 98db58055..b2d3313f5 100644 --- a/src/duckdb/src/planner/binder/statement/bind_detach.cpp +++ b/src/duckdb/src/planner/binder/statement/bind_detach.cpp @@ -13,7 +13,7 @@ BoundStatement Binder::Bind(DetachStatement &stmt) { result.types = {LogicalType::BOOLEAN}; auto &properties = GetStatementProperties(); - properties.allow_stream_result = false; + properties.output_type = QueryResultOutputType::FORCE_MATERIALIZED; properties.return_type = StatementReturnType::NOTHING; return result; } diff --git a/src/duckdb/src/planner/binder/statement/bind_drop.cpp b/src/duckdb/src/planner/binder/statement/bind_drop.cpp index f40a86c61..c6fd78845 100644 --- a/src/duckdb/src/planner/binder/statement/bind_drop.cpp +++ b/src/duckdb/src/planner/binder/statement/bind_drop.cpp @@ -1,6 +1,5 @@ #include "duckdb/parser/statement/drop_statement.hpp" #include "duckdb/planner/binder.hpp" -#include "duckdb/planner/bound_tableref.hpp" #include "duckdb/planner/operator/logical_simple.hpp" #include "duckdb/catalog/catalog.hpp" #include "duckdb/catalog/standard_entry.hpp" @@ -94,7 +93,7 @@ BoundStatement Binder::Bind(DropStatement &stmt) { result.names = {"Success"}; result.types = {LogicalType::BOOLEAN}; - properties.allow_stream_result = false; + properties.output_type = QueryResultOutputType::FORCE_MATERIALIZED; properties.return_type = StatementReturnType::NOTHING; return result; } diff --git a/src/duckdb/src/planner/binder/statement/bind_execute.cpp b/src/duckdb/src/planner/binder/statement/bind_execute.cpp index cceb6796c..1202b01fa 100644 --- a/src/duckdb/src/planner/binder/statement/bind_execute.cpp +++ b/src/duckdb/src/planner/binder/statement/bind_execute.cpp @@ -79,7 +79,7 @@ BoundStatement Binder::Bind(ExecuteStatement &stmt) { prepared = prepared_planner.PrepareSQLStatement(entry->second->unbound_statement->Copy()); rebound_plan = std::move(prepared_planner.plan); D_ASSERT(prepared->properties.bound_all_parameters); - this->bound_tables = prepared_planner.binder->bound_tables; + global_binder_state->bound_tables = prepared_planner.binder->global_binder_state->bound_tables; } // copy the properties of the prepared statement into the planner auto &properties = GetStatementProperties(); diff --git a/src/duckdb/src/planner/binder/statement/bind_export.cpp b/src/duckdb/src/planner/binder/statement/bind_export.cpp index 20d2606fe..0e6f63020 100644 --- a/src/duckdb/src/planner/binder/statement/bind_export.cpp +++ b/src/duckdb/src/planner/binder/statement/bind_export.cpp @@ -302,7 +302,7 @@ BoundStatement Binder::Bind(ExportStatement &stmt) { result.plan = std::move(export_node); auto &properties = GetStatementProperties(); - properties.allow_stream_result = false; + properties.output_type = QueryResultOutputType::FORCE_MATERIALIZED; properties.return_type = StatementReturnType::NOTHING; return result; } diff --git a/src/duckdb/src/planner/binder/statement/bind_extension.cpp b/src/duckdb/src/planner/binder/statement/bind_extension.cpp index b4fc0e86b..6569315f7 100644 --- a/src/duckdb/src/planner/binder/statement/bind_extension.cpp +++ b/src/duckdb/src/planner/binder/statement/bind_extension.cpp @@ -5,8 +5,6 @@ namespace duckdb { BoundStatement Binder::Bind(ExtensionStatement &stmt) { - BoundStatement result; - // perform the planning of the function D_ASSERT(stmt.extension.plan_function); auto parse_result = @@ -18,11 +16,9 @@ BoundStatement Binder::Bind(ExtensionStatement &stmt) { properties.return_type = parse_result.return_type; // create the plan as a scan of the given table function - result.plan = BindTableFunction(parse_result.function, std::move(parse_result.parameters)); + auto result = BindTableFunction(parse_result.function, std::move(parse_result.parameters)); D_ASSERT(result.plan->type == LogicalOperatorType::LOGICAL_GET); auto &get = result.plan->Cast(); - result.names = get.names; - result.types = get.returned_types; get.ClearColumnIds(); for (idx_t i = 0; i < get.returned_types.size(); i++) { get.AddColumnId(i); diff --git a/src/duckdb/src/planner/binder/statement/bind_insert.cpp b/src/duckdb/src/planner/binder/statement/bind_insert.cpp index f2c8db644..d9ba3cc81 100644 --- a/src/duckdb/src/planner/binder/statement/bind_insert.cpp +++ b/src/duckdb/src/planner/binder/statement/bind_insert.cpp @@ -22,9 +22,6 @@ #include "duckdb/planner/expression/bound_default_expression.hpp" #include "duckdb/catalog/catalog_entry/index_catalog_entry.hpp" #include "duckdb/catalog/catalog_entry/table_catalog_entry.hpp" -#include "duckdb/planner/bound_tableref.hpp" -#include "duckdb/planner/tableref/bound_basetableref.hpp" -#include "duckdb/planner/tableref/bound_dummytableref.hpp" #include "duckdb/parser/parsed_expression_iterator.hpp" #include "duckdb/storage/table_storage_info.hpp" #include "duckdb/parser/tableref/basetableref.hpp" @@ -99,7 +96,6 @@ void DoUpdateSetQualify(unique_ptr &expr, const string &table_ void DoUpdateSetQualifyInLambda(FunctionExpression &function, const string &table_name, vector> &lambda_params) { - for (auto &child : function.children) { if (child->GetExpressionClass() != ExpressionClass::LAMBDA) { DoUpdateSetQualify(child, table_name, lambda_params); @@ -141,7 +137,6 @@ void DoUpdateSetQualifyInLambda(FunctionExpression &function, const string &tabl void DoUpdateSetQualify(unique_ptr &expr, const string &table_name, vector> &lambda_params) { - // We avoid ambiguity with EXCLUDED columns by qualifying all column references. switch (expr->GetExpressionClass()) { case ExpressionClass::COLUMN_REF: { @@ -277,7 +272,7 @@ unique_ptr Binder::GenerateMergeInto(InsertStatement &stmt, auto storage_info = table.GetStorageInfo(context); auto &columns = table.GetColumns(); // set up the columns on which to join - vector distinct_on_columns; + vector> all_distinct_on_columns; if (on_conflict_info.indexed_columns.empty()) { // When omitting the conflict target, we derive the join columns from the primary key/unique constraints // traverse the primary key/unique constraints @@ -292,6 +287,7 @@ unique_ptr Binder::GenerateMergeInto(InsertStatement &stmt, vector> and_children; auto &indexed_columns = index.column_set; + vector distinct_on_columns; for (auto &column : columns.Physical()) { if (!indexed_columns.count(column.Physical().index)) { continue; @@ -303,6 +299,7 @@ unique_ptr Binder::GenerateMergeInto(InsertStatement &stmt, and_children.push_back(std::move(new_condition)); distinct_on_columns.push_back(column.Name()); } + all_distinct_on_columns.push_back(std::move(distinct_on_columns)); if (and_children.empty()) { continue; } @@ -377,7 +374,7 @@ unique_ptr Binder::GenerateMergeInto(InsertStatement &stmt, throw BinderException("The specified columns as conflict target are not referenced by a UNIQUE/PRIMARY KEY " "CONSTRAINT or INDEX"); } - distinct_on_columns = on_conflict_info.indexed_columns; + all_distinct_on_columns.push_back(on_conflict_info.indexed_columns); merge_into->using_columns = std::move(on_conflict_info.indexed_columns); } @@ -445,17 +442,19 @@ unique_ptr Binder::GenerateMergeInto(InsertStatement &stmt, } } // push DISTINCT ON(unique_columns) - auto distinct_stmt = make_uniq(); - auto select_node = make_uniq(); - auto distinct = make_uniq(); - for (auto &col : distinct_on_columns) { - distinct->distinct_on_targets.push_back(make_uniq(col)); + for (auto &distinct_on_columns : all_distinct_on_columns) { + auto distinct_stmt = make_uniq(); + auto select_node = make_uniq(); + auto distinct = make_uniq(); + for (auto &col : distinct_on_columns) { + distinct->distinct_on_targets.push_back(make_uniq(col)); + } + select_node->modifiers.push_back(std::move(distinct)); + select_node->select_list.push_back(make_uniq()); + select_node->from_table = std::move(source); + distinct_stmt->node = std::move(select_node); + source = make_uniq(std::move(distinct_stmt), "excluded"); } - select_node->modifiers.push_back(std::move(distinct)); - select_node->select_list.push_back(make_uniq()); - select_node->from_table = std::move(source); - distinct_stmt->node = std::move(select_node); - source = make_uniq(std::move(distinct_stmt), "excluded"); merge_into->source = std::move(source); @@ -519,8 +518,6 @@ BoundStatement Binder::Bind(InsertStatement &stmt) { } auto insert = make_uniq(table, GenerateTableIndex()); - // Add CTEs as bindable - AddCTEMap(stmt.cte_map); auto values_list = stmt.GetValuesList(); @@ -593,7 +590,7 @@ BoundStatement Binder::Bind(InsertStatement &stmt) { result.plan = std::move(insert); auto &properties = GetStatementProperties(); - properties.allow_stream_result = false; + properties.output_type = QueryResultOutputType::FORCE_MATERIALIZED; properties.return_type = StatementReturnType::CHANGED_ROWS; return result; } diff --git a/src/duckdb/src/planner/binder/statement/bind_load.cpp b/src/duckdb/src/planner/binder/statement/bind_load.cpp index 53d8f5792..a252716fe 100644 --- a/src/duckdb/src/planner/binder/statement/bind_load.cpp +++ b/src/duckdb/src/planner/binder/statement/bind_load.cpp @@ -24,7 +24,7 @@ BoundStatement Binder::Bind(LoadStatement &stmt) { result.plan = make_uniq(LogicalOperatorType::LOGICAL_LOAD, std::move(stmt.info)); auto &properties = GetStatementProperties(); - properties.allow_stream_result = false; + properties.output_type = QueryResultOutputType::FORCE_MATERIALIZED; properties.return_type = StatementReturnType::NOTHING; return result; } diff --git a/src/duckdb/src/planner/binder/statement/bind_logical_plan.cpp b/src/duckdb/src/planner/binder/statement/bind_logical_plan.cpp index 5b187c8e3..1fc7a188f 100644 --- a/src/duckdb/src/planner/binder/statement/bind_logical_plan.cpp +++ b/src/duckdb/src/planner/binder/statement/bind_logical_plan.cpp @@ -26,13 +26,13 @@ BoundStatement Binder::Bind(LogicalPlanStatement &stmt) { result.plan = std::move(stmt.plan); auto &properties = GetStatementProperties(); - properties.allow_stream_result = true; + properties.output_type = QueryResultOutputType::ALLOW_STREAMING; properties.return_type = StatementReturnType::QUERY_RESULT; // TODO could also be something else if (parent) { throw InternalException("LogicalPlanStatement should be bound in root binder"); } - bound_tables = GetMaxTableIndex(*result.plan) + 1; + global_binder_state->bound_tables = GetMaxTableIndex(*result.plan) + 1; return result; } diff --git a/src/duckdb/src/planner/binder/statement/bind_merge_into.cpp b/src/duckdb/src/planner/binder/statement/bind_merge_into.cpp index 87a9726ec..1280d492b 100644 --- a/src/duckdb/src/planner/binder/statement/bind_merge_into.cpp +++ b/src/duckdb/src/planner/binder/statement/bind_merge_into.cpp @@ -1,6 +1,5 @@ #include "duckdb/planner/binder.hpp" #include "duckdb/parser/statement/merge_into_statement.hpp" -#include "duckdb/planner/tableref/bound_basetableref.hpp" #include "duckdb/planner/tableref/bound_joinref.hpp" #include "duckdb/planner/operator/logical_get.hpp" #include "duckdb/planner/expression_binder/where_binder.hpp" @@ -173,16 +172,45 @@ void RewriteMergeBindings(LogicalOperator &op, const vector &sour op, [&](unique_ptr *child) { RewriteMergeBindings(*child, source_bindings, new_table_index); }); } +LogicalGet &ExtractLogicalGet(LogicalOperator &op) { + reference current_op(op); + while (current_op.get().type == LogicalOperatorType::LOGICAL_FILTER) { + current_op = *current_op.get().children[0]; + } + if (current_op.get().type != LogicalOperatorType::LOGICAL_GET) { + throw InvalidInputException("BindMerge - expected to find an operator of type LOGICAL_GET but got %s", + op.ToString()); + } + return current_op.get().Cast(); +} + +void CheckMergeAction(MergeActionCondition condition, MergeActionType action_type) { + if (condition == MergeActionCondition::WHEN_NOT_MATCHED_BY_TARGET) { + switch (action_type) { + case MergeActionType::MERGE_UPDATE: + case MergeActionType::MERGE_DELETE: + throw ParserException("WHEN NOT MATCHED (BY TARGET) cannot be combined with UPDATE or DELETE actions - as " + "there is no corresponding row in the target to update or delete.\nDid you mean to " + "use WHEN MATCHED or WHEN NOT MATCHED BY SOURCE?"); + default: + break; + } + } +} + BoundStatement Binder::Bind(MergeIntoStatement &stmt) { // bind the target table auto target_binder = Binder::CreateBinder(context, this); string table_alias = stmt.target->alias; auto bound_table = target_binder->Bind(*stmt.target); - if (bound_table->type != TableReferenceType::BASE_TABLE) { + if (bound_table.plan->type != LogicalOperatorType::LOGICAL_GET) { + throw BinderException("Can only merge into base tables!"); + } + auto table_ptr = bound_table.plan->Cast().GetTable(); + if (!table_ptr) { throw BinderException("Can only merge into base tables!"); } - auto &table_binding = bound_table->Cast(); - auto &table = table_binding.table; + auto &table = *table_ptr; if (!table.temporary) { // update of persistent table: not read only! auto &properties = GetStatementProperties(); @@ -198,9 +226,10 @@ BoundStatement Binder::Bind(MergeIntoStatement &stmt) { vector source_names; for (auto &binding_entry : source_binder->bind_context.GetBindingsList()) { auto &binding = *binding_entry; - for (idx_t c = 0; c < binding.names.size(); c++) { - source_aliases.push_back(binding.alias); - source_names.push_back(binding.names[c]); + auto &column_names = binding.GetColumnNames(); + for (idx_t c = 0; c < column_names.size(); c++) { + source_aliases.push_back(binding.GetBindingAlias()); + source_names.push_back(column_names[c]); } } @@ -231,11 +260,19 @@ BoundStatement Binder::Bind(MergeIntoStatement &stmt) { } auto bound_join_node = Bind(join); - auto root = CreatePlan(*bound_join_node); + auto root = std::move(bound_join_node.plan); + auto join_ref = reference(*root); + while (join_ref.get().children.size() == 1) { + join_ref = *join_ref.get().children[0]; + } + if (join_ref.get().children.size() != 2) { + throw NotImplementedException("Expected a join after binding a join operator - but got a %s", + join_ref.get().type); + } // kind of hacky, CreatePlan turns a RIGHT join into a LEFT join so the children get reversed from what we need bool inverted = join.type == JoinType::RIGHT; - auto &source = root->children[inverted ? 1 : 0]; - auto &get = root->children[inverted ? 0 : 1]->Cast(); + auto &source = join_ref.get().children[inverted ? 1 : 0]; + auto &get = ExtractLogicalGet(*join_ref.get().children[inverted ? 0 : 1]); auto merge_into = make_uniq(table); merge_into->table_index = GenerateTableIndex(); @@ -257,6 +294,7 @@ BoundStatement Binder::Bind(MergeIntoStatement &stmt) { for (auto &entry : stmt.actions) { vector> bound_actions; for (auto &action : entry.second) { + CheckMergeAction(entry.first, action->action_type); bound_actions.push_back(BindMergeAction(*merge_into, table, get, proj_index, projection_expressions, root, *action, source_aliases, source_names)); } @@ -327,7 +365,7 @@ BoundStatement Binder::Bind(MergeIntoStatement &stmt) { result.types = {LogicalType::BIGINT}; auto &properties = GetStatementProperties(); - properties.allow_stream_result = false; + properties.output_type = QueryResultOutputType::FORCE_MATERIALIZED; properties.return_type = StatementReturnType::CHANGED_ROWS; return result; } diff --git a/src/duckdb/src/planner/binder/statement/bind_pragma.cpp b/src/duckdb/src/planner/binder/statement/bind_pragma.cpp index 3955cf897..b5fc04677 100644 --- a/src/duckdb/src/planner/binder/statement/bind_pragma.cpp +++ b/src/duckdb/src/planner/binder/statement/bind_pragma.cpp @@ -2,6 +2,7 @@ #include "duckdb/parser/statement/pragma_statement.hpp" #include "duckdb/planner/operator/logical_pragma.hpp" #include "duckdb/catalog/catalog_entry/pragma_function_catalog_entry.hpp" +#include "duckdb/catalog/catalog_entry/table_function_catalog_entry.hpp" #include "duckdb/catalog/catalog.hpp" #include "duckdb/function/function_binder.hpp" #include "duckdb/planner/expression_binder/constant_binder.hpp" @@ -28,16 +29,32 @@ unique_ptr Binder::BindPragma(PragmaInfo &info, QueryErrorConte } // bind the pragma function - auto &entry = Catalog::GetEntry(context, INVALID_CATALOG, DEFAULT_SCHEMA, info.name); + auto entry = Catalog::GetEntry(context, INVALID_CATALOG, DEFAULT_SCHEMA, info.name, + OnEntryNotFound::RETURN_NULL); + if (!entry) { + // try to find whether a table extry might exist + auto table_entry = Catalog::GetEntry(context, INVALID_CATALOG, DEFAULT_SCHEMA, + info.name, OnEntryNotFound::RETURN_NULL); + if (table_entry) { + // there is a table entry with the same name, now throw more explicit error message + throw CatalogException("Pragma Function with name %s does not exist, but a table function with the same " + "name exists, try `CALL %s(...)`", + info.name, info.name); + } + // rebind to throw exception + entry = Catalog::GetEntry(context, INVALID_CATALOG, DEFAULT_SCHEMA, info.name, + OnEntryNotFound::THROW_EXCEPTION); + } + FunctionBinder function_binder(*this); ErrorData error; - auto bound_idx = function_binder.BindFunction(entry.name, entry.functions, params, error); + auto bound_idx = function_binder.BindFunction(entry->name, entry->functions, params, error); if (!bound_idx.IsValid()) { D_ASSERT(error.HasError()); error.AddQueryLocation(error_context); error.Throw(); } - auto bound_function = entry.functions.GetFunctionByOffset(bound_idx.GetIndex()); + auto bound_function = entry->functions.GetFunctionByOffset(bound_idx.GetIndex()); // bind and check named params BindNamedParameters(bound_function.named_parameters, named_parameters, error_context, bound_function.name); return make_uniq(std::move(bound_function), std::move(params), std::move(named_parameters)); diff --git a/src/duckdb/src/planner/binder/statement/bind_prepare.cpp b/src/duckdb/src/planner/binder/statement/bind_prepare.cpp index cbb338dfc..4c0579726 100644 --- a/src/duckdb/src/planner/binder/statement/bind_prepare.cpp +++ b/src/duckdb/src/planner/binder/statement/bind_prepare.cpp @@ -8,7 +8,7 @@ namespace duckdb { BoundStatement Binder::Bind(PrepareStatement &stmt) { Planner prepared_planner(context); auto prepared_data = prepared_planner.PrepareSQLStatement(std::move(stmt.statement)); - this->bound_tables = prepared_planner.binder->bound_tables; + global_binder_state->bound_tables = prepared_planner.binder->global_binder_state->bound_tables; if (prepared_planner.properties.always_require_rebind) { // we always need to rebind - don't keep the plan around @@ -20,7 +20,7 @@ BoundStatement Binder::Bind(PrepareStatement &stmt) { // this is required because most clients ALWAYS invoke prepared statements auto &properties = GetStatementProperties(); properties.requires_valid_transaction = false; - properties.allow_stream_result = false; + properties.output_type = QueryResultOutputType::FORCE_MATERIALIZED; properties.bound_all_parameters = true; properties.parameter_count = 0; properties.return_type = StatementReturnType::NOTHING; diff --git a/src/duckdb/src/planner/binder/statement/bind_select.cpp b/src/duckdb/src/planner/binder/statement/bind_select.cpp index ee68d0e25..a2656d076 100644 --- a/src/duckdb/src/planner/binder/statement/bind_select.cpp +++ b/src/duckdb/src/planner/binder/statement/bind_select.cpp @@ -6,7 +6,7 @@ namespace duckdb { BoundStatement Binder::Bind(SelectStatement &stmt) { auto &properties = GetStatementProperties(); - properties.allow_stream_result = true; + properties.output_type = QueryResultOutputType::ALLOW_STREAMING; properties.return_type = StatementReturnType::QUERY_RESULT; return Bind(*stmt.node); } diff --git a/src/duckdb/src/planner/binder/statement/bind_simple.cpp b/src/duckdb/src/planner/binder/statement/bind_simple.cpp index 942f6784c..46758e416 100644 --- a/src/duckdb/src/planner/binder/statement/bind_simple.cpp +++ b/src/duckdb/src/planner/binder/statement/bind_simple.cpp @@ -60,16 +60,15 @@ BoundStatement Binder::BindAlterAddIndex(BoundStatement &result, CatalogEntry &e TableDescription table_description(table_info.catalog, table_info.schema, table_info.name); auto table_ref = make_uniq(table_description); auto bound_table = Bind(*table_ref); - if (bound_table->type != TableReferenceType::BASE_TABLE) { + if (bound_table.plan->type != LogicalOperatorType::LOGICAL_GET) { throw BinderException("can only add an index to a base table"); } - auto plan = CreatePlan(*bound_table); - auto &get = plan->Cast(); + auto &get = bound_table.plan->Cast(); get.names = column_list.GetColumnNames(); auto alter_table_info = unique_ptr_cast(std::move(alter_info)); - result.plan = table.catalog.BindAlterAddIndex(*this, table, std::move(plan), std::move(create_index_info), - std::move(alter_table_info)); + result.plan = table.catalog.BindAlterAddIndex(*this, table, std::move(bound_table.plan), + std::move(create_index_info), std::move(alter_table_info)); return std::move(result); } diff --git a/src/duckdb/src/planner/binder/statement/bind_summarize.cpp b/src/duckdb/src/planner/binder/statement/bind_summarize.cpp index 45b2b2f25..f8a68ae4c 100644 --- a/src/duckdb/src/planner/binder/statement/bind_summarize.cpp +++ b/src/duckdb/src/planner/binder/statement/bind_summarize.cpp @@ -9,7 +9,6 @@ #include "duckdb/parser/tableref/showref.hpp" #include "duckdb/parser/tableref/basetableref.hpp" #include "duckdb/parser/expression/star_expression.hpp" -#include "duckdb/planner/bound_tableref.hpp" namespace duckdb { @@ -78,7 +77,7 @@ static unique_ptr SummarizeCreateNullPercentage(string column_ return make_uniq(LogicalType::DECIMAL(9, 2), std::move(case_expr)); } -unique_ptr Binder::BindSummarize(ShowRef &ref) { +BoundStatement Binder::BindSummarize(ShowRef &ref) { unique_ptr query; if (ref.query) { query = std::move(ref.query); diff --git a/src/duckdb/src/planner/binder/statement/bind_update.cpp b/src/duckdb/src/planner/binder/statement/bind_update.cpp index 650b23b89..d660c5155 100644 --- a/src/duckdb/src/planner/binder/statement/bind_update.cpp +++ b/src/duckdb/src/planner/binder/statement/bind_update.cpp @@ -2,7 +2,6 @@ #include "duckdb/parser/statement/update_statement.hpp" #include "duckdb/planner/binder.hpp" #include "duckdb/planner/tableref/bound_joinref.hpp" -#include "duckdb/planner/bound_tableref.hpp" #include "duckdb/planner/constraints/bound_check_constraint.hpp" #include "duckdb/planner/expression/bound_columnref_expression.hpp" #include "duckdb/planner/expression/bound_default_expression.hpp" @@ -12,7 +11,6 @@ #include "duckdb/planner/operator/logical_get.hpp" #include "duckdb/planner/operator/logical_projection.hpp" #include "duckdb/planner/operator/logical_update.hpp" -#include "duckdb/planner/tableref/bound_basetableref.hpp" #include "duckdb/catalog/catalog_entry/duck_table_entry.hpp" #include "duckdb/storage/data_table.hpp" @@ -110,14 +108,15 @@ BoundStatement Binder::Bind(UpdateStatement &stmt) { // visit the table reference auto bound_table = Bind(*stmt.table); - if (bound_table->type != TableReferenceType::BASE_TABLE) { - throw BinderException("Can only update base table!"); + if (bound_table.plan->type != LogicalOperatorType::LOGICAL_GET) { + throw BinderException("Can only update base table"); } - auto &table_binding = bound_table->Cast(); - auto &table = table_binding.table; - - // Add CTEs as bindable - AddCTEMap(stmt.cte_map); + auto &bound_table_get = bound_table.plan->Cast(); + auto table_ptr = bound_table_get.GetTable(); + if (!table_ptr) { + throw BinderException("Can only update base table"); + } + auto &table = *table_ptr; optional_ptr get; if (stmt.from_table) { @@ -129,7 +128,7 @@ BoundStatement Binder::Bind(UpdateStatement &stmt) { get = &root->children[0]->Cast(); bind_context.AddContext(std::move(from_binder->bind_context)); } else { - root = CreatePlan(*bound_table); + root = std::move(bound_table.plan); get = &root->Cast(); } @@ -192,7 +191,7 @@ BoundStatement Binder::Bind(UpdateStatement &stmt) { result.plan = std::move(update); auto &properties = GetStatementProperties(); - properties.allow_stream_result = false; + properties.output_type = QueryResultOutputType::FORCE_MATERIALIZED; properties.return_type = StatementReturnType::CHANGED_ROWS; return result; } diff --git a/src/duckdb/src/planner/binder/statement/bind_vacuum.cpp b/src/duckdb/src/planner/binder/statement/bind_vacuum.cpp index 93e70fe5b..026f682b0 100644 --- a/src/duckdb/src/planner/binder/statement/bind_vacuum.cpp +++ b/src/duckdb/src/planner/binder/statement/bind_vacuum.cpp @@ -15,12 +15,18 @@ void Binder::BindVacuumTable(LogicalVacuum &vacuum, unique_ptr } D_ASSERT(vacuum.column_id_map.empty()); + auto bound_table = Bind(*info.ref); - if (bound_table->type != TableReferenceType::BASE_TABLE) { - throw InvalidInputException("can only vacuum or analyze base tables"); + if (bound_table.plan->type != LogicalOperatorType::LOGICAL_GET) { + throw BinderException("Can only vacuum or analyze base tables"); + } + auto table_scan = std::move(bound_table.plan); + auto &get = table_scan->Cast(); + auto table_ptr = get.GetTable(); + if (!table_ptr) { + throw BinderException("Can only vacuum or analyze base tables"); } - auto ref = unique_ptr_cast(std::move(bound_table)); - auto &table = ref->table; + auto &table = *table_ptr; vacuum.SetTable(table); vector> select_list; @@ -60,11 +66,6 @@ void Binder::BindVacuumTable(LogicalVacuum &vacuum, unique_ptr } info.columns = std::move(non_generated_column_names); - auto table_scan = CreatePlan(*ref); - D_ASSERT(table_scan->type == LogicalOperatorType::LOGICAL_GET); - - auto &get = table_scan->Cast(); - auto &column_ids = get.GetColumnIds(); D_ASSERT(select_list.size() == column_ids.size()); D_ASSERT(info.columns.size() == column_ids.size()); diff --git a/src/duckdb/src/planner/binder/tableref/bind_basetableref.cpp b/src/duckdb/src/planner/binder/tableref/bind_basetableref.cpp index da1dacb15..1a005e5cc 100644 --- a/src/duckdb/src/planner/binder/tableref/bind_basetableref.cpp +++ b/src/duckdb/src/planner/binder/tableref/bind_basetableref.cpp @@ -11,15 +11,13 @@ #include "duckdb/parser/tableref/table_function_ref.hpp" #include "duckdb/planner/binder.hpp" #include "duckdb/planner/operator/logical_get.hpp" -#include "duckdb/planner/tableref/bound_basetableref.hpp" -#include "duckdb/planner/tableref/bound_cteref.hpp" -#include "duckdb/planner/tableref/bound_dummytableref.hpp" -#include "duckdb/planner/tableref/bound_subqueryref.hpp" +#include "duckdb/planner/operator/logical_cteref.hpp" #include "duckdb/planner/expression_binder/constant_binder.hpp" #include "duckdb/catalog/catalog_search_path.hpp" #include "duckdb/planner/tableref/bound_at_clause.hpp" #include "duckdb/execution/expression_executor.hpp" #include "duckdb/parser/query_node/cte_node.hpp" +#include "duckdb/planner/operator/logical_dummy_scan.hpp" namespace duckdb { @@ -48,10 +46,10 @@ static bool TryLoadExtensionForReplacementScan(ClientContext &context, const str return false; } -unique_ptr Binder::BindWithReplacementScan(ClientContext &context, BaseTableRef &ref) { +BoundStatement Binder::BindWithReplacementScan(ClientContext &context, BaseTableRef &ref) { auto &config = DBConfig::GetConfig(context); if (!context.config.use_replacement_scans) { - return nullptr; + return BoundStatement(); } for (auto &scan : config.replacement_scans) { ReplacementScanInput input(ref.catalog_name, ref.schema_name, ref.table_name); @@ -73,14 +71,21 @@ unique_ptr Binder::BindWithReplacementScan(ClientContext &context auto &subquery = replacement_function->Cast(); subquery.column_name_alias = ref.column_name_alias; } else { - throw InternalException("Replacement scan should return either a table function or a subquery"); + auto select_node = make_uniq(); + select_node->select_list.push_back(make_uniq()); + select_node->from_table = std::move(replacement_function); + auto select_stmt = make_uniq(); + select_stmt->node = std::move(select_node); + auto subquery = make_uniq(std::move(select_stmt)); + subquery->column_name_alias = ref.column_name_alias; + replacement_function = std::move(subquery); } if (GetBindingMode() == BindingMode::EXTRACT_REPLACEMENT_SCANS) { AddReplacementScan(ref.table_name, replacement_function->Copy()); } return Bind(*replacement_function); } - return nullptr; + return BoundStatement(); } unique_ptr Binder::BindAtClause(optional_ptr at_clause) { @@ -116,62 +121,36 @@ static vector ExchangeAllNullTypes(const vector &types return result; } -unique_ptr Binder::Bind(BaseTableRef &ref) { +BoundStatement Binder::Bind(BaseTableRef &ref) { QueryErrorContext error_context(ref.query_location); // CTEs and views are also referred to using BaseTableRefs, hence need to distinguish here // check if the table name refers to a CTE // CTE name should never be qualified (i.e. schema_name should be empty) // unless we want to refer to the recurring table of "using key". - vector> found_ctes; - if (ref.schema_name.empty() || ref.schema_name == "recurring") { - found_ctes = FindCTE(ref.table_name, false); - } - - if (!found_ctes.empty()) { - // Check if there is a CTE binding in the BindContext - auto ctebinding = bind_context.GetCTEBinding(ref.table_name); - if (ctebinding) { - // There is a CTE binding in the BindContext. - // This can only be the case if there is a recursive CTE, - // or a materialized CTE present. - auto index = GenerateTableIndex(); - - if (ref.schema_name == "recurring") { - auto recurring_bindings = FindCTE("recurring." + ref.table_name, false); - if (recurring_bindings.empty()) { - throw BinderException(error_context, - "There is a WITH item named \"%s\", but the recurring table cannot be " - "referenced from this part of the query." - " Hint: RECURRING can only be used with USING KEY in recursive CTE.", - ref.table_name); - } - } - - auto result = make_uniq(index, ctebinding->index, ref.schema_name == "recurring"); - auto alias = ref.alias.empty() ? ref.table_name : ref.alias; - auto names = BindContext::AliasColumnNames(alias, ctebinding->names, ref.column_name_alias); + BindingAlias binding_alias(ref.schema_name, ref.table_name); + auto ctebinding = GetCTEBinding(binding_alias); + if (ctebinding && ctebinding->CanBeReferenced()) { + ctebinding->Reference(); - bind_context.AddGenericBinding(index, alias, names, ctebinding->types); + // There is a CTE binding in the BindContext. + // This can only be the case if there is a recursive CTE, + // or a materialized CTE present. + auto index = GenerateTableIndex(); - auto cte_reference = ref.schema_name.empty() ? ref.table_name : ref.schema_name + "." + ref.table_name; + auto alias = ref.alias.empty() ? ref.table_name : ref.alias; + auto names = BindContext::AliasColumnNames(alias, ctebinding->GetColumnNames(), ref.column_name_alias); - // Update references to CTE - auto cteref = bind_context.cte_references[cte_reference]; - - if (cteref == nullptr && ref.schema_name == "recurring") { - throw BinderException(error_context, - "There is a WITH item named \"%s\", but the recurring table cannot be " - "referenced from this part of the query.", - ref.table_name); - } + bind_context.AddGenericBinding(index, alias, names, ctebinding->GetColumnTypes()); - (*cteref)++; + bool is_recurring = ref.schema_name == "recurring"; - result->types = ctebinding->types; - result->bound_columns = std::move(names); - return std::move(result); - } + BoundStatement result; + result.types = ctebinding->GetColumnTypes(); + result.names = names; + result.plan = + make_uniq(index, ctebinding->GetIndex(), result.types, std::move(names), is_recurring); + return result; } // not a CTE @@ -198,14 +177,19 @@ unique_ptr Binder::Bind(BaseTableRef &ref) { vector types {LogicalType::INTEGER}; vector names {"__dummy_col" + to_string(table_index)}; bind_context.AddGenericBinding(table_index, ref_alias, names, types); - return make_uniq_base(table_index); + + BoundStatement result; + result.types = std::move(types); + result.names = std::move(names); + result.plan = make_uniq(table_index); + return result; } } if (!table_or_view) { // table could not be found: try to bind a replacement scan // Try replacement scan bind auto replacement_scan_bind_result = BindWithReplacementScan(context, ref); - if (replacement_scan_bind_result) { + if (replacement_scan_bind_result.plan) { return replacement_scan_bind_result; } @@ -214,7 +198,7 @@ unique_ptr Binder::Bind(BaseTableRef &ref) { auto extension_loaded = TryLoadExtensionForReplacementScan(context, full_path); if (extension_loaded) { replacement_scan_bind_result = BindWithReplacementScan(context, ref); - if (replacement_scan_bind_result) { + if (replacement_scan_bind_result.plan) { return replacement_scan_bind_result; } } @@ -230,17 +214,13 @@ unique_ptr Binder::Bind(BaseTableRef &ref) { } } - // remember that we did not find a CTE, but there is a CTE with the same name - // this means that there is a circular reference - // Otherwise, re-throw the original exception - if (found_ctes.empty() && ref.schema_name.empty() && CTEExists(ref.table_name)) { - throw BinderException( - error_context, - "Circular reference to CTE \"%s\", There are two possible solutions. \n1. use WITH RECURSIVE to " - "use recursive CTEs. \n2. If " - "you want to use the TABLE name \"%s\" the same as the CTE name, please explicitly add " - "\"SCHEMA\" before table name. You can try \"main.%s\" (main is the duckdb default schema)", - ref.table_name, ref.table_name, ref.table_name); + // if we found a CTE that cannot be referenced that means that there is a circular reference + if (ctebinding) { + D_ASSERT(!ctebinding->CanBeReferenced()); + throw BinderException(error_context, + "Circular reference to CTE \"%s\", use WITH RECURSIVE to " + "use recursive CTEs.", + ref.table_name); } // could not find an alternative: bind again to get the error // note: this will always throw when using DuckDB as a catalog, but a second look-up might succeed @@ -251,7 +231,7 @@ unique_ptr Binder::Bind(BaseTableRef &ref) { switch (table_or_view->type) { case CatalogType::TABLE_ENTRY: { - // base table: create the BoundBaseTableRef node + // base table auto table_index = GenerateTableIndex(); auto &table = table_or_view->Cast(); @@ -294,7 +274,11 @@ unique_ptr Binder::Bind(BaseTableRef &ref) { } else { bind_context.AddBaseTable(table_index, ref.alias, table_names, table_types, col_ids, *table_entry); } - return make_uniq_base(table, std::move(logical_get)); + BoundStatement result; + result.types = table_types; + result.names = table_names; + result.plan = std::move(logical_get); + return result; } case CatalogType::VIEW_ENTRY: { // the node is a view: get the query that the view represents @@ -307,29 +291,6 @@ unique_ptr Binder::Bind(BaseTableRef &ref) { // The view may contain CTEs, but maybe only in the cte_map, so we need create CTE nodes for them auto query = view_catalog_entry.GetQuery().Copy(); - auto &select_stmt = query->Cast(); - - vector> materialized_ctes; - for (auto &cte : select_stmt.node->cte_map.map) { - auto &cte_entry = cte.second; - auto mat_cte = make_uniq(); - mat_cte->ctename = cte.first; - mat_cte->query = cte_entry->query->node->Copy(); - mat_cte->aliases = cte_entry->aliases; - mat_cte->materialized = cte_entry->materialized; - materialized_ctes.push_back(std::move(mat_cte)); - } - - auto root = std::move(select_stmt.node); - while (!materialized_ctes.empty()) { - unique_ptr node_result; - node_result = std::move(materialized_ctes.back()); - node_result->child = std::move(root); - root = std::move(node_result); - materialized_ctes.pop_back(); - } - select_stmt.node = std::move(root); - SubqueryRef subquery(unique_ptr_cast(std::move(query))); subquery.alias = ref.alias; @@ -355,15 +316,13 @@ unique_ptr Binder::Bind(BaseTableRef &ref) { throw BinderException("Contents of view were altered - view bound correlated columns"); } - D_ASSERT(bound_child->type == TableReferenceType::SUBQUERY); // verify that the types and names match up with the expected types and names if the view has type info defined - auto &bound_subquery = bound_child->Cast(); if (GetBindingMode() != BindingMode::EXTRACT_NAMES && GetBindingMode() != BindingMode::EXTRACT_QUALIFIED_NAMES && view_catalog_entry.HasTypes()) { // we bind the view subquery and the original view with different "can_contain_nulls", // but we don't want to throw an error when SQLNULL does not match up with INTEGER, // so we exchange all SQLNULL with INTEGER here before comparing - auto bound_types = ExchangeAllNullTypes(bound_subquery.subquery->types); + auto bound_types = ExchangeAllNullTypes(bound_child.types); auto view_types = ExchangeAllNullTypes(view_catalog_entry.types); if (bound_types != view_types) { auto actual_types = StringUtil::ToString(bound_types, ", "); @@ -372,17 +331,17 @@ unique_ptr Binder::Bind(BaseTableRef &ref) { "Contents of view were altered: types don't match! Expected [%s], but found [%s] instead", expected_types, actual_types); } - if (bound_subquery.subquery->names.size() == view_catalog_entry.names.size() && - bound_subquery.subquery->names != view_catalog_entry.names) { - auto actual_names = StringUtil::Join(bound_subquery.subquery->names, ", "); + if (bound_child.names.size() == view_catalog_entry.names.size() && + bound_child.names != view_catalog_entry.names) { + auto actual_names = StringUtil::Join(bound_child.names, ", "); auto expected_names = StringUtil::Join(view_catalog_entry.names, ", "); throw BinderException( "Contents of view were altered: names don't match! Expected [%s], but found [%s] instead", expected_names, actual_names); } } - bind_context.AddView(bound_subquery.subquery->GetRootIndex(), subquery.alias, subquery, - *bound_subquery.subquery, view_catalog_entry); + bind_context.AddView(bound_child.plan->GetRootIndex(), subquery.alias, subquery, bound_child, + view_catalog_entry); return bound_child; } default: diff --git a/src/duckdb/src/planner/binder/tableref/bind_bound_table_ref.cpp b/src/duckdb/src/planner/binder/tableref/bind_bound_table_ref.cpp index e31c2e83c..ace531ccf 100644 --- a/src/duckdb/src/planner/binder/tableref/bind_bound_table_ref.cpp +++ b/src/duckdb/src/planner/binder/tableref/bind_bound_table_ref.cpp @@ -2,8 +2,8 @@ namespace duckdb { -unique_ptr Binder::Bind(BoundRefWrapper &ref) { - if (!ref.binder || !ref.bound_ref) { +BoundStatement Binder::Bind(BoundRefWrapper &ref) { + if (!ref.binder || !ref.bound_ref.plan) { throw InternalException("Rebinding bound ref that was already bound"); } bind_context.AddContext(std::move(ref.binder->bind_context)); diff --git a/src/duckdb/src/planner/binder/tableref/bind_column_data_ref.cpp b/src/duckdb/src/planner/binder/tableref/bind_column_data_ref.cpp index 635d23f71..d3c5ea4a2 100644 --- a/src/duckdb/src/planner/binder/tableref/bind_column_data_ref.cpp +++ b/src/duckdb/src/planner/binder/tableref/bind_column_data_ref.cpp @@ -1,20 +1,25 @@ #include "duckdb/planner/binder.hpp" #include "duckdb/parser/tableref/column_data_ref.hpp" -#include "duckdb/planner/tableref/bound_column_data_ref.hpp" #include "duckdb/planner/operator/logical_column_data_get.hpp" namespace duckdb { -unique_ptr Binder::Bind(ColumnDataRef &ref) { +BoundStatement Binder::Bind(ColumnDataRef &ref) { auto &collection = *ref.collection; auto types = collection.Types(); - auto result = make_uniq(std::move(ref.collection)); - result->bind_index = GenerateTableIndex(); - for (idx_t i = ref.expected_names.size(); i < types.size(); i++) { - ref.expected_names.push_back("col" + to_string(i + 1)); + + BoundStatement result; + result.names = std::move(ref.expected_names); + for (idx_t i = result.names.size(); i < types.size(); i++) { + result.names.push_back("col" + to_string(i + 1)); } - bind_context.AddGenericBinding(result->bind_index, ref.alias, ref.expected_names, types); - return unique_ptr_cast(std::move(result)); + result.types = types; + auto bind_index = GenerateTableIndex(); + bind_context.AddGenericBinding(bind_index, ref.alias, result.names, types); + + result.plan = + make_uniq_base(bind_index, std::move(types), std::move(ref.collection)); + return result; } } // namespace duckdb diff --git a/src/duckdb/src/planner/binder/tableref/bind_delimgetref.cpp b/src/duckdb/src/planner/binder/tableref/bind_delimgetref.cpp index f280404f9..18c27cccf 100644 --- a/src/duckdb/src/planner/binder/tableref/bind_delimgetref.cpp +++ b/src/duckdb/src/planner/binder/tableref/bind_delimgetref.cpp @@ -1,16 +1,21 @@ #include "duckdb/parser/tableref/delimgetref.hpp" #include "duckdb/planner/binder.hpp" -#include "duckdb/planner/tableref/bound_delimgetref.hpp" +#include "duckdb/planner/operator/logical_delim_get.hpp" namespace duckdb { -unique_ptr Binder::Bind(DelimGetRef &ref) { +BoundStatement Binder::Bind(DelimGetRef &ref) { // Have to add bindings idx_t tbl_idx = GenerateTableIndex(); string internal_name = "__internal_delim_get_ref_" + std::to_string(tbl_idx); - bind_context.AddGenericBinding(tbl_idx, internal_name, ref.internal_aliases, ref.types); - return make_uniq(tbl_idx, ref.types); + BoundStatement result; + result.types = std::move(ref.types); + result.names = std::move(ref.internal_aliases); + result.plan = make_uniq(tbl_idx, result.types); + + bind_context.AddGenericBinding(tbl_idx, internal_name, result.names, result.types); + return result; } } // namespace duckdb diff --git a/src/duckdb/src/planner/binder/tableref/bind_emptytableref.cpp b/src/duckdb/src/planner/binder/tableref/bind_emptytableref.cpp index fe0e96f3d..b6ea93ab8 100644 --- a/src/duckdb/src/planner/binder/tableref/bind_emptytableref.cpp +++ b/src/duckdb/src/planner/binder/tableref/bind_emptytableref.cpp @@ -1,11 +1,13 @@ #include "duckdb/parser/tableref/emptytableref.hpp" #include "duckdb/planner/binder.hpp" -#include "duckdb/planner/tableref/bound_dummytableref.hpp" +#include "duckdb/planner/operator/logical_dummy_scan.hpp" namespace duckdb { -unique_ptr Binder::Bind(EmptyTableRef &ref) { - return make_uniq(GenerateTableIndex()); +BoundStatement Binder::Bind(EmptyTableRef &ref) { + BoundStatement result; + result.plan = make_uniq(GenerateTableIndex()); + return result; } } // namespace duckdb diff --git a/src/duckdb/src/planner/binder/tableref/bind_expressionlistref.cpp b/src/duckdb/src/planner/binder/tableref/bind_expressionlistref.cpp index 7176fb682..139f94670 100644 --- a/src/duckdb/src/planner/binder/tableref/bind_expressionlistref.cpp +++ b/src/duckdb/src/planner/binder/tableref/bind_expressionlistref.cpp @@ -1,72 +1,87 @@ #include "duckdb/planner/binder.hpp" -#include "duckdb/planner/tableref/bound_expressionlistref.hpp" #include "duckdb/parser/tableref/expressionlistref.hpp" #include "duckdb/planner/expression_binder/insert_binder.hpp" #include "duckdb/common/to_string.hpp" #include "duckdb/planner/expression/bound_cast_expression.hpp" +#include "duckdb/planner/operator/logical_expression_get.hpp" +#include "duckdb/planner/operator/logical_dummy_scan.hpp" namespace duckdb { -unique_ptr Binder::Bind(ExpressionListRef &expr) { - auto result = make_uniq(); - result->types = expr.expected_types; - result->names = expr.expected_names; +BoundStatement Binder::Bind(ExpressionListRef &expr) { + BoundStatement result; + result.types = expr.expected_types; + result.names = expr.expected_names; + + vector>> values; auto prev_can_contain_nulls = this->can_contain_nulls; // bind value list InsertBinder binder(*this, context); binder.target_type = LogicalType(LogicalTypeId::INVALID); for (idx_t list_idx = 0; list_idx < expr.values.size(); list_idx++) { auto &expression_list = expr.values[list_idx]; - if (result->names.empty()) { + if (result.names.empty()) { // no names provided, generate them for (idx_t val_idx = 0; val_idx < expression_list.size(); val_idx++) { - result->names.push_back("col" + to_string(val_idx)); + result.names.push_back("col" + to_string(val_idx)); } } this->can_contain_nulls = true; vector> list; for (idx_t val_idx = 0; val_idx < expression_list.size(); val_idx++) { - if (!result->types.empty()) { - D_ASSERT(result->types.size() == expression_list.size()); - binder.target_type = result->types[val_idx]; + if (!result.types.empty()) { + D_ASSERT(result.types.size() == expression_list.size()); + binder.target_type = result.types[val_idx]; } auto bound_expr = binder.Bind(expression_list[val_idx]); list.push_back(std::move(bound_expr)); } - result->values.push_back(std::move(list)); + values.push_back(std::move(list)); this->can_contain_nulls = prev_can_contain_nulls; } - if (result->types.empty() && !expr.values.empty()) { + if (result.types.empty() && !expr.values.empty()) { // there are no types specified // we have to figure out the result types // for each column, we iterate over all of the expressions and select the max logical type // we initialize all types to SQLNULL - result->types.resize(expr.values[0].size(), LogicalType::SQLNULL); + result.types.resize(expr.values[0].size(), LogicalType::SQLNULL); // now loop over the lists and select the max logical type - for (idx_t list_idx = 0; list_idx < result->values.size(); list_idx++) { - auto &list = result->values[list_idx]; + for (idx_t list_idx = 0; list_idx < values.size(); list_idx++) { + auto &list = values[list_idx]; for (idx_t val_idx = 0; val_idx < list.size(); val_idx++) { - auto ¤t_type = result->types[val_idx]; + auto ¤t_type = result.types[val_idx]; auto next_type = ExpressionBinder::GetExpressionReturnType(*list[val_idx]); - result->types[val_idx] = LogicalType::MaxLogicalType(context, current_type, next_type); + result.types[val_idx] = LogicalType::MaxLogicalType(context, current_type, next_type); } } - for (auto &type : result->types) { + for (auto &type : result.types) { type = LogicalType::NormalizeType(type); } // finally do another loop over the expressions and add casts where required - for (idx_t list_idx = 0; list_idx < result->values.size(); list_idx++) { - auto &list = result->values[list_idx]; + for (idx_t list_idx = 0; list_idx < values.size(); list_idx++) { + auto &list = values[list_idx]; for (idx_t val_idx = 0; val_idx < list.size(); val_idx++) { list[val_idx] = - BoundCastExpression::AddCastToType(context, std::move(list[val_idx]), result->types[val_idx]); + BoundCastExpression::AddCastToType(context, std::move(list[val_idx]), result.types[val_idx]); } } } - result->bind_index = GenerateTableIndex(); - bind_context.AddGenericBinding(result->bind_index, expr.alias, result->names, result->types); - return std::move(result); + auto bind_index = GenerateTableIndex(); + bind_context.AddGenericBinding(bind_index, expr.alias, result.names, result.types); + + // values list, first plan any subqueries in the list + auto root = make_uniq_base(GenerateTableIndex()); + for (auto &expr_list : values) { + for (auto &expr : expr_list) { + PlanSubqueries(expr, root); + } + } + + auto expr_get = make_uniq(bind_index, result.types, std::move(values)); + expr_get->AddChild(std::move(root)); + result.plan = std::move(expr_get); + return result; } } // namespace duckdb diff --git a/src/duckdb/src/planner/binder/tableref/bind_joinref.cpp b/src/duckdb/src/planner/binder/tableref/bind_joinref.cpp index 257e275be..0a6420bfd 100644 --- a/src/duckdb/src/planner/binder/tableref/bind_joinref.cpp +++ b/src/duckdb/src/planner/binder/tableref/bind_joinref.cpp @@ -55,7 +55,7 @@ bool Binder::TryFindBinding(const string &using_column, const string &join_side, } throw BinderException(error); } else { - result = binding.get().alias; + result = binding.get().GetBindingAlias(); } } return true; @@ -122,14 +122,14 @@ static vector RemoveDuplicateUsingColumns(const vector &using_co return result; } -unique_ptr Binder::BindJoin(Binder &parent_binder, TableRef &ref) { +BoundStatement Binder::BindJoin(Binder &parent_binder, TableRef &ref) { unnamed_subquery_index = parent_binder.unnamed_subquery_index; auto result = Bind(ref); parent_binder.unnamed_subquery_index = unnamed_subquery_index; return result; } -unique_ptr Binder::Bind(JoinRef &ref) { +BoundStatement Binder::Bind(JoinRef &ref) { auto result = make_uniq(ref.ref_type); result->left_binder = Binder::CreateBinder(context, this); result->right_binder = Binder::CreateBinder(context, this); @@ -188,7 +188,7 @@ unique_ptr Binder::Bind(JoinRef &ref) { case_insensitive_set_t lhs_columns; auto &lhs_binding_list = left_binder.bind_context.GetBindingsList(); for (auto &binding : lhs_binding_list) { - for (auto &column_name : binding->names) { + for (auto &column_name : binding->GetColumnNames()) { lhs_columns.insert(column_name); } } @@ -215,7 +215,7 @@ unique_ptr Binder::Bind(JoinRef &ref) { auto &rhs_binding_list = right_binder.bind_context.GetBindingsList(); for (auto &binding_ref : lhs_binding_list) { auto &binding = *binding_ref; - for (auto &column_name : binding.names) { + for (auto &column_name : binding.GetColumnNames()) { if (!left_candidates.empty()) { left_candidates += ", "; } @@ -224,7 +224,7 @@ unique_ptr Binder::Bind(JoinRef &ref) { } for (auto &binding_ref : rhs_binding_list) { auto &binding = *binding_ref; - for (auto &column_name : binding.names) { + for (auto &column_name : binding.GetColumnNames()) { if (!right_candidates.empty()) { right_candidates += ", "; } @@ -351,7 +351,13 @@ unique_ptr Binder::Bind(JoinRef &ref) { bind_context.RemoveContext(left_bindings); } - return std::move(result); + BoundStatement result_stmt; + result_stmt.types.insert(result_stmt.types.end(), result->left.types.begin(), result->left.types.end()); + result_stmt.types.insert(result_stmt.types.end(), result->right.types.begin(), result->right.types.end()); + result_stmt.names.insert(result_stmt.names.end(), result->left.names.begin(), result->left.names.end()); + result_stmt.names.insert(result_stmt.names.end(), result->right.names.begin(), result->right.names.end()); + result_stmt.plan = CreatePlan(*result); + return result_stmt; } } // namespace duckdb diff --git a/src/duckdb/src/planner/binder/tableref/bind_pivot.cpp b/src/duckdb/src/planner/binder/tableref/bind_pivot.cpp index 2eb211530..869676a89 100644 --- a/src/duckdb/src/planner/binder/tableref/bind_pivot.cpp +++ b/src/duckdb/src/planner/binder/tableref/bind_pivot.cpp @@ -9,18 +9,18 @@ #include "duckdb/parser/expression/conjunction_expression.hpp" #include "duckdb/parser/expression/constant_expression.hpp" #include "duckdb/parser/expression/function_expression.hpp" -#include "duckdb/planner/query_node/bound_select_node.hpp" #include "duckdb/parser/expression/star_expression.hpp" #include "duckdb/common/types/value_map.hpp" #include "duckdb/parser/parsed_expression_iterator.hpp" #include "duckdb/parser/expression/operator_expression.hpp" -#include "duckdb/planner/tableref/bound_subqueryref.hpp" #include "duckdb/planner/tableref/bound_pivotref.hpp" #include "duckdb/planner/expression/bound_aggregate_expression.hpp" #include "duckdb/main/client_config.hpp" #include "duckdb/catalog/catalog_entry/aggregate_function_catalog_entry.hpp" #include "duckdb/catalog/catalog_entry/type_catalog_entry.hpp" #include "duckdb/main/query_result.hpp" +#include "duckdb/planner/operator/logical_aggregate.hpp" +#include "duckdb/planner/operator/logical_pivot.hpp" #include "duckdb/main/settings.hpp" namespace duckdb { @@ -58,10 +58,15 @@ static void ConstructPivots(PivotRef &ref, vector &pivot_valu } } -static void ExtractPivotExpressions(ParsedExpression &root_expr, case_insensitive_set_t &handled_columns) { +static void ExtractPivotExpressions(ParsedExpression &root_expr, case_insensitive_set_t &handled_columns, + optional_ptr macro_binding) { ParsedExpressionIterator::VisitExpression( root_expr, [&](const ColumnRefExpression &child_colref) { if (child_colref.IsQualified()) { + if (child_colref.column_names[0].find(DummyBinding::DUMMY_NAME) != string::npos && macro_binding && + macro_binding->HasMatchingBinding(child_colref.GetName())) { + throw ParameterNotResolvedException(); + } throw BinderException(child_colref, "PIVOT expression cannot contain qualified columns"); } handled_columns.insert(child_colref.GetColumnName()); @@ -378,24 +383,23 @@ static unique_ptr PivotFinalOperator(PivotBindState &bind_state, Piv return final_pivot_operator; } -void ExtractPivotAggregates(BoundTableRef &node, vector> &aggregates) { - if (node.type != TableReferenceType::SUBQUERY) { - throw InternalException("Pivot - Expected a subquery"); - } - auto &subq = node.Cast(); - if (subq.subquery->type != QueryNodeType::SELECT_NODE) { - throw InternalException("Pivot - Expected a select node"); - } - auto &select = subq.subquery->Cast(); - if (select.from_table->type != TableReferenceType::SUBQUERY) { - throw InternalException("Pivot - Expected another subquery"); - } - auto &subq2 = select.from_table->Cast(); - if (subq2.subquery->type != QueryNodeType::SELECT_NODE) { - throw InternalException("Pivot - Expected another select node"); +void ExtractPivotAggregates(BoundStatement &node, vector> &aggregates) { + reference op(*node.plan); + bool found_first_aggregate = false; + while (true) { + if (op.get().type == LogicalOperatorType::LOGICAL_AGGREGATE_AND_GROUP_BY) { + if (found_first_aggregate) { + break; + } + found_first_aggregate = true; + } + if (op.get().children.size() != 1) { + throw InternalException("Pivot - expected an aggregate"); + } + op = *op.get().children[0]; } - auto &select2 = subq2.subquery->Cast(); - for (auto &aggr : select2.aggregates) { + auto &aggr_op = op.get().Cast(); + for (auto &aggr : aggr_op.expressions) { if (aggr->GetAlias() == "__collated_group") { continue; } @@ -412,15 +416,15 @@ string GetPivotAggregateName(const PivotValueElement &pivot_value, const string return name; } -unique_ptr Binder::BindBoundPivot(PivotRef &ref) { +BoundStatement Binder::BindBoundPivot(PivotRef &ref) { // bind the child table in a child binder - auto result = make_uniq(); - result->bind_index = GenerateTableIndex(); - result->child_binder = Binder::CreateBinder(context, this); - result->child = result->child_binder->Bind(*ref.source); + BoundPivotRef result; + result.bind_index = GenerateTableIndex(); + result.child_binder = Binder::CreateBinder(context, this); + result.child = result.child_binder->Bind(*ref.source); - auto &aggregates = result->bound_pivot.aggregates; - ExtractPivotAggregates(*result->child, aggregates); + auto &aggregates = result.bound_pivot.aggregates; + ExtractPivotAggregates(result.child, aggregates); if (aggregates.size() != ref.bound_aggregate_names.size()) { throw InternalException("Pivot aggregate count mismatch (expected %llu, found %llu)", ref.bound_aggregate_names.size(), aggregates.size()); @@ -428,7 +432,7 @@ unique_ptr Binder::BindBoundPivot(PivotRef &ref) { vector child_names; vector child_types; - result->child_binder->bind_context.GetTypesAndNames(child_names, child_types); + result.child_binder->bind_context.GetTypesAndNames(child_names, child_types); vector names; vector types; @@ -453,19 +457,23 @@ unique_ptr Binder::BindBoundPivot(PivotRef &ref) { pivot_str += "_" + str; } } - result->bound_pivot.pivot_values.push_back(std::move(pivot_str)); + result.bound_pivot.pivot_values.push_back(std::move(pivot_str)); names.push_back(std::move(name)); types.push_back(aggr->return_type); } } - result->bound_pivot.group_count = ref.bound_group_names.size(); - result->bound_pivot.types = types; + result.bound_pivot.group_count = ref.bound_group_names.size(); + result.bound_pivot.types = types; auto subquery_alias = ref.alias.empty() ? "__unnamed_pivot" : ref.alias; QueryResult::DeduplicateColumns(names); - bind_context.AddGenericBinding(result->bind_index, subquery_alias, names, types); + bind_context.AddGenericBinding(result.bind_index, subquery_alias, names, types); - MoveCorrelatedExpressions(*result->child_binder); - return std::move(result); + MoveCorrelatedExpressions(*result.child_binder); + + BoundStatement result_statement; + result_statement.plan = + make_uniq(result.bind_index, std::move(result.child.plan), std::move(result.bound_pivot)); + return result_statement; } unique_ptr Binder::BindPivot(PivotRef &ref, vector> all_columns) { @@ -492,7 +500,7 @@ unique_ptr Binder::BindPivot(PivotRef &ref, vector Binder::BindPivot(PivotRef &ref, vector Binder::BindUnpivot(Binder &child_binder, PivotRef &ref, vector result; ExtractUnpivotColumnName(*unpivot_expr, result); if (result.empty()) { - throw BinderException( *unpivot_expr, "UNPIVOT clause must contain exactly one column - expression \"%s\" does not contain any", @@ -827,7 +834,7 @@ unique_ptr Binder::BindUnpivot(Binder &child_binder, PivotRef &ref, return result_node; } -unique_ptr Binder::Bind(PivotRef &ref) { +BoundStatement Binder::Bind(PivotRef &ref) { if (!ref.source) { throw InternalException("Pivot without a source!?"); } @@ -858,13 +865,10 @@ unique_ptr Binder::Bind(PivotRef &ref) { } // bind the generated select node auto child_binder = Binder::CreateBinder(context, this); - auto bound_select_node = child_binder->BindNode(*select_node); - auto root_index = bound_select_node->GetRootIndex(); - BoundQueryNode *bound_select_ptr = bound_select_node.get(); + auto result = child_binder->BindNode(*select_node); + auto root_index = result.plan->GetRootIndex(); - unique_ptr result; MoveCorrelatedExpressions(*child_binder); - result = make_uniq(std::move(child_binder), std::move(bound_select_node)); auto subquery_alias = ref.alias.empty() ? "__unnamed_pivot" : ref.alias; SubqueryRef subquery_ref(nullptr, subquery_alias); subquery_ref.column_name_alias = std::move(ref.column_name_alias); @@ -872,16 +876,14 @@ unique_ptr Binder::Bind(PivotRef &ref) { // if a WHERE clause was provided - bind a subquery holding the WHERE clause // we need to bind a new subquery here because the WHERE clause has to be applied AFTER the unnest child_binder = Binder::CreateBinder(context, this); - child_binder->bind_context.AddSubquery(root_index, subquery_ref.alias, subquery_ref, *bound_select_ptr); + child_binder->bind_context.AddSubquery(root_index, subquery_ref.alias, subquery_ref, result); auto where_query = make_uniq(); where_query->select_list.push_back(make_uniq()); where_query->where_clause = std::move(where_clause); - bound_select_node = child_binder->BindSelectNode(*where_query, std::move(result)); - bound_select_ptr = bound_select_node.get(); - root_index = bound_select_node->GetRootIndex(); - result = make_uniq(std::move(child_binder), std::move(bound_select_node)); + result = child_binder->BindSelectNode(*where_query, std::move(result)); + root_index = result.plan->GetRootIndex(); } - bind_context.AddSubquery(root_index, subquery_ref.alias, subquery_ref, *bound_select_ptr); + bind_context.AddSubquery(root_index, subquery_ref.alias, subquery_ref, result); return result; } diff --git a/src/duckdb/src/planner/binder/tableref/bind_showref.cpp b/src/duckdb/src/planner/binder/tableref/bind_showref.cpp index b23456cab..d2d91c3af 100644 --- a/src/duckdb/src/planner/binder/tableref/bind_showref.cpp +++ b/src/duckdb/src/planner/binder/tableref/bind_showref.cpp @@ -5,12 +5,10 @@ #include "duckdb/parser/tableref/subqueryref.hpp" #include "duckdb/planner/binder.hpp" #include "duckdb/planner/operator/logical_column_data_get.hpp" -#include "duckdb/planner/tableref/bound_table_function.hpp" #include "duckdb/planner/operator/logical_get.hpp" #include "duckdb/planner/operator/logical_projection.hpp" #include "duckdb/catalog/catalog_entry/table_catalog_entry.hpp" #include "duckdb/catalog/catalog.hpp" -#include "duckdb/catalog/catalog_search_path.hpp" #include "duckdb/main/client_data.hpp" #include "duckdb/main/client_context.hpp" @@ -89,7 +87,7 @@ BaseTableColumnInfo FindBaseTableColumn(LogicalOperator &op, idx_t column_index) return FindBaseTableColumn(op, bindings[column_index]); } -unique_ptr Binder::BindShowQuery(ShowRef &ref) { +BoundStatement Binder::BindShowQuery(ShowRef &ref) { // bind the child plan of the DESCRIBE statement auto child_binder = Binder::CreateBinder(context, this); auto plan = child_binder->Bind(*ref.query); @@ -142,12 +140,17 @@ unique_ptr Binder::BindShowQuery(ShowRef &ref) { } collection->Append(append_state, output); - auto show = make_uniq(GenerateTableIndex(), return_types, std::move(collection)); - bind_context.AddGenericBinding(show->table_index, "__show_select", return_names, return_types); - return make_uniq(std::move(show)); + auto table_index = GenerateTableIndex(); + + BoundStatement result; + result.names = return_names; + result.types = return_types; + result.plan = make_uniq(table_index, return_types, std::move(collection)); + bind_context.AddGenericBinding(table_index, "__show_select", return_names, return_types); + return result; } -unique_ptr Binder::BindShowTable(ShowRef &ref) { +BoundStatement Binder::BindShowTable(ShowRef &ref) { auto lname = StringUtil::Lower(ref.table_name); string sql; @@ -193,7 +196,7 @@ unique_ptr Binder::BindShowTable(ShowRef &ref) { return Bind(*subquery); } -unique_ptr Binder::Bind(ShowRef &ref) { +BoundStatement Binder::Bind(ShowRef &ref) { if (ref.show_type == ShowType::SUMMARY) { return BindSummarize(ref); } diff --git a/src/duckdb/src/planner/binder/tableref/bind_subqueryref.cpp b/src/duckdb/src/planner/binder/tableref/bind_subqueryref.cpp index 9eed0ea61..cfa727927 100644 --- a/src/duckdb/src/planner/binder/tableref/bind_subqueryref.cpp +++ b/src/duckdb/src/planner/binder/tableref/bind_subqueryref.cpp @@ -1,15 +1,14 @@ #include "duckdb/parser/tableref/subqueryref.hpp" #include "duckdb/planner/binder.hpp" -#include "duckdb/planner/tableref/bound_subqueryref.hpp" namespace duckdb { -unique_ptr Binder::Bind(SubqueryRef &ref) { +BoundStatement Binder::Bind(SubqueryRef &ref) { auto binder = Binder::CreateBinder(context, this); binder->can_contain_nulls = true; auto subquery = binder->BindNode(*ref.subquery->node); binder->alias = ref.alias.empty() ? "unnamed_subquery" : ref.alias; - idx_t bind_index = subquery->GetRootIndex(); + idx_t bind_index = subquery.plan->GetRootIndex(); string subquery_alias; if (ref.alias.empty()) { auto index = unnamed_subquery_index++; @@ -21,10 +20,14 @@ unique_ptr Binder::Bind(SubqueryRef &ref) { } else { subquery_alias = ref.alias; } - auto result = make_uniq(std::move(binder), std::move(subquery)); - bind_context.AddSubquery(bind_index, subquery_alias, ref, *result->subquery); - MoveCorrelatedExpressions(*result->binder); - return std::move(result); + binder->is_outside_flattened = is_outside_flattened; + if (binder->has_unplanned_dependent_joins) { + has_unplanned_dependent_joins = true; + } + bind_context.AddSubquery(bind_index, subquery_alias, ref, subquery); + MoveCorrelatedExpressions(*binder); + + return subquery; } } // namespace duckdb diff --git a/src/duckdb/src/planner/binder/tableref/bind_table_function.cpp b/src/duckdb/src/planner/binder/tableref/bind_table_function.cpp index 0c6e1e0aa..528478c58 100644 --- a/src/duckdb/src/planner/binder/tableref/bind_table_function.cpp +++ b/src/duckdb/src/planner/binder/tableref/bind_table_function.cpp @@ -13,9 +13,6 @@ #include "duckdb/planner/expression_binder/table_function_binder.hpp" #include "duckdb/planner/expression_binder/select_binder.hpp" #include "duckdb/planner/operator/logical_get.hpp" -#include "duckdb/planner/query_node/bound_select_node.hpp" -#include "duckdb/planner/tableref/bound_subqueryref.hpp" -#include "duckdb/planner/tableref/bound_table_function.hpp" #include "duckdb/function/function_binder.hpp" #include "duckdb/catalog/catalog_entry/table_function_catalog_entry.hpp" #include "duckdb/catalog/catalog_entry/table_catalog_entry.hpp" @@ -79,32 +76,28 @@ static TableFunctionBindType GetTableFunctionBindType(TableFunctionCatalogEntry : TableFunctionBindType::STANDARD_TABLE_FUNCTION; } -void Binder::BindTableInTableOutFunction(vector> &expressions, - unique_ptr &subquery) { +void Binder::BindTableInTableOutFunction(vector> &expressions, BoundStatement &subquery) { auto binder = Binder::CreateBinder(this->context, this); - unique_ptr subquery_node; // generate a subquery and bind that (i.e. UNNEST([1,2,3]) becomes UNNEST((SELECT [1,2,3])) auto select_node = make_uniq(); select_node->select_list = std::move(expressions); select_node->from_table = make_uniq(); - subquery_node = std::move(select_node); binder->can_contain_nulls = true; - auto node = binder->BindNode(*subquery_node); - subquery = make_uniq(std::move(binder), std::move(node)); - MoveCorrelatedExpressions(*subquery->binder); + subquery = binder->BindNode(*select_node); + MoveCorrelatedExpressions(*binder); } bool Binder::BindTableFunctionParameters(TableFunctionCatalogEntry &table_function, vector> &expressions, vector &arguments, vector ¶meters, - named_parameter_map_t &named_parameters, - unique_ptr &subquery, ErrorData &error) { + named_parameter_map_t &named_parameters, BoundStatement &subquery, + ErrorData &error) { auto bind_type = GetTableFunctionBindType(table_function, expressions); if (bind_type == TableFunctionBindType::TABLE_IN_OUT_FUNCTION) { // bind table in-out function BindTableInTableOutFunction(expressions, subquery); // fetch the arguments from the subquery - arguments = subquery->subquery->types; + arguments = subquery.types; return true; } bool seen_subquery = false; @@ -142,12 +135,11 @@ bool Binder::BindTableFunctionParameters(TableFunctionCatalogEntry &table_functi auto binder = Binder::CreateBinder(this->context, this); binder->can_contain_nulls = true; auto &se = child->Cast(); - auto node = binder->BindNode(*se.subquery->node); - subquery = make_uniq(std::move(binder), std::move(node)); - MoveCorrelatedExpressions(*subquery->binder); + subquery = binder->BindNode(*se.subquery->node); + MoveCorrelatedExpressions(*binder); seen_subquery = true; arguments.emplace_back(LogicalTypeId::TABLE); - parameters.emplace_back(Value()); + parameters.emplace_back(); continue; } @@ -188,11 +180,10 @@ static string GetAlias(const TableFunctionRef &ref) { return string(); } -unique_ptr Binder::BindTableFunctionInternal(TableFunction &table_function, - const TableFunctionRef &ref, vector parameters, - named_parameter_map_t named_parameters, - vector input_table_types, - vector input_table_names) { +BoundStatement Binder::BindTableFunctionInternal(TableFunction &table_function, const TableFunctionRef &ref, + vector parameters, named_parameter_map_t named_parameters, + vector input_table_types, + vector input_table_names) { auto function_name = GetAlias(ref); auto &column_name_alias = ref.column_name_alias; auto bind_index = GenerateTableIndex(); @@ -221,8 +212,12 @@ unique_ptr Binder::BindTableFunctionInternal(TableFunction &tab table_function.name); } } + BoundStatement result; bind_context.AddGenericBinding(bind_index, function_name, return_names, new_plan->types); - return new_plan; + result.names = return_names; + result.types = new_plan->types; + result.plan = std::move(new_plan); + return result; } } if (table_function.bind_replace) { @@ -234,7 +229,7 @@ unique_ptr Binder::BindTableFunctionInternal(TableFunction &tab if (!ref.column_name_alias.empty()) { new_plan->column_name_alias = ref.column_name_alias; } - return CreatePlan(*Bind(*new_plan)); + return Bind(*new_plan); } } if (!table_function.bind) { @@ -307,52 +302,46 @@ unique_ptr Binder::BindTableFunctionInternal(TableFunction &tab } if (ref.with_ordinality == OrdinalityType::WITH_ORDINALITY && correlated_columns.empty()) { + bind_context.AddTableFunction(bind_index, function_name, return_names, return_types, get->GetMutableColumnIds(), + get->GetTable().get(), std::move(virtual_columns)); + auto window_index = GenerateTableIndex(); auto window = make_uniq(window_index); auto row_number = make_uniq(ExpressionType::WINDOW_ROW_NUMBER, LogicalType::BIGINT, nullptr, nullptr); row_number->start = WindowBoundary::UNBOUNDED_PRECEDING; row_number->end = WindowBoundary::CURRENT_ROW_ROWS; + string ordinality_alias = ordinality_column_name; if (return_names.size() < column_name_alias.size()) { row_number->alias = column_name_alias[return_names.size()]; + ordinality_alias = column_name_alias[return_names.size()]; } else { row_number->alias = ordinality_column_name; } + return_names.push_back(ordinality_alias); + return_types.push_back(LogicalType::BIGINT); window->expressions.push_back(std::move(row_number)); - for (idx_t i = 0; i < return_types.size(); i++) { - get->AddColumnId(i); - } + window->types.push_back(LogicalType::BIGINT); window->children.push_back(std::move(get)); + bind_context.AddGenericBinding(window_index, function_name, {ordinality_alias}, {LogicalType::BIGINT}); - vector> select_list; - for (idx_t i = 0; i < return_types.size(); i++) { - auto expression = make_uniq(return_types[i], ColumnBinding(bind_index, i)); - select_list.push_back(std::move(expression)); - } - select_list.push_back(make_uniq(LogicalType::BIGINT, ColumnBinding(window_index, 0))); - - auto projection_index = GenerateTableIndex(); - auto projection = make_uniq(projection_index, std::move(select_list)); - - projection->children.push_back(std::move(window)); - if (return_names.size() < column_name_alias.size()) { - return_names.push_back(column_name_alias[return_names.size()]); - } else { - return_names.push_back(ordinality_column_name); - } - - return_types.push_back(LogicalType::BIGINT); - bind_context.AddGenericBinding(projection_index, function_name, return_names, return_types); - return std::move(projection); + BoundStatement result; + result.names = std::move(return_names); + result.types = std::move(return_types); + result.plan = std::move(window); + return result; } - // now add the table function to the bind context so its columns can be bound + BoundStatement result; bind_context.AddTableFunction(bind_index, function_name, return_names, return_types, get->GetMutableColumnIds(), get->GetTable().get(), std::move(virtual_columns)); - return std::move(get); + result.names = std::move(return_names); + result.types = std::move(return_types); + result.plan = std::move(get); + return result; } -unique_ptr Binder::BindTableFunction(TableFunction &function, vector parameters) { +BoundStatement Binder::BindTableFunction(TableFunction &function, vector parameters) { named_parameter_map_t named_parameters; vector input_table_types; vector input_table_names; @@ -364,7 +353,7 @@ unique_ptr Binder::BindTableFunction(TableFunction &function, v std::move(input_table_types), std::move(input_table_names)); } -unique_ptr Binder::Bind(TableFunctionRef &ref) { +BoundStatement Binder::Bind(TableFunctionRef &ref) { QueryErrorContext error_context(ref.query_location); D_ASSERT(ref.function->GetExpressionType() == ExpressionType::FUNCTION); @@ -388,7 +377,7 @@ unique_ptr Binder::Bind(TableFunctionRef &ref) { binder->can_contain_nulls = true; binder->alias = ref.alias.empty() ? "unnamed_query" : ref.alias; - unique_ptr query; + BoundStatement query; try { query = binder->BindNode(*query_node); } catch (std::exception &ex) { @@ -397,15 +386,14 @@ unique_ptr Binder::Bind(TableFunctionRef &ref) { error.Throw(); } - idx_t bind_index = query->GetRootIndex(); + idx_t bind_index = query.plan->GetRootIndex(); // string alias; string alias = (ref.alias.empty() ? "unnamed_query" + to_string(bind_index) : ref.alias); - auto result = make_uniq(std::move(binder), std::move(query)); // remember ref here is TableFunctionRef and NOT base class - bind_context.AddSubquery(bind_index, alias, ref, *result->subquery); - MoveCorrelatedExpressions(*result->binder); - return std::move(result); + bind_context.AddSubquery(bind_index, alias, ref, query); + MoveCorrelatedExpressions(*binder); + return query; } D_ASSERT(func_catalog.type == CatalogType::TABLE_FUNCTION_ENTRY); auto &function = func_catalog.Cast(); @@ -414,7 +402,7 @@ unique_ptr Binder::Bind(TableFunctionRef &ref) { vector arguments; vector parameters; named_parameter_map_t named_parameters; - unique_ptr subquery; + BoundStatement subquery; ErrorData error; if (!BindTableFunctionParameters(function, fexpr.children, arguments, parameters, named_parameters, subquery, error)) { @@ -437,9 +425,9 @@ unique_ptr Binder::Bind(TableFunctionRef &ref) { vector input_table_types; vector input_table_names; - if (subquery) { - input_table_types = subquery->subquery->types; - input_table_names = subquery->subquery->names; + if (subquery.plan) { + input_table_types = subquery.types; + input_table_names = subquery.names; } else if (table_function.in_out_function) { for (auto ¶m : parameters) { input_table_types.push_back(param.type()); @@ -457,7 +445,7 @@ unique_ptr Binder::Bind(TableFunctionRef &ref) { parameters[i] = parameters[i].CastAs(context, target_type); } } - } else if (subquery) { + } else if (subquery.plan) { for (idx_t i = 0; i < arguments.size(); i++) { auto target_type = i < table_function.arguments.size() ? table_function.arguments[i] : table_function.varargs; @@ -469,11 +457,39 @@ unique_ptr Binder::Bind(TableFunctionRef &ref) { } } - auto get = BindTableFunctionInternal(table_function, ref, std::move(parameters), std::move(named_parameters), - std::move(input_table_types), std::move(input_table_names)); - auto table_function_ref = make_uniq(std::move(get)); - table_function_ref->subquery = std::move(subquery); - return std::move(table_function_ref); + BoundStatement get; + try { + get = BindTableFunctionInternal(table_function, ref, std::move(parameters), std::move(named_parameters), + std::move(input_table_types), std::move(input_table_names)); + } catch (std::exception &ex) { + error = ErrorData(ex); + error.AddQueryLocation(ref); + error.Throw(); + } + + if (subquery.plan) { + auto child_node = std::move(subquery.plan); + + reference node = *get.plan; + + while (!node.get().children.empty()) { + D_ASSERT(node.get().children.size() == 1); + if (node.get().children.size() != 1) { + throw InternalException( + "Binder::CreatePlan: linear path expected, but found node with %d children", + node.get().children.size()); + } + node = *node.get().children[0]; + } + + D_ASSERT(node.get().type == LogicalOperatorType::LOGICAL_GET); + node.get().children.push_back(std::move(child_node)); + } + BoundStatement result_statement; + result_statement.names = get.names; + result_statement.types = get.types; + result_statement.plan = std::move(get.plan); + return result_statement; } } // namespace duckdb diff --git a/src/duckdb/src/planner/binder/tableref/plan_basetableref.cpp b/src/duckdb/src/planner/binder/tableref/plan_basetableref.cpp deleted file mode 100644 index 085498fbb..000000000 --- a/src/duckdb/src/planner/binder/tableref/plan_basetableref.cpp +++ /dev/null @@ -1,11 +0,0 @@ -#include "duckdb/planner/binder.hpp" -#include "duckdb/planner/operator/logical_get.hpp" -#include "duckdb/planner/tableref/bound_basetableref.hpp" - -namespace duckdb { - -unique_ptr Binder::CreatePlan(BoundBaseTableRef &ref) { - return std::move(ref.get); -} - -} // namespace duckdb diff --git a/src/duckdb/src/planner/binder/tableref/plan_column_data_ref.cpp b/src/duckdb/src/planner/binder/tableref/plan_column_data_ref.cpp deleted file mode 100644 index 83e965b5e..000000000 --- a/src/duckdb/src/planner/binder/tableref/plan_column_data_ref.cpp +++ /dev/null @@ -1,15 +0,0 @@ -#include "duckdb/planner/binder.hpp" -#include "duckdb/planner/tableref/bound_column_data_ref.hpp" -#include "duckdb/planner/operator/logical_column_data_get.hpp" - -namespace duckdb { - -unique_ptr Binder::CreatePlan(BoundColumnDataRef &ref) { - auto types = ref.collection->Types(); - // Create a (potentially owning) LogicalColumnDataGet - auto root = make_uniq_base(ref.bind_index, std::move(types), - std::move(ref.collection)); - return root; -} - -} // namespace duckdb diff --git a/src/duckdb/src/planner/binder/tableref/plan_cteref.cpp b/src/duckdb/src/planner/binder/tableref/plan_cteref.cpp deleted file mode 100644 index 4ee2b9a76..000000000 --- a/src/duckdb/src/planner/binder/tableref/plan_cteref.cpp +++ /dev/null @@ -1,11 +0,0 @@ -#include "duckdb/planner/binder.hpp" -#include "duckdb/planner/operator/logical_cteref.hpp" -#include "duckdb/planner/tableref/bound_cteref.hpp" - -namespace duckdb { - -unique_ptr Binder::CreatePlan(BoundCTERef &ref) { - return make_uniq(ref.bind_index, ref.cte_index, ref.types, ref.bound_columns, ref.is_recurring); -} - -} // namespace duckdb diff --git a/src/duckdb/src/planner/binder/tableref/plan_delimgetref.cpp b/src/duckdb/src/planner/binder/tableref/plan_delimgetref.cpp deleted file mode 100644 index b674b43df..000000000 --- a/src/duckdb/src/planner/binder/tableref/plan_delimgetref.cpp +++ /dev/null @@ -1,11 +0,0 @@ -#include "duckdb/planner/binder.hpp" -#include "duckdb/planner/operator/logical_get.hpp" -#include "duckdb/planner/tableref/bound_basetableref.hpp" -#include "duckdb/planner/operator/logical_delim_get.hpp" -namespace duckdb { - -unique_ptr Binder::CreatePlan(BoundDelimGetRef &ref) { - return make_uniq(ref.bind_index, ref.column_types); -} - -} // namespace duckdb diff --git a/src/duckdb/src/planner/binder/tableref/plan_dummytableref.cpp b/src/duckdb/src/planner/binder/tableref/plan_dummytableref.cpp deleted file mode 100644 index f31fc929b..000000000 --- a/src/duckdb/src/planner/binder/tableref/plan_dummytableref.cpp +++ /dev/null @@ -1,11 +0,0 @@ -#include "duckdb/planner/binder.hpp" -#include "duckdb/planner/operator/logical_dummy_scan.hpp" -#include "duckdb/planner/tableref/bound_dummytableref.hpp" - -namespace duckdb { - -unique_ptr Binder::CreatePlan(BoundEmptyTableRef &ref) { - return make_uniq(ref.bind_index); -} - -} // namespace duckdb diff --git a/src/duckdb/src/planner/binder/tableref/plan_expressionlistref.cpp b/src/duckdb/src/planner/binder/tableref/plan_expressionlistref.cpp deleted file mode 100644 index ba6253bce..000000000 --- a/src/duckdb/src/planner/binder/tableref/plan_expressionlistref.cpp +++ /dev/null @@ -1,27 +0,0 @@ -#include "duckdb/planner/binder.hpp" -#include "duckdb/planner/tableref/bound_expressionlistref.hpp" -#include "duckdb/planner/operator/logical_expression_get.hpp" -#include "duckdb/planner/operator/logical_dummy_scan.hpp" - -namespace duckdb { - -unique_ptr Binder::CreatePlan(BoundExpressionListRef &ref) { - auto root = make_uniq_base(GenerateTableIndex()); - // values list, first plan any subqueries in the list - for (auto &expr_list : ref.values) { - for (auto &expr : expr_list) { - PlanSubqueries(expr, root); - } - } - // now create a LogicalExpressionGet from the set of expressions - // fetch the types - vector types; - for (auto &expr : ref.values[0]) { - types.push_back(expr->return_type); - } - auto expr_get = make_uniq(ref.bind_index, types, std::move(ref.values)); - expr_get->AddChild(std::move(root)); - return std::move(expr_get); -} - -} // namespace duckdb diff --git a/src/duckdb/src/planner/binder/tableref/plan_joinref.cpp b/src/duckdb/src/planner/binder/tableref/plan_joinref.cpp index 9de5829f2..d15cfe9ea 100644 --- a/src/duckdb/src/planner/binder/tableref/plan_joinref.cpp +++ b/src/duckdb/src/planner/binder/tableref/plan_joinref.cpp @@ -70,7 +70,6 @@ void LogicalComparisonJoin::ExtractJoinConditions( unique_ptr &right_child, const unordered_set &left_bindings, const unordered_set &right_bindings, vector> &expressions, vector &conditions, vector> &arbitrary_expressions) { - for (auto &expr : expressions) { auto total_side = JoinSide::GetJoinSide(*expr, left_bindings, right_bindings); if (total_side != JoinSide::BOTH) { @@ -298,8 +297,8 @@ unique_ptr Binder::CreatePlan(BoundJoinRef &ref) { // Set the flag to ensure that children do not flatten before the root is_outside_flattened = false; } - auto left = CreatePlan(*ref.left); - auto right = CreatePlan(*ref.right); + auto left = std::move(ref.left.plan); + auto right = std::move(ref.right.plan); is_outside_flattened = old_is_outside_flattened; // For joins, depth of the bindings will be one higher on the right because of the lateral binder diff --git a/src/duckdb/src/planner/binder/tableref/plan_pivotref.cpp b/src/duckdb/src/planner/binder/tableref/plan_pivotref.cpp deleted file mode 100644 index 4d9482e5b..000000000 --- a/src/duckdb/src/planner/binder/tableref/plan_pivotref.cpp +++ /dev/null @@ -1,13 +0,0 @@ -#include "duckdb/planner/tableref/bound_pivotref.hpp" -#include "duckdb/planner/operator/logical_pivot.hpp" - -namespace duckdb { - -unique_ptr Binder::CreatePlan(BoundPivotRef &ref) { - auto subquery = ref.child_binder->CreatePlan(*ref.child); - - auto result = make_uniq(ref.bind_index, std::move(subquery), std::move(ref.bound_pivot)); - return std::move(result); -} - -} // namespace duckdb diff --git a/src/duckdb/src/planner/binder/tableref/plan_subqueryref.cpp b/src/duckdb/src/planner/binder/tableref/plan_subqueryref.cpp deleted file mode 100644 index 821654460..000000000 --- a/src/duckdb/src/planner/binder/tableref/plan_subqueryref.cpp +++ /dev/null @@ -1,17 +0,0 @@ -#include "duckdb/planner/binder.hpp" -#include "duckdb/planner/tableref/bound_subqueryref.hpp" - -namespace duckdb { - -unique_ptr Binder::CreatePlan(BoundSubqueryRef &ref) { - // generate the logical plan for the subquery - // this happens separately from the current LogicalPlan generation - ref.binder->is_outside_flattened = is_outside_flattened; - auto subquery = ref.binder->CreatePlan(*ref.subquery); - if (ref.binder->has_unplanned_dependent_joins) { - has_unplanned_dependent_joins = true; - } - return subquery; -} - -} // namespace duckdb diff --git a/src/duckdb/src/planner/binder/tableref/plan_table_function.cpp b/src/duckdb/src/planner/binder/tableref/plan_table_function.cpp deleted file mode 100644 index 6c2f9957a..000000000 --- a/src/duckdb/src/planner/binder/tableref/plan_table_function.cpp +++ /dev/null @@ -1,28 +0,0 @@ -#include "duckdb/planner/binder.hpp" -#include "duckdb/planner/tableref/bound_table_function.hpp" - -namespace duckdb { - -unique_ptr Binder::CreatePlan(BoundTableFunction &ref) { - if (ref.subquery) { - auto child_node = CreatePlan(*ref.subquery); - - reference node = *ref.get; - - while (!node.get().children.empty()) { - D_ASSERT(node.get().children.size() == 1); - if (node.get().children.size() != 1) { - throw InternalException( - "Binder::CreatePlan: linear path expected, but found node with %d children", - node.get().children.size()); - } - node = *node.get().children[0]; - } - - D_ASSERT(node.get().type == LogicalOperatorType::LOGICAL_GET); - node.get().children.push_back(std::move(child_node)); - } - return std::move(ref.get); -} - -} // namespace duckdb diff --git a/src/duckdb/src/planner/bound_parameter_map.cpp b/src/duckdb/src/planner/bound_parameter_map.cpp index 112a17934..4a906d188 100644 --- a/src/duckdb/src/planner/bound_parameter_map.cpp +++ b/src/duckdb/src/planner/bound_parameter_map.cpp @@ -43,7 +43,6 @@ shared_ptr BoundParameterMap::CreateOrGetData(const string & } unique_ptr BoundParameterMap::BindParameterExpression(ParameterExpression &expr) { - auto &identifier = expr.identifier; D_ASSERT(!parameter_data.count(identifier)); diff --git a/src/duckdb/src/planner/collation_binding.cpp b/src/duckdb/src/planner/collation_binding.cpp index 1ddefb9a8..dd371bbc4 100644 --- a/src/duckdb/src/planner/collation_binding.cpp +++ b/src/duckdb/src/planner/collation_binding.cpp @@ -8,6 +8,7 @@ #include "duckdb/function/function_binder.hpp" namespace duckdb { +constexpr const char *CollateCatalogEntry::Name; bool PushVarcharCollation(ClientContext &context, unique_ptr &source, const LogicalType &sql_type, CollationType type) { @@ -109,11 +110,34 @@ bool PushIntervalCollation(ClientContext &context, unique_ptr &sourc return true; } +bool PushVariantCollation(ClientContext &context, unique_ptr &source, const LogicalType &sql_type, + CollationType) { + if (sql_type.id() != LogicalTypeId::VARIANT) { + return false; + } + auto &catalog = Catalog::GetSystemCatalog(context); + auto &function_entry = catalog.GetEntry(context, DEFAULT_SCHEMA, "variant_normalize"); + if (function_entry.functions.Size() != 1) { + throw InternalException("variant_normalize should only have a single overload"); + } + auto source_alias = source->GetAlias(); + auto &scalar_function = function_entry.functions.GetFunctionReferenceByOffset(0); + vector> children; + children.push_back(std::move(source)); + + FunctionBinder function_binder(context); + auto function = function_binder.BindScalarFunction(scalar_function, std::move(children)); + function->SetAlias(source_alias); + source = std::move(function); + return true; +} + // timetz_byte_comparable CollationBinding::CollationBinding() { RegisterCollation(CollationCallback(PushVarcharCollation)); RegisterCollation(CollationCallback(PushTimeTZCollation)); RegisterCollation(CollationCallback(PushIntervalCollation)); + RegisterCollation(CollationCallback(PushVariantCollation)); } void CollationBinding::RegisterCollation(CollationCallback callback) { diff --git a/src/duckdb/src/planner/expression/bound_aggregate_expression.cpp b/src/duckdb/src/planner/expression/bound_aggregate_expression.cpp index 68bc16b26..f0f30e030 100644 --- a/src/duckdb/src/planner/expression/bound_aggregate_expression.cpp +++ b/src/duckdb/src/planner/expression/bound_aggregate_expression.cpp @@ -11,7 +11,7 @@ namespace duckdb { BoundAggregateExpression::BoundAggregateExpression(AggregateFunction function, vector> children, unique_ptr filter, unique_ptr bind_info, AggregateType aggr_type) - : Expression(ExpressionType::BOUND_AGGREGATE, ExpressionClass::BOUND_AGGREGATE, function.return_type), + : Expression(ExpressionType::BOUND_AGGREGATE, ExpressionClass::BOUND_AGGREGATE, function.GetReturnType()), function(std::move(function)), children(std::move(children)), bind_info(std::move(bind_info)), aggr_type(aggr_type), filter(std::move(filter)) { D_ASSERT(!this->function.name.empty()); @@ -61,8 +61,8 @@ bool BoundAggregateExpression::Equals(const BaseExpression &other_p) const { } bool BoundAggregateExpression::PropagatesNullValues() const { - return function.null_handling == FunctionNullHandling::SPECIAL_HANDLING ? false - : Expression::PropagatesNullValues(); + return function.GetNullHandling() == FunctionNullHandling::SPECIAL_HANDLING ? false + : Expression::PropagatesNullValues(); } unique_ptr BoundAggregateExpression::Copy() const { diff --git a/src/duckdb/src/planner/expression/bound_function_expression.cpp b/src/duckdb/src/planner/expression/bound_function_expression.cpp index 5556dec21..152285dfe 100644 --- a/src/duckdb/src/planner/expression/bound_function_expression.cpp +++ b/src/duckdb/src/planner/expression/bound_function_expression.cpp @@ -19,11 +19,11 @@ BoundFunctionExpression::BoundFunctionExpression(LogicalType return_type, Scalar } bool BoundFunctionExpression::IsVolatile() const { - return function.stability == FunctionStability::VOLATILE ? true : Expression::IsVolatile(); + return function.GetStability() == FunctionStability::VOLATILE ? true : Expression::IsVolatile(); } bool BoundFunctionExpression::IsConsistent() const { - return function.stability != FunctionStability::CONSISTENT ? false : Expression::IsConsistent(); + return function.GetStability() != FunctionStability::CONSISTENT ? false : Expression::IsConsistent(); } bool BoundFunctionExpression::IsFoldable() const { @@ -39,11 +39,11 @@ bool BoundFunctionExpression::IsFoldable() const { } } } - return function.stability == FunctionStability::VOLATILE ? false : Expression::IsFoldable(); + return function.GetStability() == FunctionStability::VOLATILE ? false : Expression::IsFoldable(); } bool BoundFunctionExpression::CanThrow() const { - if (function.errors == FunctionErrors::CAN_THROW_RUNTIME_ERROR) { + if (function.GetErrorMode() == FunctionErrors::CAN_THROW_RUNTIME_ERROR) { return true; } return Expression::CanThrow(); @@ -54,8 +54,8 @@ string BoundFunctionExpression::ToString() const { is_operator); } bool BoundFunctionExpression::PropagatesNullValues() const { - return function.null_handling == FunctionNullHandling::SPECIAL_HANDLING ? false - : Expression::PropagatesNullValues(); + return function.GetNullHandling() == FunctionNullHandling::SPECIAL_HANDLING ? false + : Expression::PropagatesNullValues(); } hash_t BoundFunctionExpression::Hash() const { @@ -112,7 +112,7 @@ unique_ptr BoundFunctionExpression::Deserialize(Deserializer &deseri auto entry = FunctionSerializer::Deserialize( deserializer, CatalogType::SCALAR_FUNCTION_ENTRY, children, return_type); - auto function_return_type = entry.first.return_type; + auto function_return_type = entry.first.GetReturnType(); auto is_operator = deserializer.ReadProperty(202, "is_operator"); diff --git a/src/duckdb/src/planner/expression_binder.cpp b/src/duckdb/src/planner/expression_binder.cpp index 5141765bb..220714733 100644 --- a/src/duckdb/src/planner/expression_binder.cpp +++ b/src/duckdb/src/planner/expression_binder.cpp @@ -103,7 +103,9 @@ BindResult ExpressionBinder::BindExpression(unique_ptr &expr, case ExpressionClass::STAR: return BindResult(BinderException::Unsupported(expr_ref, "STAR expression is not supported here")); default: - throw NotImplementedException("Unimplemented expression class"); + return BindResult( + NotImplementedException("Unimplemented expression class in ExpressionBinder::BindExpression: %s", + EnumUtil::ToString(expr_ref.GetExpressionClass()))); } } diff --git a/src/duckdb/src/planner/expression_binder/check_binder.cpp b/src/duckdb/src/planner/expression_binder/check_binder.cpp index c89c96ded..c6f1abb5a 100644 --- a/src/duckdb/src/planner/expression_binder/check_binder.cpp +++ b/src/duckdb/src/planner/expression_binder/check_binder.cpp @@ -43,7 +43,6 @@ BindResult ExpressionBinder::BindQualifiedColumnName(ColumnRefExpression &colref } BindResult CheckBinder::BindCheckColumn(ColumnRefExpression &colref) { - if (!colref.IsQualified()) { if (lambda_bindings) { for (idx_t i = lambda_bindings->size(); i > 0; i--) { diff --git a/src/duckdb/src/planner/expression_binder/column_alias_binder.cpp b/src/duckdb/src/planner/expression_binder/column_alias_binder.cpp index c4477f9e3..d4324e8eb 100644 --- a/src/duckdb/src/planner/expression_binder/column_alias_binder.cpp +++ b/src/duckdb/src/planner/expression_binder/column_alias_binder.cpp @@ -13,7 +13,6 @@ ColumnAliasBinder::ColumnAliasBinder(SelectBindState &bind_state) : bind_state(b bool ColumnAliasBinder::BindAlias(ExpressionBinder &enclosing_binder, unique_ptr &expr_ptr, idx_t depth, bool root_expression, BindResult &result) { - D_ASSERT(expr_ptr->GetExpressionClass() == ExpressionClass::COLUMN_REF); auto &expr = expr_ptr->Cast(); diff --git a/src/duckdb/src/planner/expression_binder/having_binder.cpp b/src/duckdb/src/planner/expression_binder/having_binder.cpp index 902add5e2..ff64acf63 100644 --- a/src/duckdb/src/planner/expression_binder/having_binder.cpp +++ b/src/duckdb/src/planner/expression_binder/having_binder.cpp @@ -3,7 +3,6 @@ #include "duckdb/parser/expression/columnref_expression.hpp" #include "duckdb/parser/expression/window_expression.hpp" #include "duckdb/planner/binder.hpp" -#include "duckdb/planner/expression_binder/aggregate_binder.hpp" #include "duckdb/common/string_util.hpp" #include "duckdb/planner/query_node/bound_select_node.hpp" @@ -39,7 +38,6 @@ unique_ptr HavingBinder::QualifyColumnName(ColumnRefExpression } BindResult HavingBinder::BindColumnRef(unique_ptr &expr_ptr, idx_t depth, bool root_expression) { - // Keep the original column name to return a meaningful error message. auto col_ref = expr_ptr->Cast(); const auto &column_name = col_ref.GetColumnName(); @@ -91,7 +89,7 @@ BindResult HavingBinder::BindColumnRef(unique_ptr &expr_ptr, i } BindResult HavingBinder::BindWindow(WindowExpression &expr, idx_t depth) { - return BindResult(BinderException::Unsupported(expr, "HAVING clause cannot contain window functions!")); + throw BinderException::Unsupported(expr, "HAVING clause cannot contain window functions!"); } } // namespace duckdb diff --git a/src/duckdb/src/planner/expression_binder/lateral_binder.cpp b/src/duckdb/src/planner/expression_binder/lateral_binder.cpp index 205b644e8..0b693558a 100644 --- a/src/duckdb/src/planner/expression_binder/lateral_binder.cpp +++ b/src/duckdb/src/planner/expression_binder/lateral_binder.cpp @@ -3,7 +3,7 @@ #include "duckdb/planner/logical_operator_visitor.hpp" #include "duckdb/planner/expression/bound_columnref_expression.hpp" #include "duckdb/planner/expression/bound_subquery_expression.hpp" -#include "duckdb/planner/tableref/bound_joinref.hpp" +#include "duckdb/planner/operator/logical_dependent_join.hpp" namespace duckdb { @@ -17,7 +17,7 @@ void LateralBinder::ExtractCorrelatedColumns(Expression &expr) { // add the correlated column info CorrelatedColumnInfo info(bound_colref); if (std::find(correlated_columns.begin(), correlated_columns.end(), info) == correlated_columns.end()) { - correlated_columns.push_back(std::move(info)); + correlated_columns.AddColumn(std::move(info)); // TODO is adding to the front OK here? } } } @@ -54,8 +54,7 @@ string LateralBinder::UnsupportedAggregateMessage() { return "LATERAL join cannot contain aggregates!"; } -static void ReduceColumnRefDepth(BoundColumnRefExpression &expr, - const vector &correlated_columns) { +static void ReduceColumnRefDepth(BoundColumnRefExpression &expr, const CorrelatedColumns &correlated_columns) { // don't need to reduce this if (expr.depth == 0) { return; @@ -69,8 +68,7 @@ static void ReduceColumnRefDepth(BoundColumnRefExpression &expr, } } -static void ReduceColumnDepth(vector &columns, - const vector &affected_columns) { +static void ReduceColumnDepth(CorrelatedColumns &columns, const CorrelatedColumns &affected_columns) { for (auto &s_correlated : columns) { for (auto &affected : affected_columns) { if (affected == s_correlated) { @@ -81,45 +79,44 @@ static void ReduceColumnDepth(vector &columns, } } -class ExpressionDepthReducerRecursive : public BoundNodeVisitor { +class ExpressionDepthReducerRecursive : public LogicalOperatorVisitor { public: - explicit ExpressionDepthReducerRecursive(const vector &correlated) - : correlated_columns(correlated) { + explicit ExpressionDepthReducerRecursive(const CorrelatedColumns &correlated) : correlated_columns(correlated) { } - void VisitExpression(unique_ptr &expression) override { - if (expression->GetExpressionType() == ExpressionType::BOUND_COLUMN_REF) { - ReduceColumnRefDepth(expression->Cast(), correlated_columns); - } else if (expression->GetExpressionType() == ExpressionType::SUBQUERY) { - ReduceExpressionSubquery(expression->Cast(), correlated_columns); + void VisitExpression(unique_ptr *expression) override { + auto &expr = **expression; + if (expr.GetExpressionType() == ExpressionType::BOUND_COLUMN_REF) { + ReduceColumnRefDepth(expr.Cast(), correlated_columns); + } else if (expr.GetExpressionType() == ExpressionType::SUBQUERY) { + ReduceExpressionSubquery(expr.Cast(), correlated_columns); } - BoundNodeVisitor::VisitExpression(expression); + LogicalOperatorVisitor::VisitExpression(expression); } - void VisitBoundTableRef(BoundTableRef &ref) override { - if (ref.type == TableReferenceType::JOIN) { + void VisitOperator(LogicalOperator &op) override { + if (op.type == LogicalOperatorType::LOGICAL_DEPENDENT_JOIN) { // rewrite correlated columns in child joins - auto &bound_join = ref.Cast(); + auto &bound_join = op.Cast(); ReduceColumnDepth(bound_join.correlated_columns, correlated_columns); } // visit the children of the table ref - BoundNodeVisitor::VisitBoundTableRef(ref); + LogicalOperatorVisitor::VisitOperator(op); } - static void ReduceExpressionSubquery(BoundSubqueryExpression &expr, - const vector &correlated_columns) { + static void ReduceExpressionSubquery(BoundSubqueryExpression &expr, const CorrelatedColumns &correlated_columns) { ReduceColumnDepth(expr.binder->correlated_columns, correlated_columns); ExpressionDepthReducerRecursive recursive(correlated_columns); - recursive.VisitBoundQueryNode(*expr.subquery); + recursive.VisitOperator(*expr.subquery.plan); } private: - const vector &correlated_columns; + const CorrelatedColumns &correlated_columns; }; class ExpressionDepthReducer : public LogicalOperatorVisitor { public: - explicit ExpressionDepthReducer(const vector &correlated) : correlated_columns(correlated) { + explicit ExpressionDepthReducer(const CorrelatedColumns &correlated) : correlated_columns(correlated) { } protected: @@ -133,10 +130,10 @@ class ExpressionDepthReducer : public LogicalOperatorVisitor { return nullptr; } - const vector &correlated_columns; + const CorrelatedColumns &correlated_columns; }; -void LateralBinder::ReduceExpressionDepth(LogicalOperator &op, const vector &correlated) { +void LateralBinder::ReduceExpressionDepth(LogicalOperator &op, const CorrelatedColumns &correlated) { ExpressionDepthReducer depth_reducer(correlated); depth_reducer.VisitOperator(op); } diff --git a/src/duckdb/src/planner/expression_binder/table_function_binder.cpp b/src/duckdb/src/planner/expression_binder/table_function_binder.cpp index 198bd072b..720dbe37d 100644 --- a/src/duckdb/src/planner/expression_binder/table_function_binder.cpp +++ b/src/duckdb/src/planner/expression_binder/table_function_binder.cpp @@ -27,9 +27,13 @@ BindResult TableFunctionBinder::BindColumnReference(unique_ptr if (lambda_ref) { return BindLambdaReference(lambda_ref->Cast(), depth); } + if (binder.macro_binding && binder.macro_binding->HasMatchingBinding(col_ref.GetName())) { throw ParameterNotResolvedException(); } + } else if (col_ref.column_names[0].find(DummyBinding::DUMMY_NAME) != string::npos && binder.macro_binding && + binder.macro_binding->HasMatchingBinding(col_ref.GetName())) { + throw ParameterNotResolvedException(); } auto query_location = col_ref.GetQueryLocation(); diff --git a/src/duckdb/src/planner/expression_binder/where_binder.cpp b/src/duckdb/src/planner/expression_binder/where_binder.cpp index 9b25c7930..90a0fe5c4 100644 --- a/src/duckdb/src/planner/expression_binder/where_binder.cpp +++ b/src/duckdb/src/planner/expression_binder/where_binder.cpp @@ -10,7 +10,6 @@ WhereBinder::WhereBinder(Binder &binder, ClientContext &context, optional_ptr &expr_ptr, idx_t depth, bool root_expression) { - auto result = ExpressionBinder::BindExpression(expr_ptr, depth); if (!result.HasError() || !column_alias_binder) { return result; diff --git a/src/duckdb/src/planner/expression_iterator.cpp b/src/duckdb/src/planner/expression_iterator.cpp index 042712732..3d1407900 100644 --- a/src/duckdb/src/planner/expression_iterator.cpp +++ b/src/duckdb/src/planner/expression_iterator.cpp @@ -4,8 +4,6 @@ #include "duckdb/planner/expression/list.hpp" #include "duckdb/planner/query_node/bound_select_node.hpp" #include "duckdb/planner/query_node/bound_set_operation_node.hpp" -#include "duckdb/planner/query_node/bound_recursive_cte_node.hpp" -#include "duckdb/planner/query_node/bound_cte_node.hpp" #include "duckdb/planner/tableref/list.hpp" #include "duckdb/common/enum_util.hpp" @@ -183,156 +181,4 @@ void ExpressionIterator::VisitExpressionClassMutable( *expr, [&](unique_ptr &child) { VisitExpressionClassMutable(child, expr_class, callback); }); } -void BoundNodeVisitor::VisitExpression(unique_ptr &expression) { - VisitExpressionChildren(*expression); -} - -void BoundNodeVisitor::VisitExpressionChildren(Expression &expr) { - ExpressionIterator::EnumerateChildren(expr, [&](unique_ptr &expr) { VisitExpression(expr); }); -} - -void BoundNodeVisitor::VisitBoundQueryNode(BoundQueryNode &node) { - switch (node.type) { - case QueryNodeType::SET_OPERATION_NODE: { - auto &bound_setop = node.Cast(); - for (auto &child : bound_setop.bound_children) { - VisitBoundQueryNode(*child.node); - } - break; - } - case QueryNodeType::RECURSIVE_CTE_NODE: { - auto &cte_node = node.Cast(); - VisitBoundQueryNode(*cte_node.left); - VisitBoundQueryNode(*cte_node.right); - break; - } - case QueryNodeType::CTE_NODE: { - auto &cte_node = node.Cast(); - VisitBoundQueryNode(*cte_node.child); - VisitBoundQueryNode(*cte_node.query); - break; - } - case QueryNodeType::SELECT_NODE: { - auto &bound_select = node.Cast(); - for (auto &expr : bound_select.select_list) { - VisitExpression(expr); - } - if (bound_select.where_clause) { - VisitExpression(bound_select.where_clause); - } - for (auto &expr : bound_select.groups.group_expressions) { - VisitExpression(expr); - } - if (bound_select.having) { - VisitExpression(bound_select.having); - } - for (auto &expr : bound_select.aggregates) { - VisitExpression(expr); - } - for (auto &entry : bound_select.unnests) { - for (auto &expr : entry.second.expressions) { - VisitExpression(expr); - } - } - for (auto &expr : bound_select.windows) { - VisitExpression(expr); - } - if (bound_select.from_table) { - VisitBoundTableRef(*bound_select.from_table); - } - break; - } - default: - throw NotImplementedException("Unimplemented query node in ExpressionIterator"); - } - for (idx_t i = 0; i < node.modifiers.size(); i++) { - switch (node.modifiers[i]->type) { - case ResultModifierType::DISTINCT_MODIFIER: - for (auto &expr : node.modifiers[i]->Cast().target_distincts) { - VisitExpression(expr); - } - break; - case ResultModifierType::ORDER_MODIFIER: - for (auto &order : node.modifiers[i]->Cast().orders) { - VisitExpression(order.expression); - } - break; - case ResultModifierType::LIMIT_MODIFIER: { - auto &limit_expr = node.modifiers[i]->Cast().limit_val.GetExpression(); - auto &offset_expr = node.modifiers[i]->Cast().offset_val.GetExpression(); - if (limit_expr) { - VisitExpression(limit_expr); - } - if (offset_expr) { - VisitExpression(offset_expr); - } - break; - } - default: - break; - } - } -} - -class LogicalBoundNodeVisitor : public LogicalOperatorVisitor { -public: - explicit LogicalBoundNodeVisitor(BoundNodeVisitor &parent) : parent(parent) { - } - - void VisitExpression(unique_ptr *expression) override { - auto &expr = **expression; - parent.VisitExpression(*expression); - VisitExpressionChildren(expr); - } - -protected: - BoundNodeVisitor &parent; -}; - -void BoundNodeVisitor::VisitBoundTableRef(BoundTableRef &ref) { - switch (ref.type) { - case TableReferenceType::EXPRESSION_LIST: { - auto &bound_expr_list = ref.Cast(); - for (auto &expr_list : bound_expr_list.values) { - for (auto &expr : expr_list) { - VisitExpression(expr); - } - } - break; - } - case TableReferenceType::JOIN: { - auto &bound_join = ref.Cast(); - if (bound_join.condition) { - VisitExpression(bound_join.condition); - } - VisitBoundTableRef(*bound_join.left); - VisitBoundTableRef(*bound_join.right); - break; - } - case TableReferenceType::SUBQUERY: { - auto &bound_subquery = ref.Cast(); - VisitBoundQueryNode(*bound_subquery.subquery); - break; - } - case TableReferenceType::TABLE_FUNCTION: { - auto &bound_table_function = ref.Cast(); - LogicalBoundNodeVisitor node_visitor(*this); - if (bound_table_function.get) { - node_visitor.VisitOperator(*bound_table_function.get); - } - if (bound_table_function.subquery) { - VisitBoundTableRef(*bound_table_function.subquery); - } - break; - } - case TableReferenceType::EMPTY_FROM: - case TableReferenceType::BASE_TABLE: - case TableReferenceType::CTE: - break; - default: - throw NotImplementedException("Unimplemented table reference type (%s) in ExpressionIterator", - EnumUtil::ToString(ref.type)); - } -} - } // namespace duckdb diff --git a/src/duckdb/src/planner/filter/constant_filter.cpp b/src/duckdb/src/planner/filter/constant_filter.cpp index 5e1f39991..be43a4b0c 100644 --- a/src/duckdb/src/planner/filter/constant_filter.cpp +++ b/src/duckdb/src/planner/filter/constant_filter.cpp @@ -57,7 +57,13 @@ FilterPropagateResult ConstantFilter::CheckStatistics(BaseStatistics &stats) con result = NumericStats::CheckZonemap(stats, comparison_type, array_ptr(&constant, 1)); break; case PhysicalType::VARCHAR: - result = StringStats::CheckZonemap(stats, comparison_type, array_ptr(&constant, 1)); + switch (stats.GetStatsType()) { + case StatisticsType::STRING_STATS: + result = StringStats::CheckZonemap(stats, comparison_type, array_ptr(&constant, 1)); + break; + default: + return FilterPropagateResult::NO_PRUNING_POSSIBLE; + } break; default: return FilterPropagateResult::NO_PRUNING_POSSIBLE; diff --git a/src/duckdb/src/planner/filter/expression_filter.cpp b/src/duckdb/src/planner/filter/expression_filter.cpp index 8e9b3299f..a86433f00 100644 --- a/src/duckdb/src/planner/filter/expression_filter.cpp +++ b/src/duckdb/src/planner/filter/expression_filter.cpp @@ -27,6 +27,11 @@ bool ExpressionFilter::EvaluateWithConstant(ExpressionExecutor &executor, const } FilterPropagateResult ExpressionFilter::CheckStatistics(BaseStatistics &stats) const { + if (stats.GetStatsType() == StatisticsType::GEOMETRY_STATS) { + // Delegate to GeometryStats for geometry types + return GeometryStats::CheckZonemap(stats, expr); + } + // we cannot prune based on arbitrary expressions currently return FilterPropagateResult::NO_PRUNING_POSSIBLE; } diff --git a/src/duckdb/src/planner/logical_operator.cpp b/src/duckdb/src/planner/logical_operator.cpp index e16062573..016b7d605 100644 --- a/src/duckdb/src/planner/logical_operator.cpp +++ b/src/duckdb/src/planner/logical_operator.cpp @@ -31,6 +31,19 @@ vector LogicalOperator::GetColumnBindings() { return {ColumnBinding(0, 0)}; } +idx_t LogicalOperator::GetRootIndex() { + auto bindings = GetColumnBindings(); + if (bindings.empty()) { + throw InternalException("Empty bindings in GetRootIndex"); + } + auto root_index = bindings[0].table_index; + for (idx_t i = 1; i < bindings.size(); i++) { + if (bindings[i].table_index != root_index) { + throw InternalException("GetRootIndex - multiple column bindings found"); + } + } + return root_index; +} void LogicalOperator::SetParamsEstimatedCardinality(InsertionOrderPreservingMap &result) const { if (has_estimated_cardinality) { result[RenderTreeNode::ESTIMATED_CARDINALITY] = StringUtil::Format("%llu", estimated_cardinality); diff --git a/src/duckdb/src/planner/logical_operator_visitor.cpp b/src/duckdb/src/planner/logical_operator_visitor.cpp index 5e96a5bbb..b7723d640 100644 --- a/src/duckdb/src/planner/logical_operator_visitor.cpp +++ b/src/duckdb/src/planner/logical_operator_visitor.cpp @@ -85,7 +85,6 @@ void LogicalOperatorVisitor::VisitChildOfOperatorWithProjectionMap(LogicalOperat void LogicalOperatorVisitor::EnumerateExpressions(LogicalOperator &op, const std::function *child)> &callback) { - switch (op.type) { case LogicalOperatorType::LOGICAL_EXPRESSION_GET: { auto &get = op.Cast(); diff --git a/src/duckdb/src/planner/operator/logical_copy_to_file.cpp b/src/duckdb/src/planner/operator/logical_copy_to_file.cpp index beee4f121..f15245c5d 100644 --- a/src/duckdb/src/planner/operator/logical_copy_to_file.cpp +++ b/src/duckdb/src/planner/operator/logical_copy_to_file.cpp @@ -126,7 +126,7 @@ unique_ptr LogicalCopyToFile::Deserialize(Deserializer &deseria throw InternalException("Copy function \"%s\" has neither bind nor (de)serialize", function.name); } - CopyFunctionBindInput function_bind_input(*copy_info); + CopyFunctionBindInput function_bind_input(*copy_info, function.function_info); auto names_to_write = GetNamesWithoutPartitions(names, partition_columns, write_partition_columns); auto types_to_write = GetTypesWithoutPartitions(expected_types, partition_columns, write_partition_columns); bind_data = function.copy_to_bind(context, function_bind_input, names_to_write, types_to_write); diff --git a/src/duckdb/src/planner/operator/logical_create_index.cpp b/src/duckdb/src/planner/operator/logical_create_index.cpp index e1bc0f0ee..44dcab583 100644 --- a/src/duckdb/src/planner/operator/logical_create_index.cpp +++ b/src/duckdb/src/planner/operator/logical_create_index.cpp @@ -10,7 +10,6 @@ LogicalCreateIndex::LogicalCreateIndex(unique_ptr info_p, vecto TableCatalogEntry &table_p, unique_ptr alter_table_info) : LogicalOperator(LogicalOperatorType::LOGICAL_CREATE_INDEX), info(std::move(info_p)), table(table_p), alter_table_info(std::move(alter_table_info)) { - for (auto &expr : expressions_p) { unbound_expressions.push_back(expr->Copy()); } @@ -27,7 +26,6 @@ LogicalCreateIndex::LogicalCreateIndex(ClientContext &context, unique_ptr(std::move(info_p))), table(BindTable(context, *info)), alter_table_info(unique_ptr_cast(std::move(alter_table_info))) { - for (auto &expr : expressions_p) { unbound_expressions.push_back(expr->Copy()); } diff --git a/src/duckdb/src/planner/operator/logical_dependent_join.cpp b/src/duckdb/src/planner/operator/logical_dependent_join.cpp index 2e46dbc78..70af8444a 100644 --- a/src/duckdb/src/planner/operator/logical_dependent_join.cpp +++ b/src/duckdb/src/planner/operator/logical_dependent_join.cpp @@ -3,7 +3,7 @@ namespace duckdb { LogicalDependentJoin::LogicalDependentJoin(unique_ptr left, unique_ptr right, - vector correlated_columns, JoinType type, + CorrelatedColumns correlated_columns, JoinType type, unique_ptr condition) : LogicalComparisonJoin(type, LogicalOperatorType::LOGICAL_DEPENDENT_JOIN), join_condition(std::move(condition)), correlated_columns(std::move(correlated_columns)) { @@ -17,7 +17,7 @@ LogicalDependentJoin::LogicalDependentJoin(JoinType join_type) unique_ptr LogicalDependentJoin::Create(unique_ptr left, unique_ptr right, - vector correlated_columns, JoinType type, + CorrelatedColumns correlated_columns, JoinType type, unique_ptr condition) { return make_uniq(std::move(left), std::move(right), std::move(correlated_columns), type, std::move(condition)); diff --git a/src/duckdb/src/planner/operator/logical_empty_result.cpp b/src/duckdb/src/planner/operator/logical_empty_result.cpp index 12c1653b3..b745228b5 100644 --- a/src/duckdb/src/planner/operator/logical_empty_result.cpp +++ b/src/duckdb/src/planner/operator/logical_empty_result.cpp @@ -4,7 +4,6 @@ namespace duckdb { LogicalEmptyResult::LogicalEmptyResult(unique_ptr op) : LogicalOperator(LogicalOperatorType::LOGICAL_EMPTY_RESULT) { - this->bindings = op->GetColumnBindings(); op->ResolveOperatorTypes(); diff --git a/src/duckdb/src/planner/operator/logical_vacuum.cpp b/src/duckdb/src/planner/operator/logical_vacuum.cpp index 36352a0ea..ce4a76951 100644 --- a/src/duckdb/src/planner/operator/logical_vacuum.cpp +++ b/src/duckdb/src/planner/operator/logical_vacuum.cpp @@ -1,5 +1,5 @@ #include "duckdb/planner/operator/logical_vacuum.hpp" - +#include "duckdb/planner/operator/logical_get.hpp" #include "duckdb/catalog/catalog_entry/table_catalog_entry.hpp" #include "duckdb/common/serializer/serializer.hpp" #include "duckdb/common/serializer/deserializer.hpp" @@ -46,11 +46,14 @@ unique_ptr LogicalVacuum::Deserialize(Deserializer &deserialize auto &context = deserializer.Get(); auto binder = Binder::CreateBinder(context); auto bound_table = binder->Bind(*info.ref); - if (bound_table->type != TableReferenceType::BASE_TABLE) { + if (bound_table.plan->type != LogicalOperatorType::LOGICAL_GET) { + throw InvalidInputException("can only vacuum or analyze base tables"); + } + auto table_ptr = bound_table.plan->Cast().GetTable(); + if (!table_ptr) { throw InvalidInputException("can only vacuum or analyze base tables"); } - auto ref = unique_ptr_cast(std::move(bound_table)); - auto &table = ref->table; + auto &table = *table_ptr; result->SetTable(table); // FIXME: we should probably verify that the 'column_id_map' and 'columns' are the same on the bound table after // deserialization? diff --git a/src/duckdb/src/planner/planner.cpp b/src/duckdb/src/planner/planner.cpp index 78bca8a02..a38bc2a6c 100644 --- a/src/duckdb/src/planner/planner.cpp +++ b/src/duckdb/src/planner/planner.cpp @@ -41,7 +41,7 @@ void Planner::CreatePlan(SQLStatement &statement) { bool parameters_resolved = true; try { profiler.StartPhase(MetricsType::PLANNER_BINDING); - binder->parameters = &bound_parameters; + binder->SetParameters(bound_parameters); auto bound_statement = binder->Bind(statement); profiler.EndPhase(); diff --git a/src/duckdb/src/planner/subquery/flatten_dependent_join.cpp b/src/duckdb/src/planner/subquery/flatten_dependent_join.cpp index 7b2909c6d..348d44997 100644 --- a/src/duckdb/src/planner/subquery/flatten_dependent_join.cpp +++ b/src/duckdb/src/planner/subquery/flatten_dependent_join.cpp @@ -18,9 +18,8 @@ namespace duckdb { -FlattenDependentJoins::FlattenDependentJoins(Binder &binder, const vector &correlated, - bool perform_delim, bool any_join, - optional_ptr parent) +FlattenDependentJoins::FlattenDependentJoins(Binder &binder, const CorrelatedColumns &correlated, bool perform_delim, + bool any_join, optional_ptr parent) : binder(binder), delim_offset(DConstants::INVALID_INDEX), correlated_columns(correlated), perform_delim(perform_delim), any_join(any_join), parent(parent) { for (idx_t i = 0; i < correlated_columns.size(); i++) { @@ -30,8 +29,7 @@ FlattenDependentJoins::FlattenDependentJoins(Binder &binder, const vector &correlated_columns, +static void CreateDelimJoinConditions(LogicalComparisonJoin &delim_join, const CorrelatedColumns &correlated_columns, vector bindings, idx_t base_offset, bool perform_delim) { auto col_count = perform_delim ? correlated_columns.size() : 1; for (idx_t i = 0; i < col_count; i++) { @@ -50,7 +48,7 @@ static void CreateDelimJoinConditions(LogicalComparisonJoin &delim_join, unique_ptr FlattenDependentJoins::DecorrelateIndependent(Binder &binder, unique_ptr plan) { - vector correlated; + CorrelatedColumns correlated; FlattenDependentJoins flatten(binder, correlated); return flatten.Decorrelate(std::move(plan)); } @@ -80,12 +78,12 @@ unique_ptr FlattenDependentJoins::Decorrelate(unique_ptrsecond = false; // rewrite - idx_t lateral_depth = 0; + idx_t next_lateral_depth = 0; - RewriteCorrelatedExpressions rewriter(base_binding, correlated_map, lateral_depth); + RewriteCorrelatedExpressions rewriter(base_binding, correlated_map, next_lateral_depth); rewriter.VisitOperator(*plan); - RewriteCorrelatedExpressions recursive_rewriter(base_binding, correlated_map, lateral_depth, true); + RewriteCorrelatedExpressions recursive_rewriter(base_binding, correlated_map, next_lateral_depth, true); recursive_rewriter.VisitOperator(*plan); } else { op.children[0] = Decorrelate(std::move(op.children[0])); @@ -94,8 +92,8 @@ unique_ptr FlattenDependentJoins::Decorrelate(unique_ptr(op.correlated_columns[0].binding.table_index); + const auto &op_col = op.correlated_columns[op.correlated_columns.GetDelimIndex()]; + auto window = make_uniq(op_col.binding.table_index); auto row_number = make_uniq(ExpressionType::WINDOW_ROW_NUMBER, LogicalType::BIGINT, nullptr, nullptr); row_number->start = WindowBoundary::UNBOUNDED_PRECEDING; @@ -114,9 +112,9 @@ unique_ptr FlattenDependentJoins::Decorrelate(unique_ptrchildren[1], op.is_lateral_join, lateral_depth); if (delim_join->children[1]->type == LogicalOperatorType::LOGICAL_MATERIALIZED_CTE) { - auto &cte = delim_join->children[1]->Cast(); + auto &cte_ref = delim_join->children[1]->Cast(); // check if the left side of the CTE has correlated expressions - auto entry = flatten.has_correlated_expressions.find(*cte.children[0]); + auto entry = flatten.has_correlated_expressions.find(*cte_ref.children[0]); if (entry != flatten.has_correlated_expressions.end()) { if (!entry->second) { // the left side of the CTE has no correlated expressions, we can push the DEPENDENT_JOIN down @@ -132,7 +130,7 @@ unique_ptr FlattenDependentJoins::Decorrelate(unique_ptrchildren[1] = flatten.PushDownDependentJoin(std::move(delim_join->children[1]), propagate_null_values, lateral_depth); data_offset = flatten.data_offset; - auto left_offset = delim_join->children[0]->GetColumnBindings().size(); + const auto left_offset = delim_join->children[0]->GetColumnBindings().size(); if (!parent) { delim_offset = left_offset + flatten.delim_offset; } @@ -214,7 +212,6 @@ unique_ptr FlattenDependentJoins::Decorrelate(unique_ptr FlattenDependentJoins::PushDownDependentJoinInternal } case LogicalOperatorType::LOGICAL_MATERIALIZED_CTE: case LogicalOperatorType::LOGICAL_RECURSIVE_CTE: { - #ifdef DEBUG plan->children[0]->ResolveOperatorTypes(); plan->children[1]->ResolveOperatorTypes(); diff --git a/src/duckdb/src/planner/subquery/has_correlated_expressions.cpp b/src/duckdb/src/planner/subquery/has_correlated_expressions.cpp index 9f1c679a1..8554f3f5b 100644 --- a/src/duckdb/src/planner/subquery/has_correlated_expressions.cpp +++ b/src/duckdb/src/planner/subquery/has_correlated_expressions.cpp @@ -7,7 +7,7 @@ namespace duckdb { -HasCorrelatedExpressions::HasCorrelatedExpressions(const vector &correlated, bool lateral, +HasCorrelatedExpressions::HasCorrelatedExpressions(const CorrelatedColumns &correlated, bool lateral, idx_t lateral_depth) : has_correlated_expressions(false), lateral(lateral), correlated_columns(correlated), lateral_depth(lateral_depth) { diff --git a/src/duckdb/src/planner/subquery/rewrite_correlated_expressions.cpp b/src/duckdb/src/planner/subquery/rewrite_correlated_expressions.cpp index 903840dda..10004d8f7 100644 --- a/src/duckdb/src/planner/subquery/rewrite_correlated_expressions.cpp +++ b/src/duckdb/src/planner/subquery/rewrite_correlated_expressions.cpp @@ -9,7 +9,6 @@ #include "duckdb/planner/expression_iterator.hpp" #include "duckdb/planner/tableref/bound_joinref.hpp" #include "duckdb/planner/operator/logical_dependent_join.hpp" -#include "duckdb/planner/tableref/bound_subqueryref.hpp" namespace duckdb { @@ -71,14 +70,14 @@ unique_ptr RewriteCorrelatedExpressions::VisitReplace(BoundColumnRef } //! Helper class used to recursively rewrite correlated expressions within nested subqueries. -class RewriteCorrelatedRecursive : public BoundNodeVisitor { +class RewriteCorrelatedRecursive : public LogicalOperatorVisitor { public: RewriteCorrelatedRecursive(ColumnBinding base_binding, column_binding_map_t &correlated_map); - void VisitBoundTableRef(BoundTableRef &ref) override; - void VisitExpression(unique_ptr &expression) override; + void VisitOperator(LogicalOperator &op) override; + void VisitExpression(unique_ptr *expression) override; - void RewriteCorrelatedSubquery(Binder &binder, BoundQueryNode &subquery); + void RewriteCorrelatedSubquery(Binder &binder, LogicalOperator &subquery); ColumnBinding base_binding; column_binding_map_t &correlated_map; @@ -92,7 +91,7 @@ unique_ptr RewriteCorrelatedExpressions::VisitReplace(BoundSubqueryE // subquery detected within this subquery // recursively rewrite it using the RewriteCorrelatedRecursive class RewriteCorrelatedRecursive rewrite(base_binding, correlated_map); - rewrite.RewriteCorrelatedSubquery(*expr.binder, *expr.subquery); + rewrite.RewriteCorrelatedSubquery(*expr.binder, *expr.subquery.plan); return nullptr; } @@ -101,40 +100,30 @@ RewriteCorrelatedRecursive::RewriteCorrelatedRecursive(ColumnBinding base_bindin : base_binding(base_binding), correlated_map(correlated_map) { } -void RewriteCorrelatedRecursive::VisitBoundTableRef(BoundTableRef &ref) { - if (ref.type == TableReferenceType::JOIN) { +void RewriteCorrelatedRecursive::VisitOperator(LogicalOperator &op) { + if (op.type == LogicalOperatorType::LOGICAL_DEPENDENT_JOIN) { // rewrite correlated columns in child joins - auto &bound_join = ref.Cast(); - for (auto &corr : bound_join.correlated_columns) { + auto &dep_join = op.Cast(); + for (auto &corr : dep_join.correlated_columns) { auto entry = correlated_map.find(corr.binding); if (entry != correlated_map.end()) { corr.binding = ColumnBinding(base_binding.table_index, base_binding.column_index + entry->second); } } - } else if (ref.type == TableReferenceType::SUBQUERY) { - auto &subquery = ref.Cast(); - RewriteCorrelatedSubquery(*subquery.binder, *subquery.subquery); - return; } // visit the children of the table ref - BoundNodeVisitor::VisitBoundTableRef(ref); + LogicalOperatorVisitor::VisitOperator(op); } -void RewriteCorrelatedRecursive::RewriteCorrelatedSubquery(Binder &binder, BoundQueryNode &subquery) { - // rewrite the binding in the correlated list of the subquery) - for (auto &corr : binder.correlated_columns) { - auto entry = correlated_map.find(corr.binding); - if (entry != correlated_map.end()) { - corr.binding = ColumnBinding(base_binding.table_index, base_binding.column_index + entry->second); - } - } - VisitBoundQueryNode(subquery); +void RewriteCorrelatedRecursive::RewriteCorrelatedSubquery(Binder &binder, LogicalOperator &op) { + VisitOperator(op); } -void RewriteCorrelatedRecursive::VisitExpression(unique_ptr &expression) { - if (expression->GetExpressionType() == ExpressionType::BOUND_COLUMN_REF) { +void RewriteCorrelatedRecursive::VisitExpression(unique_ptr *expression) { + auto &expr = **expression; + if (expr.GetExpressionType() == ExpressionType::BOUND_COLUMN_REF) { // bound column reference - auto &bound_colref = expression->Cast(); + auto &bound_colref = expr.Cast(); if (bound_colref.depth == 0) { // not a correlated column, ignore return; @@ -148,13 +137,13 @@ void RewriteCorrelatedRecursive::VisitExpression(unique_ptr &express bound_colref.binding = ColumnBinding(base_binding.table_index, base_binding.column_index + entry->second); bound_colref.depth--; } - } else if (expression->GetExpressionType() == ExpressionType::SUBQUERY) { + } else if (expr.GetExpressionType() == ExpressionType::SUBQUERY) { // we encountered another subquery: rewrite recursively - auto &bound_subquery = expression->Cast(); - RewriteCorrelatedSubquery(*bound_subquery.binder, *bound_subquery.subquery); + auto &bound_subquery = expr.Cast(); + RewriteCorrelatedSubquery(*bound_subquery.binder, *bound_subquery.subquery.plan); } // recurse into the children of this subquery - BoundNodeVisitor::VisitExpression(expression); + LogicalOperatorVisitor::VisitExpression(expression); } RewriteCountAggregates::RewriteCountAggregates(column_binding_map_t &replacement_map) diff --git a/src/duckdb/src/planner/subquery/rewrite_cte_scan.cpp b/src/duckdb/src/planner/subquery/rewrite_cte_scan.cpp index 78b3b21ec..f846d9b36 100644 --- a/src/duckdb/src/planner/subquery/rewrite_cte_scan.cpp +++ b/src/duckdb/src/planner/subquery/rewrite_cte_scan.cpp @@ -14,7 +14,7 @@ namespace duckdb { -RewriteCTEScan::RewriteCTEScan(idx_t table_index, const vector &correlated_columns) +RewriteCTEScan::RewriteCTEScan(idx_t table_index, const CorrelatedColumns &correlated_columns) : table_index(table_index), correlated_columns(correlated_columns) { } @@ -49,7 +49,7 @@ void RewriteCTEScan::VisitOperator(LogicalOperator &op) { // The correlated columns must be placed at the beginning of the // correlated_columns list. Otherwise, further column accesses // and rewrites will fail. - join.correlated_columns.emplace(join.correlated_columns.begin(), corr); + join.correlated_columns.AddColumn(std::move(corr)); } } } diff --git a/src/duckdb/src/planner/table_binding.cpp b/src/duckdb/src/planner/table_binding.cpp index d9bdd71c7..c55d0be82 100644 --- a/src/duckdb/src/planner/table_binding.cpp +++ b/src/duckdb/src/planner/table_binding.cpp @@ -19,6 +19,10 @@ Binding::Binding(BindingType binding_type, BindingAlias alias_p, vector &Binding::GetColumnTypes() { + return types; +} + +const vector &Binding::GetColumnNames() { + return names; +} + +idx_t Binding::GetColumnCount() { + return GetColumnNames().size(); +} + +void Binding::SetColumnType(idx_t col_idx, LogicalType type_p) { + types[col_idx] = std::move(type_p); +} + string Binding::GetAlias() const { return alias.GetAlias(); } @@ -304,4 +336,42 @@ unique_ptr DummyBinding::ParamToArg(ColumnRefExpression &colre return arg; } +CTEBinding::CTEBinding(BindingAlias alias, vector types, vector names, idx_t index, + CTEType cte_type) + : Binding(BindingType::CTE, std::move(alias), std::move(types), std::move(names), index), cte_type(cte_type), + reference_count(0) { +} + +CTEBinding::CTEBinding(BindingAlias alias_p, shared_ptr bind_state_p, idx_t index) + : Binding(BindingType::CTE, std::move(alias_p), vector(), vector(), index), + cte_type(CTEType::CAN_BE_REFERENCED), reference_count(0), bind_state(std::move(bind_state_p)) { +} + +bool CTEBinding::CanBeReferenced() const { + return cte_type == CTEType::CAN_BE_REFERENCED; +} + +bool CTEBinding::IsReferenced() const { + return reference_count > 0; +} + +void CTEBinding::Reference() { + if (!CanBeReferenced()) { + throw InternalException("CTE cannot be referenced!"); + } + if (bind_state) { + // we have not bound the CTE yet - bind it + bind_state->Bind(*this); + + // copy over the names / types and initialize the binding + this->names = bind_state->names; + this->types = bind_state->types; + Initialize(); + + // finalize binding + bind_state.reset(); + } + reference_count++; +} + } // namespace duckdb diff --git a/src/duckdb/src/storage/buffer/block_handle.cpp b/src/duckdb/src/storage/buffer/block_handle.cpp index dd9671af2..91c98d4b4 100644 --- a/src/duckdb/src/storage/buffer/block_handle.cpp +++ b/src/duckdb/src/storage/buffer/block_handle.cpp @@ -60,7 +60,6 @@ BlockHandle::~BlockHandle() { // NOLINT: allow internal exceptions unique_ptr AllocateBlock(BlockManager &block_manager, unique_ptr reusable_buffer, block_id_t block_id) { - if (reusable_buffer && reusable_buffer->GetHeaderSize() == block_manager.GetBlockHeaderSize()) { // re-usable buffer: re-use it if (reusable_buffer->GetBufferType() == FileBufferType::BLOCK) { diff --git a/src/duckdb/src/storage/buffer/block_manager.cpp b/src/duckdb/src/storage/buffer/block_manager.cpp index 47fef0ebf..3f2064a5c 100644 --- a/src/duckdb/src/storage/buffer/block_manager.cpp +++ b/src/duckdb/src/storage/buffer/block_manager.cpp @@ -34,19 +34,32 @@ shared_ptr BlockManager::RegisterBlock(block_id_t block_id) { } shared_ptr BlockManager::ConvertToPersistent(QueryContext context, block_id_t block_id, - shared_ptr old_block, BufferHandle old_handle) { + shared_ptr old_block, BufferHandle old_handle, + ConvertToPersistentMode mode) { // register a block with the new block id auto new_block = RegisterBlock(block_id); D_ASSERT(new_block->GetState() == BlockState::BLOCK_UNLOADED); D_ASSERT(new_block->Readers() == 0); + if (mode == ConvertToPersistentMode::THREAD_SAFE) { + // safe mode - create a copy of the old block and operate on that + // this ensures we don't modify the old block - which allows other concurrent operations on the old block to + // continue + auto old_block_copy = buffer_manager.AllocateMemory(old_block->GetMemoryTag(), this, false); + auto copy_pin = buffer_manager.Pin(old_block_copy); + memcpy(copy_pin.Ptr(), old_handle.Ptr(), GetBlockSize()); + old_block = std::move(old_block_copy); + old_handle = std::move(copy_pin); + } + auto lock = old_block->GetLock(); D_ASSERT(old_block->GetState() == BlockState::BLOCK_LOADED); D_ASSERT(old_block->GetBuffer(lock)); if (old_block->Readers() > 1) { - throw InternalException("BlockManager::ConvertToPersistent - cannot be called for block %d as old_block has " - "multiple readers active", - block_id); + throw InternalException( + "BlockManager::ConvertToPersistent in destructive mode - cannot be called for block %d as old_block has " + "multiple readers active", + block_id); } // Temp buffers can be larger than the storage block size. @@ -76,10 +89,11 @@ shared_ptr BlockManager::ConvertToPersistent(QueryContext context, } shared_ptr BlockManager::ConvertToPersistent(QueryContext context, block_id_t block_id, - shared_ptr old_block) { + shared_ptr old_block, + ConvertToPersistentMode mode) { // pin the old block to ensure we have it loaded in memory auto handle = buffer_manager.Pin(old_block); - return ConvertToPersistent(context, block_id, std::move(old_block), std::move(handle)); + return ConvertToPersistent(context, block_id, std::move(old_block), std::move(handle), mode); } void BlockManager::UnregisterBlock(block_id_t id) { diff --git a/src/duckdb/src/storage/caching_file_system.cpp b/src/duckdb/src/storage/caching_file_system.cpp index 3de905228..36a33ccb9 100644 --- a/src/duckdb/src/storage/caching_file_system.cpp +++ b/src/duckdb/src/storage/caching_file_system.cpp @@ -41,6 +41,9 @@ CachingFileHandle::CachingFileHandle(QueryContext context, CachingFileSystem &ca const auto &open_options = path.extended_info->options; const auto validate_entry = open_options.find("validate_external_file_cache"); if (validate_entry != open_options.end()) { + if (validate_entry->second.IsNull()) { + throw InvalidInputException("Cannot use NULL as argument for validate_external_file_cache"); + } validate = BooleanValue::Get(validate_entry->second); } } @@ -79,6 +82,21 @@ FileHandle &CachingFileHandle::GetFileHandle() { return *file_handle; } +static bool ShouldExpandToFillGap(const idx_t current_length, const idx_t added_length) { + const idx_t MAX_BOUND_TO_BE_ADDED_LENGTH = 1048576; + + if (added_length > MAX_BOUND_TO_BE_ADDED_LENGTH) { + // Absolute value of what would be needed to added is too high + return false; + } + if (added_length > current_length) { + // Relative value of what would be needed to added is too high + return false; + } + + return true; +} + BufferHandle CachingFileHandle::Read(data_ptr_t &buffer, const idx_t nr_bytes, const idx_t location) { BufferHandle result; if (!external_file_cache.IsEnabled()) { @@ -90,30 +108,42 @@ BufferHandle CachingFileHandle::Read(data_ptr_t &buffer, const idx_t nr_bytes, c // Try to read from the cache, filling overlapping_ranges in the process vector> overlapping_ranges; - result = TryReadFromCache(buffer, nr_bytes, location, overlapping_ranges); + optional_idx start_location_of_next_range; + result = TryReadFromCache(buffer, nr_bytes, location, overlapping_ranges, start_location_of_next_range); if (result.IsValid()) { return result; // Success } + idx_t new_nr_bytes = nr_bytes; + if (start_location_of_next_range.IsValid()) { + const idx_t nr_bytes_to_be_added = start_location_of_next_range.GetIndex() - location - nr_bytes; + if (ShouldExpandToFillGap(nr_bytes, nr_bytes_to_be_added)) { + // Grow the range from location to start_location_of_next_range, so that to fill gaps in the cached ranges + new_nr_bytes = nr_bytes + nr_bytes_to_be_added; + } + } + // Finally, if we weren't able to find the file range in the cache, we have to create a new file range - result = external_file_cache.GetBufferManager().Allocate(MemoryTag::EXTERNAL_FILE_CACHE, nr_bytes); - auto new_file_range = make_shared_ptr(result.GetBlockHandle(), nr_bytes, location, version_tag); + result = external_file_cache.GetBufferManager().Allocate(MemoryTag::EXTERNAL_FILE_CACHE, new_nr_bytes); + auto new_file_range = + make_shared_ptr(result.GetBlockHandle(), new_nr_bytes, location, version_tag); buffer = result.Ptr(); // Interleave reading and copying from cached buffers if (OnDiskFile()) { // On-disk file: prefer interleaving reading and copying from cached buffers - ReadAndCopyInterleaved(overlapping_ranges, new_file_range, buffer, nr_bytes, location, true); + ReadAndCopyInterleaved(overlapping_ranges, new_file_range, buffer, new_nr_bytes, location, true); } else { - // Remote file: prefer interleaving reading and copying from cached buffers only if reduces number of real reads - if (ReadAndCopyInterleaved(overlapping_ranges, new_file_range, buffer, nr_bytes, location, false) <= 1) { - ReadAndCopyInterleaved(overlapping_ranges, new_file_range, buffer, nr_bytes, location, true); + // Remote file: prefer interleaving reading and copying from cached buffers only if reduces number of real + // reads + if (ReadAndCopyInterleaved(overlapping_ranges, new_file_range, buffer, new_nr_bytes, location, false) <= 1) { + ReadAndCopyInterleaved(overlapping_ranges, new_file_range, buffer, new_nr_bytes, location, true); } else { - GetFileHandle().Read(context, buffer, nr_bytes, location); + GetFileHandle().Read(context, buffer, new_nr_bytes, location); } } - return TryInsertFileRange(result, buffer, nr_bytes, location, new_file_range); + return TryInsertFileRange(result, buffer, new_nr_bytes, location, new_file_range); } BufferHandle CachingFileHandle::Read(data_ptr_t &buffer, idx_t &nr_bytes) { @@ -131,7 +161,12 @@ BufferHandle CachingFileHandle::Read(data_ptr_t &buffer, idx_t &nr_bytes) { // Try to read from the cache first vector> overlapping_ranges; - result = TryReadFromCache(buffer, nr_bytes, position, overlapping_ranges); + { + optional_idx start_location_of_next_range; + result = TryReadFromCache(buffer, nr_bytes, position, overlapping_ranges, start_location_of_next_range); + // start_location_of_next_range is in this case discarded + } + if (result.IsValid()) { position += nr_bytes; return result; // Success @@ -214,7 +249,8 @@ const string &CachingFileHandle::GetVersionTag(const unique_ptr } BufferHandle CachingFileHandle::TryReadFromCache(data_ptr_t &buffer, idx_t nr_bytes, idx_t location, - vector> &overlapping_ranges) { + vector> &overlapping_ranges, + optional_idx &start_location_of_next_range) { BufferHandle result; // Get read lock for cached ranges @@ -246,7 +282,8 @@ BufferHandle CachingFileHandle::TryReadFromCache(data_ptr_t &buffer, idx_t nr_by } while (it != ranges.end()) { if (it->second->location >= this_end) { - // We're past the requested location + // We're past the requested location, we are going to bail out, save start_location_of_next_range + start_location_of_next_range = it->second->location; break; } // Check if the cached range overlaps the requested one diff --git a/src/duckdb/src/storage/checkpoint/table_data_writer.cpp b/src/duckdb/src/storage/checkpoint/table_data_writer.cpp index 342ea1ff5..48c85a066 100644 --- a/src/duckdb/src/storage/checkpoint/table_data_writer.cpp +++ b/src/duckdb/src/storage/checkpoint/table_data_writer.cpp @@ -4,6 +4,7 @@ #include "duckdb/catalog/catalog_entry/table_catalog_entry.hpp" #include "duckdb/common/serializer/binary_serializer.hpp" #include "duckdb/main/database.hpp" +#include "duckdb/main/settings.hpp" #include "duckdb/parallel/task_scheduler.hpp" #include "duckdb/storage/table/column_checkpoint_state.hpp" #include "duckdb/storage/table/table_statistics.hpp" @@ -119,15 +120,16 @@ void SingleFileTableDataWriter::FinalizeTable(const TableStatistics &global_stat } auto index_storage_infos = info.GetIndexes().SerializeToDisk(context, options); -#ifdef DUCKDB_BLOCK_VERIFICATION - for (auto &entry : index_storage_infos) { - for (auto &allocator : entry.allocator_infos) { - for (auto &block : allocator.block_pointers) { - checkpoint_manager.verify_block_usage_count[block.block_id]++; + auto debug_verify_blocks = DBConfig::GetSetting(GetDatabase()); + if (debug_verify_blocks) { + for (auto &entry : index_storage_infos) { + for (auto &allocator : entry.allocator_infos) { + for (auto &block : allocator.block_pointers) { + checkpoint_manager.verify_block_usage_count[block.block_id]++; + } } } } -#endif // write empty block pointers for forwards compatibility vector compat_block_pointers; diff --git a/src/duckdb/src/storage/checkpoint_manager.cpp b/src/duckdb/src/storage/checkpoint_manager.cpp index af361d4bf..6096c0a0e 100644 --- a/src/duckdb/src/storage/checkpoint_manager.cpp +++ b/src/duckdb/src/storage/checkpoint_manager.cpp @@ -21,7 +21,6 @@ #include "duckdb/parser/parsed_data/create_schema_info.hpp" #include "duckdb/parser/parsed_data/create_view_info.hpp" #include "duckdb/planner/binder.hpp" -#include "duckdb/planner/bound_tableref.hpp" #include "duckdb/planner/parsed_data/bound_create_table_info.hpp" #include "duckdb/storage/block_manager.hpp" #include "duckdb/storage/checkpoint/table_data_reader.hpp" @@ -214,33 +213,35 @@ void SingleFileCheckpointWriter::CreateCheckpoint() { header.vector_size = STANDARD_VECTOR_SIZE; block_manager.WriteHeader(context, header); -#ifdef DUCKDB_BLOCK_VERIFICATION - // extend verify_block_usage_count - auto metadata_info = storage_manager.GetMetadataInfo(); - for (auto &info : metadata_info) { - verify_block_usage_count[info.block_id]++; - } - for (auto &entry_ref : catalog_entries) { - auto &entry = entry_ref.get(); - if (entry.type == CatalogType::TABLE_ENTRY) { - auto &table = entry.Cast(); - auto &storage = table.GetStorage(); - auto segment_info = storage.GetColumnSegmentInfo(); - for (auto &segment : segment_info) { - verify_block_usage_count[segment.block_id]++; - if (StringUtil::Contains(segment.segment_info, "Overflow String Block Ids: ")) { - auto overflow_blocks = StringUtil::Replace(segment.segment_info, "Overflow String Block Ids: ", ""); - auto splits = StringUtil::Split(overflow_blocks, ", "); - for (auto &split : splits) { - auto overflow_block_id = std::stoll(split); - verify_block_usage_count[overflow_block_id]++; + auto debug_verify_blocks = DBConfig::GetSetting(db.GetDatabase()); + if (debug_verify_blocks) { + // extend verify_block_usage_count + auto metadata_info = storage_manager.GetMetadataInfo(); + for (auto &info : metadata_info) { + verify_block_usage_count[info.block_id]++; + } + for (auto &entry_ref : catalog_entries) { + auto &entry = entry_ref.get(); + if (entry.type == CatalogType::TABLE_ENTRY) { + auto &table = entry.Cast(); + auto &storage = table.GetStorage(); + auto segment_info = storage.GetColumnSegmentInfo(context); + for (auto &segment : segment_info) { + verify_block_usage_count[segment.block_id]++; + if (StringUtil::Contains(segment.segment_info, "Overflow String Block Ids: ")) { + auto overflow_blocks = + StringUtil::Replace(segment.segment_info, "Overflow String Block Ids: ", ""); + auto splits = StringUtil::Split(overflow_blocks, ", "); + for (auto &split : splits) { + auto overflow_block_id = std::stoll(split); + verify_block_usage_count[overflow_block_id]++; + } } } } } + block_manager.VerifyBlocks(verify_block_usage_count); } - block_manager.VerifyBlocks(verify_block_usage_count); -#endif if (debug_checkpoint_abort == CheckpointAbort::DEBUG_ABORT_BEFORE_TRUNCATE) { throw FatalException("Checkpoint aborted before truncate because of PRAGMA checkpoint_abort flag"); @@ -566,7 +567,6 @@ void CheckpointReader::ReadTable(CatalogTransaction transaction, Deserializer &d void CheckpointReader::ReadTableData(CatalogTransaction transaction, Deserializer &deserializer, BoundCreateTableInfo &bound_info) { - // written in "SingleFileTableDataWriter::FinalizeTable" auto table_pointer = deserializer.ReadProperty(101, "table_pointer"); auto total_rows = deserializer.ReadProperty(102, "total_rows"); diff --git a/src/duckdb/src/storage/compression/bitpacking.cpp b/src/duckdb/src/storage/compression/bitpacking.cpp index fa1ffaeba..628ff19a9 100644 --- a/src/duckdb/src/storage/compression/bitpacking.cpp +++ b/src/duckdb/src/storage/compression/bitpacking.cpp @@ -19,6 +19,7 @@ namespace duckdb { +constexpr const idx_t BitpackingPrimitives::BITPACKING_ALGORITHM_GROUP_SIZE; static constexpr const idx_t BITPACKING_METADATA_GROUP_SIZE = STANDARD_VECTOR_SIZE > 512 ? STANDARD_VECTOR_SIZE : 2048; BitpackingMode BitpackingModeFromString(const string &str) { @@ -341,8 +342,6 @@ unique_ptr BitpackingInitAnalyze(ColumnData &col_data, PhysicalTyp template bool BitpackingAnalyze(AnalyzeState &state, Vector &input, idx_t count) { - auto &analyze_state = state.Cast>(); - // We use BITPACKING_METADATA_GROUP_SIZE tuples, which can exceed the block size. // In that case, we disable bitpacking. // we are conservative here by multiplying by 2 @@ -351,6 +350,7 @@ bool BitpackingAnalyze(AnalyzeState &state, Vector &input, idx_t count) { return false; } + auto &analyze_state = state.Cast>(); UnifiedVectorFormat vdata; input.ToUnifiedFormat(count, vdata); @@ -629,9 +629,9 @@ static T DeltaDecode(T *data, T previous_value, const size_t size) { template ::type> struct BitpackingScanState : public SegmentScanState { public: - explicit BitpackingScanState(ColumnSegment &segment) : current_segment(segment) { + explicit BitpackingScanState(const QueryContext &context, ColumnSegment &segment) : current_segment(segment) { auto &buffer_manager = BufferManager::GetBufferManager(segment.db); - handle = buffer_manager.Pin(segment.block); + handle = buffer_manager.Pin(context, segment.block); auto data_ptr = handle.Ptr(); // load offset to bitpacking widths pointer @@ -720,7 +720,6 @@ struct BitpackingScanState : public SegmentScanState { // This skips straight to the correct metadata group idx_t meta_groups_to_skip = (skip_count + current_group_offset) / BITPACKING_METADATA_GROUP_SIZE; if (meta_groups_to_skip) { - // bitpacking_metadata_ptr points to the next metadata: this means we need to advance the pointer by n-1 bitpacking_metadata_ptr -= (meta_groups_to_skip - 1) * sizeof(bitpacking_metadata_encoded_t); LoadNextGroup(); @@ -782,8 +781,8 @@ struct BitpackingScanState : public SegmentScanState { }; template -unique_ptr BitpackingInitScan(ColumnSegment &segment) { - auto result = make_uniq>(segment); +unique_ptr BitpackingInitScan(const QueryContext &context, ColumnSegment &segment) { + auto result = make_uniq>(context, segment); return std::move(result); } @@ -892,7 +891,7 @@ void BitpackingScan(ColumnSegment &segment, ColumnScanState &state, idx_t scan_c template void BitpackingFetchRow(ColumnSegment &segment, ColumnFetchState &state, row_t row_id, Vector &result, idx_t result_idx) { - BitpackingScanState scan_state(segment); + BitpackingScanState scan_state(state.context, segment); scan_state.Skip(segment, NumericCast(row_id)); D_ASSERT(scan_state.current_group_offset < BITPACKING_METADATA_GROUP_SIZE); @@ -956,10 +955,10 @@ void BitpackingSkip(ColumnSegment &segment, ColumnScanState &state, idx_t skip_c // GetSegmentInfo //===--------------------------------------------------------------------===// template -InsertionOrderPreservingMap BitpackingGetSegmentInfo(ColumnSegment &segment) { +InsertionOrderPreservingMap BitpackingGetSegmentInfo(QueryContext context, ColumnSegment &segment) { map counts; auto tuple_count = segment.count.load(); - BitpackingScanState scan_state(segment); + BitpackingScanState scan_state(context, segment); for (idx_t i = 0; i < tuple_count; i += BITPACKING_METADATA_GROUP_SIZE) { if (i) { scan_state.LoadNextGroup(); diff --git a/src/duckdb/src/storage/compression/bitpacking_hugeint.cpp b/src/duckdb/src/storage/compression/bitpacking_hugeint.cpp index 3c84c6ec6..c363c9280 100644 --- a/src/duckdb/src/storage/compression/bitpacking_hugeint.cpp +++ b/src/duckdb/src/storage/compression/bitpacking_hugeint.cpp @@ -119,7 +119,6 @@ static void UnpackDelta128(const uint32_t *__restrict in, uhugeint_t *__restrict static void PackSingle(const uhugeint_t in, uint32_t *__restrict &out, uint16_t delta, uint16_t shl, uhugeint_t mask) { if (delta + shl < 32) { - if (shl == 0) { out[0] = static_cast(in & mask); } else { @@ -127,7 +126,6 @@ static void PackSingle(const uhugeint_t in, uint32_t *__restrict &out, uint16_t } } else if (delta + shl >= 32 && delta + shl < 64) { - if (shl == 0) { out[0] = static_cast(in & mask); } else { @@ -141,7 +139,6 @@ static void PackSingle(const uhugeint_t in, uint32_t *__restrict &out, uint16_t } else if (delta + shl >= 64 && delta + shl < 96) { - if (shl == 0) { out[0] = static_cast(in & mask); } else { diff --git a/src/duckdb/src/storage/compression/dict_fsst.cpp b/src/duckdb/src/storage/compression/dict_fsst.cpp index c43567c52..18c5dac21 100644 --- a/src/duckdb/src/storage/compression/dict_fsst.cpp +++ b/src/duckdb/src/storage/compression/dict_fsst.cpp @@ -56,7 +56,7 @@ struct DictFSSTCompressionStorage { static void Compress(CompressionState &state_p, Vector &scan_vector, idx_t count); static void FinalizeCompress(CompressionState &state_p); - static unique_ptr StringInitScan(ColumnSegment &segment); + static unique_ptr StringInitScan(const QueryContext &context, ColumnSegment &segment); template static void StringScanPartial(ColumnSegment &segment, ColumnScanState &state, idx_t scan_count, Vector &result, idx_t result_offset); @@ -111,12 +111,15 @@ void DictFSSTCompressionStorage::FinalizeCompress(CompressionState &state_p) { //===--------------------------------------------------------------------===// // Scan //===--------------------------------------------------------------------===// -unique_ptr DictFSSTCompressionStorage::StringInitScan(ColumnSegment &segment) { +unique_ptr DictFSSTCompressionStorage::StringInitScan(const QueryContext &context, + ColumnSegment &segment) { auto &buffer_manager = BufferManager::GetBufferManager(segment.db); auto state = make_uniq(segment, buffer_manager.Pin(segment.block)); state->Initialize(true); - if (StringStats::HasMaxStringLength(segment.stats.statistics)) { - state->all_values_inlined = StringStats::MaxStringLength(segment.stats.statistics) <= string_t::INLINE_LENGTH; + + const auto &stats = segment.stats.statistics; + if (stats.GetStatsType() == StatisticsType::STRING_STATS && StringStats::HasMaxStringLength(stats)) { + state->all_values_inlined = StringStats::MaxStringLength(stats) <= string_t::INLINE_LENGTH; } return std::move(state); } @@ -187,12 +190,13 @@ static void DictFSSTFilter(ColumnSegment &segment, ColumnScanState &state, idx_t scan_state.filter_result = make_unsafe_uniq_array(scan_state.dict_count); // apply the filter + auto &dict_data = scan_state.dictionary->data; UnifiedVectorFormat vdata; - scan_state.dictionary->ToUnifiedFormat(scan_state.dict_count, vdata); + dict_data.ToUnifiedFormat(scan_state.dict_count, vdata); SelectionVector dict_sel; idx_t filter_count = scan_state.dict_count; - ColumnSegment::FilterSelection(dict_sel, *scan_state.dictionary, vdata, filter, filter_state, - scan_state.dict_count, filter_count); + ColumnSegment::FilterSelection(dict_sel, dict_data, vdata, filter, filter_state, scan_state.dict_count, + filter_count); // now set all matching tuples to true for (idx_t i = 0; i < filter_count; i++) { @@ -217,8 +221,7 @@ static void DictFSSTFilter(ColumnSegment &segment, ColumnScanState &state, idx_t } sel_count = approved_tuple_count; - result.Dictionary(*(scan_state.dictionary), scan_state.dict_count, dict_sel, vector_count); - DictionaryVector::SetDictionaryId(result, to_string(CastPointerToValue(&segment))); + result.Dictionary(scan_state.dictionary, dict_sel); return; } // fallback: scan + filter diff --git a/src/duckdb/src/storage/compression/dict_fsst/compression.cpp b/src/duckdb/src/storage/compression/dict_fsst/compression.cpp index 580a5cfc5..9c2cdf85a 100644 --- a/src/duckdb/src/storage/compression/dict_fsst/compression.cpp +++ b/src/duckdb/src/storage/compression/dict_fsst/compression.cpp @@ -4,6 +4,10 @@ #include "fsst.h" #include "duckdb/common/fsst.hpp" +#if defined(__MVS__) && !defined(alloca) +#define alloca __builtin_alloca +#endif + namespace duckdb { namespace dict_fsst { @@ -11,6 +15,11 @@ DictFSSTCompressionState::DictFSSTCompressionState(ColumnDataCheckpointData &che unique_ptr &&analyze_p) : CompressionState(analyze_p->info), checkpoint_data(checkpoint_data_p), function(checkpoint_data.GetCompressionFunction(CompressionType::COMPRESSION_DICT_FSST)), + current_string_map( + info.GetBlockManager().buffer_manager.GetBufferAllocator(), + MinValue(analyze_p.get()->total_count, info.GetBlockSize()) / 2, // maximum_size_p (amount of elements) + 1 // maximum_target_capacity_p (byte capacity) + ), analyze(std::move(analyze_p)) { CreateEmptySegment(checkpoint_data.GetRowGroup().start); } @@ -251,7 +260,7 @@ void DictFSSTCompressionState::CreateEmptySegment(idx_t row_start) { D_ASSERT(string_lengths.empty()); string_lengths.push_back(0); dict_count = 1; - D_ASSERT(current_string_map.empty()); + D_ASSERT(current_string_map.GetSize() == 0); symbol_table_size = DConstants::INVALID_INDEX; dictionary_offset = 0; @@ -280,11 +289,7 @@ void DictFSSTCompressionState::Flush(bool final) { D_ASSERT(dictionary_encoding_buffer.empty()); D_ASSERT(to_encode_string_sum == 0); - auto old_size = current_string_map.size(); - current_string_map.clear(); - if (!final) { - current_string_map.reserve(old_size); - } + current_string_map.Clear(); string_lengths.clear(); dictionary_indices.clear(); if (encoder) { @@ -444,7 +449,7 @@ static inline bool AddToDictionary(DictFSSTCompressionState &state, const string } state.to_encode_string_sum += str_len; auto &uncompressed_string = state.dictionary_encoding_buffer.back(); - state.current_string_map[uncompressed_string] = state.dict_count; + state.current_string_map.Insert(uncompressed_string); } else { state.string_lengths.push_back(str_len); auto baseptr = @@ -452,7 +457,7 @@ static inline bool AddToDictionary(DictFSSTCompressionState &state, const string memcpy(baseptr + state.dictionary_offset, str.GetData(), str_len); string_t dictionary_string((const char *)(baseptr + state.dictionary_offset), str_len); // NOLINT state.dictionary_offset += str_len; - state.current_string_map[dictionary_string] = state.dict_count; + state.current_string_map.Insert(dictionary_string); } state.dict_count++; @@ -490,8 +495,8 @@ bool DictFSSTCompressionState::CompressInternal(UnifiedVectorFormat &vector_form if (append_state == DictionaryAppendState::ENCODED_ALL_UNIQUE || is_null) { lookup = 0; } else { - auto it = current_string_map.find(str); - lookup = it == current_string_map.end() ? DConstants::INVALID_INDEX : it->second; + auto it = current_string_map.Lookup(str); + lookup = it.IsEmpty() ? DConstants::INVALID_INDEX : it.index + 1; } switch (append_state) { @@ -785,8 +790,7 @@ DictionaryAppendState DictFSSTCompressionState::TryEncode() { #endif // Rewrite the dictionary - current_string_map.clear(); - current_string_map.reserve(dict_count); + current_string_map.Clear(); if (new_state == DictionaryAppendState::ENCODED) { offset = 0; auto uncompressed_dictionary_ptr = dict_copy.GetData(); @@ -797,7 +801,7 @@ DictionaryAppendState DictFSSTCompressionState::TryEncode() { auto uncompressed_str_len = string_lengths[dictionary_index]; string_t dictionary_string(uncompressed_dictionary_ptr + offset, uncompressed_str_len); - current_string_map.insert({dictionary_string, dictionary_index}); + current_string_map.Insert(dictionary_string); #ifdef DEBUG //! Verify that we can decompress the string @@ -822,7 +826,7 @@ DictionaryAppendState DictFSSTCompressionState::TryEncode() { string_lengths[dictionary_index] = size; string_t dictionary_string((const char *)start, UnsafeNumericCast(size)); // NOLINT - current_string_map.insert({dictionary_string, dictionary_index}); + current_string_map.Insert(dictionary_string); } } dictionary_offset = new_size; diff --git a/src/duckdb/src/storage/compression/dict_fsst/decompression.cpp b/src/duckdb/src/storage/compression/dict_fsst/decompression.cpp index 0546096bb..f6befffe5 100644 --- a/src/duckdb/src/storage/compression/dict_fsst/decompression.cpp +++ b/src/duckdb/src/storage/compression/dict_fsst/decompression.cpp @@ -98,17 +98,18 @@ void CompressedStringScanState::Initialize(bool initialize_dictionary) { return; } - dictionary = make_buffer(segment.type, dict_count); - auto dict_child_data = FlatVector::GetData(*(dictionary)); - auto &validity = FlatVector::Validity(*dictionary); + dictionary = DictionaryVector::CreateReusableDictionary(segment.type, dict_count); + auto dict_child_data = FlatVector::GetData(dictionary->data); + auto &validity = FlatVector::Validity(dictionary->data); D_ASSERT(dict_count >= 1); validity.SetInvalid(0); + auto &dict_data = dictionary->data; uint32_t offset = 0; for (uint32_t i = 0; i < dict_count; i++) { //! We can uncompress during fetching, we need the length of the string inside the dictionary auto string_len = string_lengths[i]; - dict_child_data[i] = FetchStringFromDict(*dictionary, offset, i); + dict_child_data[i] = FetchStringFromDict(dict_data, offset, i); offset += string_len; } } @@ -158,7 +159,7 @@ void CompressedStringScanState::ScanToFlatVector(Vector &result, idx_t result_of if (dictionary) { // We have prepared the full dictionary, we can reference these strings directly - auto dictionary_values = FlatVector::GetData(*dictionary); + auto dictionary_values = FlatVector::GetData(dictionary->data); for (idx_t i = 0; i < scan_count; i++) { // Lookup dict offset in index buffer auto string_number = selvec.get_index(i + start_offset); @@ -223,8 +224,7 @@ void CompressedStringScanState::ScanToDictionaryVector(ColumnSegment &segment, V D_ASSERT(result_offset == 0); auto &selvec = GetSelVec(start, scan_count); - result.Dictionary(*(dictionary), dict_count, selvec, scan_count); - DictionaryVector::SetDictionaryId(result, to_string(CastPointerToValue(&segment))); + result.Dictionary(dictionary, selvec); result.Verify(result_offset + scan_count); } diff --git a/src/duckdb/src/storage/compression/dictionary/analyze.cpp b/src/duckdb/src/storage/compression/dictionary/analyze.cpp index 3d12bc2e1..538ad543c 100644 --- a/src/duckdb/src/storage/compression/dictionary/analyze.cpp +++ b/src/duckdb/src/storage/compression/dictionary/analyze.cpp @@ -44,10 +44,14 @@ bool DictionaryAnalyzeState::CalculateSpaceRequirements(bool new_string, idx_t s void DictionaryAnalyzeState::Flush(bool final) { segment_count++; current_tuple_count = 0; + max_unique_count_across_segments = MaxValue(max_unique_count_across_segments, current_unique_count); current_unique_count = 0; current_dict_size = 0; current_set.clear(); } +void DictionaryAnalyzeState::UpdateMaxUniqueCount() { + max_unique_count_across_segments = MaxValue(max_unique_count_across_segments, current_unique_count); +} void DictionaryAnalyzeState::Verify() { } diff --git a/src/duckdb/src/storage/compression/dictionary/compression.cpp b/src/duckdb/src/storage/compression/dictionary/compression.cpp index 48b02a42a..d1de02fa9 100644 --- a/src/duckdb/src/storage/compression/dictionary/compression.cpp +++ b/src/duckdb/src/storage/compression/dictionary/compression.cpp @@ -1,12 +1,18 @@ #include "duckdb/storage/compression/dictionary/compression.hpp" -#include "duckdb/storage/segment/uncompressed.hpp" namespace duckdb { DictionaryCompressionCompressState::DictionaryCompressionCompressState(ColumnDataCheckpointData &checkpoint_data_p, - const CompressionInfo &info) + const CompressionInfo &info, + const idx_t max_unique_count_across_all_segments) : DictionaryCompressionState(info), checkpoint_data(checkpoint_data_p), - function(checkpoint_data.GetCompressionFunction(CompressionType::COMPRESSION_DICTIONARY)) { + function(checkpoint_data.GetCompressionFunction(CompressionType::COMPRESSION_DICTIONARY)), + current_string_map( + info.GetBlockManager().buffer_manager.GetBufferAllocator(), + max_unique_count_across_all_segments * 2, // * 2 results in less linear probing, improving performance + 1 // maximum_target_capacity_p, 1 because we don't care about target for our use-case, as we + // only use PrimitiveDictionary for duplicate checks, and not for writing to any target + ) { CreateEmptySegment(checkpoint_data.GetRowGroup().start); } @@ -19,7 +25,7 @@ void DictionaryCompressionCompressState::CreateEmptySegment(idx_t row_start) { current_segment = std::move(compressed_segment); // Reset the buffers and the string map. - current_string_map.clear(); + current_string_map.Clear(); index_buffer.clear(); // Reserve index 0 for null strings. @@ -42,15 +48,14 @@ void DictionaryCompressionCompressState::Verify() { D_ASSERT(DictionaryCompression::HasEnoughSpace(current_segment->count.load(), index_buffer.size(), current_dictionary.size, current_width, info.GetBlockSize())); D_ASSERT(current_dictionary.end == info.GetBlockSize()); - D_ASSERT(index_buffer.size() == current_string_map.size() + 1); // +1 is for null value + D_ASSERT(index_buffer.size() == current_string_map.GetSize() + 1); // +1 is for null value } bool DictionaryCompressionCompressState::LookupString(string_t str) { - auto search = current_string_map.find(str); - auto has_result = search != current_string_map.end(); - + const auto &entry = current_string_map.Lookup(str); + const auto has_result = !entry.IsEmpty(); if (has_result) { - latest_lookup_result = search->second; + latest_lookup_result = entry.index + 1; } return has_result; } @@ -69,11 +74,11 @@ void DictionaryCompressionCompressState::AddNewString(string_t str) { index_buffer.push_back(current_dictionary.size); selection_buffer.push_back(UnsafeNumericCast(index_buffer.size() - 1)); if (str.IsInlined()) { - current_string_map.insert({str, index_buffer.size() - 1}); + current_string_map.Insert(str); } else { string_t dictionary_string((const char *)dict_pos, UnsafeNumericCast(str.GetSize())); // NOLINT D_ASSERT(!dictionary_string.IsInlined()); - current_string_map.insert({dictionary_string, index_buffer.size() - 1}); + current_string_map.Insert(dictionary_string); } DictionaryCompression::SetDictionary(*current_segment, current_handle, current_dictionary); diff --git a/src/duckdb/src/storage/compression/dictionary/decompression.cpp b/src/duckdb/src/storage/compression/dictionary/decompression.cpp index 51f6945e2..6e389d026 100644 --- a/src/duckdb/src/storage/compression/dictionary/decompression.cpp +++ b/src/duckdb/src/storage/compression/dictionary/decompression.cpp @@ -48,10 +48,10 @@ void CompressedStringScanState::Initialize(ColumnSegment &segment, bool initiali return; } - dictionary = make_buffer(segment.type, index_buffer_count); + dictionary = DictionaryVector::CreateReusableDictionary(segment.type, index_buffer_count); dictionary_size = index_buffer_count; - auto dict_child_data = FlatVector::GetData(*(dictionary)); - FlatVector::SetNull(*dictionary, 0, true); + auto dict_child_data = FlatVector::GetData(dictionary->data); + FlatVector::SetNull(dictionary->data, 0, true); for (uint32_t i = 1; i < index_buffer_count; i++) { // NOTE: the passing of dict_child_vector, will not be used, its for big strings uint16_t str_len = GetStringLength(i); @@ -114,8 +114,7 @@ void CompressedStringScanState::ScanToDictionaryVector(ColumnSegment &segment, V } } - result.Dictionary(*(dictionary), dictionary_size, *sel_vec, scan_count); - DictionaryVector::SetDictionaryId(result, to_string(CastPointerToValue(&segment))); + result.Dictionary(dictionary, *sel_vec); } } // namespace duckdb diff --git a/src/duckdb/src/storage/compression/dictionary_compression.cpp b/src/duckdb/src/storage/compression/dictionary_compression.cpp index fa027edd9..e3a976ccc 100644 --- a/src/duckdb/src/storage/compression/dictionary_compression.cpp +++ b/src/duckdb/src/storage/compression/dictionary_compression.cpp @@ -4,8 +4,6 @@ #include "duckdb/common/bitpacking.hpp" #include "duckdb/common/numeric_utils.hpp" -#include "duckdb/common/operator/comparison_operators.hpp" -#include "duckdb/common/string_map_set.hpp" #include "duckdb/common/types/vector_buffer.hpp" #include "duckdb/function/compression/compression.hpp" #include "duckdb/function/compression_function.hpp" @@ -57,7 +55,7 @@ struct DictionaryCompressionStorage { static void Compress(CompressionState &state_p, Vector &scan_vector, idx_t count); static void FinalizeCompress(CompressionState &state_p); - static unique_ptr StringInitScan(ColumnSegment &segment); + static unique_ptr StringInitScan(const QueryContext &context, ColumnSegment &segment); template static void StringScanPartial(ColumnSegment &segment, ColumnScanState &state, idx_t scan_count, Vector &result, idx_t result_offset); @@ -89,6 +87,10 @@ idx_t DictionaryCompressionStorage::StringFinalAnalyze(AnalyzeState &state_p) { auto &analyze_state = state_p.Cast(); auto &state = *analyze_state.analyze_state; + if (state.current_tuple_count != 0) { + state.UpdateMaxUniqueCount(); + } + auto width = BitpackingPrimitives::MinimumBitWidth(state.current_unique_count + 1); auto req_space = DictionaryCompression::RequiredSpace(state.current_tuple_count, state.current_unique_count, state.current_dict_size, width); @@ -102,7 +104,10 @@ idx_t DictionaryCompressionStorage::StringFinalAnalyze(AnalyzeState &state_p) { //===--------------------------------------------------------------------===// unique_ptr DictionaryCompressionStorage::InitCompression(ColumnDataCheckpointData &checkpoint_data, unique_ptr state) { - return make_uniq(checkpoint_data, state->info); + const auto &analyze_state = state->Cast(); + auto &actual_state = *analyze_state.analyze_state; + return make_uniq(checkpoint_data, state->info, + actual_state.max_unique_count_across_segments); } void DictionaryCompressionStorage::Compress(CompressionState &state_p, Vector &scan_vector, idx_t count) { @@ -118,7 +123,8 @@ void DictionaryCompressionStorage::FinalizeCompress(CompressionState &state_p) { //===--------------------------------------------------------------------===// // Scan //===--------------------------------------------------------------------===// -unique_ptr DictionaryCompressionStorage::StringInitScan(ColumnSegment &segment) { +unique_ptr DictionaryCompressionStorage::StringInitScan(const QueryContext &context, + ColumnSegment &segment) { auto &buffer_manager = BufferManager::GetBufferManager(segment.db); auto state = make_uniq(buffer_manager.Pin(segment.block)); state->Initialize(segment, true); diff --git a/src/duckdb/src/storage/compression/fixed_size_uncompressed.cpp b/src/duckdb/src/storage/compression/fixed_size_uncompressed.cpp index afd335dab..89c718525 100644 --- a/src/duckdb/src/storage/compression/fixed_size_uncompressed.cpp +++ b/src/duckdb/src/storage/compression/fixed_size_uncompressed.cpp @@ -143,10 +143,10 @@ struct FixedSizeScanState : public SegmentScanState { BufferHandle handle; }; -unique_ptr FixedSizeInitScan(ColumnSegment &segment) { +unique_ptr FixedSizeInitScan(const QueryContext &context, ColumnSegment &segment) { auto result = make_uniq(); auto &buffer_manager = BufferManager::GetBufferManager(segment.db); - result->handle = buffer_manager.Pin(segment.block); + result->handle = buffer_manager.Pin(context, segment.block); return std::move(result); } diff --git a/src/duckdb/src/storage/compression/fsst.cpp b/src/duckdb/src/storage/compression/fsst.cpp index cbb3b3ac7..57b769432 100644 --- a/src/duckdb/src/storage/compression/fsst.cpp +++ b/src/duckdb/src/storage/compression/fsst.cpp @@ -50,7 +50,7 @@ struct FSSTStorage { static void Compress(CompressionState &state_p, Vector &scan_vector, idx_t count); static void FinalizeCompress(CompressionState &state_p); - static unique_ptr StringInitScan(ColumnSegment &segment); + static unique_ptr StringInitScan(const QueryContext &context, ColumnSegment &segment); template static void StringScanPartial(ColumnSegment &segment, ColumnScanState &state, idx_t scan_count, Vector &result, idx_t result_offset); @@ -569,7 +569,7 @@ struct FSSTScanState : public StringScanState { } }; -unique_ptr FSSTStorage::StringInitScan(ColumnSegment &segment) { +unique_ptr FSSTStorage::StringInitScan(const QueryContext &context, ColumnSegment &segment) { auto block_size = segment.GetBlockManager().GetBlockSize(); auto string_block_limit = StringUncompressed::GetStringBlockLimit(block_size); auto state = make_uniq(string_block_limit); @@ -585,8 +585,9 @@ unique_ptr FSSTStorage::StringInitScan(ColumnSegment &segment) } state->duckdb_fsst_decoder_ptr = state->duckdb_fsst_decoder.get(); - if (StringStats::HasMaxStringLength(segment.stats.statistics)) { - state->all_values_inlined = StringStats::MaxStringLength(segment.stats.statistics) <= string_t::INLINE_LENGTH; + const auto &stats = segment.stats.statistics; + if (stats.GetStatsType() == StatisticsType::STRING_STATS && StringStats::HasMaxStringLength(stats)) { + state->all_values_inlined = StringStats::MaxStringLength(stats) <= string_t::INLINE_LENGTH; } return std::move(state); @@ -640,7 +641,6 @@ void FSSTStorage::EndScan(FSSTScanState &scan_state, bp_delta_offsets_t &offsets template void FSSTStorage::StringScanPartial(ColumnSegment &segment, ColumnScanState &state, idx_t scan_count, Vector &result, idx_t result_offset) { - auto &scan_state = state.scan_state->Cast(); auto start = segment.GetRelativeIndex(state.row_index); @@ -734,7 +734,6 @@ void FSSTStorage::Select(ColumnSegment &segment, ColumnScanState &state, idx_t v //===--------------------------------------------------------------------===// void FSSTStorage::StringFetchRow(ColumnSegment &segment, ColumnFetchState &state, row_t row_id, Vector &result, idx_t result_idx) { - auto &buffer_manager = BufferManager::GetBufferManager(segment.db); auto handle = buffer_manager.Pin(segment.block); auto base_ptr = handle.Ptr() + segment.GetBlockOffset(); diff --git a/src/duckdb/src/storage/compression/numeric_constant.cpp b/src/duckdb/src/storage/compression/numeric_constant.cpp index a4d1e789b..411a85f6d 100644 --- a/src/duckdb/src/storage/compression/numeric_constant.cpp +++ b/src/duckdb/src/storage/compression/numeric_constant.cpp @@ -11,7 +11,7 @@ namespace duckdb { //===--------------------------------------------------------------------===// // Scan //===--------------------------------------------------------------------===// -unique_ptr ConstantInitScan(ColumnSegment &segment) { +unique_ptr ConstantInitScan(const QueryContext &context, ColumnSegment &segment) { return nullptr; } diff --git a/src/duckdb/src/storage/compression/rle.cpp b/src/duckdb/src/storage/compression/rle.cpp index 57ebaf1fa..ed26d824d 100644 --- a/src/duckdb/src/storage/compression/rle.cpp +++ b/src/duckdb/src/storage/compression/rle.cpp @@ -303,7 +303,7 @@ struct RLEScanState : public SegmentScanState { }; template -unique_ptr RLEInitScan(ColumnSegment &segment) { +unique_ptr RLEInitScan(const QueryContext &context, ColumnSegment &segment) { auto result = make_uniq>(segment); return std::move(result); } diff --git a/src/duckdb/src/storage/compression/roaring/common.cpp b/src/duckdb/src/storage/compression/roaring/common.cpp index 80f7004de..3d9230787 100644 --- a/src/duckdb/src/storage/compression/roaring/common.cpp +++ b/src/duckdb/src/storage/compression/roaring/common.cpp @@ -208,7 +208,7 @@ void RoaringFinalizeCompress(CompressionState &state_p) { state.Finalize(); } -unique_ptr RoaringInitScan(ColumnSegment &segment) { +unique_ptr RoaringInitScan(const QueryContext &context, ColumnSegment &segment) { auto result = make_uniq(segment); return std::move(result); } diff --git a/src/duckdb/src/storage/compression/string_uncompressed.cpp b/src/duckdb/src/storage/compression/string_uncompressed.cpp index af3b826bf..201e97787 100644 --- a/src/duckdb/src/storage/compression/string_uncompressed.cpp +++ b/src/duckdb/src/storage/compression/string_uncompressed.cpp @@ -77,7 +77,8 @@ void UncompressedStringInitPrefetch(ColumnSegment &segment, PrefetchState &prefe } } -unique_ptr UncompressedStringStorage::StringInitScan(ColumnSegment &segment) { +unique_ptr UncompressedStringStorage::StringInitScan(const QueryContext &context, + ColumnSegment &segment) { auto result = make_uniq(); auto &buffer_manager = BufferManager::GetBufferManager(segment.db); result->handle = buffer_manager.Pin(segment.block); diff --git a/src/duckdb/src/storage/compression/validity_uncompressed.cpp b/src/duckdb/src/storage/compression/validity_uncompressed.cpp index 5a71b8974..e43f237a3 100644 --- a/src/duckdb/src/storage/compression/validity_uncompressed.cpp +++ b/src/duckdb/src/storage/compression/validity_uncompressed.cpp @@ -207,7 +207,7 @@ struct ValidityScanState : public SegmentScanState { block_id_t block_id; }; -unique_ptr ValidityInitScan(ColumnSegment &segment) { +unique_ptr ValidityInitScan(const QueryContext &context, ColumnSegment &segment) { auto result = make_uniq(); auto &buffer_manager = BufferManager::GetBufferManager(segment.db); result->handle = buffer_manager.Pin(segment.block); @@ -287,6 +287,13 @@ void ValidityUncompressed::UnalignedScan(data_ptr_t input, idx_t input_size, idx // otherwise the subsequent bitwise & will modify values outside of the range of values we want to alter input_mask |= ValidityUncompressed::UPPER_MASKS[shift_amount]; + if (pos == 0) { + // We also need to set the lower bits, which are to the left of the relevant bits (x), to 1 + // These are the bits that are "behind" this scan window, and should not affect this scan + auto non_relevant_mask = ValidityUncompressed::LOWER_MASKS[result_idx]; + input_mask |= non_relevant_mask; + } + // after this, we move to the next input_entry offset = ValidityMask::BITS_PER_VALUE - input_idx; input_entry++; diff --git a/src/duckdb/src/storage/compression/zstd.cpp b/src/duckdb/src/storage/compression/zstd.cpp index 408855284..829a76f05 100644 --- a/src/duckdb/src/storage/compression/zstd.cpp +++ b/src/duckdb/src/storage/compression/zstd.cpp @@ -81,7 +81,7 @@ struct ZSTDStorage { static void Compress(CompressionState &state_p, Vector &scan_vector, idx_t count); static void FinalizeCompress(CompressionState &state_p); - static unique_ptr StringInitScan(ColumnSegment &segment); + static unique_ptr StringInitScan(const QueryContext &context, ColumnSegment &segment); static void StringScanPartial(ColumnSegment &segment, ColumnScanState &state, idx_t scan_count, Vector &result, idx_t result_offset); static void StringScan(ColumnSegment &segment, ColumnScanState &state, idx_t scan_count, Vector &result); @@ -142,6 +142,11 @@ struct ZSTDAnalyzeState : public AnalyzeState { unique_ptr ZSTDStorage::StringInitAnalyze(ColumnData &col_data, PhysicalType type) { // check if the storage version we are writing to supports sztd auto &storage = col_data.GetStorageManager(); + auto &block_manager = col_data.GetBlockManager(); + if (block_manager.InMemory()) { + //! Can't use ZSTD in in-memory environment + return nullptr; + } if (storage.GetStorageVersion() < 4) { // compatibility mode with old versions - disable zstd return nullptr; @@ -231,7 +236,6 @@ class ZSTDCompressionState : public CompressionState { checkpoint_data(checkpoint_data), partial_block_manager(checkpoint_data.GetCheckpointState().GetPartialBlockManager()), function(checkpoint_data.GetCompressionFunction(CompressionType::COMPRESSION_ZSTD)) { - total_vector_count = GetVectorCount(analyze_state->count); total_segment_count = analyze_state->segment_count; vectors_per_segment = analyze_state->vectors_per_segment; @@ -249,6 +253,7 @@ class ZSTDCompressionState : public CompressionState { public: void ResetOutBuffer() { + D_ASSERT(GetCurrentOffset() <= GetWritableSpace(info)); out_buffer.dst = current_buffer_ptr; out_buffer.pos = 0; @@ -347,6 +352,7 @@ class ZSTDCompressionState : public CompressionState { void InitializeVector() { D_ASSERT(!in_vector); if (vector_count + 1 >= total_vector_count) { + //! Last vector vector_size = analyze_state->count - (ZSTD_VECTOR_SIZE * vector_count); } else { vector_size = ZSTD_VECTOR_SIZE; @@ -355,6 +361,7 @@ class ZSTDCompressionState : public CompressionState { current_offset = UnsafeNumericCast( AlignValue(UnsafeNumericCast(current_offset))); current_buffer_ptr = current_buffer->Ptr() + current_offset; + D_ASSERT(GetCurrentOffset() <= GetWritableSpace(info)); compressed_size = 0; uncompressed_size = 0; @@ -413,15 +420,11 @@ class ZSTDCompressionState : public CompressionState { throw InvalidInputException("ZSTD Compression failed: %s", duckdb_zstd::ZSTD_getErrorName(compress_result)); } + D_ASSERT(GetCurrentOffset() <= GetWritableSpace(info)); if (compress_result == 0) { // Finished break; } - if (out_buffer.pos != out_buffer.size) { - throw InternalException("Expected ZSTD_compressStream2 to fully utilize the current buffer, but pos is " - "%d, while size is %d", - out_buffer.pos, out_buffer.size); - } NewPage(); } } @@ -691,7 +694,7 @@ struct ZSTDScanState : public SegmentScanState { explicit ZSTDScanState(ColumnSegment &segment) : state(segment.GetSegmentState()->Cast()), block_manager(segment.GetBlockManager()), buffer_manager(BufferManager::GetBufferManager(segment.db)), - segment_block_offset(segment.GetBlockOffset()) { + segment_block_offset(segment.GetBlockOffset()), segment(segment) { decompression_context = duckdb_zstd::ZSTD_createDCtx(); segment_handle = buffer_manager.Pin(segment.block); @@ -791,14 +794,23 @@ struct ZSTDScanState : public SegmentScanState { auto vector_size = metadata.count; + auto string_lengths_size = (sizeof(string_length_t) * vector_size); scan_state.string_lengths = reinterpret_cast(scan_state.current_buffer_ptr); - scan_state.current_buffer_ptr += (sizeof(string_length_t) * vector_size); + scan_state.current_buffer_ptr += string_lengths_size; // Update the in_buffer to point to the start of the compressed data frame idx_t current_offset = UnsafeNumericCast(scan_state.current_buffer_ptr - handle_start); scan_state.in_buffer.src = scan_state.current_buffer_ptr; scan_state.in_buffer.pos = 0; - scan_state.in_buffer.size = block_manager.GetBlockSize() - sizeof(block_id_t) - current_offset; + if (scan_state.metadata.block_offset + string_lengths_size + scan_state.metadata.compressed_size > + (segment.SegmentSize() - sizeof(block_id_t))) { + //! We know that the compressed size is too big to fit on the current page + scan_state.in_buffer.size = + MinValue(metadata.compressed_size, block_manager.GetBlockSize() - sizeof(block_id_t) - current_offset); + } else { + scan_state.in_buffer.size = + MinValue(metadata.compressed_size, block_manager.GetBlockSize() - current_offset); + } // Initialize the context for streaming decompression duckdb_zstd::ZSTD_DCtx_reset(decompression_context, duckdb_zstd::ZSTD_reset_session_only); @@ -832,7 +844,7 @@ struct ZSTDScanState : public SegmentScanState { scan_state.in_buffer.src = ptr; scan_state.in_buffer.pos = 0; - idx_t page_size = block_manager.GetBlockSize() - sizeof(block_id_t); + idx_t page_size = segment.SegmentSize() - sizeof(block_id_t); idx_t remaining_compressed_data = scan_state.metadata.compressed_size - scan_state.compressed_scan_count; scan_state.in_buffer.size = MinValue(page_size, remaining_compressed_data); } @@ -842,6 +854,7 @@ struct ZSTDScanState : public SegmentScanState { return; } + auto &in_buffer = scan_state.in_buffer; duckdb_zstd::ZSTD_outBuffer out_buffer; out_buffer.dst = destination; @@ -849,18 +862,25 @@ struct ZSTDScanState : public SegmentScanState { out_buffer.size = uncompressed_length; while (true) { - idx_t old_pos = scan_state.in_buffer.pos; + idx_t old_pos = in_buffer.pos; size_t res = duckdb_zstd::ZSTD_decompressStream( /* zds = */ decompression_context, /* output =*/&out_buffer, - /* input =*/&scan_state.in_buffer); - scan_state.compressed_scan_count += scan_state.in_buffer.pos - old_pos; + /* input =*/&in_buffer); + scan_state.compressed_scan_count += in_buffer.pos - old_pos; if (duckdb_zstd::ZSTD_isError(res)) { throw InvalidInputException("ZSTD Decompression failed: %s", duckdb_zstd::ZSTD_getErrorName(res)); } if (out_buffer.pos == out_buffer.size) { + //! Done decompressing the relevant portion + break; + } + if (!res) { + D_ASSERT(out_buffer.pos == out_buffer.size); + D_ASSERT(in_buffer.pos == in_buffer.size); break; } + D_ASSERT(in_buffer.pos == in_buffer.size); // Did not fully decompress, it needs a new page to read from LoadNextPageForVector(scan_state); } @@ -956,12 +976,13 @@ struct ZSTDScanState : public SegmentScanState { idx_t segment_count; //! The amount of tuples consumed idx_t scanned_count = 0; + ColumnSegment &segment; //! Buffer for skipping data AllocatedData skip_buffer; }; -unique_ptr ZSTDStorage::StringInitScan(ColumnSegment &segment) { +unique_ptr ZSTDStorage::StringInitScan(const QueryContext &context, ColumnSegment &segment) { auto result = make_uniq(segment); return std::move(result); } diff --git a/src/duckdb/src/storage/data_table.cpp b/src/duckdb/src/storage/data_table.cpp index 7d19449bb..d4a246e01 100644 --- a/src/duckdb/src/storage/data_table.cpp +++ b/src/duckdb/src/storage/data_table.cpp @@ -146,7 +146,6 @@ DataTable::DataTable(ClientContext &context, DataTable &parent, idx_t removed_co DataTable::DataTable(ClientContext &context, DataTable &parent, BoundConstraint &constraint) : db(parent.db), info(parent.info), row_groups(parent.row_groups), version(DataTableVersion::MAIN_TABLE) { - // ALTER COLUMN to add a new constraint. // Clone the storage info vector or the table. @@ -173,7 +172,6 @@ DataTable::DataTable(ClientContext &context, DataTable &parent, BoundConstraint DataTable::DataTable(ClientContext &context, DataTable &parent, idx_t changed_idx, const LogicalType &target_type, const vector &bound_columns, Expression &cast_expr) : db(parent.db), info(parent.info), version(DataTableVersion::MAIN_TABLE) { - auto &local_storage = LocalStorage::Get(context, db); // prevent any tuples from being added to the parent lock_guard lock(append_lock); @@ -245,7 +243,7 @@ void DataTable::InitializeScan(ClientContext &context, DuckTransaction &transact state.checkpoint_lock = transaction.SharedLockTable(*info); auto &local_storage = LocalStorage::Get(transaction); state.Initialize(column_ids, context, table_filters); - row_groups->InitializeScan(state.table_state, column_ids, table_filters); + row_groups->InitializeScan(context, state.table_state, column_ids, table_filters); local_storage.InitializeScan(*this, state.local_state, table_filters); } @@ -253,7 +251,7 @@ void DataTable::InitializeScanWithOffset(DuckTransaction &transaction, TableScan const vector &column_ids, idx_t start_row, idx_t end_row) { state.checkpoint_lock = transaction.SharedLockTable(*info); state.Initialize(column_ids); - row_groups->InitializeScanWithOffset(state.table_state, column_ids, start_row, end_row); + row_groups->InitializeScanWithOffset(QueryContext(), state.table_state, column_ids, start_row, end_row); } idx_t DataTable::GetRowGroupSize() const { @@ -681,7 +679,7 @@ void DataTable::VerifyNewConstraint(LocalStorage &local_storage, DataTable &pare throw NotImplementedException("FIXME: ALTER COLUMN with such constraint is not supported yet"); } - parent.row_groups->VerifyNewConstraint(parent, constraint); + parent.row_groups->VerifyNewConstraint(local_storage.GetClientContext(), parent, constraint); local_storage.VerifyNewConstraint(parent, constraint); } @@ -768,7 +766,6 @@ void DataTable::VerifyUniqueIndexes(TableIndexList &indexes, optional_ptr storage, optional_ptr manager) { - auto &table = constraint_state.table; if (table.HasGeneratedColumns()) { // Verify the generated columns against the inserted values. @@ -958,7 +955,6 @@ void DataTable::LocalAppend(TableCatalogEntry &table, ClientContext &context, Da void DataTable::LocalAppend(TableCatalogEntry &table, ClientContext &context, ColumnDataCollection &collection, const vector> &bound_constraints, optional_ptr> column_ids) { - LocalAppendState append_state; auto &storage = table.GetStorage(); storage.InitializeLocalAppend(append_state, table, context, bound_constraints); @@ -1062,7 +1058,8 @@ void DataTable::ScanTableSegment(DuckTransaction &transaction, idx_t row_start, CreateIndexScanState state; InitializeScanWithOffset(transaction, state, column_ids, row_start, row_start + count); - auto row_start_aligned = state.table_state.row_group->start + state.table_state.vector_index * STANDARD_VECTOR_SIZE; + auto row_start_aligned = + state.table_state.row_group->node->start + state.table_state.vector_index * STANDARD_VECTOR_SIZE; idx_t current_row = row_start_aligned; while (current_row < end) { @@ -1197,7 +1194,7 @@ ErrorData DataTable::AppendToIndexes(TableIndexList &indexes, optional_ptr(); - unbound_index.BufferChunk(index_chunk, row_ids, mapped_column_ids); + unbound_index.BufferChunk(index_chunk, row_ids, mapped_column_ids, BufferedIndexReplay::INSERT_ENTRY); return false; } @@ -1270,9 +1267,9 @@ void DataTable::RemoveFromIndexes(TableAppendState &state, DataChunk &chunk, Vec }); } -void DataTable::RemoveFromIndexes(Vector &row_identifiers, idx_t count) { +void DataTable::RemoveFromIndexes(const QueryContext &context, Vector &row_identifiers, idx_t count) { D_ASSERT(IsMainTable()); - row_groups->RemoveFromIndexes(info->indexes, row_identifiers, count); + row_groups->RemoveFromIndexes(context, info->indexes, row_identifiers, count); } //===--------------------------------------------------------------------===// @@ -1544,7 +1541,7 @@ void DataTable::Update(TableUpdateState &state, ClientContext &context, Vector & row_ids_slice.Slice(row_ids, sel_global_update, n_global_update); row_ids_slice.Flatten(n_global_update); - row_groups->Update(transaction, FlatVector::GetData(row_ids_slice), column_ids, updates_slice); + row_groups->Update(transaction, *this, FlatVector::GetData(row_ids_slice), column_ids, updates_slice); } } @@ -1568,7 +1565,7 @@ void DataTable::UpdateColumn(TableCatalogEntry &table, ClientContext &context, V updates.Flatten(); row_ids.Flatten(updates.size()); - row_groups->UpdateColumn(transaction, row_ids, column_path, updates); + row_groups->UpdateColumn(transaction, *this, row_ids, column_path, updates); } //===--------------------------------------------------------------------===// @@ -1649,9 +1646,9 @@ void DataTable::CommitDropTable() { //===--------------------------------------------------------------------===// // Column Segment Info //===--------------------------------------------------------------------===// -vector DataTable::GetColumnSegmentInfo() { +vector DataTable::GetColumnSegmentInfo(const QueryContext &context) { auto lock = GetSharedCheckpointLock(); - return row_groups->GetColumnSegmentInfo(); + return row_groups->GetColumnSegmentInfo(context); } //===--------------------------------------------------------------------===// diff --git a/src/duckdb/src/storage/external_file_cache.cpp b/src/duckdb/src/storage/external_file_cache.cpp index bcd5730f0..304116c7f 100644 --- a/src/duckdb/src/storage/external_file_cache.cpp +++ b/src/duckdb/src/storage/external_file_cache.cpp @@ -57,7 +57,8 @@ void ExternalFileCache::CachedFileRange::VerifyCheckSum() { #endif } -ExternalFileCache::CachedFile::CachedFile(string path_p) : path(std::move(path_p)) { +ExternalFileCache::CachedFile::CachedFile(string path_p) + : path(std::move(path_p)), file_size(0), last_modified(0), can_seek(false), on_disk_file(false) { } void ExternalFileCache::CachedFile::Verify(const unique_ptr &guard) const { diff --git a/src/duckdb/src/storage/index.cpp b/src/duckdb/src/storage/index.cpp index ca136d631..fc9ffcb85 100644 --- a/src/duckdb/src/storage/index.cpp +++ b/src/duckdb/src/storage/index.cpp @@ -7,7 +7,6 @@ namespace duckdb { Index::Index(const vector &column_ids, TableIOManager &table_io_manager, AttachedDatabase &db) : column_ids(column_ids), table_io_manager(table_io_manager), db(db) { - if (!Radix::IsLittleEndian()) { throw NotImplementedException("indexes are not supported on big endian architectures"); } diff --git a/src/duckdb/src/storage/local_storage.cpp b/src/duckdb/src/storage/local_storage.cpp index e3cbb8f3b..f28a294ba 100644 --- a/src/duckdb/src/storage/local_storage.cpp +++ b/src/duckdb/src/storage/local_storage.cpp @@ -16,12 +16,11 @@ namespace duckdb { LocalTableStorage::LocalTableStorage(ClientContext &context, DataTable &table) - : table_ref(table), allocator(Allocator::Get(table.db)), deleted_rows(0), optimistic_writer(context, table), - merged_storage(false) { - + : context(context), table_ref(table), allocator(Allocator::Get(table.db)), deleted_rows(0), + optimistic_writer(context, table), merged_storage(false) { auto types = table.GetTypes(); auto data_table_info = table.GetDataTableInfo(); - row_groups = OptimisticDataWriter::CreateCollection(table, types); + row_groups = optimistic_writer.CreateCollection(table, types, OptimisticWritePartialManagers::GLOBAL); auto &collection = *row_groups->collection; collection.InitializeEmpty(); @@ -63,10 +62,9 @@ LocalTableStorage::LocalTableStorage(ClientContext &context, DataTable &table) LocalTableStorage::LocalTableStorage(ClientContext &context, DataTable &new_data_table, LocalTableStorage &parent, const idx_t alter_column_index, const LogicalType &target_type, const vector &bound_columns, Expression &cast_expr) - : table_ref(new_data_table), allocator(Allocator::Get(new_data_table.db)), deleted_rows(parent.deleted_rows), - optimistic_collections(std::move(parent.optimistic_collections)), + : context(context), table_ref(new_data_table), allocator(Allocator::Get(new_data_table.db)), + deleted_rows(parent.deleted_rows), optimistic_collections(std::move(parent.optimistic_collections)), optimistic_writer(new_data_table, parent.optimistic_writer), merged_storage(parent.merged_storage) { - // Alter the column type. auto &parent_collection = *parent.row_groups->collection; auto new_collection = @@ -83,7 +81,6 @@ LocalTableStorage::LocalTableStorage(DataTable &new_data_table, LocalTableStorag : table_ref(new_data_table), allocator(Allocator::Get(new_data_table.db)), deleted_rows(parent.deleted_rows), optimistic_collections(std::move(parent.optimistic_collections)), optimistic_writer(new_data_table, parent.optimistic_writer), merged_storage(parent.merged_storage) { - // Remove the column from the previous table storage. auto &parent_collection = *parent.row_groups->collection; auto new_collection = parent_collection.RemoveColumn(drop_column_index); @@ -99,7 +96,6 @@ LocalTableStorage::LocalTableStorage(ClientContext &context, DataTable &new_dt, : table_ref(new_dt), allocator(Allocator::Get(new_dt.db)), deleted_rows(parent.deleted_rows), optimistic_collections(std::move(parent.optimistic_collections)), optimistic_writer(new_dt, parent.optimistic_writer), merged_storage(parent.merged_storage) { - auto &parent_collection = *parent.row_groups->collection; auto new_collection = parent_collection.AddColumn(context, new_column, default_executor); row_groups = std::move(parent.row_groups); @@ -115,7 +111,7 @@ void LocalTableStorage::InitializeScan(CollectionScanState &state, optional_ptr< if (collection.GetTotalRows() == 0) { throw InternalException("No rows in LocalTableStorage row group for scan"); } - collection.InitializeScan(state, state.GetColumnIds(), table_filters.get()); + collection.InitializeScan(context, state, state.GetColumnIds(), table_filters.get()); } idx_t LocalTableStorage::EstimatedSize() { @@ -164,12 +160,25 @@ void LocalTableStorage::FlushBlocks() { ErrorData LocalTableStorage::AppendToIndexes(DuckTransaction &transaction, RowGroupCollection &source, TableIndexList &index_list, const vector &table_types, row_t &start_row) { - // In this function, we only care about scanning the indexed columns of a table. + // mapped_column_ids contains the physical column indices of each Indexed column in the table. + // This mapping is used to retrieve the physical column index for the corresponding vector of an index chunk scan. + // For example, if we are processing data for index_chunk.data[i], we can retrieve the physical column index + // by getting the value at mapped_column_ids[i]. + // An important note is that the index_chunk orderings are created in accordance with this mapping, not the other + // way around. (Check the scan code below, where the mapped_column_ids is passed as a parameter to the scan. + // The index_chunk inside of that lambda is ordered according to the mapping that is a parameter to the scan). + + // mapped_column_ids is used in two places: + // 1) To create the physical table chunk in this function. + // 2) If we are in an unbound state (i.e., WAL replay is happening right now), this mapping and the index_chunk + // are buffered in unbound_index. However, there can also be buffered deletes happening, so it is important + // to maintain a canonical representation of the mapping, which is just sorting. auto indexed_columns = index_list.GetRequiredColumns(); vector mapped_column_ids; for (auto &col : indexed_columns) { mapped_column_ids.emplace_back(col); } + std::sort(mapped_column_ids.begin(), mapped_column_ids.end()); // However, because the bound expressions of the indexes (and their bound // column references) are in relation to ALL table columns, we create an @@ -178,6 +187,7 @@ ErrorData LocalTableStorage::AppendToIndexes(DuckTransaction &transaction, RowGr DataChunk table_chunk; table_chunk.InitializeEmpty(table_types); + // index_chunk scans are created here in the mapped_column_ids ordering (see note above). ErrorData error; source.Scan(transaction, mapped_column_ids, [&](DataChunk &index_chunk) -> bool { D_ASSERT(index_chunk.ColumnCount() == mapped_column_ids.size()); @@ -205,7 +215,6 @@ void LocalTableStorage::AppendToIndexes(DuckTransaction &transaction, TableAppen bool append_to_table) { // In this function, we might scan all table columns, // as we might also append to the table itself (append_to_table). - auto &table = table_ref.get(); if (append_to_table) { table.InitializeAppend(transaction, append_state); @@ -564,7 +573,7 @@ idx_t LocalStorage::Delete(DataTable &table, Vector &row_ids, idx_t count) { // delete from unique indices (if any) if (!storage->append_indexes.Empty()) { - storage->GetCollection().RemoveFromIndexes(storage->append_indexes, row_ids, count); + storage->GetCollection().RemoveFromIndexes(context, storage->append_indexes, row_ids, count); } auto ids = FlatVector::GetData(row_ids); @@ -580,7 +589,7 @@ void LocalStorage::Update(DataTable &table, Vector &row_ids, const vector(row_ids); - storage->GetCollection().Update(TransactionData(0, 0), ids, column_ids, updates); + storage->GetCollection().Update(TransactionData(0, 0), table, ids, column_ids, updates); } void LocalStorage::Flush(DataTable &table, LocalTableStorage &storage, optional_ptr commit_state) { @@ -752,7 +761,7 @@ void LocalStorage::VerifyNewConstraint(DataTable &parent, const BoundConstraint if (!storage) { return; } - storage->GetCollection().VerifyNewConstraint(parent, constraint); + storage->GetCollection().VerifyNewConstraint(context, parent, constraint); } } // namespace duckdb diff --git a/src/duckdb/src/storage/metadata/metadata_manager.cpp b/src/duckdb/src/storage/metadata/metadata_manager.cpp index 8674f742d..3f80fe44e 100644 --- a/src/duckdb/src/storage/metadata/metadata_manager.cpp +++ b/src/duckdb/src/storage/metadata/metadata_manager.cpp @@ -99,12 +99,16 @@ MetadataHandle MetadataManager::Pin(const MetadataPointer &pointer) { return Pin(QueryContext(), pointer); } -MetadataHandle MetadataManager::Pin(QueryContext context, const MetadataPointer &pointer) { +MetadataHandle MetadataManager::Pin(const QueryContext &context, const MetadataPointer &pointer) { D_ASSERT(pointer.index < METADATA_BLOCK_COUNT); shared_ptr block_handle; { lock_guard guard(block_lock); - auto &block = blocks[UnsafeNumericCast(pointer.block_index)]; + auto entry = blocks.find(UnsafeNumericCast(pointer.block_index)); + if (entry == blocks.end()) { + throw InternalException("Trying to pin block %llu - but the block did not exist", pointer.block_index); + } + auto &block = entry->second; #ifdef DEBUG for (auto &free_block : block.free_blocks) { if (free_block == pointer.index) { @@ -272,15 +276,18 @@ void MetadataManager::Flush() { } continue; } - auto handle = buffer_manager.Pin(block.block); + auto block_handle = block.block; + auto handle = buffer_manager.Pin(block_handle); // zero-initialize the few leftover bytes memset(handle.Ptr() + total_metadata_size, 0, block_manager.GetBlockSize() - total_metadata_size); D_ASSERT(kv.first == block.block_id); - if (block.block->BlockId() >= MAXIMUM_BLOCK) { - auto new_block = - block_manager.ConvertToPersistent(QueryContext(), kv.first, block.block, std::move(handle)); - + if (block_handle->BlockId() >= MAXIMUM_BLOCK) { // Convert the temporary block to a persistent block. + // we cannot use ConvertToPersistent as another thread might still be reading the block + // so we use the safe version of ConvertToPersistent + auto new_block = block_manager.ConvertToPersistent(QueryContext(), kv.first, std::move(block_handle), + std::move(handle), ConvertToPersistentMode::THREAD_SAFE); + guard.lock(); block.block = std::move(new_block); guard.unlock(); @@ -366,6 +373,7 @@ void MetadataBlock::FreeBlocksFromInteger(idx_t free_list) { } void MetadataManager::MarkBlocksAsModified() { + unique_lock guard(block_lock); // for any blocks that were modified in the last checkpoint - set them to free blocks currently for (auto &kv : modified_blocks) { auto block_id = kv.first; @@ -379,7 +387,10 @@ void MetadataManager::MarkBlocksAsModified() { if (new_free_blocks == NumericLimits::Maximum()) { // if new free_blocks is all blocks - mark entire block as modified blocks.erase(entry); + + guard.unlock(); block_manager.MarkBlockAsModified(block_id); + guard.lock(); } else { // set the new set of free blocks block.FreeBlocksFromInteger(new_free_blocks); @@ -414,6 +425,18 @@ void MetadataManager::ClearModifiedBlocks(const vector &pointe } } +bool MetadataManager::BlockHasBeenCleared(const MetaBlockPointer &pointer) { + unique_lock guard(block_lock); + auto block_id = pointer.GetBlockId(); + auto block_index = pointer.GetBlockIndex(); + auto entry = modified_blocks.find(block_id); + if (entry == modified_blocks.end()) { + throw InternalException("BlockHasBeenCleared - Block id %llu not found in modified_blocks", block_id); + } + auto &modified_list = entry->second; + return (modified_list & (1ULL << block_index)) == 0ULL; +} + vector MetadataManager::GetMetadataInfo() const { vector result; unique_lock guard(block_lock); diff --git a/src/duckdb/src/storage/metadata/metadata_reader.cpp b/src/duckdb/src/storage/metadata/metadata_reader.cpp index 06c2b1c1b..342833448 100644 --- a/src/duckdb/src/storage/metadata/metadata_reader.cpp +++ b/src/duckdb/src/storage/metadata/metadata_reader.cpp @@ -4,11 +4,8 @@ namespace duckdb { MetadataReader::MetadataReader(MetadataManager &manager, MetaBlockPointer pointer, optional_ptr> read_pointers_p, BlockReaderType type) - : manager(manager), type(type), next_pointer(FromDiskPointer(pointer)), has_next_block(true), - read_pointers(read_pointers_p), index(0), offset(0), next_offset(pointer.offset), capacity(0) { - if (read_pointers) { - read_pointers->push_back(pointer); - } + : manager(manager), type(type), next_pointer(pointer), has_next_block(true), read_pointers(read_pointers_p), + index(0), offset(0), next_offset(pointer.offset), capacity(0) { } MetadataReader::MetadataReader(MetadataManager &manager, BlockPointer pointer) @@ -59,11 +56,10 @@ MetaBlockPointer MetadataReader::GetMetaBlockPointer() { vector MetadataReader::GetRemainingBlocks(MetaBlockPointer last_block) { vector result; while (has_next_block) { - auto next_block_pointer = manager.GetDiskPointer(next_pointer, UnsafeNumericCast(next_offset)); - if (last_block.IsValid() && next_block_pointer.block_pointer == last_block.block_pointer) { + if (last_block.IsValid() && next_pointer.block_pointer == last_block.block_pointer) { break; } - result.push_back(next_block_pointer); + result.push_back(next_pointer); ReadNextBlock(); } return result; @@ -77,18 +73,18 @@ void MetadataReader::ReadNextBlock(QueryContext context) { if (!has_next_block) { throw IOException("No more data remaining in MetadataReader"); } - block = manager.Pin(context, next_pointer); - index = next_pointer.index; + if (read_pointers) { + read_pointers->push_back(next_pointer); + } + auto next_disk_pointer = FromDiskPointer(next_pointer); + block = manager.Pin(context, next_disk_pointer); + index = next_disk_pointer.index; idx_t next_block = Load(BasePtr()); if (next_block == idx_t(-1)) { has_next_block = false; } else { - next_pointer = FromDiskPointer(MetaBlockPointer(next_block, 0)); - MetaBlockPointer next_block_pointer(next_block, 0); - if (read_pointers) { - read_pointers->push_back(next_block_pointer); - } + next_pointer = MetaBlockPointer(next_block, 0); } if (next_offset < sizeof(block_id_t)) { next_offset = sizeof(block_id_t); diff --git a/src/duckdb/src/storage/metadata/metadata_writer.cpp b/src/duckdb/src/storage/metadata/metadata_writer.cpp index 69d8ea87e..8e7138b7d 100644 --- a/src/duckdb/src/storage/metadata/metadata_writer.cpp +++ b/src/duckdb/src/storage/metadata/metadata_writer.cpp @@ -32,7 +32,7 @@ MetaBlockPointer MetadataWriter::GetMetaBlockPointer() { void MetadataWriter::SetWrittenPointers(optional_ptr> written_pointers_p) { written_pointers = written_pointers_p; - if (written_pointers && capacity > 0) { + if (written_pointers && capacity > 0 && offset < capacity) { written_pointers->push_back(manager.GetDiskPointer(current_pointer)); } } diff --git a/src/duckdb/src/storage/optimistic_data_writer.cpp b/src/duckdb/src/storage/optimistic_data_writer.cpp index 4f595223f..fbe481364 100644 --- a/src/duckdb/src/storage/optimistic_data_writer.cpp +++ b/src/duckdb/src/storage/optimistic_data_writer.cpp @@ -6,6 +6,9 @@ namespace duckdb { +OptimisticWriteCollection::~OptimisticWriteCollection() { +} + OptimisticDataWriter::OptimisticDataWriter(ClientContext &context, DataTable &table) : context(context), table(table) { } @@ -28,14 +31,14 @@ bool OptimisticDataWriter::PrepareWrite() { // allocate the partial block-manager if none is allocated yet if (!partial_manager) { auto &block_manager = table.GetTableIOManager().GetBlockManagerForRowData(); - partial_manager = - make_uniq(QueryContext(context), block_manager, PartialBlockType::APPEND_TO_TABLE); + partial_manager = make_uniq(context, block_manager, PartialBlockType::APPEND_TO_TABLE); } return true; } unique_ptr OptimisticDataWriter::CreateCollection(DataTable &storage, - const vector &insert_types) { + const vector &insert_types, + OptimisticWritePartialManagers type) { auto table_info = storage.GetDataTableInfo(); auto &io_manager = TableIOManager::Get(storage); @@ -45,6 +48,13 @@ unique_ptr OptimisticDataWriter::CreateCollection(Dat auto result = make_uniq(); result->collection = std::move(row_groups); + if (type == OptimisticWritePartialManagers::PER_COLUMN) { + for (idx_t i = 0; i < insert_types.size(); i++) { + auto &block_manager = table.GetTableIOManager().GetBlockManagerForRowData(); + result->partial_block_managers.push_back(make_uniq( + QueryContext(context), block_manager, PartialBlockType::APPEND_TO_TABLE)); + } + } return result; } @@ -62,7 +72,7 @@ void OptimisticDataWriter::WriteNewRowGroup(OptimisticWriteCollection &row_group for (idx_t i = row_groups.last_flushed; i < row_groups.complete_row_groups; i++) { to_flush.push_back(*row_groups.collection->GetRowGroup(NumericCast(i))); } - FlushToDisk(to_flush); + FlushToDisk(row_groups, to_flush); row_groups.last_flushed = row_groups.complete_row_groups; } } @@ -79,30 +89,40 @@ void OptimisticDataWriter::WriteLastRowGroup(OptimisticWriteCollection &row_grou } // add the last (incomplete) row group to_flush.push_back(*row_groups.collection->GetRowGroup(-1)); - FlushToDisk(to_flush); + FlushToDisk(row_groups, to_flush); + + for (auto &partial_manager : row_groups.partial_block_managers) { + Merge(partial_manager); + } + row_groups.partial_block_managers.clear(); } -void OptimisticDataWriter::FlushToDisk(const vector> &row_groups) { +void OptimisticDataWriter::FlushToDisk(OptimisticWriteCollection &collection, + const vector> &row_groups) { //! The set of column compression types (if any) vector compression_types; D_ASSERT(compression_types.empty()); for (auto &column : table.Columns()) { compression_types.push_back(column.CompressionType()); } - RowGroupWriteInfo info(*partial_manager, compression_types); + RowGroupWriteInfo info(*partial_manager, compression_types, collection.partial_block_managers); RowGroup::WriteToDisk(info, row_groups); } -void OptimisticDataWriter::Merge(OptimisticDataWriter &other) { - if (!other.partial_manager) { +void OptimisticDataWriter::Merge(unique_ptr &other_manager) { + if (!other_manager) { return; } if (!partial_manager) { - partial_manager = std::move(other.partial_manager); + partial_manager = std::move(other_manager); return; } - partial_manager->Merge(*other.partial_manager); - other.partial_manager.reset(); + partial_manager->Merge(*other_manager); + other_manager.reset(); +} + +void OptimisticDataWriter::Merge(OptimisticDataWriter &other) { + Merge(other.partial_manager); } void OptimisticDataWriter::FinalFlush() { diff --git a/src/duckdb/src/storage/partial_block_manager.cpp b/src/duckdb/src/storage/partial_block_manager.cpp index 27fe86cd3..61f103c5a 100644 --- a/src/duckdb/src/storage/partial_block_manager.cpp +++ b/src/duckdb/src/storage/partial_block_manager.cpp @@ -22,7 +22,6 @@ void PartialBlock::AddSegmentToTail(ColumnData &data, ColumnSegment &segment, ui } void PartialBlock::FlushInternal(const idx_t free_space_left) { - // ensure that we do not leak any data if (free_space_left > 0 || !uninitialized_regions.empty()) { auto buffer_handle = block_manager.buffer_manager.Pin(block_handle); @@ -45,7 +44,6 @@ PartialBlockManager::PartialBlockManager(QueryContext context, BlockManager &blo uint32_t max_use_count) : context(context.GetClientContext()), block_manager(block_manager), partial_block_type(partial_block_type), max_use_count(max_use_count) { - if (max_partial_block_size_p.IsValid()) { max_partial_block_size = NumericCast(max_partial_block_size_p.GetIndex()); return; diff --git a/src/duckdb/src/storage/serialization/serialize_nodes.cpp b/src/duckdb/src/storage/serialization/serialize_nodes.cpp index b87ba38ec..ac3959177 100644 --- a/src/duckdb/src/storage/serialization/serialize_nodes.cpp +++ b/src/duckdb/src/storage/serialization/serialize_nodes.cpp @@ -252,7 +252,7 @@ ColumnList ColumnList::Deserialize(Deserializer &deserializer) { void CommonTableExpressionInfo::Serialize(Serializer &serializer) const { serializer.WritePropertyWithDefault>(100, "aliases", aliases); serializer.WritePropertyWithDefault>(101, "query", query); - serializer.WriteProperty(102, "materialized", materialized); + serializer.WriteProperty(102, "materialized", GetMaterializedForSerialization(serializer)); serializer.WritePropertyWithDefault>>(103, "key_targets", key_targets); } diff --git a/src/duckdb/src/storage/serialization/serialize_query_node.cpp b/src/duckdb/src/storage/serialization/serialize_query_node.cpp index 50ab535d2..25b167558 100644 --- a/src/duckdb/src/storage/serialization/serialize_query_node.cpp +++ b/src/duckdb/src/storage/serialization/serialize_query_node.cpp @@ -38,6 +38,9 @@ unique_ptr QueryNode::Deserialize(Deserializer &deserializer) { } result->modifiers = std::move(modifiers); result->cte_map = std::move(cte_map); + if (type == QueryNodeType::CTE_NODE) { + result = std::move(result->Cast().child); + } return result; } diff --git a/src/duckdb/src/storage/serialization/serialize_types.cpp b/src/duckdb/src/storage/serialization/serialize_types.cpp index 453961009..963d5646e 100644 --- a/src/duckdb/src/storage/serialization/serialize_types.cpp +++ b/src/duckdb/src/storage/serialization/serialize_types.cpp @@ -42,6 +42,9 @@ shared_ptr ExtraTypeInfo::Deserialize(Deserializer &deserializer) case ExtraTypeInfoType::GENERIC_TYPE_INFO: result = make_shared_ptr(type); break; + case ExtraTypeInfoType::GEO_TYPE_INFO: + result = GeoTypeInfo::Deserialize(deserializer); + break; case ExtraTypeInfoType::INTEGER_LITERAL_TYPE_INFO: result = IntegerLiteralTypeInfo::Deserialize(deserializer); break; @@ -136,6 +139,15 @@ unique_ptr ExtensionTypeInfo::Deserialize(Deserializer &deser return result; } +void GeoTypeInfo::Serialize(Serializer &serializer) const { + ExtraTypeInfo::Serialize(serializer); +} + +shared_ptr GeoTypeInfo::Deserialize(Deserializer &deserializer) { + auto result = duckdb::shared_ptr(new GeoTypeInfo()); + return std::move(result); +} + void IntegerLiteralTypeInfo::Serialize(Serializer &serializer) const { ExtraTypeInfo::Serialize(serializer); serializer.WriteProperty(200, "constant_value", constant_value); diff --git a/src/duckdb/src/storage/single_file_block_manager.cpp b/src/duckdb/src/storage/single_file_block_manager.cpp index 6d22ff423..127fa795c 100644 --- a/src/duckdb/src/storage/single_file_block_manager.cpp +++ b/src/duckdb/src/storage/single_file_block_manager.cpp @@ -72,7 +72,6 @@ void GenerateDBIdentifier(uint8_t *db_identifier) { void EncryptCanary(MainHeader &main_header, const shared_ptr &encryption_state, const_data_ptr_t derived_key) { - uint8_t canary_buffer[MainHeader::CANARY_BYTE_SIZE]; // we zero-out the iv and the (not yet) encrypted canary diff --git a/src/duckdb/src/storage/standard_buffer_manager.cpp b/src/duckdb/src/storage/standard_buffer_manager.cpp index e15986e1c..f1170a562 100644 --- a/src/duckdb/src/storage/standard_buffer_manager.cpp +++ b/src/duckdb/src/storage/standard_buffer_manager.cpp @@ -338,7 +338,7 @@ BufferHandle StandardBufferManager::Pin(shared_ptr &handle) { return Pin(QueryContext(), handle); } -BufferHandle StandardBufferManager::Pin(QueryContext context, shared_ptr &handle) { +BufferHandle StandardBufferManager::Pin(const QueryContext &context, shared_ptr &handle) { // we need to be careful not to return the BufferHandle to this block while holding the BlockHandle's lock // as exiting this function's scope may cause the destructor of the BufferHandle to be called while holding the lock // the destructor calls Unpin, which grabs the BlockHandle's lock again, causing a deadlock @@ -495,7 +495,6 @@ void StandardBufferManager::RequireTemporaryDirectory() { } void StandardBufferManager::WriteTemporaryBuffer(MemoryTag tag, block_id_t block_id, FileBuffer &buffer) { - // WriteTemporaryBuffer assumes that we never write a buffer below DEFAULT_BLOCK_ALLOC_SIZE. RequireTemporaryDirectory(); @@ -543,8 +542,10 @@ unique_ptr StandardBufferManager::ReadTemporaryBuffer(QueryContext c BlockHandle &block, unique_ptr reusable_buffer) { D_ASSERT(!temporary_directory.path.empty()); - D_ASSERT(temporary_directory.handle.get()); auto id = block.BlockId(); + if (!temporary_directory.handle) { + throw InternalException("ReadTemporaryBuffer called but temporary directory has not been instantiated yet"); + } if (temporary_directory.handle->GetTempFile().HasTemporaryBuffer(id)) { // This is a block that was offloaded to a regular .tmp file, the file contains blocks of a fixed size return temporary_directory.handle->GetTempFile().ReadTemporaryBuffer(context, id, std::move(reusable_buffer)); @@ -642,6 +643,10 @@ bool StandardBufferManager::HasFilesInTemporaryDirectory() const { return found; } +BlockManager &StandardBufferManager::GetTemporaryBlockManager() { + return *temp_block_manager; +} + vector StandardBufferManager::GetTemporaryFiles() { vector result; if (temporary_directory.path.empty()) { diff --git a/src/duckdb/src/storage/statistics/base_statistics.cpp b/src/duckdb/src/storage/statistics/base_statistics.cpp index 89ae9cb61..9eac3b9aa 100644 --- a/src/duckdb/src/storage/statistics/base_statistics.cpp +++ b/src/duckdb/src/storage/statistics/base_statistics.cpp @@ -62,6 +62,9 @@ StatisticsType BaseStatistics::GetStatsType(const LogicalType &type) { if (type.id() == LogicalTypeId::SQLNULL) { return StatisticsType::BASE_STATS; } + if (type.id() == LogicalTypeId::GEOMETRY) { + return StatisticsType::GEOMETRY_STATS; + } switch (type.InternalType()) { case PhysicalType::BOOL: case PhysicalType::INT8: @@ -153,6 +156,9 @@ void BaseStatistics::Merge(const BaseStatistics &other) { case StatisticsType::ARRAY_STATS: ArrayStats::Merge(*this, other); break; + case StatisticsType::GEOMETRY_STATS: + GeometryStats::Merge(*this, other); + break; default: break; } @@ -174,6 +180,8 @@ BaseStatistics BaseStatistics::CreateUnknownType(LogicalType type) { return StructStats::CreateUnknown(std::move(type)); case StatisticsType::ARRAY_STATS: return ArrayStats::CreateUnknown(std::move(type)); + case StatisticsType::GEOMETRY_STATS: + return GeometryStats::CreateUnknown(std::move(type)); default: return BaseStatistics(std::move(type)); } @@ -191,6 +199,8 @@ BaseStatistics BaseStatistics::CreateEmptyType(LogicalType type) { return StructStats::CreateEmpty(std::move(type)); case StatisticsType::ARRAY_STATS: return ArrayStats::CreateEmpty(std::move(type)); + case StatisticsType::GEOMETRY_STATS: + return GeometryStats::CreateEmpty(std::move(type)); default: return BaseStatistics(std::move(type)); } @@ -329,6 +339,9 @@ void BaseStatistics::Serialize(Serializer &serializer) const { case StatisticsType::ARRAY_STATS: ArrayStats::Serialize(*this, serializer); break; + case StatisticsType::GEOMETRY_STATS: + GeometryStats::Serialize(*this, serializer); + break; default: break; } @@ -367,6 +380,9 @@ BaseStatistics BaseStatistics::Deserialize(Deserializer &deserializer) { case StatisticsType::ARRAY_STATS: ArrayStats::Deserialize(obj, stats); break; + case StatisticsType::GEOMETRY_STATS: + GeometryStats::Deserialize(obj, stats); + break; default: break; } @@ -397,6 +413,9 @@ string BaseStatistics::ToString() const { case StatisticsType::ARRAY_STATS: result = ArrayStats::ToString(*this) + result; break; + case StatisticsType::GEOMETRY_STATS: + result = GeometryStats::ToString(*this) + result; + break; default: break; } @@ -421,6 +440,9 @@ void BaseStatistics::Verify(Vector &vector, const SelectionVector &sel, idx_t co case StatisticsType::ARRAY_STATS: ArrayStats::Verify(*this, vector, sel, count); break; + case StatisticsType::GEOMETRY_STATS: + GeometryStats::Verify(*this, vector, sel, count); + break; default: break; } @@ -505,6 +527,14 @@ BaseStatistics BaseStatistics::FromConstantType(const Value &input) { } return result; } + case StatisticsType::GEOMETRY_STATS: { + auto result = GeometryStats::CreateEmpty(input.type()); + if (!input.IsNull()) { + auto &string_value = StringValue::Get(input); + GeometryStats::Update(result, string_t(string_value)); + } + return result; + } default: return BaseStatistics(input.type()); } diff --git a/src/duckdb/src/storage/statistics/geometry_stats.cpp b/src/duckdb/src/storage/statistics/geometry_stats.cpp new file mode 100644 index 000000000..91ebeaa5f --- /dev/null +++ b/src/duckdb/src/storage/statistics/geometry_stats.cpp @@ -0,0 +1,280 @@ +#include "duckdb/storage/statistics/geometry_stats.hpp" +#include "duckdb/storage/statistics/base_statistics.hpp" +#include "duckdb/common/string_util.hpp" +#include "duckdb/common/types/vector.hpp" +#include "duckdb/common/serializer/serializer.hpp" +#include "duckdb/common/serializer/deserializer.hpp" +#include "duckdb/planner/expression/bound_constant_expression.hpp" +#include "duckdb/planner/expression/bound_function_expression.hpp" + +namespace duckdb { + +vector GeometryTypeSet::ToString(bool snake_case) const { + vector result; + for (auto d = 0; d < VERT_TYPES; d++) { + for (auto i = 0; i < PART_TYPES; i++) { + if (sets[d] & (1 << i)) { + string str; + switch (i) { + case 1: + str = snake_case ? "point" : "Point"; + break; + case 2: + str = snake_case ? "linestring" : "LineString"; + break; + case 3: + str = snake_case ? "polygon" : "Polygon"; + break; + case 4: + str = snake_case ? "multipoint" : "MultiPoint"; + break; + case 5: + str = snake_case ? "multilinestring" : "MultiLineString"; + break; + case 6: + str = snake_case ? "multipolygon" : "MultiPolygon"; + break; + case 7: + str = snake_case ? "geometrycollection" : "GeometryCollection"; + break; + default: + str = snake_case ? "unknown" : "Unknown"; + break; + } + switch (d) { + case 1: + str += snake_case ? "_z" : " Z"; + break; + case 2: + str += snake_case ? "_m" : " M"; + break; + case 3: + str += snake_case ? "_zm" : " ZM"; + break; + default: + break; + } + + result.push_back(str); + } + } + } + return result; +} + +BaseStatistics GeometryStats::CreateUnknown(LogicalType type) { + BaseStatistics result(std::move(type)); + result.InitializeUnknown(); + GetDataUnsafe(result).SetUnknown(); + return result; +} + +BaseStatistics GeometryStats::CreateEmpty(LogicalType type) { + BaseStatistics result(std::move(type)); + result.InitializeEmpty(); + GetDataUnsafe(result).SetEmpty(); + return result; +} + +void GeometryStats::Serialize(const BaseStatistics &stats, Serializer &serializer) { + const auto &data = GetDataUnsafe(stats); + + // Write extent + serializer.WriteObject(200, "extent", [&](Serializer &extent) { + extent.WriteProperty(101, "x_min", data.extent.x_min); + extent.WriteProperty(102, "x_max", data.extent.x_max); + extent.WriteProperty(103, "y_min", data.extent.y_min); + extent.WriteProperty(104, "y_max", data.extent.y_max); + extent.WriteProperty(105, "z_min", data.extent.z_min); + extent.WriteProperty(106, "z_max", data.extent.z_max); + extent.WriteProperty(107, "m_min", data.extent.m_min); + extent.WriteProperty(108, "m_max", data.extent.m_max); + }); + + // Write types + serializer.WriteObject(201, "types", [&](Serializer &types) { + types.WriteProperty(101, "types_xy", data.types.sets[0]); + types.WriteProperty(102, "types_xyz", data.types.sets[1]); + types.WriteProperty(103, "types_xym", data.types.sets[2]); + types.WriteProperty(104, "types_xyzm", data.types.sets[3]); + }); +} + +void GeometryStats::Deserialize(Deserializer &deserializer, BaseStatistics &base) { + auto &data = GetDataUnsafe(base); + + // Read extent + deserializer.ReadObject(200, "extent", [&](Deserializer &extent) { + extent.ReadProperty(101, "x_min", data.extent.x_min); + extent.ReadProperty(102, "x_max", data.extent.x_max); + extent.ReadProperty(103, "y_min", data.extent.y_min); + extent.ReadProperty(104, "y_max", data.extent.y_max); + extent.ReadProperty(105, "z_min", data.extent.z_min); + extent.ReadProperty(106, "z_max", data.extent.z_max); + extent.ReadProperty(107, "m_min", data.extent.m_min); + extent.ReadProperty(108, "m_max", data.extent.m_max); + }); + + // Read types + deserializer.ReadObject(201, "types", [&](Deserializer &types) { + types.ReadProperty(101, "types_xy", data.types.sets[0]); + types.ReadProperty(102, "types_xyz", data.types.sets[1]); + types.ReadProperty(103, "types_xym", data.types.sets[2]); + types.ReadProperty(104, "types_xyzm", data.types.sets[3]); + }); +} + +string GeometryStats::ToString(const BaseStatistics &stats) { + const auto &data = GetDataUnsafe(stats); + string result; + + result += "["; + result += StringUtil::Format("Extent: [X: [%f, %f], Y: [%f, %f], Z: [%f, %f], M: [%f, %f]", data.extent.x_min, + data.extent.x_max, data.extent.y_min, data.extent.y_max, data.extent.z_min, + data.extent.z_max, data.extent.m_min, data.extent.m_max); + result += StringUtil::Format("], Types: [%s]", StringUtil::Join(data.types.ToString(true), ", ")); + + result += "]"; + return result; +} + +void GeometryStats::Update(BaseStatistics &stats, const string_t &value) { + auto &data = GetDataUnsafe(stats); + data.Update(value); +} + +void GeometryStats::Merge(BaseStatistics &stats, const BaseStatistics &other) { + if (other.GetType().id() == LogicalTypeId::VALIDITY) { + return; + } + if (other.GetType().id() == LogicalTypeId::SQLNULL) { + return; + } + + auto &target = GetDataUnsafe(stats); + auto &source = GetDataUnsafe(other); + target.Merge(source); +} + +void GeometryStats::Verify(const BaseStatistics &stats, Vector &vector, const SelectionVector &sel, idx_t count) { + // TODO: Verify stats +} + +const GeometryStatsData &GeometryStats::GetDataUnsafe(const BaseStatistics &stats) { + D_ASSERT(stats.GetStatsType() == StatisticsType::GEOMETRY_STATS); + return stats.stats_union.geometry_data; +} + +GeometryStatsData &GeometryStats::GetDataUnsafe(BaseStatistics &stats) { + D_ASSERT(stats.GetStatsType() == StatisticsType::GEOMETRY_STATS); + return stats.stats_union.geometry_data; +} + +GeometryExtent &GeometryStats::GetExtent(BaseStatistics &stats) { + return GetDataUnsafe(stats).extent; +} + +GeometryTypeSet &GeometryStats::GetTypes(BaseStatistics &stats) { + return GetDataUnsafe(stats).types; +} + +const GeometryExtent &GeometryStats::GetExtent(const BaseStatistics &stats) { + return GetDataUnsafe(stats).extent; +} + +const GeometryTypeSet &GeometryStats::GetTypes(const BaseStatistics &stats) { + return GetDataUnsafe(stats).types; +} + +// Expression comparison pruning +static FilterPropagateResult CheckIntersectionFilter(const GeometryStatsData &data, const Value &constant) { + if (constant.IsNull() || constant.type().id() != LogicalTypeId::GEOMETRY) { + // Cannot prune against NULL + return FilterPropagateResult::NO_PRUNING_POSSIBLE; + } + + // This has been checked before and needs to be true for the checks below to be valid + D_ASSERT(data.extent.HasXY()); + + const auto &geom = StringValue::Get(constant); + auto extent = GeometryExtent::Empty(); + if (Geometry::GetExtent(string_t(geom), extent) == 0) { + // If the geometry is empty, the predicate will never match + return FilterPropagateResult::FILTER_ALWAYS_FALSE; + } + + // Check if the bounding boxes intersect + // If the bounding boxes do not intersect, the predicate will never match + if (!extent.IntersectsXY(data.extent)) { + return FilterPropagateResult::FILTER_ALWAYS_FALSE; + } + + // If the column is completely inside the bounds, the predicate will always match + if (extent.ContainsXY(data.extent)) { + return FilterPropagateResult::FILTER_ALWAYS_TRUE; + } + + // We cannot prune, as this column may contain geometries that intersect + return FilterPropagateResult::NO_PRUNING_POSSIBLE; +} + +FilterPropagateResult GeometryStats::CheckZonemap(const BaseStatistics &stats, const unique_ptr &expr) { + if (expr->GetExpressionType() != ExpressionType::BOUND_FUNCTION) { + return FilterPropagateResult::NO_PRUNING_POSSIBLE; + } + if (expr->return_type != LogicalType::BOOLEAN) { + return FilterPropagateResult::NO_PRUNING_POSSIBLE; + } + const auto &func = expr->Cast(); + if (func.children.size() != 2) { + return FilterPropagateResult::NO_PRUNING_POSSIBLE; + } + + if (func.children[0]->return_type.id() != LogicalTypeId::GEOMETRY || + func.children[1]->return_type.id() != LogicalTypeId::GEOMETRY) { + return FilterPropagateResult::NO_PRUNING_POSSIBLE; + } + + // The set of geometry predicates that can be optimized using the bounding box + static constexpr const char *geometry_predicates[2] = {"&&", "st_intersects_extent"}; + + auto found = false; + for (const auto &name : geometry_predicates) { + if (StringUtil::CIEquals(func.function.name.c_str(), name)) { + found = true; + break; + } + } + if (!found) { + // Not a geometry predicate we can optimize + return FilterPropagateResult::NO_PRUNING_POSSIBLE; + } + + const auto lhs_kind = func.children[0]->GetExpressionType(); + const auto rhs_kind = func.children[1]->GetExpressionType(); + const auto lhs_is_const = lhs_kind == ExpressionType::VALUE_CONSTANT && rhs_kind == ExpressionType::BOUND_REF; + const auto rhs_is_const = rhs_kind == ExpressionType::VALUE_CONSTANT && lhs_kind == ExpressionType::BOUND_REF; + + if (!stats.CanHaveNoNull()) { + // no non-null values are possible: always false + return FilterPropagateResult::FILTER_ALWAYS_FALSE; + } + + auto &data = GetDataUnsafe(stats); + + if (!data.extent.HasXY()) { + // If the extent is empty or unknown, we cannot prune + return FilterPropagateResult::NO_PRUNING_POSSIBLE; + } + + if (lhs_is_const) { + return CheckIntersectionFilter(data, func.children[0]->Cast().value); + } + if (rhs_is_const) { + return CheckIntersectionFilter(data, func.children[1]->Cast().value); + } + // Else, no constant argument + return FilterPropagateResult::NO_PRUNING_POSSIBLE; +} + +} // namespace duckdb diff --git a/src/duckdb/src/storage/statistics/string_stats.cpp b/src/duckdb/src/storage/statistics/string_stats.cpp index e7d232692..3fe22ecac 100644 --- a/src/duckdb/src/storage/statistics/string_stats.cpp +++ b/src/duckdb/src/storage/statistics/string_stats.cpp @@ -170,6 +170,14 @@ void StringStats::Update(BaseStatistics &stats, const string_t &value) { } } +void StringStats::SetMin(BaseStatistics &stats, const string_t &value) { + ConstructValue(const_data_ptr_cast(value.GetData()), value.GetSize(), GetDataUnsafe(stats).min); +} + +void StringStats::SetMax(BaseStatistics &stats, const string_t &value) { + ConstructValue(const_data_ptr_cast(value.GetData()), value.GetSize(), GetDataUnsafe(stats).max); +} + void StringStats::Merge(BaseStatistics &stats, const BaseStatistics &other) { if (other.GetType().id() == LogicalTypeId::VALIDITY) { return; diff --git a/src/duckdb/src/storage/storage_info.cpp b/src/duckdb/src/storage/storage_info.cpp index 616aa3039..c55602367 100644 --- a/src/duckdb/src/storage/storage_info.cpp +++ b/src/duckdb/src/storage/storage_info.cpp @@ -4,6 +4,10 @@ #include "duckdb/common/optional_idx.hpp" namespace duckdb { +constexpr idx_t Storage::MAX_ROW_GROUP_SIZE; +constexpr idx_t Storage::MAX_BLOCK_ALLOC_SIZE; +constexpr idx_t Storage::MIN_BLOCK_ALLOC_SIZE; +constexpr idx_t Storage::DEFAULT_BLOCK_HEADER_SIZE; const uint64_t VERSION_NUMBER = 64; const uint64_t VERSION_NUMBER_LOWER = 64; @@ -83,6 +87,8 @@ static const StorageVersionInfo storage_version_info[] = { {"v1.3.1", 66}, {"v1.3.2", 66}, {"v1.4.0", 67}, + {"v1.4.1", 67}, + {"v1.4.2", 67}, {"v1.5.0", 67}, {nullptr, 0} }; @@ -109,6 +115,8 @@ static const SerializationVersionInfo serialization_version_info[] = { {"v1.3.1", 5}, {"v1.3.2", 5}, {"v1.4.0", 6}, + {"v1.4.1", 6}, + {"v1.4.2", 6}, {"v1.5.0", 7}, {"latest", 7}, {nullptr, 0} diff --git a/src/duckdb/src/storage/storage_manager.cpp b/src/duckdb/src/storage/storage_manager.cpp index d82905452..3a7c8ce28 100644 --- a/src/duckdb/src/storage/storage_manager.cpp +++ b/src/duckdb/src/storage/storage_manager.cpp @@ -5,6 +5,7 @@ #include "duckdb/common/serializer/memory_stream.hpp" #include "duckdb/main/attached_database.hpp" #include "duckdb/main/client_context.hpp" +#include "duckdb/main/client_data.hpp" #include "duckdb/main/database.hpp" #include "duckdb/storage/checkpoint_manager.hpp" #include "duckdb/storage/in_memory_block_manager.hpp" @@ -80,7 +81,6 @@ void StorageOptions::Initialize(const unordered_map &options) { StorageManager::StorageManager(AttachedDatabase &db, string path_p, const AttachOptions &options) : db(db), path(std::move(path_p)), read_only(options.access_mode == AccessMode::READ_ONLY), in_memory_change_size(0) { - if (path.empty()) { path = IN_MEMORY_PATH; return; @@ -110,7 +110,10 @@ ObjectCache &ObjectCache::GetObjectCache(ClientContext &context) { } idx_t StorageManager::GetWALSize() { - return InMemory() ? in_memory_change_size.load() : wal->GetWALSize(); + if (InMemory() || wal->GetDatabase().GetRecoveryMode() == RecoveryMode::NO_WAL_WRITES) { + return in_memory_change_size.load(); + } + return wal->GetWALSize(); } optional_ptr StorageManager::GetWAL() { @@ -275,6 +278,7 @@ void SingleFileStorageManager::LoadDatabase(QueryContext context) { block_manager = std::move(sf_block_manager); table_io_manager = make_uniq(*block_manager, row_group_size); wal = make_uniq(db, wal_path); + } else { // Either the file exists, or we are in read-only mode, so we // try to read the existing file on disk. @@ -322,13 +326,40 @@ void SingleFileStorageManager::LoadDatabase(QueryContext context) { } } - // load the db from storage + // Start timing the storage load step. + auto client_context = context.GetClientContext(); + if (client_context) { + auto profiler = client_context->client_data->profiler; + profiler->StartTimer(MetricsType::ATTACH_LOAD_STORAGE_LATENCY); + } + + // Load the checkpoint from storage. auto checkpoint_reader = SingleFileCheckpointReader(*this); checkpoint_reader.LoadFromStorage(); + // End timing the storage load step. + if (client_context) { + auto profiler = client_context->client_data->profiler; + profiler->EndTimer(MetricsType::ATTACH_LOAD_STORAGE_LATENCY); + } + + // Start timing the WAL replay step. + if (client_context) { + auto profiler = client_context->client_data->profiler; + profiler->StartTimer(MetricsType::ATTACH_REPLAY_WAL_LATENCY); + } + + // Replay the WAL. auto wal_path = GetWALPath(); - wal = WriteAheadLog::Replay(fs, db, wal_path); + wal = WriteAheadLog::Replay(context, fs, db, wal_path); + + // End timing the WAL replay step. + if (client_context) { + auto profiler = client_context->client_data->profiler; + profiler->EndTimer(MetricsType::ATTACH_REPLAY_WAL_LATENCY); + } } + if (row_group_size > 122880ULL && GetStorageVersion() < 4) { throw InvalidInputException("Unsupported row group size %llu - row group sizes >= 122_880 are only supported " "with STORAGE_VERSION '1.2.0' or above.\nExplicitly specify a newer storage " @@ -476,17 +507,35 @@ void SingleFileStorageManager::CreateCheckpoint(QueryContext context, Checkpoint if (db.GetStorageExtension()) { db.GetStorageExtension()->OnCheckpointStart(db, options); } + auto &config = DBConfig::Get(db); - if (GetWALSize() > 0 || config.options.force_checkpoint || options.action == CheckpointAction::ALWAYS_CHECKPOINT) { - // we only need to checkpoint if there is anything in the WAL + // We only need to checkpoint if there is anything in the WAL. + auto wal_size = GetWALSize(); + if (wal_size > 0 || config.options.force_checkpoint || options.action == CheckpointAction::ALWAYS_CHECKPOINT) { try { + // Start timing the checkpoint. + auto client_context = context.GetClientContext(); + if (client_context) { + auto profiler = client_context->client_data->profiler; + profiler->StartTimer(MetricsType::CHECKPOINT_LATENCY); + } + + // Write the checkpoint. auto checkpointer = CreateCheckpointWriter(context, options); checkpointer->CreateCheckpoint(); + + // End timing the checkpoint. + if (client_context) { + auto profiler = client_context->client_data->profiler; + profiler->EndTimer(MetricsType::CHECKPOINT_LATENCY); + } + } catch (std::exception &ex) { ErrorData error(ex); throw FatalException("Failed to create checkpoint because of error: %s", error.RawMessage()); } } + if (!InMemory() && options.wal_action == CheckpointWALAction::DELETE_WAL) { ResetWAL(); } diff --git a/src/duckdb/src/storage/table/array_column_data.cpp b/src/duckdb/src/storage/table/array_column_data.cpp index 7c8a12f13..dbca6b0fe 100644 --- a/src/duckdb/src/storage/table/array_column_data.cpp +++ b/src/duckdb/src/storage/table/array_column_data.cpp @@ -120,7 +120,7 @@ void ArrayColumnData::Select(TransactionData transaction, idx_t vector_index, Co // not consecutive - break break; } - end_idx = next_idx; + end_idx = next_idx + 1; } consecutive_ranges++; } @@ -224,13 +224,14 @@ idx_t ArrayColumnData::Fetch(ColumnScanState &state, row_t row_id, Vector &resul throw NotImplementedException("Array Fetch"); } -void ArrayColumnData::Update(TransactionData transaction, idx_t column_index, Vector &update_vector, row_t *row_ids, - idx_t update_count) { +void ArrayColumnData::Update(TransactionData transaction, DataTable &data_table, idx_t column_index, + Vector &update_vector, row_t *row_ids, idx_t update_count) { throw NotImplementedException("Array Update is not supported."); } -void ArrayColumnData::UpdateColumn(TransactionData transaction, const vector &column_path, - Vector &update_vector, row_t *row_ids, idx_t update_count, idx_t depth) { +void ArrayColumnData::UpdateColumn(TransactionData transaction, DataTable &data_table, + const vector &column_path, Vector &update_vector, row_t *row_ids, + idx_t update_count, idx_t depth) { throw NotImplementedException("Array Update Column is not supported"); } @@ -240,7 +241,6 @@ unique_ptr ArrayColumnData::GetUpdateStatistics() { void ArrayColumnData::FetchRow(TransactionData transaction, ColumnFetchState &state, row_t row_id, Vector &result, idx_t result_idx) { - // Create state for validity & child column if (state.child_states.empty()) { state.child_states.push_back(make_uniq()); @@ -256,7 +256,7 @@ void ArrayColumnData::FetchRow(TransactionData transaction, ColumnFetchState &st // We need to fetch between [row_id * array_size, (row_id + 1) * array_size) auto child_state = make_uniq(); - child_state->Initialize(child_type, nullptr); + child_state->Initialize(state.context, child_type, nullptr); const auto child_offset = start + (UnsafeNumericCast(row_id) - start) * array_size; @@ -302,8 +302,8 @@ unique_ptr ArrayColumnData::CreateCheckpointState(RowGrou unique_ptr ArrayColumnData::Checkpoint(RowGroup &row_group, ColumnCheckpointInfo &checkpoint_info) { - - auto checkpoint_state = make_uniq(row_group, *this, checkpoint_info.info.manager); + auto &partial_block_manager = checkpoint_info.GetPartialBlockManager(); + auto checkpoint_state = make_uniq(row_group, *this, partial_block_manager); checkpoint_state->validity_state = validity.Checkpoint(row_group, checkpoint_info); checkpoint_state->child_state = child_column->Checkpoint(row_group, checkpoint_info); return std::move(checkpoint_state); @@ -332,12 +332,12 @@ void ArrayColumnData::InitializeColumn(PersistentColumnData &column_data, BaseSt this->count = validity.count.load(); } -void ArrayColumnData::GetColumnSegmentInfo(idx_t row_group_index, vector col_path, +void ArrayColumnData::GetColumnSegmentInfo(const QueryContext &context, idx_t row_group_index, vector col_path, vector &result) { col_path.push_back(0); - validity.GetColumnSegmentInfo(row_group_index, col_path, result); + validity.GetColumnSegmentInfo(context, row_group_index, col_path, result); col_path.back() = 1; - child_column->GetColumnSegmentInfo(row_group_index, col_path, result); + child_column->GetColumnSegmentInfo(context, row_group_index, col_path, result); } void ArrayColumnData::Verify(RowGroup &parent) { diff --git a/src/duckdb/src/storage/table/chunk_info.cpp b/src/duckdb/src/storage/table/chunk_info.cpp index 3b7b11d7b..702b4beb6 100644 --- a/src/duckdb/src/storage/table/chunk_info.cpp +++ b/src/duckdb/src/storage/table/chunk_info.cpp @@ -1,10 +1,12 @@ #include "duckdb/storage/table/chunk_info.hpp" + #include "duckdb/transaction/transaction.hpp" #include "duckdb/common/exception/transaction_exception.hpp" #include "duckdb/common/serializer/serializer.hpp" #include "duckdb/common/serializer/deserializer.hpp" #include "duckdb/common/serializer/memory_stream.hpp" #include "duckdb/transaction/delete_info.hpp" +#include "duckdb/execution/index/fixed_size_allocator.hpp" namespace duckdb { @@ -40,7 +42,7 @@ void ChunkInfo::Write(WriteStream &writer) const { writer.Write(type); } -unique_ptr ChunkInfo::Read(ReadStream &reader) { +unique_ptr ChunkInfo::Read(FixedSizeAllocator &allocator, ReadStream &reader) { auto type = reader.Read(); switch (type) { case ChunkInfoType::EMPTY_INFO: @@ -48,7 +50,7 @@ unique_ptr ChunkInfo::Read(ReadStream &reader) { case ChunkInfoType::CONSTANT_INFO: return ChunkConstantInfo::Read(reader); case ChunkInfoType::VECTOR_INFO: - return ChunkVectorInfo::Read(reader); + return ChunkVectorInfo::Read(allocator, reader); default: throw SerializationException("Could not deserialize Chunk Info Type: unrecognized type"); } @@ -71,7 +73,7 @@ idx_t ChunkConstantInfo::TemplatedGetSelVector(transaction_t start_time, transac return 0; } -idx_t ChunkConstantInfo::GetSelVector(TransactionData transaction, SelectionVector &sel_vector, idx_t max_count) { +idx_t ChunkConstantInfo::GetSelVector(TransactionData transaction, SelectionVector &sel_vector, idx_t max_count) const { return TemplatedGetSelVector(transaction.start_time, transaction.transaction_id, sel_vector, max_count); } @@ -95,7 +97,7 @@ bool ChunkConstantInfo::HasDeletes() const { return is_deleted; } -idx_t ChunkConstantInfo::GetCommittedDeletedCount(idx_t max_count) { +idx_t ChunkConstantInfo::GetCommittedDeletedCount(idx_t max_count) const { return delete_id < TRANSACTION_ID_START ? max_count : 0; } @@ -128,49 +130,70 @@ unique_ptr ChunkConstantInfo::Read(ReadStream &reader) { //===--------------------------------------------------------------------===// // Vector info //===--------------------------------------------------------------------===// -ChunkVectorInfo::ChunkVectorInfo(idx_t start) - : ChunkInfo(start, ChunkInfoType::VECTOR_INFO), insert_id(0), same_inserted_id(true), any_deleted(false) { - for (idx_t i = 0; i < STANDARD_VECTOR_SIZE; i++) { - inserted[i] = 0; - deleted[i] = NOT_DELETED_ID; +ChunkVectorInfo::ChunkVectorInfo(FixedSizeAllocator &allocator_p, idx_t start, transaction_t insert_id_p) + : ChunkInfo(start, ChunkInfoType::VECTOR_INFO), allocator(allocator_p), constant_insert_id(insert_id_p) { +} + +ChunkVectorInfo::~ChunkVectorInfo() { + if (AnyDeleted()) { + allocator.Free(deleted_data); + } + if (!HasConstantInsertionId()) { + allocator.Free(inserted_data); } } template idx_t ChunkVectorInfo::TemplatedGetSelVector(transaction_t start_time, transaction_t transaction_id, SelectionVector &sel_vector, idx_t max_count) const { - idx_t count = 0; - if (same_inserted_id && !any_deleted) { - // all tuples have the same inserted id: and no tuples were deleted - if (OP::UseInsertedVersion(start_time, transaction_id, insert_id)) { - return max_count; - } else { - return 0; + if (HasConstantInsertionId()) { + if (!AnyDeleted()) { + // all tuples have the same inserted id: and no tuples were deleted + if (OP::UseInsertedVersion(start_time, transaction_id, ConstantInsertId())) { + return max_count; + } else { + return 0; + } } - } else if (same_inserted_id) { - if (!OP::UseInsertedVersion(start_time, transaction_id, insert_id)) { + if (!OP::UseInsertedVersion(start_time, transaction_id, ConstantInsertId())) { return 0; } // have to check deleted flag + idx_t count = 0; + auto segment = allocator.GetHandle(GetDeletedPointer()); + auto deleted = segment.GetPtr(); for (idx_t i = 0; i < max_count; i++) { if (OP::UseDeletedVersion(start_time, transaction_id, deleted[i])) { sel_vector.set_index(count++, i); } } - } else if (!any_deleted) { + return count; + } + if (!AnyDeleted()) { // have to check inserted flag + auto insert_segment = allocator.GetHandle(GetInsertedPointer()); + auto inserted = insert_segment.GetPtr(); + + idx_t count = 0; for (idx_t i = 0; i < max_count; i++) { if (OP::UseInsertedVersion(start_time, transaction_id, inserted[i])) { sel_vector.set_index(count++, i); } } - } else { - // have to check both flags - for (idx_t i = 0; i < max_count; i++) { - if (OP::UseInsertedVersion(start_time, transaction_id, inserted[i]) && - OP::UseDeletedVersion(start_time, transaction_id, deleted[i])) { - sel_vector.set_index(count++, i); - } + return count; + } + + idx_t count = 0; + // have to check both flags + auto insert_segment = allocator.GetHandle(GetInsertedPointer()); + auto inserted = insert_segment.GetPtr(); + + auto delete_segment = allocator.GetHandle(GetDeletedPointer()); + auto deleted = delete_segment.GetPtr(); + for (idx_t i = 0; i < max_count; i++) { + if (OP::UseInsertedVersion(start_time, transaction_id, inserted[i]) && + OP::UseDeletedVersion(start_time, transaction_id, deleted[i])) { + sel_vector.set_index(count++, i); } } return count; @@ -186,16 +209,76 @@ idx_t ChunkVectorInfo::GetCommittedSelVector(transaction_t min_start_id, transac return TemplatedGetSelVector(min_start_id, min_transaction_id, sel_vector, max_count); } -idx_t ChunkVectorInfo::GetSelVector(TransactionData transaction, SelectionVector &sel_vector, idx_t max_count) { +idx_t ChunkVectorInfo::GetSelVector(TransactionData transaction, SelectionVector &sel_vector, idx_t max_count) const { return GetSelVector(transaction.start_time, transaction.transaction_id, sel_vector, max_count); } bool ChunkVectorInfo::Fetch(TransactionData transaction, row_t row) { - return UseVersion(transaction, inserted[row]) && !UseVersion(transaction, deleted[row]); + transaction_t fetch_insert_id; + transaction_t fetch_deleted_id; + if (HasConstantInsertionId()) { + fetch_insert_id = ConstantInsertId(); + } else { + auto insert_segment = allocator.GetHandle(GetInsertedPointer()); + auto inserted = insert_segment.GetPtr(); + fetch_insert_id = inserted[row]; + } + if (!AnyDeleted()) { + fetch_deleted_id = NOT_DELETED_ID; + } else { + auto delete_segment = allocator.GetHandle(GetDeletedPointer()); + auto deleted = delete_segment.GetPtr(); + fetch_deleted_id = deleted[row]; + } + + return UseVersion(transaction, fetch_insert_id) && !UseVersion(transaction, fetch_deleted_id); +} + +IndexPointer ChunkVectorInfo::GetInsertedPointer() const { + if (HasConstantInsertionId()) { + throw InternalException("ChunkVectorInfo: insert id requested but insertions were not initialized"); + } + return inserted_data; +} + +IndexPointer ChunkVectorInfo::GetDeletedPointer() const { + if (!AnyDeleted()) { + throw InternalException("ChunkVectorInfo: deleted id requested but deletions were not initialized"); + } + return deleted_data; +} + +IndexPointer ChunkVectorInfo::GetInitializedInsertedPointer() { + if (HasConstantInsertionId()) { + transaction_t constant_id = ConstantInsertId(); + + inserted_data = allocator.New(); + inserted_data.SetMetadata(1); + auto segment = allocator.GetHandle(inserted_data); + auto inserted = segment.GetPtr(); + for (idx_t i = 0; i < STANDARD_VECTOR_SIZE; i++) { + inserted[i] = constant_id; + } + } + return inserted_data; +} + +IndexPointer ChunkVectorInfo::GetInitializedDeletedPointer() { + if (!AnyDeleted()) { + deleted_data = allocator.New(); + deleted_data.SetMetadata(1); + auto segment = allocator.GetHandle(deleted_data); + auto deleted = segment.GetPtr(); + for (idx_t i = 0; i < STANDARD_VECTOR_SIZE; i++) { + deleted[i] = NOT_DELETED_ID; + } + } + return deleted_data; } idx_t ChunkVectorInfo::Delete(transaction_t transaction_id, row_t rows[], idx_t count) { - any_deleted = true; + auto segment = allocator.GetHandle(GetInitializedDeletedPointer()); + auto deleted = segment.GetPtr(); idx_t deleted_tuples = 0; for (idx_t i = 0; i < count; i++) { @@ -220,6 +303,9 @@ idx_t ChunkVectorInfo::Delete(transaction_t transaction_id, row_t rows[], idx_t } void ChunkVectorInfo::CommitDelete(transaction_t commit_id, const DeleteInfo &info) { + auto segment = allocator.GetHandle(GetDeletedPointer()); + auto deleted = segment.GetPtr(); + if (info.is_consecutive) { for (idx_t i = 0; i < info.count; i++) { deleted[i] = commit_id; @@ -234,32 +320,45 @@ void ChunkVectorInfo::CommitDelete(transaction_t commit_id, const DeleteInfo &in void ChunkVectorInfo::Append(idx_t start, idx_t end, transaction_t commit_id) { if (start == 0) { - insert_id = commit_id; - } else if (insert_id != commit_id) { - same_inserted_id = false; - insert_id = NOT_DELETED_ID; + // first insert to this vector - just assign the commit id + constant_insert_id = commit_id; + return; + } + if (HasConstantInsertionId() && ConstantInsertId() == commit_id) { + // we are inserting again, but we have the same id as before - still the same insert id + return; } + + auto segment = allocator.GetHandle(GetInitializedInsertedPointer()); + auto inserted = segment.GetPtr(); for (idx_t i = start; i < end; i++) { inserted[i] = commit_id; } } void ChunkVectorInfo::CommitAppend(transaction_t commit_id, idx_t start, idx_t end) { - if (same_inserted_id) { - insert_id = commit_id; + if (HasConstantInsertionId()) { + constant_insert_id = commit_id; + return; } + auto segment = allocator.GetHandle(GetInsertedPointer()); + auto inserted = segment.GetPtr(); + for (idx_t i = start; i < end; i++) { inserted[i] = commit_id; } } bool ChunkVectorInfo::Cleanup(transaction_t lowest_transaction, unique_ptr &result) const { - if (any_deleted) { + if (AnyDeleted()) { // if any rows are deleted we can't clean-up return false; } // check if the insertion markers have to be used by all transactions going forward - if (!same_inserted_id) { + if (!HasConstantInsertionId()) { + auto segment = allocator.GetHandle(GetInsertedPointer()); + auto inserted = segment.GetPtr(); + for (idx_t idx = 1; idx < STANDARD_VECTOR_SIZE; idx++) { if (inserted[idx] > lowest_transaction) { // transaction was inserted after the lowest transaction start @@ -267,7 +366,7 @@ bool ChunkVectorInfo::Cleanup(transaction_t lowest_transaction, unique_ptr lowest_transaction) { + } else if (ConstantInsertId() > lowest_transaction) { // transaction was inserted after the lowest transaction start // we still need to use an older version - cannot compress return false; @@ -276,13 +375,31 @@ bool ChunkVectorInfo::Cleanup(transaction_t lowest_transaction, unique_ptr(); + idx_t delete_count = 0; for (idx_t i = 0; i < max_count; i++) { if (deleted[i] < TRANSACTION_ID_START) { @@ -319,15 +436,17 @@ void ChunkVectorInfo::Write(WriteStream &writer) const { mask.Write(writer, STANDARD_VECTOR_SIZE); } -unique_ptr ChunkVectorInfo::Read(ReadStream &reader) { +unique_ptr ChunkVectorInfo::Read(FixedSizeAllocator &allocator, ReadStream &reader) { auto start = reader.Read(); - auto result = make_uniq(start); - result->any_deleted = true; + auto result = make_uniq(allocator, start); ValidityMask mask; mask.Read(reader, STANDARD_VECTOR_SIZE); + + auto segment = allocator.GetHandle(result->GetInitializedDeletedPointer()); + auto deleted = segment.GetPtr(); for (idx_t i = 0; i < STANDARD_VECTOR_SIZE; i++) { if (mask.RowIsValid(i)) { - result->deleted[i] = 0; + deleted[i] = 0; } } return std::move(result); diff --git a/src/duckdb/src/storage/table/column_data.cpp b/src/duckdb/src/storage/table/column_data.cpp index c212fcb18..64a9fce0e 100644 --- a/src/duckdb/src/storage/table/column_data.cpp +++ b/src/duckdb/src/storage/table/column_data.cpp @@ -81,14 +81,14 @@ bool ColumnData::HasChanges() const { auto l = data.Lock(); auto &nodes = data.ReferenceLoadedSegments(l); for (idx_t segment_idx = 0; segment_idx < nodes.size(); segment_idx++) { - auto segment = nodes[segment_idx].node.get(); - if (segment->segment_type == ColumnSegmentType::TRANSIENT) { + auto &segment = *nodes[segment_idx]->node; + if (segment.segment_type == ColumnSegmentType::TRANSIENT) { // transient segment: always need to write to disk return true; } // persistent segment; check if there were any updates or deletions in this segment - idx_t start_row_idx = segment->start - start; - idx_t end_row_idx = start_row_idx + segment->count; + idx_t start_row_idx = segment.start - start; + idx_t end_row_idx = start_row_idx + segment.count; if (HasChanges(start_row_idx, end_row_idx)) { return true; } @@ -112,7 +112,7 @@ idx_t ColumnData::GetMaxEntry() { void ColumnData::InitializeScan(ColumnScanState &state) { state.current = data.GetRootSegment(); state.segment_tree = &data; - state.row_index = state.current ? state.current->start : 0; + state.row_index = state.current ? state.current->row_start : 0; state.internal_index = state.row_index; state.initialized = false; state.scan_state.reset(); @@ -123,7 +123,7 @@ void ColumnData::InitializeScanWithOffset(ColumnScanState &state, idx_t row_idx) state.current = data.GetSegment(row_idx); state.segment_tree = &data; state.row_index = row_idx; - state.internal_index = state.current->start; + state.internal_index = state.current->row_start; state.initialized = false; state.scan_state.reset(); state.last_offset = 0; @@ -139,7 +139,8 @@ ScanVectorType ColumnData::GetVectorScanType(ColumnScanState &state, idx_t scan_ return ScanVectorType::SCAN_FLAT_VECTOR; } // check if the current segment has enough data remaining - idx_t remaining_in_segment = state.current->start + state.current->count - state.row_index; + auto ¤t = *state.current->node; + idx_t remaining_in_segment = current.start + current.count - state.row_index; if (remaining_in_segment < scan_count) { // there is not enough data remaining in the current segment so we need to scan across segments // we need flat vectors here @@ -155,19 +156,20 @@ void ColumnData::InitializePrefetch(PrefetchState &prefetch_state, ColumnScanSta } if (!scan_state.initialized) { // need to prefetch for the current segment if we have not yet initialized the scan for this segment - scan_state.current->InitializePrefetch(prefetch_state, scan_state); + current_segment->node->InitializePrefetch(prefetch_state, scan_state); } idx_t row_index = scan_state.row_index; while (remaining > 0) { - idx_t scan_count = MinValue(remaining, current_segment->start + current_segment->count - row_index); + auto ¤t = *current_segment->node; + idx_t scan_count = MinValue(remaining, current.start + current.count - row_index); remaining -= scan_count; row_index += scan_count; if (remaining > 0) { - auto next = data.GetNextSegment(current_segment); + auto next = data.GetNextSegment(*current_segment); if (!next) { break; } - next->InitializePrefetch(prefetch_state, scan_state); + next->node->InitializePrefetch(prefetch_state, scan_state); current_segment = next; } } @@ -176,17 +178,18 @@ void ColumnData::InitializePrefetch(PrefetchState &prefetch_state, ColumnScanSta void ColumnData::BeginScanVectorInternal(ColumnScanState &state) { state.previous_states.clear(); if (!state.initialized) { - D_ASSERT(state.current); - state.current->InitializeScan(state); - state.internal_index = state.current->start; + auto ¤t = *state.current->node; + current.InitializeScan(state); + state.internal_index = current.start; state.initialized = true; } - D_ASSERT(data.HasSegment(state.current)); + D_ASSERT(data.HasSegment(*state.current)); D_ASSERT(state.internal_index <= state.row_index); if (state.internal_index < state.row_index) { - state.current->Skip(state); + auto ¤t = *state.current->node; + current.Skip(state); } - D_ASSERT(state.current->type == type); + D_ASSERT(state.current->node->type == type); } idx_t ColumnData::ScanVector(ColumnScanState &state, Vector &result, idx_t remaining, ScanVectorType scan_type, @@ -197,19 +200,19 @@ idx_t ColumnData::ScanVector(ColumnScanState &state, Vector &result, idx_t remai BeginScanVectorInternal(state); idx_t initial_remaining = remaining; while (remaining > 0) { - D_ASSERT(state.row_index >= state.current->start && - state.row_index <= state.current->start + state.current->count); - idx_t scan_count = MinValue(remaining, state.current->start + state.current->count - state.row_index); + auto ¤t = *state.current->node; + D_ASSERT(state.row_index >= current.start && state.row_index <= current.start + current.count); + idx_t scan_count = MinValue(remaining, current.start + current.count - state.row_index); idx_t result_offset = base_result_offset + initial_remaining - remaining; if (scan_count > 0) { if (state.scan_options && state.scan_options->force_fetch_row) { for (idx_t i = 0; i < scan_count; i++) { ColumnFetchState fetch_state; - state.current->FetchRow(fetch_state, UnsafeNumericCast(state.row_index + i), result, - result_offset + i); + current.FetchRow(fetch_state, UnsafeNumericCast(state.row_index + i), result, + result_offset + i); } } else { - state.current->Scan(state, scan_count, result, result_offset, scan_type); + current.Scan(state, scan_count, result, result_offset, scan_type); } state.row_index += scan_count; @@ -217,16 +220,16 @@ idx_t ColumnData::ScanVector(ColumnScanState &state, Vector &result, idx_t remai } if (remaining > 0) { - auto next = data.GetNextSegment(state.current); + auto next = data.GetNextSegment(*state.current); if (!next) { break; } state.previous_states.emplace_back(std::move(state.scan_state)); state.current = next; - state.current->InitializeScan(state); + state.current->node->InitializeScan(state); state.segment_checked = false; - D_ASSERT(state.row_index >= state.current->start && - state.row_index <= state.current->start + state.current->count); + D_ASSERT(state.row_index >= state.current->node->start && + state.row_index <= state.current->node->start + state.current->node->count); } } state.internal_index = state.row_index; @@ -236,17 +239,18 @@ idx_t ColumnData::ScanVector(ColumnScanState &state, Vector &result, idx_t remai void ColumnData::SelectVector(ColumnScanState &state, Vector &result, idx_t target_count, const SelectionVector &sel, idx_t sel_count) { BeginScanVectorInternal(state); - if (state.current->start + state.current->count - state.row_index < target_count) { + auto ¤t = *state.current->node; + if (current.start + current.count - state.row_index < target_count) { throw InternalException("ColumnData::SelectVector should be able to fetch everything from one segment"); } if (state.scan_options && state.scan_options->force_fetch_row) { for (idx_t i = 0; i < sel_count; i++) { auto source_idx = sel.get_index(i); ColumnFetchState fetch_state; - state.current->FetchRow(fetch_state, UnsafeNumericCast(state.row_index + source_idx), result, i); + current.FetchRow(fetch_state, UnsafeNumericCast(state.row_index + source_idx), result, i); } } else { - state.current->Select(state, target_count, result, sel, sel_count); + current.Select(state, target_count, result, sel, sel_count); } state.row_index += target_count; state.internal_index = state.row_index; @@ -255,10 +259,11 @@ void ColumnData::SelectVector(ColumnScanState &state, Vector &result, idx_t targ void ColumnData::FilterVector(ColumnScanState &state, Vector &result, idx_t target_count, SelectionVector &sel, idx_t &sel_count, const TableFilter &filter, TableFilterState &filter_state) { BeginScanVectorInternal(state); - if (state.current->start + state.current->count - state.row_index < target_count) { + auto ¤t = *state.current->node; + if (current.start + current.count - state.row_index < target_count) { throw InternalException("ColumnData::Filter should be able to fetch everything from one segment"); } - state.current->Filter(state, target_count, result, sel, sel_count, filter, filter_state); + current.Filter(state, target_count, result, sel, sel_count, filter, filter_state); state.row_index += target_count; state.internal_index = state.row_index; } @@ -293,13 +298,13 @@ void ColumnData::FetchUpdateRow(TransactionData transaction, row_t row_id, Vecto updates->FetchRow(transaction, NumericCast(row_id), result, result_idx); } -void ColumnData::UpdateInternal(TransactionData transaction, idx_t column_index, Vector &update_vector, row_t *row_ids, - idx_t update_count, Vector &base_vector) { +void ColumnData::UpdateInternal(TransactionData transaction, DataTable &data_table, idx_t column_index, + Vector &update_vector, row_t *row_ids, idx_t update_count, Vector &base_vector) { lock_guard update_guard(update_lock); if (!updates) { updates = make_uniq(*this); } - updates->Update(transaction, column_index, update_vector, row_ids, update_count, base_vector); + updates->Update(transaction, data_table, column_index, update_vector, row_ids, update_count, base_vector); } idx_t ColumnData::ScanVector(TransactionData transaction, idx_t vector_index, ColumnScanState &state, Vector &result, @@ -420,7 +425,7 @@ FilterPropagateResult ColumnData::CheckZonemap(ColumnScanState &state, TableFilt FilterPropagateResult prune_result; { lock_guard l(stats_lock); - prune_result = filter.CheckStatistics(state.current->stats.statistics); + prune_result = filter.CheckStatistics(state.current->node->stats.statistics); if (prune_result == FilterPropagateResult::NO_PRUNING_POSSIBLE) { return FilterPropagateResult::NO_PRUNING_POSSIBLE; } @@ -478,28 +483,31 @@ void ColumnData::InitializeAppend(ColumnAppendState &state) { AppendTransientSegment(l, start); } auto segment = data.GetLastSegment(l); - if (segment->segment_type == ColumnSegmentType::PERSISTENT || !segment->GetCompressionFunction().init_append) { + auto &last_segment = *segment->node; + if (last_segment.segment_type == ColumnSegmentType::PERSISTENT || + !last_segment.GetCompressionFunction().init_append) { // we cannot append to this segment - append a new segment - auto total_rows = segment->start + segment->count; + auto total_rows = last_segment.start + last_segment.count; AppendTransientSegment(l, total_rows); state.current = data.GetLastSegment(l); } else { state.current = segment; } - - D_ASSERT(state.current->segment_type == ColumnSegmentType::TRANSIENT); - state.current->InitializeAppend(state); - D_ASSERT(state.current->GetCompressionFunction().append); + auto &append_segment = *state.current->node; + D_ASSERT(append_segment.segment_type == ColumnSegmentType::TRANSIENT); + append_segment.InitializeAppend(state); + D_ASSERT(append_segment.GetCompressionFunction().append); } void ColumnData::AppendData(BaseStatistics &append_stats, ColumnAppendState &state, UnifiedVectorFormat &vdata, idx_t append_count) { idx_t offset = 0; - this->count += append_count; while (true) { // append the data from the vector - idx_t copied_elements = state.current->Append(state, vdata, offset, append_count); - append_stats.Merge(state.current->stats.statistics); + auto &append_segment = *state.current->node; + idx_t copied_elements = append_segment.Append(state, vdata, offset, append_count); + this->count += copied_elements; + append_stats.Merge(append_segment.stats.statistics); if (copied_elements == append_count) { // finished copying everything break; @@ -508,9 +516,9 @@ void ColumnData::AppendData(BaseStatistics &append_stats, ColumnAppendState &sta // we couldn't fit everything we wanted in the current column segment, create a new one { auto l = data.Lock(); - AppendTransientSegment(l, state.current->start + state.current->count); + AppendTransientSegment(l, append_segment.start + append_segment.count); state.current = data.GetLastSegment(l); - state.current->InitializeAppend(state); + state.current->node->InitializeAppend(state); } offset += copied_elements; append_count -= copied_elements; @@ -521,19 +529,20 @@ void ColumnData::RevertAppend(row_t start_row_p) { idx_t start_row = NumericCast(start_row_p); auto l = data.Lock(); // check if this row is in the segment tree at all - auto last_segment = data.GetLastSegment(l); - if (!last_segment) { + auto last_segment_node = data.GetLastSegment(l); + if (!last_segment_node) { return; } - if (start_row >= last_segment->start + last_segment->count) { + auto &last_segment = *last_segment_node->node; + if (start_row >= last_segment.start + last_segment.count) { // the start row is equal to the final portion of the column data: nothing was ever appended here - D_ASSERT(start_row == last_segment->start + last_segment->count); + D_ASSERT(start_row == last_segment.start + last_segment.count); return; } // find the segment index that the current row belongs to idx_t segment_index = data.GetSegmentIndex(l, start_row); auto segment = data.GetSegmentByIndex(l, UnsafeNumericCast(segment_index)); - if (segment->start == start_row) { + if (segment->node->start == start_row) { // we are truncating exactly this segment - erase it entirely data.EraseSegments(l, segment_index); } else { @@ -541,7 +550,7 @@ void ColumnData::RevertAppend(row_t start_row_p) { // remove any segments AFTER this segment: they should be deleted entirely data.EraseSegments(l, segment_index + 1); - auto &transient = *segment; + auto &transient = *segment->node; D_ASSERT(transient.segment_type == ColumnSegmentType::TRANSIENT); segment->next = nullptr; transient.RevertAppend(start_row); @@ -557,7 +566,7 @@ idx_t ColumnData::Fetch(ColumnScanState &state, row_t row_id, Vector &result) { state.row_index = start + ((UnsafeNumericCast(row_id) - start) / STANDARD_VECTOR_SIZE * STANDARD_VECTOR_SIZE); state.current = data.GetSegment(state.row_index); - state.internal_index = state.current->start; + state.internal_index = state.current->node->start; return ScanVector(state, result, STANDARD_VECTOR_SIZE, ScanVectorType::SCAN_FLAT_VECTOR); } @@ -566,7 +575,7 @@ void ColumnData::FetchRow(TransactionData transaction, ColumnFetchState &state, auto segment = data.GetSegment(UnsafeNumericCast(row_id)); // now perform the fetch within the segment - segment->FetchRow(state, row_id, result, result_idx); + segment->node->FetchRow(state, row_id, result, result_idx); // merge any updates made to this row FetchUpdateRow(transaction, row_id, result, result_idx); @@ -578,24 +587,23 @@ idx_t ColumnData::FetchUpdateData(ColumnScanState &state, row_t *row_ids, Vector return fetch_count; } -void ColumnData::Update(TransactionData transaction, idx_t column_index, Vector &update_vector, row_t *row_ids, - idx_t update_count) { +void ColumnData::Update(TransactionData transaction, DataTable &data_table, idx_t column_index, Vector &update_vector, + row_t *row_ids, idx_t update_count) { Vector base_vector(type); ColumnScanState state; FetchUpdateData(state, row_ids, base_vector); - UpdateInternal(transaction, column_index, update_vector, row_ids, update_count, base_vector); + UpdateInternal(transaction, data_table, column_index, update_vector, row_ids, update_count, base_vector); } -void ColumnData::UpdateColumn(TransactionData transaction, const vector &column_path, Vector &update_vector, - row_t *row_ids, idx_t update_count, idx_t depth) { +void ColumnData::UpdateColumn(TransactionData transaction, DataTable &data_table, const vector &column_path, + Vector &update_vector, row_t *row_ids, idx_t update_count, idx_t depth) { // this method should only be called at the end of the path in the base column case D_ASSERT(depth >= column_path.size()); - ColumnData::Update(transaction, column_path[0], update_vector, row_ids, update_count); + ColumnData::Update(transaction, data_table, column_path[0], update_vector, row_ids, update_count); } void ColumnData::AppendTransientSegment(SegmentLock &l, idx_t start_row) { - const auto block_size = block_manager.GetBlockSize(); const auto type_size = GetTypeIdSize(type.InternalType()); auto vector_segment_size = block_size; @@ -673,7 +681,8 @@ void ColumnData::CheckpointScan(ColumnSegment &segment, ColumnScanState &state, unique_ptr ColumnData::Checkpoint(RowGroup &row_group, ColumnCheckpointInfo &checkpoint_info) { // scan the segments of the column data // set up the checkpoint state - auto checkpoint_state = CreateCheckpointState(row_group, checkpoint_info.info.manager); + auto &partial_block_manager = checkpoint_info.GetPartialBlockManager(); + auto checkpoint_state = CreateCheckpointState(row_group, partial_block_manager); checkpoint_state->global_stats = BaseStatistics::CreateEmpty(type).ToUnique(); auto &nodes = data.ReferenceSegments(); @@ -699,6 +708,7 @@ void ColumnData::InitializeColumn(PersistentColumnData &column_data, BaseStatist this->count = 0; for (auto &data_pointer : column_data.pointers) { // Update the count and statistics + data_pointer.row_start = start + count; this->count += data_pointer.tuple_count; // Merge the statistics. If this is a child column, the target_stats reference will point into the parents stats @@ -909,7 +919,7 @@ shared_ptr ColumnData::Deserialize(BlockManager &block_manager, Data return entry; } -void ColumnData::GetColumnSegmentInfo(idx_t row_group_index, vector col_path, +void ColumnData::GetColumnSegmentInfo(const QueryContext &context, idx_t row_group_index, vector col_path, vector &result) { D_ASSERT(!col_path.empty()); @@ -925,40 +935,39 @@ void ColumnData::GetColumnSegmentInfo(idx_t row_group_index, vector col_p // iterate over the segments idx_t segment_idx = 0; - auto segment = data.GetRootSegment(); - while (segment) { + for (auto &segment : data.Segments()) { ColumnSegmentInfo column_info; column_info.row_group_index = row_group_index; column_info.column_id = col_path[0]; column_info.column_path = col_path_str; column_info.segment_idx = segment_idx; column_info.segment_type = type.ToString(); - column_info.segment_start = segment->start; - column_info.segment_count = segment->count; - column_info.compression_type = CompressionTypeToString(segment->GetCompressionFunction().type); + column_info.segment_start = segment.start; + column_info.segment_count = segment.count; + column_info.compression_type = CompressionTypeToString(segment.GetCompressionFunction().type); { lock_guard l(stats_lock); - column_info.segment_stats = segment->stats.statistics.ToString(); + column_info.segment_stats = segment.stats.statistics.ToString(); } column_info.has_updates = ColumnData::HasUpdates(); // persistent // block_id // block_offset - if (segment->segment_type == ColumnSegmentType::PERSISTENT) { + if (segment.segment_type == ColumnSegmentType::PERSISTENT) { column_info.persistent = true; - column_info.block_id = segment->GetBlockId(); - column_info.block_offset = segment->GetBlockOffset(); + column_info.block_id = segment.GetBlockId(); + column_info.block_offset = segment.GetBlockOffset(); } else { column_info.persistent = false; } - auto &compression_function = segment->GetCompressionFunction(); - auto segment_state = segment->GetSegmentState(); + auto &compression_function = segment.GetCompressionFunction(); + auto segment_state = segment.GetSegmentState(); if (segment_state) { column_info.segment_info = segment_state->GetSegmentInfo(); column_info.additional_blocks = segment_state->GetAdditionalBlocks(); } if (compression_function.get_segment_info) { - auto segment_info = compression_function.get_segment_info(*segment); + auto segment_info = compression_function.get_segment_info(context, segment); vector sinfo; for (auto &item : segment_info) { auto &mode = item.first; @@ -970,7 +979,6 @@ void ColumnData::GetColumnSegmentInfo(idx_t row_group_index, vector col_p result.emplace_back(column_info); segment_idx++; - segment = data.GetNextSegment(segment); } } @@ -986,11 +994,11 @@ void ColumnData::Verify(RowGroup &parent) { idx_t current_index = 0; idx_t current_start = this->start; idx_t total_count = 0; - for (auto &segment : data.Segments()) { + for (auto &segment : data.SegmentNodes()) { D_ASSERT(segment.index == current_index); - D_ASSERT(segment.start == current_start); - current_start += segment.count; - total_count += segment.count; + D_ASSERT(segment.row_start == current_start); + current_start += segment.node->count; + total_count += segment.node->count; current_index++; } D_ASSERT(this->count == total_count); diff --git a/src/duckdb/src/storage/table/column_data_checkpointer.cpp b/src/duckdb/src/storage/table/column_data_checkpointer.cpp index 68c35f842..0ad658f2c 100644 --- a/src/duckdb/src/storage/table/column_data_checkpointer.cpp +++ b/src/duckdb/src/storage/table/column_data_checkpointer.cpp @@ -65,7 +65,6 @@ ColumnDataCheckpointer::ColumnDataCheckpointer(vectorCommitDropSegment(); + auto &segment = *nodes[segment_idx]->node; + segment.CommitDropSegment(); } } } @@ -374,12 +375,12 @@ void ColumnDataCheckpointer::WritePersistentSegments(ColumnCheckpointState &stat idx_t current_row = row_group.start; for (idx_t segment_idx = 0; segment_idx < nodes.size(); segment_idx++) { - auto segment = nodes[segment_idx].node.get(); - if (segment->start != current_row) { + auto &segment = *nodes[segment_idx]->node; + if (segment.start != current_row) { string extra_info; for (auto &s : nodes) { extra_info += "\n"; - extra_info += StringUtil::Format("Start %d, count %d", s.node->start, s.node->count.load()); + extra_info += StringUtil::Format("Start %d, count %d", segment.start, segment.count.load()); } const_reference root = col_data; while (root.get().HasParent()) { @@ -389,18 +390,18 @@ void ColumnDataCheckpointer::WritePersistentSegments(ColumnCheckpointState &stat "Failure in RowGroup::Checkpoint - column data pointer is unaligned with row group " "start\nRow group start: %d\nRow group count %d\nCurrent row: %d\nSegment start: %d\nColumn index: " "%d\nColumn type: %s\nRoot type: %s\nTable: %s.%s\nAll segments:%s", - row_group.start, row_group.count.load(), current_row, segment->start, root.get().column_index, + row_group.start, row_group.count.load(), current_row, segment.start, root.get().column_index, col_data.type, root.get().type, root.get().info.GetSchemaName(), root.get().info.GetTableName(), extra_info); } - current_row += segment->count; - auto pointer = segment->GetDataPointer(); + current_row += segment.count; + auto pointer = segment.GetDataPointer(); // merge the persistent stats into the global column stats - state.global_stats->Merge(segment->stats.statistics); + state.global_stats->Merge(segment.stats.statistics); // directly append the current segment to the new tree - state.new_tree.AppendSegment(std::move(nodes[segment_idx].node)); + state.new_tree.AppendSegment(std::move(nodes[segment_idx]->node)); state.data_pointers.push_back(std::move(pointer)); } @@ -447,7 +448,7 @@ void ColumnDataCheckpointer::FinalizeCheckpoint() { auto new_segments = state.new_tree.MoveSegments(); auto l = col_data.data.Lock(); for (auto &new_segment : new_segments) { - col_data.AppendSegment(l, std::move(new_segment.node)); + col_data.AppendSegment(l, std::move(new_segment->node)); } col_data.ClearUpdates(); } diff --git a/src/duckdb/src/storage/table/column_segment.cpp b/src/duckdb/src/storage/table/column_segment.cpp index 347463fbe..df0c665f9 100644 --- a/src/duckdb/src/storage/table/column_segment.cpp +++ b/src/duckdb/src/storage/table/column_segment.cpp @@ -30,7 +30,6 @@ unique_ptr ColumnSegment::CreatePersistentSegment(DatabaseInstanc CompressionType compression_type, BaseStatistics statistics, unique_ptr segment_state) { - auto &config = DBConfig::GetConfig(db); optional_ptr function; shared_ptr block; @@ -48,7 +47,6 @@ unique_ptr ColumnSegment::CreatePersistentSegment(DatabaseInstanc unique_ptr ColumnSegment::CreateTransientSegment(DatabaseInstance &db, CompressionFunction &function, const LogicalType &type, const idx_t start, const idx_t segment_size, BlockManager &block_manager) { - // Allocate a buffer for the uncompressed segment. auto &buffer_manager = BufferManager::GetBufferManager(db); D_ASSERT(&buffer_manager == &block_manager.buffer_manager); @@ -70,7 +68,6 @@ ColumnSegment::ColumnSegment(DatabaseInstance &db, shared_ptr block : SegmentBase(start, count), db(db), type(type), type_size(GetTypeIdSize(type.InternalType())), segment_type(segment_type), stats(std::move(statistics)), block(std::move(block_p)), function(function_p), block_id(block_id_p), offset(offset), segment_size(segment_size_p) { - if (function.get().init_segment) { segment_state = function.get().init_segment(*this, block_id, segment_state_p.get()); } @@ -80,12 +77,10 @@ ColumnSegment::ColumnSegment(DatabaseInstance &db, shared_ptr block } ColumnSegment::ColumnSegment(ColumnSegment &other, const idx_t start) - : SegmentBase(start, other.count.load()), db(other.db), type(std::move(other.type)), type_size(other.type_size), segment_type(other.segment_type), stats(std::move(other.stats)), block(std::move(other.block)), function(other.function), block_id(other.block_id), offset(other.offset), segment_size(other.segment_size), segment_state(std::move(other.segment_state)) { - // For constant segments (CompressionType::COMPRESSION_CONSTANT) the block is a nullptr. D_ASSERT(!block || segment_size <= GetBlockManager().GetBlockSize()); } @@ -109,7 +104,7 @@ void ColumnSegment::InitializePrefetch(PrefetchState &prefetch_state, ColumnScan } void ColumnSegment::InitializeScan(ColumnScanState &state) { - state.scan_state = function.get().init_scan(*this); + state.scan_state = function.get().init_scan(state.context, *this); } void ColumnSegment::Scan(ColumnScanState &state, idx_t scan_count, Vector &result, idx_t result_offset, diff --git a/src/duckdb/src/storage/table/list_column_data.cpp b/src/duckdb/src/storage/table/list_column_data.cpp index 7685d16ca..0d9793e9c 100644 --- a/src/duckdb/src/storage/table/list_column_data.cpp +++ b/src/duckdb/src/storage/table/list_column_data.cpp @@ -58,7 +58,7 @@ uint64_t ListColumnData::FetchListOffset(idx_t row_idx) { auto segment = data.GetSegment(row_idx); ColumnFetchState fetch_state; Vector result(LogicalType::UBIGINT, 1); - segment->FetchRow(fetch_state, UnsafeNumericCast(row_idx), result, 0U); + segment->node->FetchRow(fetch_state, UnsafeNumericCast(row_idx), result, 0U); // initialize the child scan with the required offset return FlatVector::GetData(result)[0]; @@ -263,13 +263,14 @@ idx_t ListColumnData::Fetch(ColumnScanState &state, row_t row_id, Vector &result throw NotImplementedException("List Fetch"); } -void ListColumnData::Update(TransactionData transaction, idx_t column_index, Vector &update_vector, row_t *row_ids, - idx_t update_count) { +void ListColumnData::Update(TransactionData transaction, DataTable &data_table, idx_t column_index, + Vector &update_vector, row_t *row_ids, idx_t update_count) { throw NotImplementedException("List Update is not supported."); } -void ListColumnData::UpdateColumn(TransactionData transaction, const vector &column_path, - Vector &update_vector, row_t *row_ids, idx_t update_count, idx_t depth) { +void ListColumnData::UpdateColumn(TransactionData transaction, DataTable &data_table, + const vector &column_path, Vector &update_vector, row_t *row_ids, + idx_t update_count, idx_t depth) { throw NotImplementedException("List Update Column is not supported"); } @@ -312,7 +313,7 @@ void ListColumnData::FetchRow(TransactionData transaction, ColumnFetchState &sta auto &child_type = ListType::GetChildType(result.GetType()); Vector child_scan(child_type, child_scan_count); // seek the scan towards the specified position and read [length] entries - child_state->Initialize(child_type, nullptr); + child_state->Initialize(state.context, child_type, nullptr); child_column->InitializeScanWithOffset(*child_state, start + start_offset); D_ASSERT(child_type.InternalType() == PhysicalType::STRUCT || child_state->row_index + child_scan_count - this->start <= child_column->GetMaxEntry()); @@ -391,13 +392,13 @@ void ListColumnData::InitializeColumn(PersistentColumnData &column_data, BaseSta child_column->InitializeColumn(column_data.child_columns[1], child_stats); } -void ListColumnData::GetColumnSegmentInfo(duckdb::idx_t row_group_index, vector col_path, - vector &result) { - ColumnData::GetColumnSegmentInfo(row_group_index, col_path, result); +void ListColumnData::GetColumnSegmentInfo(const QueryContext &context, idx_t row_group_index, vector col_path, + vector &result) { + ColumnData::GetColumnSegmentInfo(context, row_group_index, col_path, result); col_path.push_back(0); - validity.GetColumnSegmentInfo(row_group_index, col_path, result); + validity.GetColumnSegmentInfo(context, row_group_index, col_path, result); col_path.back() = 1; - child_column->GetColumnSegmentInfo(row_group_index, col_path, result); + child_column->GetColumnSegmentInfo(context, row_group_index, col_path, result); } } // namespace duckdb diff --git a/src/duckdb/src/storage/table/row_group.cpp b/src/duckdb/src/storage/table/row_group.cpp index 40e20d2d4..659555bed 100644 --- a/src/duckdb/src/storage/table/row_group.cpp +++ b/src/duckdb/src/storage/table/row_group.cpp @@ -26,14 +26,14 @@ namespace duckdb { RowGroup::RowGroup(RowGroupCollection &collection_p, idx_t start, idx_t count) - : SegmentBase(start, count), collection(collection_p), version_info(nullptr), allocation_size(0), - row_id_is_loaded(false), has_changes(false) { + : SegmentBase(start, count), collection(collection_p), version_info(nullptr), deletes_is_loaded(false), + allocation_size(0), row_id_is_loaded(false), has_changes(false) { Verify(); } RowGroup::RowGroup(RowGroupCollection &collection_p, RowGroupPointer pointer) : SegmentBase(pointer.row_start, pointer.tuple_count), collection(collection_p), version_info(nullptr), - allocation_size(0), row_id_is_loaded(false), has_changes(false) { + deletes_is_loaded(false), allocation_size(0), row_id_is_loaded(false), has_changes(false) { // deserialize the columns if (pointer.data_pointers.size() != collection_p.GetTypes().size()) { throw IOException("Row group column count is unaligned with table column count. Corrupt file?"); @@ -45,7 +45,6 @@ RowGroup::RowGroup(RowGroupCollection &collection_p, RowGroupPointer pointer) this->is_loaded[c] = false; } this->deletes_pointers = std::move(pointer.deletes_pointers); - this->deletes_is_loaded = false; this->has_metadata_blocks = pointer.has_metadata_blocks; this->extra_metadata_blocks = std::move(pointer.extra_metadata_blocks); @@ -54,7 +53,7 @@ RowGroup::RowGroup(RowGroupCollection &collection_p, RowGroupPointer pointer) RowGroup::RowGroup(RowGroupCollection &collection_p, PersistentRowGroupData &data) : SegmentBase(data.start, data.count), collection(collection_p), version_info(nullptr), - allocation_size(0), row_id_is_loaded(false), has_changes(false) { + deletes_is_loaded(false), allocation_size(0), row_id_is_loaded(false), has_changes(false) { auto &block_manager = GetBlockManager(); auto &info = GetTableInfo(); auto &types = collection.get().GetTypes(); @@ -86,12 +85,6 @@ void RowGroup::MoveToCollection(RowGroupCollection &collection_p, idx_t new_star if (row_id_is_loaded) { row_id_column_data->SetStart(new_start); } - if (!HasUnloadedDeletes()) { - auto vinfo = GetVersionInfo(); - if (vinfo) { - vinfo->SetStart(new_start); - } - } } RowGroup::~RowGroup() { @@ -183,10 +176,11 @@ void RowGroup::InitializeEmpty(const vector &types) { } } -void ColumnScanState::Initialize(const LogicalType &type, const vector &children, - optional_ptr options) { +void ColumnScanState::Initialize(const QueryContext &context_p, const LogicalType &type, + const vector &children, optional_ptr options) { // Register the options in the state scan_options = options; + context = context_p; if (type.id() == LogicalTypeId::VALIDITY) { // validity - nothing to initialize @@ -201,7 +195,7 @@ void ColumnScanState::Initialize(const LogicalType &type, const vector options) { +void ColumnScanState::Initialize(const QueryContext &context_p, const LogicalType &type, + optional_ptr options) { vector children; - Initialize(type, children, options); + Initialize(context_p, type, children, options); } -void CollectionScanState::Initialize(const vector &types) { +void CollectionScanState::Initialize(const QueryContext &context, const vector &types) { auto &column_ids = GetColumnIds(); column_scans = make_unsafe_uniq_array(column_ids.size()); for (idx_t i = 0; i < column_ids.size(); i++) { @@ -245,18 +240,21 @@ void CollectionScanState::Initialize(const vector &types) { continue; } auto col_id = column_ids[i].GetPrimaryIndex(); - column_scans[i].Initialize(types[col_id], column_ids[i].GetChildIndexes(), &GetOptions()); + column_scans[i].Initialize(context, types[col_id], column_ids[i].GetChildIndexes(), &GetOptions()); } } -bool RowGroup::InitializeScanWithOffset(CollectionScanState &state, idx_t vector_offset) { +bool RowGroup::InitializeScanWithOffset(CollectionScanState &state, SegmentNode &node, idx_t vector_offset) { auto &column_ids = state.GetColumnIds(); auto &filters = state.GetFilterInfo(); if (!CheckZonemap(filters)) { return false; } + if (!RefersToSameObject(*node.node, *this)) { + throw InternalException("RowGroup::InitializeScanWithOffset segment node mismatch"); + } - state.row_group = this; + state.row_group = node; state.vector_index = vector_offset; state.max_row_group_row = this->start > state.max_row ? 0 : MinValue(this->count, state.max_row - this->start); @@ -275,13 +273,16 @@ bool RowGroup::InitializeScanWithOffset(CollectionScanState &state, idx_t vector return true; } -bool RowGroup::InitializeScan(CollectionScanState &state) { +bool RowGroup::InitializeScan(CollectionScanState &state, SegmentNode &node) { auto &column_ids = state.GetColumnIds(); auto &filters = state.GetFilterInfo(); if (!CheckZonemap(filters)) { return false; } - state.row_group = this; + if (!RefersToSameObject(*node.node, *this)) { + throw InternalException("RowGroup::InitializeScan segment node mismatch"); + } + state.row_group = node; state.vector_index = 0; state.max_row_group_row = this->start > state.max_row ? 0 : MinValue(this->count, state.max_row - this->start); @@ -300,7 +301,8 @@ bool RowGroup::InitializeScan(CollectionScanState &state) { unique_ptr RowGroup::AlterType(RowGroupCollection &new_collection, const LogicalType &target_type, idx_t changed_idx, ExpressionExecutor &executor, - CollectionScanState &scan_state, DataChunk &scan_chunk) { + CollectionScanState &scan_state, SegmentNode &node, + DataChunk &scan_chunk) { Verify(); // construct a new column data for this type @@ -310,8 +312,8 @@ unique_ptr RowGroup::AlterType(RowGroupCollection &new_collection, con column_data->InitializeAppend(append_state); // scan the original table, and fill the new column with the transformed value - scan_state.Initialize(GetCollection().GetTypes()); - InitializeScan(scan_state); + scan_state.Initialize(executor.GetContext(), GetCollection().GetTypes()); + InitializeScan(scan_state, node); DataChunk append_chunk; vector append_types; @@ -478,7 +480,7 @@ bool RowGroup::CheckZonemapSegments(CollectionScanState &state) { // no segment to skip continue; } - idx_t target_row = current_segment->start + current_segment->count; + idx_t target_row = current_segment->node->start + current_segment->node->count; if (target_row >= state.max_row) { target_row = state.max_row; } @@ -529,19 +531,20 @@ void RowGroup::TemplatedScan(TransactionData transaction, CollectionScanState &s if (!CheckZonemapSegments(state)) { continue; } + auto ¤t_row_group = *state.row_group->node; // second, scan the version chunk manager to figure out which tuples to load for this transaction idx_t count; if (TYPE == TableScanType::TABLE_SCAN_REGULAR) { - count = state.row_group->GetSelVector(transaction, state.vector_index, state.valid_sel, max_count); + count = current_row_group.GetSelVector(transaction, state.vector_index, state.valid_sel, max_count); if (count == 0) { // nothing to scan for this vector, skip the entire vector NextVector(state); continue; } } else if (TYPE == TableScanType::TABLE_SCAN_COMMITTED_ROWS_OMIT_PERMANENTLY_DELETED) { - count = state.row_group->GetCommittedSelVector(transaction.start_time, transaction.transaction_id, - state.vector_index, state.valid_sel, max_count); + count = current_row_group.GetCommittedSelVector(transaction.start_time, transaction.transaction_id, + state.vector_index, state.valid_sel, max_count); if (count == 0) { // nothing to scan for this vector, skip the entire vector NextVector(state); @@ -706,7 +709,7 @@ optional_ptr RowGroup::GetVersionInfo() { } // deletes are not loaded - reload auto root_delete = deletes_pointers[0]; - auto loaded_info = RowVersionManager::Deserialize(root_delete, GetBlockManager().GetMetadataManager(), start); + auto loaded_info = RowVersionManager::Deserialize(root_delete, GetBlockManager().GetMetadataManager()); SetVersionInfo(std::move(loaded_info)); deletes_is_loaded = true; return version_info; @@ -721,7 +724,8 @@ shared_ptr RowGroup::GetOrCreateVersionInfoInternal() { // version info does not exist - need to create it lock_guard lock(row_group_lock); if (!owned_version_info) { - auto new_info = make_shared_ptr(start); + auto &buffer_manager = GetBlockManager().GetBufferManager(); + auto new_info = make_shared_ptr(buffer_manager); SetVersionInfo(std::move(new_info)); } return owned_version_info; @@ -852,8 +856,8 @@ void RowGroup::CleanupAppend(transaction_t lowest_transaction, idx_t start, idx_ vinfo.CleanupAppend(lowest_transaction, start, count); } -void RowGroup::Update(TransactionData transaction, DataChunk &update_chunk, row_t *ids, idx_t offset, idx_t count, - const vector &column_ids) { +void RowGroup::Update(TransactionData transaction, DataTable &data_table, DataChunk &update_chunk, row_t *ids, + idx_t offset, idx_t count, const vector &column_ids) { #ifdef DEBUG for (size_t i = offset; i < offset + count; i++) { D_ASSERT(ids[i] >= row_t(this->start) && ids[i] < row_t(this->start + this->count)); @@ -866,16 +870,16 @@ void RowGroup::Update(TransactionData transaction, DataChunk &update_chunk, row_ if (offset > 0) { Vector sliced_vector(update_chunk.data[i], offset, offset + count); sliced_vector.Flatten(count); - col_data.Update(transaction, column.index, sliced_vector, ids + offset, count); + col_data.Update(transaction, data_table, column.index, sliced_vector, ids + offset, count); } else { - col_data.Update(transaction, column.index, update_chunk.data[i], ids, count); + col_data.Update(transaction, data_table, column.index, update_chunk.data[i], ids, count); } MergeStatistics(column.index, *col_data.GetUpdateStatistics()); } } -void RowGroup::UpdateColumn(TransactionData transaction, DataChunk &updates, Vector &row_ids, idx_t offset, idx_t count, - const vector &column_path) { +void RowGroup::UpdateColumn(TransactionData transaction, DataTable &data_table, DataChunk &updates, Vector &row_ids, + idx_t offset, idx_t count, const vector &column_path) { D_ASSERT(updates.ColumnCount() == 1); auto ids = FlatVector::GetData(row_ids); @@ -885,9 +889,9 @@ void RowGroup::UpdateColumn(TransactionData transaction, DataChunk &updates, Vec if (offset > 0) { Vector sliced_vector(updates.data[0], offset, offset + count); sliced_vector.Flatten(count); - col_data.UpdateColumn(transaction, column_path, sliced_vector, ids + offset, count, 1); + col_data.UpdateColumn(transaction, data_table, column_path, sliced_vector, ids + offset, count, 1); } else { - col_data.UpdateColumn(transaction, column_path, updates.data[0], ids, count, 1); + col_data.UpdateColumn(transaction, data_table, column_path, updates.data[0], ids, count, 1); } MergeStatistics(primary_column_idx, *col_data.GetUpdateStatistics()); } @@ -914,6 +918,32 @@ void RowGroup::MergeIntoStatistics(TableStatistics &other) { } } +ColumnCheckpointInfo::ColumnCheckpointInfo(RowGroupWriteInfo &info, idx_t column_idx) + : column_idx(column_idx), info(info) { +} + +RowGroupWriteInfo::RowGroupWriteInfo(PartialBlockManager &manager, const vector &compression_types, + CheckpointType checkpoint_type) + : manager(manager), compression_types(compression_types), checkpoint_type(checkpoint_type) { +} + +RowGroupWriteInfo::RowGroupWriteInfo(PartialBlockManager &manager, const vector &compression_types, + vector> &column_partial_block_managers_p) + : manager(manager), compression_types(compression_types), checkpoint_type(CheckpointType::FULL_CHECKPOINT), + column_partial_block_managers(column_partial_block_managers_p) { +} + +PartialBlockManager &RowGroupWriteInfo::GetPartialBlockManager(idx_t column_idx) { + if (column_partial_block_managers && !column_partial_block_managers->empty()) { + return *column_partial_block_managers->at(column_idx); + } + return manager; +} + +PartialBlockManager &ColumnCheckpointInfo::GetPartialBlockManager() { + return info.GetPartialBlockManager(column_idx); +} + CompressionType ColumnCheckpointInfo::GetCompressionType() { return info.compression_types[column_idx]; } @@ -991,21 +1021,15 @@ bool RowGroup::HasUnloadedDeletes() const { return !deletes_is_loaded; } -vector RowGroup::GetColumnPointers() { - if (has_metadata_blocks) { - // we have the column metadata from the file itself - no need to deserialize metadata to fetch it - // read if from "column_pointers" and "extra_metadata_blocks" - auto result = column_pointers; - for (auto &block_pointer : extra_metadata_blocks) { - result.emplace_back(block_pointer, 0); - } - return result; +vector RowGroup::GetOrComputeExtraMetadataBlocks(bool force_compute) { + if (has_metadata_blocks && !force_compute) { + return extra_metadata_blocks; } - vector result; if (column_pointers.empty()) { // no pointers - return result; + return {}; } + vector read_pointers; // column_pointers stores the beginning of each column // if columns are big - they may span multiple metadata blocks // we need to figure out all blocks that this row group points to @@ -1016,13 +1040,25 @@ vector RowGroup::GetColumnPointers() { // for all but the last column pointer - we can just follow the linked list until we reach the last column MetadataReader reader(metadata_manager, column_pointers[0]); auto last_pointer = column_pointers[last_idx]; - result = reader.GetRemainingBlocks(last_pointer); + read_pointers = reader.GetRemainingBlocks(last_pointer); } // for the last column we need to deserialize the column - because we don't know where it stops auto &types = GetCollection().GetTypes(); - MetadataReader reader(metadata_manager, column_pointers[last_idx], &result); + MetadataReader reader(metadata_manager, column_pointers[last_idx], &read_pointers); ColumnData::Deserialize(GetBlockManager(), GetTableInfo(), last_idx, start, reader, types[last_idx]); - return result; + + unordered_set result_as_set; + for (auto &ptr : read_pointers) { + result_as_set.emplace(ptr.block_pointer); + } + for (auto &ptr : column_pointers) { + result_as_set.erase(ptr.block_pointer); + } + return {result_as_set.begin(), result_as_set.end()}; +} + +const vector &RowGroup::GetColumnStartPointers() const { + return column_pointers; } RowGroupWriteData RowGroup::WriteToDisk(RowGroupWriter &writer) { @@ -1031,7 +1067,8 @@ RowGroupWriteData RowGroup::WriteToDisk(RowGroupWriter &writer) { // we have existing metadata and the row group has not been changed // re-use previous metadata RowGroupWriteData result; - result.existing_pointers = GetColumnPointers(); + result.reuse_existing_metadata_blocks = true; + result.existing_extra_metadata_blocks = GetOrComputeExtraMetadataBlocks(); return result; } auto &compression_types = writer.GetCompressionTypes(); @@ -1059,14 +1096,26 @@ RowGroupPointer RowGroup::Checkpoint(RowGroupWriteData write_data, RowGroupWrite // construct the row group pointer and write the column meta data to disk row_group_pointer.row_start = start; row_group_pointer.tuple_count = count; - if (!write_data.existing_pointers.empty()) { + if (write_data.reuse_existing_metadata_blocks) { // we are re-using the previous metadata row_group_pointer.data_pointers = column_pointers; - row_group_pointer.has_metadata_blocks = has_metadata_blocks; - row_group_pointer.extra_metadata_blocks = extra_metadata_blocks; + row_group_pointer.has_metadata_blocks = true; + row_group_pointer.extra_metadata_blocks = write_data.existing_extra_metadata_blocks; row_group_pointer.deletes_pointers = deletes_pointers; - metadata_manager->ClearModifiedBlocks(write_data.existing_pointers); - metadata_manager->ClearModifiedBlocks(deletes_pointers); + if (metadata_manager) { + vector extra_metadata_block_pointers; + extra_metadata_block_pointers.reserve(write_data.existing_extra_metadata_blocks.size()); + for (auto &block_pointer : write_data.existing_extra_metadata_blocks) { + extra_metadata_block_pointers.emplace_back(block_pointer, 0); + } + metadata_manager->ClearModifiedBlocks(column_pointers); + metadata_manager->ClearModifiedBlocks(extra_metadata_block_pointers); + metadata_manager->ClearModifiedBlocks(deletes_pointers); + + // remember metadata_blocks to avoid loading them on future checkpoints + has_metadata_blocks = true; + extra_metadata_blocks = row_group_pointer.extra_metadata_blocks; + } return row_group_pointer; } D_ASSERT(write_data.states.size() == columns.size()); @@ -1109,6 +1158,7 @@ RowGroupPointer RowGroup::Checkpoint(RowGroupWriteData write_data, RowGroupWrite } // this metadata block is not stored - add it to the extra metadata blocks row_group_pointer.extra_metadata_blocks.push_back(column_pointer.block_pointer); + metadata_blocks.insert(column_pointer.block_pointer); } // set up the pointers correctly within this row group for future operations column_pointers = row_group_pointer.data_pointers; @@ -1130,6 +1180,7 @@ bool RowGroup::HasChanges() const { // we have deletes return true; } + D_ASSERT(!deletes_is_loaded.load()); // check if any of the columns have changes // avoid loading unloaded columns - unloaded columns can never have changes for (idx_t c = 0; c < columns.size(); c++) { @@ -1220,10 +1271,11 @@ PartitionStatistics RowGroup::GetPartitionStats() const { //===--------------------------------------------------------------------===// // GetColumnSegmentInfo //===--------------------------------------------------------------------===// -void RowGroup::GetColumnSegmentInfo(idx_t row_group_index, vector &result) { +void RowGroup::GetColumnSegmentInfo(const QueryContext &context, idx_t row_group_index, + vector &result) { for (idx_t col_idx = 0; col_idx < GetColumnCount(); col_idx++) { auto &col_data = GetColumn(col_idx); - col_data.GetColumnSegmentInfo(row_group_index, {col_idx}, result); + col_data.GetColumnSegmentInfo(context, row_group_index, {col_idx}, result); } } diff --git a/src/duckdb/src/storage/table/row_group_collection.cpp b/src/duckdb/src/storage/table/row_group_collection.cpp index d44ca0544..ec9c3906a 100644 --- a/src/duckdb/src/storage/table/row_group_collection.cpp +++ b/src/duckdb/src/storage/table/row_group_collection.cpp @@ -17,6 +17,7 @@ #include "duckdb/storage/table/scan_state.hpp" #include "duckdb/storage/table_storage_info.hpp" #include "duckdb/main/settings.hpp" +#include "duckdb/transaction/duck_transaction.hpp" namespace duckdb { @@ -133,7 +134,11 @@ void RowGroupCollection::AppendRowGroup(SegmentLock &l, idx_t start_row) { } optional_ptr RowGroupCollection::GetRowGroup(int64_t index) { - return row_groups->GetSegmentByIndex(index); + auto result = row_groups->GetSegmentByIndex(index); + if (!result) { + return nullptr; + } + return result->node.get(); } void RowGroupCollection::Verify() { @@ -153,15 +158,17 @@ void RowGroupCollection::Verify() { //===--------------------------------------------------------------------===// // Scan //===--------------------------------------------------------------------===// -void RowGroupCollection::InitializeScan(CollectionScanState &state, const vector &column_ids, +void RowGroupCollection::InitializeScan(const QueryContext &context, CollectionScanState &state, + const vector &column_ids, optional_ptr table_filters) { - auto row_group = row_groups->GetRootSegment(); + state.row_groups = row_groups.get(); + auto row_group = state.GetRootSegment(); D_ASSERT(row_group); state.row_groups = row_groups.get(); state.max_row = row_start + total_rows; - state.Initialize(GetTypes()); - while (row_group && !row_group->InitializeScan(state)) { - row_group = row_groups->GetNextSegment(row_group); + state.Initialize(context, GetTypes()); + while (row_group && !row_group->node->InitializeScan(state, *row_group)) { + row_group = state.GetNextRowGroup(*row_group); } } @@ -169,33 +176,35 @@ void RowGroupCollection::InitializeCreateIndexScan(CreateIndexScanState &state) state.segment_lock = row_groups->Lock(); } -void RowGroupCollection::InitializeScanWithOffset(CollectionScanState &state, const vector &column_ids, - idx_t start_row, idx_t end_row) { +void RowGroupCollection::InitializeScanWithOffset(const QueryContext &context, CollectionScanState &state, + const vector &column_ids, idx_t start_row, + idx_t end_row) { auto row_group = row_groups->GetSegment(start_row); D_ASSERT(row_group); state.row_groups = row_groups.get(); state.max_row = end_row; - state.Initialize(GetTypes()); - idx_t start_vector = (start_row - row_group->start) / STANDARD_VECTOR_SIZE; - if (!row_group->InitializeScanWithOffset(state, start_vector)) { + state.Initialize(context, GetTypes()); + idx_t start_vector = (start_row - row_group->node->start) / STANDARD_VECTOR_SIZE; + if (!row_group->node->InitializeScanWithOffset(state, *row_group, start_vector)) { throw InternalException("Failed to initialize row group scan with offset"); } } -bool RowGroupCollection::InitializeScanInRowGroup(CollectionScanState &state, RowGroupCollection &collection, - RowGroup &row_group, idx_t vector_index, idx_t max_row) { +bool RowGroupCollection::InitializeScanInRowGroup(const QueryContext &context, CollectionScanState &state, + RowGroupCollection &collection, SegmentNode &row_group, + idx_t vector_index, idx_t max_row) { state.max_row = max_row; state.row_groups = collection.row_groups.get(); if (!state.column_scans) { // initialize the scan state - state.Initialize(collection.GetTypes()); + state.Initialize(context, collection.GetTypes()); } - return row_group.InitializeScanWithOffset(state, vector_index); + return row_group.node->InitializeScanWithOffset(state, row_group, vector_index); } void RowGroupCollection::InitializeParallelScan(ParallelCollectionScanState &state) { state.collection = this; - state.current_row_group = row_groups->GetRootSegment(); + state.current_row_group = state.GetRootSegment(*row_groups); state.vector_index = 0; state.max_row = row_start + total_rows; state.batch_index = 0; @@ -208,32 +217,36 @@ bool RowGroupCollection::NextParallelScan(ClientContext &context, ParallelCollec idx_t vector_index; idx_t max_row; RowGroupCollection *collection; - RowGroup *row_group; + optional_ptr> row_group; { // select the next row group to scan from the parallel state lock_guard l(state.lock); - if (!state.current_row_group || state.current_row_group->count == 0) { + if (!state.current_row_group) { // no more data left to scan break; } + auto ¤t_row_group = *state.current_row_group->node; + if (current_row_group.count == 0) { + break; + } collection = state.collection; row_group = state.current_row_group; if (ClientConfig::GetConfig(context).verify_parallelism) { vector_index = state.vector_index; - max_row = state.current_row_group->start + - MinValue(state.current_row_group->count, + max_row = current_row_group.start + + MinValue(current_row_group.count, STANDARD_VECTOR_SIZE * state.vector_index + STANDARD_VECTOR_SIZE); - D_ASSERT(vector_index * STANDARD_VECTOR_SIZE < state.current_row_group->count); + D_ASSERT(vector_index * STANDARD_VECTOR_SIZE < current_row_group.count); state.vector_index++; - if (state.vector_index * STANDARD_VECTOR_SIZE >= state.current_row_group->count) { - state.current_row_group = row_groups->GetNextSegment(state.current_row_group); + if (state.vector_index * STANDARD_VECTOR_SIZE >= current_row_group.count) { + state.current_row_group = state.GetNextRowGroup(*row_groups, *row_group).get(); state.vector_index = 0; } } else { - state.processed_rows += state.current_row_group->count; + state.processed_rows += current_row_group.count; vector_index = 0; - max_row = state.current_row_group->start + state.current_row_group->count; - state.current_row_group = row_groups->GetNextSegment(state.current_row_group); + max_row = current_row_group.start + current_row_group.count; + state.current_row_group = state.GetNextRowGroup(*row_groups, *row_group).get(); } max_row = MinValue(max_row, state.max_row); scan_state.batch_index = ++state.batch_index; @@ -242,7 +255,8 @@ bool RowGroupCollection::NextParallelScan(ClientContext &context, ParallelCollec D_ASSERT(row_group); // initialize the scan for this row group - bool need_to_scan = InitializeScanInRowGroup(scan_state, *collection, *row_group, vector_index, max_row); + bool need_to_scan = + InitializeScanInRowGroup(context, scan_state, *collection, *row_group, vector_index, max_row); if (!need_to_scan) { // skip this row group continue; @@ -266,7 +280,7 @@ bool RowGroupCollection::Scan(DuckTransaction &transaction, const vector> row_group; { idx_t segment_index; auto l = row_groups->Lock(); @@ -309,17 +323,18 @@ void RowGroupCollection::Fetch(TransactionData transaction, DataChunk &result, c } row_group = row_groups->GetSegmentByIndex(l, UnsafeNumericCast(segment_index)); } - if (!row_group->Fetch(transaction, UnsafeNumericCast(row_id) - row_group->start)) { + auto ¤t_row_group = *row_group->node; + if (!current_row_group.Fetch(transaction, UnsafeNumericCast(row_id) - current_row_group.start)) { continue; } - row_group->FetchRow(transaction, state, column_ids, row_id, result, count); + current_row_group.FetchRow(transaction, state, column_ids, row_id, result, count); count++; } result.SetCardinality(count); } bool RowGroupCollection::CanFetch(TransactionData transaction, const row_t row_id) { - RowGroup *row_group; + optional_ptr> row_group; { idx_t segment_index; auto l = row_groups->Lock(); @@ -328,7 +343,8 @@ bool RowGroupCollection::CanFetch(TransactionData transaction, const row_t row_i } row_group = row_groups->GetSegmentByIndex(l, UnsafeNumericCast(segment_index)); } - return row_group->Fetch(transaction, UnsafeNumericCast(row_id) - row_group->start); + auto ¤t_row_group = *row_group->node; + return current_row_group.Fetch(transaction, UnsafeNumericCast(row_id) - current_row_group.start); } //===--------------------------------------------------------------------===// @@ -363,8 +379,8 @@ void RowGroupCollection::InitializeAppend(TransactionData transaction, TableAppe AppendRowGroup(l, row_start + total_rows); } state.start_row_group = row_groups->GetLastSegment(l); - D_ASSERT(this->row_start + total_rows == state.start_row_group->start + state.start_row_group->count); - state.start_row_group->InitializeAppend(state.row_group_append_state); + D_ASSERT(this->row_start + total_rows == state.start_row_group->row_start + state.start_row_group->node->count); + state.start_row_group->node->InitializeAppend(state.row_group_append_state); state.transaction = transaction; // initialize thread-local stats so we have less lock contention when updating distinct statistics @@ -415,7 +431,7 @@ bool RowGroupCollection::Append(DataChunk &chunk, TableAppendState &state) { AppendRowGroup(l, next_start); // set up the append state for this row_group auto last_row_group = row_groups->GetLastSegment(l); - last_row_group->InitializeAppend(state.row_group_append_state); + last_row_group->node->InitializeAppend(state.row_group_append_state); continue; } else { break; @@ -439,10 +455,11 @@ void RowGroupCollection::FinalizeAppend(TransactionData transaction, TableAppend auto remaining = state.total_append_count; auto row_group = state.start_row_group; while (remaining > 0) { - auto append_count = MinValue(remaining, row_group_size - row_group->count); - row_group->AppendVersionInfo(transaction, append_count); + auto ¤t_row_group = *row_group->node; + auto append_count = MinValue(remaining, row_group_size - current_row_group.count); + current_row_group.AppendVersionInfo(transaction, append_count); remaining -= append_count; - row_group = row_groups->GetNextSegment(row_group); + row_group = row_groups->GetNextSegment(*row_group); } total_rows += state.total_append_count; @@ -472,17 +489,18 @@ void RowGroupCollection::CommitAppend(transaction_t commit_id, idx_t row_start, idx_t current_row = row_start; idx_t remaining = count; while (true) { - idx_t start_in_row_group = current_row - row_group->start; - idx_t append_count = MinValue(row_group->count - start_in_row_group, remaining); + auto ¤t_row_group = *row_group->node; + idx_t start_in_row_group = current_row - current_row_group.start; + idx_t append_count = MinValue(current_row_group.count - start_in_row_group, remaining); - row_group->CommitAppend(commit_id, start_in_row_group, append_count); + current_row_group.CommitAppend(commit_id, start_in_row_group, append_count); current_row += append_count; remaining -= append_count; if (remaining == 0) { break; } - row_group = row_groups->GetNextSegment(row_group); + row_group = row_groups->GetNextSegment(*row_group); } } @@ -502,7 +520,7 @@ void RowGroupCollection::RevertAppendInternal(idx_t start_row) { segment_index = segment_count - 1; } auto &segment = *row_groups->GetSegmentByIndex(l, UnsafeNumericCast(segment_index)); - if (segment.start == start_row) { + if (segment.row_start == start_row) { // we are truncating exactly this row group - erase it entirely row_groups->EraseSegments(l, segment_index); } else { @@ -511,7 +529,7 @@ void RowGroupCollection::RevertAppendInternal(idx_t start_row) { row_groups->EraseSegments(l, segment_index + 1); segment.next = nullptr; - segment.RevertAppend(start_row); + segment.node->RevertAppend(start_row); } } @@ -521,17 +539,18 @@ void RowGroupCollection::CleanupAppend(transaction_t lowest_transaction, idx_t s idx_t current_row = start; idx_t remaining = count; while (true) { - idx_t start_in_row_group = current_row - row_group->start; - idx_t append_count = MinValue(row_group->count - start_in_row_group, remaining); + auto ¤t_row_group = *row_group->node; + idx_t start_in_row_group = current_row - current_row_group.start; + idx_t append_count = MinValue(current_row_group.count - start_in_row_group, remaining); - row_group->CleanupAppend(lowest_transaction, start_in_row_group, append_count); + current_row_group.CleanupAppend(lowest_transaction, start_in_row_group, append_count); current_row += append_count; remaining -= append_count; if (remaining == 0) { break; } - row_group = row_groups->GetNextSegment(row_group); + row_group = row_groups->GetNextSegment(*row_group); } } @@ -557,7 +576,7 @@ void RowGroupCollection::MergeStorage(RowGroupCollection &data, optional_ptrnode; if (!row_group.IsPersistent()) { break; } @@ -568,7 +587,7 @@ void RowGroupCollection::MergeStorage(RowGroupCollection &data, optional_ptrnode; row_group->MoveToCollection(*this, index); if (commit_state && (index - start_index) < optimistically_written_count) { @@ -601,19 +620,21 @@ idx_t RowGroupCollection::Delete(TransactionData transaction, DataTable &table, do { idx_t start = pos; auto row_group = row_groups->GetSegment(UnsafeNumericCast(ids[start])); + + auto ¤t_row_group = *row_group->node; for (pos++; pos < count; pos++) { D_ASSERT(ids[pos] >= 0); // check if this id still belongs to this row group - if (idx_t(ids[pos]) < row_group->start) { + if (idx_t(ids[pos]) < current_row_group.start) { // id is before row_group start -> it does not break; } - if (idx_t(ids[pos]) >= row_group->start + row_group->count) { + if (idx_t(ids[pos]) >= current_row_group.start + current_row_group.count) { // id is after row group end -> it does not break; } } - delete_count += row_group->Delete(transaction, table, ids + start, pos - start); + delete_count += current_row_group.Delete(transaction, table, ids + start, pos - start); } while (pos < count); return delete_count; @@ -622,14 +643,15 @@ idx_t RowGroupCollection::Delete(TransactionData transaction, DataTable &table, //===--------------------------------------------------------------------===// // Update //===--------------------------------------------------------------------===// -optional_ptr RowGroupCollection::NextUpdateRowGroup(row_t *ids, idx_t &pos, idx_t count) const { +optional_ptr> RowGroupCollection::NextUpdateRowGroup(row_t *ids, idx_t &pos, idx_t count) const { auto row_group = row_groups->GetSegment(UnsafeNumericCast(ids[pos])); - row_t base_id = - UnsafeNumericCast(row_group->start + ((UnsafeNumericCast(ids[pos]) - row_group->start) / - STANDARD_VECTOR_SIZE * STANDARD_VECTOR_SIZE)); - auto max_id = - MinValue(base_id + STANDARD_VECTOR_SIZE, UnsafeNumericCast(row_group->start + row_group->count)); + auto ¤t_row_group = *row_group->node; + row_t base_id = UnsafeNumericCast( + current_row_group.start + + ((UnsafeNumericCast(ids[pos]) - current_row_group.start) / STANDARD_VECTOR_SIZE * STANDARD_VECTOR_SIZE)); + auto max_id = MinValue(base_id + STANDARD_VECTOR_SIZE, + UnsafeNumericCast(current_row_group.start + current_row_group.count)); for (pos++; pos < count; pos++) { D_ASSERT(ids[pos] >= 0); // check if this id still belongs to this vector in this row group @@ -645,34 +667,39 @@ optional_ptr RowGroupCollection::NextUpdateRowGroup(row_t *ids, idx_t return row_group; } -void RowGroupCollection::Update(TransactionData transaction, row_t *ids, const vector &column_ids, - DataChunk &updates) { +void RowGroupCollection::Update(TransactionData transaction, DataTable &data_table, row_t *ids, + const vector &column_ids, DataChunk &updates) { D_ASSERT(updates.size() >= 1); idx_t pos = 0; do { idx_t start = pos; auto row_group = NextUpdateRowGroup(ids, pos, updates.size()); - row_group->Update(transaction, updates, ids, start, pos - start, column_ids); + + auto ¤t_row_group = *row_group->node; + current_row_group.Update(transaction, data_table, updates, ids, start, pos - start, column_ids); auto l = stats.GetLock(); for (idx_t i = 0; i < column_ids.size(); i++) { auto column_id = column_ids[i]; - stats.MergeStats(*l, column_id.index, *row_group->GetStatistics(column_id.index)); + stats.MergeStats(*l, column_id.index, *current_row_group.GetStatistics(column_id.index)); } } while (pos < updates.size()); } -void RowGroupCollection::RemoveFromIndexes(TableIndexList &indexes, Vector &row_identifiers, idx_t count) { +void RowGroupCollection::RemoveFromIndexes(const QueryContext &context, TableIndexList &indexes, + Vector &row_identifiers, idx_t count) { auto row_ids = FlatVector::GetData(row_identifiers); - // Collect all indexed columns. + // Collect all Indexed columns on the table. unordered_set indexed_column_id_set; indexes.Scan([&](Index &index) { - D_ASSERT(index.IsBound()); auto &set = index.GetColumnIdSet(); indexed_column_id_set.insert(set.begin(), set.end()); return false; }); + + // If we are in WAL replay, delete data will be buffered, and so we sort the column_ids + // since the sorted form will be the mapping used to get back physical IDs from the buffered index chunk. vector column_ids; for (auto &col : indexed_column_id_set) { column_ids.emplace_back(col); @@ -686,10 +713,10 @@ void RowGroupCollection::RemoveFromIndexes(TableIndexList &indexes, Vector &row_ // Initialize the fetch state. Only use indexed columns. TableScanState state; - state.Initialize(std::move(column_ids)); + auto column_ids_copy = column_ids; + state.Initialize(std::move(column_ids_copy)); state.table_state.max_row = row_start + total_rows; - // Used for scanning data. Only contains the indexed columns. DataChunk fetch_chunk; fetch_chunk.Initialize(GetAllocator(), column_types); @@ -713,13 +740,15 @@ void RowGroupCollection::RemoveFromIndexes(TableIndexList &indexes, Vector &row_ // Figure out which row_group to fetch from. auto row_id = row_ids[r]; auto row_group = row_groups->GetSegment(UnsafeNumericCast(row_id)); - auto row_group_vector_idx = (UnsafeNumericCast(row_id) - row_group->start) / STANDARD_VECTOR_SIZE; - auto base_row_id = row_group_vector_idx * STANDARD_VECTOR_SIZE + row_group->start; + + auto ¤t_row_group = *row_group->node; + auto row_group_vector_idx = (UnsafeNumericCast(row_id) - current_row_group.start) / STANDARD_VECTOR_SIZE; + auto base_row_id = row_group_vector_idx * STANDARD_VECTOR_SIZE + current_row_group.start; // Fetch the current vector into fetch_chunk. - state.table_state.Initialize(GetTypes()); - row_group->InitializeScanWithOffset(state.table_state, row_group_vector_idx); - row_group->ScanCommitted(state.table_state, fetch_chunk, TableScanType::TABLE_SCAN_COMMITTED_ROWS); + state.table_state.Initialize(context, GetTypes()); + current_row_group.InitializeScanWithOffset(state.table_state, *row_group, row_group_vector_idx); + current_row_group.ScanCommitted(state.table_state, fetch_chunk, TableScanType::TABLE_SCAN_COMMITTED_ROWS); fetch_chunk.Verify(); // Check for any remaining row ids, if they also fall into this vector. @@ -749,34 +778,43 @@ void RowGroupCollection::RemoveFromIndexes(TableIndexList &indexes, Vector &row_ result_chunk.SetCardinality(fetch_chunk); // Slice the vector with all rows that are present in this vector. - // Then, erase all values from the indexes. + // If the index is bound, delete the data. If unbound, buffer into unbound_index. result_chunk.Slice(sel, sel_count); indexes.Scan([&](Index &index) { if (index.IsBound()) { index.Cast().Delete(result_chunk, row_identifiers); return false; } - throw MissingExtensionException( - "Cannot delete from index '%s', unknown index type '%s'. You need to load the " - "extension that provides this index type before table '%s' can be modified.", - index.GetIndexName(), index.GetIndexType(), info->GetTableName()); + // Buffering takes only the indexed columns in ordering of the column_ids mapping. + DataChunk index_column_chunk; + index_column_chunk.InitializeEmpty(column_types); + for (idx_t i = 0; i < column_types.size(); i++) { + auto col_id = column_ids[i].GetPrimaryIndex(); + index_column_chunk.data[i].Reference(result_chunk.data[col_id]); + } + index_column_chunk.SetCardinality(result_chunk.size()); + auto &unbound_index = index.Cast(); + unbound_index.BufferChunk(index_column_chunk, row_identifiers, column_ids, BufferedIndexReplay::DEL_ENTRY); + return false; }); } } -void RowGroupCollection::UpdateColumn(TransactionData transaction, Vector &row_ids, const vector &column_path, - DataChunk &updates) { +void RowGroupCollection::UpdateColumn(TransactionData transaction, DataTable &data_table, Vector &row_ids, + const vector &column_path, DataChunk &updates) { D_ASSERT(updates.size() >= 1); auto ids = FlatVector::GetData(row_ids); idx_t pos = 0; do { idx_t start = pos; auto row_group = NextUpdateRowGroup(ids, pos, updates.size()); - row_group->UpdateColumn(transaction, updates, row_ids, start, pos - start, column_path); + auto ¤t_row_group = *row_group->node; + current_row_group.UpdateColumn(transaction, data_table, updates, row_ids, start, pos - start, column_path); auto lock = stats.GetLock(); auto primary_column_idx = column_path[0]; - row_group->MergeIntoStatistics(primary_column_idx, stats.GetStats(*lock, primary_column_idx).Statistics()); + current_row_group.MergeIntoStatistics(primary_column_idx, + stats.GetStats(*lock, primary_column_idx).Statistics()); } while (pos < updates.size()); } @@ -785,7 +823,7 @@ void RowGroupCollection::UpdateColumn(TransactionData transaction, Vector &row_i //===--------------------------------------------------------------------===// struct CollectionCheckpointState { CollectionCheckpointState(RowGroupCollection &collection, TableDataWriter &writer, - vector> &segments, TableStatistics &global_stats) + vector>> &segments, TableStatistics &global_stats) : collection(collection), writer(writer), executor(writer.CreateTaskExecutor()), segments(segments), global_stats(global_stats) { writers.resize(segments.size()); @@ -795,7 +833,7 @@ struct CollectionCheckpointState { RowGroupCollection &collection; TableDataWriter &writer; unique_ptr executor; - vector> &segments; + vector>> &segments; vector> writers; vector write_data; TableStatistics &global_stats; @@ -820,8 +858,8 @@ class CheckpointTask : public BaseCheckpointTask { void ExecuteTask() override { auto &entry = checkpoint_state.segments[index]; - auto &row_group = *entry.node; - checkpoint_state.writers[index] = checkpoint_state.writer.GetRowGroupWriter(*entry.node); + auto &row_group = *entry->node; + checkpoint_state.writers[index] = checkpoint_state.writer.GetRowGroupWriter(row_group); checkpoint_state.write_data[index] = row_group.WriteToDisk(*checkpoint_state.writers[index]); } @@ -887,7 +925,7 @@ class VacuumTask : public BaseCheckpointTask { TableScanState scan_state; scan_state.Initialize(column_ids); - scan_state.table_state.Initialize(types); + scan_state.table_state.Initialize(QueryContext(), types); scan_state.table_state.max_row = idx_t(-1); idx_t merged_groups = 0; idx_t total_row_groups = vacuum_state.row_group_counts.size(); @@ -897,9 +935,9 @@ class VacuumTask : public BaseCheckpointTask { } merged_groups++; - auto ¤t_row_group = *checkpoint_state.segments[c_idx].node; + auto ¤t_row_group = *checkpoint_state.segments[c_idx]->node; - current_row_group.InitializeScan(scan_state.table_state); + current_row_group.InitializeScan(scan_state.table_state, *checkpoint_state.segments[c_idx]); while (true) { scan_chunk.Reset(); @@ -929,7 +967,7 @@ class VacuumTask : public BaseCheckpointTask { } // drop the row group after merging current_row_group.CommitDrop(); - checkpoint_state.segments[c_idx].node.reset(); + checkpoint_state.segments[c_idx]->node.reset(); } idx_t total_append_count = 0; for (idx_t target_idx = 0; target_idx < target_count; target_idx++) { @@ -937,7 +975,7 @@ class VacuumTask : public BaseCheckpointTask { row_group->Verify(); // assign the new row group to the current segment - checkpoint_state.segments[segment_idx + target_idx].node = std::move(row_group); + checkpoint_state.segments[segment_idx + target_idx]->node = std::move(row_group); total_append_count += append_counts[target_idx]; } if (total_append_count != merge_rows) { @@ -964,7 +1002,7 @@ class VacuumTask : public BaseCheckpointTask { }; void RowGroupCollection::InitializeVacuumState(CollectionCheckpointState &checkpoint_state, VacuumState &state, - vector> &segments) { + vector>> &segments) { auto checkpoint_type = checkpoint_state.writer.GetCheckpointType(); bool vacuum_is_allowed = checkpoint_type != CheckpointType::CONCURRENT_CHECKPOINT; // currently we can only vacuum deletes if we are doing a full checkpoint and there are no indexes @@ -975,12 +1013,12 @@ void RowGroupCollection::InitializeVacuumState(CollectionCheckpointState &checkp // obtain the set of committed row counts for each row group state.row_group_counts.reserve(segments.size()); for (auto &entry : segments) { - auto &row_group = *entry.node; + auto &row_group = *entry->node; auto row_group_count = row_group.GetCommittedRowCount(); if (row_group_count == 0) { // empty row group - we can drop it entirely row_group.CommitDrop(); - entry.node.reset(); + entry->node.reset(); } state.row_group_counts.push_back(row_group_count); } @@ -1000,7 +1038,7 @@ bool RowGroupCollection::ScheduleVacuumTasks(CollectionCheckpointState &checkpoi } if (state.row_group_counts[segment_idx] == 0) { // segment was already dropped - skip - D_ASSERT(!checkpoint_state.segments[segment_idx].node); + D_ASSERT(!checkpoint_state.segments[segment_idx]->node); return false; } if (!schedule_vacuum) { @@ -1102,19 +1140,20 @@ void RowGroupCollection::Checkpoint(TableDataWriter &writer, TableStatistics &gl total_vacuum_tasks++; continue; } - if (!entry.node) { + if (!entry->node) { // row group was vacuumed/dropped - skip continue; } // schedule a checkpoint task for this row group - entry.node->MoveToCollection(*this, vacuum_state.row_start); + auto &row_group = *entry->node; + row_group.MoveToCollection(*this, vacuum_state.row_start); if (writer.GetCheckpointType() != CheckpointType::VACUUM_ONLY) { DUCKDB_LOG(checkpoint_state.writer.GetDatabase(), CheckpointLogType, GetAttached(), *info, segment_idx, - *entry.node); + row_group); auto checkpoint_task = GetCheckpointTask(checkpoint_state, segment_idx); checkpoint_state.executor->ScheduleTask(std::move(checkpoint_task)); } - vacuum_state.row_start += entry.node->count; + vacuum_state.row_start += row_group.count; } } catch (const std::exception &e) { ErrorData error(e); @@ -1131,12 +1170,12 @@ void RowGroupCollection::Checkpoint(TableDataWriter &writer, TableStatistics &gl bool table_has_changes = false; for (idx_t segment_idx = 0; segment_idx < segments.size(); segment_idx++) { auto &entry = segments[segment_idx]; - if (!entry.node) { + if (!entry->node) { table_has_changes = true; break; } auto &write_state = checkpoint_state.write_data[segment_idx]; - if (write_state.existing_pointers.empty()) { + if (!write_state.reuse_existing_metadata_blocks) { table_has_changes = true; break; } @@ -1148,11 +1187,18 @@ void RowGroupCollection::Checkpoint(TableDataWriter &writer, TableStatistics &gl auto &metadata_manager = writer.GetMetadataManager(); for (idx_t segment_idx = 0; segment_idx < segments.size(); segment_idx++) { auto &entry = segments[segment_idx]; - auto &row_group = *entry.node; + auto &row_group = *entry->node; auto &write_state = checkpoint_state.write_data[segment_idx]; - metadata_manager.ClearModifiedBlocks(write_state.existing_pointers); + metadata_manager.ClearModifiedBlocks(row_group.GetColumnStartPointers()); + D_ASSERT(write_state.reuse_existing_metadata_blocks); + vector extra_metadata_block_pointers; + extra_metadata_block_pointers.reserve(write_state.existing_extra_metadata_blocks.size()); + for (auto &block_pointer : write_state.existing_extra_metadata_blocks) { + extra_metadata_block_pointers.emplace_back(block_pointer, 0); + } + metadata_manager.ClearModifiedBlocks(extra_metadata_block_pointers); metadata_manager.ClearModifiedBlocks(row_group.GetDeletesPointers()); - row_groups->AppendSegment(l, std::move(entry.node)); + row_groups->AppendSegment(l, std::move(entry->node)); } writer.WriteUnchangedTable(metadata_pointer, total_rows.load()); return; @@ -1162,15 +1208,15 @@ void RowGroupCollection::Checkpoint(TableDataWriter &writer, TableStatistics &gl idx_t new_total_rows = 0; for (idx_t segment_idx = 0; segment_idx < segments.size(); segment_idx++) { auto &entry = segments[segment_idx]; - if (!entry.node) { + if (!entry->node) { // row group was vacuumed/dropped - skip continue; } - auto &row_group = *entry.node; + auto &row_group = *entry->node; if (!checkpoint_state.writers[segment_idx]) { // row group was not checkpointed - this can happen if compressing is disabled for in-memory tables D_ASSERT(writer.GetCheckpointType() == CheckpointType::VACUUM_ONLY); - row_groups->AppendSegment(l, std::move(entry.node)); + row_groups->AppendSegment(l, std::move(entry->node)); new_total_rows += row_group.count; continue; } @@ -1178,11 +1224,98 @@ void RowGroupCollection::Checkpoint(TableDataWriter &writer, TableStatistics &gl if (!row_group_writer) { throw InternalException("Missing row group writer for index %llu", segment_idx); } + bool metadata_reuse = checkpoint_state.write_data[segment_idx].reuse_existing_metadata_blocks; auto pointer = row_group.Checkpoint(std::move(checkpoint_state.write_data[segment_idx]), *row_group_writer, global_stats); + + auto debug_verify_blocks = DBConfig::GetSetting(GetAttached().GetDatabase()) && + dynamic_cast(&checkpoint_state.writer) != nullptr; + RowGroupPointer pointer_copy; + if (debug_verify_blocks) { + pointer_copy = pointer; + } writer.AddRowGroup(std::move(pointer), std::move(row_group_writer)); - row_groups->AppendSegment(l, std::move(entry.node)); + row_groups->AppendSegment(l, std::move(entry->node)); new_total_rows += row_group.count; + + if (debug_verify_blocks) { + if (!pointer_copy.has_metadata_blocks) { + throw InternalException("Checkpointing should always remember metadata blocks"); + } + if (metadata_reuse && pointer_copy.data_pointers != row_group.GetColumnStartPointers()) { + throw InternalException("Colum start pointers changed during metadata reuse"); + } + + // Capture blocks that have been written + vector all_written_blocks = pointer_copy.data_pointers; + vector all_metadata_blocks; + for (auto &block : pointer_copy.extra_metadata_blocks) { + all_written_blocks.emplace_back(block, 0); + all_metadata_blocks.emplace_back(block, 0); + } + + // Verify that we can load the metadata correctly again + vector all_quick_read_blocks; + for (auto &ptr : row_group.GetColumnStartPointers()) { + all_quick_read_blocks.emplace_back(ptr); + if (metadata_reuse && !block_manager.GetMetadataManager().BlockHasBeenCleared(ptr)) { + throw InternalException("Found column start block that was not cleared"); + } + } + auto extra_metadata_blocks = row_group.GetOrComputeExtraMetadataBlocks(/* force_compute: */ true); + for (auto &ptr : extra_metadata_blocks) { + auto block_pointer = MetaBlockPointer(ptr, 0); + all_quick_read_blocks.emplace_back(block_pointer); + if (metadata_reuse && !block_manager.GetMetadataManager().BlockHasBeenCleared(block_pointer)) { + throw InternalException("Found extra metadata block that was not cleared"); + } + } + + // Deserialize all columns to check if the quick read via GetOrComputeExtraMetadataBlocks was correct + vector all_full_read_blocks; + auto column_start_pointers = row_group.GetColumnStartPointers(); + auto &types = row_group.GetCollection().GetTypes(); + auto &metadata_manager = row_group.GetCollection().GetMetadataManager(); + for (idx_t i = 0; i < column_start_pointers.size(); i++) { + MetadataReader reader(metadata_manager, column_start_pointers[i], &all_full_read_blocks); + ColumnData::Deserialize(GetBlockManager(), GetTableInfo(), i, row_group.start, reader, types[i]); + } + + // Derive sets of blocks to compare + set all_written_block_ids; + for (auto &ptr : all_written_blocks) { + all_written_block_ids.insert(ptr.block_pointer); + } + set all_quick_read_block_ids; + for (auto &ptr : all_quick_read_blocks) { + all_quick_read_block_ids.insert(ptr.block_pointer); + } + set all_full_read_block_ids; + for (auto &ptr : all_full_read_blocks) { + all_full_read_block_ids.insert(ptr.block_pointer); + } + if (all_written_block_ids != all_quick_read_block_ids || + all_quick_read_block_ids != all_full_read_block_ids) { + std::stringstream oss; + oss << "Written: "; + for (auto &block : all_written_blocks) { + oss << block << ", "; + } + oss << "\n"; + oss << "Quick read: "; + for (auto &block : all_quick_read_blocks) { + oss << block << ", "; + } + oss << "\n"; + oss << "Full read: "; + for (auto &block : all_full_read_blocks) { + oss << block << ", "; + } + oss << "\n"; + + throw InternalException("Reloading blocks just written does not yield same blocks: " + oss.str()); + } + } } total_rows = new_total_rows; l.Release(); @@ -1213,7 +1346,7 @@ void RowGroupCollection::Destroy() { TaskExecutor executor(TaskScheduler::GetScheduler(GetAttached().GetDatabase())); for (auto &segment : segments) { - auto destroy_task = make_uniq(executor, std::move(segment.node)); + auto destroy_task = make_uniq(executor, std::move(segment->node)); executor.ScheduleTask(std::move(destroy_task)); } executor.WorkOnTasks(); @@ -1248,11 +1381,12 @@ vector RowGroupCollection::GetPartitionStats() const { //===--------------------------------------------------------------------===// // GetColumnSegmentInfo //===--------------------------------------------------------------------===// -vector RowGroupCollection::GetColumnSegmentInfo() { +vector RowGroupCollection::GetColumnSegmentInfo(const QueryContext &context) { vector result; auto lock = row_groups->Lock(); - for (auto &row_group : row_groups->Segments(lock)) { - row_group.GetColumnSegmentInfo(row_group.index, result); + for (auto &node : row_groups->SegmentNodes(lock)) { + auto &row_group = *node.node; + row_group.GetColumnSegmentInfo(context, node.index, result); } return result; } @@ -1340,16 +1474,18 @@ shared_ptr RowGroupCollection::AlterType(ClientContext &cont // now alter the type of the column within all of the row_groups individually auto lock = result->stats.GetLock(); auto &changed_stats = result->stats.GetStats(*lock, changed_idx); - for (auto ¤t_row_group : row_groups->Segments()) { + for (auto &node : row_groups->SegmentNodes()) { + auto ¤t_row_group = *node.node; auto new_row_group = current_row_group.AlterType(*result, target_type, changed_idx, executor, - scan_state.table_state, scan_chunk); + scan_state.table_state, node, scan_chunk); new_row_group->MergeIntoStatistics(changed_idx, changed_stats.Statistics()); result->row_groups->AppendSegment(std::move(new_row_group)); } return result; } -void RowGroupCollection::VerifyNewConstraint(DataTable &parent, const BoundConstraint &constraint) { +void RowGroupCollection::VerifyNewConstraint(const QueryContext &context, DataTable &parent, + const BoundConstraint &constraint) { if (total_rows == 0) { return; } @@ -1371,7 +1507,7 @@ void RowGroupCollection::VerifyNewConstraint(DataTable &parent, const BoundConst CreateIndexScanState state; auto scan_type = TableScanType::TABLE_SCAN_COMMITTED_ROWS_OMIT_PERMANENTLY_DELETED; state.Initialize(column_ids, nullptr); - InitializeScan(state.table_state, column_ids, nullptr); + InitializeScan(context, state.table_state, column_ids, nullptr); InitializeCreateIndexScan(state); diff --git a/src/duckdb/src/storage/table/row_id_column_data.cpp b/src/duckdb/src/storage/table/row_id_column_data.cpp index d869913bf..4bc3c4148 100644 --- a/src/duckdb/src/storage/table/row_id_column_data.cpp +++ b/src/duckdb/src/storage/table/row_id_column_data.cpp @@ -138,13 +138,14 @@ void RowIdColumnData::RevertAppend(row_t start_row) { throw InternalException("RowIdColumnData cannot be appended to"); } -void RowIdColumnData::Update(TransactionData transaction, idx_t column_index, Vector &update_vector, row_t *row_ids, - idx_t update_count) { +void RowIdColumnData::Update(TransactionData transaction, DataTable &data_table, idx_t column_index, + Vector &update_vector, row_t *row_ids, idx_t update_count) { throw InternalException("RowIdColumnData cannot be updated"); } -void RowIdColumnData::UpdateColumn(TransactionData transaction, const vector &column_path, - Vector &update_vector, row_t *row_ids, idx_t update_count, idx_t depth) { +void RowIdColumnData::UpdateColumn(TransactionData transaction, DataTable &data_table, + const vector &column_path, Vector &update_vector, row_t *row_ids, + idx_t update_count, idx_t depth) { throw InternalException("RowIdColumnData cannot be updated"); } diff --git a/src/duckdb/src/storage/table/row_version_manager.cpp b/src/duckdb/src/storage/table/row_version_manager.cpp index df4e463da..6b5f4b9bd 100644 --- a/src/duckdb/src/storage/table/row_version_manager.cpp +++ b/src/duckdb/src/storage/table/row_version_manager.cpp @@ -7,19 +7,10 @@ namespace duckdb { -RowVersionManager::RowVersionManager(idx_t start) noexcept : start(start), has_changes(false) { -} - -void RowVersionManager::SetStart(idx_t new_start) { - lock_guard l(version_lock); - this->start = new_start; - idx_t current_start = start; - for (auto &info : vector_info) { - if (info) { - info->start = current_start; - } - current_start += STANDARD_VECTOR_SIZE; - } +RowVersionManager::RowVersionManager(BufferManager &buffer_manager_p) noexcept + : allocator(STANDARD_VECTOR_SIZE * sizeof(transaction_t), buffer_manager_p.GetTemporaryBlockManager(), + MemoryTag::BASE_TABLE), + has_changes(false) { } idx_t RowVersionManager::GetCommittedDeletedCount(idx_t count) { @@ -103,7 +94,7 @@ void RowVersionManager::AppendVersionInfo(TransactionData transaction, idx_t cou vector_idx == end_vector_idx ? row_group_end - end_vector_idx * STANDARD_VECTOR_SIZE : STANDARD_VECTOR_SIZE; if (vector_start == 0 && vector_end == STANDARD_VECTOR_SIZE) { // entire vector is encapsulated by append: append a single constant - auto constant_info = make_uniq(start + vector_idx * STANDARD_VECTOR_SIZE); + auto constant_info = make_uniq(vector_idx * STANDARD_VECTOR_SIZE); constant_info->insert_id = transaction.transaction_id; constant_info->delete_id = NOT_DELETED_ID; vector_info[vector_idx] = std::move(constant_info); @@ -112,7 +103,7 @@ void RowVersionManager::AppendVersionInfo(TransactionData transaction, idx_t cou optional_ptr new_info; if (!vector_info[vector_idx]) { // first time appending to this vector: create new info - auto insert_info = make_uniq(start + vector_idx * STANDARD_VECTOR_SIZE); + auto insert_info = make_uniq(allocator, vector_idx * STANDARD_VECTOR_SIZE); new_info = insert_info.get(); vector_info[vector_idx] = std::move(insert_info); } else if (vector_info[vector_idx]->type == ChunkInfoType::VECTOR_INFO) { @@ -188,15 +179,11 @@ ChunkVectorInfo &RowVersionManager::GetVectorInfo(idx_t vector_idx) { if (!vector_info[vector_idx]) { // no info yet: create it - vector_info[vector_idx] = make_uniq(start + vector_idx * STANDARD_VECTOR_SIZE); + vector_info[vector_idx] = make_uniq(allocator, vector_idx * STANDARD_VECTOR_SIZE); } else if (vector_info[vector_idx]->type == ChunkInfoType::CONSTANT_INFO) { auto &constant = vector_info[vector_idx]->Cast(); // info exists but it's a constant info: convert to a vector info - auto new_info = make_uniq(start + vector_idx * STANDARD_VECTOR_SIZE); - new_info->insert_id = constant.insert_id; - for (idx_t i = 0; i < STANDARD_VECTOR_SIZE; i++) { - new_info->inserted[i] = constant.insert_id; - } + auto new_info = make_uniq(allocator, vector_idx * STANDARD_VECTOR_SIZE, constant.insert_id); vector_info[vector_idx] = std::move(new_info); } D_ASSERT(vector_info[vector_idx]->type == ChunkInfoType::VECTOR_INFO); @@ -257,12 +244,12 @@ vector RowVersionManager::Checkpoint(MetadataManager &manager) return storage_pointers; } -shared_ptr RowVersionManager::Deserialize(MetaBlockPointer delete_pointer, MetadataManager &manager, - idx_t start) { +shared_ptr RowVersionManager::Deserialize(MetaBlockPointer delete_pointer, + MetadataManager &manager) { if (!delete_pointer.IsValid()) { return nullptr; } - auto version_info = make_shared_ptr(start); + auto version_info = make_shared_ptr(manager.GetBufferManager()); MetadataReader source(manager, delete_pointer, &version_info->storage_pointers); auto chunk_count = source.Read(); D_ASSERT(chunk_count > 0); @@ -275,7 +262,7 @@ shared_ptr RowVersionManager::Deserialize(MetaBlockPointer de } version_info->FillVectorInfo(vector_index); - version_info->vector_info[vector_index] = ChunkInfo::Read(source); + version_info->vector_info[vector_index] = ChunkInfo::Read(version_info->GetAllocator(), source); } version_info->has_changes = false; return version_info; diff --git a/src/duckdb/src/storage/table/scan_state.cpp b/src/duckdb/src/storage/table/scan_state.cpp index f7f5f727d..a5da219e7 100644 --- a/src/duckdb/src/storage/table/scan_state.cpp +++ b/src/duckdb/src/storage/table/scan_state.cpp @@ -110,6 +110,67 @@ void ScanFilterInfo::SetFilterAlwaysTrue(idx_t filter_idx) { always_true_filters++; } +RowGroupReorderer::RowGroupReorderer(const RowGroupOrderOptions &options) + : column_idx(options.column_idx), order_by(options.order_by), order_type(options.order_type), + column_type(options.column_type), offset(0), initialized(false) { +} + +optional_ptr> RowGroupReorderer::GetNextRowGroup(SegmentNode &row_group) { + D_ASSERT(RefersToSameObject(ordered_row_groups[offset].get(), row_group)); + if (offset >= ordered_row_groups.size() - 1) { + return nullptr; + } + return ordered_row_groups[++offset].get(); +} + +Value RowGroupReorderer::RetrieveStat(const BaseStatistics &stats, OrderByStatistics order_by, + OrderByColumnType column_type) { + switch (order_by) { + case OrderByStatistics::MIN: + return column_type == OrderByColumnType::NUMERIC ? NumericStats::Min(stats) : StringStats::Min(stats); + case OrderByStatistics::MAX: + return column_type == OrderByColumnType::NUMERIC ? NumericStats::Max(stats) : StringStats::Max(stats); + } + return Value(); +} + +optional_ptr> RowGroupReorderer::GetRootSegment(RowGroupSegmentTree &row_groups) { + if (initialized) { + if (ordered_row_groups.empty()) { + return nullptr; + } + return ordered_row_groups[0].get(); + } + + initialized = true; + + multimap>> row_group_map; + for (auto &node : row_groups.SegmentNodes()) { + auto &row_group = *node.node; + auto stats = row_group.GetStatistics(column_idx); + Value comparison_value = RetrieveStat(*stats, order_by, column_type); + + row_group_map.emplace(comparison_value, reference>(node)); + } + + if (row_group_map.empty()) { + return nullptr; + } + + ordered_row_groups.reserve(row_group_map.size()); + if (order_type == RowGroupOrderType::ASC) { + for (auto &row_group : row_group_map) { + ordered_row_groups.emplace_back(row_group.second); + } + } else { + for (auto it = row_group_map.rbegin(); it != row_group_map.rend(); ++it) { + ordered_row_groups.emplace_back(it->second); + } + } + + return ordered_row_groups[0].get(); +} + optional_ptr ScanFilterInfo::GetAdaptiveFilter() { return adaptive_filter.get(); } @@ -134,15 +195,16 @@ void ColumnScanState::NextInternal(idx_t count) { return; } row_index += count; - while (row_index >= current->start + current->count) { - current = segment_tree->GetNextSegment(current); + while (row_index >= current->node->start + current->node->count) { + current = segment_tree->GetNextSegment(*current); initialized = false; segment_checked = false; if (!current) { break; } } - D_ASSERT(!current || (row_index >= current->start && row_index < current->start + current->count)); + D_ASSERT(!current || + (row_index >= current->node->start && row_index < current->node->start + current->node->count)); } void ColumnScanState::Next(idx_t count) { @@ -174,28 +236,63 @@ ParallelCollectionScanState::ParallelCollectionScanState() : collection(nullptr), current_row_group(nullptr), processed_rows(0) { } +optional_ptr> ParallelCollectionScanState::GetRootSegment(RowGroupSegmentTree &row_groups) const { + if (reorderer) { + return reorderer->GetRootSegment(row_groups); + } + return row_groups.GetRootSegment(); +} + +optional_ptr> +ParallelCollectionScanState::GetNextRowGroup(RowGroupSegmentTree &row_groups, SegmentNode &row_group) const { + if (reorderer) { + return reorderer->GetNextRowGroup(row_group); + } + return row_groups.GetNextSegment(row_group); +} + CollectionScanState::CollectionScanState(TableScanState &parent_p) : row_group(nullptr), vector_index(0), max_row_group_row(0), row_groups(nullptr), max_row(0), batch_index(0), valid_sel(STANDARD_VECTOR_SIZE), random(-1), parent(parent_p) { } +optional_ptr> CollectionScanState::GetNextRowGroup(SegmentNode &row_group) const { + if (reorderer) { + return reorderer->GetNextRowGroup(row_group); + } + return row_groups->GetNextSegment(row_group); +} + +optional_ptr> CollectionScanState::GetNextRowGroup(SegmentLock &l, + SegmentNode &row_group) const { + D_ASSERT(!reorderer); + return row_groups->GetNextSegment(l, row_group); +} + +optional_ptr> CollectionScanState::GetRootSegment() const { + if (reorderer) { + return reorderer->GetRootSegment(*row_groups); + } + return row_groups->GetRootSegment(); +} + bool CollectionScanState::Scan(DuckTransaction &transaction, DataChunk &result) { while (row_group) { - row_group->Scan(transaction, *this, result); + row_group->node->Scan(transaction, *this, result); if (result.size() > 0) { return true; - } else if (max_row <= row_group->start + row_group->count) { + } else if (max_row <= row_group->node->start + row_group->node->count) { row_group = nullptr; return false; } else { do { - row_group = row_groups->GetNextSegment(row_group); + row_group = GetNextRowGroup(*row_group).get(); if (row_group) { - if (row_group->start >= max_row) { + if (row_group->node->start >= max_row) { row_group = nullptr; break; } - bool scan_row_group = row_group->InitializeScan(*this); + bool scan_row_group = row_group->node->InitializeScan(*this, *row_group); if (scan_row_group) { // scan this row group break; @@ -209,13 +306,13 @@ bool CollectionScanState::Scan(DuckTransaction &transaction, DataChunk &result) bool CollectionScanState::ScanCommitted(DataChunk &result, SegmentLock &l, TableScanType type) { while (row_group) { - row_group->ScanCommitted(*this, result, type); + row_group->node->ScanCommitted(*this, result, type); if (result.size() > 0) { return true; } else { - row_group = row_groups->GetNextSegment(l, row_group); + row_group = GetNextRowGroup(l, *row_group).get(); if (row_group) { - row_group->InitializeScan(*this); + row_group->node->InitializeScan(*this, *row_group); } } } @@ -224,14 +321,14 @@ bool CollectionScanState::ScanCommitted(DataChunk &result, SegmentLock &l, Table bool CollectionScanState::ScanCommitted(DataChunk &result, TableScanType type) { while (row_group) { - row_group->ScanCommitted(*this, result, type); + row_group->node->ScanCommitted(*this, result, type); if (result.size() > 0) { return true; - } else { - row_group = row_groups->GetNextSegment(row_group); - if (row_group) { - row_group->InitializeScan(*this); - } + } + + row_group = GetNextRowGroup(*row_group).get(); + if (row_group) { + row_group->node->InitializeScan(*this, *row_group); } } return false; diff --git a/src/duckdb/src/storage/table/standard_column_data.cpp b/src/duckdb/src/storage/table/standard_column_data.cpp index c657c63ee..ad8814ab4 100644 --- a/src/duckdb/src/storage/table/standard_column_data.cpp +++ b/src/duckdb/src/storage/table/standard_column_data.cpp @@ -152,8 +152,8 @@ idx_t StandardColumnData::Fetch(ColumnScanState &state, row_t row_id, Vector &re return scan_count; } -void StandardColumnData::Update(TransactionData transaction, idx_t column_index, Vector &update_vector, row_t *row_ids, - idx_t update_count) { +void StandardColumnData::Update(TransactionData transaction, DataTable &data_table, idx_t column_index, + Vector &update_vector, row_t *row_ids, idx_t update_count) { ColumnScanState standard_state, validity_state; Vector base_vector(type); auto standard_fetch = FetchUpdateData(standard_state, row_ids, base_vector); @@ -162,18 +162,19 @@ void StandardColumnData::Update(TransactionData transaction, idx_t column_index, throw InternalException("Unaligned fetch in validity and main column data for update"); } - UpdateInternal(transaction, column_index, update_vector, row_ids, update_count, base_vector); - validity.UpdateInternal(transaction, column_index, update_vector, row_ids, update_count, base_vector); + UpdateInternal(transaction, data_table, column_index, update_vector, row_ids, update_count, base_vector); + validity.UpdateInternal(transaction, data_table, column_index, update_vector, row_ids, update_count, base_vector); } -void StandardColumnData::UpdateColumn(TransactionData transaction, const vector &column_path, - Vector &update_vector, row_t *row_ids, idx_t update_count, idx_t depth) { +void StandardColumnData::UpdateColumn(TransactionData transaction, DataTable &data_table, + const vector &column_path, Vector &update_vector, row_t *row_ids, + idx_t update_count, idx_t depth) { if (depth >= column_path.size()) { // update this column - ColumnData::Update(transaction, column_path[0], update_vector, row_ids, update_count); + ColumnData::Update(transaction, data_table, column_path[0], update_vector, row_ids, update_count); } else { // update the child column (i.e. the validity column) - validity.UpdateColumn(transaction, column_path, update_vector, row_ids, update_count, depth + 1); + validity.UpdateColumn(transaction, data_table, column_path, update_vector, row_ids, update_count, depth + 1); } } @@ -241,9 +242,10 @@ unique_ptr StandardColumnData::Checkpoint(RowGroup &row_g // to prevent reading the validity data immediately after it is checkpointed we first checkpoint the main column // this is necessary for concurrent checkpointing as due to the partial block manager checkpointed data might be // flushed to disk by a different thread than the one that wrote it, causing a data race - auto base_state = CreateCheckpointState(row_group, checkpoint_info.info.manager); + auto &partial_block_manager = checkpoint_info.GetPartialBlockManager(); + auto base_state = CreateCheckpointState(row_group, partial_block_manager); base_state->global_stats = BaseStatistics::CreateEmpty(type).ToUnique(); - auto validity_state_p = validity.CreateCheckpointState(row_group, checkpoint_info.info.manager); + auto validity_state_p = validity.CreateCheckpointState(row_group, partial_block_manager); validity_state_p->global_stats = BaseStatistics::CreateEmpty(validity.type).ToUnique(); auto &validity_state = *validity_state_p; @@ -294,11 +296,12 @@ void StandardColumnData::InitializeColumn(PersistentColumnData &column_data, Bas validity.InitializeColumn(column_data.child_columns[0], target_stats); } -void StandardColumnData::GetColumnSegmentInfo(duckdb::idx_t row_group_index, vector col_path, +void StandardColumnData::GetColumnSegmentInfo(const QueryContext &context, duckdb::idx_t row_group_index, + vector col_path, vector &result) { - ColumnData::GetColumnSegmentInfo(row_group_index, col_path, result); + ColumnData::GetColumnSegmentInfo(context, row_group_index, col_path, result); col_path.push_back(0); - validity.GetColumnSegmentInfo(row_group_index, std::move(col_path), result); + validity.GetColumnSegmentInfo(context, row_group_index, std::move(col_path), result); } void StandardColumnData::Verify(RowGroup &parent) { diff --git a/src/duckdb/src/storage/table/struct_column_data.cpp b/src/duckdb/src/storage/table/struct_column_data.cpp index 5137330ef..b1de02b2d 100644 --- a/src/duckdb/src/storage/table/struct_column_data.cpp +++ b/src/duckdb/src/storage/table/struct_column_data.cpp @@ -207,17 +207,18 @@ idx_t StructColumnData::Fetch(ColumnScanState &state, row_t row_id, Vector &resu return scan_count; } -void StructColumnData::Update(TransactionData transaction, idx_t column_index, Vector &update_vector, row_t *row_ids, - idx_t update_count) { - validity.Update(transaction, column_index, update_vector, row_ids, update_count); +void StructColumnData::Update(TransactionData transaction, DataTable &data_table, idx_t column_index, + Vector &update_vector, row_t *row_ids, idx_t update_count) { + validity.Update(transaction, data_table, column_index, update_vector, row_ids, update_count); auto &child_entries = StructVector::GetEntries(update_vector); for (idx_t i = 0; i < child_entries.size(); i++) { - sub_columns[i]->Update(transaction, column_index, *child_entries[i], row_ids, update_count); + sub_columns[i]->Update(transaction, data_table, column_index, *child_entries[i], row_ids, update_count); } } -void StructColumnData::UpdateColumn(TransactionData transaction, const vector &column_path, - Vector &update_vector, row_t *row_ids, idx_t update_count, idx_t depth) { +void StructColumnData::UpdateColumn(TransactionData transaction, DataTable &data_table, + const vector &column_path, Vector &update_vector, row_t *row_ids, + idx_t update_count, idx_t depth) { // we can never DIRECTLY update a struct column if (depth >= column_path.size()) { throw InternalException("Attempting to directly update a struct column - this should not be possible"); @@ -225,13 +226,13 @@ void StructColumnData::UpdateColumn(TransactionData transaction, const vector sub_columns.size()) { throw InternalException("Update column_path out of range"); } - sub_columns[update_column - 1]->UpdateColumn(transaction, column_path, update_vector, row_ids, update_count, - depth + 1); + sub_columns[update_column - 1]->UpdateColumn(transaction, data_table, column_path, update_vector, row_ids, + update_count, depth + 1); } } @@ -311,7 +312,8 @@ unique_ptr StructColumnData::CreateCheckpointState(RowGro unique_ptr StructColumnData::Checkpoint(RowGroup &row_group, ColumnCheckpointInfo &checkpoint_info) { - auto checkpoint_state = make_uniq(row_group, *this, checkpoint_info.info.manager); + auto &partial_block_manager = checkpoint_info.GetPartialBlockManager(); + auto checkpoint_state = make_uniq(row_group, *this, partial_block_manager); checkpoint_state->validity_state = validity.Checkpoint(row_group, checkpoint_info); for (auto &sub_column : sub_columns) { checkpoint_state->child_states.push_back(sub_column->Checkpoint(row_group, checkpoint_info)); @@ -361,13 +363,13 @@ void StructColumnData::InitializeColumn(PersistentColumnData &column_data, BaseS this->count = validity.count.load(); } -void StructColumnData::GetColumnSegmentInfo(duckdb::idx_t row_group_index, vector col_path, - vector &result) { +void StructColumnData::GetColumnSegmentInfo(const QueryContext &context, idx_t row_group_index, vector col_path, + vector &result) { col_path.push_back(0); - validity.GetColumnSegmentInfo(row_group_index, col_path, result); + validity.GetColumnSegmentInfo(context, row_group_index, col_path, result); for (idx_t i = 0; i < sub_columns.size(); i++) { col_path.back() = i + 1; - sub_columns[i]->GetColumnSegmentInfo(row_group_index, col_path, result); + sub_columns[i]->GetColumnSegmentInfo(context, row_group_index, col_path, result); } } diff --git a/src/duckdb/src/storage/table/update_segment.cpp b/src/duckdb/src/storage/table/update_segment.cpp index 8056907bc..00a6bbbc2 100644 --- a/src/duckdb/src/storage/table/update_segment.cpp +++ b/src/duckdb/src/storage/table/update_segment.cpp @@ -7,6 +7,7 @@ #include "duckdb/transaction/duck_transaction.hpp" #include "duckdb/transaction/update_info.hpp" #include "duckdb/transaction/undo_buffer.hpp" +#include "duckdb/storage/data_table.hpp" #include @@ -104,9 +105,10 @@ idx_t UpdateInfo::GetAllocSize(idx_t type_size) { return AlignValue(sizeof(UpdateInfo) + (sizeof(sel_t) + type_size) * STANDARD_VECTOR_SIZE); } -void UpdateInfo::Initialize(UpdateInfo &info, transaction_t transaction_id) { +void UpdateInfo::Initialize(UpdateInfo &info, DataTable &data_table, transaction_t transaction_id) { info.max = STANDARD_VECTOR_SIZE; info.version_number = transaction_id; + info.table = &data_table; info.segment = nullptr; info.prev.entry = nullptr; info.next.entry = nullptr; @@ -1236,11 +1238,11 @@ static idx_t SortSelectionVector(SelectionVector &sel, idx_t count, row_t *ids) return pos; } -UpdateInfo *CreateEmptyUpdateInfo(TransactionData transaction, idx_t type_size, idx_t count, +UpdateInfo *CreateEmptyUpdateInfo(TransactionData transaction, DataTable &data_table, idx_t type_size, idx_t count, unsafe_unique_array &data) { data = make_unsafe_uniq_array_uninitialized(UpdateInfo::GetAllocSize(type_size)); auto update_info = reinterpret_cast(data.get()); - UpdateInfo::Initialize(*update_info, transaction.transaction_id); + UpdateInfo::Initialize(*update_info, data_table, transaction.transaction_id); return update_info; } @@ -1258,8 +1260,8 @@ void UpdateSegment::InitializeUpdateInfo(idx_t vector_idx) { } } -void UpdateSegment::Update(TransactionData transaction, idx_t column_index, Vector &update_p, row_t *ids, idx_t count, - Vector &base_data) { +void UpdateSegment::Update(TransactionData transaction, DataTable &data_table, idx_t column_index, Vector &update_p, + row_t *ids, idx_t count, Vector &base_data) { // obtain an exclusive lock auto write_lock = lock.GetExclusiveLock(); @@ -1322,10 +1324,10 @@ void UpdateSegment::Update(TransactionData transaction, idx_t column_index, Vect // no updates made yet by this transaction: initially the update info to empty if (transaction.transaction) { auto &dtransaction = transaction.transaction->Cast(); - node_ref = dtransaction.CreateUpdateInfo(type_size, count); + node_ref = dtransaction.CreateUpdateInfo(type_size, data_table, count); node = &UpdateInfo::Get(node_ref); } else { - node = CreateEmptyUpdateInfo(transaction, type_size, count, update_info_data); + node = CreateEmptyUpdateInfo(transaction, data_table, type_size, count, update_info_data); } node->segment = this; node->vector_index = vector_index; @@ -1354,13 +1356,12 @@ void UpdateSegment::Update(TransactionData transaction, idx_t column_index, Vect base_info.Verify(); node->Verify(); } else { - // there is no version info yet: create the top level update info and fill it with the updates // allocate space for the UpdateInfo in the allocator idx_t alloc_size = UpdateInfo::GetAllocSize(type_size); auto handle = root->allocator.Allocate(alloc_size); auto &update_info = UpdateInfo::Get(handle); - UpdateInfo::Initialize(update_info, TRANSACTION_ID_START - 1); + UpdateInfo::Initialize(update_info, data_table, TRANSACTION_ID_START - 1); update_info.column_index = column_index; InitializeUpdateInfo(update_info, ids, sel, count, vector_index, vector_offset); @@ -1370,10 +1371,10 @@ void UpdateSegment::Update(TransactionData transaction, idx_t column_index, Vect UndoBufferReference node_ref; optional_ptr transaction_node; if (transaction.transaction) { - node_ref = transaction.transaction->CreateUpdateInfo(type_size, count); + node_ref = transaction.transaction->CreateUpdateInfo(type_size, data_table, count); transaction_node = &UpdateInfo::Get(node_ref); } else { - transaction_node = CreateEmptyUpdateInfo(transaction, type_size, count, update_info_data); + transaction_node = CreateEmptyUpdateInfo(transaction, data_table, type_size, count, update_info_data); } InitializeUpdateInfo(*transaction_node, ids, sel, count, vector_index, vector_offset); diff --git a/src/duckdb/src/storage/table_index_list.cpp b/src/duckdb/src/storage/table_index_list.cpp index ade84cdc8..77f1f6581 100644 --- a/src/duckdb/src/storage/table_index_list.cpp +++ b/src/duckdb/src/storage/table_index_list.cpp @@ -147,11 +147,17 @@ void TableIndexList::Bind(ClientContext &context, DataTableInfo &table_info, con // Create an IndexBinder to bind the index IndexBinder idx_binder(*binder, context); - // Apply any outstanding appends and replace the unbound index with a bound index. + // Apply any outstanding buffered replays and replace the unbound index with a bound index. auto &unbound_index = index_entry->index->Cast(); auto bound_idx = idx_binder.BindIndex(unbound_index); - if (unbound_index.HasBufferedAppends()) { - bound_idx->ApplyBufferedAppends(column_types, unbound_index.GetBufferedAppends(), + if (unbound_index.HasBufferedReplays()) { + // For replaying buffered index operations, we only want the physical column types (skip over + // generated column types). + vector physical_column_types; + for (auto &col : table.GetColumns().Physical()) { + physical_column_types.push_back(col.Type()); + } + bound_idx->ApplyBufferedReplays(physical_column_types, unbound_index.GetBufferedReplays(), unbound_index.GetMappedColumnIds()); } @@ -255,11 +261,18 @@ void TableIndexList::InitializeIndexChunk(DataChunk &index_chunk, const vector index_types; + // Store the mapped_column_ids and index_types in sorted canonical form, needed for + // buffering WAL index operations during replay (see notes in unbound_index.hpp). + // First sort mapped_column_ids, then populate index_types according to the sorted order. for (auto &col : indexed_columns) { - index_types.push_back(table_types[col]); mapped_column_ids.emplace_back(col); } + std::sort(mapped_column_ids.begin(), mapped_column_ids.end()); + + vector index_types; + for (auto &col : mapped_column_ids) { + index_types.push_back(table_types[col.GetPrimaryIndex()]); + } index_chunk.InitializeEmpty(index_types); } diff --git a/src/duckdb/src/storage/temporary_file_manager.cpp b/src/duckdb/src/storage/temporary_file_manager.cpp index b8ab5a7b0..e27ef0729 100644 --- a/src/duckdb/src/storage/temporary_file_manager.cpp +++ b/src/duckdb/src/storage/temporary_file_manager.cpp @@ -73,7 +73,6 @@ TemporaryFileIdentifier::TemporaryFileIdentifier(TemporaryBufferSize size_p, idx TemporaryFileIdentifier::TemporaryFileIdentifier(DatabaseInstance &db, TemporaryBufferSize size_p, idx_t file_index_p, bool encrypted_p) : size(size_p), file_index(file_index_p), encrypted(encrypted_p) { - if (encrypted) { // generate a random encryption key ID and corresponding key EncryptionEngine::AddTempKeyToCache(db); diff --git a/src/duckdb/src/storage/wal_replay.cpp b/src/duckdb/src/storage/wal_replay.cpp index 77eca9cf7..ea271f488 100644 --- a/src/duckdb/src/storage/wal_replay.cpp +++ b/src/duckdb/src/storage/wal_replay.cpp @@ -32,6 +32,7 @@ #include "duckdb/storage/table/delete_state.hpp" #include "duckdb/storage/write_ahead_log.hpp" #include "duckdb/transaction/meta_transaction.hpp" +#include "duckdb/main/client_data.hpp" namespace duckdb { @@ -256,28 +257,34 @@ class WriteAheadLogDeserializer { //===--------------------------------------------------------------------===// // Replay //===--------------------------------------------------------------------===// -unique_ptr WriteAheadLog::Replay(FileSystem &fs, AttachedDatabase &db, const string &wal_path) { +unique_ptr WriteAheadLog::Replay(QueryContext context, FileSystem &fs, AttachedDatabase &db, + const string &wal_path) { auto handle = fs.OpenFile(wal_path, FileFlags::FILE_FLAGS_READ | FileFlags::FILE_FLAGS_NULL_IF_NOT_EXISTS); if (!handle) { // WAL does not exist - instantiate an empty WAL return make_uniq(db, wal_path); } - auto wal_handle = ReplayInternal(db, std::move(handle)); + + // context is passed for metric collection purposes only!! + auto wal_handle = ReplayInternal(context, db, std::move(handle)); if (wal_handle) { return wal_handle; } // replay returning NULL indicates we can nuke the WAL entirely - but only if this is not a read-only connection if (!db.IsReadOnly()) { - fs.RemoveFile(wal_path); + fs.TryRemoveFile(wal_path); } return make_uniq(db, wal_path); } -unique_ptr WriteAheadLog::ReplayInternal(AttachedDatabase &database, unique_ptr handle) { + +// QueryContext is passed for metric collection purposes only!! +unique_ptr WriteAheadLog::ReplayInternal(QueryContext context, AttachedDatabase &database, + unique_ptr handle) { Connection con(database.GetDatabase()); auto wal_path = handle->GetPath(); BufferedFileReader reader(FileSystem::Get(database), std::move(handle)); if (reader.Finished()) { - // WAL file exists but it is empty - we can delete the file + // WAL file exists, but it is empty - we can delete the file return nullptr; } @@ -289,7 +296,9 @@ unique_ptr WriteAheadLog::ReplayInternal(AttachedDatabase &databa // if there is a checkpoint flag, we might have already flushed the contents of the WAL to disk ReplayState checkpoint_state(database, *con.context); try { + idx_t replay_entry_count = 0; while (true) { + replay_entry_count++; // read the current entry (deserialize only) auto deserializer = WriteAheadLogDeserializer::Open(checkpoint_state, reader, true); if (deserializer.ReplayEntry()) { @@ -300,6 +309,11 @@ unique_ptr WriteAheadLog::ReplayInternal(AttachedDatabase &databa } } } + auto client_context = context.GetClientContext(); + if (client_context) { + auto &profiler = *client_context->client_data->profiler; + profiler.AddToCounter(MetricsType::WAL_REPLAY_ENTRY_COUNT, replay_entry_count); + } } catch (std::exception &ex) { // LCOV_EXCL_START ErrorData error(ex); // ignore serialization exceptions - they signal a torn WAL @@ -562,7 +576,6 @@ void WriteAheadLogDeserializer::ReplayIndexData(IndexStorageInfo &info) { // Read the data into buffer handles and convert them to blocks on disk. for (idx_t j = 0; j < data_info.allocation_sizes.size(); j++) { - // Read the data into a buffer handle. auto buffer_handle = buffer_manager.Allocate(MemoryTag::ART_INDEX, block_manager.get(), false); auto block_handle = buffer_handle.GetBlockHandle(); @@ -573,7 +586,7 @@ void WriteAheadLogDeserializer::ReplayIndexData(IndexStorageInfo &info) { // Convert the buffer handle to a persistent block and store the block id. if (!deserialize_only) { auto block_id = block_manager->GetFreeBlockId(); - block_manager->ConvertToPersistent(QueryContext(context), block_id, std::move(block_handle), + block_manager->ConvertToPersistent(context, block_id, std::move(block_handle), std::move(buffer_handle)); data_info.block_pointers[j].block_id = block_id; } diff --git a/src/duckdb/src/transaction/cleanup_state.cpp b/src/duckdb/src/transaction/cleanup_state.cpp index f9a17f265..4633a9b1a 100644 --- a/src/duckdb/src/transaction/cleanup_state.cpp +++ b/src/duckdb/src/transaction/cleanup_state.cpp @@ -13,7 +13,7 @@ namespace duckdb { -CleanupState::CleanupState(transaction_t lowest_active_transaction) +CleanupState::CleanupState(const QueryContext &context, transaction_t lowest_active_transaction) : lowest_active_transaction(lowest_active_transaction), current_table(nullptr), count(0) { } @@ -95,10 +95,15 @@ void CleanupState::Flush() { // set up the row identifiers vector Vector row_identifiers(LogicalType::ROW_TYPE, data_ptr_cast(row_numbers)); - // delete the tuples from all the indexes + // delete the tuples from all the indexes. + // If there is any issue with removal, a FatalException must be thrown since there may be a corruption of + // data, hence the transaction cannot be guaranteed. try { - current_table->RemoveFromIndexes(row_identifiers, count); - } catch (...) { // NOLINT: ignore errors here + current_table->RemoveFromIndexes(context, row_identifiers, count); + } catch (std::exception &ex) { + throw FatalException(ErrorData(ex).Message()); + } catch (...) { + throw FatalException("unknown failure in CleanupState::Flush"); } count = 0; diff --git a/src/duckdb/src/transaction/commit_state.cpp b/src/duckdb/src/transaction/commit_state.cpp index 0f5d75bd2..6eba8ab10 100644 --- a/src/duckdb/src/transaction/commit_state.cpp +++ b/src/duckdb/src/transaction/commit_state.cpp @@ -165,6 +165,12 @@ void CommitState::CommitEntry(UndoFlags type, data_ptr_t data) { case UndoFlags::INSERT_TUPLE: { // append: auto info = reinterpret_cast(data); + if (!info->table->IsMainTable()) { + auto table_name = info->table->GetTableName(); + auto table_modification = info->table->TableModification(); + throw TransactionException("Attempting to modify table %s but another transaction has %s this table", + table_name, table_modification); + } // mark the tuples as committed info->table->CommitAppend(commit_id, info->start_row, info->count); break; @@ -172,6 +178,12 @@ void CommitState::CommitEntry(UndoFlags type, data_ptr_t data) { case UndoFlags::DELETE_TUPLE: { // deletion: auto info = reinterpret_cast(data); + if (!info->table->IsMainTable()) { + auto table_name = info->table->GetTableName(); + auto table_modification = info->table->TableModification(); + throw TransactionException("Attempting to modify table %s but another transaction has %s this table", + table_name, table_modification); + } // mark the tuples as committed info->version_info->CommitDelete(info->vector_idx, commit_id, *info); break; @@ -179,6 +191,12 @@ void CommitState::CommitEntry(UndoFlags type, data_ptr_t data) { case UndoFlags::UPDATE_TUPLE: { // update: auto info = reinterpret_cast(data); + if (!info->table->IsMainTable()) { + auto table_name = info->table->GetTableName(); + auto table_modification = info->table->TableModification(); + throw TransactionException("Attempting to modify table %s but another transaction has %s this table", + table_name, table_modification); + } info->version_number = commit_id; break; } diff --git a/src/duckdb/src/transaction/duck_transaction.cpp b/src/duckdb/src/transaction/duck_transaction.cpp index dc6afccb7..3362c52d0 100644 --- a/src/duckdb/src/transaction/duck_transaction.cpp +++ b/src/duckdb/src/transaction/duck_transaction.cpp @@ -32,8 +32,8 @@ TransactionData::TransactionData(transaction_t transaction_id_p, transaction_t s DuckTransaction::DuckTransaction(DuckTransactionManager &manager, ClientContext &context_p, transaction_t start_time, transaction_t transaction_id, idx_t catalog_version_p) : Transaction(manager, context_p), start_time(start_time), transaction_id(transaction_id), commit_id(0), - highest_active_query(0), catalog_version(catalog_version_p), awaiting_cleanup(false), - transaction_manager(manager), undo_buffer(*this, context_p), storage(make_uniq(context_p, *this)) { + catalog_version(catalog_version_p), awaiting_cleanup(false), transaction_manager(manager), + undo_buffer(*this, context_p), storage(make_uniq(context_p, *this)) { } DuckTransaction::~DuckTransaction() { @@ -126,11 +126,11 @@ void DuckTransaction::PushAppend(DataTable &table, idx_t start_row, idx_t row_co append_info->count = row_count; } -UndoBufferReference DuckTransaction::CreateUpdateInfo(idx_t type_size, idx_t entries) { +UndoBufferReference DuckTransaction::CreateUpdateInfo(idx_t type_size, DataTable &data_table, idx_t entries) { idx_t alloc_size = UpdateInfo::GetAllocSize(type_size); auto undo_entry = undo_buffer.CreateEntry(UndoFlags::UPDATE_TUPLE, alloc_size); auto &update_info = UpdateInfo::Get(undo_entry); - UpdateInfo::Initialize(update_info, transaction_id); + UpdateInfo::Initialize(update_info, data_table, transaction_id); return undo_entry; } @@ -208,10 +208,10 @@ ErrorData DuckTransaction::WriteToWAL(AttachedDatabase &db, unique_ptrCommit(commit_state.get()); - undo_buffer.WriteToWAL(*log, commit_state.get()); + undo_buffer.WriteToWAL(*wal, commit_state.get()); if (commit_state->HasRowGroupData()) { // if we have optimistically written any data AND we are writing to the WAL, we have written references to // optimistically written blocks @@ -246,14 +246,6 @@ ErrorData DuckTransaction::Commit(AttachedDatabase &db, transaction_t new_commit // no need to flush anything if we made no changes return ErrorData(); } - for (auto &entry : modified_tables) { - auto &tbl = entry.first.get(); - if (!tbl.IsMainTable()) { - return ErrorData( - TransactionException("Attempting to modify table %s but another transaction has %s this table", - tbl.GetTableName(), tbl.TableModification())); - } - } D_ASSERT(db.IsSystem() || db.IsTemporary() || !IsReadOnly()); UndoBuffer::IteratorState iterator_state; diff --git a/src/duckdb/src/transaction/duck_transaction_manager.cpp b/src/duckdb/src/transaction/duck_transaction_manager.cpp index eace5283c..1b15bb33a 100644 --- a/src/duckdb/src/transaction/duck_transaction_manager.cpp +++ b/src/duckdb/src/transaction/duck_transaction_manager.cpp @@ -1,5 +1,7 @@ #include "duckdb/transaction/duck_transaction_manager.hpp" +#include "duckdb/main/client_data.hpp" + #include "duckdb/catalog/catalog_set.hpp" #include "duckdb/common/exception/transaction_exception.hpp" #include "duckdb/common/exception.hpp" @@ -216,7 +218,7 @@ void DuckTransactionManager::Checkpoint(ClientContext &context, bool force) { options.type = CheckpointType::CONCURRENT_CHECKPOINT; } - storage_manager.CreateCheckpoint(QueryContext(context), options); + storage_manager.CreateCheckpoint(context, options); } unique_ptr DuckTransactionManager::SharedCheckpointLock() { @@ -268,7 +270,14 @@ ErrorData DuckTransactionManager::CommitTransaction(ClientContext &context, Tran t_lock.unlock(); // grab the WAL lock and hold it until the entire commit is finished held_wal_lock = make_uniq>(wal_lock); - error = transaction.WriteToWAL(db, commit_state); + + // Commit the changes to the WAL. + if (db.GetRecoveryMode() == RecoveryMode::DEFAULT) { + auto &profiler = *context.client_data->profiler; + profiler.StartTimer(MetricsType::COMMIT_WRITE_WAL_LATENCY); + error = transaction.WriteToWAL(db, commit_state); + profiler.EndTimer(MetricsType::COMMIT_WRITE_WAL_LATENCY); + } // after we finish writing to the WAL we grab the transaction lock again t_lock.lock(); @@ -276,7 +285,7 @@ ErrorData DuckTransactionManager::CommitTransaction(ClientContext &context, Tran // in-memory databases don't have a WAL - we estimate how large their changeset is based on the undo properties if (!db.IsSystem()) { auto &storage_manager = db.GetStorageManager(); - if (storage_manager.InMemory()) { + if (storage_manager.InMemory() || db.GetRecoveryMode() == RecoveryMode::NO_WAL_WRITES) { storage_manager.AddInMemoryChange(undo_properties.estimated_size); } } @@ -324,7 +333,7 @@ ErrorData DuckTransactionManager::CommitTransaction(ClientContext &context, Tran } // We do not need to hold the transaction lock during cleanup of transactions, - // as they (1) have been removed, or (2) exited old_transactions. + // as they (1) have been removed, or (2) enter cleanup_info. t_lock.unlock(); { @@ -353,7 +362,7 @@ ErrorData DuckTransactionManager::CommitTransaction(ClientContext &context, Tran options.type = checkpoint_decision.type; auto &storage_manager = db.GetStorageManager(); try { - storage_manager.CreateCheckpoint(QueryContext(context), options); + storage_manager.CreateCheckpoint(context, options); } catch (std::exception &ex) { error.Merge(ErrorData(ex)); } @@ -412,7 +421,6 @@ unique_ptr DuckTransactionManager::RemoveTransaction(DuckTransa idx_t t_index = active_transactions.size(); auto lowest_start_time = TRANSACTION_ID_START; auto lowest_transaction_id = MAX_TRANSACTION_ID; - auto lowest_active_query = MAXIMUM_QUERY_ID; for (idx_t i = 0; i < active_transactions.size(); i++) { if (active_transactions[i].get() == &transaction) { t_index = i; @@ -420,8 +428,6 @@ unique_ptr DuckTransactionManager::RemoveTransaction(DuckTransa } lowest_start_time = MinValue(lowest_start_time, active_transactions[i]->start_time); lowest_transaction_id = MinValue(lowest_transaction_id, active_transactions[i]->transaction_id); - transaction_t active_query = active_transactions[i]->active_query; - lowest_active_query = MinValue(lowest_active_query, active_query); } lowest_active_start = lowest_start_time; lowest_active_id = lowest_transaction_id; @@ -429,7 +435,6 @@ unique_ptr DuckTransactionManager::RemoveTransaction(DuckTransa // Decide if we need to store the transaction, or if we can schedule it for cleanup. auto current_transaction = std::move(active_transactions[t_index]); - auto current_query = DatabaseManager::Get(db).ActiveQueryNumber(); if (store_transaction) { // If the transaction made any changes, we need to keep it around. if (transaction.commit_id != 0) { @@ -438,9 +443,7 @@ unique_ptr DuckTransactionManager::RemoveTransaction(DuckTransa recently_committed_transactions.push_back(std::move(current_transaction)); } else { // The transaction was aborted. - // We might still need its information; add it to the set of transactions awaiting GC. - current_transaction->highest_active_query = current_query; - old_transactions.push_back(std::move(current_transaction)); + cleanup_info->transactions.push_back(std::move(current_transaction)); } } else if (transaction.ChangesMade()) { // We do not need to store the transaction, directly schedule it for cleanup. @@ -464,18 +467,8 @@ unique_ptr DuckTransactionManager::RemoveTransaction(DuckTransa break; } - // Changes made BEFORE this transaction are no longer relevant. - // We can schedule the transaction and its undo buffer for cleanup. recently_committed_transactions[i]->awaiting_cleanup = true; - - // HOWEVER: Any currently running QUERY can still be using - // the version information of the transaction. - // If we remove the UndoBuffer immediately, we have a race condition. - - // Store the current highest active query. - recently_committed_transactions[i]->highest_active_query = current_query; - // Move it to the list of transactions awaiting GC. - old_transactions.push_back(std::move(recently_committed_transactions[i])); + cleanup_info->transactions.push_back(std::move(recently_committed_transactions[i])); } if (i > 0) { @@ -485,34 +478,6 @@ unique_ptr DuckTransactionManager::RemoveTransaction(DuckTransa recently_committed_transactions.erase(start, end); } - // Check if we can clean up and free the memory of any old transactions. - i = active_transactions.empty() ? old_transactions.size() : 0; - for (; i < old_transactions.size(); i++) { - D_ASSERT(old_transactions[i]); - D_ASSERT(old_transactions[i]->highest_active_query > 0); - if (old_transactions[i]->highest_active_query >= lowest_active_query) { - // There is still a query running that could be using - // this transactions' data. - break; - } - } - - if (i > 0) { - // We garbage-collected old transactions: - // - Remove them from the list and schedule them for cleanup. - - // We can only safely do the actual memory cleanup when all the - // currently active queries have finished running! (actually, - // when all the currently active scans have finished running...). - - // Because we clean up asynchronously, we only clean up once we - // no longer need the transaction for anything (i.e., we can move it). - for (idx_t t_idx = 0; t_idx < i; t_idx++) { - cleanup_info->transactions.push_back(std::move(old_transactions[t_idx])); - } - old_transactions.erase(old_transactions.begin(), old_transactions.begin() + static_cast(i)); - } - return cleanup_info; } diff --git a/src/duckdb/src/transaction/undo_buffer.cpp b/src/duckdb/src/transaction/undo_buffer.cpp index 4408a972b..beec05ec9 100644 --- a/src/duckdb/src/transaction/undo_buffer.cpp +++ b/src/duckdb/src/transaction/undo_buffer.cpp @@ -15,6 +15,7 @@ #include "duckdb/transaction/delete_info.hpp" #include "duckdb/transaction/rollback_state.hpp" #include "duckdb/transaction/wal_write_state.hpp" +#include "duckdb/transaction/duck_transaction.hpp" namespace duckdb { constexpr uint32_t UNDO_ENTRY_HEADER_SIZE = sizeof(UndoFlags) + sizeof(uint32_t); @@ -176,7 +177,7 @@ void UndoBuffer::Cleanup(transaction_t lowest_active_transaction) { // the chunks) // (2) there is no active transaction with start_id < commit_id of this // transaction - CleanupState state(lowest_active_transaction); + CleanupState state(QueryContext(), lowest_active_transaction); UndoBuffer::IteratorState iterator_state; IterateEntries(iterator_state, [&](UndoFlags type, data_ptr_t data) { state.CleanupEntry(type, data); }); diff --git a/src/duckdb/src/transaction/wal_write_state.cpp b/src/duckdb/src/transaction/wal_write_state.cpp index 5fe17e050..0036ad0c6 100644 --- a/src/duckdb/src/transaction/wal_write_state.cpp +++ b/src/duckdb/src/transaction/wal_write_state.cpp @@ -27,10 +27,10 @@ WALWriteState::WALWriteState(DuckTransaction &transaction_p, WriteAheadLog &log, : transaction(transaction_p), log(log), commit_state(commit_state), current_table_info(nullptr) { } -void WALWriteState::SwitchTable(DataTableInfo *table_info, UndoFlags new_op) { - if (current_table_info != table_info) { +void WALWriteState::SwitchTable(DataTableInfo &table_info, UndoFlags new_op) { + if (current_table_info != &table_info) { // write the current table to the log - log.WriteSetTable(table_info->GetSchemaName(), table_info->GetTableName()); + log.WriteSetTable(table_info.GetSchemaName(), table_info.GetTableName()); current_table_info = table_info; } } @@ -171,7 +171,7 @@ void WALWriteState::WriteCatalogEntry(CatalogEntry &entry, data_ptr_t dataptr) { void WALWriteState::WriteDelete(DeleteInfo &info) { // switch to the current table, if necessary - SwitchTable(info.table->GetDataTableInfo().get(), UndoFlags::DELETE_TUPLE); + SwitchTable(*info.table->GetDataTableInfo(), UndoFlags::DELETE_TUPLE); if (!delete_chunk) { delete_chunk = make_uniq(); @@ -198,7 +198,7 @@ void WALWriteState::WriteUpdate(UpdateInfo &info) { auto &column_data = info.segment->column_data; auto &table_info = column_data.GetTableInfo(); - SwitchTable(&table_info, UndoFlags::UPDATE_TUPLE); + SwitchTable(table_info, UndoFlags::UPDATE_TUPLE); // initialize the update chunk vector update_types; diff --git a/src/duckdb/src/verification/deserialized_statement_verifier.cpp b/src/duckdb/src/verification/deserialized_statement_verifier.cpp index 1ade815d7..3d72d6159 100644 --- a/src/duckdb/src/verification/deserialized_statement_verifier.cpp +++ b/src/duckdb/src/verification/deserialized_statement_verifier.cpp @@ -13,7 +13,6 @@ DeserializedStatementVerifier::DeserializedStatementVerifier( unique_ptr DeserializedStatementVerifier::Create(const SQLStatement &statement, optional_ptr> parameters) { - auto &select_stmt = statement.Cast(); Allocator allocator; MemoryStream stream(allocator); diff --git a/src/duckdb/src/verification/statement_verifier.cpp b/src/duckdb/src/verification/statement_verifier.cpp index 81f4c4aba..14e4c0491 100644 --- a/src/duckdb/src/verification/statement_verifier.cpp +++ b/src/duckdb/src/verification/statement_verifier.cpp @@ -1,5 +1,9 @@ #include "duckdb/verification/statement_verifier.hpp" +#include "duckdb/parser/query_node/select_node.hpp" +#include "duckdb/parser/query_node/set_operation_node.hpp" +#include "duckdb/parser/query_node/cte_node.hpp" + #include "duckdb/common/error_data.hpp" #include "duckdb/common/types/column/column_data_collection.hpp" #include "duckdb/parser/parser.hpp" @@ -15,13 +19,24 @@ namespace duckdb { +const vector> &StatementVerifier::GetSelectList(QueryNode &node) { + switch (node.type) { + case QueryNodeType::SELECT_NODE: + return node.Cast().select_list; + case QueryNodeType::SET_OPERATION_NODE: + return GetSelectList(*node.Cast().children[0]); + default: + return empty_select_list; + } +} + StatementVerifier::StatementVerifier(VerificationType type, string name, unique_ptr statement_p, optional_ptr> parameters_p) : type(type), name(std::move(name)), statement(std::move(statement_p)), select_statement(statement->type == StatementType::SELECT_STATEMENT ? &statement->Cast() : nullptr), parameters(parameters_p), - select_list(select_statement ? select_statement->node->GetSelectList() : empty_select_list) { + select_list(select_statement ? GetSelectList(*select_statement->node) : empty_select_list) { } StatementVerifier::StatementVerifier(unique_ptr statement_p, diff --git a/src/duckdb/third_party/httplib/httplib.hpp b/src/duckdb/third_party/httplib/httplib.hpp index 4aa0458dc..409c47d0b 100644 --- a/src/duckdb/third_party/httplib/httplib.hpp +++ b/src/duckdb/third_party/httplib/httplib.hpp @@ -7077,7 +7077,12 @@ inline bool ClientImpl::redirect(Request &req, Response &res, Error &error) { } auto location = res.get_header_value("location"); - if (location.empty()) { return false; } + if (location.empty()) { + // s3 requests will not return a location header, and instead a + // X-Amx-Region-Bucket header. Return true so all response headers + // are returned to the httpfs/calling extension + return true; + } const Regex re( R"((?:(https?):)?(?://(?:\[([\d:]+)\]|([^:/?#]+))(?::(\d+))?)?([^?#]*)(\?[^#]*)?(?:#.*)?)"); diff --git a/src/duckdb/third_party/libpg_query/include/pg_functions.hpp b/src/duckdb/third_party/libpg_query/include/pg_functions.hpp index bb591f75d..f33723183 100644 --- a/src/duckdb/third_party/libpg_query/include/pg_functions.hpp +++ b/src/duckdb/third_party/libpg_query/include/pg_functions.hpp @@ -3,7 +3,9 @@ #include #include +#ifndef __MVS__ #define fprintf(...) +#endif #include "pg_definitions.hpp" diff --git a/src/duckdb/third_party/libpg_query/pg_functions.cpp b/src/duckdb/third_party/libpg_query/pg_functions.cpp index 3b7a7515e..36bed9dcb 100644 --- a/src/duckdb/third_party/libpg_query/pg_functions.cpp +++ b/src/duckdb/third_party/libpg_query/pg_functions.cpp @@ -30,13 +30,8 @@ struct pg_parser_state_str { }; #ifdef __MVS__ -// -------------------------------------------------------- -// Permanent - WIP -// static __tlssim pg_parser_state_impl(); -// #define pg_parser_state (*pg_parser_state_impl.access()) -// -------------------------------------------------------- -// Temporary -static parser_state pg_parser_state; +static __tlssim pg_parser_state_impl; +#define pg_parser_state (*pg_parser_state_impl.access()) #else static __thread parser_state pg_parser_state; #endif diff --git a/src/duckdb/third_party/parquet/parquet_types.cpp b/src/duckdb/third_party/parquet/parquet_types.cpp index 95cfbc3f7..a508a69f2 100644 --- a/src/duckdb/third_party/parquet/parquet_types.cpp +++ b/src/duckdb/third_party/parquet/parquet_types.cpp @@ -1,5 +1,5 @@ /** - * Autogenerated by Thrift Compiler (0.21.0) + * Autogenerated by Thrift Compiler (0.22.0) * * DO NOT EDIT UNLESS YOU ARE SURE THAT YOU KNOW WHAT YOU ARE DOING * @generated @@ -13,6 +13,14 @@ namespace duckdb_parquet { +template +static typename ENUM::type SafeEnumCast(const std::map &values_to_names, const int &ecast) { + if (values_to_names.find(ecast) == values_to_names.end()) { + throw duckdb_apache::thrift::protocol::TProtocolException(duckdb_apache::thrift::protocol::TProtocolException::INVALID_DATA); + } + return static_cast(ecast); +} + int _kTypeValues[] = { Type::BOOLEAN, Type::INT32, @@ -176,7 +184,14 @@ int _kConvertedTypeValues[] = { * the provided duration. This duration of time is independent of any * particular timezone or date. */ - ConvertedType::INTERVAL + ConvertedType::INTERVAL, + /** + * Non-standard NULL value + * + * This was written by old writers - it is kept here for compatibility purposes. + * See https://github.com/duckdb/duckdb/pull/11774 + */ + ConvertedType::PARQUET_NULL }; const char* _kConvertedTypeNames[] = { /** @@ -300,9 +315,16 @@ const char* _kConvertedTypeNames[] = { * the provided duration. This duration of time is independent of any * particular timezone or date. */ - "INTERVAL" + "INTERVAL", + /** + * Non-standard NULL value + * + * This was written by old writers - it is kept here for compatibility purposes. + * See https://github.com/duckdb/duckdb/pull/11774 + */ + "PARQUET_NULL" }; -const std::map _ConvertedType_VALUES_TO_NAMES(::apache::thrift::TEnumIterator(22, _kConvertedTypeValues, _kConvertedTypeNames), ::apache::thrift::TEnumIterator(-1, nullptr, nullptr)); +const std::map _ConvertedType_VALUES_TO_NAMES(::apache::thrift::TEnumIterator(23, _kConvertedTypeValues, _kConvertedTypeNames), ::apache::thrift::TEnumIterator(-1, nullptr, nullptr)); std::ostream& operator<<(std::ostream& out, const ConvertedType::type& val) { std::map::const_iterator it = _ConvertedType_VALUES_TO_NAMES.find(val); @@ -3446,7 +3468,7 @@ GeographyType::~GeographyType() noexcept { GeographyType::GeographyType() noexcept : crs(), - algorithm(static_cast(0)) { + algorithm(SafeEnumCast(_EdgeInterpolationAlgorithm_VALUES_TO_NAMES, 0)) { } void GeographyType::__set_crs(const std::string& val) { @@ -3498,7 +3520,7 @@ uint32_t GeographyType::read(::apache::thrift::protocol::TProtocol* iprot) { if (ftype == ::apache::thrift::protocol::T_I32) { int32_t ecast114; xfer += iprot->readI32(ecast114); - this->algorithm = static_cast(ecast114); + this->algorithm = SafeEnumCast(_EdgeInterpolationAlgorithm_VALUES_TO_NAMES, ecast114); this->__isset.algorithm = true; } else { xfer += iprot->skip(ftype); @@ -4067,12 +4089,12 @@ SchemaElement::~SchemaElement() noexcept { } SchemaElement::SchemaElement() noexcept - : type(static_cast(0)), + : type(SafeEnumCast(_Type_VALUES_TO_NAMES, 0)), type_length(0), - repetition_type(static_cast(0)), + repetition_type(SafeEnumCast(_FieldRepetitionType_VALUES_TO_NAMES, 0)), name(), num_children(0), - converted_type(static_cast(0)), + converted_type(SafeEnumCast(_ConvertedType_VALUES_TO_NAMES, 0)), scale(0), precision(0), field_id(0) { @@ -4159,7 +4181,7 @@ uint32_t SchemaElement::read(::apache::thrift::protocol::TProtocol* iprot) { if (ftype == ::apache::thrift::protocol::T_I32) { int32_t ecast123; xfer += iprot->readI32(ecast123); - this->type = static_cast(ecast123); + this->type = SafeEnumCast(_Type_VALUES_TO_NAMES, ecast123); this->__isset.type = true; } else { xfer += iprot->skip(ftype); @@ -4177,7 +4199,7 @@ uint32_t SchemaElement::read(::apache::thrift::protocol::TProtocol* iprot) { if (ftype == ::apache::thrift::protocol::T_I32) { int32_t ecast124; xfer += iprot->readI32(ecast124); - this->repetition_type = static_cast(ecast124); + this->repetition_type = SafeEnumCast(_FieldRepetitionType_VALUES_TO_NAMES, ecast124); this->__isset.repetition_type = true; } else { xfer += iprot->skip(ftype); @@ -4203,7 +4225,7 @@ uint32_t SchemaElement::read(::apache::thrift::protocol::TProtocol* iprot) { if (ftype == ::apache::thrift::protocol::T_I32) { int32_t ecast125; xfer += iprot->readI32(ecast125); - this->converted_type = static_cast(ecast125); + this->converted_type = SafeEnumCast(_ConvertedType_VALUES_TO_NAMES, ecast125); this->__isset.converted_type = true; } else { xfer += iprot->skip(ftype); @@ -4405,9 +4427,9 @@ DataPageHeader::~DataPageHeader() noexcept { DataPageHeader::DataPageHeader() noexcept : num_values(0), - encoding(static_cast(0)), - definition_level_encoding(static_cast(0)), - repetition_level_encoding(static_cast(0)) { + encoding(SafeEnumCast(_Encoding_VALUES_TO_NAMES, 0)), + definition_level_encoding(SafeEnumCast(_Encoding_VALUES_TO_NAMES, 0)), + repetition_level_encoding(SafeEnumCast(_Encoding_VALUES_TO_NAMES, 0)) { } void DataPageHeader::__set_num_values(const int32_t val) { @@ -4474,7 +4496,7 @@ uint32_t DataPageHeader::read(::apache::thrift::protocol::TProtocol* iprot) { if (ftype == ::apache::thrift::protocol::T_I32) { int32_t ecast130; xfer += iprot->readI32(ecast130); - this->encoding = static_cast(ecast130); + this->encoding = SafeEnumCast(_Encoding_VALUES_TO_NAMES, ecast130); isset_encoding = true; } else { xfer += iprot->skip(ftype); @@ -4484,7 +4506,7 @@ uint32_t DataPageHeader::read(::apache::thrift::protocol::TProtocol* iprot) { if (ftype == ::apache::thrift::protocol::T_I32) { int32_t ecast131; xfer += iprot->readI32(ecast131); - this->definition_level_encoding = static_cast(ecast131); + this->definition_level_encoding = SafeEnumCast(_Encoding_VALUES_TO_NAMES, ecast131); isset_definition_level_encoding = true; } else { xfer += iprot->skip(ftype); @@ -4494,7 +4516,7 @@ uint32_t DataPageHeader::read(::apache::thrift::protocol::TProtocol* iprot) { if (ftype == ::apache::thrift::protocol::T_I32) { int32_t ecast132; xfer += iprot->readI32(ecast132); - this->repetition_level_encoding = static_cast(ecast132); + this->repetition_level_encoding = SafeEnumCast(_Encoding_VALUES_TO_NAMES, ecast132); isset_repetition_level_encoding = true; } else { xfer += iprot->skip(ftype); @@ -4697,7 +4719,7 @@ DictionaryPageHeader::~DictionaryPageHeader() noexcept { DictionaryPageHeader::DictionaryPageHeader() noexcept : num_values(0), - encoding(static_cast(0)), + encoding(SafeEnumCast(_Encoding_VALUES_TO_NAMES, 0)), is_sorted(0) { } @@ -4755,7 +4777,7 @@ uint32_t DictionaryPageHeader::read(::apache::thrift::protocol::TProtocol* iprot if (ftype == ::apache::thrift::protocol::T_I32) { int32_t ecast141; xfer += iprot->readI32(ecast141); - this->encoding = static_cast(ecast141); + this->encoding = SafeEnumCast(_Encoding_VALUES_TO_NAMES, ecast141); isset_encoding = true; } else { xfer += iprot->skip(ftype); @@ -4859,7 +4881,7 @@ DataPageHeaderV2::DataPageHeaderV2() noexcept : num_values(0), num_nulls(0), num_rows(0), - encoding(static_cast(0)), + encoding(SafeEnumCast(_Encoding_VALUES_TO_NAMES, 0)), definition_levels_byte_length(0), repetition_levels_byte_length(0), is_compressed(true) { @@ -4960,7 +4982,7 @@ uint32_t DataPageHeaderV2::read(::apache::thrift::protocol::TProtocol* iprot) { if (ftype == ::apache::thrift::protocol::T_I32) { int32_t ecast146; xfer += iprot->readI32(ecast146); - this->encoding = static_cast(ecast146); + this->encoding = SafeEnumCast(_Encoding_VALUES_TO_NAMES, ecast146); isset_encoding = true; } else { xfer += iprot->skip(ftype); @@ -5867,7 +5889,7 @@ PageHeader::~PageHeader() noexcept { } PageHeader::PageHeader() noexcept - : type(static_cast(0)), + : type(SafeEnumCast(_PageType_VALUES_TO_NAMES, 0)), uncompressed_page_size(0), compressed_page_size(0), crc(0) { @@ -5944,7 +5966,7 @@ uint32_t PageHeader::read(::apache::thrift::protocol::TProtocol* iprot) { if (ftype == ::apache::thrift::protocol::T_I32) { int32_t ecast179; xfer += iprot->readI32(ecast179); - this->type = static_cast(ecast179); + this->type = SafeEnumCast(_PageType_VALUES_TO_NAMES, ecast179); isset_type = true; } else { xfer += iprot->skip(ftype); @@ -6435,8 +6457,8 @@ PageEncodingStats::~PageEncodingStats() noexcept { } PageEncodingStats::PageEncodingStats() noexcept - : page_type(static_cast(0)), - encoding(static_cast(0)), + : page_type(SafeEnumCast(_PageType_VALUES_TO_NAMES, 0)), + encoding(SafeEnumCast(_Encoding_VALUES_TO_NAMES, 0)), count(0) { } @@ -6486,7 +6508,7 @@ uint32_t PageEncodingStats::read(::apache::thrift::protocol::TProtocol* iprot) { if (ftype == ::apache::thrift::protocol::T_I32) { int32_t ecast192; xfer += iprot->readI32(ecast192); - this->page_type = static_cast(ecast192); + this->page_type = SafeEnumCast(_PageType_VALUES_TO_NAMES, ecast192); isset_page_type = true; } else { xfer += iprot->skip(ftype); @@ -6496,7 +6518,7 @@ uint32_t PageEncodingStats::read(::apache::thrift::protocol::TProtocol* iprot) { if (ftype == ::apache::thrift::protocol::T_I32) { int32_t ecast193; xfer += iprot->readI32(ecast193); - this->encoding = static_cast(ecast193); + this->encoding = SafeEnumCast(_Encoding_VALUES_TO_NAMES, ecast193); isset_encoding = true; } else { xfer += iprot->skip(ftype); @@ -6593,8 +6615,8 @@ ColumnMetaData::~ColumnMetaData() noexcept { } ColumnMetaData::ColumnMetaData() noexcept - : type(static_cast(0)), - codec(static_cast(0)), + : type(SafeEnumCast(_Type_VALUES_TO_NAMES, 0)), + codec(SafeEnumCast(_CompressionCodec_VALUES_TO_NAMES, 0)), num_values(0), total_uncompressed_size(0), total_compressed_size(0), @@ -6721,7 +6743,7 @@ uint32_t ColumnMetaData::read(::apache::thrift::protocol::TProtocol* iprot) { if (ftype == ::apache::thrift::protocol::T_I32) { int32_t ecast198; xfer += iprot->readI32(ecast198); - this->type = static_cast(ecast198); + this->type = SafeEnumCast(_Type_VALUES_TO_NAMES, ecast198); isset_type = true; } else { xfer += iprot->skip(ftype); @@ -6740,7 +6762,7 @@ uint32_t ColumnMetaData::read(::apache::thrift::protocol::TProtocol* iprot) { { int32_t ecast204; xfer += iprot->readI32(ecast204); - this->encodings[_i203] = static_cast(ecast204); + this->encodings[_i203] = SafeEnumCast(_Encoding_VALUES_TO_NAMES, ecast204); } xfer += iprot->readListEnd(); } @@ -6773,7 +6795,7 @@ uint32_t ColumnMetaData::read(::apache::thrift::protocol::TProtocol* iprot) { if (ftype == ::apache::thrift::protocol::T_I32) { int32_t ecast210; xfer += iprot->readI32(ecast210); - this->codec = static_cast(ecast210); + this->codec = SafeEnumCast(_CompressionCodec_VALUES_TO_NAMES, ecast210); isset_codec = true; } else { xfer += iprot->skip(ftype); @@ -8651,7 +8673,7 @@ ColumnIndex::~ColumnIndex() noexcept { } ColumnIndex::ColumnIndex() noexcept - : boundary_order(static_cast(0)) { + : boundary_order(SafeEnumCast(_BoundaryOrder_VALUES_TO_NAMES, 0)) { } void ColumnIndex::__set_null_pages(const duckdb::vector & val) { @@ -8780,7 +8802,7 @@ uint32_t ColumnIndex::read(::apache::thrift::protocol::TProtocol* iprot) { if (ftype == ::apache::thrift::protocol::T_I32) { int32_t ecast310; xfer += iprot->readI32(ecast310); - this->boundary_order = static_cast(ecast310); + this->boundary_order = SafeEnumCast(_BoundaryOrder_VALUES_TO_NAMES, ecast310); isset_boundary_order = true; } else { xfer += iprot->skip(ftype); diff --git a/src/duckdb/third_party/parquet/parquet_types.h b/src/duckdb/third_party/parquet/parquet_types.h index a872a3d6b..762d3533a 100644 --- a/src/duckdb/third_party/parquet/parquet_types.h +++ b/src/duckdb/third_party/parquet/parquet_types.h @@ -1,5 +1,5 @@ /** - * Autogenerated by Thrift Compiler (0.21.0) + * Autogenerated by Thrift Compiler (0.22.0) * * DO NOT EDIT UNLESS YOU ARE SURE THAT YOU KNOW WHAT YOU ARE DOING * @generated @@ -178,7 +178,14 @@ struct ConvertedType { * the provided duration. This duration of time is independent of any * particular timezone or date. */ - INTERVAL = 21 + INTERVAL = 21, + /** + * Non-standard NULL value + * + * This was written by old writers - it is kept here for compatibility purposes. + * See https://github.com/duckdb/duckdb/pull/11774 + */ + PARQUET_NULL = 24 }; }; diff --git a/src/duckdb/third_party/re2/re2/re2.h b/src/duckdb/third_party/re2/re2/re2.h index f34936011..538594a2c 100644 --- a/src/duckdb/third_party/re2/re2/re2.h +++ b/src/duckdb/third_party/re2/re2/re2.h @@ -985,7 +985,7 @@ namespace hooks { // As per https://github.com/google/re2/issues/325, thread_local support in // MinGW seems to be buggy. (FWIW, Abseil folks also avoid it.) #define RE2_HAVE_THREAD_LOCAL -#if (defined(__APPLE__) && !(defined(TARGET_OS_OSX) && TARGET_OS_OSX)) || defined(__MINGW32__) +#if (defined(__APPLE__) && !(defined(TARGET_OS_OSX) && TARGET_OS_OSX)) || defined(__MINGW32__) || defined(__MVS__) #undef RE2_HAVE_THREAD_LOCAL #endif diff --git a/src/duckdb/third_party/yyjson/include/yyjson_utils.hpp b/src/duckdb/third_party/yyjson/include/yyjson_utils.hpp new file mode 100644 index 000000000..a848d44a9 --- /dev/null +++ b/src/duckdb/third_party/yyjson/include/yyjson_utils.hpp @@ -0,0 +1,33 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// yyjson_utils.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "yyjson.hpp" + +using namespace duckdb_yyjson; // NOLINT + +namespace duckdb { + +struct ConvertedJSONHolder { +public: + ~ConvertedJSONHolder() { + if (doc) { + yyjson_mut_doc_free(doc); + } + if (stringified_json) { + free(stringified_json); + } + } + +public: + yyjson_mut_doc *doc = nullptr; + char *stringified_json = nullptr; +}; + +} // namespace duckdb diff --git a/src/duckdb/ub_extension_parquet_writer_variant.cpp b/src/duckdb/ub_extension_parquet_writer_variant.cpp new file mode 100644 index 000000000..88e32d186 --- /dev/null +++ b/src/duckdb/ub_extension_parquet_writer_variant.cpp @@ -0,0 +1,2 @@ +#include "extension/parquet/writer/variant/convert_variant.cpp" + diff --git a/src/duckdb/ub_src_common_row_operations.cpp b/src/duckdb/ub_src_common_row_operations.cpp index f1ac77f8e..f8f47aee8 100644 --- a/src/duckdb/ub_src_common_row_operations.cpp +++ b/src/duckdb/ub_src_common_row_operations.cpp @@ -1,16 +1,4 @@ #include "src/common/row_operations/row_aggregate.cpp" -#include "src/common/row_operations/row_scatter.cpp" - -#include "src/common/row_operations/row_gather.cpp" - #include "src/common/row_operations/row_matcher.cpp" -#include "src/common/row_operations/row_external.cpp" - -#include "src/common/row_operations/row_radix_scatter.cpp" - -#include "src/common/row_operations/row_heap_scatter.cpp" - -#include "src/common/row_operations/row_heap_gather.cpp" - diff --git a/src/duckdb/ub_src_common_sort.cpp b/src/duckdb/ub_src_common_sort.cpp index e472e71ff..7aeebf9a5 100644 --- a/src/duckdb/ub_src_common_sort.cpp +++ b/src/duckdb/ub_src_common_sort.cpp @@ -1,12 +1,8 @@ -#include "src/common/sort/comparators.cpp" +#include "src/common/sort/hashed_sort.cpp" -#include "src/common/sort/merge_sorter.cpp" +#include "src/common/sort/sort.cpp" -#include "src/common/sort/partition_state.cpp" +#include "src/common/sort/sorted_run.cpp" -#include "src/common/sort/radix_sort.cpp" - -#include "src/common/sort/sort_state.cpp" - -#include "src/common/sort/sorted_block.cpp" +#include "src/common/sort/sorted_run_merger.cpp" diff --git a/src/duckdb/ub_src_common_sorting.cpp b/src/duckdb/ub_src_common_sorting.cpp deleted file mode 100644 index b444cb55b..000000000 --- a/src/duckdb/ub_src_common_sorting.cpp +++ /dev/null @@ -1,8 +0,0 @@ -#include "src/common/sorting/hashed_sort.cpp" - -#include "src/common/sorting/sort.cpp" - -#include "src/common/sorting/sorted_run.cpp" - -#include "src/common/sorting/sorted_run_merger.cpp" - diff --git a/src/duckdb/ub_src_common_tree_renderer.cpp b/src/duckdb/ub_src_common_tree_renderer.cpp index 65e8dfeba..bf7f6001e 100644 --- a/src/duckdb/ub_src_common_tree_renderer.cpp +++ b/src/duckdb/ub_src_common_tree_renderer.cpp @@ -8,5 +8,7 @@ #include "src/common/tree_renderer/yaml_tree_renderer.cpp" +#include "src/common/tree_renderer/mermaid_tree_renderer.cpp" + #include "src/common/tree_renderer/tree_renderer.cpp" diff --git a/src/duckdb/ub_src_common_types.cpp b/src/duckdb/ub_src_common_types.cpp index 7f181227e..5bcfc4f96 100644 --- a/src/duckdb/ub_src_common_types.cpp +++ b/src/duckdb/ub_src_common_types.cpp @@ -54,3 +54,5 @@ #include "src/common/types/vector_constants.cpp" +#include "src/common/types/geometry.cpp" + diff --git a/src/duckdb/ub_src_common_types_row.cpp b/src/duckdb/ub_src_common_types_row.cpp index 3d4ff32c2..b82384bcc 100644 --- a/src/duckdb/ub_src_common_types_row.cpp +++ b/src/duckdb/ub_src_common_types_row.cpp @@ -1,13 +1,5 @@ -#include "src/common/types/row/block_iterator.cpp" - #include "src/common/types/row/partitioned_tuple_data.cpp" -#include "src/common/types/row/row_data_collection.cpp" - -#include "src/common/types/row/row_data_collection_scanner.cpp" - -#include "src/common/types/row/row_layout.cpp" - #include "src/common/types/row/tuple_data_allocator.cpp" #include "src/common/types/row/tuple_data_collection.cpp" diff --git a/src/duckdb/ub_src_function_cast.cpp b/src/duckdb/ub_src_function_cast.cpp index fcf41bbee..99f3378ca 100644 --- a/src/duckdb/ub_src_function_cast.cpp +++ b/src/duckdb/ub_src_function_cast.cpp @@ -12,6 +12,8 @@ #include "src/function/cast/enum_casts.cpp" +#include "src/function/cast/geo_casts.cpp" + #include "src/function/cast/list_casts.cpp" #include "src/function/cast/map_cast.cpp" diff --git a/src/duckdb/ub_src_function_scalar_geometry.cpp b/src/duckdb/ub_src_function_scalar_geometry.cpp new file mode 100644 index 000000000..4c1e73842 --- /dev/null +++ b/src/duckdb/ub_src_function_scalar_geometry.cpp @@ -0,0 +1,2 @@ +#include "src/function/scalar/geometry/geometry_functions.cpp" + diff --git a/src/duckdb/ub_src_function_scalar_variant.cpp b/src/duckdb/ub_src_function_scalar_variant.cpp index a3276cf42..6fd6a062d 100644 --- a/src/duckdb/ub_src_function_scalar_variant.cpp +++ b/src/duckdb/ub_src_function_scalar_variant.cpp @@ -4,3 +4,5 @@ #include "src/function/scalar/variant/variant_typeof.cpp" +#include "src/function/scalar/variant/variant_normalize.cpp" + diff --git a/src/duckdb/ub_src_function_table_system.cpp b/src/duckdb/ub_src_function_table_system.cpp index afa17b21b..5ca818791 100644 --- a/src/duckdb/ub_src_function_table_system.cpp +++ b/src/duckdb/ub_src_function_table_system.cpp @@ -1,3 +1,5 @@ +#include "src/function/table/system/duckdb_connection_count.cpp" + #include "src/function/table/system/duckdb_approx_database_count.cpp" #include "src/function/table/system/duckdb_columns.cpp" diff --git a/src/duckdb/ub_src_main.cpp b/src/duckdb/ub_src_main.cpp index d3709dc92..f86af90b8 100644 --- a/src/duckdb/ub_src_main.cpp +++ b/src/duckdb/ub_src_main.cpp @@ -56,6 +56,8 @@ #include "src/main/query_result.cpp" +#include "src/main/result_set_manager.cpp" + #include "src/main/stream_query_result.cpp" #include "src/main/valid_checker.cpp" diff --git a/src/duckdb/ub_src_main_capi.cpp b/src/duckdb/ub_src_main_capi.cpp index 30ba6a200..84b7bc21e 100644 --- a/src/duckdb/ub_src_main_capi.cpp +++ b/src/duckdb/ub_src_main_capi.cpp @@ -8,6 +8,10 @@ #include "src/main/capi/config-c.cpp" +#include "src/main/capi/config_options-c.cpp" + +#include "src/main/capi/copy_function-c.cpp" + #include "src/main/capi/data_chunk-c.cpp" #include "src/main/capi/datetime-c.cpp" diff --git a/src/duckdb/ub_src_optimizer.cpp b/src/duckdb/ub_src_optimizer.cpp index f8238dab4..5c94d24f9 100644 --- a/src/duckdb/ub_src_optimizer.cpp +++ b/src/duckdb/ub_src_optimizer.cpp @@ -8,6 +8,8 @@ #include "src/optimizer/common_aggregate_optimizer.cpp" +#include "src/optimizer/common_subplan_optimizer.cpp" + #include "src/optimizer/compressed_materialization.cpp" #include "src/optimizer/cse_optimizer.cpp" @@ -34,6 +36,8 @@ #include "src/optimizer/late_materialization.cpp" +#include "src/optimizer/late_materialization_helper.cpp" + #include "src/optimizer/optimizer.cpp" #include "src/optimizer/regex_range_filter.cpp" @@ -48,6 +52,8 @@ #include "src/optimizer/topn_optimizer.cpp" +#include "src/optimizer/topn_window_elimination.cpp" + #include "src/optimizer/unnest_rewriter.cpp" #include "src/optimizer/sampling_pushdown.cpp" diff --git a/src/duckdb/ub_src_optimizer_rule.cpp b/src/duckdb/ub_src_optimizer_rule.cpp index 3fa057ede..2a2c56c3c 100644 --- a/src/duckdb/ub_src_optimizer_rule.cpp +++ b/src/duckdb/ub_src_optimizer_rule.cpp @@ -36,3 +36,5 @@ #include "src/optimizer/rule/timestamp_comparison.cpp" +#include "src/optimizer/rule/constant_order_normalization.cpp" + diff --git a/src/duckdb/ub_src_parallel.cpp b/src/duckdb/ub_src_parallel.cpp index eee589714..a95258810 100644 --- a/src/duckdb/ub_src_parallel.cpp +++ b/src/duckdb/ub_src_parallel.cpp @@ -1,3 +1,5 @@ +#include "src/parallel/async_result.cpp" + #include "src/parallel/base_pipeline_event.cpp" #include "src/parallel/meta_pipeline.cpp" diff --git a/src/duckdb/ub_src_parser_query_node.cpp b/src/duckdb/ub_src_parser_query_node.cpp index f0fefe80e..131571749 100644 --- a/src/duckdb/ub_src_parser_query_node.cpp +++ b/src/duckdb/ub_src_parser_query_node.cpp @@ -6,3 +6,5 @@ #include "src/parser/query_node/set_operation_node.cpp" +#include "src/parser/query_node/statement_node.cpp" + diff --git a/src/duckdb/ub_src_planner_binder_query_node.cpp b/src/duckdb/ub_src_planner_binder_query_node.cpp index 2250c80ca..acecbaf63 100644 --- a/src/duckdb/ub_src_planner_binder_query_node.cpp +++ b/src/duckdb/ub_src_planner_binder_query_node.cpp @@ -6,14 +6,12 @@ #include "src/planner/binder/query_node/bind_cte_node.cpp" +#include "src/planner/binder/query_node/bind_statement_node.cpp" + #include "src/planner/binder/query_node/bind_table_macro_node.cpp" #include "src/planner/binder/query_node/plan_query_node.cpp" -#include "src/planner/binder/query_node/plan_recursive_cte_node.cpp" - -#include "src/planner/binder/query_node/plan_cte_node.cpp" - #include "src/planner/binder/query_node/plan_select_node.cpp" #include "src/planner/binder/query_node/plan_setop.cpp" diff --git a/src/duckdb/ub_src_planner_binder_tableref.cpp b/src/duckdb/ub_src_planner_binder_tableref.cpp index b06304d78..641fd88f6 100644 --- a/src/duckdb/ub_src_planner_binder_tableref.cpp +++ b/src/duckdb/ub_src_planner_binder_tableref.cpp @@ -22,23 +22,5 @@ #include "src/planner/binder/tableref/bind_named_parameters.cpp" -#include "src/planner/binder/tableref/plan_basetableref.cpp" - -#include "src/planner/binder/tableref/plan_delimgetref.cpp" - -#include "src/planner/binder/tableref/plan_dummytableref.cpp" - -#include "src/planner/binder/tableref/plan_expressionlistref.cpp" - -#include "src/planner/binder/tableref/plan_column_data_ref.cpp" - #include "src/planner/binder/tableref/plan_joinref.cpp" -#include "src/planner/binder/tableref/plan_subqueryref.cpp" - -#include "src/planner/binder/tableref/plan_table_function.cpp" - -#include "src/planner/binder/tableref/plan_cteref.cpp" - -#include "src/planner/binder/tableref/plan_pivotref.cpp" - diff --git a/src/duckdb/ub_src_storage_statistics.cpp b/src/duckdb/ub_src_storage_statistics.cpp index 637a311d7..5f8380c90 100644 --- a/src/duckdb/ub_src_storage_statistics.cpp +++ b/src/duckdb/ub_src_storage_statistics.cpp @@ -16,3 +16,5 @@ #include "src/storage/statistics/struct_stats.cpp" +#include "src/storage/statistics/geometry_stats.cpp" +