From c6218ed41cdf755dd665a9b6611f28d44803aa48 Mon Sep 17 00:00:00 2001 From: wudidapaopao Date: Mon, 22 Sep 2025 03:18:10 +0800 Subject: [PATCH 01/13] feat: add support for querying Python Arrow tables directly --- programs/local/ArrowSchema.cpp | 17 ++ programs/local/ArrowSchema.h | 21 ++ programs/local/ArrowStreamWrapper.cpp | 255 +++++++++++++++++++++++++ programs/local/ArrowStreamWrapper.h | 99 ++++++++++ programs/local/ArrowTableReader.cpp | 204 ++++++++++++++++++++ programs/local/ArrowTableReader.h | 86 +++++++++ programs/local/CMakeLists.txt | 22 +++ programs/local/LocalServer.cpp | 7 +- programs/local/PandasDataFrame.cpp | 11 +- programs/local/PyArrowCacheItem.h | 47 +++++ programs/local/PyArrowTable.cpp | 115 +++++++++++ programs/local/PyArrowTable.h | 27 +++ programs/local/PybindWrapper.h | 10 +- programs/local/PythonImportCache.h | 4 +- programs/local/PythonImportCacheItem.h | 7 + programs/local/PythonSource.cpp | 46 +++-- programs/local/PythonSource.h | 12 +- programs/local/StoragePython.cpp | 16 +- programs/local/TableFunctionPython.cpp | 5 +- programs/local/chdb-arrow.cpp | 185 ++++++++++++++++++ programs/local/chdb-internal.h | 26 +++ programs/local/chdb.cpp | 32 +--- programs/local/chdb.h | 70 +++++++ src/Core/FormatFactorySettings.h | 2 +- tests/test_query_json.py | 7 +- tests/test_query_py.py | 52 ++--- 26 files changed, 1287 insertions(+), 98 deletions(-) create mode 100644 programs/local/ArrowSchema.cpp create mode 100644 programs/local/ArrowSchema.h create mode 100644 programs/local/ArrowStreamWrapper.cpp create mode 100644 programs/local/ArrowStreamWrapper.h create mode 100644 programs/local/ArrowTableReader.cpp create mode 100644 programs/local/ArrowTableReader.h create mode 100644 programs/local/PyArrowCacheItem.h create mode 100644 programs/local/PyArrowTable.cpp create mode 100644 programs/local/PyArrowTable.h create mode 100644 programs/local/chdb-arrow.cpp diff --git a/programs/local/ArrowSchema.cpp b/programs/local/ArrowSchema.cpp new file mode 100644 index 00000000000..d0b3a391e61 --- /dev/null +++ b/programs/local/ArrowSchema.cpp @@ -0,0 +1,17 @@ +#include "ArrowSchema.h" + +#include + +namespace CHDB +{ + +ArrowSchemaWrapper::~ArrowSchemaWrapper() +{ + if (arrow_schema.release != nullptr) + { + arrow_schema.release(&arrow_schema); + chassert(!arrow_schema.release); + } +} + +} // namespace CHDB diff --git a/programs/local/ArrowSchema.h b/programs/local/ArrowSchema.h new file mode 100644 index 00000000000..8b0318381aa --- /dev/null +++ b/programs/local/ArrowSchema.h @@ -0,0 +1,21 @@ +#pragma once + +#include + +namespace CHDB +{ + +class ArrowSchemaWrapper +{ +public: + ArrowSchema arrow_schema; + + ArrowSchemaWrapper() + { + arrow_schema.release = nullptr; + } + + ~ArrowSchemaWrapper(); +}; + +} // namespace CHDB diff --git a/programs/local/ArrowStreamWrapper.cpp b/programs/local/ArrowStreamWrapper.cpp new file mode 100644 index 00000000000..7d298ea1b3d --- /dev/null +++ b/programs/local/ArrowStreamWrapper.cpp @@ -0,0 +1,255 @@ +#include "ArrowStreamWrapper.h" +#include "PyArrowTable.h" +#include "PybindWrapper.h" +#include "PythonImporter.h" + +#include +#include +#include +#include + +namespace DB +{ + +namespace ErrorCodes +{ +extern const int PY_EXCEPTION_OCCURED; +} + +} + +namespace py = pybind11; +using namespace DB; + +namespace CHDB +{ + +/// ArrowSchemaWrapper implementation +ArrowSchemaWrapper::~ArrowSchemaWrapper() +{ + if (arrow_schema.release) + { + arrow_schema.release(&arrow_schema); + } +} + +ArrowSchemaWrapper::ArrowSchemaWrapper(ArrowSchemaWrapper && other) noexcept + : arrow_schema(other.arrow_schema) +{ + other.arrow_schema.release = nullptr; +} + +ArrowSchemaWrapper & ArrowSchemaWrapper::operator=(ArrowSchemaWrapper && other) noexcept +{ + if (this != &other) + { + if (arrow_schema.release) + { + arrow_schema.release(&arrow_schema); + } + arrow_schema = other.arrow_schema; + other.arrow_schema.release = nullptr; + } + return *this; +} + +/// ArrowArrayWrapper implementation +ArrowArrayWrapper::~ArrowArrayWrapper() +{ + if (arrow_array.release) + { + arrow_array.release(&arrow_array); + } +} + +ArrowArrayWrapper::ArrowArrayWrapper(ArrowArrayWrapper && other) noexcept + : arrow_array(other.arrow_array) +{ + other.arrow_array.release = nullptr; +} + +ArrowArrayWrapper & ArrowArrayWrapper::operator=(ArrowArrayWrapper && other) noexcept +{ + if (this != &other) + { + if (arrow_array.release) + { + arrow_array.release(&arrow_array); + } + arrow_array = other.arrow_array; + other.arrow_array.release = nullptr; + } + return *this; +} + +/// ArrowArrayStreamWrapper implementation +ArrowArrayStreamWrapper::~ArrowArrayStreamWrapper() +{ + if (arrow_array_stream.release) + { + arrow_array_stream.release(&arrow_array_stream); + } +} + +ArrowArrayStreamWrapper::ArrowArrayStreamWrapper(ArrowArrayStreamWrapper&& other) noexcept + : arrow_array_stream(other.arrow_array_stream) +{ + other.arrow_array_stream.release = nullptr; +} + +ArrowArrayStreamWrapper & ArrowArrayStreamWrapper::operator=(ArrowArrayStreamWrapper && other) noexcept +{ + if (this != &other) + { + if (arrow_array_stream.release) + { + arrow_array_stream.release(&arrow_array_stream); + } + arrow_array_stream = other.arrow_array_stream; + other.arrow_array_stream.release = nullptr; + } + return *this; +} + +void ArrowArrayStreamWrapper::getSchema(ArrowSchemaWrapper& schema) +{ + if (!isValid()) + { + throw Exception(ErrorCodes::PY_EXCEPTION_OCCURED, "ArrowArrayStream is not valid"); + } + + if (arrow_array_stream.get_schema(&arrow_array_stream, &schema.arrow_schema) != 0) + { + throw Exception(ErrorCodes::PY_EXCEPTION_OCCURED, + "Failed to get schema from ArrowArrayStream: {}", getError()); + } + + if (!schema.arrow_schema.release) + { + throw Exception(ErrorCodes::PY_EXCEPTION_OCCURED, "Released schema returned from ArrowArrayStream"); + } +} + +std::unique_ptr ArrowArrayStreamWrapper::getNextChunk() +{ + chassert(isValid()); + + auto chunk = std::make_unique(); + + /// Get next non-empty chunk, skipping empty ones + do + { + chunk->reset(); + if (arrow_array_stream.get_next(&arrow_array_stream, &chunk->arrow_array) != 0) + { + throw Exception(ErrorCodes::PY_EXCEPTION_OCCURED, + "Failed to get next chunk from ArrowArrayStream: {}", getError()); + } + + /// Check if we've reached the end of the stream + if (!chunk->arrow_array.release) + { + return nullptr; + } + } + while (chunk->arrow_array.length == 0); + + return chunk; +} + +const char* ArrowArrayStreamWrapper::getError() +{ + if (!isValid()) + { + return "ArrowArrayStream is not valid"; + } + + return arrow_array_stream.get_last_error(&arrow_array_stream); +} + +std::unique_ptr PyArrowStreamFactory::createFromPyObject( + py::object & py_obj, + const Names & column_names) +{ + py::gil_scoped_acquire acquire; + + try + { + auto arrow_object_type = PyArrowTable::getArrowType(py_obj); + + switch (arrow_object_type) + { + case PyArrowObjectType::Table: + return createFromTable(py_obj, column_names); + default: + throw Exception(ErrorCodes::PY_EXCEPTION_OCCURED, + "Unsupported PyArrow object type: {}", arrow_object_type); + } + } + catch (const py::error_already_set & e) + { + throw Exception(ErrorCodes::PY_EXCEPTION_OCCURED, + "Failed to convert PyArrow object to arrow array stream: {}", e.what()); + } +} + +std::unique_ptr PyArrowStreamFactory::createFromTable( + py::object & table, + const Names & column_names) +{ + chassert(py::gil_check()); + + py::handle table_handle(table); + auto & import_cache = PythonImporter::ImportCache(); + auto arrow_dataset = import_cache.pyarrow.dataset().attr("dataset"); + + auto dataset = arrow_dataset(table_handle); + py::object arrow_scanner = dataset.attr("__class__").attr("scanner"); + + py::dict kwargs; + if (!column_names.empty()) { + ArrowSchemaWrapper schema; + auto obj_schema = table_handle.attr("schema"); + auto export_to_c = obj_schema.attr("_export_to_c"); + export_to_c(reinterpret_cast(&schema.arrow_schema)); + + /// Get available column names from schema + std::unordered_set available_columns; + if (schema.arrow_schema.n_children > 0 && schema.arrow_schema.children) + { + for (int64_t i = 0; i < schema.arrow_schema.n_children; ++i) + { + if (schema.arrow_schema.children[i] && schema.arrow_schema.children[i]->name) + { + available_columns.insert(schema.arrow_schema.children[i]->name); + } + } + } + + /// Only add column names that exist in the schema + py::list projection_list; + for (const auto & name : column_names) + { + if (available_columns.contains(name)) + { + projection_list.append(name); + } + } + + /// Only set columns if we have valid projections + if (projection_list.size() > 0) + { + kwargs["columns"] = projection_list; + } + } + + auto scanner = arrow_scanner(dataset, **kwargs); + + auto record_batches = scanner.attr("to_reader")(); + auto res = std::make_unique(); + auto export_to_c = record_batches.attr("_export_to_c"); + export_to_c(reinterpret_cast(&res->arrow_array_stream)); + return res; +} + +} // namespace CHDB diff --git a/programs/local/ArrowStreamWrapper.h b/programs/local/ArrowStreamWrapper.h new file mode 100644 index 00000000000..51a646c497b --- /dev/null +++ b/programs/local/ArrowStreamWrapper.h @@ -0,0 +1,99 @@ +#pragma once + +#include +#include +#include +#include + +namespace CHDB +{ + +/// Wrapper for Arrow C Data Interface structures with RAII resource management +class ArrowSchemaWrapper +{ +public: + ArrowSchema arrow_schema; + + ArrowSchemaWrapper() { + arrow_schema.release = nullptr; + } + + ~ArrowSchemaWrapper(); + + /// Non-copyable but moveable + ArrowSchemaWrapper(const ArrowSchemaWrapper &) = delete; + ArrowSchemaWrapper & operator=(const ArrowSchemaWrapper &) = delete; + ArrowSchemaWrapper(ArrowSchemaWrapper && other) noexcept; + ArrowSchemaWrapper & operator=(ArrowSchemaWrapper && other) noexcept; +}; + +class ArrowArrayWrapper +{ +public: + ArrowArray arrow_array; + + ArrowArrayWrapper() + { + reset(); + } + + ~ArrowArrayWrapper(); + + void reset() + { + arrow_array.length = 0; + arrow_array.release = nullptr; + } + + /// Non-copyable but moveable + ArrowArrayWrapper(const ArrowArrayWrapper &) = delete; + ArrowArrayWrapper & operator=(const ArrowArrayWrapper &) = delete; + ArrowArrayWrapper(ArrowArrayWrapper && other) noexcept; + ArrowArrayWrapper & operator=(ArrowArrayWrapper && other) noexcept; +}; + +class ArrowArrayStreamWrapper +{ +public: + ArrowArrayStream arrow_array_stream; + + ArrowArrayStreamWrapper() { + arrow_array_stream.release = nullptr; + } + + ~ArrowArrayStreamWrapper(); + + // Non-copyable but moveable + ArrowArrayStreamWrapper(const ArrowArrayStreamWrapper&) = delete; + ArrowArrayStreamWrapper& operator=(const ArrowArrayStreamWrapper&) = delete; + ArrowArrayStreamWrapper(ArrowArrayStreamWrapper&& other) noexcept; + ArrowArrayStreamWrapper& operator=(ArrowArrayStreamWrapper&& other) noexcept; + + /// Get schema from the stream + void getSchema(ArrowSchemaWrapper& schema); + + /// Get next chunk from the stream + std::unique_ptr getNextChunk(); + + /// Get last error message + const char* getError(); + + /// Check if stream is valid + bool isValid() const { return arrow_array_stream.release != nullptr; } +}; + +/// Factory class for creating ArrowArrayStream from Python objects +class PyArrowStreamFactory +{ +public: + static std::unique_ptr createFromPyObject( + pybind11::object & py_obj, + const DB::Names & column_names); + +private: + static std::unique_ptr createFromTable( + pybind11::object & table, + const DB::Names & column_names); +}; + +} // namespace CHDB diff --git a/programs/local/ArrowTableReader.cpp b/programs/local/ArrowTableReader.cpp new file mode 100644 index 00000000000..822530706c3 --- /dev/null +++ b/programs/local/ArrowTableReader.cpp @@ -0,0 +1,204 @@ +#include "ArrowTableReader.h" + +#include +#include +#include +#include +#include +#include + +namespace DB +{ + +namespace ErrorCodes +{ +extern const int PY_EXCEPTION_OCCURED; +} + +} + +namespace py = pybind11; +using namespace DB; + +namespace CHDB +{ + +ArrowTableReader::ArrowTableReader( + py::object & data_source_, + const DB::Block & sample_block_, + const DB::FormatSettings & format_settings_, + size_t num_streams_, + size_t max_block_size_) + : sample_block(sample_block_), + format_settings(format_settings_), + num_streams(num_streams_), + max_block_size(max_block_size_), + scan_states(num_streams_) +{ + initializeStream(data_source_); +} + +void ArrowTableReader::initializeStream(py::object & data_source_) +{ + try + { + /// Create Arrow stream from Python object + arrow_stream = PyArrowStreamFactory::createFromPyObject(data_source_, sample_block.getNames()); + + if (!arrow_stream || !arrow_stream->isValid()) + { + throw Exception(ErrorCodes::PY_EXCEPTION_OCCURED, + "Failed to create valid ArrowArrayStream from Python object"); + } + } + catch (const py::error_already_set & e) + { + throw Exception(ErrorCodes::PY_EXCEPTION_OCCURED, + "Failed to initialize Arrow stream from Python object: {}", e.what()); + } + + /// Get schema from stream + arrow_stream->getSchema(schema); + auto arrow_schema_result = arrow::ImportSchema(&schema.arrow_schema); + if (!arrow_schema_result.ok()) + { + throw Exception(ErrorCodes::PY_EXCEPTION_OCCURED, + "Failed to import Arrow schema during initialization: {}", arrow_schema_result.status().message()); + } + cached_arrow_schema = arrow_schema_result.ValueOrDie(); +} + +Chunk ArrowTableReader::readNextChunk(size_t stream_index) +{ + if (stream_index >= num_streams) + { + throw Exception(ErrorCodes::PY_EXCEPTION_OCCURED, + "Stream index {} is out of range [0, {})", stream_index, num_streams); + } + + auto & state = scan_states[stream_index]; + + if (state.exhausted) + { + return {}; + } + + try + { + /// If we don't have a current array or it's exhausted, get the next one + if (!state.current_array || state.current_offset >= static_cast(state.current_array->arrow_array.length)) + { + auto next_array = getNextArrowArray(); + if (!next_array) + { + state.exhausted = true; + return {}; + } + state.current_array = std::move(next_array); + state.current_offset = 0; + state.cached_record_batch.reset(); + } + + /// Calculate how many rows to read from current array + size_t available_rows = static_cast(state.current_array->arrow_array.length) - state.current_offset; + size_t rows_to_read = std::min(max_block_size, available_rows); + + /// Convert the slice to chunk + auto chunk = convertArrowArrayToChunk(*state.current_array, state.current_offset, rows_to_read, stream_index); + + /// Update offset + state.current_offset += rows_to_read; + + return chunk; + } + catch (const Exception &) + { + state.exhausted = true; + throw; + } +} + +std::unique_ptr ArrowTableReader::getNextArrowArray() +{ + std::lock_guard lock(stream_mutex); + + if (global_stream_exhausted || !arrow_stream || !arrow_stream->isValid()) + { + return nullptr; + } + + try + { + auto arrow_array = arrow_stream->getNextChunk(); + + if (!arrow_array || arrow_array->arrow_array.length == 0) + { + global_stream_exhausted = true; + return nullptr; + } + + return arrow_array; + } + catch (const Exception &) + { + global_stream_exhausted = true; + throw; + } +} + +Chunk ArrowTableReader::convertArrowArrayToChunk(const ArrowArrayWrapper & arrow_array_wrapper, size_t offset, size_t count, size_t stream_index) +{ + chassert(arrow_array_wrapper.arrow_array.length && count && offset < arrow_array_wrapper.arrow_array.length); + chassert(count <= arrow_array_wrapper.arrow_array.length - offset); + chassert(stream_index < num_streams); + + auto & state = scan_states[stream_index]; + std::shared_ptr record_batch; + + /// Check if we have a cached RecordBatch for this ArrowArray + if (!state.cached_record_batch) + { + /// Import the full ArrowArray to RecordBatch and cache it + ArrowArray array_copy = arrow_array_wrapper.arrow_array; + + /// Set a dummy release function to prevent Arrow from freeing the underlying data + static auto dummy_release = [](ArrowArray* array) + { + // No-op: ArrowArrayWrapper will handle the actual cleanup + // But we must set release to nullptr to follow Arrow C ABI convention + array->release = nullptr; + }; + array_copy.release = dummy_release; + + /// Import the full Arrow array to Arrow RecordBatch + auto arrow_batch_result = arrow::ImportRecordBatch(&array_copy, cached_arrow_schema); + if (!arrow_batch_result.ok()) + { + throw Exception(ErrorCodes::PY_EXCEPTION_OCCURED, + "Failed to import Arrow RecordBatch: {}", arrow_batch_result.status().message()); + } + + state.cached_record_batch = arrow_batch_result.ValueOrDie(); + } + + /// Use the cached RecordBatch and slice it + record_batch = state.cached_record_batch; + auto sliced_batch = record_batch->Slice(offset, count); + auto arrow_table = arrow::Table::FromRecordBatches({sliced_batch}).ValueOrDie(); + + /// Use ArrowColumnToCHColumn to convert the batch + ArrowColumnToCHColumn converter( + sample_block, + "Arrow", + format_settings.arrow.allow_missing_columns, + format_settings.null_as_default, + format_settings.date_time_overflow_behavior, + format_settings.parquet.allow_geoparquet_parser, + format_settings.arrow.case_insensitive_column_matching, + false + ); + + return converter.arrowTableToCHChunk(arrow_table, sliced_batch->num_rows(), nullptr); +} + +} // namespace CHDB diff --git a/programs/local/ArrowTableReader.h b/programs/local/ArrowTableReader.h new file mode 100644 index 00000000000..5cba71745d5 --- /dev/null +++ b/programs/local/ArrowTableReader.h @@ -0,0 +1,86 @@ +#pragma once + +#include "ArrowStreamWrapper.h" + +#include +#include +#include +#include +#include + +namespace CHDB +{ + +/// Scan state for each stream +struct ArrowScanState +{ + /// Current Arrow array being processed + std::unique_ptr current_array; + /// Current offset within the array + size_t current_offset = 0; + /// Whether this stream is exhausted + bool exhausted = false; + /// Cached imported RecordBatch to avoid repeated imports + std::shared_ptr cached_record_batch; + + void reset() + { + current_array.reset(); + current_offset = 0; + exhausted = false; + cached_record_batch.reset(); + } +}; + +class ArrowTableReader; +using ArrowTableReaderPtr = std::shared_ptr; + +class ArrowTableReader +{ +public: + ArrowTableReader( + pybind11::object & data_source_, + const DB::Block & sample_block_, + const DB::FormatSettings & format_settings_, + size_t num_streams_, + size_t max_block_size_); + + ~ArrowTableReader() = default; + + /// Read next chunk from the specified stream + DB::Chunk readNextChunk(size_t stream_index); + +private: + /// Initialize the Arrow stream from Python object + void initializeStream(pybind11::object & data_source_); + + /// Convert Arrow array slice to ClickHouse chunk + DB::Chunk convertArrowArrayToChunk(const ArrowArrayWrapper & arrow_array, size_t offset, size_t count, size_t stream_index); + + /// Get next Arrow array from stream + std::unique_ptr getNextArrowArray(); + + DB::Block sample_block; + DB::FormatSettings format_settings; + std::unique_ptr arrow_stream; + ArrowSchemaWrapper schema; + + /// Cached Arrow schema to avoid repeated imports + std::shared_ptr cached_arrow_schema; + + /// Multi-stream scanning parameters + size_t num_streams; + size_t max_block_size; + + /// Scan states for each stream + std::vector scan_states; + + /// Global stream state + bool global_stream_exhausted = false; + size_t total_rows_hint = 0; + + /// Mutex for thread-safe access to arrow_stream + mutable std::mutex stream_mutex; +}; + +} // namespace CHDB diff --git a/programs/local/CMakeLists.txt b/programs/local/CMakeLists.txt index 83095fe2dd0..9e32398c67f 100644 --- a/programs/local/CMakeLists.txt +++ b/programs/local/CMakeLists.txt @@ -3,6 +3,21 @@ set (CLICKHOUSE_LOCAL_SOURCES LocalServer.cpp ) +# Add ArrowStream table function sources when not using Python +if (NOT USE_PYTHON) + set (CHDB_ARROW_SOURCES + chdb-arrow.cpp + ArrowStreamRegistry.h + ArrowStreamSource.cpp + ArrowStreamSource.h + StorageArrowStream.cpp + StorageArrowStream.h + TableFunctionArrowStream.cpp + TableFunctionArrowStream.h + ) + set (CLICKHOUSE_LOCAL_SOURCES ${CLICKHOUSE_LOCAL_SOURCES} ${CHDB_ARROW_SOURCES}) +endif() + # Add force function references only for static library builds if (CHDB_STATIC_LIBRARY_BUILD) list(APPEND CLICKHOUSE_LOCAL_SOURCES ForceFunctionReferences.cpp) @@ -12,6 +27,9 @@ endif() if (USE_PYTHON) set (CHDB_SOURCES chdb.cpp + ArrowSchema.cpp + ArrowStreamWrapper.cpp + ArrowTableReader.cpp FormatHelper.cpp ListScan.cpp LocalChdb.cpp @@ -20,6 +38,7 @@ if (USE_PYTHON) PandasAnalyzer.cpp PandasDataFrame.cpp PandasScan.cpp + PyArrowTable.cpp PybindWrapper.cpp PythonConversion.cpp PythonDict.cpp @@ -117,6 +136,9 @@ endif() if (TARGET ch_contrib::utf8proc) target_link_libraries(clickhouse-local-lib PRIVATE ch_contrib::utf8proc) endif() +if (TARGET ch_contrib::arrow) + target_link_libraries(clickhouse-local-lib PRIVATE ch_contrib::arrow) +endif() if (TARGET ch_contrib::pybind11_stubs) target_link_libraries(clickhouse-local-lib PRIVATE ch_contrib::pybind11_stubs) target_compile_definitions(clickhouse-local-lib PRIVATE Py_LIMITED_API=0x03080000) diff --git a/programs/local/LocalServer.cpp b/programs/local/LocalServer.cpp index edf3f67ad20..319c83df28a 100644 --- a/programs/local/LocalServer.cpp +++ b/programs/local/LocalServer.cpp @@ -5,6 +5,8 @@ #include "TableFunctionPython.h" #include #include +#else +#include "TableFunctionArrowStream.h" #endif #include @@ -645,9 +647,12 @@ try registerAggregateFunctions(); registerTableFunctions(); -#if USE_PYTHON + auto & table_function_factory = TableFunctionFactory::instance(); +#if USE_PYTHON registerTableFunctionPython(table_function_factory); +#else + registerTableFunctionArrowStream(table_function_factory); #endif registerDatabases(); diff --git a/programs/local/PandasDataFrame.cpp b/programs/local/PandasDataFrame.cpp index e6841ce3937..c304dabbf0a 100644 --- a/programs/local/PandasDataFrame.cpp +++ b/programs/local/PandasDataFrame.cpp @@ -22,13 +22,6 @@ using namespace DB; namespace CHDB { -template -static bool ModuleIsLoaded() -{ - auto dict = pybind11::module_::import("sys").attr("modules"); - return dict.contains(py::str(T::Name)); -} - struct PandasBindColumn { public: PandasBindColumn(py::handle name, py::handle type, py::object column) @@ -92,6 +85,8 @@ static DataTypePtr inferDataTypeFromPandasColumn(PandasBindColumn & column, Cont ColumnsDescription PandasDataFrame::getActualTableStructure(const py::object & object, ContextPtr & context) { + chassert(py::gil_check()); + NamesAndTypesList names_and_types; PandasDataFrameBind df(object); @@ -116,6 +111,8 @@ ColumnsDescription PandasDataFrame::getActualTableStructure(const py::object & o bool PandasDataFrame::isPandasDataframe(const py::object & object) { + chassert(py::gil_check()); + if (!ModuleIsLoaded()) return false; diff --git a/programs/local/PyArrowCacheItem.h b/programs/local/PyArrowCacheItem.h new file mode 100644 index 00000000000..494cf149870 --- /dev/null +++ b/programs/local/PyArrowCacheItem.h @@ -0,0 +1,47 @@ +#pragma once + +#include "PythonImportCacheItem.h" + +namespace CHDB +{ + +struct PyarrowIpcCacheItem : public PythonImportCacheItem +{ + explicit PyarrowIpcCacheItem(PythonImportCacheItem * parent) + : PythonImportCacheItem("ipc", parent), message_reader("MessageReader", this) + {} + ~PyarrowIpcCacheItem() override = default; + + PythonImportCacheItem message_reader; +}; + +struct PyarrowDatasetCacheItem : public PythonImportCacheItem +{ + static constexpr const char * Name = "pyarrow.dataset"; + + PyarrowDatasetCacheItem() + : PythonImportCacheItem("pyarrow.dataset"), scanner("Scanner", this), dataset("Dataset", this) + {} + ~PyarrowDatasetCacheItem() override = default; + + PythonImportCacheItem scanner; + PythonImportCacheItem dataset; +}; + +struct PyarrowCacheItem : public PythonImportCacheItem +{ + static constexpr const char * Name = "pyarrow"; + + PyarrowCacheItem() + : PythonImportCacheItem("pyarrow"), dataset(), table("Table", this), + record_batch_reader("RecordBatchReader", this), ipc(this) + {} + ~PyarrowCacheItem() override = default; + + PyarrowDatasetCacheItem dataset; + PythonImportCacheItem table; + PythonImportCacheItem record_batch_reader; + PyarrowIpcCacheItem ipc; +}; + +} // namespace CHDB diff --git a/programs/local/PyArrowTable.cpp b/programs/local/PyArrowTable.cpp new file mode 100644 index 00000000000..83bdc99b761 --- /dev/null +++ b/programs/local/PyArrowTable.cpp @@ -0,0 +1,115 @@ +#include "PyArrowTable.h" +#include "ArrowSchema.h" +#include "PyArrowCacheItem.h" +#include "PythonImporter.h" + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace DB +{ + +namespace ErrorCodes +{ +extern const int BAD_ARGUMENTS; +extern const int PY_EXCEPTION_OCCURED; +} + +} + +using namespace DB; + +namespace CHDB +{ + +static void convertArrowSchema( + ArrowSchemaWrapper & schema, + NamesAndTypesList & names_and_types, + ContextPtr & context) +{ + if (!schema.arrow_schema.release) + { + throw Exception(ErrorCodes::BAD_ARGUMENTS, "ArrowSchema is already released"); + } + + /// Import ArrowSchema to arrow::Schema + auto arrow_schema_result = arrow::ImportSchema(&schema.arrow_schema); + if (!arrow_schema_result.ok()) + { + throw Exception(ErrorCodes::BAD_ARGUMENTS, + "Failed to import Arrow schema: {}", arrow_schema_result.status().message()); + } + + const auto & arrow_schema = arrow_schema_result.ValueOrDie(); + + const auto format_settings = getFormatSettings(context); + + /// Convert Arrow schema to ClickHouse header + auto block = ArrowColumnToCHColumn::arrowSchemaToCHHeader( + *arrow_schema, + nullptr, + "Arrow", + format_settings.arrow.skip_columns_with_unsupported_types_in_schema_inference, + format_settings.schema_inference_make_columns_nullable != 0, + false, + format_settings.parquet.allow_geoparquet_parser); + + for (const auto & column : block) + { + names_and_types.emplace_back(column.name, column.type); + } +} + +PyArrowObjectType PyArrowTable::getArrowType(const py::object & obj) +{ + chassert(py::gil_check()); + + if (ModuleIsLoaded()) + { + auto & import_cache = PythonImporter::ImportCache(); + auto table_class = import_cache.pyarrow.table(); + + if (py::isinstance(obj, table_class)) + return PyArrowObjectType::Table; + } + + return PyArrowObjectType::Invalid; +} + +bool PyArrowTable::isPyArrowTable(const py::object & object) +{ + try + { + return getArrowType(object) == PyArrowObjectType::Table; + } + catch (const py::error_already_set &) + { + return false; + } +} + +ColumnsDescription PyArrowTable::getActualTableStructure(const py::object & object, ContextPtr & context) +{ + chassert(py::gil_check()); + chassert(isPyArrowTable(object)); + + NamesAndTypesList names_and_types; + + auto obj_schema = object.attr("schema"); + auto export_to_c = obj_schema.attr("_export_to_c"); + ArrowSchemaWrapper schema; + export_to_c(reinterpret_cast(&schema.arrow_schema)); + + convertArrowSchema(schema, names_and_types, context); + + return ColumnsDescription(names_and_types); +} + +} // namespace CHDB diff --git a/programs/local/PyArrowTable.h b/programs/local/PyArrowTable.h new file mode 100644 index 00000000000..d4fccb634ca --- /dev/null +++ b/programs/local/PyArrowTable.h @@ -0,0 +1,27 @@ +#pragma once + +#include "PybindWrapper.h" + +#include +#include + +namespace CHDB +{ + +enum class PyArrowObjectType +{ + Invalid, + Table +}; + +class PyArrowTable +{ +public: + static DB::ColumnsDescription getActualTableStructure(const py::object & object, DB::ContextPtr & context); + + static bool isPyArrowTable(const py::object & object); + + static PyArrowObjectType getArrowType(const py::object & object); +}; + +} // namespace CHDB diff --git a/programs/local/PybindWrapper.h b/programs/local/PybindWrapper.h index d653ab1ea73..17e630f10eb 100644 --- a/programs/local/PybindWrapper.h +++ b/programs/local/PybindWrapper.h @@ -4,15 +4,19 @@ #include #include -namespace pybind11 { +namespace pybind11 +{ +bool gil_check(); void gil_assert(); } -namespace CHDB { +namespace CHDB +{ -namespace py { +namespace py +{ using namespace pybind11; diff --git a/programs/local/PythonImportCache.h b/programs/local/PythonImportCache.h index 516ed057875..6bdf5cf7c8f 100644 --- a/programs/local/PythonImportCache.h +++ b/programs/local/PythonImportCache.h @@ -3,6 +3,7 @@ #include "DatetimeCacheItem.h" #include "DecimalCacheItem.h" #include "PandasCacheItem.h" +#include "PyArrowCacheItem.h" #include "PythonImportCacheItem.h" #include @@ -18,12 +19,11 @@ struct PythonImportCache { ~PythonImportCache(); -public: PandasCacheItem pandas; + PyarrowCacheItem pyarrow; DatetimeCacheItem datetime; DecimalCacheItem decimal; -public: py::handle AddCache(py::object item); private: diff --git a/programs/local/PythonImportCacheItem.h b/programs/local/PythonImportCacheItem.h index bea49e5e02c..5908bbf57ac 100644 --- a/programs/local/PythonImportCacheItem.h +++ b/programs/local/PythonImportCacheItem.h @@ -6,6 +6,13 @@ namespace CHDB { +template +static bool ModuleIsLoaded() +{ + auto dict = pybind11::module_::import("sys").attr("modules"); + return dict.contains(py::str(T::Name)); +} + struct PythonImportCache; struct PythonImportCacheItem { diff --git a/programs/local/PythonSource.cpp b/programs/local/PythonSource.cpp index cfecdbaab1d..d2a1435eb01 100644 --- a/programs/local/PythonSource.cpp +++ b/programs/local/PythonSource.cpp @@ -5,9 +5,7 @@ #include "StoragePython.h" #include -#include #include -#include #include #include #include @@ -20,7 +18,6 @@ #include #include #include -#include "PythonUtils.h" #include #include #include @@ -37,6 +34,8 @@ #include #include +using namespace CHDB; + namespace DB { @@ -58,7 +57,8 @@ PythonSource::PythonSource( size_t max_block_size_, size_t stream_index, size_t num_streams, - const FormatSettings & format_settings_) + const FormatSettings & format_settings_, + ArrowTableReaderPtr arrow_table_reader_) : ISource(sample_block_.cloneEmpty()) , data_source(data_source_) , isInheritsFromPyReader(isInheritsFromPyReader_) @@ -70,6 +70,7 @@ PythonSource::PythonSource( , num_streams(num_streams) , cursor(0) , format_settings(format_settings_) + , arrow_table_reader(arrow_table_reader_) { } @@ -438,14 +439,7 @@ Chunk PythonSource::scanDataToChunk() if (names.size() != columns.size()) throw Exception(ErrorCodes::PY_EXCEPTION_OCCURED, "Column cache size mismatch"); - auto rows_per_stream = data_source_row_count / num_streams; - auto start = stream_index * rows_per_stream; - auto end = (stream_index + 1) * rows_per_stream; - if (stream_index == num_streams - 1) - end = data_source_row_count; - if (cursor == 0) - cursor = start; - auto count = std::min(max_block_size, end - cursor); + auto [offset, count] = calculateOffsetAndCount(); if (count == 0) return {}; LOG_DEBUG(logger, "Stream index {} Reading {} rows from {}", stream_index, count, cursor); @@ -554,7 +548,6 @@ Chunk PythonSource::scanDataToChunk() return Chunk(std::move(columns), count); } - Chunk PythonSource::generate() { size_t num_rows = 0; @@ -564,6 +557,12 @@ Chunk PythonSource::generate() try { + if (arrow_table_reader) + { + auto chunk = arrow_table_reader->readNextChunk(stream_index); + return chunk; + } + if (isInheritsFromPyReader) { PyObjectVecPtr data; @@ -574,10 +573,8 @@ Chunk PythonSource::generate() return std::move(genChunk(num_rows, data)); } - else - { - return std::move(scanDataToChunk()); - } + + return std::move(scanDataToChunk()); } catch (const Exception & e) { @@ -596,4 +593,19 @@ Chunk PythonSource::generate() throw Exception(ErrorCodes::PY_EXCEPTION_OCCURED, "Python data handling unknown exception"); } } + +std::pair PythonSource::calculateOffsetAndCount() +{ + auto rows_per_stream = data_source_row_count / num_streams; + auto start = stream_index * rows_per_stream; + auto end = (stream_index + 1) * rows_per_stream; + if (stream_index == num_streams - 1) + end = data_source_row_count; + if (cursor == 0) + cursor = start; + auto count = std::min(max_block_size, end - cursor); + + return std::make_pair(cursor, count); } + +} \ No newline at end of file diff --git a/programs/local/PythonSource.h b/programs/local/PythonSource.h index 610fe823679..1ef138c68a6 100644 --- a/programs/local/PythonSource.h +++ b/programs/local/PythonSource.h @@ -1,16 +1,15 @@ #pragma once +#include "ArrowTableReader.h" +#include "PythonUtils.h" #include "config.h" -#include #include - #include #include #include #include #include -#include "PythonUtils.h" namespace DB { @@ -32,7 +31,8 @@ class PythonSource : public ISource size_t max_block_size_, size_t stream_index, size_t num_streams, - const FormatSettings & format_settings_); + const FormatSettings & format_settings_, + CHDB::ArrowTableReaderPtr arrow_table_reader_ = nullptr); ~PythonSource() override = default; @@ -59,6 +59,8 @@ class PythonSource : public ISource const FormatSettings format_settings; + CHDB::ArrowTableReaderPtr arrow_table_reader; + Chunk genChunk(size_t & num_rows, PyObjectVecPtr data); PyObjectVecPtr scanData(const py::object & data, const std::vector & col_names, size_t & cursor, size_t count); @@ -78,8 +80,8 @@ class PythonSource : public ISource void insert_string_from_array(py::handle obj, const MutableColumnPtr & column); - void prepareColumnCache(Names & names, Columns & columns); Chunk scanDataToChunk(); void destory(PyObjectVecPtr & data); + std::pair calculateOffsetAndCount(); }; } diff --git a/programs/local/StoragePython.cpp b/programs/local/StoragePython.cpp index c97ad161117..a60f52dc555 100644 --- a/programs/local/StoragePython.cpp +++ b/programs/local/StoragePython.cpp @@ -2,6 +2,7 @@ #include "FormatHelper.h" #include "PybindWrapper.h" #include "PythonSource.h" +#include "PyArrowTable.h" #include #include @@ -31,6 +32,8 @@ #include +using namespace CHDB; + namespace DB { @@ -80,11 +83,20 @@ Pipe StoragePython::read( prepareColumnCache(column_names, sample_block.getColumns(), sample_block); + ArrowTableReaderPtr arrow_table_reader; + { + py::gil_scoped_acquire acquire; + if (PyArrowTable::isPyArrowTable(data_source)) + { + arrow_table_reader = std::make_shared(data_source, sample_block, + format_settings, num_streams, max_block_size); + } + } + Pipes pipes; - // num_streams = 32; // for chdb testing for (size_t stream = 0; stream < num_streams; ++stream) pipes.emplace_back(std::make_shared( - data_source, false, sample_block, column_cache, data_source_row_count, max_block_size, stream, num_streams, format_settings)); + data_source, false, sample_block, column_cache, data_source_row_count, max_block_size, stream, num_streams, format_settings, arrow_table_reader)); return Pipe::unitePipes(std::move(pipes)); } diff --git a/programs/local/TableFunctionPython.cpp b/programs/local/TableFunctionPython.cpp index 042a41fcaa8..5967bebf381 100644 --- a/programs/local/TableFunctionPython.cpp +++ b/programs/local/TableFunctionPython.cpp @@ -1,5 +1,6 @@ #include "StoragePython.h" #include "PandasDataFrame.h" +#include "PyArrowTable.h" #include "PythonDict.h" #include "PythonReader.h" #include "PythonTableCache.h" @@ -39,7 +40,6 @@ extern const int UNKNOWN_FORMAT; void TableFunctionPython::parseArguments(const ASTPtr & ast_function, ContextPtr context) { - // py::gil_scoped_acquire acquire; const auto & func_args = ast_function->as(); if (!func_args.arguments) @@ -115,6 +115,9 @@ ColumnsDescription TableFunctionPython::getActualTableStructure(ContextPtr conte if (PandasDataFrame::isPandasDataframe(reader)) return PandasDataFrame::getActualTableStructure(reader, context); + if (PyArrowTable::isPyArrowTable(reader)) + return PyArrowTable::getActualTableStructure(reader, context); + if (PythonDict::isPythonDict(reader)) return PythonDict::getActualTableStructure(reader, context); diff --git a/programs/local/chdb-arrow.cpp b/programs/local/chdb-arrow.cpp new file mode 100644 index 00000000000..d8d07d7e764 --- /dev/null +++ b/programs/local/chdb-arrow.cpp @@ -0,0 +1,185 @@ +#include "chdb.h" +#include "chdb-internal.h" +#include "ArrowStreamRegistry.h" + +#include +#include +#include + +namespace CHDB +{ + +struct PrivateData { + ArrowSchema * schema; + ArrowArray * array; + bool done = false; +}; + +void EmptySchemaRelease(ArrowSchema * schema) +{ + schema->release = nullptr; +} + +void EmptyArrayRelease(ArrowArray * array) +{ + array->release = nullptr; +} + +void EmptyStreamRelease(ArrowArrayStream * stream) +{ + stream->release = nullptr; +} + +int GetSchema(struct ArrowArrayStream * stream, struct ArrowSchema * out) +{ + auto * private_data = static_cast((stream->private_data)); + if (private_data->schema == nullptr) + return CHDBError; + + *out = *private_data->schema; + out->release = EmptySchemaRelease; + return CHDBSuccess; +} + +int GetNext(struct ArrowArrayStream * stream, struct ArrowArray * out) +{ + auto * private_data = static_cast((stream->private_data)); + *out = *private_data->array; + if (private_data->done) + { + out->release = nullptr; + } + else + { + out->release = EmptyArrayRelease; + } + + private_data->done = true; + return CHDBSuccess; +} + +const char * GetLastError(struct ArrowArrayStream * /*stream*/) +{ + return nullptr; +} + +void Release(struct ArrowArrayStream * stream) +{ + if (stream->private_data != nullptr) + delete reinterpret_cast(stream->private_data); + + stream->private_data = nullptr; + stream->release = nullptr; +} + +} // namespace CHDB + +chdb_state chdb_arrow_scan( + chdb_connection conn, const char * table_name, + chdb_arrow_stream arrow_stream) +{ + ChdbDestructorGuard guard; + + std::shared_lock global_lock(global_connection_mutex); + + if (!table_name || !arrow_stream) + return CHDBError; + + auto * connection = reinterpret_cast(conn); + if (!checkConnectionValidity(connection)) + return CHDBError; + + auto * stream = reinterpret_cast(arrow_stream); + + ArrowSchema schema; + if (stream->get_schema(stream, &schema) == CHDBError) + return CHDBError; + + using ReleaseFunction = void (*)(ArrowSchema *); + std::vector releases(static_cast(schema.n_children)); + for (size_t i = 0; i < static_cast(schema.n_children); i++) + { + auto * child = schema.children[i]; + releases[i] = child->release; + child->release = CHDB::EmptySchemaRelease; + } + + try + { + bool success = DB::ArrowStreamRegistry::instance().registerArrowStream(String(table_name), stream); + return success ? CHDBSuccess : CHDBError; + } + catch (...) + { + return CHDBError; + } + + for (size_t i = 0; i < static_cast(schema.n_children); ++i) + { + schema.children[i]->release = releases[i]; + } + + return CHDBSuccess; +} + +chdb_state chdb_arrow_array_scan( + chdb_connection conn, const char * table_name, + chdb_arrow_schema arrow_schema, chdb_arrow_array arrow_array, + chdb_arrow_stream * out_stream) +{ + auto * private_data = new CHDB::PrivateData; + private_data->schema = reinterpret_cast(arrow_schema); + private_data->array = reinterpret_cast(arrow_array); + private_data->done = false; + + auto * stream = new ArrowArrayStream(); + *out_stream = reinterpret_cast(stream); + stream->get_schema = CHDB::GetSchema; + stream->get_next = CHDB::GetNext; + stream->get_last_error = CHDB::GetLastError; + stream->release = CHDB::Release; + stream->private_data = private_data; + + return chdb_arrow_scan(conn, table_name, reinterpret_cast(stream)); +} + +void chdb_destroy_arrow_stream(chdb_arrow_stream * arrow_stream) +{ + if (!arrow_stream) + return; + + auto * stream = reinterpret_cast(*arrow_stream); + if (!stream) + return; + + if (stream->release) + stream->release(stream); + chassert(!stream->release); + + delete stream; + *arrow_stream = nullptr; +} + +chdb_state chdb_arrow_unregister_table(chdb_connection conn, const char * table_name) +{ + ChdbDestructorGuard guard; + + std::shared_lock global_lock(global_connection_mutex); + + if (!table_name) + return CHDBError; + + auto * connection = reinterpret_cast(conn); + if (!checkConnectionValidity(connection)) + return CHDBError; + + try + { + DB::ArrowStreamRegistry::instance().unregisterArrowStream(String(table_name)); + return CHDBSuccess; + } + catch (...) + { + return CHDBError; + } +} diff --git a/programs/local/chdb-internal.h b/programs/local/chdb-internal.h index 0b93907d981..45036098d3b 100644 --- a/programs/local/chdb-internal.h +++ b/programs/local/chdb-internal.h @@ -6,6 +6,7 @@ #include #include #include +#include #include #include @@ -14,6 +15,31 @@ namespace DB class LocalServer; } +extern std::shared_mutex global_connection_mutex; +extern thread_local bool chdb_destructor_cleanup_in_progress; + +/** + * RAII guard for accurate memory tracking in chDB external interfaces + * used at the beginning of execution to provide thread marking, enabling MemoryTracker + * to accurately track memory changes. + */ +class ChdbDestructorGuard +{ +public: + ChdbDestructorGuard() { chdb_destructor_cleanup_in_progress = true; } + ~ChdbDestructorGuard() { chdb_destructor_cleanup_in_progress = false; } + ChdbDestructorGuard(const ChdbDestructorGuard &) = delete; + ChdbDestructorGuard & operator=(const ChdbDestructorGuard &) = delete; + ChdbDestructorGuard(ChdbDestructorGuard &&) = delete; + ChdbDestructorGuard & operator=(ChdbDestructorGuard &&) = delete; +}; + +/// Connection validity check function +inline bool checkConnectionValidity(chdb_conn * connection) +{ + return connection && connection->connected && connection->queue; +} + namespace CHDB { diff --git a/programs/local/chdb.cpp b/programs/local/chdb.cpp index bd933e6ab93..034f8b5fecc 100644 --- a/programs/local/chdb.cpp +++ b/programs/local/chdb.cpp @@ -22,35 +22,10 @@ namespace DB #endif extern thread_local bool chdb_destructor_cleanup_in_progress; +std::shared_mutex global_connection_mutex; namespace CHDB { - -/** - * RAII guard for accurate memory tracking in chDB external interfaces - * - * When Python (or other programming language) threads call chDB-provided interfaces - * such as chdb_destroy_query_result, the memory released cannot be accurately tracked - * by ClickHouse's MemoryTracker, which may lead to false reports of insufficient memory. - * - * Therefore, for all externally exposed chDB interfaces, ChdbDestructorGuard must be - * used at the beginning of execution to provide thread marking, enabling MemoryTracker - * to accurately track memory changes. - */ -class ChdbDestructorGuard -{ -public: - ChdbDestructorGuard() { chdb_destructor_cleanup_in_progress = true; } - - ~ChdbDestructorGuard() { chdb_destructor_cleanup_in_progress = false; } - - ChdbDestructorGuard(const ChdbDestructorGuard &) = delete; - ChdbDestructorGuard & operator=(const ChdbDestructorGuard &) = delete; - ChdbDestructorGuard(ChdbDestructorGuard &&) = delete; - ChdbDestructorGuard & operator=(ChdbDestructorGuard &&) = delete; -}; - -static std::shared_mutex global_connection_mutex; static std::mutex CHDB_MUTEX; chdb_conn * global_conn_ptr = nullptr; std::string global_db_path; @@ -298,11 +273,6 @@ static std::pair createQueryResult(DB::LocalServer * serve return std::make_pair(std::move(query_result), is_end); } -static bool checkConnectionValidity(chdb_conn * conn) -{ - return conn && conn->connected && conn->queue; -} - static QueryResultPtr executeQueryRequest( CHDB::QueryQueue * queue, const char * query, diff --git a/programs/local/chdb.h b/programs/local/chdb.h index e4ecaddb4d9..cebee1a36c1 100644 --- a/programs/local/chdb.h +++ b/programs/local/chdb.h @@ -66,6 +66,13 @@ typedef struct #endif +// Return state enumeration for chDB API functions +typedef enum chdb_state +{ + CHDBSuccess = 0, + CHDBError = 1 +} chdb_state; + // Opaque handle for query results. // Internal data structure managed by chDB implementation. // Users should only interact through API functions. @@ -82,6 +89,25 @@ typedef struct chdb_connection_ void * internal_data; } * chdb_connection; +// Holds an arrow array stream. Wraps ArrowArrayStream for chdb usage. +// Must be released with chdb_destroy_arrow_stream when no longer needed. +typedef struct _chdb_arrow_stream +{ + void * internal_data; +} * chdb_arrow_stream; + +// Holds an arrow schema. Wraps ArrowSchema for chdb usage. +typedef struct _chdb_arrow_schema +{ + void * internal_data; +} * chdb_arrow_schema; + +// Holds an arrow array. Wraps ArrowArray for chdb usage. +typedef struct _chdb_arrow_array +{ + void * internal_data; +} * chdb_arrow_array; + #ifndef CHDB_NO_DEPRECATED // WARNING: The following interfaces are deprecated and will be removed in a future version. CHDB_EXPORT struct local_result * query_stable(int argc, char ** argv); @@ -375,6 +401,50 @@ CHDB_EXPORT uint64_t chdb_result_storage_bytes_read(chdb_result * result); */ CHDB_EXPORT const char * chdb_result_error(chdb_result * result); +//===--------------------------------------------------------------------===// +// Arrow Integration +//===--------------------------------------------------------------------===// + +/** + * Registers an Arrow stream as an arrow stream table function with the given name + * @param conn The connection on which to execute the registration + * @param table_name Name to register for the arrow stream table function + * @param arrow_stream chdb Arrow stream handle + * @param arrow_schema chdb Arrow schema handle + * @return CHDBSuccess on success, CHDBError on failure + */ +CHDB_EXPORT chdb_state chdb_arrow_scan( + chdb_connection conn, const char * table_name, + chdb_arrow_stream arrow_stream); + +/** + * Registers an Arrow array as an arrow stream table function with the given name + * @param conn The connection on which to execute the registration + * @param table_name Name to register for the arrow stream table function + * @param arrow_schema chdb Arrow schema handle + * @param arrow_array chdb Arrow array handle + * @param out_stream Optional output stream handle for result streaming + * @return CHDBSuccess on success, CHDBError on failure + */ +CHDB_EXPORT chdb_state chdb_arrow_array_scan( + chdb_connection conn, const char * table_name, + chdb_arrow_schema arrow_schema, chdb_arrow_array arrow_array, + chdb_arrow_stream * out_stream); + +/** + * Destroys and releases resources for an Arrow stream handle + * @param arrow_stream Pointer to the Arrow stream handle to destroy + */ +CHDB_EXPORT void chdb_destroy_arrow_stream(chdb_arrow_stream * arrow_stream); + +/** + * Unregisters an arrow stream table function that was previously registered via chdb_arrow_scan + * @param conn The connection on which to execute the unregister operation + * @param table_name Name of the arrow stream table function to unregister + * @return CHDBSuccess on success, CHDBError on failure + */ +CHDB_EXPORT chdb_state chdb_arrow_unregister_table(chdb_connection conn, const char * table_name); + #ifdef __cplusplus } #endif diff --git a/src/Core/FormatFactorySettings.h b/src/Core/FormatFactorySettings.h index 0fe39a82028..02ae9763348 100644 --- a/src/Core/FormatFactorySettings.h +++ b/src/Core/FormatFactorySettings.h @@ -301,7 +301,7 @@ Skip columns with unsupported types while schema inference for format CapnProto DECLARE(Bool, input_format_orc_skip_columns_with_unsupported_types_in_schema_inference, false, R"( Skip columns with unsupported types while schema inference for format ORC )", 0) \ - DECLARE(Bool, input_format_arrow_skip_columns_with_unsupported_types_in_schema_inference, false, R"( + DECLARE(Bool, input_format_arrow_skip_columns_with_unsupported_types_in_schema_inference, true, R"( Skip columns with unsupported types while schema inference for format Arrow )", 0) \ DECLARE(String, column_names_for_schema_inference, "", R"( diff --git a/tests/test_query_json.py b/tests/test_query_json.py index c52a24f589c..b4b1f3b35d1 100644 --- a/tests/test_query_json.py +++ b/tests/test_query_json.py @@ -21,6 +21,9 @@ \\N,\\N,"[1,666]" """ EXPECTED2 = '"apple1",3,\\N\n\\N,4,2\n' +EXPECTED3 = """"['urgent','important']",100.3,"[]" +"[]",0,"[1,666]" +""" dict1 = { "c1": [1, 2, 3, 4, 5, 6, 7, 8], @@ -389,9 +392,9 @@ def test_special_numpy_types(self): self.assertEqual(str(ret), '"2025-05-30 20:08:08.123000000"\n') def test_query_pyarrow_table1(self): - ret = self.sess.query("SELECT c4.tags, c3.deep.level2.level3, c3.mixed_list[].a FROM Python(arrow_table1) WHERE c1 <= 2 ORDER BY c1") + ret = self.sess.query("SELECT c4.tags, c3.deep.level2.level3, c3.mixed_list.a FROM Python(arrow_table1) WHERE c1 <= 2 ORDER BY c1") - self.assertEqual(str(ret), EXPECTED1) + self.assertEqual(str(ret), EXPECTED3) def test_pyarrow_complex_types(self): struct_type = pa.struct([ diff --git a/tests/test_query_py.py b/tests/test_query_py.py index ea6a2074665..b33097dfede 100644 --- a/tests/test_query_py.py +++ b/tests/test_query_py.py @@ -185,7 +185,7 @@ def test_query_arrow3(self): ) self.assertEqual( str(ret), - "5872873,587287.3,553446.5,470878.25,3,0,7,10\n", + "5872873,587287.3,553446.5,582813.5,3,0,7,10\n", ) def test_query_arrow4(self): @@ -209,17 +209,17 @@ def test_query_arrow5(self): self.assertDictEqual( schema_dict, { - "quadkey": "String", - "tile": "String", - "tile_x": "Float64", - "tile_y": "Float64", - "avg_d_kbps": "Int64", - "avg_u_kbps": "Int64", - "avg_lat_ms": "Int64", - "avg_lat_down_ms": "Float64", - "avg_lat_up_ms": "Float64", - "tests": "Int64", - "devices": "Int64", + "quadkey": "Nullable(String)", + "tile": "Nullable(String)", + "tile_x": "Nullable(Float64)", + "tile_y": "Nullable(Float64)", + "avg_d_kbps": "Nullable(Int64)", + "avg_u_kbps": "Nullable(Int64)", + "avg_lat_ms": "Nullable(Int64)", + "avg_lat_down_ms": "Nullable(Float64)", + "avg_lat_up_ms": "Nullable(Float64)", + "tests": "Nullable(Int64)", + "devices": "Nullable(Int64)", }, ) ret = chdb.query( @@ -237,20 +237,20 @@ def test_query_arrow5(self): self.assertDictEqual( {x["name"]: x["type"] for x in json.loads(str(ret)).get("meta")}, { - "max(avg_d_kbps)": "Int64", - "max(avg_lat_down_ms)": "Float64", - "max(avg_lat_ms)": "Int64", - "max(avg_lat_up_ms)": "Float64", - "max(avg_u_kbps)": "Int64", - "max(devices)": "Int64", - "max(tests)": "Int64", - "round(median(avg_d_kbps), 2)": "Float64", - "round(median(avg_lat_down_ms), 2)": "Float64", - "round(median(avg_lat_ms), 2)": "Float64", - "round(median(avg_lat_up_ms), 2)": "Float64", - "round(median(avg_u_kbps), 2)": "Float64", - "round(median(devices), 2)": "Float64", - "round(median(tests), 2)": "Float64", + "max(avg_d_kbps)": "Nullable(Int64)", + "max(avg_lat_down_ms)": "Nullable(Float64)", + "max(avg_lat_ms)": "Nullable(Int64)", + "max(avg_lat_up_ms)": "Nullable(Float64)", + "max(avg_u_kbps)": "Nullable(Int64)", + "max(devices)": "Nullable(Int64)", + "max(tests)": "Nullable(Int64)", + "round(median(avg_d_kbps), 2)": "Nullable(Float64)", + "round(median(avg_lat_down_ms), 2)": "Nullable(Float64)", + "round(median(avg_lat_ms), 2)": "Nullable(Float64)", + "round(median(avg_lat_up_ms), 2)": "Nullable(Float64)", + "round(median(avg_u_kbps), 2)": "Nullable(Float64)", + "round(median(devices), 2)": "Nullable(Float64)", + "round(median(tests), 2)": "Nullable(Float64)", }, ) From b85f9ea75f40b96bed30ad26fe4cc60ab64554c9 Mon Sep 17 00:00:00 2001 From: wudidapaopao Date: Tue, 23 Sep 2025 19:22:13 +0800 Subject: [PATCH 02/13] feat: add chdb_arrow_scan, chdb_arrow_array_scan and chdb_arrow_unregister_table C APIs --- .../workflows/build_linux_arm64_wheels-gh.yml | 3 + .github/workflows/build_linux_x86_wheels.yml | 3 + .../workflows/build_macos_arm64_wheels.yml | 3 + .github/workflows/build_macos_x86_wheels.yml | 3 + README.md | 6 +- examples/chdbArrowTest.cpp | 784 ++++++++++++++++++ examples/runArrowTest.sh | 35 + programs/local/ArrowScanState.h | 34 + programs/local/ArrowSchema.cpp | 74 +- programs/local/ArrowSchema.h | 25 +- programs/local/ArrowStreamRegistry.h | 100 +++ programs/local/ArrowStreamSource.cpp | 51 ++ programs/local/ArrowStreamSource.h | 31 + programs/local/ArrowStreamWrapper.cpp | 138 +-- programs/local/ArrowStreamWrapper.h | 47 +- programs/local/ArrowTableReader.cpp | 39 +- programs/local/ArrowTableReader.h | 30 +- programs/local/CMakeLists.txt | 12 +- programs/local/LocalServer.cpp | 9 +- programs/local/PyArrowStreamFactory.cpp | 113 +++ programs/local/PyArrowStreamFactory.h | 25 + programs/local/PyArrowTable.cpp | 59 +- programs/local/StorageArrowStream.cpp | 97 +++ programs/local/StorageArrowStream.h | 44 + programs/local/StoragePython.cpp | 10 +- programs/local/TableFunctionArrowStream.cpp | 127 +++ programs/local/TableFunctionArrowStream.h | 41 + programs/local/TableFunctionPython.cpp | 3 +- programs/local/TableFunctionPython.h | 6 +- programs/local/chdb-arrow.cpp | 59 +- programs/local/chdb-internal.h | 5 +- programs/local/chdb.cpp | 21 +- programs/local/chdb.h | 24 +- tests/test_arrow_table_queries.py | 173 ++++ 34 files changed, 1877 insertions(+), 357 deletions(-) create mode 100644 examples/chdbArrowTest.cpp create mode 100755 examples/runArrowTest.sh create mode 100644 programs/local/ArrowScanState.h create mode 100644 programs/local/ArrowStreamRegistry.h create mode 100644 programs/local/ArrowStreamSource.cpp create mode 100644 programs/local/ArrowStreamSource.h create mode 100644 programs/local/PyArrowStreamFactory.cpp create mode 100644 programs/local/PyArrowStreamFactory.h create mode 100644 programs/local/StorageArrowStream.cpp create mode 100644 programs/local/StorageArrowStream.h create mode 100644 programs/local/TableFunctionArrowStream.cpp create mode 100644 programs/local/TableFunctionArrowStream.h create mode 100644 tests/test_arrow_table_queries.py diff --git a/.github/workflows/build_linux_arm64_wheels-gh.yml b/.github/workflows/build_linux_arm64_wheels-gh.yml index 37cf62ad904..e12fca991e4 100644 --- a/.github/workflows/build_linux_arm64_wheels-gh.yml +++ b/.github/workflows/build_linux_arm64_wheels-gh.yml @@ -138,6 +138,9 @@ jobs: - name: Run libchdb stub in examples dir run: | bash -x ./examples/runStub.sh + - name: Run Arrow functions test in examples dir + run: | + bash -x ./examples/runArrowTest.sh - name: Check ccache statistics run: | ccache -s diff --git a/.github/workflows/build_linux_x86_wheels.yml b/.github/workflows/build_linux_x86_wheels.yml index 3ff06698d27..32ad5853766 100644 --- a/.github/workflows/build_linux_x86_wheels.yml +++ b/.github/workflows/build_linux_x86_wheels.yml @@ -138,6 +138,9 @@ jobs: - name: Run libchdb stub in examples dir run: | bash -x ./examples/runStub.sh + - name: Run Arrow functions test in examples dir + run: | + bash -x ./examples/runArrowTest.sh - name: Check ccache statistics run: | ccache -s diff --git a/.github/workflows/build_macos_arm64_wheels.yml b/.github/workflows/build_macos_arm64_wheels.yml index 96ef0b988a6..39cd4d2f1e0 100644 --- a/.github/workflows/build_macos_arm64_wheels.yml +++ b/.github/workflows/build_macos_arm64_wheels.yml @@ -141,6 +141,9 @@ jobs: - name: Run libchdb stub in examples dir run: | bash -x ./examples/runStub.sh + - name: Run Arrow functions test in examples dir + run: | + bash -x ./examples/runArrowTest.sh - name: Keep killall ccache and wait for ccache to finish if: always() run: | diff --git a/.github/workflows/build_macos_x86_wheels.yml b/.github/workflows/build_macos_x86_wheels.yml index 85ebe048c87..148edc45971 100644 --- a/.github/workflows/build_macos_x86_wheels.yml +++ b/.github/workflows/build_macos_x86_wheels.yml @@ -142,6 +142,9 @@ jobs: - name: Run libchdb stub in examples dir run: | bash -x ./examples/runStub.sh + - name: Run Arrow functions test in examples dir + run: | + bash -x ./examples/runArrowTest.sh - name: Keep killall ccache and wait for ccache to finish if: always() run: | diff --git a/README.md b/README.md index 83760e75590..c8052829995 100644 --- a/README.md +++ b/README.md @@ -416,11 +416,7 @@ chDB automatically converts Python dictionary objects to ClickHouse JSON types f ``` - Columns are converted to `String` if sampling finds non-dictionary values. -2. **Arrow Table** - - `struct` type columns are automatically mapped to JSON columns. - - Nested structures preserve type information. - -3. **chdb.PyReader** +2. **chdb.PyReader** - Implement custom schema mapping in `get_schema()`: ```python def get_schema(self): diff --git a/examples/chdbArrowTest.cpp b/examples/chdbArrowTest.cpp new file mode 100644 index 00000000000..6cc1884beec --- /dev/null +++ b/examples/chdbArrowTest.cpp @@ -0,0 +1,784 @@ +#include +#include +#include +#include +#include + +#include "../programs/local/chdb.h" +#include "../contrib/arrow/cpp/src/arrow/c/abi.h" + +// Custom ArrowArrayStream implementation data +struct CustomStreamData +{ + bool schema_sent; + size_t current_row; + size_t total_rows; + size_t batch_size; + std::string last_error; + + CustomStreamData() : schema_sent(false), current_row(0), total_rows(1000000), batch_size(10000) {} + + // Reset the stream to allow reading from the beginning + void reset() + { + current_row = 0; + last_error.clear(); + } +}; + +// Helper function to create schema with 2 columns: id(int64), value(string) +static void create_schema(struct ArrowSchema * schema) { + schema->format = "+s"; // struct format + schema->name = nullptr; + schema->metadata = nullptr; + schema->flags = 0; + schema->n_children = 2; + schema->children = static_cast(malloc(2 * sizeof(struct ArrowSchema *))); + schema->dictionary = nullptr; + schema->release = [](struct ArrowSchema * s) + { + if (s->children) { + for (int64_t i = 0; i < s->n_children; i++) { + if (s->children[i] && s->children[i]->release) { + s->children[i]->release(s->children[i]); + } + free(s->children[i]); + } + free(s->children); + } + s->release = nullptr; + }; + + // Field 0: id (int64) + schema->children[0] = static_cast(malloc(sizeof(struct ArrowSchema))); + schema->children[0]->format = "l"; // int64 + schema->children[0]->name = "id"; + schema->children[0]->metadata = nullptr; + schema->children[0]->flags = 0; + schema->children[0]->n_children = 0; + schema->children[0]->children = nullptr; + schema->children[0]->dictionary = nullptr; + schema->children[0]->release = [](struct ArrowSchema* s) { s->release = nullptr; }; + + // Field 1: value (string) + schema->children[1] = static_cast(malloc(sizeof(struct ArrowSchema))); + schema->children[1]->format = "u"; // utf8 string + schema->children[1]->name = "value"; + schema->children[1]->metadata = nullptr; + schema->children[1]->flags = 0; + schema->children[1]->n_children = 0; + schema->children[1]->children = nullptr; + schema->children[1]->dictionary = nullptr; + schema->children[1]->release = [](struct ArrowSchema* s) { s->release = nullptr; }; +} + +// Helper function to create a batch of data +static void create_batch(struct ArrowArray* array, size_t start_row, size_t batch_size) +{ + // Main array structure + array->length = batch_size; + array->null_count = 0; + array->offset = 0; + array->n_buffers = 1; + array->n_children = 2; + array->buffers = static_cast(malloc(1 * sizeof(void*))); + array->buffers[0] = nullptr; // validity buffer (no nulls) + array->children = static_cast(malloc(2 * sizeof(struct ArrowArray*))); + array->dictionary = nullptr; + + // Create id column (int64) + array->children[0] = static_cast(malloc(sizeof(struct ArrowArray))); + struct ArrowArray* id_array = array->children[0]; + id_array->length = batch_size; + id_array->null_count = 0; + id_array->offset = 0; + id_array->n_buffers = 2; + id_array->n_children = 0; + id_array->buffers = static_cast(malloc(2 * sizeof(void*))); + id_array->buffers[0] = nullptr; // validity buffer + + // Allocate and fill id data + int64_t* id_data = static_cast(malloc(batch_size * sizeof(int64_t))); + for (size_t i = 0; i < batch_size; i++) + id_data[i] = start_row + i; + + id_array->buffers[1] = id_data; // data buffer + id_array->children = nullptr; + id_array->dictionary = nullptr; + id_array->release = [](struct ArrowArray* arr) + { + if (arr->buffers) { + free(const_cast(arr->buffers[1])); // free data buffer + free(const_cast(arr->buffers)); + } + arr->release = nullptr; + }; + + // Create value column (string) + array->children[1] = static_cast(malloc(sizeof(struct ArrowArray))); + struct ArrowArray* str_array = array->children[1]; + str_array->length = batch_size; + str_array->null_count = 0; + str_array->offset = 0; + str_array->n_buffers = 3; + str_array->n_children = 0; + str_array->buffers = static_cast(malloc(3 * sizeof(void*))); + str_array->buffers[0] = nullptr; // validity buffer + + // Create offset buffer (int32) + int32_t* offsets = static_cast(malloc((batch_size + 1) * sizeof(int32_t))); + offsets[0] = 0; + + // Calculate total string length and create strings + size_t total_str_len = 0; + std::vector strings; + for (size_t i = 0; i < batch_size; i++) + { + std::string str = "value_" + std::to_string(start_row + i); + strings.push_back(str); + total_str_len += str.length(); + offsets[i + 1] = total_str_len; + } + str_array->buffers[1] = offsets; // offset buffer + + // Create data buffer + char* str_data = static_cast(malloc(total_str_len)); + size_t pos = 0; + for (const auto& str : strings) + { + memcpy(str_data + pos, str.c_str(), str.length()); + pos += str.length(); + } + str_array->buffers[2] = str_data; // data buffer + + str_array->children = nullptr; + str_array->dictionary = nullptr; + str_array->release = [](struct ArrowArray* arr) + { + if (arr->buffers) { + free(const_cast(arr->buffers[1])); // free offset buffer + free(const_cast(arr->buffers[2])); // free data buffer + free(const_cast(arr->buffers)); + } + arr->release = nullptr; + }; + + // Main array release function + array->release = [](struct ArrowArray* arr) { + if (arr->children) { + for (int64_t i = 0; i < arr->n_children; i++) { + if (arr->children[i] && arr->children[i]->release) { + arr->children[i]->release(arr->children[i]); + } + free(arr->children[i]); + } + free(arr->children); + } + if (arr->buffers) { + free(const_cast(arr->buffers)); + } + arr->release = nullptr; + }; +} + +// Callback function to get schema +static int custom_get_schema(struct ArrowArrayStream * /* stream */, struct ArrowSchema * out) +{ + create_schema(out); + return 0; +} + +// Callback function to get next array +static int custom_get_next(struct ArrowArrayStream * stream, struct ArrowArray * out) +{ + auto* data = static_cast(stream->private_data); + if (!data) + return EINVAL; + + // Check if we've reached the end of the stream + if (data->current_row >= data->total_rows) + { + // End of stream - set release to nullptr to indicate no more data + out->release = nullptr; + return 0; + } + + // Calculate batch size for this iteration + size_t remaining_rows = data->total_rows - data->current_row; + size_t batch_size = std::min(data->batch_size, remaining_rows); + + // Create the batch + create_batch(out, data->current_row, batch_size); + + data->current_row += batch_size; + return 0; +} + +// Callback function to get last error +static const char* custom_get_last_error(struct ArrowArrayStream* stream) { + auto* data = static_cast(stream->private_data); + if (!data || data->last_error.empty()) + return nullptr; + + return data->last_error.c_str(); +} + +// Callback function to release stream resources +static void custom_release(struct ArrowArrayStream* stream) { + if (stream->private_data) + { + delete static_cast(stream->private_data); + stream->private_data = nullptr; + } + stream->release = nullptr; +} + +// Helper function to reset the ArrowArrayStream for reuse +static void reset_arrow_stream(struct ArrowArrayStream* stream) +{ + if (stream && stream->private_data) + { + auto* data = static_cast(stream->private_data); + data->reset(); + std::cout << "✓ ArrowArrayStream has been reset, ready for re-reading\n"; + } +} + +//===--------------------------------------------------------------------===// +// Unit Test Utilities +//===--------------------------------------------------------------------===// + +static void test_assert(bool condition, const std::string& test_name, const std::string& message = "") +{ + if (condition) + { + std::cout << "✓ PASS: " << test_name << std::endl; + } + else + { + std::cout << "✗ FAIL: " << test_name; + if (!message.empty()) + { + std::cout << " - " << message; + } + std::cout << std::endl; + exit(1); + } +} + +static void test_assert_chdb_state(chdb_state state, const std::string& operation_name) +{ + test_assert(state == CHDBSuccess, + "chDB operation: " + operation_name, + state == CHDBError ? "Operation failed" : "Unknown state"); +} + +static void test_assert_not_null(void* ptr, const std::string& test_name) +{ + test_assert(ptr != nullptr, test_name, "Pointer is null"); +} + +static void test_assert_no_error(chdb_result* result, const std::string& query_name) +{ + test_assert_not_null(result, query_name + " - Result is not null"); + + const char * error = chdb_result_error(result); + test_assert(error == nullptr, + query_name + " - No query error", + error ? std::string("Error: ") + error : ""); +} + +static void test_assert_query_result_contains(chdb_result* result, const std::string& expected_content, const std::string& query_name) +{ + test_assert_no_error(result, query_name); + + char * buffer = chdb_result_buffer(result); + test_assert_not_null(buffer, query_name + " - Result buffer is not null"); + + std::string result_str(buffer); + test_assert(result_str.find(expected_content) != std::string::npos, + query_name + " - Result contains expected content", + "Expected: " + expected_content + ", Actual: " + result_str); +} + +static void test_assert_row_count(chdb_result* result, uint64_t expected_rows, const std::string& query_name) +{ + test_assert_no_error(result, query_name); + + char* buffer = chdb_result_buffer(result); + test_assert_not_null(buffer, query_name + " - Result buffer is not null"); + + // Parse the count result (assuming CSV format with just the number) + std::string result_str(buffer); + // Remove trailing whitespace/newlines + result_str.erase(result_str.find_last_not_of(" \t\n\r\f\v") + 1); + + uint64_t actual_rows = std::stoull(result_str); + test_assert(actual_rows == expected_rows, + query_name + " - Row count matches", + "Expected: " + std::to_string(expected_rows) + ", Actual: " + std::to_string(actual_rows)); +} + +void test_arrow_scan(chdb_connection conn) +{ + std::cout << "\n=== Creating Custom ArrowArrayStream ===\n"; + std::cout << "Data specification: 1,000,000 rows × 2 columns (id: int64, value: string)\n"; + + struct ArrowArrayStream stream; + memset(&stream, 0, sizeof(stream)); + + // Create and initialize stream data + auto * stream_data = new CustomStreamData(); + + // Set up the ArrowArrayStream callbacks + stream.get_schema = custom_get_schema; + stream.get_next = custom_get_next; + stream.get_last_error = custom_get_last_error; + stream.release = custom_release; + stream.private_data = stream_data; + + std::cout << "✓ ArrowArrayStream initialization completed\n"; + std::cout << "Starting registration with chDB...\n"; + + const char * table_name = "test_arrow_table"; + const char * non_exist_table_name = "non_exist_table"; + + chdb_arrow_stream arrow_stream = reinterpret_cast(&stream); + chdb_state result = chdb_arrow_scan(conn, table_name, arrow_stream); + + // Test 1: Verify arrow registration succeeded + test_assert_chdb_state(result, "Register ArrowArrayStream to table: " + std::string(table_name)); + + // Test 2: Unregister non-existent table should handle gracefully + result = chdb_arrow_unregister_table(conn, non_exist_table_name); + test_assert_chdb_state(result, "Unregister non-existent table: " + std::string(non_exist_table_name)); + + // Test 3: Count rows - should be exactly 1,000,000 + chdb_result * count_result = chdb_query(conn, "SELECT COUNT(*) as total_rows FROM arrowstream(test_arrow_table)", "CSV"); + test_assert_row_count(count_result, 1000000, "Count total rows"); + chdb_destroy_query_result(count_result); + + // Test 4: Sample first 5 rows - should contain id=0,1,2,3,4 + reset_arrow_stream(&stream); + chdb_result * sample_result = chdb_query(conn, "SELECT * FROM arrowstream(test_arrow_table) LIMIT 5", "CSV"); + test_assert_query_result_contains(sample_result, "0,\"value_0\"", "First 5 rows contain first row"); + test_assert_query_result_contains(sample_result, "4,\"value_4\"", "First 5 rows contain fifth row"); + chdb_destroy_query_result(sample_result); + + // Test 5: Sample last 5 rows - should contain id=999999,999998,999997,999996,999995 + reset_arrow_stream(&stream); + chdb_result * last_result = chdb_query(conn, "SELECT * FROM arrowstream(test_arrow_table) ORDER BY id DESC LIMIT 5", "CSV"); + test_assert_query_result_contains(last_result, "999999,\"value_999999\"", "Last 5 rows contain last row"); + test_assert_query_result_contains(last_result, "999995,\"value_999995\"", "Last 5 rows contain fifth row"); + chdb_destroy_query_result(last_result); + + // Test 6: Multiple table registration tests + // Create second ArrowArrayStream with different data (500,000 rows) + struct ArrowArrayStream stream2; + memset(&stream2, 0, sizeof(stream2)); + auto * stream_data2 = new CustomStreamData(); + stream_data2->total_rows = 500000; // Different row count + stream_data2->current_row = 0; + stream2.get_schema = custom_get_schema; + stream2.get_next = custom_get_next; + stream2.get_last_error = custom_get_last_error; + stream2.release = custom_release; + stream2.private_data = stream_data2; + + // Create third ArrowArrayStream with different data (100,000 rows) + struct ArrowArrayStream stream3; + memset(&stream3, 0, sizeof(stream3)); + auto * stream_data3 = new CustomStreamData(); + stream_data3->total_rows = 100000; // Different row count + stream_data3->current_row = 0; + stream3.get_schema = custom_get_schema; + stream3.get_next = custom_get_next; + stream3.get_last_error = custom_get_last_error; + stream3.release = custom_release; + stream3.private_data = stream_data3; + + const char * table_name2 = "test_arrow_table_2"; + const char * table_name3 = "test_arrow_table_3"; + + // Register second table + chdb_arrow_stream arrow_stream2 = reinterpret_cast(&stream2); + result = chdb_arrow_scan(conn, table_name2, arrow_stream2); + test_assert_chdb_state(result, "Register second ArrowArrayStream to table: " + std::string(table_name2)); + + // Register third table + chdb_arrow_stream arrow_stream3 = reinterpret_cast(&stream3); + result = chdb_arrow_scan(conn, table_name3, arrow_stream3); + test_assert_chdb_state(result, "Register third ArrowArrayStream to table: " + std::string(table_name3)); + + // Test 6a: Verify each table has correct row counts + reset_arrow_stream(&stream); + chdb_result * count1_result = chdb_query(conn, "SELECT COUNT(*) FROM arrowstream(test_arrow_table)", "CSV"); + test_assert_row_count(count1_result, 1000000, "First table row count"); + chdb_destroy_query_result(count1_result); + + reset_arrow_stream(&stream2); + chdb_result * count2_result = chdb_query(conn, "SELECT COUNT(*) FROM arrowstream(test_arrow_table_2)", "CSV"); + test_assert_row_count(count2_result, 500000, "Second table row count"); + chdb_destroy_query_result(count2_result); + + reset_arrow_stream(&stream3); + chdb_result * count3_result = chdb_query(conn, "SELECT COUNT(*) FROM arrowstream(test_arrow_table_3)", "CSV"); + test_assert_row_count(count3_result, 100000, "Third table row count"); + chdb_destroy_query_result(count3_result); + + // Test 6b: Test cross-table JOIN query + reset_arrow_stream(&stream); + reset_arrow_stream(&stream2); + chdb_result * join_result = chdb_query(conn, + "SELECT t1.id, t1.value, t2.value as value2 " + "FROM arrowstream(test_arrow_table) t1 " + "INNER JOIN arrowstream(test_arrow_table_2) t2 ON t1.id = t2.id " + "WHERE t1.id < 5 ORDER BY t1.id", "CSV"); + test_assert_query_result_contains(join_result, R"(0,"value_0","value_0")", "JOIN query contains expected data"); + test_assert_query_result_contains(join_result, R"(4,"value_4","value_4")", "JOIN query contains fifth row"); + chdb_destroy_query_result(join_result); + + // Test 6c: Test UNION query across multiple tables + reset_arrow_stream(&stream2); + reset_arrow_stream(&stream3); + chdb_result * union_result = chdb_query(conn, + "SELECT COUNT(*) FROM (" + "SELECT id FROM arrowstream(test_arrow_table_2) WHERE id < 10 " + "UNION ALL " + "SELECT id FROM arrowstream(test_arrow_table_3) WHERE id < 10" + ")", "CSV"); + test_assert_row_count(union_result, 20, "UNION query row count"); + chdb_destroy_query_result(union_result); + + // Cleanup additional tables + result = chdb_arrow_unregister_table(conn, table_name2); + test_assert_chdb_state(result, "Unregister second ArrowArrayStream table"); + + result = chdb_arrow_unregister_table(conn, table_name3); + test_assert_chdb_state(result, "Unregister third ArrowArrayStream table"); + + // Test 7: Unregister original table should succeed + result = chdb_arrow_unregister_table(conn, table_name); + test_assert_chdb_state(result, "Unregister ArrowArrayStream table: " + std::string(table_name)); + + // Test 8: Sample last 5 rows after unregister should fail + reset_arrow_stream(&stream); + chdb_result * unregister_result = chdb_query(conn, "SELECT * FROM arrowstream(test_arrow_table) ORDER BY id DESC LIMIT 5", "CSV"); + const char * error = chdb_result_error(unregister_result); + test_assert(error != nullptr, + "Query after unregister should fail", + error ? std::string("Got expected error: ") + error : "No error returned when error was expected"); + chdb_destroy_query_result(unregister_result); +} + +// Helper function to create ArrowArray with specified row count +static void create_arrow_array(struct ArrowArray * array, uint64_t row_count) +{ + array->length = row_count; + array->null_count = 0; + array->offset = 0; + array->n_buffers = 1; + array->n_children = 2; + array->buffers = static_cast(malloc(1 * sizeof(void *))); + array->buffers[0] = nullptr; // validity buffer + + array->children = static_cast(malloc(2 * sizeof(struct ArrowArray *))); + array->dictionary = nullptr; + + // Create id column (int64) + array->children[0] = static_cast(malloc(sizeof(struct ArrowArray))); + struct ArrowArray * id_array = array->children[0]; + id_array->length = row_count; + id_array->null_count = 0; + id_array->offset = 0; + id_array->n_buffers = 2; + id_array->n_children = 0; + id_array->children = nullptr; + id_array->dictionary = nullptr; + + id_array->buffers = static_cast(malloc(2 * sizeof(void *))); + id_array->buffers[0] = nullptr; // validity buffer + + // Allocate and populate id data + int64_t * id_data = static_cast(malloc(row_count * sizeof(int64_t))); + for (uint64_t i = 0; i < row_count; i++) + { + id_data[i] = static_cast(i); + } + id_array->buffers[1] = id_data; + + id_array->release = [](struct ArrowArray * a) + { + if (a->buffers) + { + free(const_cast(a->buffers[1])); // id data + free(const_cast(a->buffers)); + } + free(a); + }; + + // Create value column (string) + array->children[1] = static_cast(malloc(sizeof(struct ArrowArray))); + struct ArrowArray * value_array = array->children[1]; + value_array->length = row_count; + value_array->null_count = 0; + value_array->offset = 0; + value_array->n_buffers = 3; + value_array->n_children = 0; + value_array->children = nullptr; + value_array->dictionary = nullptr; + + value_array->buffers = static_cast(malloc(3 * sizeof(void *))); + value_array->buffers[0] = nullptr; // validity buffer + + // Calculate total string data size and create offset array + int32_t * offsets = static_cast(malloc((row_count + 1) * sizeof(int32_t))); + size_t total_string_size = 0; + offsets[0] = 0; + + for (uint64_t i = 0; i < row_count; i++) + { + std::string value_str = "value_" + std::to_string(i); + total_string_size += value_str.length(); + offsets[i + 1] = static_cast(total_string_size); + } + + value_array->buffers[1] = offsets; + + // Allocate and populate string data + char * string_data = static_cast(malloc(total_string_size)); + size_t current_pos = 0; + for (uint64_t i = 0; i < row_count; i++) { + std::string value_str = "value_" + std::to_string(i); + memcpy(string_data + current_pos, value_str.c_str(), value_str.length()); + current_pos += value_str.length(); + } + value_array->buffers[2] = string_data; + + value_array->release = [](struct ArrowArray * a) { + if (a->buffers) { + free(const_cast(a->buffers[1])); // offsets + free(const_cast(a->buffers[2])); // string data + free(const_cast(a->buffers)); + } + free(a); + }; + + // Set release callback for main array + array->release = [](struct ArrowArray * a) + { + if (a->children) + { + for (int64_t i = 0; i < a->n_children; i++) + { + if (a->children[i] && a->children[i]->release) + { + a->children[i]->release(a->children[i]); + } + } + free(a->children); + } + if (a->buffers) { + free(const_cast(a->buffers)); + } + }; +} + +void test_arrow_array_scan(chdb_connection conn) +{ + std::cout << "\n=== Testing ArrowArray Scan Functions ===\n"; + std::cout << "Data specification: 1,000,000 rows × 2 columns (id: int64, value: string)\n"; + + // Create ArrowSchema (reuse existing function) + struct ArrowSchema schema; + create_schema(&schema); + + // Create ArrowArray with 1,000,000 rows + struct ArrowArray array; + memset(&array, 0, sizeof(array)); + create_arrow_array(&array, 1000000); + + std::cout << "✓ ArrowArray initialization completed\n"; + std::cout << "Starting registration with chDB...\n"; + + const char * table_name = "test_arrow_array_table"; + const char * non_exist_table_name = "non_exist_array_table"; + + chdb_arrow_schema arrow_schema = reinterpret_cast(&schema); + chdb_arrow_array arrow_array = reinterpret_cast(&array); + + // Test 1: Register -> Query -> Unregister for row count + chdb_state result = chdb_arrow_array_scan(conn, table_name, arrow_schema, arrow_array); + test_assert_chdb_state(result, "Register ArrowArray to table: " + std::string(table_name)); + + chdb_result * count_result = chdb_query(conn, "SELECT COUNT(*) as total_rows FROM arrowstream(test_arrow_array_table)", "CSV"); + test_assert_row_count(count_result, 1000000, "Count total rows"); + chdb_destroy_query_result(count_result); + + result = chdb_arrow_unregister_table(conn, table_name); + test_assert_chdb_state(result, "Unregister ArrowArray table after count query"); + + // Test 2: Unregister non-existent table should handle gracefully + result = chdb_arrow_unregister_table(conn, non_exist_table_name); + test_assert_chdb_state(result, "Unregister non-existent array table: " + std::string(non_exist_table_name)); + + // Test 3: Register -> Query -> Unregister for first 5 rows + result = chdb_arrow_array_scan(conn, table_name, arrow_schema, arrow_array); + test_assert_chdb_state(result, "Register ArrowArray for sample query"); + + chdb_result * sample_result = chdb_query(conn, "SELECT * FROM arrowstream(test_arrow_array_table) LIMIT 5", "CSV"); + test_assert_query_result_contains(sample_result, "0,\"value_0\"", "First 5 rows contain first row"); + test_assert_query_result_contains(sample_result, "4,\"value_4\"", "First 5 rows contain fifth row"); + chdb_destroy_query_result(sample_result); + + result = chdb_arrow_unregister_table(conn, table_name); + test_assert_chdb_state(result, "Unregister ArrowArray table after sample query"); + + // Test 4: Register -> Query -> Unregister for last 5 rows + result = chdb_arrow_array_scan(conn, table_name, arrow_schema, arrow_array); + test_assert_chdb_state(result, "Register ArrowArray for last rows query"); + + chdb_result * last_result = chdb_query(conn, "SELECT * FROM arrowstream(test_arrow_array_table) ORDER BY id DESC LIMIT 5", "CSV"); + test_assert_query_result_contains(last_result, "999999,\"value_999999\"", "Last 5 rows contain last row"); + test_assert_query_result_contains(last_result, "999995,\"value_999995\"", "Last 5 rows contain fifth row"); + chdb_destroy_query_result(last_result); + + result = chdb_arrow_unregister_table(conn, table_name); + test_assert_chdb_state(result, "Unregister ArrowArray table after last rows query"); + + // Test 5: Independent multiple table tests + // Create second ArrowArray with different data (500,000 rows) + struct ArrowSchema schema2; + create_schema(&schema2); + struct ArrowArray array2; + memset(&array2, 0, sizeof(array2)); + create_arrow_array(&array2, 500000); + + // Create third ArrowArray with different data (100,000 rows) + struct ArrowSchema schema3; + create_schema(&schema3); + struct ArrowArray array3; + memset(&array3, 0, sizeof(array3)); + create_arrow_array(&array3, 100000); + + const char * table_name2 = "test_arrow_array_table_2"; + const char * table_name3 = "test_arrow_array_table_3"; + + chdb_arrow_schema arrow_schema2 = reinterpret_cast(&schema2); + chdb_arrow_array arrow_array2 = reinterpret_cast(&array2); + chdb_arrow_schema arrow_schema3 = reinterpret_cast(&schema3); + chdb_arrow_array arrow_array3 = reinterpret_cast(&array3); + + // Test 5a: Register -> Query -> Unregister for second table (500K rows) + result = chdb_arrow_array_scan(conn, table_name2, arrow_schema2, arrow_array2); + test_assert_chdb_state(result, "Register second ArrowArray to table: " + std::string(table_name2)); + + chdb_result * count2_result = chdb_query(conn, "SELECT COUNT(*) FROM arrowstream(test_arrow_array_table_2)", "CSV"); + test_assert_row_count(count2_result, 500000, "Second array table row count"); + chdb_destroy_query_result(count2_result); + + result = chdb_arrow_unregister_table(conn, table_name2); + test_assert_chdb_state(result, "Unregister second ArrowArray table"); + + // Test 5b: Register -> Query -> Unregister for third table (100K rows) + result = chdb_arrow_array_scan(conn, table_name3, arrow_schema3, arrow_array3); + test_assert_chdb_state(result, "Register third ArrowArray to table: " + std::string(table_name3)); + + chdb_result * count3_result = chdb_query(conn, "SELECT COUNT(*) FROM arrowstream(test_arrow_array_table_3)", "CSV"); + test_assert_row_count(count3_result, 100000, "Third array table row count"); + chdb_destroy_query_result(count3_result); + + result = chdb_arrow_unregister_table(conn, table_name3); + test_assert_chdb_state(result, "Unregister third ArrowArray table"); + + // Test 6: Cross-table JOIN query (Register both -> Query -> Unregister both) + result = chdb_arrow_array_scan(conn, table_name, arrow_schema, arrow_array); + test_assert_chdb_state(result, "Register first ArrowArray for JOIN"); + + result = chdb_arrow_array_scan(conn, table_name2, arrow_schema2, arrow_array2); + test_assert_chdb_state(result, "Register second ArrowArray for JOIN"); + + chdb_result * join_result = chdb_query(conn, + "SELECT t1.id, t1.value, t2.value as value2 " + "FROM arrowstream(test_arrow_array_table) t1 " + "INNER JOIN arrowstream(test_arrow_array_table_2) t2 ON t1.id = t2.id " + "WHERE t1.id < 5 ORDER BY t1.id", "CSV"); + test_assert_query_result_contains(join_result, R"(0,"value_0","value_0")", "Array JOIN query contains expected data"); + test_assert_query_result_contains(join_result, R"(4,"value_4","value_4")", "Array JOIN query contains fifth row"); + chdb_destroy_query_result(join_result); + + result = chdb_arrow_unregister_table(conn, table_name); + test_assert_chdb_state(result, "Unregister first ArrowArray after JOIN"); + + result = chdb_arrow_unregister_table(conn, table_name2); + test_assert_chdb_state(result, "Unregister second ArrowArray after JOIN"); + + // Test 7: Cross-table UNION query (Register both -> Query -> Unregister both) + result = chdb_arrow_array_scan(conn, table_name2, arrow_schema2, arrow_array2); + test_assert_chdb_state(result, "Register second ArrowArray for UNION"); + + result = chdb_arrow_array_scan(conn, table_name3, arrow_schema3, arrow_array3); + test_assert_chdb_state(result, "Register third ArrowArray for UNION"); + + chdb_result * union_result = chdb_query(conn, + "SELECT COUNT(*) FROM (" + "SELECT id FROM arrowstream(test_arrow_array_table_2) WHERE id < 10 " + "UNION ALL " + "SELECT id FROM arrowstream(test_arrow_array_table_3) WHERE id < 10" + ")", "CSV"); + test_assert_row_count(union_result, 20, "Array UNION query row count"); + chdb_destroy_query_result(union_result); + + result = chdb_arrow_unregister_table(conn, table_name2); + test_assert_chdb_state(result, "Unregister second ArrowArray after UNION"); + + result = chdb_arrow_unregister_table(conn, table_name3); + test_assert_chdb_state(result, "Unregister third ArrowArray after UNION"); + + // Test 8: Query after unregister should fail + chdb_result * unregister_result = chdb_query(conn, "SELECT * FROM arrowstream(test_arrow_array_table) ORDER BY id DESC LIMIT 5", "CSV"); + const char * error = chdb_result_error(unregister_result); + test_assert(error != nullptr, + "Array query after unregister should fail", + error ? std::string("Got expected error: ") + error : "No error returned when error was expected"); + chdb_destroy_query_result(unregister_result); + + // Cleanup ArrowArrays and schemas + if (array.release) array.release(&array); + if (schema.release) schema.release(&schema); + if (array2.release) array2.release(&array2); + if (schema2.release) schema2.release(&schema2); + if (array3.release) array3.release(&array3); + if (schema3.release) schema3.release(&schema3); +} + +int main() +{ + const char *argv[] = {"clickhouse", "--multiquery"}; + int argc = sizeof(argv) / sizeof(argv[0]); + chdb_connection * conn_ptr; + chdb_connection conn; + + std::cout << "=== chDB Arrow Functions Test ===\n"; + + // Create connection + conn_ptr = chdb_connect(argc, const_cast(argv)); + if (!conn_ptr || !*conn_ptr) { + std::cout << "Failed to create chDB connection\n"; + return 1; + } + + conn = *conn_ptr; + std::cout << "✓ chDB connection established\n"; + + // Run test suites + test_arrow_scan(conn); + test_arrow_array_scan(conn); + + // Clean up + chdb_close_conn(conn_ptr); + + std::cout << "\n=== chDB Arrow Functions Test Completed ===\n"; + + return 0; +} diff --git a/examples/runArrowTest.sh b/examples/runArrowTest.sh new file mode 100755 index 00000000000..7f8d5e9bed0 --- /dev/null +++ b/examples/runArrowTest.sh @@ -0,0 +1,35 @@ +#!/bin/bash + +set -e + +CXXFLAGS="-g -O0 -DDEBUG" + +# check current os type, and make ldd command +if [ "$(uname)" == "Darwin" ]; then + LDD="otool -L" + LIB_PATH="DYLD_LIBRARY_PATH" +elif [ "$(uname)" == "Linux" ]; then + LDD="ldd" + LIB_PATH="LD_LIBRARY_PATH" +else + echo "OS not supported" + exit 1 +fi + +# cd to the directory of this script +DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" >/dev/null 2>&1 && pwd )" +cd $DIR + +echo "Compile and link chdbArrowTest" +clang++ $CXXFLAGS chdbArrowTest.cpp -o chdbArrowTest \ + -I../programs/local/ \ + -I../contrib/arrow/cpp/src \ + -I../contrib/arrow-cmake/cpp/src \ + -I../src \ + -L../ -lchdb + +export ${LIB_PATH}=.. +${LDD} chdbArrowTest + +echo "Run Arrow API tests:" +./chdbArrowTest diff --git a/programs/local/ArrowScanState.h b/programs/local/ArrowScanState.h new file mode 100644 index 00000000000..98501498f57 --- /dev/null +++ b/programs/local/ArrowScanState.h @@ -0,0 +1,34 @@ +#pragma once + +#include "ArrowStreamWrapper.h" + +#include +#include + +namespace CHDB +{ + +/// Scan state for each stream - shared between ArrowTableReader and ArrowStreamReader +struct ArrowScanState +{ + /// Current Arrow array being processed (for ArrowTableReader) + std::unique_ptr current_array; + /// Current offset within the array + size_t current_offset = 0; + /// Whether this stream is exhausted + bool exhausted = false; + /// Cached imported RecordBatch to avoid repeated imports + std::shared_ptr cached_record_batch; + + virtual ~ArrowScanState() = default; + + virtual void reset() + { + current_array.reset(); + current_offset = 0; + exhausted = false; + cached_record_batch.reset(); + } +}; + +} // namespace CHDB diff --git a/programs/local/ArrowSchema.cpp b/programs/local/ArrowSchema.cpp index d0b3a391e61..8e0b3f21d0f 100644 --- a/programs/local/ArrowSchema.cpp +++ b/programs/local/ArrowSchema.cpp @@ -1,6 +1,20 @@ #include "ArrowSchema.h" -#include +#include +#include +#include + +namespace DB +{ + +namespace ErrorCodes +{ +extern const int BAD_ARGUMENTS; +} + +} + +using namespace DB; namespace CHDB { @@ -14,4 +28,62 @@ ArrowSchemaWrapper::~ArrowSchemaWrapper() } } +ArrowSchemaWrapper::ArrowSchemaWrapper(ArrowSchemaWrapper && other) noexcept + : arrow_schema(other.arrow_schema) +{ + other.arrow_schema.release = nullptr; +} + +ArrowSchemaWrapper & ArrowSchemaWrapper::operator=(ArrowSchemaWrapper && other) noexcept +{ + if (this != &other) + { + if (arrow_schema.release) + { + arrow_schema.release(&arrow_schema); + } + arrow_schema = other.arrow_schema; + other.arrow_schema.release = nullptr; + } + return *this; +} + +void ArrowSchemaWrapper::convertArrowSchema( + ArrowSchemaWrapper & schema, + NamesAndTypesList & names_and_types, + ContextPtr & context) +{ + if (!schema.arrow_schema.release) + { + throw Exception(ErrorCodes::BAD_ARGUMENTS, "ArrowSchema is already released"); + } + + /// Import ArrowSchema to arrow::Schema + auto arrow_schema_result = arrow::ImportSchema(&schema.arrow_schema); + if (!arrow_schema_result.ok()) + { + throw Exception(ErrorCodes::BAD_ARGUMENTS, + "Failed to import Arrow schema: {}", arrow_schema_result.status().message()); + } + + const auto & arrow_schema = arrow_schema_result.ValueOrDie(); + + const auto format_settings = getFormatSettings(context); + + /// Convert Arrow schema to ClickHouse header + auto block = ArrowColumnToCHColumn::arrowSchemaToCHHeader( + *arrow_schema, + nullptr, + "Arrow", + format_settings.arrow.skip_columns_with_unsupported_types_in_schema_inference, + format_settings.schema_inference_make_columns_nullable != 0, + false, + format_settings.parquet.allow_geoparquet_parser); + + for (const auto & column : block) + { + names_and_types.emplace_back(column.name, column.type); + } +} + } // namespace CHDB diff --git a/programs/local/ArrowSchema.h b/programs/local/ArrowSchema.h index 8b0318381aa..5e6720386ae 100644 --- a/programs/local/ArrowSchema.h +++ b/programs/local/ArrowSchema.h @@ -1,21 +1,34 @@ #pragma once +#include +#include #include namespace CHDB { +/// Wrapper for Arrow C Data Interface structures with RAII resource management class ArrowSchemaWrapper { public: - ArrowSchema arrow_schema; + ArrowSchema arrow_schema; - ArrowSchemaWrapper() - { - arrow_schema.release = nullptr; - } + ArrowSchemaWrapper() { + arrow_schema.release = nullptr; + } - ~ArrowSchemaWrapper(); + ~ArrowSchemaWrapper(); + + /// Non-copyable but moveable + ArrowSchemaWrapper(const ArrowSchemaWrapper &) = delete; + ArrowSchemaWrapper & operator=(const ArrowSchemaWrapper &) = delete; + ArrowSchemaWrapper(ArrowSchemaWrapper && other) noexcept; + ArrowSchemaWrapper & operator=(ArrowSchemaWrapper && other) noexcept; + + static void convertArrowSchema( + ArrowSchemaWrapper & schema, + DB::NamesAndTypesList & names_and_types, + DB::ContextPtr & context); }; } // namespace CHDB diff --git a/programs/local/ArrowStreamRegistry.h b/programs/local/ArrowStreamRegistry.h new file mode 100644 index 00000000000..c1fb465999e --- /dev/null +++ b/programs/local/ArrowStreamRegistry.h @@ -0,0 +1,100 @@ +#pragma once + +#include "chdb-internal.h" + +#include +#include +#include +#include + +#include + +struct ArrowArrayStream; + +namespace CHDB +{ + +class ArrowStreamRegistry +{ +public: + struct ArrowStreamInfo + { + ArrowArrayStream * stream = nullptr; + bool is_owner = false; + }; + +private: + std::unordered_map registered_streams; + mutable std::shared_mutex registry_mutex; + +public: + static ArrowStreamRegistry & instance() + { + static ArrowStreamRegistry instance; + return instance; + } + + bool registerArrowStream(const String & name, ArrowArrayStream * arrow_stream, bool is_owner) + { + std::unique_lock lock(registry_mutex); + + ArrowStreamInfo info; + info.stream = arrow_stream; + info.is_owner = is_owner; + + auto [iter, inserted] = registered_streams.emplace(name, std::move(info)); + return inserted; + } + + std::optional getArrowStream(const String & name) const + { + std::shared_lock lock(registry_mutex); + auto it = registered_streams.find(name); + if (it != registered_streams.end()) + return it->second; + return {}; + } + + bool unregisterArrowStream(const String & name) + { + std::unique_lock lock(registry_mutex); + auto it = registered_streams.find(name); + if (it != registered_streams.end()) + { + if (it->second.is_owner && it->second.stream) + { + /// Clean up owned Arrow stream + chdb_destroy_arrow_stream(it->second.stream); + } + registered_streams.erase(it); + return true; + } + return false; + } + + std::vector listRegisteredNames() const + { + std::shared_lock lock(registry_mutex); + std::vector names; + names.reserve(registered_streams.size()); + + for (const auto& [name, info] : registered_streams) + names.push_back(name); + + return names; + } + + size_t size() const + { + std::shared_lock lock(registry_mutex); + return registered_streams.size(); + } + + void clear() + { + std::unique_lock lock(registry_mutex); + registered_streams.clear(); + } +}; + +} diff --git a/programs/local/ArrowStreamSource.cpp b/programs/local/ArrowStreamSource.cpp new file mode 100644 index 00000000000..c9545f8205c --- /dev/null +++ b/programs/local/ArrowStreamSource.cpp @@ -0,0 +1,51 @@ +#include "ArrowStreamSource.h" +#include + +#include + +namespace DB +{ + +namespace ErrorCodes +{ +extern const int BAD_ARGUMENTS; +} + +ArrowStreamSource::ArrowStreamSource( + const Block & sample_block_, + CHDB::ArrowTableReaderPtr arrow_table_reader_, + size_t stream_index_) + : ISource(sample_block_.cloneEmpty()) + , arrow_table_reader(arrow_table_reader_) + , sample_block(sample_block_) + , stream_index(stream_index_) +{ +} + +Chunk ArrowStreamSource::generate() +{ + chassert(arrow_table_reader); + + if (sample_block.getNames().empty()) + return {}; + + try + { + auto chunk = arrow_table_reader->readNextChunk(stream_index); + return chunk; + } + catch (const Exception &) + { + throw; + } + catch (const std::exception & e) + { + throw Exception(ErrorCodes::BAD_ARGUMENTS, "ArrowStreamSource error: {}", e.what()); + } + catch (...) + { + throw Exception(ErrorCodes::BAD_ARGUMENTS, "ArrowStreamSource unknown exception"); + } +} + +} diff --git a/programs/local/ArrowStreamSource.h b/programs/local/ArrowStreamSource.h new file mode 100644 index 00000000000..2d5055f394f --- /dev/null +++ b/programs/local/ArrowStreamSource.h @@ -0,0 +1,31 @@ +#pragma once + +#include "ArrowTableReader.h" + +#include +#include +#include + +namespace DB +{ + +class ArrowStreamSource : public ISource +{ +public: + ArrowStreamSource( + const Block & sample_block_, + CHDB::ArrowTableReaderPtr arrow_table_reader_, + size_t stream_index_); + + String getName() const override { return "ArrowStream"; } + + Chunk generate() override; + +private: + CHDB::ArrowTableReaderPtr arrow_table_reader; + Block sample_block; + size_t stream_index; + Poco::Logger * logger = &Poco::Logger::get("ArrowStreamSource"); +}; + +} diff --git a/programs/local/ArrowStreamWrapper.cpp b/programs/local/ArrowStreamWrapper.cpp index 7d298ea1b3d..a1c834a49b1 100644 --- a/programs/local/ArrowStreamWrapper.cpp +++ b/programs/local/ArrowStreamWrapper.cpp @@ -1,58 +1,23 @@ #include "ArrowStreamWrapper.h" -#include "PyArrowTable.h" -#include "PybindWrapper.h" -#include "PythonImporter.h" #include #include -#include -#include namespace DB { namespace ErrorCodes { -extern const int PY_EXCEPTION_OCCURED; +extern const int BAD_ARGUMENTS; } } -namespace py = pybind11; using namespace DB; namespace CHDB { -/// ArrowSchemaWrapper implementation -ArrowSchemaWrapper::~ArrowSchemaWrapper() -{ - if (arrow_schema.release) - { - arrow_schema.release(&arrow_schema); - } -} - -ArrowSchemaWrapper::ArrowSchemaWrapper(ArrowSchemaWrapper && other) noexcept - : arrow_schema(other.arrow_schema) -{ - other.arrow_schema.release = nullptr; -} - -ArrowSchemaWrapper & ArrowSchemaWrapper::operator=(ArrowSchemaWrapper && other) noexcept -{ - if (this != &other) - { - if (arrow_schema.release) - { - arrow_schema.release(&arrow_schema); - } - arrow_schema = other.arrow_schema; - other.arrow_schema.release = nullptr; - } - return *this; -} - /// ArrowArrayWrapper implementation ArrowArrayWrapper::~ArrowArrayWrapper() { @@ -85,7 +50,7 @@ ArrowArrayWrapper & ArrowArrayWrapper::operator=(ArrowArrayWrapper && other) noe /// ArrowArrayStreamWrapper implementation ArrowArrayStreamWrapper::~ArrowArrayStreamWrapper() { - if (arrow_array_stream.release) + if (should_release_on_destroy && arrow_array_stream.release) { arrow_array_stream.release(&arrow_array_stream); } @@ -93,20 +58,24 @@ ArrowArrayStreamWrapper::~ArrowArrayStreamWrapper() ArrowArrayStreamWrapper::ArrowArrayStreamWrapper(ArrowArrayStreamWrapper&& other) noexcept : arrow_array_stream(other.arrow_array_stream) + , should_release_on_destroy(other.should_release_on_destroy) { other.arrow_array_stream.release = nullptr; + other.should_release_on_destroy = true; } ArrowArrayStreamWrapper & ArrowArrayStreamWrapper::operator=(ArrowArrayStreamWrapper && other) noexcept { if (this != &other) { - if (arrow_array_stream.release) + if (should_release_on_destroy && arrow_array_stream.release) { arrow_array_stream.release(&arrow_array_stream); } arrow_array_stream = other.arrow_array_stream; + should_release_on_destroy = other.should_release_on_destroy; other.arrow_array_stream.release = nullptr; + other.should_release_on_destroy = true; } return *this; } @@ -115,18 +84,18 @@ void ArrowArrayStreamWrapper::getSchema(ArrowSchemaWrapper& schema) { if (!isValid()) { - throw Exception(ErrorCodes::PY_EXCEPTION_OCCURED, "ArrowArrayStream is not valid"); + throw Exception(ErrorCodes::BAD_ARGUMENTS, "ArrowArrayStream is not valid"); } if (arrow_array_stream.get_schema(&arrow_array_stream, &schema.arrow_schema) != 0) { - throw Exception(ErrorCodes::PY_EXCEPTION_OCCURED, + throw Exception(ErrorCodes::BAD_ARGUMENTS, "Failed to get schema from ArrowArrayStream: {}", getError()); } if (!schema.arrow_schema.release) { - throw Exception(ErrorCodes::PY_EXCEPTION_OCCURED, "Released schema returned from ArrowArrayStream"); + throw Exception(ErrorCodes::BAD_ARGUMENTS, "Released schema returned from ArrowArrayStream"); } } @@ -142,7 +111,7 @@ std::unique_ptr ArrowArrayStreamWrapper::getNextChunk() chunk->reset(); if (arrow_array_stream.get_next(&arrow_array_stream, &chunk->arrow_array) != 0) { - throw Exception(ErrorCodes::PY_EXCEPTION_OCCURED, + throw Exception(ErrorCodes::BAD_ARGUMENTS, "Failed to get next chunk from ArrowArrayStream: {}", getError()); } @@ -167,89 +136,4 @@ const char* ArrowArrayStreamWrapper::getError() return arrow_array_stream.get_last_error(&arrow_array_stream); } -std::unique_ptr PyArrowStreamFactory::createFromPyObject( - py::object & py_obj, - const Names & column_names) -{ - py::gil_scoped_acquire acquire; - - try - { - auto arrow_object_type = PyArrowTable::getArrowType(py_obj); - - switch (arrow_object_type) - { - case PyArrowObjectType::Table: - return createFromTable(py_obj, column_names); - default: - throw Exception(ErrorCodes::PY_EXCEPTION_OCCURED, - "Unsupported PyArrow object type: {}", arrow_object_type); - } - } - catch (const py::error_already_set & e) - { - throw Exception(ErrorCodes::PY_EXCEPTION_OCCURED, - "Failed to convert PyArrow object to arrow array stream: {}", e.what()); - } -} - -std::unique_ptr PyArrowStreamFactory::createFromTable( - py::object & table, - const Names & column_names) -{ - chassert(py::gil_check()); - - py::handle table_handle(table); - auto & import_cache = PythonImporter::ImportCache(); - auto arrow_dataset = import_cache.pyarrow.dataset().attr("dataset"); - - auto dataset = arrow_dataset(table_handle); - py::object arrow_scanner = dataset.attr("__class__").attr("scanner"); - - py::dict kwargs; - if (!column_names.empty()) { - ArrowSchemaWrapper schema; - auto obj_schema = table_handle.attr("schema"); - auto export_to_c = obj_schema.attr("_export_to_c"); - export_to_c(reinterpret_cast(&schema.arrow_schema)); - - /// Get available column names from schema - std::unordered_set available_columns; - if (schema.arrow_schema.n_children > 0 && schema.arrow_schema.children) - { - for (int64_t i = 0; i < schema.arrow_schema.n_children; ++i) - { - if (schema.arrow_schema.children[i] && schema.arrow_schema.children[i]->name) - { - available_columns.insert(schema.arrow_schema.children[i]->name); - } - } - } - - /// Only add column names that exist in the schema - py::list projection_list; - for (const auto & name : column_names) - { - if (available_columns.contains(name)) - { - projection_list.append(name); - } - } - - /// Only set columns if we have valid projections - if (projection_list.size() > 0) - { - kwargs["columns"] = projection_list; - } - } - - auto scanner = arrow_scanner(dataset, **kwargs); - - auto record_batches = scanner.attr("to_reader")(); - auto res = std::make_unique(); - auto export_to_c = record_batches.attr("_export_to_c"); - export_to_c(reinterpret_cast(&res->arrow_array_stream)); - return res; -} - } // namespace CHDB diff --git a/programs/local/ArrowStreamWrapper.h b/programs/local/ArrowStreamWrapper.h index 51a646c497b..0eb5c229d51 100644 --- a/programs/local/ArrowStreamWrapper.h +++ b/programs/local/ArrowStreamWrapper.h @@ -1,32 +1,13 @@ #pragma once +#include "ArrowSchema.h" + #include #include -#include -#include namespace CHDB { -/// Wrapper for Arrow C Data Interface structures with RAII resource management -class ArrowSchemaWrapper -{ -public: - ArrowSchema arrow_schema; - - ArrowSchemaWrapper() { - arrow_schema.release = nullptr; - } - - ~ArrowSchemaWrapper(); - - /// Non-copyable but moveable - ArrowSchemaWrapper(const ArrowSchemaWrapper &) = delete; - ArrowSchemaWrapper & operator=(const ArrowSchemaWrapper &) = delete; - ArrowSchemaWrapper(ArrowSchemaWrapper && other) noexcept; - ArrowSchemaWrapper & operator=(ArrowSchemaWrapper && other) noexcept; -}; - class ArrowArrayWrapper { public: @@ -57,20 +38,21 @@ class ArrowArrayStreamWrapper public: ArrowArrayStream arrow_array_stream; - ArrowArrayStreamWrapper() { + explicit ArrowArrayStreamWrapper(bool should_release = true) + : should_release_on_destroy(should_release) { arrow_array_stream.release = nullptr; } ~ArrowArrayStreamWrapper(); - // Non-copyable but moveable + /// Non-copyable but moveable ArrowArrayStreamWrapper(const ArrowArrayStreamWrapper&) = delete; ArrowArrayStreamWrapper& operator=(const ArrowArrayStreamWrapper&) = delete; ArrowArrayStreamWrapper(ArrowArrayStreamWrapper&& other) noexcept; ArrowArrayStreamWrapper& operator=(ArrowArrayStreamWrapper&& other) noexcept; /// Get schema from the stream - void getSchema(ArrowSchemaWrapper& schema); + void getSchema(ArrowSchemaWrapper & schema); /// Get next chunk from the stream std::unique_ptr getNextChunk(); @@ -80,20 +62,15 @@ class ArrowArrayStreamWrapper /// Check if stream is valid bool isValid() const { return arrow_array_stream.release != nullptr; } -}; -/// Factory class for creating ArrowArrayStream from Python objects -class PyArrowStreamFactory -{ -public: - static std::unique_ptr createFromPyObject( - pybind11::object & py_obj, - const DB::Names & column_names); + /// Set whether to release on destruction + void setShouldRelease(bool should_release) { should_release_on_destroy = should_release; } + + /// Get whether will release on destruction + bool getShouldRelease() const { return should_release_on_destroy; } private: - static std::unique_ptr createFromTable( - pybind11::object & table, - const DB::Names & column_names); + bool should_release_on_destroy = true; }; } // namespace CHDB diff --git a/programs/local/ArrowTableReader.cpp b/programs/local/ArrowTableReader.cpp index 822530706c3..21fd7efa1af 100644 --- a/programs/local/ArrowTableReader.cpp +++ b/programs/local/ArrowTableReader.cpp @@ -4,57 +4,44 @@ #include #include #include -#include -#include namespace DB { namespace ErrorCodes { -extern const int PY_EXCEPTION_OCCURED; +extern const int BAD_ARGUMENTS; } } -namespace py = pybind11; using namespace DB; namespace CHDB { ArrowTableReader::ArrowTableReader( - py::object & data_source_, + std::unique_ptr arrow_stream_, const DB::Block & sample_block_, const DB::FormatSettings & format_settings_, size_t num_streams_, size_t max_block_size_) : sample_block(sample_block_), format_settings(format_settings_), + arrow_stream(std::move(arrow_stream_)), num_streams(num_streams_), max_block_size(max_block_size_), scan_states(num_streams_) { - initializeStream(data_source_); + initializeStream(); } -void ArrowTableReader::initializeStream(py::object & data_source_) +void ArrowTableReader::initializeStream() { - try - { - /// Create Arrow stream from Python object - arrow_stream = PyArrowStreamFactory::createFromPyObject(data_source_, sample_block.getNames()); - - if (!arrow_stream || !arrow_stream->isValid()) - { - throw Exception(ErrorCodes::PY_EXCEPTION_OCCURED, - "Failed to create valid ArrowArrayStream from Python object"); - } - } - catch (const py::error_already_set & e) + if (!arrow_stream || !arrow_stream->isValid()) { - throw Exception(ErrorCodes::PY_EXCEPTION_OCCURED, - "Failed to initialize Arrow stream from Python object: {}", e.what()); + throw Exception(ErrorCodes::BAD_ARGUMENTS, + "ArrowArrayStream is not valid"); } /// Get schema from stream @@ -62,7 +49,7 @@ void ArrowTableReader::initializeStream(py::object & data_source_) auto arrow_schema_result = arrow::ImportSchema(&schema.arrow_schema); if (!arrow_schema_result.ok()) { - throw Exception(ErrorCodes::PY_EXCEPTION_OCCURED, + throw Exception(ErrorCodes::BAD_ARGUMENTS, "Failed to import Arrow schema during initialization: {}", arrow_schema_result.status().message()); } cached_arrow_schema = arrow_schema_result.ValueOrDie(); @@ -72,7 +59,7 @@ Chunk ArrowTableReader::readNextChunk(size_t stream_index) { if (stream_index >= num_streams) { - throw Exception(ErrorCodes::PY_EXCEPTION_OCCURED, + throw Exception(ErrorCodes::BAD_ARGUMENTS, "Stream index {} is out of range [0, {})", stream_index, num_streams); } @@ -148,8 +135,8 @@ std::unique_ptr ArrowTableReader::getNextArrowArray() Chunk ArrowTableReader::convertArrowArrayToChunk(const ArrowArrayWrapper & arrow_array_wrapper, size_t offset, size_t count, size_t stream_index) { - chassert(arrow_array_wrapper.arrow_array.length && count && offset < arrow_array_wrapper.arrow_array.length); - chassert(count <= arrow_array_wrapper.arrow_array.length - offset); + chassert(arrow_array_wrapper.arrow_array.length && count && offset < static_cast(arrow_array_wrapper.arrow_array.length)); + chassert(count <= static_cast(arrow_array_wrapper.arrow_array.length) - offset); chassert(stream_index < num_streams); auto & state = scan_states[stream_index]; @@ -174,7 +161,7 @@ Chunk ArrowTableReader::convertArrowArrayToChunk(const ArrowArrayWrapper & arrow auto arrow_batch_result = arrow::ImportRecordBatch(&array_copy, cached_arrow_schema); if (!arrow_batch_result.ok()) { - throw Exception(ErrorCodes::PY_EXCEPTION_OCCURED, + throw Exception(ErrorCodes::BAD_ARGUMENTS, "Failed to import Arrow RecordBatch: {}", arrow_batch_result.status().message()); } diff --git a/programs/local/ArrowTableReader.h b/programs/local/ArrowTableReader.h index 5cba71745d5..c2a29fc82c4 100644 --- a/programs/local/ArrowTableReader.h +++ b/programs/local/ArrowTableReader.h @@ -1,37 +1,16 @@ #pragma once +#include "ArrowScanState.h" #include "ArrowStreamWrapper.h" #include #include #include -#include #include namespace CHDB { -/// Scan state for each stream -struct ArrowScanState -{ - /// Current Arrow array being processed - std::unique_ptr current_array; - /// Current offset within the array - size_t current_offset = 0; - /// Whether this stream is exhausted - bool exhausted = false; - /// Cached imported RecordBatch to avoid repeated imports - std::shared_ptr cached_record_batch; - - void reset() - { - current_array.reset(); - current_offset = 0; - exhausted = false; - cached_record_batch.reset(); - } -}; - class ArrowTableReader; using ArrowTableReaderPtr = std::shared_ptr; @@ -39,7 +18,7 @@ class ArrowTableReader { public: ArrowTableReader( - pybind11::object & data_source_, + std::unique_ptr arrow_stream_, const DB::Block & sample_block_, const DB::FormatSettings & format_settings_, size_t num_streams_, @@ -51,8 +30,8 @@ class ArrowTableReader DB::Chunk readNextChunk(size_t stream_index); private: - /// Initialize the Arrow stream from Python object - void initializeStream(pybind11::object & data_source_); + /// Initialize the Arrow stream from ArrowArrayStreamWrapper + void initializeStream(); /// Convert Arrow array slice to ClickHouse chunk DB::Chunk convertArrowArrayToChunk(const ArrowArrayWrapper & arrow_array, size_t offset, size_t count, size_t stream_index); @@ -77,7 +56,6 @@ class ArrowTableReader /// Global stream state bool global_stream_exhausted = false; - size_t total_rows_hint = 0; /// Mutex for thread-safe access to arrow_stream mutable std::mutex stream_mutex; diff --git a/programs/local/CMakeLists.txt b/programs/local/CMakeLists.txt index 9e32398c67f..f84770e6392 100644 --- a/programs/local/CMakeLists.txt +++ b/programs/local/CMakeLists.txt @@ -1,19 +1,17 @@ set (CLICKHOUSE_LOCAL_SOURCES chdb.cpp + ArrowSchema.cpp + ArrowStreamWrapper.cpp + ArrowTableReader.cpp LocalServer.cpp ) -# Add ArrowStream table function sources when not using Python if (NOT USE_PYTHON) set (CHDB_ARROW_SOURCES chdb-arrow.cpp - ArrowStreamRegistry.h ArrowStreamSource.cpp - ArrowStreamSource.h StorageArrowStream.cpp - StorageArrowStream.h TableFunctionArrowStream.cpp - TableFunctionArrowStream.h ) set (CLICKHOUSE_LOCAL_SOURCES ${CLICKHOUSE_LOCAL_SOURCES} ${CHDB_ARROW_SOURCES}) endif() @@ -27,9 +25,6 @@ endif() if (USE_PYTHON) set (CHDB_SOURCES chdb.cpp - ArrowSchema.cpp - ArrowStreamWrapper.cpp - ArrowTableReader.cpp FormatHelper.cpp ListScan.cpp LocalChdb.cpp @@ -38,6 +33,7 @@ if (USE_PYTHON) PandasAnalyzer.cpp PandasDataFrame.cpp PandasScan.cpp + PyArrowStreamFactory.cpp PyArrowTable.cpp PybindWrapper.cpp PythonConversion.cpp diff --git a/programs/local/LocalServer.cpp b/programs/local/LocalServer.cpp index 319c83df28a..c7b6026fea8 100644 --- a/programs/local/LocalServer.cpp +++ b/programs/local/LocalServer.cpp @@ -2,12 +2,13 @@ #include "chdb-internal.h" #if USE_PYTHON +#include "StoragePython.h" #include "TableFunctionPython.h" -#include -#include #else +#include "StorageArrowStream.h" #include "TableFunctionArrowStream.h" #endif +#include #include #include @@ -658,9 +659,11 @@ try registerDatabases(); registerStorages(); -#if USE_PYTHON auto & storage_factory = StorageFactory::instance(); +#if USE_PYTHON registerStoragePython(storage_factory); +#else + registerStorageArrowStream(storage_factory); #endif registerDictionaries(); diff --git a/programs/local/PyArrowStreamFactory.cpp b/programs/local/PyArrowStreamFactory.cpp new file mode 100644 index 00000000000..0272985194e --- /dev/null +++ b/programs/local/PyArrowStreamFactory.cpp @@ -0,0 +1,113 @@ +#include "PyArrowStreamFactory.h" +#include "PyArrowTable.h" +#include "PythonImporter.h" + +#include +#include + +#include +#include +#include + +namespace DB +{ + +namespace ErrorCodes +{ +extern const int PY_EXCEPTION_OCCURED; +} + +} + +using namespace DB; +namespace py = pybind11; + +namespace CHDB +{ + +std::unique_ptr PyArrowStreamFactory::createFromPyObject( + py::object & py_obj, + const Names & column_names) +{ + chassert(py::gil_check()); + + try + { + auto arrow_object_type = PyArrowTable::getArrowType(py_obj); + + switch (arrow_object_type) + { + case PyArrowObjectType::Table: + return createFromTable(py_obj, column_names); + default: + throw Exception(ErrorCodes::PY_EXCEPTION_OCCURED, + "Unsupported PyArrow object type: {}", arrow_object_type); + } + } + catch (const py::error_already_set & e) + { + throw Exception(ErrorCodes::PY_EXCEPTION_OCCURED, + "Failed to convert PyArrow object to arrow array stream: {}", e.what()); + } +} + +std::unique_ptr PyArrowStreamFactory::createFromTable( + py::object & table, + const Names & column_names) +{ + chassert(py::gil_check()); + + py::handle table_handle(table); + auto & import_cache = PythonImporter::ImportCache(); + auto arrow_dataset = import_cache.pyarrow.dataset().attr("dataset"); + + auto dataset = arrow_dataset(table_handle); + py::object arrow_scanner = dataset.attr("__class__").attr("scanner"); + + py::dict kwargs; + if (!column_names.empty()) { + ArrowSchemaWrapper schema; + auto obj_schema = table_handle.attr("schema"); + auto export_to_c = obj_schema.attr("_export_to_c"); + export_to_c(reinterpret_cast(&schema.arrow_schema)); + + /// Get available column names from schema + std::unordered_set available_columns; + if (schema.arrow_schema.n_children > 0 && schema.arrow_schema.children) + { + for (int64_t i = 0; i < schema.arrow_schema.n_children; ++i) + { + if (schema.arrow_schema.children[i] && schema.arrow_schema.children[i]->name) + { + available_columns.insert(schema.arrow_schema.children[i]->name); + } + } + } + + /// Only add column names that exist in the schema + py::list projection_list; + for (const auto & name : column_names) + { + if (available_columns.contains(name)) + { + projection_list.append(name); + } + } + + /// Only set columns if we have valid projections + if (projection_list.size() > 0) + { + kwargs["columns"] = projection_list; + } + } + + auto scanner = arrow_scanner(dataset, **kwargs); + + auto record_batches = scanner.attr("to_reader")(); + auto res = std::make_unique(); + auto export_to_c = record_batches.attr("_export_to_c"); + export_to_c(reinterpret_cast(&res->arrow_array_stream)); + return res; +} + +} // namespace CHDB diff --git a/programs/local/PyArrowStreamFactory.h b/programs/local/PyArrowStreamFactory.h new file mode 100644 index 00000000000..4c480d1d113 --- /dev/null +++ b/programs/local/PyArrowStreamFactory.h @@ -0,0 +1,25 @@ +#pragma once + +#include "ArrowStreamWrapper.h" + +#include +#include + +namespace CHDB +{ + +/// Factory class for creating ArrowArrayStream from Python objects +class PyArrowStreamFactory +{ +public: + static std::unique_ptr createFromPyObject( + pybind11::object & py_obj, + const DB::Names & column_names); + +private: + static std::unique_ptr createFromTable( + pybind11::object & table, + const DB::Names & column_names); +}; + +} // namespace CHDB diff --git a/programs/local/PyArrowTable.cpp b/programs/local/PyArrowTable.cpp index 83bdc99b761..5eef4b43ee4 100644 --- a/programs/local/PyArrowTable.cpp +++ b/programs/local/PyArrowTable.cpp @@ -3,70 +3,13 @@ #include "PyArrowCacheItem.h" #include "PythonImporter.h" -#include #include -#include -#include -#include -#include -#include -#include -#include - -namespace DB -{ - -namespace ErrorCodes -{ -extern const int BAD_ARGUMENTS; -extern const int PY_EXCEPTION_OCCURED; -} - -} using namespace DB; namespace CHDB { -static void convertArrowSchema( - ArrowSchemaWrapper & schema, - NamesAndTypesList & names_and_types, - ContextPtr & context) -{ - if (!schema.arrow_schema.release) - { - throw Exception(ErrorCodes::BAD_ARGUMENTS, "ArrowSchema is already released"); - } - - /// Import ArrowSchema to arrow::Schema - auto arrow_schema_result = arrow::ImportSchema(&schema.arrow_schema); - if (!arrow_schema_result.ok()) - { - throw Exception(ErrorCodes::BAD_ARGUMENTS, - "Failed to import Arrow schema: {}", arrow_schema_result.status().message()); - } - - const auto & arrow_schema = arrow_schema_result.ValueOrDie(); - - const auto format_settings = getFormatSettings(context); - - /// Convert Arrow schema to ClickHouse header - auto block = ArrowColumnToCHColumn::arrowSchemaToCHHeader( - *arrow_schema, - nullptr, - "Arrow", - format_settings.arrow.skip_columns_with_unsupported_types_in_schema_inference, - format_settings.schema_inference_make_columns_nullable != 0, - false, - format_settings.parquet.allow_geoparquet_parser); - - for (const auto & column : block) - { - names_and_types.emplace_back(column.name, column.type); - } -} - PyArrowObjectType PyArrowTable::getArrowType(const py::object & obj) { chassert(py::gil_check()); @@ -107,7 +50,7 @@ ColumnsDescription PyArrowTable::getActualTableStructure(const py::object & obje ArrowSchemaWrapper schema; export_to_c(reinterpret_cast(&schema.arrow_schema)); - convertArrowSchema(schema, names_and_types, context); + ArrowSchemaWrapper::convertArrowSchema(schema, names_and_types, context); return ColumnsDescription(names_and_types); } diff --git a/programs/local/StorageArrowStream.cpp b/programs/local/StorageArrowStream.cpp new file mode 100644 index 00000000000..403875886f1 --- /dev/null +++ b/programs/local/StorageArrowStream.cpp @@ -0,0 +1,97 @@ +#include "StorageArrowStream.h" +#include "ArrowStreamSource.h" +#include "ArrowStreamWrapper.h" +#include "ArrowTableReader.h" + +#include +#include +#include + +namespace DB +{ + +namespace ErrorCodes +{ +extern const int NUMBER_OF_ARGUMENTS_DOESNT_MATCH; +} + +void registerStorageArrowStream(StorageFactory & factory) +{ + factory.registerStorage( + "ArrowStream", + [](const StorageFactory::Arguments & args) -> StoragePtr + { + if (args.engine_args.size() != 1) + throw Exception(ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH, "ArrowStream engine requires 1 argument: ArrowStreamInfo object"); + + CHDB::ArrowStreamRegistry::ArrowStreamInfo stream_info = std::any_cast(args.engine_args[0]); + return std::make_shared(args.table_id, stream_info, args.columns, args.getLocalContext()); + }, + { + .supports_settings = false, + .supports_parallel_insert = false, + }); +} + +StorageArrowStream::StorageArrowStream( + const StorageID & storage_id_, + const CHDB::ArrowStreamRegistry::ArrowStreamInfo & stream_info_, + const ColumnsDescription & columns_, + ContextPtr context_) + : IStorage(storage_id_) + , WithContext(context_) + , stream_info(stream_info_) +{ + StorageInMemoryMetadata storage_metadata; + storage_metadata.setColumns(columns_); + setInMemoryMetadata(storage_metadata); +} + +Pipe StorageArrowStream::read( + const Names & column_names, + const StorageSnapshotPtr & storage_snapshot, + SelectQueryInfo & /*query_info*/, + ContextPtr /*context*/, + QueryProcessingStage::Enum /*processed_stage*/, + size_t max_block_size, + size_t num_streams) +{ + chassert(stream_info.stream); + storage_snapshot->check(column_names); + + Block sample_block = prepareSampleBlock(column_names, storage_snapshot); + auto format_settings = getFormatSettings(getContext()); + + /// Create ArrowArrayStreamWrapper from the registered stream + auto arrow_stream_wrapper = std::make_unique(false); + arrow_stream_wrapper->arrow_array_stream = *stream_info.stream; + + auto arrow_table_reader = std::make_shared( + std::move(arrow_stream_wrapper), + sample_block, + format_settings, + num_streams, + max_block_size + ); + + Pipes pipes; + for (size_t stream = 0; stream < num_streams; ++stream) + { + pipes.emplace_back(std::make_shared( + sample_block, arrow_table_reader, stream)); + } + return Pipe::unitePipes(std::move(pipes)); +} + +Block StorageArrowStream::prepareSampleBlock(const Names & column_names, const StorageSnapshotPtr & storage_snapshot) +{ + Block sample_block; + for (const String & column_name : column_names) + { + auto column_data = storage_snapshot->metadata->getColumns().getPhysical(column_name); + sample_block.insert({column_data.type, column_data.name}); + } + return sample_block; +} + +} diff --git a/programs/local/StorageArrowStream.h b/programs/local/StorageArrowStream.h new file mode 100644 index 00000000000..46d0262680c --- /dev/null +++ b/programs/local/StorageArrowStream.h @@ -0,0 +1,44 @@ +#pragma once + +#include "ArrowStreamRegistry.h" + +#include +#include +#include +#include + +namespace DB +{ + +void registerStorageArrowStream(StorageFactory & factory); + +class StorageArrowStream : public IStorage, public WithContext +{ +public: + StorageArrowStream( + const StorageID & storage_id_, + const CHDB::ArrowStreamRegistry::ArrowStreamInfo & stream_info_, + const ColumnsDescription & columns_, + ContextPtr context_); + + ~StorageArrowStream() override = default; + + std::string getName() const override { return "ArrowStream"; } + + Pipe read( + const Names & column_names, + const StorageSnapshotPtr & storage_snapshot, + SelectQueryInfo & query_info, + ContextPtr context, + QueryProcessingStage::Enum processed_stage, + size_t max_block_size, + size_t num_streams) override; + + Block prepareSampleBlock(const Names & column_names, const StorageSnapshotPtr & storage_snapshot); + +private: + CHDB::ArrowStreamRegistry::ArrowStreamInfo stream_info; + Poco::Logger * logger = &Poco::Logger::get("StorageArrowStream"); +}; + +} diff --git a/programs/local/StoragePython.cpp b/programs/local/StoragePython.cpp index a60f52dc555..8f3b4f8002f 100644 --- a/programs/local/StoragePython.cpp +++ b/programs/local/StoragePython.cpp @@ -3,6 +3,7 @@ #include "PybindWrapper.h" #include "PythonSource.h" #include "PyArrowTable.h" +#include "PyArrowStreamFactory.h" #include #include @@ -81,18 +82,21 @@ Pipe StoragePython::read( std::make_shared(data_source, true, sample_block, column_cache, data_source_row_count, max_block_size, 0, 1, format_settings)); } - prepareColumnCache(column_names, sample_block.getColumns(), sample_block); - ArrowTableReaderPtr arrow_table_reader; { py::gil_scoped_acquire acquire; if (PyArrowTable::isPyArrowTable(data_source)) { - arrow_table_reader = std::make_shared(data_source, sample_block, + auto arrow_stream = PyArrowStreamFactory::createFromPyObject(data_source, sample_block.getNames()); + arrow_table_reader = std::make_shared( + std::move(arrow_stream), sample_block, format_settings, num_streams, max_block_size); } } + if (!arrow_table_reader) + prepareColumnCache(column_names, sample_block.getColumns(), sample_block); + Pipes pipes; for (size_t stream = 0; stream < num_streams; ++stream) pipes.emplace_back(std::make_shared( diff --git a/programs/local/TableFunctionArrowStream.cpp b/programs/local/TableFunctionArrowStream.cpp new file mode 100644 index 00000000000..a59176c71c0 --- /dev/null +++ b/programs/local/TableFunctionArrowStream.cpp @@ -0,0 +1,127 @@ +#include "TableFunctionArrowStream.h" +#include "ArrowSchema.h" +#include "ArrowStreamWrapper.h" +#include "StorageArrowStream.h" + +#include +#include +#include +#include +#include + +namespace DB +{ + +namespace ErrorCodes +{ + extern const int NUMBER_OF_ARGUMENTS_DOESNT_MATCH; + extern const int UNKNOWN_IDENTIFIER; + extern const int BAD_ARGUMENTS; +} + +void TableFunctionArrowStream::parseArguments(const ASTPtr & ast_function, ContextPtr context) +{ + const auto & func_args = ast_function->as(); + + if (!func_args.arguments) + throw Exception(ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH, + "Table function 'arrowstream' must have arguments."); + + ASTs & args = func_args.arguments->children; + + if (args.size() != 1) + throw Exception(ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH, + "ArrowStream table requires 1 argument: stream name"); + + auto stream_name_arg = evaluateConstantExpressionOrIdentifierAsLiteral(args[0], context); + + try + { + stream_name = stream_name_arg->as().value.safeGet(); + + stream_name.erase( + std::remove_if(stream_name.begin(), stream_name.end(), + [](char c) { return c == '\'' || c == '\"' || c == '`'; }), + stream_name.end()); + + auto stream_opt = CHDB::ArrowStreamRegistry::instance().getArrowStream(stream_name); + if (!stream_opt) + { + throw Exception(ErrorCodes::UNKNOWN_IDENTIFIER, + "ArrowStream '{}' not found in registry. " + "Please register it first using chdb_arrow_scan.", + stream_name); + } + + stream_info = *stream_opt; + } + catch (const Exception &) + { + throw; + } + catch (const std::exception & e) + { + throw Exception(ErrorCodes::BAD_ARGUMENTS, "Error parsing arrowstream argument: {}", e.what()); + } + catch (...) + { + throw Exception(ErrorCodes::BAD_ARGUMENTS, "Error parsing arrowstream argument"); + } +} + +StoragePtr TableFunctionArrowStream::executeImpl( + const ASTPtr & /*ast_function*/, + ContextPtr context, + const String & table_name, + ColumnsDescription /*cached_columns*/, + bool is_insert_query) const +{ + if (stream_name.empty() || !stream_info.stream) + throw Exception(ErrorCodes::BAD_ARGUMENTS, "ArrowStream name not initialized"); + + auto columns = getActualTableStructure(context, is_insert_query); + + auto storage = std::make_shared( + StorageID(getDatabaseName(), table_name), + stream_info, + columns, + context); + + storage->startup(); + return storage; +} + +ColumnsDescription TableFunctionArrowStream::getActualTableStructure( + ContextPtr context, bool /*is_insert_query*/) const +{ + auto * arrow_stream = reinterpret_cast(stream_info.stream); + CHDB::ArrowSchemaWrapper schema; + + if (arrow_stream->get_schema(arrow_stream, &schema.arrow_schema) != 0) + { + throw Exception(ErrorCodes::BAD_ARGUMENTS, + "Failed to get schema from ArrowStream '{}'", stream_name); + } + + NamesAndTypesList names_and_types; + CHDB::ArrowSchemaWrapper::convertArrowSchema(schema, names_and_types, context); + + return ColumnsDescription(names_and_types); +} + +void registerTableFunctionArrowStream(TableFunctionFactory & factory) +{ + factory.registerFunction( + {.documentation = { + .description = R"( +Creates a table from a registered ArrowStream. +This table function requires a single argument which is the name of a registered ArrowStream. +Use chdb_arrow_register_table() to register ArrowStreams first. +)", + .examples = {{"arrowstream", "SELECT * FROM arrowstream('my_data')", ""}}, + .category = FunctionDocumentation::Category::TableFunction + }}, + TableFunctionFactory::Case::Insensitive); +} + +} diff --git a/programs/local/TableFunctionArrowStream.h b/programs/local/TableFunctionArrowStream.h new file mode 100644 index 00000000000..761830e3633 --- /dev/null +++ b/programs/local/TableFunctionArrowStream.h @@ -0,0 +1,41 @@ +#pragma once + +#include "ArrowStreamRegistry.h" + +#include +#include +#include + +namespace DB +{ + +class TableFunctionFactory; +void registerTableFunctionArrowStream(TableFunctionFactory & factory); + +class TableFunctionArrowStream : public ITableFunction +{ +public: + static constexpr auto name = "arrowstream"; + std::string getName() const override { return name; } + +private: + Poco::Logger * logger = &Poco::Logger::get("TableFunctionArrowStream"); + + StoragePtr executeImpl( + const ASTPtr & ast_function, + ContextPtr context, + const std::string & table_name, + ColumnsDescription cached_columns, + bool is_insert_query) const override; + + const char * getStorageTypeName() const override { return "ArrowStream"; } + + void parseArguments(const ASTPtr & ast_function, ContextPtr context) override; + + ColumnsDescription getActualTableStructure(ContextPtr context, bool is_insert_query) const override; + + String stream_name; + CHDB::ArrowStreamRegistry::ArrowStreamInfo stream_info; +}; + +} diff --git a/programs/local/TableFunctionPython.cpp b/programs/local/TableFunctionPython.cpp index 5967bebf381..ecb4ca09ba8 100644 --- a/programs/local/TableFunctionPython.cpp +++ b/programs/local/TableFunctionPython.cpp @@ -1,3 +1,4 @@ +#include "TableFunctionPython.h" #include "StoragePython.h" #include "PandasDataFrame.h" #include "PyArrowTable.h" @@ -5,10 +6,8 @@ #include "PythonReader.h" #include "PythonTableCache.h" #include "PythonUtils.h" -#include "TableFunctionPython.h" #include - #include #include #include diff --git a/programs/local/TableFunctionPython.h b/programs/local/TableFunctionPython.h index 067a6fa4601..ffea035e24b 100644 --- a/programs/local/TableFunctionPython.h +++ b/programs/local/TableFunctionPython.h @@ -1,9 +1,7 @@ #pragma once -#include "StoragePython.h" #include "PybindWrapper.h" -#include "config.h" #include #include #include @@ -22,7 +20,7 @@ class TableFunctionPython : public ITableFunction ~TableFunctionPython() override { // Acquire the GIL before destroying the reader object - py::gil_scoped_acquire acquire; + pybind11::gil_scoped_acquire acquire; reader.dec_ref(); reader.release(); } @@ -40,7 +38,7 @@ class TableFunctionPython : public ITableFunction void parseArguments(const ASTPtr & ast_function, ContextPtr context) override; ColumnsDescription getActualTableStructure(ContextPtr context, bool is_insert_query) const override; - py::object reader; + pybind11::object reader; }; } diff --git a/programs/local/chdb-arrow.cpp b/programs/local/chdb-arrow.cpp index d8d07d7e764..5200c537a7e 100644 --- a/programs/local/chdb-arrow.cpp +++ b/programs/local/chdb-arrow.cpp @@ -9,7 +9,8 @@ namespace CHDB { -struct PrivateData { +struct PrivateData +{ ArrowSchema * schema; ArrowArray * array; bool done = false; @@ -72,11 +73,23 @@ void Release(struct ArrowArrayStream * stream) stream->release = nullptr; } +void chdb_destroy_arrow_stream(ArrowArrayStream * arrow_stream) +{ + if (!arrow_stream) + return; + + if (arrow_stream->release) + arrow_stream->release(arrow_stream); + chassert(!arrow_stream->release); + + delete arrow_stream; +} + } // namespace CHDB -chdb_state chdb_arrow_scan( +static chdb_state chdb_inner_arrow_scan( chdb_connection conn, const char * table_name, - chdb_arrow_stream arrow_stream) + chdb_arrow_stream arrow_stream, bool is_owner) { ChdbDestructorGuard guard; @@ -104,10 +117,10 @@ chdb_state chdb_arrow_scan( child->release = CHDB::EmptySchemaRelease; } + bool success = false; try { - bool success = DB::ArrowStreamRegistry::instance().registerArrowStream(String(table_name), stream); - return success ? CHDBSuccess : CHDBError; + success = CHDB::ArrowStreamRegistry::instance().registerArrowStream(String(table_name), stream, is_owner); } catch (...) { @@ -119,45 +132,33 @@ chdb_state chdb_arrow_scan( schema.children[i]->release = releases[i]; } - return CHDBSuccess; + return success ? CHDBSuccess : CHDBError; +} + +chdb_state chdb_arrow_scan( + chdb_connection conn, const char * table_name, + chdb_arrow_stream arrow_stream) +{ + return chdb_inner_arrow_scan(conn, table_name, arrow_stream, false); } chdb_state chdb_arrow_array_scan( chdb_connection conn, const char * table_name, - chdb_arrow_schema arrow_schema, chdb_arrow_array arrow_array, - chdb_arrow_stream * out_stream) + chdb_arrow_schema arrow_schema, chdb_arrow_array arrow_array) { - auto * private_data = new CHDB::PrivateData; + auto * private_data = new CHDB::PrivateData(); private_data->schema = reinterpret_cast(arrow_schema); private_data->array = reinterpret_cast(arrow_array); private_data->done = false; auto * stream = new ArrowArrayStream(); - *out_stream = reinterpret_cast(stream); stream->get_schema = CHDB::GetSchema; stream->get_next = CHDB::GetNext; stream->get_last_error = CHDB::GetLastError; stream->release = CHDB::Release; stream->private_data = private_data; - return chdb_arrow_scan(conn, table_name, reinterpret_cast(stream)); -} - -void chdb_destroy_arrow_stream(chdb_arrow_stream * arrow_stream) -{ - if (!arrow_stream) - return; - - auto * stream = reinterpret_cast(*arrow_stream); - if (!stream) - return; - - if (stream->release) - stream->release(stream); - chassert(!stream->release); - - delete stream; - *arrow_stream = nullptr; + return chdb_inner_arrow_scan(conn, table_name, reinterpret_cast(stream), true); } chdb_state chdb_arrow_unregister_table(chdb_connection conn, const char * table_name) @@ -175,7 +176,7 @@ chdb_state chdb_arrow_unregister_table(chdb_connection conn, const char * table_ try { - DB::ArrowStreamRegistry::instance().unregisterArrowStream(String(table_name)); + CHDB::ArrowStreamRegistry::instance().unregisterArrowStream(String(table_name)); return CHDBSuccess; } catch (...) diff --git a/programs/local/chdb-internal.h b/programs/local/chdb-internal.h index 45036098d3b..945cf4ba3ae 100644 --- a/programs/local/chdb-internal.h +++ b/programs/local/chdb-internal.h @@ -8,7 +8,7 @@ #include #include #include -#include +#include namespace DB { @@ -119,4 +119,7 @@ void cancelStreamQuery(DB::LocalServer * server, void * stream_result); const std::string & chdb_result_error_string(chdb_result * result); const std::string & chdb_streaming_result_error_string(chdb_streaming_result * result); + +void chdb_destroy_arrow_stream(ArrowArrayStream * arrow_stream); + } diff --git a/programs/local/chdb.cpp b/programs/local/chdb.cpp index 034f8b5fecc..26885cb6cdb 100644 --- a/programs/local/chdb.cpp +++ b/programs/local/chdb.cpp @@ -1,14 +1,11 @@ #include "chdb.h" -#include -#include -#include "Common/MemoryTracker.h" +#include "chdb-internal.h" #include "LocalServer.h" #include "QueryResult.h" -#include "chdb-internal.h" #if USE_PYTHON -# include "FormatHelper.h" -# include "PythonTableCache.h" +#include "FormatHelper.h" +#include "PythonTableCache.h" #endif #ifdef CHDB_STATIC_LIBRARY_BUILD @@ -26,6 +23,18 @@ std::shared_mutex global_connection_mutex; namespace CHDB { + +#if !USE_PYTHON +extern "C" +{ + extern chdb_state chdb_arrow_scan(chdb_connection, const char *, chdb_arrow_stream); +} + +[[maybe_unused]] void * force_link_arrow_functions[] = { + reinterpret_cast(chdb_arrow_scan) +}; +#endif + static std::mutex CHDB_MUTEX; chdb_conn * global_conn_ptr = nullptr; std::string global_db_path; diff --git a/programs/local/chdb.h b/programs/local/chdb.h index cebee1a36c1..30fc8942ea8 100644 --- a/programs/local/chdb.h +++ b/programs/local/chdb.h @@ -89,21 +89,20 @@ typedef struct chdb_connection_ void * internal_data; } * chdb_connection; -// Holds an arrow array stream. Wraps ArrowArrayStream for chdb usage. -// Must be released with chdb_destroy_arrow_stream when no longer needed. -typedef struct _chdb_arrow_stream +// Holds an arrow array stream. +typedef struct chdb_arrow_stream_ { void * internal_data; } * chdb_arrow_stream; -// Holds an arrow schema. Wraps ArrowSchema for chdb usage. -typedef struct _chdb_arrow_schema +// Holds an arrow schema. +typedef struct chdb_arrow_schema_ { void * internal_data; } * chdb_arrow_schema; -// Holds an arrow array. Wraps ArrowArray for chdb usage. -typedef struct _chdb_arrow_array +// Holds an arrow array. +typedef struct chdb_arrow_array_ { void * internal_data; } * chdb_arrow_array; @@ -410,7 +409,6 @@ CHDB_EXPORT const char * chdb_result_error(chdb_result * result); * @param conn The connection on which to execute the registration * @param table_name Name to register for the arrow stream table function * @param arrow_stream chdb Arrow stream handle - * @param arrow_schema chdb Arrow schema handle * @return CHDBSuccess on success, CHDBError on failure */ CHDB_EXPORT chdb_state chdb_arrow_scan( @@ -423,19 +421,11 @@ CHDB_EXPORT chdb_state chdb_arrow_scan( * @param table_name Name to register for the arrow stream table function * @param arrow_schema chdb Arrow schema handle * @param arrow_array chdb Arrow array handle - * @param out_stream Optional output stream handle for result streaming * @return CHDBSuccess on success, CHDBError on failure */ CHDB_EXPORT chdb_state chdb_arrow_array_scan( chdb_connection conn, const char * table_name, - chdb_arrow_schema arrow_schema, chdb_arrow_array arrow_array, - chdb_arrow_stream * out_stream); - -/** - * Destroys and releases resources for an Arrow stream handle - * @param arrow_stream Pointer to the Arrow stream handle to destroy - */ -CHDB_EXPORT void chdb_destroy_arrow_stream(chdb_arrow_stream * arrow_stream); + chdb_arrow_schema arrow_schema, chdb_arrow_array arrow_array); /** * Unregisters an arrow stream table function that was previously registered via chdb_arrow_scan diff --git a/tests/test_arrow_table_queries.py b/tests/test_arrow_table_queries.py new file mode 100644 index 00000000000..e4f9e3d0a3e --- /dev/null +++ b/tests/test_arrow_table_queries.py @@ -0,0 +1,173 @@ +#!/usr/bin/env python3 + +import unittest +import tempfile +import os +import shutil +import pyarrow as pa +import pyarrow.parquet as pq +import chdb +from chdb import session +from urllib.request import urlretrieve + +if os.path.exists(".test_chdb_arrow_table"): + shutil.rmtree(".test_chdb_arrow_table", ignore_errors=True) +sess = session.Session(".test_chdb_arrow_table") + +class TestChDBArrowTable(unittest.TestCase): + @classmethod + def setUpClass(cls): + # Download parquet file if it doesn't exist + cls.parquet_file = "hits_0.parquet" + if not os.path.exists(cls.parquet_file): + print(f"Downloading {cls.parquet_file}...") + url = "https://datasets.clickhouse.com/hits_compatible/athena_partitioned/hits_0.parquet" + urlretrieve(url, cls.parquet_file) + print("Download complete!") + + # Load parquet as PyArrow table + cls.arrow_table = pq.read_table(cls.parquet_file) + cls.table_size = cls.arrow_table.nbytes + cls.num_rows = cls.arrow_table.num_rows + cls.num_columns = cls.arrow_table.num_columns + + print(f"Loaded Arrow table: {cls.num_rows} rows, {cls.num_columns} columns, {cls.table_size} bytes") + + @classmethod + def tearDownClass(cls): + # Clean up session directory + if os.path.exists(".test_chdb_arrow_table"): + shutil.rmtree(".test_chdb_arrow_table", ignore_errors=True) + + def setUp(self): + pass + + def tearDown(self): + pass + + def test_arrow_table_basic_info(self): + """Test basic Arrow table information""" + self.assertEqual(self.table_size, 729898624) + self.assertEqual(self.num_rows, 1000000) + self.assertEqual(self.num_columns, 105) + + def test_arrow_table_count(self): + """Test counting rows in Arrow table""" + my_arrow_table = self.arrow_table + result = sess.query("SELECT COUNT(*) as row_count FROM Python(my_arrow_table)", "CSV") + lines = str(result).strip().split('\n') + count = int(lines[0]) + self.assertEqual(count, self.num_rows, f"Count should match table rows: {self.num_rows}") + + def test_arrow_table_schema(self): + """Test querying Arrow table schema information""" + my_arrow_table = self.arrow_table + result = sess.query("DESCRIBE Python(my_arrow_table)", "CSV") + # print(result) + self.assertIn('WatchID', str(result)) + self.assertIn('URLHash', str(result)) + + def test_arrow_table_limit(self): + """Test LIMIT queries on Arrow table""" + my_arrow_table = self.arrow_table + result = sess.query("SELECT * FROM Python(my_arrow_table) LIMIT 5", "CSV") + lines = str(result).strip().split('\n') + self.assertEqual(len(lines), 5, "Should have 5 data rows") + + def test_arrow_table_select_columns(self): + """Test selecting specific columns from Arrow table""" + my_arrow_table = self.arrow_table + # Get first few column names from schema + schema = self.arrow_table.schema + first_col = schema.field(0).name + second_col = schema.field(1).name if len(schema) > 1 else first_col + + result = sess.query(f"SELECT {first_col}, {second_col} FROM Python(my_arrow_table) LIMIT 3", "CSV") + lines = str(result).strip().split('\n') + self.assertEqual(len(lines), 3, "Should have 3 data rows") + + def test_arrow_table_where_clause(self): + """Test WHERE clause filtering on Arrow table""" + my_arrow_table = self.arrow_table + # Find a numeric column for filtering + numeric_col = None + for field in self.arrow_table.schema: + if pa.types.is_integer(field.type) or pa.types.is_floating(field.type): + numeric_col = field.name + break + + result = sess.query(f"SELECT COUNT(*) FROM Python(my_arrow_table) WHERE {numeric_col} > 1", "CSV") + lines = str(result).strip().split('\n') + count = int(lines[0]) + self.assertEqual(count, 1000000) + + def test_arrow_table_group_by(self): + """Test GROUP BY queries on Arrow table""" + my_arrow_table = self.arrow_table + # Find a string column for grouping + string_col = None + for field in self.arrow_table.schema: + if pa.types.is_binary(field.type) or pa.types.is_large_binary(field.type): + string_col = field.name + break + + result = sess.query(f"SELECT {string_col}, COUNT(*) as cnt FROM Python(my_arrow_table) GROUP BY {string_col} ORDER BY cnt DESC LIMIT 5", "CSV") + lines = str(result).strip().split('\n') + self.assertEqual(len(lines), 5) + + def test_arrow_table_aggregations(self): + """Test aggregation functions on Arrow table""" + my_arrow_table = self.arrow_table + # Find a numeric column for aggregation + numeric_col = None + for field in self.arrow_table.schema: + if pa.types.is_integer(field.type) or pa.types.is_floating(field.type): + numeric_col = field.name + break + + result = sess.query(f"SELECT AVG({numeric_col}) as avg_val, MIN({numeric_col}) as min_val, MAX({numeric_col}) as max_val FROM Python(my_arrow_table)", "CSV") + lines = str(result).strip().split('\n') + self.assertEqual(len(lines), 1) + + def test_arrow_table_order_by(self): + """Test ORDER BY queries on Arrow table""" + my_arrow_table = self.arrow_table + # Use first column for ordering + first_col = self.arrow_table.schema.field(0).name + + result = sess.query(f"SELECT {first_col} FROM Python(my_arrow_table) ORDER BY {first_col} LIMIT 10", "CSV") + lines = str(result).strip().split('\n') + self.assertEqual(len(lines), 10) + + def test_arrow_table_subquery(self): + """Test subqueries with Arrow table""" + my_arrow_table = self.arrow_table + result = sess.query(""" + SELECT COUNT(*) as total_count + FROM ( + SELECT * FROM Python(my_arrow_table) + WHERE WatchID IS NOT NULL + LIMIT 1000 + ) subq + """, "CSV") + lines = str(result).strip().split('\n') + self.assertEqual(len(lines), 1) + count = int(lines[0]) + self.assertEqual(count, 1000) + + def test_arrow_table_multiple_tables(self): + """Test using multiple Arrow tables in one query""" + my_arrow_table = self.arrow_table + # Create a smaller subset table + subset_table = my_arrow_table.slice(0, min(100, my_arrow_table.num_rows)) + + result = sess.query(""" + SELECT + (SELECT COUNT(*) FROM Python(my_arrow_table)) as full_count, + (SELECT COUNT(*) FROM Python(subset_table)) as subset_count + """, "CSV") + self.assertEqual(str(result).strip(), '1000000,100') + + +if __name__ == '__main__': + unittest.main() From bd764040c97c1a7e7530e20e5ed6b0f75a687a88 Mon Sep 17 00:00:00 2001 From: wudidapaopao Date: Tue, 23 Sep 2025 20:34:39 +0800 Subject: [PATCH 03/13] fix: fix tests --- tests/test_arrow_table_queries.py | 29 ++++++++++++++++------------- 1 file changed, 16 insertions(+), 13 deletions(-) diff --git a/tests/test_arrow_table_queries.py b/tests/test_arrow_table_queries.py index e4f9e3d0a3e..a05bf3b0e95 100644 --- a/tests/test_arrow_table_queries.py +++ b/tests/test_arrow_table_queries.py @@ -10,9 +10,7 @@ from chdb import session from urllib.request import urlretrieve -if os.path.exists(".test_chdb_arrow_table"): - shutil.rmtree(".test_chdb_arrow_table", ignore_errors=True) -sess = session.Session(".test_chdb_arrow_table") +# Clean up and create session in the test methods instead of globally class TestChDBArrowTable(unittest.TestCase): @classmethod @@ -33,11 +31,16 @@ def setUpClass(cls): print(f"Loaded Arrow table: {cls.num_rows} rows, {cls.num_columns} columns, {cls.table_size} bytes") + if os.path.exists(".test_chdb_arrow_table"): + shutil.rmtree(".test_chdb_arrow_table", ignore_errors=True) + cls.sess = session.Session(".test_chdb_arrow_table") + @classmethod def tearDownClass(cls): # Clean up session directory if os.path.exists(".test_chdb_arrow_table"): shutil.rmtree(".test_chdb_arrow_table", ignore_errors=True) + cls.sess.close() def setUp(self): pass @@ -54,7 +57,7 @@ def test_arrow_table_basic_info(self): def test_arrow_table_count(self): """Test counting rows in Arrow table""" my_arrow_table = self.arrow_table - result = sess.query("SELECT COUNT(*) as row_count FROM Python(my_arrow_table)", "CSV") + result = self.sess.query("SELECT COUNT(*) as row_count FROM Python(my_arrow_table)", "CSV") lines = str(result).strip().split('\n') count = int(lines[0]) self.assertEqual(count, self.num_rows, f"Count should match table rows: {self.num_rows}") @@ -62,7 +65,7 @@ def test_arrow_table_count(self): def test_arrow_table_schema(self): """Test querying Arrow table schema information""" my_arrow_table = self.arrow_table - result = sess.query("DESCRIBE Python(my_arrow_table)", "CSV") + result = self.sess.query("DESCRIBE Python(my_arrow_table)", "CSV") # print(result) self.assertIn('WatchID', str(result)) self.assertIn('URLHash', str(result)) @@ -70,7 +73,7 @@ def test_arrow_table_schema(self): def test_arrow_table_limit(self): """Test LIMIT queries on Arrow table""" my_arrow_table = self.arrow_table - result = sess.query("SELECT * FROM Python(my_arrow_table) LIMIT 5", "CSV") + result = self.sess.query("SELECT * FROM Python(my_arrow_table) LIMIT 5", "CSV") lines = str(result).strip().split('\n') self.assertEqual(len(lines), 5, "Should have 5 data rows") @@ -82,7 +85,7 @@ def test_arrow_table_select_columns(self): first_col = schema.field(0).name second_col = schema.field(1).name if len(schema) > 1 else first_col - result = sess.query(f"SELECT {first_col}, {second_col} FROM Python(my_arrow_table) LIMIT 3", "CSV") + result = self.sess.query(f"SELECT {first_col}, {second_col} FROM Python(my_arrow_table) LIMIT 3", "CSV") lines = str(result).strip().split('\n') self.assertEqual(len(lines), 3, "Should have 3 data rows") @@ -96,7 +99,7 @@ def test_arrow_table_where_clause(self): numeric_col = field.name break - result = sess.query(f"SELECT COUNT(*) FROM Python(my_arrow_table) WHERE {numeric_col} > 1", "CSV") + result = self.sess.query(f"SELECT COUNT(*) FROM Python(my_arrow_table) WHERE {numeric_col} > 1", "CSV") lines = str(result).strip().split('\n') count = int(lines[0]) self.assertEqual(count, 1000000) @@ -111,7 +114,7 @@ def test_arrow_table_group_by(self): string_col = field.name break - result = sess.query(f"SELECT {string_col}, COUNT(*) as cnt FROM Python(my_arrow_table) GROUP BY {string_col} ORDER BY cnt DESC LIMIT 5", "CSV") + result = self.sess.query(f"SELECT {string_col}, COUNT(*) as cnt FROM Python(my_arrow_table) GROUP BY {string_col} ORDER BY cnt DESC LIMIT 5", "CSV") lines = str(result).strip().split('\n') self.assertEqual(len(lines), 5) @@ -125,7 +128,7 @@ def test_arrow_table_aggregations(self): numeric_col = field.name break - result = sess.query(f"SELECT AVG({numeric_col}) as avg_val, MIN({numeric_col}) as min_val, MAX({numeric_col}) as max_val FROM Python(my_arrow_table)", "CSV") + result = self.sess.query(f"SELECT AVG({numeric_col}) as avg_val, MIN({numeric_col}) as min_val, MAX({numeric_col}) as max_val FROM Python(my_arrow_table)", "CSV") lines = str(result).strip().split('\n') self.assertEqual(len(lines), 1) @@ -135,14 +138,14 @@ def test_arrow_table_order_by(self): # Use first column for ordering first_col = self.arrow_table.schema.field(0).name - result = sess.query(f"SELECT {first_col} FROM Python(my_arrow_table) ORDER BY {first_col} LIMIT 10", "CSV") + result = self.sess.query(f"SELECT {first_col} FROM Python(my_arrow_table) ORDER BY {first_col} LIMIT 10", "CSV") lines = str(result).strip().split('\n') self.assertEqual(len(lines), 10) def test_arrow_table_subquery(self): """Test subqueries with Arrow table""" my_arrow_table = self.arrow_table - result = sess.query(""" + result = self.sess.query(""" SELECT COUNT(*) as total_count FROM ( SELECT * FROM Python(my_arrow_table) @@ -161,7 +164,7 @@ def test_arrow_table_multiple_tables(self): # Create a smaller subset table subset_table = my_arrow_table.slice(0, min(100, my_arrow_table.num_rows)) - result = sess.query(""" + result = self.sess.query(""" SELECT (SELECT COUNT(*) FROM Python(my_arrow_table)) as full_count, (SELECT COUNT(*) FROM Python(subset_table)) as subset_count From b243c7b6db07f8c98b906a27cf12b22d096074ae Mon Sep 17 00:00:00 2001 From: wudidapaopao Date: Tue, 23 Sep 2025 23:50:47 +0800 Subject: [PATCH 04/13] fix: fix tests --- .github/workflows/build_macos_arm64_wheels.yml | 2 +- .github/workflows/build_macos_x86_wheels.yml | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/build_macos_arm64_wheels.yml b/.github/workflows/build_macos_arm64_wheels.yml index 39cd4d2f1e0..0c857fbb515 100644 --- a/.github/workflows/build_macos_arm64_wheels.yml +++ b/.github/workflows/build_macos_arm64_wheels.yml @@ -79,7 +79,7 @@ jobs: brew install ca-certificates lz4 mpdecimal openssl@3 readline sqlite xz z3 zstd brew install --ignore-dependencies llvm@19 brew install git ninja libtool gettext gcc binutils grep findutils nasm - brew install --build-from-source ccache + brew install ccache || echo "ccache installation failed, continuing without it" brew install go cd /usr/local/opt/ && sudo rm -f llvm && sudo ln -sf llvm@19 llvm export PATH=$(brew --prefix llvm@19)/bin:$PATH diff --git a/.github/workflows/build_macos_x86_wheels.yml b/.github/workflows/build_macos_x86_wheels.yml index 148edc45971..974b99d3f44 100644 --- a/.github/workflows/build_macos_x86_wheels.yml +++ b/.github/workflows/build_macos_x86_wheels.yml @@ -79,7 +79,7 @@ jobs: brew install ca-certificates lz4 mpdecimal openssl@3 readline sqlite xz z3 zstd brew install --ignore-dependencies llvm@19 brew install git ninja libtool gettext gcc binutils grep findutils nasm - brew install --build-from-source ccache + brew install ccache || echo "ccache installation failed, continuing without it" brew install go cd /usr/local/opt/ && sudo rm -f llvm && sudo ln -sf llvm@19 llvm export PATH=$(brew --prefix llvm@19)/bin:$PATH From d3b252de7f15da5c6e26bc72bc6a7c2cbbe7a82b Mon Sep 17 00:00:00 2001 From: wudidapaopao Date: Wed, 24 Sep 2025 14:29:49 +0800 Subject: [PATCH 05/13] feat: optimize Arrow table generation using memoryview --- chdb/__init__.py | 4 +++- chdb/state/sqlitelike.py | 4 +++- examples/runArrowTest.sh | 3 ++- programs/local/ArrowTableReader.cpp | 8 +++++++- programs/local/chdb.h | 14 -------------- 5 files changed, 15 insertions(+), 18 deletions(-) diff --git a/chdb/__init__.py b/chdb/__init__.py index 0674a46927c..b4132409aa8 100644 --- a/chdb/__init__.py +++ b/chdb/__init__.py @@ -103,7 +103,9 @@ def to_arrowTable(res): raise ImportError("Failed to import pyarrow or pandas") from None if len(res) == 0: return pa.Table.from_batches([], schema=pa.schema([])) - return pa.RecordBatchFileReader(res.bytes()).read_all() + + memview = res.get_memview() + return pa.RecordBatchFileReader(memview.view()).read_all() # return pandas dataframe diff --git a/chdb/state/sqlitelike.py b/chdb/state/sqlitelike.py index 7694cb42ece..e743e1f722b 100644 --- a/chdb/state/sqlitelike.py +++ b/chdb/state/sqlitelike.py @@ -62,7 +62,9 @@ def to_arrowTable(res): raise ImportError("Failed to import pyarrow or pandas") from None if len(res) == 0: return pa.Table.from_batches([], schema=pa.schema([])) - return pa.RecordBatchFileReader(res.bytes()).read_all() + + memview = res.get_memview() + return pa.RecordBatchFileReader(memview.view()).read_all() # return pandas dataframe diff --git a/examples/runArrowTest.sh b/examples/runArrowTest.sh index 7f8d5e9bed0..96ec084e674 100755 --- a/examples/runArrowTest.sh +++ b/examples/runArrowTest.sh @@ -2,7 +2,8 @@ set -e -CXXFLAGS="-g -O0 -DDEBUG" +# CXXFLAGS="-g -O0 -DDEBUG" +CXXFLAGS="-std=c++17" # check current os type, and make ldd command if [ "$(uname)" == "Darwin" ]; then diff --git a/programs/local/ArrowTableReader.cpp b/programs/local/ArrowTableReader.cpp index 21fd7efa1af..8e5d31a9cd1 100644 --- a/programs/local/ArrowTableReader.cpp +++ b/programs/local/ArrowTableReader.cpp @@ -171,7 +171,13 @@ Chunk ArrowTableReader::convertArrowArrayToChunk(const ArrowArrayWrapper & arrow /// Use the cached RecordBatch and slice it record_batch = state.cached_record_batch; auto sliced_batch = record_batch->Slice(offset, count); - auto arrow_table = arrow::Table::FromRecordBatches({sliced_batch}).ValueOrDie(); + auto table_result = arrow::Table::FromRecordBatches({sliced_batch}); + if (!table_result.ok()) + { + throw Exception(ErrorCodes::BAD_ARGUMENTS, + "Failed to create Arrow table from RecordBatch: {}", table_result.status().ToString()); + } + const auto & arrow_table = table_result.ValueOrDie(); /// Use ArrowColumnToCHColumn to convert the batch ArrowColumnToCHColumn converter( diff --git a/programs/local/chdb.h b/programs/local/chdb.h index 30fc8942ea8..d16f5172f43 100644 --- a/programs/local/chdb.h +++ b/programs/local/chdb.h @@ -291,20 +291,6 @@ CHDB_EXPORT chdb_result * chdb_query_cmdline(int argc, char ** argv); */ CHDB_EXPORT chdb_result * chdb_stream_query(chdb_connection conn, const char * query, const char * format); -/** - * Executes a query with explicit string lengths (binary-safe). - * @brief Thread-safe function that handles query execution with specified buffer lengths - * @param conn Connection to execute query on - * @param query SQL query buffer (may contain null bytes) - * @param query_len Length of query buffer in bytes - * @param format Output format buffer (may contain null bytes) - * @param format_len Length of format buffer in bytes - * @return Query result structure containing output or error message - * @note Strings do not need to be null-terminated - * @note Use this function when dealing with queries/formats containing null bytes - */ -CHDB_EXPORT chdb_result * chdb_query_n(chdb_connection conn, const char * query, size_t query_len, const char * format, size_t format_len); - /** * Executes a streaming query with explicit string lengths (binary-safe). * @brief Initializes streaming query execution with specified buffer lengths From 0c0f7fb9c6acc121469c5025ac78c4c54895748c Mon Sep 17 00:00:00 2001 From: wudidapaopao Date: Thu, 25 Sep 2025 18:31:43 +0800 Subject: [PATCH 06/13] test: update workflow --- .../workflows/build_linux_arm64_wheels-gh.yml | 15 +++++++++++++ .github/workflows/build_linux_x86_wheels.yml | 15 +++++++++++++ .../workflows/build_macos_arm64_wheels.yml | 15 +++++++++++++ .github/workflows/build_macos_x86_wheels.yml | 22 ++++++++++++++----- 4 files changed, 62 insertions(+), 5 deletions(-) diff --git a/.github/workflows/build_linux_arm64_wheels-gh.yml b/.github/workflows/build_linux_arm64_wheels-gh.yml index e12fca991e4..a08c6fd59c3 100644 --- a/.github/workflows/build_linux_arm64_wheels-gh.yml +++ b/.github/workflows/build_linux_arm64_wheels-gh.yml @@ -24,6 +24,21 @@ jobs: build_universal_wheel: name: Build Universal Wheel (Linux ARM64) runs-on: GH-Linux-ARM64 + steps: + - name: Check machine architecture + run: | + echo "=== Machine Architecture Information ===" + echo "Machine type: $(uname -m)" + echo "Architecture: $(arch)" + echo "System info: $(uname -a)" + echo "Hardware info:" + system_profiler SPHardwareDataType | grep "Chip\|Processor" + if sysctl -n hw.optional.arm64 2>/dev/null | grep -q "1"; then + echo "This is an ARM64 (Apple Silicon) machine" + else + echo "This is an x86_64 (Intel) machine" + fi + echo "========================================" steps: - name: Install Python build dependencies run: | diff --git a/.github/workflows/build_linux_x86_wheels.yml b/.github/workflows/build_linux_x86_wheels.yml index 32ad5853766..0c038a280dc 100644 --- a/.github/workflows/build_linux_x86_wheels.yml +++ b/.github/workflows/build_linux_x86_wheels.yml @@ -24,6 +24,21 @@ jobs: build_universal_wheel: name: Build Universal Wheel (Linux x86_64) runs-on: gh-64c + steps: + - name: Check machine architecture + run: | + echo "=== Machine Architecture Information ===" + echo "Machine type: $(uname -m)" + echo "Architecture: $(arch)" + echo "System info: $(uname -a)" + echo "Hardware info:" + system_profiler SPHardwareDataType | grep "Chip\|Processor" + if sysctl -n hw.optional.arm64 2>/dev/null | grep -q "1"; then + echo "This is an ARM64 (Apple Silicon) machine" + else + echo "This is an x86_64 (Intel) machine" + fi + echo "========================================" steps: - name: Install Python build dependencies run: | diff --git a/.github/workflows/build_macos_arm64_wheels.yml b/.github/workflows/build_macos_arm64_wheels.yml index 0c857fbb515..f12dbebb245 100644 --- a/.github/workflows/build_macos_arm64_wheels.yml +++ b/.github/workflows/build_macos_arm64_wheels.yml @@ -23,6 +23,21 @@ jobs: build_universal_wheel: name: Build Universal Wheel (macOS ARM64) runs-on: macos-13-xlarge + steps: + - name: Check machine architecture + run: | + echo "=== Machine Architecture Information ===" + echo "Machine type: $(uname -m)" + echo "Architecture: $(arch)" + echo "System info: $(uname -a)" + echo "Hardware info:" + system_profiler SPHardwareDataType | grep "Chip\|Processor" + if sysctl -n hw.optional.arm64 2>/dev/null | grep -q "1"; then + echo "This is an ARM64 (Apple Silicon) machine" + else + echo "This is an x86_64 (Intel) machine" + fi + echo "========================================" steps: - name: Setup pyenv run: | diff --git a/.github/workflows/build_macos_x86_wheels.yml b/.github/workflows/build_macos_x86_wheels.yml index 974b99d3f44..9b50885119e 100644 --- a/.github/workflows/build_macos_x86_wheels.yml +++ b/.github/workflows/build_macos_x86_wheels.yml @@ -22,7 +22,22 @@ on: jobs: build_universal_wheel: name: Build Universal Wheel (macOS x86_64) - runs-on: macos-13 + runs-on: macos-14-large + steps: + - name: Check machine architecture + run: | + echo "=== Machine Architecture Information ===" + echo "Machine type: $(uname -m)" + echo "Architecture: $(arch)" + echo "System info: $(uname -a)" + echo "Hardware info:" + system_profiler SPHardwareDataType | grep "Chip\|Processor" + if sysctl -n hw.optional.arm64 2>/dev/null | grep -q "1"; then + echo "This is an ARM64 (Apple Silicon) machine" + else + echo "This is an x86_64 (Intel) machine" + fi + echo "========================================" steps: - name: Setup pyenv run: | @@ -97,7 +112,7 @@ jobs: - name: ccache uses: hendrikmuhs/ccache-action@v1.2 with: - key: macos-13-x86_64 + key: macos-14-x86_64 max-size: 5G append-timestamp: true - name: Run chdb/build.sh @@ -142,9 +157,6 @@ jobs: - name: Run libchdb stub in examples dir run: | bash -x ./examples/runStub.sh - - name: Run Arrow functions test in examples dir - run: | - bash -x ./examples/runArrowTest.sh - name: Keep killall ccache and wait for ccache to finish if: always() run: | From aa410256523ee0cfdff4711d2abc5ae1d35ba74f Mon Sep 17 00:00:00 2001 From: wudidapaopao Date: Thu, 25 Sep 2025 18:33:48 +0800 Subject: [PATCH 07/13] test: update workflow --- .github/workflows/build_linux_arm64_wheels-gh.yml | 15 --------------- .github/workflows/build_linux_x86_wheels.yml | 15 --------------- .github/workflows/build_macos_arm64_wheels.yml | 1 - .github/workflows/build_macos_x86_wheels.yml | 1 - 4 files changed, 32 deletions(-) diff --git a/.github/workflows/build_linux_arm64_wheels-gh.yml b/.github/workflows/build_linux_arm64_wheels-gh.yml index a08c6fd59c3..e12fca991e4 100644 --- a/.github/workflows/build_linux_arm64_wheels-gh.yml +++ b/.github/workflows/build_linux_arm64_wheels-gh.yml @@ -24,21 +24,6 @@ jobs: build_universal_wheel: name: Build Universal Wheel (Linux ARM64) runs-on: GH-Linux-ARM64 - steps: - - name: Check machine architecture - run: | - echo "=== Machine Architecture Information ===" - echo "Machine type: $(uname -m)" - echo "Architecture: $(arch)" - echo "System info: $(uname -a)" - echo "Hardware info:" - system_profiler SPHardwareDataType | grep "Chip\|Processor" - if sysctl -n hw.optional.arm64 2>/dev/null | grep -q "1"; then - echo "This is an ARM64 (Apple Silicon) machine" - else - echo "This is an x86_64 (Intel) machine" - fi - echo "========================================" steps: - name: Install Python build dependencies run: | diff --git a/.github/workflows/build_linux_x86_wheels.yml b/.github/workflows/build_linux_x86_wheels.yml index 0c038a280dc..32ad5853766 100644 --- a/.github/workflows/build_linux_x86_wheels.yml +++ b/.github/workflows/build_linux_x86_wheels.yml @@ -24,21 +24,6 @@ jobs: build_universal_wheel: name: Build Universal Wheel (Linux x86_64) runs-on: gh-64c - steps: - - name: Check machine architecture - run: | - echo "=== Machine Architecture Information ===" - echo "Machine type: $(uname -m)" - echo "Architecture: $(arch)" - echo "System info: $(uname -a)" - echo "Hardware info:" - system_profiler SPHardwareDataType | grep "Chip\|Processor" - if sysctl -n hw.optional.arm64 2>/dev/null | grep -q "1"; then - echo "This is an ARM64 (Apple Silicon) machine" - else - echo "This is an x86_64 (Intel) machine" - fi - echo "========================================" steps: - name: Install Python build dependencies run: | diff --git a/.github/workflows/build_macos_arm64_wheels.yml b/.github/workflows/build_macos_arm64_wheels.yml index f12dbebb245..484b67c0964 100644 --- a/.github/workflows/build_macos_arm64_wheels.yml +++ b/.github/workflows/build_macos_arm64_wheels.yml @@ -37,7 +37,6 @@ jobs: else echo "This is an x86_64 (Intel) machine" fi - echo "========================================" steps: - name: Setup pyenv run: | diff --git a/.github/workflows/build_macos_x86_wheels.yml b/.github/workflows/build_macos_x86_wheels.yml index 9b50885119e..9e1b5587054 100644 --- a/.github/workflows/build_macos_x86_wheels.yml +++ b/.github/workflows/build_macos_x86_wheels.yml @@ -37,7 +37,6 @@ jobs: else echo "This is an x86_64 (Intel) machine" fi - echo "========================================" steps: - name: Setup pyenv run: | From fd72aa4a7655ff7137bd7b1f26272b1cdcd798f4 Mon Sep 17 00:00:00 2001 From: wudidapaopao Date: Thu, 25 Sep 2025 18:36:07 +0800 Subject: [PATCH 08/13] test: update workflow --- .github/workflows/build_macos_arm64_wheels.yml | 1 - .github/workflows/build_macos_x86_wheels.yml | 1 - 2 files changed, 2 deletions(-) diff --git a/.github/workflows/build_macos_arm64_wheels.yml b/.github/workflows/build_macos_arm64_wheels.yml index 484b67c0964..5a8c72e8279 100644 --- a/.github/workflows/build_macos_arm64_wheels.yml +++ b/.github/workflows/build_macos_arm64_wheels.yml @@ -37,7 +37,6 @@ jobs: else echo "This is an x86_64 (Intel) machine" fi - steps: - name: Setup pyenv run: | curl https://pyenv.run | bash diff --git a/.github/workflows/build_macos_x86_wheels.yml b/.github/workflows/build_macos_x86_wheels.yml index 9e1b5587054..f61de5cde53 100644 --- a/.github/workflows/build_macos_x86_wheels.yml +++ b/.github/workflows/build_macos_x86_wheels.yml @@ -37,7 +37,6 @@ jobs: else echo "This is an x86_64 (Intel) machine" fi - steps: - name: Setup pyenv run: | curl https://pyenv.run | bash From 61a0ec52a2044ccab59b262d00866286fa0eb56e Mon Sep 17 00:00:00 2001 From: wudidapaopao Date: Fri, 26 Sep 2025 22:13:56 +0800 Subject: [PATCH 09/13] test: update workflow --- .../workflows/build_linux_arm64_wheels-gh.yml | 1 + .github/workflows/build_linux_x86_wheels.yml | 1 + .../workflows/build_macos_arm64_wheels.yml | 1 + .github/workflows/build_macos_x86_wheels.yml | 1 + examples/chdbArrowTest.c | 962 ++++++++++++++++++ examples/runArrowTestC.sh | 36 + 6 files changed, 1002 insertions(+) create mode 100644 examples/chdbArrowTest.c create mode 100755 examples/runArrowTestC.sh diff --git a/.github/workflows/build_linux_arm64_wheels-gh.yml b/.github/workflows/build_linux_arm64_wheels-gh.yml index e12fca991e4..7c6a50ab15c 100644 --- a/.github/workflows/build_linux_arm64_wheels-gh.yml +++ b/.github/workflows/build_linux_arm64_wheels-gh.yml @@ -138,6 +138,7 @@ jobs: - name: Run libchdb stub in examples dir run: | bash -x ./examples/runStub.sh + bash -x ./examples/runArrowTestC.sh - name: Run Arrow functions test in examples dir run: | bash -x ./examples/runArrowTest.sh diff --git a/.github/workflows/build_linux_x86_wheels.yml b/.github/workflows/build_linux_x86_wheels.yml index 32ad5853766..8d7c21e0411 100644 --- a/.github/workflows/build_linux_x86_wheels.yml +++ b/.github/workflows/build_linux_x86_wheels.yml @@ -138,6 +138,7 @@ jobs: - name: Run libchdb stub in examples dir run: | bash -x ./examples/runStub.sh + bash -x ./examples/runArrowTestC.sh - name: Run Arrow functions test in examples dir run: | bash -x ./examples/runArrowTest.sh diff --git a/.github/workflows/build_macos_arm64_wheels.yml b/.github/workflows/build_macos_arm64_wheels.yml index 5a8c72e8279..ac4dfa826bf 100644 --- a/.github/workflows/build_macos_arm64_wheels.yml +++ b/.github/workflows/build_macos_arm64_wheels.yml @@ -154,6 +154,7 @@ jobs: - name: Run libchdb stub in examples dir run: | bash -x ./examples/runStub.sh + bash -x ./examples/runArrowTestC.sh - name: Run Arrow functions test in examples dir run: | bash -x ./examples/runArrowTest.sh diff --git a/.github/workflows/build_macos_x86_wheels.yml b/.github/workflows/build_macos_x86_wheels.yml index f61de5cde53..48a52261156 100644 --- a/.github/workflows/build_macos_x86_wheels.yml +++ b/.github/workflows/build_macos_x86_wheels.yml @@ -155,6 +155,7 @@ jobs: - name: Run libchdb stub in examples dir run: | bash -x ./examples/runStub.sh + bash -x ./examples/runArrowTestC.sh - name: Keep killall ccache and wait for ccache to finish if: always() run: | diff --git a/examples/chdbArrowTest.c b/examples/chdbArrowTest.c new file mode 100644 index 00000000000..841be0cd9c3 --- /dev/null +++ b/examples/chdbArrowTest.c @@ -0,0 +1,962 @@ +#include +#include +#include +#include +#include +#include + +#include "../programs/local/chdb.h" +#include "../contrib/arrow/cpp/src/arrow/c/abi.h" + +// Custom ArrowArrayStream implementation data +typedef struct CustomStreamData +{ + bool schema_sent; + size_t current_row; + size_t total_rows; + size_t batch_size; + char* last_error; +} CustomStreamData; + +// Function to initialize CustomStreamData +static void init_custom_stream_data(CustomStreamData* data) { + data->schema_sent = false; + data->current_row = 0; + data->total_rows = 1000000; + data->batch_size = 10000; + data->last_error = NULL; +} + +// Reset the stream to allow reading from the beginning +static void reset_custom_stream_data(CustomStreamData* data) +{ + data->current_row = 0; + if (data->last_error) { + free(data->last_error); + data->last_error = NULL; + } +} + +// Release function prototypes +static void release_schema_child(struct ArrowSchema* s); +static void release_schema_main(struct ArrowSchema* s); +static void release_id_array(struct ArrowArray* arr); +static void release_string_array(struct ArrowArray* arr); +static void release_main_array(struct ArrowArray* arr); + +// Helper function to find minimum of two values +static size_t min_size_t(size_t a, size_t b) { + return (a < b) ? a : b; +} + +// Release function implementations +static void release_schema_child(struct ArrowSchema* s) { + s->release = NULL; +} + +static void release_schema_main(struct ArrowSchema* s) +{ + if (s->children) { + for (int64_t i = 0; i < s->n_children; i++) { + if (s->children[i] && s->children[i]->release) { + s->children[i]->release(s->children[i]); + } + free(s->children[i]); + } + free(s->children); + } + s->release = NULL; +} + +static void release_id_array(struct ArrowArray* arr) +{ + if (arr->buffers) { + free((void*)(uintptr_t)arr->buffers[1]); // free data buffer + free((void**)(uintptr_t)arr->buffers); + } + arr->release = NULL; +} + +static void release_string_array(struct ArrowArray* arr) +{ + if (arr->buffers) { + free((void*)(uintptr_t)arr->buffers[1]); // free offset buffer + free((void*)(uintptr_t)arr->buffers[2]); // free data buffer + free((void**)(uintptr_t)arr->buffers); + } + arr->release = NULL; +} + +static void release_main_array(struct ArrowArray* arr) { + if (arr->children) { + for (int64_t i = 0; i < arr->n_children; i++) { + if (arr->children[i] && arr->children[i]->release) { + arr->children[i]->release(arr->children[i]); + } + free(arr->children[i]); + } + free(arr->children); + } + if (arr->buffers) { + free((void**)(uintptr_t)arr->buffers); + } + arr->release = NULL; +} + +// Helper function to create schema with 2 columns: id(int64), value(string) +static void create_schema(struct ArrowSchema* schema) { + schema->format = "+s"; // struct format + schema->name = NULL; + schema->metadata = NULL; + schema->flags = 0; + schema->n_children = 2; + schema->children = (struct ArrowSchema**)malloc(2 * sizeof(struct ArrowSchema*)); + schema->dictionary = NULL; + schema->release = release_schema_main; + + // Field 0: id (int64) + schema->children[0] = (struct ArrowSchema*)malloc(sizeof(struct ArrowSchema)); + schema->children[0]->format = "l"; // int64 + schema->children[0]->name = "id"; + schema->children[0]->metadata = NULL; + schema->children[0]->flags = 0; + schema->children[0]->n_children = 0; + schema->children[0]->children = NULL; + schema->children[0]->dictionary = NULL; + schema->children[0]->release = release_schema_child; + + // Field 1: value (string) + schema->children[1] = (struct ArrowSchema*)malloc(sizeof(struct ArrowSchema)); + schema->children[1]->format = "u"; // utf8 string + schema->children[1]->name = "value"; + schema->children[1]->metadata = NULL; + schema->children[1]->flags = 0; + schema->children[1]->n_children = 0; + schema->children[1]->children = NULL; + schema->children[1]->dictionary = NULL; + schema->children[1]->release = release_schema_child; +} + +// Helper function to create a batch of data +static void create_batch(struct ArrowArray* array, size_t start_row, size_t batch_size) +{ + struct ArrowArray* id_array; + struct ArrowArray* str_array; + int64_t* id_data; + int32_t* offsets; + size_t total_str_len; + char** strings; + char* str_data; + size_t pos; + size_t i; + + // Main array structure + array->length = batch_size; + array->null_count = 0; + array->offset = 0; + array->n_buffers = 1; + array->n_children = 2; + array->buffers = (const void**)malloc(1 * sizeof(void*)); + array->buffers[0] = NULL; // validity buffer (no nulls) + array->children = (struct ArrowArray**)malloc(2 * sizeof(struct ArrowArray*)); + array->dictionary = NULL; + + // Create id column (int64) + array->children[0] = (struct ArrowArray*)malloc(sizeof(struct ArrowArray)); + id_array = array->children[0]; + id_array->length = batch_size; + id_array->null_count = 0; + id_array->offset = 0; + id_array->n_buffers = 2; + id_array->n_children = 0; + id_array->buffers = (const void**)malloc(2 * sizeof(void*)); + id_array->buffers[0] = NULL; // validity buffer + + // Allocate and fill id data + id_data = (int64_t*)malloc(batch_size * sizeof(int64_t)); + for (i = 0; i < batch_size; i++) + id_data[i] = start_row + i; + + id_array->buffers[1] = id_data; // data buffer + id_array->children = NULL; + id_array->dictionary = NULL; + id_array->release = release_id_array; + + // Create value column (string) + array->children[1] = (struct ArrowArray*)malloc(sizeof(struct ArrowArray)); + str_array = array->children[1]; + str_array->length = batch_size; + str_array->null_count = 0; + str_array->offset = 0; + str_array->n_buffers = 3; + str_array->n_children = 0; + str_array->buffers = (const void**)malloc(3 * sizeof(void*)); + str_array->buffers[0] = NULL; // validity buffer + + // Create offset buffer (int32) + offsets = (int32_t*)malloc((batch_size + 1) * sizeof(int32_t)); + offsets[0] = 0; + + // Calculate total string length and create strings + total_str_len = 0; + strings = (char**)malloc(batch_size * sizeof(char*)); + for (i = 0; i < batch_size; i++) + { + char buffer[64]; + size_t len; + snprintf(buffer, sizeof(buffer), "value_%zu", start_row + i); + len = strlen(buffer); + strings[i] = (char*)malloc(len + 1); + strcpy(strings[i], buffer); + total_str_len += len; + offsets[i + 1] = total_str_len; + } + str_array->buffers[1] = offsets; // offset buffer + + // Create data buffer + str_data = (char*)malloc(total_str_len); + pos = 0; + for (i = 0; i < batch_size; i++) + { + size_t len = strlen(strings[i]); + memcpy(str_data + pos, strings[i], len); + pos += len; + free(strings[i]); + } + free(strings); + str_array->buffers[2] = str_data; // data buffer + + str_array->children = NULL; + str_array->dictionary = NULL; + str_array->release = release_string_array; + + // Main array release function + array->release = release_main_array; +} + +// Callback function to get schema +static int custom_get_schema(struct ArrowArrayStream* stream, struct ArrowSchema* out) +{ + (void)stream; // Suppress unused parameter warning + create_schema(out); + return 0; +} + +// Callback function to get next array +static int custom_get_next(struct ArrowArrayStream* stream, struct ArrowArray* out) +{ + CustomStreamData* data; + size_t remaining_rows; + size_t batch_size; + + data = (CustomStreamData*)stream->private_data; + if (!data) + return EINVAL; + + // Check if we've reached the end of the stream + if (data->current_row >= data->total_rows) + { + // End of stream - set release to NULL to indicate no more data + out->release = NULL; + return 0; + } + + // Calculate batch size for this iteration + remaining_rows = data->total_rows - data->current_row; + batch_size = min_size_t(data->batch_size, remaining_rows); + + // Create the batch + create_batch(out, data->current_row, batch_size); + + data->current_row += batch_size; + return 0; +} + +// Callback function to get last error +static const char* custom_get_last_error(struct ArrowArrayStream* stream) { + CustomStreamData* data = (CustomStreamData*)stream->private_data; + if (!data || !data->last_error) + return NULL; + + return data->last_error; +} + +// Callback function to release stream resources +static void custom_release(struct ArrowArrayStream* stream) { + if (stream->private_data) + { + CustomStreamData* data = (CustomStreamData*)stream->private_data; + if (data->last_error) { + free(data->last_error); + } + free(data); + stream->private_data = NULL; + } + stream->release = NULL; +} + +// Helper function to reset the ArrowArrayStream for reuse +static void reset_arrow_stream(struct ArrowArrayStream* stream) +{ + if (stream && stream->private_data) + { + CustomStreamData* data = (CustomStreamData*)stream->private_data; + reset_custom_stream_data(data); + printf("✓ ArrowArrayStream has been reset, ready for re-reading\n"); + } +} + +//===--------------------------------------------------------------------===// +// Unit Test Utilities +//===--------------------------------------------------------------------===// + +static void test_assert(bool condition, const char* test_name, const char* message) +{ + if (condition) + { + printf("✓ PASS: %s\n", test_name); + } + else + { + printf("✗ FAIL: %s", test_name); + if (message && strlen(message) > 0) + { + printf(" - %s", message); + } + printf("\n"); + exit(1); + } +} + +static void test_assert_chdb_state(chdb_state state, const char* operation_name) +{ + char message[256]; + if (state == CHDBError) { + strcpy(message, "Operation failed"); + } else { + strcpy(message, "Unknown state"); + } + + test_assert(state == CHDBSuccess, operation_name, + state == CHDBError ? message : NULL); +} + +static void test_assert_not_null(void* ptr, const char* test_name) +{ + test_assert(ptr != NULL, test_name, "Pointer is null"); +} + +static void test_assert_no_error(chdb_result* result, const char* query_name) +{ + char full_test_name[512]; + const char* error; + + snprintf(full_test_name, sizeof(full_test_name), "%s - Result is not null", query_name); + test_assert_not_null(result, full_test_name); + + error = chdb_result_error(result); + snprintf(full_test_name, sizeof(full_test_name), "%s - No query error", query_name); + + if (error) { + char error_message[512]; + snprintf(error_message, sizeof(error_message), "Error: %s", error); + test_assert(error == NULL, full_test_name, error_message); + } else { + test_assert(error == NULL, full_test_name, NULL); + } +} + +static void test_assert_query_result_contains(chdb_result* result, const char* expected_content, const char* query_name) +{ + char* buffer; + char full_test_name[512]; + bool contains; + + test_assert_no_error(result, query_name); + + buffer = chdb_result_buffer(result); + snprintf(full_test_name, sizeof(full_test_name), "%s - Result buffer is not null", query_name); + test_assert_not_null(buffer, full_test_name); + + snprintf(full_test_name, sizeof(full_test_name), "%s - Result contains expected content", query_name); + + contains = strstr(buffer, expected_content) != NULL; + if (!contains) { + char error_message[1024]; + snprintf(error_message, sizeof(error_message), "Expected: %s, Actual: %s", expected_content, buffer); + test_assert(contains, full_test_name, error_message); + } else { + test_assert(contains, full_test_name, NULL); + } +} + +static void test_assert_row_count(chdb_result* result, uint64_t expected_rows, const char* query_name) +{ + char* buffer; + char full_test_name[512]; + char* result_str; + char* end; + uint64_t actual_rows; + + test_assert_no_error(result, query_name); + + buffer = chdb_result_buffer(result); + snprintf(full_test_name, sizeof(full_test_name), "%s - Result buffer is not null", query_name); + test_assert_not_null(buffer, full_test_name); + + /* Parse the count result (assuming CSV format with just the number) */ + result_str = (char*)malloc(strlen(buffer) + 1); + strcpy(result_str, buffer); + + /* Remove trailing whitespace/newlines */ + end = result_str + strlen(result_str) - 1; + while (end > result_str && (*end == ' ' || *end == '\t' || *end == '\n' || *end == '\r' || *end == '\f' || *end == '\v')) { + *end = '\0'; + end--; + } + + actual_rows = strtoull(result_str, NULL, 10); + + snprintf(full_test_name, sizeof(full_test_name), "%s - Row count matches", query_name); + + if (actual_rows != expected_rows) { + char error_message[256]; + snprintf(error_message, sizeof(error_message), "Expected: %llu, Actual: %llu", + (unsigned long long)expected_rows, (unsigned long long)actual_rows); + test_assert(actual_rows == expected_rows, full_test_name, error_message); + } else { + test_assert(actual_rows == expected_rows, full_test_name, NULL); + } + + free(result_str); +} + +void test_arrow_scan(chdb_connection conn) +{ + struct ArrowArrayStream stream; + struct ArrowArrayStream stream2; + struct ArrowArrayStream stream3; + CustomStreamData* stream_data; + CustomStreamData* stream_data2; + CustomStreamData* stream_data3; + const char* table_name = "test_arrow_table"; + const char* non_exist_table_name = "non_exist_table"; + const char* table_name2 = "test_arrow_table_2"; + const char* table_name3 = "test_arrow_table_3"; + chdb_arrow_stream arrow_stream; + chdb_arrow_stream arrow_stream2; + chdb_arrow_stream arrow_stream3; + chdb_state result; + chdb_result* count_result; + chdb_result* sample_result; + chdb_result* last_result; + chdb_result* count1_result; + chdb_result* count2_result; + chdb_result* count3_result; + chdb_result* join_result; + chdb_result* union_result; + chdb_result* unregister_result; + const char* error; + char error_message[512]; + + printf("\n=== Creating Custom ArrowArrayStream ===\n"); + printf("Data specification: 1,000,000 rows × 2 columns (id: int64, value: string)\n"); + + memset(&stream, 0, sizeof(stream)); + + /* Create and initialize stream data */ + stream_data = (CustomStreamData*)malloc(sizeof(CustomStreamData)); + init_custom_stream_data(stream_data); + + /* Set up the ArrowArrayStream callbacks */ + stream.get_schema = custom_get_schema; + stream.get_next = custom_get_next; + stream.get_last_error = custom_get_last_error; + stream.release = custom_release; + stream.private_data = stream_data; + + printf("✓ ArrowArrayStream initialization completed\n"); + printf("Starting registration with chDB...\n"); + + arrow_stream = (chdb_arrow_stream)&stream; + result = chdb_arrow_scan(conn, table_name, arrow_stream); + + /* Test 1: Verify arrow registration succeeded */ + test_assert_chdb_state(result, "Register ArrowArrayStream to table: test_arrow_table"); + + /* Test 2: Unregister non-existent table should handle gracefully */ + result = chdb_arrow_unregister_table(conn, non_exist_table_name); + test_assert_chdb_state(result, "Unregister non-existent table: non_exist_table"); + + /* Test 3: Count rows - should be exactly 1,000,000 */ + count_result = chdb_query(conn, "SELECT COUNT(*) as total_rows FROM arrowstream(test_arrow_table)", "CSV"); + test_assert_row_count(count_result, 1000000, "Count total rows"); + chdb_destroy_query_result(count_result); + + /* Test 4: Sample first 5 rows - should contain id=0,1,2,3,4 */ + reset_arrow_stream(&stream); + sample_result = chdb_query(conn, "SELECT * FROM arrowstream(test_arrow_table) LIMIT 5", "CSV"); + test_assert_query_result_contains(sample_result, "0,\"value_0\"", "First 5 rows contain first row"); + test_assert_query_result_contains(sample_result, "4,\"value_4\"", "First 5 rows contain fifth row"); + chdb_destroy_query_result(sample_result); + + /* Test 5: Sample last 5 rows - should contain id=999999,999998,999997,999996,999995 */ + reset_arrow_stream(&stream); + last_result = chdb_query(conn, "SELECT * FROM arrowstream(test_arrow_table) ORDER BY id DESC LIMIT 5", "CSV"); + test_assert_query_result_contains(last_result, "999999,\"value_999999\"", "Last 5 rows contain last row"); + test_assert_query_result_contains(last_result, "999995,\"value_999995\"", "Last 5 rows contain fifth row"); + chdb_destroy_query_result(last_result); + + /* Test 6: Multiple table registration tests */ + /* Create second ArrowArrayStream with different data (500,000 rows) */ + memset(&stream2, 0, sizeof(stream2)); + stream_data2 = (CustomStreamData*)malloc(sizeof(CustomStreamData)); + init_custom_stream_data(stream_data2); + stream_data2->total_rows = 500000; /* Different row count */ + stream_data2->current_row = 0; + stream2.get_schema = custom_get_schema; + stream2.get_next = custom_get_next; + stream2.get_last_error = custom_get_last_error; + stream2.release = custom_release; + stream2.private_data = stream_data2; + + /* Create third ArrowArrayStream with different data (100,000 rows) */ + memset(&stream3, 0, sizeof(stream3)); + stream_data3 = (CustomStreamData*)malloc(sizeof(CustomStreamData)); + init_custom_stream_data(stream_data3); + stream_data3->total_rows = 100000; /* Different row count */ + stream_data3->current_row = 0; + stream3.get_schema = custom_get_schema; + stream3.get_next = custom_get_next; + stream3.get_last_error = custom_get_last_error; + stream3.release = custom_release; + stream3.private_data = stream_data3; + + /* Register second table */ + arrow_stream2 = (chdb_arrow_stream)&stream2; + result = chdb_arrow_scan(conn, table_name2, arrow_stream2); + test_assert_chdb_state(result, "Register second ArrowArrayStream to table: test_arrow_table_2"); + + /* Register third table */ + arrow_stream3 = (chdb_arrow_stream)&stream3; + result = chdb_arrow_scan(conn, table_name3, arrow_stream3); + test_assert_chdb_state(result, "Register third ArrowArrayStream to table: test_arrow_table_3"); + + /* Test 6a: Verify each table has correct row counts */ + reset_arrow_stream(&stream); + count1_result = chdb_query(conn, "SELECT COUNT(*) FROM arrowstream(test_arrow_table)", "CSV"); + test_assert_row_count(count1_result, 1000000, "First table row count"); + chdb_destroy_query_result(count1_result); + + reset_arrow_stream(&stream2); + count2_result = chdb_query(conn, "SELECT COUNT(*) FROM arrowstream(test_arrow_table_2)", "CSV"); + test_assert_row_count(count2_result, 500000, "Second table row count"); + chdb_destroy_query_result(count2_result); + + reset_arrow_stream(&stream3); + count3_result = chdb_query(conn, "SELECT COUNT(*) FROM arrowstream(test_arrow_table_3)", "CSV"); + test_assert_row_count(count3_result, 100000, "Third table row count"); + chdb_destroy_query_result(count3_result); + + /* Test 6b: Test cross-table JOIN query */ + reset_arrow_stream(&stream); + reset_arrow_stream(&stream2); + join_result = chdb_query(conn, + "SELECT t1.id, t1.value, t2.value as value2 " + "FROM arrowstream(test_arrow_table) t1 " + "INNER JOIN arrowstream(test_arrow_table_2) t2 ON t1.id = t2.id " + "WHERE t1.id < 5 ORDER BY t1.id", "CSV"); + test_assert_query_result_contains(join_result, "0,\"value_0\",\"value_0\"", "JOIN query contains expected data"); + test_assert_query_result_contains(join_result, "4,\"value_4\",\"value_4\"", "JOIN query contains fifth row"); + chdb_destroy_query_result(join_result); + + /* Test 6c: Test UNION query across multiple tables */ + reset_arrow_stream(&stream2); + reset_arrow_stream(&stream3); + union_result = chdb_query(conn, + "SELECT COUNT(*) FROM (" + "SELECT id FROM arrowstream(test_arrow_table_2) WHERE id < 10 " + "UNION ALL " + "SELECT id FROM arrowstream(test_arrow_table_3) WHERE id < 10" + ")", "CSV"); + test_assert_row_count(union_result, 20, "UNION query row count"); + chdb_destroy_query_result(union_result); + + /* Cleanup additional tables */ + result = chdb_arrow_unregister_table(conn, table_name2); + test_assert_chdb_state(result, "Unregister second ArrowArrayStream table"); + + result = chdb_arrow_unregister_table(conn, table_name3); + test_assert_chdb_state(result, "Unregister third ArrowArrayStream table"); + + /* Test 7: Unregister original table should succeed */ + result = chdb_arrow_unregister_table(conn, table_name); + test_assert_chdb_state(result, "Unregister ArrowArrayStream table: test_arrow_table"); + + /* Test 8: Sample last 5 rows after unregister should fail */ + reset_arrow_stream(&stream); + unregister_result = chdb_query(conn, "SELECT * FROM arrowstream(test_arrow_table) ORDER BY id DESC LIMIT 5", "CSV"); + error = chdb_result_error(unregister_result); + + if (error) { + snprintf(error_message, sizeof(error_message), "Got expected error: %s", error); + test_assert(error != NULL, "Query after unregister should fail", error_message); + } else { + test_assert(error != NULL, "Query after unregister should fail", "No error returned when error was expected"); + } + chdb_destroy_query_result(unregister_result); +} + +// Release function for array children in create_arrow_array +static void release_array_child_id(struct ArrowArray* a) +{ + if (a->buffers) + { + free((void*)(uintptr_t)a->buffers[1]); // id data + free((void**)(uintptr_t)a->buffers); + } + free(a); +} + +// Release function for array children (string) in create_arrow_array +static void release_array_child_string(struct ArrowArray* a) { + if (a->buffers) { + free((void*)(uintptr_t)a->buffers[1]); // offsets + free((void*)(uintptr_t)a->buffers[2]); // string data + free((void**)(uintptr_t)a->buffers); + } + free(a); +} + +// Release function for main array in create_arrow_array +static void release_arrow_array_main(struct ArrowArray* a) +{ + if (a->children) + { + for (int64_t i = 0; i < a->n_children; i++) + { + if (a->children[i] && a->children[i]->release) + { + a->children[i]->release(a->children[i]); + } + } + free(a->children); + } + if (a->buffers) { + free((void**)(uintptr_t)a->buffers); + } +} + +// Helper function to create ArrowArray with specified row count +static void create_arrow_array(struct ArrowArray* array, uint64_t row_count) +{ + struct ArrowArray* id_array; + struct ArrowArray* value_array; + int64_t* id_data; + int32_t* offsets; + size_t total_string_size; + char* string_data; + size_t current_pos; + uint64_t i; + + array->length = row_count; + array->null_count = 0; + array->offset = 0; + array->n_buffers = 1; + array->n_children = 2; + array->buffers = (const void**)malloc(1 * sizeof(void*)); + array->buffers[0] = NULL; // validity buffer + + array->children = (struct ArrowArray**)malloc(2 * sizeof(struct ArrowArray*)); + array->dictionary = NULL; + + // Create id column (int64) + array->children[0] = (struct ArrowArray*)malloc(sizeof(struct ArrowArray)); + id_array = array->children[0]; + id_array->length = row_count; + id_array->null_count = 0; + id_array->offset = 0; + id_array->n_buffers = 2; + id_array->n_children = 0; + id_array->children = NULL; + id_array->dictionary = NULL; + + id_array->buffers = (const void**)malloc(2 * sizeof(void*)); + id_array->buffers[0] = NULL; // validity buffer + + // Allocate and populate id data + id_data = (int64_t*)malloc(row_count * sizeof(int64_t)); + for (i = 0; i < row_count; i++) + { + id_data[i] = (int64_t)i; + } + id_array->buffers[1] = id_data; + id_array->release = release_array_child_id; + + // Create value column (string) + array->children[1] = (struct ArrowArray*)malloc(sizeof(struct ArrowArray)); + value_array = array->children[1]; + value_array->length = row_count; + value_array->null_count = 0; + value_array->offset = 0; + value_array->n_buffers = 3; + value_array->n_children = 0; + value_array->children = NULL; + value_array->dictionary = NULL; + + value_array->buffers = (const void**)malloc(3 * sizeof(void*)); + value_array->buffers[0] = NULL; // validity buffer + + // Calculate total string data size and create offset array + offsets = (int32_t*)malloc((row_count + 1) * sizeof(int32_t)); + total_string_size = 0; + offsets[0] = 0; + + for (i = 0; i < row_count; i++) + { + char value_str[64]; + size_t len; + snprintf(value_str, sizeof(value_str), "value_%llu", (unsigned long long)i); + len = strlen(value_str); + total_string_size += len; + offsets[i + 1] = (int32_t)total_string_size; + } + + value_array->buffers[1] = offsets; + + // Allocate and populate string data + string_data = (char*)malloc(total_string_size); + current_pos = 0; + for (i = 0; i < row_count; i++) { + char value_str[64]; + size_t len; + snprintf(value_str, sizeof(value_str), "value_%llu", (unsigned long long)i); + len = strlen(value_str); + memcpy(string_data + current_pos, value_str, len); + current_pos += len; + } + value_array->buffers[2] = string_data; + value_array->release = release_array_child_string; + + // Set release callback for main array + array->release = release_arrow_array_main; +} + +void test_arrow_array_scan(chdb_connection conn) +{ + struct ArrowSchema schema; + struct ArrowArray array; + struct ArrowSchema schema2; + struct ArrowArray array2; + struct ArrowSchema schema3; + struct ArrowArray array3; + const char* table_name = "test_arrow_array_table"; + const char* non_exist_table_name = "non_exist_array_table"; + const char* table_name2 = "test_arrow_array_table_2"; + const char* table_name3 = "test_arrow_array_table_3"; + chdb_arrow_schema arrow_schema; + chdb_arrow_array arrow_array; + chdb_arrow_schema arrow_schema2; + chdb_arrow_array arrow_array2; + chdb_arrow_schema arrow_schema3; + chdb_arrow_array arrow_array3; + chdb_state result; + chdb_result* count_result; + chdb_result* sample_result; + chdb_result* last_result; + chdb_result* count2_result; + chdb_result* count3_result; + chdb_result* join_result; + chdb_result* union_result; + chdb_result* unregister_result; + const char* error; + char error_message[512]; + + printf("\n=== Testing ArrowArray Scan Functions ===\n"); + printf("Data specification: 1,000,000 rows × 2 columns (id: int64, value: string)\n"); + + // Create ArrowSchema (reuse existing function) + create_schema(&schema); + + // Create ArrowArray with 1,000,000 rows + memset(&array, 0, sizeof(array)); + create_arrow_array(&array, 1000000); + + printf("✓ ArrowArray initialization completed\n"); + printf("Starting registration with chDB...\n"); + + arrow_schema = (chdb_arrow_schema)&schema; + arrow_array = (chdb_arrow_array)&array; + + // Test 1: Register -> Query -> Unregister for row count + result = chdb_arrow_array_scan(conn, table_name, arrow_schema, arrow_array); + test_assert_chdb_state(result, "Register ArrowArray to table: test_arrow_array_table"); + + count_result = chdb_query(conn, "SELECT COUNT(*) as total_rows FROM arrowstream(test_arrow_array_table)", "CSV"); + test_assert_row_count(count_result, 1000000, "Count total rows"); + chdb_destroy_query_result(count_result); + + result = chdb_arrow_unregister_table(conn, table_name); + test_assert_chdb_state(result, "Unregister ArrowArray table after count query"); + + // Test 2: Unregister non-existent table should handle gracefully + result = chdb_arrow_unregister_table(conn, non_exist_table_name); + test_assert_chdb_state(result, "Unregister non-existent array table: non_exist_array_table"); + + // Test 3: Register -> Query -> Unregister for first 5 rows + result = chdb_arrow_array_scan(conn, table_name, arrow_schema, arrow_array); + test_assert_chdb_state(result, "Register ArrowArray for sample query"); + + sample_result = chdb_query(conn, "SELECT * FROM arrowstream(test_arrow_array_table) LIMIT 5", "CSV"); + test_assert_query_result_contains(sample_result, "0,\"value_0\"", "First 5 rows contain first row"); + test_assert_query_result_contains(sample_result, "4,\"value_4\"", "First 5 rows contain fifth row"); + chdb_destroy_query_result(sample_result); + + result = chdb_arrow_unregister_table(conn, table_name); + test_assert_chdb_state(result, "Unregister ArrowArray table after sample query"); + + // Test 4: Register -> Query -> Unregister for last 5 rows + result = chdb_arrow_array_scan(conn, table_name, arrow_schema, arrow_array); + test_assert_chdb_state(result, "Register ArrowArray for last rows query"); + + last_result = chdb_query(conn, "SELECT * FROM arrowstream(test_arrow_array_table) ORDER BY id DESC LIMIT 5", "CSV"); + test_assert_query_result_contains(last_result, "999999,\"value_999999\"", "Last 5 rows contain last row"); + test_assert_query_result_contains(last_result, "999995,\"value_999995\"", "Last 5 rows contain fifth row"); + chdb_destroy_query_result(last_result); + + result = chdb_arrow_unregister_table(conn, table_name); + test_assert_chdb_state(result, "Unregister ArrowArray table after last rows query"); + + // Test 5: Independent multiple table tests + // Create second ArrowArray with different data (500,000 rows) + create_schema(&schema2); + memset(&array2, 0, sizeof(array2)); + create_arrow_array(&array2, 500000); + + // Create third ArrowArray with different data (100,000 rows) + create_schema(&schema3); + memset(&array3, 0, sizeof(array3)); + create_arrow_array(&array3, 100000); + + arrow_schema2 = (chdb_arrow_schema)&schema2; + arrow_array2 = (chdb_arrow_array)&array2; + arrow_schema3 = (chdb_arrow_schema)&schema3; + arrow_array3 = (chdb_arrow_array)&array3; + + // Test 5a: Register -> Query -> Unregister for second table (500K rows) + result = chdb_arrow_array_scan(conn, table_name2, arrow_schema2, arrow_array2); + test_assert_chdb_state(result, "Register second ArrowArray to table: test_arrow_array_table_2"); + + count2_result = chdb_query(conn, "SELECT COUNT(*) FROM arrowstream(test_arrow_array_table_2)", "CSV"); + test_assert_row_count(count2_result, 500000, "Second array table row count"); + chdb_destroy_query_result(count2_result); + + result = chdb_arrow_unregister_table(conn, table_name2); + test_assert_chdb_state(result, "Unregister second ArrowArray table"); + + // Test 5b: Register -> Query -> Unregister for third table (100K rows) + result = chdb_arrow_array_scan(conn, table_name3, arrow_schema3, arrow_array3); + test_assert_chdb_state(result, "Register third ArrowArray to table: test_arrow_array_table_3"); + + count3_result = chdb_query(conn, "SELECT COUNT(*) FROM arrowstream(test_arrow_array_table_3)", "CSV"); + test_assert_row_count(count3_result, 100000, "Third array table row count"); + chdb_destroy_query_result(count3_result); + + result = chdb_arrow_unregister_table(conn, table_name3); + test_assert_chdb_state(result, "Unregister third ArrowArray table"); + + // Test 6: Cross-table JOIN query (Register both -> Query -> Unregister both) + result = chdb_arrow_array_scan(conn, table_name, arrow_schema, arrow_array); + test_assert_chdb_state(result, "Register first ArrowArray for JOIN"); + + result = chdb_arrow_array_scan(conn, table_name2, arrow_schema2, arrow_array2); + test_assert_chdb_state(result, "Register second ArrowArray for JOIN"); + + join_result = chdb_query(conn, + "SELECT t1.id, t1.value, t2.value as value2 " + "FROM arrowstream(test_arrow_array_table) t1 " + "INNER JOIN arrowstream(test_arrow_array_table_2) t2 ON t1.id = t2.id " + "WHERE t1.id < 5 ORDER BY t1.id", "CSV"); + test_assert_query_result_contains(join_result, "0,\"value_0\",\"value_0\"", "Array JOIN query contains expected data"); + test_assert_query_result_contains(join_result, "4,\"value_4\",\"value_4\"", "Array JOIN query contains fifth row"); + chdb_destroy_query_result(join_result); + + result = chdb_arrow_unregister_table(conn, table_name); + test_assert_chdb_state(result, "Unregister first ArrowArray after JOIN"); + + result = chdb_arrow_unregister_table(conn, table_name2); + test_assert_chdb_state(result, "Unregister second ArrowArray after JOIN"); + + // Test 7: Cross-table UNION query (Register both -> Query -> Unregister both) + result = chdb_arrow_array_scan(conn, table_name2, arrow_schema2, arrow_array2); + test_assert_chdb_state(result, "Register second ArrowArray for UNION"); + + result = chdb_arrow_array_scan(conn, table_name3, arrow_schema3, arrow_array3); + test_assert_chdb_state(result, "Register third ArrowArray for UNION"); + + union_result = chdb_query(conn, + "SELECT COUNT(*) FROM (" + "SELECT id FROM arrowstream(test_arrow_array_table_2) WHERE id < 10 " + "UNION ALL " + "SELECT id FROM arrowstream(test_arrow_array_table_3) WHERE id < 10" + ")", "CSV"); + test_assert_row_count(union_result, 20, "Array UNION query row count"); + chdb_destroy_query_result(union_result); + + result = chdb_arrow_unregister_table(conn, table_name2); + test_assert_chdb_state(result, "Unregister second ArrowArray after UNION"); + + result = chdb_arrow_unregister_table(conn, table_name3); + test_assert_chdb_state(result, "Unregister third ArrowArray after UNION"); + + // Test 8: Query after unregister should fail + unregister_result = chdb_query(conn, "SELECT * FROM arrowstream(test_arrow_array_table) ORDER BY id DESC LIMIT 5", "CSV"); + error = chdb_result_error(unregister_result); + + if (error) { + snprintf(error_message, sizeof(error_message), "Got expected error: %s", error); + test_assert(error != NULL, "Array query after unregister should fail", error_message); + } else { + test_assert(error != NULL, "Array query after unregister should fail", "No error returned when error was expected"); + } + chdb_destroy_query_result(unregister_result); + + // Cleanup ArrowArrays and schemas + if (array.release) array.release(&array); + if (schema.release) schema.release(&schema); + if (array2.release) array2.release(&array2); + if (schema2.release) schema2.release(&schema2); + if (array3.release) array3.release(&array3); + if (schema3.release) schema3.release(&schema3); +} + +int main(void) +{ + char* argv[] = {"clickhouse", "--multiquery"}; + int argc = sizeof(argv) / sizeof(argv[0]); + chdb_connection* conn_ptr; + chdb_connection conn; + + printf("=== chDB Arrow Functions Test ===\n"); + + /* Create connection */ + conn_ptr = chdb_connect(argc, argv); + if (!conn_ptr || !*conn_ptr) { + printf("Failed to create chDB connection\n"); + return 1; + } + + conn = *conn_ptr; + printf("✓ chDB connection established\n"); + + /* Run test suites */ + test_arrow_scan(conn); + test_arrow_array_scan(conn); + + /* Clean up */ + chdb_close_conn(conn_ptr); + + printf("\n=== chDB Arrow Functions Test Completed ===\n"); + + return 0; +} diff --git a/examples/runArrowTestC.sh b/examples/runArrowTestC.sh new file mode 100755 index 00000000000..db94b824500 --- /dev/null +++ b/examples/runArrowTestC.sh @@ -0,0 +1,36 @@ +#!/bin/bash + +set -e + +# CFLAGS="-g -O0 -DDEBUG" +CFLAGS="-std=c99" + +# check current os type, and make ldd command +if [ "$(uname)" == "Darwin" ]; then + LDD="otool -L" + LIB_PATH="DYLD_LIBRARY_PATH" +elif [ "$(uname)" == "Linux" ]; then + LDD="ldd" + LIB_PATH="LD_LIBRARY_PATH" +else + echo "OS not supported" + exit 1 +fi + +# cd to the directory of this script +DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" >/dev/null 2>&1 && pwd )" +cd $DIR + +echo "Compile and link chdbArrowTest (C version)" +clang $CFLAGS chdbArrowTest.c -o chdbArrowTestC \ + -I../programs/local/ \ + -I../contrib/arrow/cpp/src \ + -I../contrib/arrow-cmake/cpp/src \ + -I../src \ + -L../ -lchdb + +export ${LIB_PATH}=.. +${LDD} chdbArrowTestC + +echo "Run Arrow API tests (C version):" +./chdbArrowTestC From e092a5c7d2b8dcc0339e3051387f67fee8bb671f Mon Sep 17 00:00:00 2001 From: wudidapaopao Date: Sat, 27 Sep 2025 21:16:53 +0800 Subject: [PATCH 10/13] test: update workflow --- .../workflows/build_linux_arm64_wheels-gh.yml | 3 - .github/workflows/build_linux_x86_wheels.yml | 3 - .../workflows/build_macos_arm64_wheels.yml | 3 - .github/workflows/build_macos_x86_wheels.yml | 2 +- examples/chdbArrowTest.c | 296 ++++++++++-------- examples/runArrowTest.sh | 9 +- examples/runArrowTestC.sh | 36 --- programs/local/chdb-arrow.cpp | 5 +- 8 files changed, 175 insertions(+), 182 deletions(-) delete mode 100755 examples/runArrowTestC.sh diff --git a/.github/workflows/build_linux_arm64_wheels-gh.yml b/.github/workflows/build_linux_arm64_wheels-gh.yml index 7c6a50ab15c..e661ad544aa 100644 --- a/.github/workflows/build_linux_arm64_wheels-gh.yml +++ b/.github/workflows/build_linux_arm64_wheels-gh.yml @@ -138,9 +138,6 @@ jobs: - name: Run libchdb stub in examples dir run: | bash -x ./examples/runStub.sh - bash -x ./examples/runArrowTestC.sh - - name: Run Arrow functions test in examples dir - run: | bash -x ./examples/runArrowTest.sh - name: Check ccache statistics run: | diff --git a/.github/workflows/build_linux_x86_wheels.yml b/.github/workflows/build_linux_x86_wheels.yml index 8d7c21e0411..d3b4ca61efe 100644 --- a/.github/workflows/build_linux_x86_wheels.yml +++ b/.github/workflows/build_linux_x86_wheels.yml @@ -138,9 +138,6 @@ jobs: - name: Run libchdb stub in examples dir run: | bash -x ./examples/runStub.sh - bash -x ./examples/runArrowTestC.sh - - name: Run Arrow functions test in examples dir - run: | bash -x ./examples/runArrowTest.sh - name: Check ccache statistics run: | diff --git a/.github/workflows/build_macos_arm64_wheels.yml b/.github/workflows/build_macos_arm64_wheels.yml index ac4dfa826bf..2cd312eb755 100644 --- a/.github/workflows/build_macos_arm64_wheels.yml +++ b/.github/workflows/build_macos_arm64_wheels.yml @@ -154,9 +154,6 @@ jobs: - name: Run libchdb stub in examples dir run: | bash -x ./examples/runStub.sh - bash -x ./examples/runArrowTestC.sh - - name: Run Arrow functions test in examples dir - run: | bash -x ./examples/runArrowTest.sh - name: Keep killall ccache and wait for ccache to finish if: always() diff --git a/.github/workflows/build_macos_x86_wheels.yml b/.github/workflows/build_macos_x86_wheels.yml index 48a52261156..35cc72e284d 100644 --- a/.github/workflows/build_macos_x86_wheels.yml +++ b/.github/workflows/build_macos_x86_wheels.yml @@ -155,7 +155,7 @@ jobs: - name: Run libchdb stub in examples dir run: | bash -x ./examples/runStub.sh - bash -x ./examples/runArrowTestC.sh + bash -x ./examples/runArrowTest.sh - name: Keep killall ccache and wait for ccache to finish if: always() run: | diff --git a/examples/chdbArrowTest.c b/examples/chdbArrowTest.c index 841be0cd9c3..f91d4d4c8c4 100644 --- a/examples/chdbArrowTest.c +++ b/examples/chdbArrowTest.c @@ -19,7 +19,8 @@ typedef struct CustomStreamData } CustomStreamData; // Function to initialize CustomStreamData -static void init_custom_stream_data(CustomStreamData* data) { +static void init_custom_stream_data(CustomStreamData * data) +{ data->schema_sent = false; data->current_row = 0; data->total_rows = 1000000; @@ -28,7 +29,7 @@ static void init_custom_stream_data(CustomStreamData* data) { } // Reset the stream to allow reading from the beginning -static void reset_custom_stream_data(CustomStreamData* data) +static void reset_custom_stream_data(CustomStreamData * data) { data->current_row = 0; if (data->last_error) { @@ -38,27 +39,32 @@ static void reset_custom_stream_data(CustomStreamData* data) } // Release function prototypes -static void release_schema_child(struct ArrowSchema* s); -static void release_schema_main(struct ArrowSchema* s); -static void release_id_array(struct ArrowArray* arr); -static void release_string_array(struct ArrowArray* arr); -static void release_main_array(struct ArrowArray* arr); +static void release_schema_child(struct ArrowSchema * s); +static void release_schema_main(struct ArrowSchema * s); +static void release_id_array(struct ArrowArray * arr); +static void release_string_array(struct ArrowArray * arr); +static void release_main_array(struct ArrowArray * arr); // Helper function to find minimum of two values -static size_t min_size_t(size_t a, size_t b) { +static size_t min_size_t(size_t a, size_t b) +{ return (a < b) ? a : b; } // Release function implementations -static void release_schema_child(struct ArrowSchema* s) { +static void release_schema_child(struct ArrowSchema * s) +{ s->release = NULL; } -static void release_schema_main(struct ArrowSchema* s) +static void release_schema_main(struct ArrowSchema * s) { - if (s->children) { - for (int64_t i = 0; i < s->n_children; i++) { - if (s->children[i] && s->children[i]->release) { + if (s->children) + { + for (int64_t i = 0; i < s->n_children; i++) + { + if (s->children[i] && s->children[i]->release) + { s->children[i]->release(s->children[i]); } free(s->children[i]); @@ -68,18 +74,20 @@ static void release_schema_main(struct ArrowSchema* s) s->release = NULL; } -static void release_id_array(struct ArrowArray* arr) +static void release_id_array(struct ArrowArray * arr) { - if (arr->buffers) { + if (arr->buffers) + { free((void*)(uintptr_t)arr->buffers[1]); // free data buffer free((void**)(uintptr_t)arr->buffers); } arr->release = NULL; } -static void release_string_array(struct ArrowArray* arr) +static void release_string_array(struct ArrowArray * arr) { - if (arr->buffers) { + if (arr->buffers) + { free((void*)(uintptr_t)arr->buffers[1]); // free offset buffer free((void*)(uintptr_t)arr->buffers[2]); // free data buffer free((void**)(uintptr_t)arr->buffers); @@ -87,10 +95,14 @@ static void release_string_array(struct ArrowArray* arr) arr->release = NULL; } -static void release_main_array(struct ArrowArray* arr) { - if (arr->children) { - for (int64_t i = 0; i < arr->n_children; i++) { - if (arr->children[i] && arr->children[i]->release) { +static void release_main_array(struct ArrowArray * arr) +{ + if (arr->children) + { + for (int64_t i = 0; i < arr->n_children; i++) + { + if (arr->children[i] && arr->children[i]->release) + { arr->children[i]->release(arr->children[i]); } free(arr->children[i]); @@ -104,7 +116,8 @@ static void release_main_array(struct ArrowArray* arr) { } // Helper function to create schema with 2 columns: id(int64), value(string) -static void create_schema(struct ArrowSchema* schema) { +static void create_schema(struct ArrowSchema * schema) +{ schema->format = "+s"; // struct format schema->name = NULL; schema->metadata = NULL; @@ -138,15 +151,15 @@ static void create_schema(struct ArrowSchema* schema) { } // Helper function to create a batch of data -static void create_batch(struct ArrowArray* array, size_t start_row, size_t batch_size) +static void create_batch(struct ArrowArray * array, size_t start_row, size_t batch_size) { - struct ArrowArray* id_array; - struct ArrowArray* str_array; - int64_t* id_data; - int32_t* offsets; + struct ArrowArray * id_array; + struct ArrowArray * str_array; + int64_t * id_data; + int32_t * offsets; size_t total_str_len; - char** strings; - char* str_data; + char ** strings; + char * str_data; size_t pos; size_t i; @@ -156,24 +169,24 @@ static void create_batch(struct ArrowArray* array, size_t start_row, size_t batc array->offset = 0; array->n_buffers = 1; array->n_children = 2; - array->buffers = (const void**)malloc(1 * sizeof(void*)); + array->buffers = (const void **)malloc(1 * sizeof(void *)); array->buffers[0] = NULL; // validity buffer (no nulls) - array->children = (struct ArrowArray**)malloc(2 * sizeof(struct ArrowArray*)); + array->children = (struct ArrowArray **)malloc(2 * sizeof(struct ArrowArray *)); array->dictionary = NULL; // Create id column (int64) - array->children[0] = (struct ArrowArray*)malloc(sizeof(struct ArrowArray)); + array->children[0] = (struct ArrowArray *)malloc(sizeof(struct ArrowArray)); id_array = array->children[0]; id_array->length = batch_size; id_array->null_count = 0; id_array->offset = 0; id_array->n_buffers = 2; id_array->n_children = 0; - id_array->buffers = (const void**)malloc(2 * sizeof(void*)); + id_array->buffers = (const void **)malloc(2 * sizeof(void *)); id_array->buffers[0] = NULL; // validity buffer // Allocate and fill id data - id_data = (int64_t*)malloc(batch_size * sizeof(int64_t)); + id_data = (int64_t *)malloc(batch_size * sizeof(int64_t)); for (i = 0; i < batch_size; i++) id_data[i] = start_row + i; @@ -183,23 +196,23 @@ static void create_batch(struct ArrowArray* array, size_t start_row, size_t batc id_array->release = release_id_array; // Create value column (string) - array->children[1] = (struct ArrowArray*)malloc(sizeof(struct ArrowArray)); + array->children[1] = (struct ArrowArray *)malloc(sizeof(struct ArrowArray)); str_array = array->children[1]; str_array->length = batch_size; str_array->null_count = 0; str_array->offset = 0; str_array->n_buffers = 3; str_array->n_children = 0; - str_array->buffers = (const void**)malloc(3 * sizeof(void*)); + str_array->buffers = (const void **)malloc(3 * sizeof(void *)); str_array->buffers[0] = NULL; // validity buffer // Create offset buffer (int32) - offsets = (int32_t*)malloc((batch_size + 1) * sizeof(int32_t)); + offsets = (int32_t *)malloc((batch_size + 1) * sizeof(int32_t)); offsets[0] = 0; // Calculate total string length and create strings total_str_len = 0; - strings = (char**)malloc(batch_size * sizeof(char*)); + strings = (char **)malloc(batch_size * sizeof(char *)); for (i = 0; i < batch_size; i++) { char buffer[64]; @@ -243,13 +256,13 @@ static int custom_get_schema(struct ArrowArrayStream* stream, struct ArrowSchema } // Callback function to get next array -static int custom_get_next(struct ArrowArrayStream* stream, struct ArrowArray* out) +static int custom_get_next(struct ArrowArrayStream * stream, struct ArrowArray * out) { - CustomStreamData* data; + CustomStreamData * data; size_t remaining_rows; size_t batch_size; - - data = (CustomStreamData*)stream->private_data; + + data = (CustomStreamData *)stream->private_data; if (!data) return EINVAL; @@ -273,8 +286,9 @@ static int custom_get_next(struct ArrowArrayStream* stream, struct ArrowArray* o } // Callback function to get last error -static const char* custom_get_last_error(struct ArrowArrayStream* stream) { - CustomStreamData* data = (CustomStreamData*)stream->private_data; +static const char * custom_get_last_error(struct ArrowArrayStream * stream) +{ + CustomStreamData * data = (CustomStreamData *)stream->private_data; if (!data || !data->last_error) return NULL; @@ -282,11 +296,13 @@ static const char* custom_get_last_error(struct ArrowArrayStream* stream) { } // Callback function to release stream resources -static void custom_release(struct ArrowArrayStream* stream) { +static void custom_release(struct ArrowArrayStream * stream) +{ if (stream->private_data) { - CustomStreamData* data = (CustomStreamData*)stream->private_data; - if (data->last_error) { + CustomStreamData * data = (CustomStreamData *)stream->private_data; + if (data->last_error) + { free(data->last_error); } free(data); @@ -296,11 +312,11 @@ static void custom_release(struct ArrowArrayStream* stream) { } // Helper function to reset the ArrowArrayStream for reuse -static void reset_arrow_stream(struct ArrowArrayStream* stream) +static void reset_arrow_stream(struct ArrowArrayStream * stream) { if (stream && stream->private_data) { - CustomStreamData* data = (CustomStreamData*)stream->private_data; + CustomStreamData * data = (CustomStreamData *)stream->private_data; reset_custom_stream_data(data); printf("✓ ArrowArrayStream has been reset, ready for re-reading\n"); } @@ -310,7 +326,7 @@ static void reset_arrow_stream(struct ArrowArrayStream* stream) // Unit Test Utilities //===--------------------------------------------------------------------===// -static void test_assert(bool condition, const char* test_name, const char* message) +static void test_assert(bool condition, const char * test_name, const char * message) { if (condition) { @@ -328,50 +344,56 @@ static void test_assert(bool condition, const char* test_name, const char* messa } } -static void test_assert_chdb_state(chdb_state state, const char* operation_name) +static void test_assert_chdb_state(chdb_state state, const char * operation_name) { char message[256]; - if (state == CHDBError) { + if (state == CHDBError) + { strcpy(message, "Operation failed"); - } else { + } + else + { strcpy(message, "Unknown state"); } - + test_assert(state == CHDBSuccess, operation_name, state == CHDBError ? message : NULL); } -static void test_assert_not_null(void* ptr, const char* test_name) +static void test_assert_not_null(void * ptr, const char * test_name) { test_assert(ptr != NULL, test_name, "Pointer is null"); } -static void test_assert_no_error(chdb_result* result, const char* query_name) +static void test_assert_no_error(chdb_result * result, const char * query_name) { char full_test_name[512]; - const char* error; - + const char * error; + snprintf(full_test_name, sizeof(full_test_name), "%s - Result is not null", query_name); test_assert_not_null(result, full_test_name); error = chdb_result_error(result); snprintf(full_test_name, sizeof(full_test_name), "%s - No query error", query_name); - - if (error) { + + if (error) + { char error_message[512]; snprintf(error_message, sizeof(error_message), "Error: %s", error); test_assert(error == NULL, full_test_name, error_message); - } else { + } + else + { test_assert(error == NULL, full_test_name, NULL); } } -static void test_assert_query_result_contains(chdb_result* result, const char* expected_content, const char* query_name) +static void test_assert_query_result_contains(chdb_result * result, const char * expected_content, const char * query_name) { - char* buffer; + char * buffer; char full_test_name[512]; bool contains; - + test_assert_no_error(result, query_name); buffer = chdb_result_buffer(result); @@ -379,25 +401,28 @@ static void test_assert_query_result_contains(chdb_result* result, const char* e test_assert_not_null(buffer, full_test_name); snprintf(full_test_name, sizeof(full_test_name), "%s - Result contains expected content", query_name); - + contains = strstr(buffer, expected_content) != NULL; - if (!contains) { + if (!contains) + { char error_message[1024]; snprintf(error_message, sizeof(error_message), "Expected: %s, Actual: %s", expected_content, buffer); test_assert(contains, full_test_name, error_message); - } else { + } + else + { test_assert(contains, full_test_name, NULL); } } -static void test_assert_row_count(chdb_result* result, uint64_t expected_rows, const char* query_name) +static void test_assert_row_count(chdb_result * result, uint64_t expected_rows, const char * query_name) { - char* buffer; + char * buffer; char full_test_name[512]; - char* result_str; - char* end; + char * result_str; + char * end; uint64_t actual_rows; - + test_assert_no_error(result, query_name); buffer = chdb_result_buffer(result); @@ -407,7 +432,7 @@ static void test_assert_row_count(chdb_result* result, uint64_t expected_rows, c /* Parse the count result (assuming CSV format with just the number) */ result_str = (char*)malloc(strlen(buffer) + 1); strcpy(result_str, buffer); - + /* Remove trailing whitespace/newlines */ end = result_str + strlen(result_str) - 1; while (end > result_str && (*end == ' ' || *end == '\t' || *end == '\n' || *end == '\r' || *end == '\f' || *end == '\v')) { @@ -416,18 +441,21 @@ static void test_assert_row_count(chdb_result* result, uint64_t expected_rows, c } actual_rows = strtoull(result_str, NULL, 10); - + snprintf(full_test_name, sizeof(full_test_name), "%s - Row count matches", query_name); - - if (actual_rows != expected_rows) { + + if (actual_rows != expected_rows) + { char error_message[256]; - snprintf(error_message, sizeof(error_message), "Expected: %llu, Actual: %llu", + snprintf(error_message, sizeof(error_message), "Expected: %llu, Actual: %llu", (unsigned long long)expected_rows, (unsigned long long)actual_rows); test_assert(actual_rows == expected_rows, full_test_name, error_message); - } else { + } + else + { test_assert(actual_rows == expected_rows, full_test_name, NULL); } - + free(result_str); } @@ -436,9 +464,9 @@ void test_arrow_scan(chdb_connection conn) struct ArrowArrayStream stream; struct ArrowArrayStream stream2; struct ArrowArrayStream stream3; - CustomStreamData* stream_data; - CustomStreamData* stream_data2; - CustomStreamData* stream_data3; + CustomStreamData * stream_data; + CustomStreamData * stream_data2; + CustomStreamData * stream_data3; const char* table_name = "test_arrow_table"; const char* non_exist_table_name = "non_exist_table"; const char* table_name2 = "test_arrow_table_2"; @@ -447,19 +475,19 @@ void test_arrow_scan(chdb_connection conn) chdb_arrow_stream arrow_stream2; chdb_arrow_stream arrow_stream3; chdb_state result; - chdb_result* count_result; - chdb_result* sample_result; - chdb_result* last_result; - chdb_result* count1_result; - chdb_result* count2_result; - chdb_result* count3_result; - chdb_result* join_result; - chdb_result* union_result; - chdb_result* unregister_result; - const char* error; + chdb_result * count_result; + chdb_result * sample_result; + chdb_result * last_result; + chdb_result * count1_result; + chdb_result * count2_result; + chdb_result * count3_result; + chdb_result * join_result; + chdb_result * union_result; + chdb_result * unregister_result; + const char * error; char error_message[512]; - printf("\n=== Creating Custom ArrowArrayStream ===\n"); + printf("\n=== Testing ArrowArrayStream Scan Functions ===\n"); printf("Data specification: 1,000,000 rows × 2 columns (id: int64, value: string)\n"); memset(&stream, 0, sizeof(stream)); @@ -510,7 +538,7 @@ void test_arrow_scan(chdb_connection conn) /* Test 6: Multiple table registration tests */ /* Create second ArrowArrayStream with different data (500,000 rows) */ memset(&stream2, 0, sizeof(stream2)); - stream_data2 = (CustomStreamData*)malloc(sizeof(CustomStreamData)); + stream_data2 = (CustomStreamData *)malloc(sizeof(CustomStreamData)); init_custom_stream_data(stream_data2); stream_data2->total_rows = 500000; /* Different row count */ stream_data2->current_row = 0; @@ -522,7 +550,7 @@ void test_arrow_scan(chdb_connection conn) /* Create third ArrowArrayStream with different data (100,000 rows) */ memset(&stream3, 0, sizeof(stream3)); - stream_data3 = (CustomStreamData*)malloc(sizeof(CustomStreamData)); + stream_data3 = (CustomStreamData *)malloc(sizeof(CustomStreamData)); init_custom_stream_data(stream_data3); stream_data3->total_rows = 100000; /* Different row count */ stream_data3->current_row = 0; @@ -597,11 +625,14 @@ void test_arrow_scan(chdb_connection conn) reset_arrow_stream(&stream); unregister_result = chdb_query(conn, "SELECT * FROM arrowstream(test_arrow_table) ORDER BY id DESC LIMIT 5", "CSV"); error = chdb_result_error(unregister_result); - - if (error) { + + if (error) + { snprintf(error_message, sizeof(error_message), "Got expected error: %s", error); test_assert(error != NULL, "Query after unregister should fail", error_message); - } else { + } + else + { test_assert(error != NULL, "Query after unregister should fail", "No error returned when error was expected"); } chdb_destroy_query_result(unregister_result); @@ -619,8 +650,10 @@ static void release_array_child_id(struct ArrowArray* a) } // Release function for array children (string) in create_arrow_array -static void release_array_child_string(struct ArrowArray* a) { - if (a->buffers) { +static void release_array_child_string(struct ArrowArray* a) +{ + if (a->buffers) + { free((void*)(uintptr_t)a->buffers[1]); // offsets free((void*)(uintptr_t)a->buffers[2]); // string data free((void**)(uintptr_t)a->buffers); @@ -629,7 +662,7 @@ static void release_array_child_string(struct ArrowArray* a) { } // Release function for main array in create_arrow_array -static void release_arrow_array_main(struct ArrowArray* a) +static void release_arrow_array_main(struct ArrowArray * a) { if (a->children) { @@ -642,23 +675,25 @@ static void release_arrow_array_main(struct ArrowArray* a) } free(a->children); } - if (a->buffers) { + + if (a->buffers) + { free((void**)(uintptr_t)a->buffers); } } // Helper function to create ArrowArray with specified row count -static void create_arrow_array(struct ArrowArray* array, uint64_t row_count) +static void create_arrow_array(struct ArrowArray * array, uint64_t row_count) { - struct ArrowArray* id_array; - struct ArrowArray* value_array; - int64_t* id_data; - int32_t* offsets; + struct ArrowArray * id_array; + struct ArrowArray * value_array; + int64_t * id_data; + int32_t * offsets; size_t total_string_size; - char* string_data; + char * string_data; size_t current_pos; uint64_t i; - + array->length = row_count; array->null_count = 0; array->offset = 0; @@ -725,7 +760,7 @@ static void create_arrow_array(struct ArrowArray* array, uint64_t row_count) value_array->buffers[1] = offsets; // Allocate and populate string data - string_data = (char*)malloc(total_string_size); + string_data = (char *)malloc(total_string_size); current_pos = 0; for (i = 0; i < row_count; i++) { char value_str[64]; @@ -750,10 +785,10 @@ void test_arrow_array_scan(chdb_connection conn) struct ArrowArray array2; struct ArrowSchema schema3; struct ArrowArray array3; - const char* table_name = "test_arrow_array_table"; - const char* non_exist_table_name = "non_exist_array_table"; - const char* table_name2 = "test_arrow_array_table_2"; - const char* table_name3 = "test_arrow_array_table_3"; + const char * table_name = "test_arrow_array_table"; + const char * non_exist_table_name = "non_exist_array_table"; + const char * table_name2 = "test_arrow_array_table_2"; + const char * table_name3 = "test_arrow_array_table_3"; chdb_arrow_schema arrow_schema; chdb_arrow_array arrow_array; chdb_arrow_schema arrow_schema2; @@ -761,15 +796,15 @@ void test_arrow_array_scan(chdb_connection conn) chdb_arrow_schema arrow_schema3; chdb_arrow_array arrow_array3; chdb_state result; - chdb_result* count_result; - chdb_result* sample_result; - chdb_result* last_result; - chdb_result* count2_result; - chdb_result* count3_result; - chdb_result* join_result; - chdb_result* union_result; - chdb_result* unregister_result; - const char* error; + chdb_result * count_result; + chdb_result * sample_result; + chdb_result * last_result; + chdb_result * count2_result; + chdb_result * count3_result; + chdb_result * join_result; + chdb_result * union_result; + chdb_result * unregister_result; + const char * error; char error_message[512]; printf("\n=== Testing ArrowArray Scan Functions ===\n"); @@ -912,11 +947,14 @@ void test_arrow_array_scan(chdb_connection conn) // Test 8: Query after unregister should fail unregister_result = chdb_query(conn, "SELECT * FROM arrowstream(test_arrow_array_table) ORDER BY id DESC LIMIT 5", "CSV"); error = chdb_result_error(unregister_result); - - if (error) { + + if (error) + { snprintf(error_message, sizeof(error_message), "Got expected error: %s", error); test_assert(error != NULL, "Array query after unregister should fail", error_message); - } else { + } + else + { test_assert(error != NULL, "Array query after unregister should fail", "No error returned when error was expected"); } chdb_destroy_query_result(unregister_result); @@ -932,9 +970,9 @@ void test_arrow_array_scan(chdb_connection conn) int main(void) { - char* argv[] = {"clickhouse", "--multiquery"}; + char * argv[] = {"clickhouse", "--multiquery"}; int argc = sizeof(argv) / sizeof(argv[0]); - chdb_connection* conn_ptr; + chdb_connection * conn_ptr; chdb_connection conn; printf("=== chDB Arrow Functions Test ===\n"); @@ -943,7 +981,7 @@ int main(void) conn_ptr = chdb_connect(argc, argv); if (!conn_ptr || !*conn_ptr) { printf("Failed to create chDB connection\n"); - return 1; + exit(1); } conn = *conn_ptr; diff --git a/examples/runArrowTest.sh b/examples/runArrowTest.sh index 96ec084e674..7696f361162 100755 --- a/examples/runArrowTest.sh +++ b/examples/runArrowTest.sh @@ -2,8 +2,7 @@ set -e -# CXXFLAGS="-g -O0 -DDEBUG" -CXXFLAGS="-std=c++17" +CFLAGS="-std=c99" # check current os type, and make ldd command if [ "$(uname)" == "Darwin" ]; then @@ -21,8 +20,8 @@ fi DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" >/dev/null 2>&1 && pwd )" cd $DIR -echo "Compile and link chdbArrowTest" -clang++ $CXXFLAGS chdbArrowTest.cpp -o chdbArrowTest \ +echo "Compile and link chdbArrowTest (C version)" +clang $CFLAGS chdbArrowTest.c -o chdbArrowTest \ -I../programs/local/ \ -I../contrib/arrow/cpp/src \ -I../contrib/arrow-cmake/cpp/src \ @@ -32,5 +31,5 @@ clang++ $CXXFLAGS chdbArrowTest.cpp -o chdbArrowTest \ export ${LIB_PATH}=.. ${LDD} chdbArrowTest -echo "Run Arrow API tests:" +echo "Run Arrow API tests (C version):" ./chdbArrowTest diff --git a/examples/runArrowTestC.sh b/examples/runArrowTestC.sh deleted file mode 100755 index db94b824500..00000000000 --- a/examples/runArrowTestC.sh +++ /dev/null @@ -1,36 +0,0 @@ -#!/bin/bash - -set -e - -# CFLAGS="-g -O0 -DDEBUG" -CFLAGS="-std=c99" - -# check current os type, and make ldd command -if [ "$(uname)" == "Darwin" ]; then - LDD="otool -L" - LIB_PATH="DYLD_LIBRARY_PATH" -elif [ "$(uname)" == "Linux" ]; then - LDD="ldd" - LIB_PATH="LD_LIBRARY_PATH" -else - echo "OS not supported" - exit 1 -fi - -# cd to the directory of this script -DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" >/dev/null 2>&1 && pwd )" -cd $DIR - -echo "Compile and link chdbArrowTest (C version)" -clang $CFLAGS chdbArrowTest.c -o chdbArrowTestC \ - -I../programs/local/ \ - -I../contrib/arrow/cpp/src \ - -I../contrib/arrow-cmake/cpp/src \ - -I../src \ - -L../ -lchdb - -export ${LIB_PATH}=.. -${LDD} chdbArrowTestC - -echo "Run Arrow API tests (C version):" -./chdbArrowTestC diff --git a/programs/local/chdb-arrow.cpp b/programs/local/chdb-arrow.cpp index 5200c537a7e..e899e47e0c5 100644 --- a/programs/local/chdb-arrow.cpp +++ b/programs/local/chdb-arrow.cpp @@ -91,8 +91,6 @@ static chdb_state chdb_inner_arrow_scan( chdb_connection conn, const char * table_name, chdb_arrow_stream arrow_stream, bool is_owner) { - ChdbDestructorGuard guard; - std::shared_lock global_lock(global_connection_mutex); if (!table_name || !arrow_stream) @@ -139,6 +137,7 @@ chdb_state chdb_arrow_scan( chdb_connection conn, const char * table_name, chdb_arrow_stream arrow_stream) { + ChdbDestructorGuard guard; return chdb_inner_arrow_scan(conn, table_name, arrow_stream, false); } @@ -146,6 +145,8 @@ chdb_state chdb_arrow_array_scan( chdb_connection conn, const char * table_name, chdb_arrow_schema arrow_schema, chdb_arrow_array arrow_array) { + ChdbDestructorGuard guard; + auto * private_data = new CHDB::PrivateData(); private_data->schema = reinterpret_cast(arrow_schema); private_data->array = reinterpret_cast(arrow_array); From 25e87d964c4d34088a7e91952131531f4703b980 Mon Sep 17 00:00:00 2001 From: wudidapaopao Date: Sun, 28 Sep 2025 09:17:10 +0800 Subject: [PATCH 11/13] chore: remove test file --- examples/chdbArrowTest.cpp | 784 ------------------------------------- 1 file changed, 784 deletions(-) delete mode 100644 examples/chdbArrowTest.cpp diff --git a/examples/chdbArrowTest.cpp b/examples/chdbArrowTest.cpp deleted file mode 100644 index 6cc1884beec..00000000000 --- a/examples/chdbArrowTest.cpp +++ /dev/null @@ -1,784 +0,0 @@ -#include -#include -#include -#include -#include - -#include "../programs/local/chdb.h" -#include "../contrib/arrow/cpp/src/arrow/c/abi.h" - -// Custom ArrowArrayStream implementation data -struct CustomStreamData -{ - bool schema_sent; - size_t current_row; - size_t total_rows; - size_t batch_size; - std::string last_error; - - CustomStreamData() : schema_sent(false), current_row(0), total_rows(1000000), batch_size(10000) {} - - // Reset the stream to allow reading from the beginning - void reset() - { - current_row = 0; - last_error.clear(); - } -}; - -// Helper function to create schema with 2 columns: id(int64), value(string) -static void create_schema(struct ArrowSchema * schema) { - schema->format = "+s"; // struct format - schema->name = nullptr; - schema->metadata = nullptr; - schema->flags = 0; - schema->n_children = 2; - schema->children = static_cast(malloc(2 * sizeof(struct ArrowSchema *))); - schema->dictionary = nullptr; - schema->release = [](struct ArrowSchema * s) - { - if (s->children) { - for (int64_t i = 0; i < s->n_children; i++) { - if (s->children[i] && s->children[i]->release) { - s->children[i]->release(s->children[i]); - } - free(s->children[i]); - } - free(s->children); - } - s->release = nullptr; - }; - - // Field 0: id (int64) - schema->children[0] = static_cast(malloc(sizeof(struct ArrowSchema))); - schema->children[0]->format = "l"; // int64 - schema->children[0]->name = "id"; - schema->children[0]->metadata = nullptr; - schema->children[0]->flags = 0; - schema->children[0]->n_children = 0; - schema->children[0]->children = nullptr; - schema->children[0]->dictionary = nullptr; - schema->children[0]->release = [](struct ArrowSchema* s) { s->release = nullptr; }; - - // Field 1: value (string) - schema->children[1] = static_cast(malloc(sizeof(struct ArrowSchema))); - schema->children[1]->format = "u"; // utf8 string - schema->children[1]->name = "value"; - schema->children[1]->metadata = nullptr; - schema->children[1]->flags = 0; - schema->children[1]->n_children = 0; - schema->children[1]->children = nullptr; - schema->children[1]->dictionary = nullptr; - schema->children[1]->release = [](struct ArrowSchema* s) { s->release = nullptr; }; -} - -// Helper function to create a batch of data -static void create_batch(struct ArrowArray* array, size_t start_row, size_t batch_size) -{ - // Main array structure - array->length = batch_size; - array->null_count = 0; - array->offset = 0; - array->n_buffers = 1; - array->n_children = 2; - array->buffers = static_cast(malloc(1 * sizeof(void*))); - array->buffers[0] = nullptr; // validity buffer (no nulls) - array->children = static_cast(malloc(2 * sizeof(struct ArrowArray*))); - array->dictionary = nullptr; - - // Create id column (int64) - array->children[0] = static_cast(malloc(sizeof(struct ArrowArray))); - struct ArrowArray* id_array = array->children[0]; - id_array->length = batch_size; - id_array->null_count = 0; - id_array->offset = 0; - id_array->n_buffers = 2; - id_array->n_children = 0; - id_array->buffers = static_cast(malloc(2 * sizeof(void*))); - id_array->buffers[0] = nullptr; // validity buffer - - // Allocate and fill id data - int64_t* id_data = static_cast(malloc(batch_size * sizeof(int64_t))); - for (size_t i = 0; i < batch_size; i++) - id_data[i] = start_row + i; - - id_array->buffers[1] = id_data; // data buffer - id_array->children = nullptr; - id_array->dictionary = nullptr; - id_array->release = [](struct ArrowArray* arr) - { - if (arr->buffers) { - free(const_cast(arr->buffers[1])); // free data buffer - free(const_cast(arr->buffers)); - } - arr->release = nullptr; - }; - - // Create value column (string) - array->children[1] = static_cast(malloc(sizeof(struct ArrowArray))); - struct ArrowArray* str_array = array->children[1]; - str_array->length = batch_size; - str_array->null_count = 0; - str_array->offset = 0; - str_array->n_buffers = 3; - str_array->n_children = 0; - str_array->buffers = static_cast(malloc(3 * sizeof(void*))); - str_array->buffers[0] = nullptr; // validity buffer - - // Create offset buffer (int32) - int32_t* offsets = static_cast(malloc((batch_size + 1) * sizeof(int32_t))); - offsets[0] = 0; - - // Calculate total string length and create strings - size_t total_str_len = 0; - std::vector strings; - for (size_t i = 0; i < batch_size; i++) - { - std::string str = "value_" + std::to_string(start_row + i); - strings.push_back(str); - total_str_len += str.length(); - offsets[i + 1] = total_str_len; - } - str_array->buffers[1] = offsets; // offset buffer - - // Create data buffer - char* str_data = static_cast(malloc(total_str_len)); - size_t pos = 0; - for (const auto& str : strings) - { - memcpy(str_data + pos, str.c_str(), str.length()); - pos += str.length(); - } - str_array->buffers[2] = str_data; // data buffer - - str_array->children = nullptr; - str_array->dictionary = nullptr; - str_array->release = [](struct ArrowArray* arr) - { - if (arr->buffers) { - free(const_cast(arr->buffers[1])); // free offset buffer - free(const_cast(arr->buffers[2])); // free data buffer - free(const_cast(arr->buffers)); - } - arr->release = nullptr; - }; - - // Main array release function - array->release = [](struct ArrowArray* arr) { - if (arr->children) { - for (int64_t i = 0; i < arr->n_children; i++) { - if (arr->children[i] && arr->children[i]->release) { - arr->children[i]->release(arr->children[i]); - } - free(arr->children[i]); - } - free(arr->children); - } - if (arr->buffers) { - free(const_cast(arr->buffers)); - } - arr->release = nullptr; - }; -} - -// Callback function to get schema -static int custom_get_schema(struct ArrowArrayStream * /* stream */, struct ArrowSchema * out) -{ - create_schema(out); - return 0; -} - -// Callback function to get next array -static int custom_get_next(struct ArrowArrayStream * stream, struct ArrowArray * out) -{ - auto* data = static_cast(stream->private_data); - if (!data) - return EINVAL; - - // Check if we've reached the end of the stream - if (data->current_row >= data->total_rows) - { - // End of stream - set release to nullptr to indicate no more data - out->release = nullptr; - return 0; - } - - // Calculate batch size for this iteration - size_t remaining_rows = data->total_rows - data->current_row; - size_t batch_size = std::min(data->batch_size, remaining_rows); - - // Create the batch - create_batch(out, data->current_row, batch_size); - - data->current_row += batch_size; - return 0; -} - -// Callback function to get last error -static const char* custom_get_last_error(struct ArrowArrayStream* stream) { - auto* data = static_cast(stream->private_data); - if (!data || data->last_error.empty()) - return nullptr; - - return data->last_error.c_str(); -} - -// Callback function to release stream resources -static void custom_release(struct ArrowArrayStream* stream) { - if (stream->private_data) - { - delete static_cast(stream->private_data); - stream->private_data = nullptr; - } - stream->release = nullptr; -} - -// Helper function to reset the ArrowArrayStream for reuse -static void reset_arrow_stream(struct ArrowArrayStream* stream) -{ - if (stream && stream->private_data) - { - auto* data = static_cast(stream->private_data); - data->reset(); - std::cout << "✓ ArrowArrayStream has been reset, ready for re-reading\n"; - } -} - -//===--------------------------------------------------------------------===// -// Unit Test Utilities -//===--------------------------------------------------------------------===// - -static void test_assert(bool condition, const std::string& test_name, const std::string& message = "") -{ - if (condition) - { - std::cout << "✓ PASS: " << test_name << std::endl; - } - else - { - std::cout << "✗ FAIL: " << test_name; - if (!message.empty()) - { - std::cout << " - " << message; - } - std::cout << std::endl; - exit(1); - } -} - -static void test_assert_chdb_state(chdb_state state, const std::string& operation_name) -{ - test_assert(state == CHDBSuccess, - "chDB operation: " + operation_name, - state == CHDBError ? "Operation failed" : "Unknown state"); -} - -static void test_assert_not_null(void* ptr, const std::string& test_name) -{ - test_assert(ptr != nullptr, test_name, "Pointer is null"); -} - -static void test_assert_no_error(chdb_result* result, const std::string& query_name) -{ - test_assert_not_null(result, query_name + " - Result is not null"); - - const char * error = chdb_result_error(result); - test_assert(error == nullptr, - query_name + " - No query error", - error ? std::string("Error: ") + error : ""); -} - -static void test_assert_query_result_contains(chdb_result* result, const std::string& expected_content, const std::string& query_name) -{ - test_assert_no_error(result, query_name); - - char * buffer = chdb_result_buffer(result); - test_assert_not_null(buffer, query_name + " - Result buffer is not null"); - - std::string result_str(buffer); - test_assert(result_str.find(expected_content) != std::string::npos, - query_name + " - Result contains expected content", - "Expected: " + expected_content + ", Actual: " + result_str); -} - -static void test_assert_row_count(chdb_result* result, uint64_t expected_rows, const std::string& query_name) -{ - test_assert_no_error(result, query_name); - - char* buffer = chdb_result_buffer(result); - test_assert_not_null(buffer, query_name + " - Result buffer is not null"); - - // Parse the count result (assuming CSV format with just the number) - std::string result_str(buffer); - // Remove trailing whitespace/newlines - result_str.erase(result_str.find_last_not_of(" \t\n\r\f\v") + 1); - - uint64_t actual_rows = std::stoull(result_str); - test_assert(actual_rows == expected_rows, - query_name + " - Row count matches", - "Expected: " + std::to_string(expected_rows) + ", Actual: " + std::to_string(actual_rows)); -} - -void test_arrow_scan(chdb_connection conn) -{ - std::cout << "\n=== Creating Custom ArrowArrayStream ===\n"; - std::cout << "Data specification: 1,000,000 rows × 2 columns (id: int64, value: string)\n"; - - struct ArrowArrayStream stream; - memset(&stream, 0, sizeof(stream)); - - // Create and initialize stream data - auto * stream_data = new CustomStreamData(); - - // Set up the ArrowArrayStream callbacks - stream.get_schema = custom_get_schema; - stream.get_next = custom_get_next; - stream.get_last_error = custom_get_last_error; - stream.release = custom_release; - stream.private_data = stream_data; - - std::cout << "✓ ArrowArrayStream initialization completed\n"; - std::cout << "Starting registration with chDB...\n"; - - const char * table_name = "test_arrow_table"; - const char * non_exist_table_name = "non_exist_table"; - - chdb_arrow_stream arrow_stream = reinterpret_cast(&stream); - chdb_state result = chdb_arrow_scan(conn, table_name, arrow_stream); - - // Test 1: Verify arrow registration succeeded - test_assert_chdb_state(result, "Register ArrowArrayStream to table: " + std::string(table_name)); - - // Test 2: Unregister non-existent table should handle gracefully - result = chdb_arrow_unregister_table(conn, non_exist_table_name); - test_assert_chdb_state(result, "Unregister non-existent table: " + std::string(non_exist_table_name)); - - // Test 3: Count rows - should be exactly 1,000,000 - chdb_result * count_result = chdb_query(conn, "SELECT COUNT(*) as total_rows FROM arrowstream(test_arrow_table)", "CSV"); - test_assert_row_count(count_result, 1000000, "Count total rows"); - chdb_destroy_query_result(count_result); - - // Test 4: Sample first 5 rows - should contain id=0,1,2,3,4 - reset_arrow_stream(&stream); - chdb_result * sample_result = chdb_query(conn, "SELECT * FROM arrowstream(test_arrow_table) LIMIT 5", "CSV"); - test_assert_query_result_contains(sample_result, "0,\"value_0\"", "First 5 rows contain first row"); - test_assert_query_result_contains(sample_result, "4,\"value_4\"", "First 5 rows contain fifth row"); - chdb_destroy_query_result(sample_result); - - // Test 5: Sample last 5 rows - should contain id=999999,999998,999997,999996,999995 - reset_arrow_stream(&stream); - chdb_result * last_result = chdb_query(conn, "SELECT * FROM arrowstream(test_arrow_table) ORDER BY id DESC LIMIT 5", "CSV"); - test_assert_query_result_contains(last_result, "999999,\"value_999999\"", "Last 5 rows contain last row"); - test_assert_query_result_contains(last_result, "999995,\"value_999995\"", "Last 5 rows contain fifth row"); - chdb_destroy_query_result(last_result); - - // Test 6: Multiple table registration tests - // Create second ArrowArrayStream with different data (500,000 rows) - struct ArrowArrayStream stream2; - memset(&stream2, 0, sizeof(stream2)); - auto * stream_data2 = new CustomStreamData(); - stream_data2->total_rows = 500000; // Different row count - stream_data2->current_row = 0; - stream2.get_schema = custom_get_schema; - stream2.get_next = custom_get_next; - stream2.get_last_error = custom_get_last_error; - stream2.release = custom_release; - stream2.private_data = stream_data2; - - // Create third ArrowArrayStream with different data (100,000 rows) - struct ArrowArrayStream stream3; - memset(&stream3, 0, sizeof(stream3)); - auto * stream_data3 = new CustomStreamData(); - stream_data3->total_rows = 100000; // Different row count - stream_data3->current_row = 0; - stream3.get_schema = custom_get_schema; - stream3.get_next = custom_get_next; - stream3.get_last_error = custom_get_last_error; - stream3.release = custom_release; - stream3.private_data = stream_data3; - - const char * table_name2 = "test_arrow_table_2"; - const char * table_name3 = "test_arrow_table_3"; - - // Register second table - chdb_arrow_stream arrow_stream2 = reinterpret_cast(&stream2); - result = chdb_arrow_scan(conn, table_name2, arrow_stream2); - test_assert_chdb_state(result, "Register second ArrowArrayStream to table: " + std::string(table_name2)); - - // Register third table - chdb_arrow_stream arrow_stream3 = reinterpret_cast(&stream3); - result = chdb_arrow_scan(conn, table_name3, arrow_stream3); - test_assert_chdb_state(result, "Register third ArrowArrayStream to table: " + std::string(table_name3)); - - // Test 6a: Verify each table has correct row counts - reset_arrow_stream(&stream); - chdb_result * count1_result = chdb_query(conn, "SELECT COUNT(*) FROM arrowstream(test_arrow_table)", "CSV"); - test_assert_row_count(count1_result, 1000000, "First table row count"); - chdb_destroy_query_result(count1_result); - - reset_arrow_stream(&stream2); - chdb_result * count2_result = chdb_query(conn, "SELECT COUNT(*) FROM arrowstream(test_arrow_table_2)", "CSV"); - test_assert_row_count(count2_result, 500000, "Second table row count"); - chdb_destroy_query_result(count2_result); - - reset_arrow_stream(&stream3); - chdb_result * count3_result = chdb_query(conn, "SELECT COUNT(*) FROM arrowstream(test_arrow_table_3)", "CSV"); - test_assert_row_count(count3_result, 100000, "Third table row count"); - chdb_destroy_query_result(count3_result); - - // Test 6b: Test cross-table JOIN query - reset_arrow_stream(&stream); - reset_arrow_stream(&stream2); - chdb_result * join_result = chdb_query(conn, - "SELECT t1.id, t1.value, t2.value as value2 " - "FROM arrowstream(test_arrow_table) t1 " - "INNER JOIN arrowstream(test_arrow_table_2) t2 ON t1.id = t2.id " - "WHERE t1.id < 5 ORDER BY t1.id", "CSV"); - test_assert_query_result_contains(join_result, R"(0,"value_0","value_0")", "JOIN query contains expected data"); - test_assert_query_result_contains(join_result, R"(4,"value_4","value_4")", "JOIN query contains fifth row"); - chdb_destroy_query_result(join_result); - - // Test 6c: Test UNION query across multiple tables - reset_arrow_stream(&stream2); - reset_arrow_stream(&stream3); - chdb_result * union_result = chdb_query(conn, - "SELECT COUNT(*) FROM (" - "SELECT id FROM arrowstream(test_arrow_table_2) WHERE id < 10 " - "UNION ALL " - "SELECT id FROM arrowstream(test_arrow_table_3) WHERE id < 10" - ")", "CSV"); - test_assert_row_count(union_result, 20, "UNION query row count"); - chdb_destroy_query_result(union_result); - - // Cleanup additional tables - result = chdb_arrow_unregister_table(conn, table_name2); - test_assert_chdb_state(result, "Unregister second ArrowArrayStream table"); - - result = chdb_arrow_unregister_table(conn, table_name3); - test_assert_chdb_state(result, "Unregister third ArrowArrayStream table"); - - // Test 7: Unregister original table should succeed - result = chdb_arrow_unregister_table(conn, table_name); - test_assert_chdb_state(result, "Unregister ArrowArrayStream table: " + std::string(table_name)); - - // Test 8: Sample last 5 rows after unregister should fail - reset_arrow_stream(&stream); - chdb_result * unregister_result = chdb_query(conn, "SELECT * FROM arrowstream(test_arrow_table) ORDER BY id DESC LIMIT 5", "CSV"); - const char * error = chdb_result_error(unregister_result); - test_assert(error != nullptr, - "Query after unregister should fail", - error ? std::string("Got expected error: ") + error : "No error returned when error was expected"); - chdb_destroy_query_result(unregister_result); -} - -// Helper function to create ArrowArray with specified row count -static void create_arrow_array(struct ArrowArray * array, uint64_t row_count) -{ - array->length = row_count; - array->null_count = 0; - array->offset = 0; - array->n_buffers = 1; - array->n_children = 2; - array->buffers = static_cast(malloc(1 * sizeof(void *))); - array->buffers[0] = nullptr; // validity buffer - - array->children = static_cast(malloc(2 * sizeof(struct ArrowArray *))); - array->dictionary = nullptr; - - // Create id column (int64) - array->children[0] = static_cast(malloc(sizeof(struct ArrowArray))); - struct ArrowArray * id_array = array->children[0]; - id_array->length = row_count; - id_array->null_count = 0; - id_array->offset = 0; - id_array->n_buffers = 2; - id_array->n_children = 0; - id_array->children = nullptr; - id_array->dictionary = nullptr; - - id_array->buffers = static_cast(malloc(2 * sizeof(void *))); - id_array->buffers[0] = nullptr; // validity buffer - - // Allocate and populate id data - int64_t * id_data = static_cast(malloc(row_count * sizeof(int64_t))); - for (uint64_t i = 0; i < row_count; i++) - { - id_data[i] = static_cast(i); - } - id_array->buffers[1] = id_data; - - id_array->release = [](struct ArrowArray * a) - { - if (a->buffers) - { - free(const_cast(a->buffers[1])); // id data - free(const_cast(a->buffers)); - } - free(a); - }; - - // Create value column (string) - array->children[1] = static_cast(malloc(sizeof(struct ArrowArray))); - struct ArrowArray * value_array = array->children[1]; - value_array->length = row_count; - value_array->null_count = 0; - value_array->offset = 0; - value_array->n_buffers = 3; - value_array->n_children = 0; - value_array->children = nullptr; - value_array->dictionary = nullptr; - - value_array->buffers = static_cast(malloc(3 * sizeof(void *))); - value_array->buffers[0] = nullptr; // validity buffer - - // Calculate total string data size and create offset array - int32_t * offsets = static_cast(malloc((row_count + 1) * sizeof(int32_t))); - size_t total_string_size = 0; - offsets[0] = 0; - - for (uint64_t i = 0; i < row_count; i++) - { - std::string value_str = "value_" + std::to_string(i); - total_string_size += value_str.length(); - offsets[i + 1] = static_cast(total_string_size); - } - - value_array->buffers[1] = offsets; - - // Allocate and populate string data - char * string_data = static_cast(malloc(total_string_size)); - size_t current_pos = 0; - for (uint64_t i = 0; i < row_count; i++) { - std::string value_str = "value_" + std::to_string(i); - memcpy(string_data + current_pos, value_str.c_str(), value_str.length()); - current_pos += value_str.length(); - } - value_array->buffers[2] = string_data; - - value_array->release = [](struct ArrowArray * a) { - if (a->buffers) { - free(const_cast(a->buffers[1])); // offsets - free(const_cast(a->buffers[2])); // string data - free(const_cast(a->buffers)); - } - free(a); - }; - - // Set release callback for main array - array->release = [](struct ArrowArray * a) - { - if (a->children) - { - for (int64_t i = 0; i < a->n_children; i++) - { - if (a->children[i] && a->children[i]->release) - { - a->children[i]->release(a->children[i]); - } - } - free(a->children); - } - if (a->buffers) { - free(const_cast(a->buffers)); - } - }; -} - -void test_arrow_array_scan(chdb_connection conn) -{ - std::cout << "\n=== Testing ArrowArray Scan Functions ===\n"; - std::cout << "Data specification: 1,000,000 rows × 2 columns (id: int64, value: string)\n"; - - // Create ArrowSchema (reuse existing function) - struct ArrowSchema schema; - create_schema(&schema); - - // Create ArrowArray with 1,000,000 rows - struct ArrowArray array; - memset(&array, 0, sizeof(array)); - create_arrow_array(&array, 1000000); - - std::cout << "✓ ArrowArray initialization completed\n"; - std::cout << "Starting registration with chDB...\n"; - - const char * table_name = "test_arrow_array_table"; - const char * non_exist_table_name = "non_exist_array_table"; - - chdb_arrow_schema arrow_schema = reinterpret_cast(&schema); - chdb_arrow_array arrow_array = reinterpret_cast(&array); - - // Test 1: Register -> Query -> Unregister for row count - chdb_state result = chdb_arrow_array_scan(conn, table_name, arrow_schema, arrow_array); - test_assert_chdb_state(result, "Register ArrowArray to table: " + std::string(table_name)); - - chdb_result * count_result = chdb_query(conn, "SELECT COUNT(*) as total_rows FROM arrowstream(test_arrow_array_table)", "CSV"); - test_assert_row_count(count_result, 1000000, "Count total rows"); - chdb_destroy_query_result(count_result); - - result = chdb_arrow_unregister_table(conn, table_name); - test_assert_chdb_state(result, "Unregister ArrowArray table after count query"); - - // Test 2: Unregister non-existent table should handle gracefully - result = chdb_arrow_unregister_table(conn, non_exist_table_name); - test_assert_chdb_state(result, "Unregister non-existent array table: " + std::string(non_exist_table_name)); - - // Test 3: Register -> Query -> Unregister for first 5 rows - result = chdb_arrow_array_scan(conn, table_name, arrow_schema, arrow_array); - test_assert_chdb_state(result, "Register ArrowArray for sample query"); - - chdb_result * sample_result = chdb_query(conn, "SELECT * FROM arrowstream(test_arrow_array_table) LIMIT 5", "CSV"); - test_assert_query_result_contains(sample_result, "0,\"value_0\"", "First 5 rows contain first row"); - test_assert_query_result_contains(sample_result, "4,\"value_4\"", "First 5 rows contain fifth row"); - chdb_destroy_query_result(sample_result); - - result = chdb_arrow_unregister_table(conn, table_name); - test_assert_chdb_state(result, "Unregister ArrowArray table after sample query"); - - // Test 4: Register -> Query -> Unregister for last 5 rows - result = chdb_arrow_array_scan(conn, table_name, arrow_schema, arrow_array); - test_assert_chdb_state(result, "Register ArrowArray for last rows query"); - - chdb_result * last_result = chdb_query(conn, "SELECT * FROM arrowstream(test_arrow_array_table) ORDER BY id DESC LIMIT 5", "CSV"); - test_assert_query_result_contains(last_result, "999999,\"value_999999\"", "Last 5 rows contain last row"); - test_assert_query_result_contains(last_result, "999995,\"value_999995\"", "Last 5 rows contain fifth row"); - chdb_destroy_query_result(last_result); - - result = chdb_arrow_unregister_table(conn, table_name); - test_assert_chdb_state(result, "Unregister ArrowArray table after last rows query"); - - // Test 5: Independent multiple table tests - // Create second ArrowArray with different data (500,000 rows) - struct ArrowSchema schema2; - create_schema(&schema2); - struct ArrowArray array2; - memset(&array2, 0, sizeof(array2)); - create_arrow_array(&array2, 500000); - - // Create third ArrowArray with different data (100,000 rows) - struct ArrowSchema schema3; - create_schema(&schema3); - struct ArrowArray array3; - memset(&array3, 0, sizeof(array3)); - create_arrow_array(&array3, 100000); - - const char * table_name2 = "test_arrow_array_table_2"; - const char * table_name3 = "test_arrow_array_table_3"; - - chdb_arrow_schema arrow_schema2 = reinterpret_cast(&schema2); - chdb_arrow_array arrow_array2 = reinterpret_cast(&array2); - chdb_arrow_schema arrow_schema3 = reinterpret_cast(&schema3); - chdb_arrow_array arrow_array3 = reinterpret_cast(&array3); - - // Test 5a: Register -> Query -> Unregister for second table (500K rows) - result = chdb_arrow_array_scan(conn, table_name2, arrow_schema2, arrow_array2); - test_assert_chdb_state(result, "Register second ArrowArray to table: " + std::string(table_name2)); - - chdb_result * count2_result = chdb_query(conn, "SELECT COUNT(*) FROM arrowstream(test_arrow_array_table_2)", "CSV"); - test_assert_row_count(count2_result, 500000, "Second array table row count"); - chdb_destroy_query_result(count2_result); - - result = chdb_arrow_unregister_table(conn, table_name2); - test_assert_chdb_state(result, "Unregister second ArrowArray table"); - - // Test 5b: Register -> Query -> Unregister for third table (100K rows) - result = chdb_arrow_array_scan(conn, table_name3, arrow_schema3, arrow_array3); - test_assert_chdb_state(result, "Register third ArrowArray to table: " + std::string(table_name3)); - - chdb_result * count3_result = chdb_query(conn, "SELECT COUNT(*) FROM arrowstream(test_arrow_array_table_3)", "CSV"); - test_assert_row_count(count3_result, 100000, "Third array table row count"); - chdb_destroy_query_result(count3_result); - - result = chdb_arrow_unregister_table(conn, table_name3); - test_assert_chdb_state(result, "Unregister third ArrowArray table"); - - // Test 6: Cross-table JOIN query (Register both -> Query -> Unregister both) - result = chdb_arrow_array_scan(conn, table_name, arrow_schema, arrow_array); - test_assert_chdb_state(result, "Register first ArrowArray for JOIN"); - - result = chdb_arrow_array_scan(conn, table_name2, arrow_schema2, arrow_array2); - test_assert_chdb_state(result, "Register second ArrowArray for JOIN"); - - chdb_result * join_result = chdb_query(conn, - "SELECT t1.id, t1.value, t2.value as value2 " - "FROM arrowstream(test_arrow_array_table) t1 " - "INNER JOIN arrowstream(test_arrow_array_table_2) t2 ON t1.id = t2.id " - "WHERE t1.id < 5 ORDER BY t1.id", "CSV"); - test_assert_query_result_contains(join_result, R"(0,"value_0","value_0")", "Array JOIN query contains expected data"); - test_assert_query_result_contains(join_result, R"(4,"value_4","value_4")", "Array JOIN query contains fifth row"); - chdb_destroy_query_result(join_result); - - result = chdb_arrow_unregister_table(conn, table_name); - test_assert_chdb_state(result, "Unregister first ArrowArray after JOIN"); - - result = chdb_arrow_unregister_table(conn, table_name2); - test_assert_chdb_state(result, "Unregister second ArrowArray after JOIN"); - - // Test 7: Cross-table UNION query (Register both -> Query -> Unregister both) - result = chdb_arrow_array_scan(conn, table_name2, arrow_schema2, arrow_array2); - test_assert_chdb_state(result, "Register second ArrowArray for UNION"); - - result = chdb_arrow_array_scan(conn, table_name3, arrow_schema3, arrow_array3); - test_assert_chdb_state(result, "Register third ArrowArray for UNION"); - - chdb_result * union_result = chdb_query(conn, - "SELECT COUNT(*) FROM (" - "SELECT id FROM arrowstream(test_arrow_array_table_2) WHERE id < 10 " - "UNION ALL " - "SELECT id FROM arrowstream(test_arrow_array_table_3) WHERE id < 10" - ")", "CSV"); - test_assert_row_count(union_result, 20, "Array UNION query row count"); - chdb_destroy_query_result(union_result); - - result = chdb_arrow_unregister_table(conn, table_name2); - test_assert_chdb_state(result, "Unregister second ArrowArray after UNION"); - - result = chdb_arrow_unregister_table(conn, table_name3); - test_assert_chdb_state(result, "Unregister third ArrowArray after UNION"); - - // Test 8: Query after unregister should fail - chdb_result * unregister_result = chdb_query(conn, "SELECT * FROM arrowstream(test_arrow_array_table) ORDER BY id DESC LIMIT 5", "CSV"); - const char * error = chdb_result_error(unregister_result); - test_assert(error != nullptr, - "Array query after unregister should fail", - error ? std::string("Got expected error: ") + error : "No error returned when error was expected"); - chdb_destroy_query_result(unregister_result); - - // Cleanup ArrowArrays and schemas - if (array.release) array.release(&array); - if (schema.release) schema.release(&schema); - if (array2.release) array2.release(&array2); - if (schema2.release) schema2.release(&schema2); - if (array3.release) array3.release(&array3); - if (schema3.release) schema3.release(&schema3); -} - -int main() -{ - const char *argv[] = {"clickhouse", "--multiquery"}; - int argc = sizeof(argv) / sizeof(argv[0]); - chdb_connection * conn_ptr; - chdb_connection conn; - - std::cout << "=== chDB Arrow Functions Test ===\n"; - - // Create connection - conn_ptr = chdb_connect(argc, const_cast(argv)); - if (!conn_ptr || !*conn_ptr) { - std::cout << "Failed to create chDB connection\n"; - return 1; - } - - conn = *conn_ptr; - std::cout << "✓ chDB connection established\n"; - - // Run test suites - test_arrow_scan(conn); - test_arrow_array_scan(conn); - - // Clean up - chdb_close_conn(conn_ptr); - - std::cout << "\n=== chDB Arrow Functions Test Completed ===\n"; - - return 0; -} From 69149cc4560ea8b274d52788c0bfe45199758400 Mon Sep 17 00:00:00 2001 From: wudidapaopao Date: Tue, 30 Sep 2025 02:14:08 +0800 Subject: [PATCH 12/13] test: add test_unsupported_arrow_types.py --- src/Core/FormatFactorySettings.h | 2 +- .../Formats/Impl/ArrowColumnToCHColumn.cpp | 10 +- tests/test_query_py.py | 11 ++ tests/test_unsupported_arrow_types.py | 150 ++++++++++++++++++ 4 files changed, 165 insertions(+), 8 deletions(-) create mode 100644 tests/test_unsupported_arrow_types.py diff --git a/src/Core/FormatFactorySettings.h b/src/Core/FormatFactorySettings.h index 02ae9763348..0fe39a82028 100644 --- a/src/Core/FormatFactorySettings.h +++ b/src/Core/FormatFactorySettings.h @@ -301,7 +301,7 @@ Skip columns with unsupported types while schema inference for format CapnProto DECLARE(Bool, input_format_orc_skip_columns_with_unsupported_types_in_schema_inference, false, R"( Skip columns with unsupported types while schema inference for format ORC )", 0) \ - DECLARE(Bool, input_format_arrow_skip_columns_with_unsupported_types_in_schema_inference, true, R"( + DECLARE(Bool, input_format_arrow_skip_columns_with_unsupported_types_in_schema_inference, false, R"( Skip columns with unsupported types while schema inference for format Arrow )", 0) \ DECLARE(String, column_names_for_schema_inference, "", R"( diff --git a/src/Processors/Formats/Impl/ArrowColumnToCHColumn.cpp b/src/Processors/Formats/Impl/ArrowColumnToCHColumn.cpp index 273122ec09f..3a501a4d29a 100644 --- a/src/Processors/Formats/Impl/ArrowColumnToCHColumn.cpp +++ b/src/Processors/Formats/Impl/ArrowColumnToCHColumn.cpp @@ -1220,13 +1220,9 @@ static ColumnWithTypeAndName readNonNullableColumnFromArrowColumn( // TODO: read UUID as a string? case arrow::Type::NA: { - if (settings.allow_arrow_null_type) - { - auto type = std::make_shared(); - auto column = ColumnNothing::create(arrow_column->length()); - return {std::move(column), type, column_name}; - } - [[fallthrough]]; + auto type = std::make_shared(); + auto column = ColumnNothing::create(arrow_column->length()); + return {std::move(column), type, column_name}; } default: { diff --git a/tests/test_query_py.py b/tests/test_query_py.py index b33097dfede..a35dccf8d47 100644 --- a/tests/test_query_py.py +++ b/tests/test_query_py.py @@ -254,6 +254,17 @@ def test_query_arrow5(self): }, ) + def test_query_arrow_null_type(self): + null_array = pa.array([None, None, None]) + table = pa.table([null_array], names=["null_col"]) + ret = chdb.query("SELECT * FROM Python(table)") + self.assertEqual(str(ret), "\\N\n\\N\n\\N\n") + + null_array = pa.array([None, 1, None]) + table = pa.table([null_array], names=["null_col"]) + ret = chdb.query("SELECT * FROM Python(table)") + self.assertEqual(str(ret), "\\N\n1\n\\N\n") + def test_random_float(self): x = {"col1": [random.uniform(0, 1) for _ in range(0, 100000)]} ret = chdb.sql( diff --git a/tests/test_unsupported_arrow_types.py b/tests/test_unsupported_arrow_types.py new file mode 100644 index 00000000000..88640876e1f --- /dev/null +++ b/tests/test_unsupported_arrow_types.py @@ -0,0 +1,150 @@ +#!/usr/bin/env python3 + +import unittest +import pyarrow as pa +import pyarrow.compute as pc +import chdb +from chdb import ChdbError + + +class TestUnsupportedArrowTypes(unittest.TestCase): + """Test that chDB properly handles unsupported Arrow types""" + + def setUp(self): + """Set up test data""" + self.sample_data = [1, 2, 3, 4, 5] + self.sample_strings = ["a", "b", "c", "d", "e"] + + def test_sparse_union_type(self): + """Test SPARSE_UNION type - should fail""" + # Create a sparse union type + children = [ + pa.array([1, None, 3, None, 5]), + pa.array([None, "b", None, "d", None]) + ] + type_ids = pa.array([0, 1, 0, 1, 0], type=pa.int8()) + + union_array = pa.UnionArray.from_sparse(type_ids, children) + table = pa.table([union_array], names=["sparse_union_col"]) + + with self.assertRaises(Exception) as context: + chdb.query("SELECT * FROM Python(table)") + + self.assertIn("Unsupported", str(context.exception)) + + def test_dense_union_type(self): + """Test DENSE_UNION type - should fail""" + # Create a dense union type + children = [ + pa.array([1, 3, 5]), + pa.array(["b", "d"]) + ] + type_ids = pa.array([0, 1, 0, 1, 0], type=pa.int8()) + offsets = pa.array([0, 0, 1, 1, 2], type=pa.int32()) + + union_array = pa.UnionArray.from_dense(type_ids, offsets, children) + table = pa.table([union_array], names=["dense_union_col"]) + + with self.assertRaises(Exception) as context: + chdb.query("SELECT * FROM Python(table)") + + self.assertIn("Unsupported", str(context.exception)) + + def test_interval_month_day_type(self): + """Test INTERVAL_MONTH_DAY type - should fail""" + pass + + def test_interval_day_time_type(self): + """Test INTERVAL_DAY_TIME type - should fail""" + pass + + def test_interval_month_day_nano_type(self): + """Test INTERVAL_MONTH_DAY_NANO type - should fail""" + start_timestamps = pc.strptime( + pa.array(["2021-01-01 00:00:00", "2022-01-01 00:00:00", "2023-01-01 00:00:00"]), + format="%Y-%m-%d %H:%M:%S", + unit="ns" + ) + + end_timestamps = pc.strptime( + pa.array(["2021-04-01 00:00:00", "2022-05-01 00:00:00", "2023-07-01 00:00:00"]), + format="%Y-%m-%d %H:%M:%S", + unit="ns" + ) + + interval_array = pc.month_day_nano_interval_between(start_timestamps, end_timestamps) + table = pa.table([interval_array], names=["interval_month_col"]) + + with self.assertRaises(Exception) as context: + chdb.query("SELECT * FROM Python(table)") + + self.assertIn("Unsupported", str(context.exception)) + + def test_list_view_type(self): + """Test LIST_VIEW type - should fail""" + # Create list view array + list_data = [[1, 2], [3, 4, 5], [6], [], [7, 8, 9]] + list_view_array = pa.array(list_data, type=pa.list_view(pa.int64())) + table = pa.table([list_view_array], names=["list_view_col"]) + + with self.assertRaises(Exception) as context: + chdb.query("SELECT * FROM Python(table)") + + self.assertIn("Unsupported", str(context.exception)) + + def test_large_list_view_type(self): + """Test LARGE_LIST_VIEW type - should fail""" + # Create large list view array (if available) + list_data = [[1, 2], [3, 4, 5], [6], [], [7, 8, 9]] + large_list_view_array = pa.array(list_data, type=pa.large_list_view(pa.int64())) + table = pa.table([large_list_view_array], names=["large_list_view_col"]) + + with self.assertRaises(Exception) as context: + chdb.query("SELECT * FROM Python(table)") + + self.assertIn("Unsupported", str(context.exception)) + + def test_run_end_encoded_type(self): + """Test RUN_END_ENCODED type - should fail""" + # Create run-end encoded array + values = pa.array([1, 2, 3]) + run_ends = pa.array([3, 7, 10], type=pa.int32()) + ree_array = pa.RunEndEncodedArray.from_arrays(run_ends, values) + table = pa.table([ree_array], names=["run_end_encoded_col"]) + + with self.assertRaises(Exception) as context: + chdb.query("SELECT * FROM Python(table)") + + self.assertIn("Unsupported", str(context.exception)) + + def test_skip_unsupported_columns_setting(self): + """Test input_format_arrow_skip_columns_with_unsupported_types_in_schema_inference=1 skips unsupported columns""" + # Create a table with both supported and unsupported columns + supported_col = pa.array([1, 2, 3, 4, 5]) # int64 - supported + # Create union array (unsupported) + union_children = [ + pa.array([10, None, 30, None, 50]), + pa.array([None, "b", None, "d", None]) + ] + union_type_ids = pa.array([0, 1, 0, 1, 0], type=pa.int8()) + unsupported_col = pa.UnionArray.from_sparse(union_type_ids, union_children) + + table = pa.table([ + supported_col, + unsupported_col + ], names=["supported_col", "unsupported_col"]) + + # Without the setting, query should fail + with self.assertRaises(Exception) as context: + chdb.query("SELECT * FROM Python(table)") + self.assertIn("Unsupported", str(context.exception)) + + # With the setting, query should succeed but skip unsupported column + result = chdb.query( + "SELECT * FROM Python(table) settings input_format_arrow_skip_columns_with_unsupported_types_in_schema_inference=1" + ) + self.assertEqual(str(result), "1\n2\n3\n4\n5\n") + + +if __name__ == "__main__": + unittest.main(verbosity=2) From 9b4f94d49dfb34bb9d35f85c287d76d61c5cd911 Mon Sep 17 00:00:00 2001 From: wudidapaopao Date: Tue, 30 Sep 2025 11:55:46 +0800 Subject: [PATCH 13/13] test: use macos-14-xlarge runner --- .github/workflows/build_macos_arm64_wheels.yml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/build_macos_arm64_wheels.yml b/.github/workflows/build_macos_arm64_wheels.yml index 2cd312eb755..aeb71d2fc41 100644 --- a/.github/workflows/build_macos_arm64_wheels.yml +++ b/.github/workflows/build_macos_arm64_wheels.yml @@ -22,7 +22,7 @@ on: jobs: build_universal_wheel: name: Build Universal Wheel (macOS ARM64) - runs-on: macos-13-xlarge + runs-on: macos-14-xlarge steps: - name: Check machine architecture run: | @@ -110,7 +110,7 @@ jobs: - name: ccache uses: hendrikmuhs/ccache-action@v1.2 with: - key: macos-13-xlarge + key: macos-14-xlarge max-size: 5G append-timestamp: true - name: Run chdb/build.sh