diff --git a/.dockerignore b/.dockerignore new file mode 100644 index 0000000000000..5a23a1addcc55 --- /dev/null +++ b/.dockerignore @@ -0,0 +1,629 @@ +*-t +*.ctest +*.reject +*.spec +*.bak +*.dgcov +*.rpm +.*.swp +*.ninja +.ninja_* +*.mri +*.mri.tpl +/.cproject +/.project +.gdb_history +.vs/ +/.settings/ +errmsg.sys +typescript +_CPack_Packages +CMakeCache.txt +CMakeFiles/ +MakeFile +install_manifest*.txt +CPackConfig.cmake +CPackSourceConfig.cmake +CTestTestfile.cmake +Docs/INFO_BIN +Docs/INFO_SRC +Makefile +TAGS +Testing/ +tmp/ +VERSION.dep +configure +client/async_example +client/mysql +client/mysql_plugin +client/mysql_upgrade +client/mysqladmin +client/mysqlbinlog +client/mysqlcheck +client/mysqldump +client/mysqlimport +client/mysqlshow +client/mysqlslap +client/mysqltest +client/mariadb-conv +cmake_install.cmake +dbug/*.r +dbug/factorial +dbug/tests +dbug/user.ps +dbug/user.t +extra/comp_err +extra/innochecksum +extra/jemalloc/build/ +extra/jemalloc/tmp/ +extra/mariabackup/mariabackup +extra/mariabackup/mbstream +extra/my_print_defaults +extra/mysql_waitpid +extra/mysqld_safe_helper +extra/perror +extra/replace +extra/resolve_stack_dump +extra/resolveip +extra/wolfssl/user_settings.h +import_executables.cmake +include/*.h.tmp +include/config.h +include/my_config.h +include/mysql_version.h +include/mysqld_ername.h +include/mysqld_error.h +include/sql_state.h +include/probes_mysql.d +include/probes_mysql_dtrace.h +include/probes_mysql_nodtrace.h +include/source_revision.h +info_macros.cmake +libmysql*/libmysql*_exports_file.cc +libmysql*/merge_archives_mysql*.cmake +libmysql*/mysql*_depends.c +libmysql/libmysql_versions.ld +libmysqld/examples/mysql_client_test_embedded +libmysqld/examples/mysql_embedded +libmysqld/examples/mysqltest_embedded +make_dist.cmake +mariadb-*.*.*.tar.gz +mariadb-*.*.*/ +mysql-test/lib/My/SafeProcess/my_safe_process +mysql-test/lib/My/SafeProcess/wsrep_check_version +mysql-test/mtr +mysql-test/mysql-test-run +mysql-test/mariadb-test-run +mysql-test/mysql-stress-test.pl +mysql-test/mysql-test-run.pl +mysql-test/var* +mysql-test-gcov.err +mysql-test-gcov.msg +mysys/test_hash +mysys/thr_lock +mysys/thr_timer +packaging/rpm-oel/mysql.spec +packaging/rpm-uln/mysql.10.0.11.spec +packaging/solaris/postinstall-solaris +extra/pcre2 +extra/libfmt +plugin/auth_pam/auth_pam_tool +plugin/auth_pam/config_auth_pam.h +plugin/aws_key_management/aws-sdk-cpp +plugin/aws_key_management/aws_sdk_cpp +plugin/aws_key_management/aws_sdk_cpp-prefix +scripts/comp_sql +scripts/make_binary_distribution +scripts/msql2mysql +scripts/mysql_config +scripts/mysql_config.pl +scripts/mysql_convert_table_format +scripts/mysql_find_rows +scripts/mysql_fix_extensions +scripts/mysql_fix_privilege_tables.sql +scripts/mysql_fix_privilege_tables_sql.c +scripts/mysql_install_db +scripts/mysql_secure_installation +scripts/mysql_setpermission +scripts/mysql_zap +scripts/mysqlaccess +scripts/mysqlbug +scripts/mysqld_multi +scripts/mysqld_safe +scripts/mysqldumpslow +scripts/mysqlhotcopy +scripts/mytop +scripts/wsrep_sst_backup +scripts/wsrep_sst_common +scripts/wsrep_sst_mysqldump +scripts/wsrep_sst_rsync +scripts/wsrep_sst_rsync_wan +scripts/wsrep_sst_mariabackup +scripts/wsrep_sst_xtrabackup +scripts/wsrep_sst_xtrabackup-v2 +scripts/maria_add_gis_sp.sql +scripts/maria_add_gis_sp_bootstrap.sql +scripts/galera_new_cluster +scripts/galera_recovery +scripts/mysql_convert_table_format.pl +scripts/mysql_sys_schema.sql +scripts/mysqld_multi.pl +scripts/mysqldumpslow.pl +scripts/mysqlhotcopy.pl +sql-bench/bench-count-distinct +sql-bench/bench-init.pl +sql-bench/compare-results +sql-bench/copy-db +sql-bench/crash-me +sql-bench/graph-compare-results +sql-bench/innotest1 +sql-bench/innotest1a +sql-bench/innotest1b +sql-bench/innotest2 +sql-bench/innotest2a +sql-bench/innotest2b +sql-bench/run-all-tests +sql-bench/server-cfg +sql-bench/test-ATIS +sql-bench/test-alter-table +sql-bench/test-big-tables +sql-bench/test-connect +sql-bench/test-create +sql-bench/test-insert +sql-bench/test-select +sql-bench/test-table-elimination +sql-bench/test-transactions +sql-bench/test-wisconsin +sql-bench/bench-count-distinct.pl +sql-bench/compare-results.pl +sql-bench/copy-db.pl +sql-bench/crash-me.pl +sql-bench/graph-compare-results.pl +sql-bench/innotest1.pl +sql-bench/innotest1a.pl +sql-bench/innotest1b.pl +sql-bench/innotest2.pl +sql-bench/innotest2a.pl +sql-bench/innotest2b.pl +sql-bench/run-all-tests.pl +sql-bench/server-cfg.pl +sql-bench/test-ATIS.pl +sql-bench/test-alter-table.pl +sql-bench/test-big-tables.pl +sql-bench/test-connect.pl +sql-bench/test-create.pl +sql-bench/test-insert.pl +sql-bench/test-select.pl +sql-bench/test-table-elimination.pl +sql-bench/test-transactions.pl +sql-bench/test-wisconsin.pl +sql/make_mysqld_lib.cmake +sql/lex_token.h +sql/gen_lex_token +sql/gen_lex_hash +sql/lex_hash.h +sql/myskel.m4 +sql/mysql_tzinfo_to_sql +sql/mysqld +sql/sql_builtin.cc +sql/yy_mariadb.cc +sql/yy_mariadb.hh +sql/yy_mariadb.yy +sql/yy_oracle.cc +sql/yy_oracle.hh +sql/yy_oracle.yy +storage/heap/hp_test1 +storage/heap/hp_test2 +storage/maria/aria_chk +storage/maria/aria_dump_log +storage/maria/aria_ftdump +storage/maria/aria_pack +storage/maria/aria_read_log +storage/maria/aria_s3_copy +storage/maria/ma_rt_test +storage/maria/ma_sp_test +storage/maria/ma_test1 +storage/maria/ma_test2 +storage/maria/ma_test3 +storage/maria/test_ma_backup +storage/myisam/mi_test1 +storage/myisam/mi_test2 +storage/myisam/mi_test3 +storage/myisam/myisam_ftdump +storage/myisam/myisamchk +storage/myisam/myisamlog +storage/myisam/myisampack +storage/myisam/rt_test +storage/myisam/sp_test +storage/perfschema/pfs_config.h +storage/rocksdb/ldb +storage/rocksdb/myrocks_hotbackup +storage/rocksdb/mysql_ldb +storage/rocksdb/rdb_source_revision.h +storage/rocksdb/sst_dump +strings/conf_to_src +support-files/MySQL-shared-compat.spec +support-files/binary-configure +support-files/config.huge.ini +support-files/config.medium.ini +support-files/config.small.ini +support-files/mariadb.pc +support-files/mariadb.pp +support-files/mariadb.service +support-files/mariadb.socket +support-files/mariadb-extra.socket +support-files/mariadb@.service +support-files/mariadb@.socket +support-files/mariadb-extra@.socket +support-files/mini-benchmark +support-files/my-huge.cnf +support-files/my-innodb-heavy-4G.cnf +support-files/my-large.cnf +support-files/my-medium.cnf +support-files/my-small.cnf +support-files/mariadb.logrotate +support-files/mysql.10.0.11.spec +support-files/mysql.server +support-files/mysql.service +support-files/mysql.spec +support-files/mysqld.service +support-files/mysqld_multi.server +support-files/policy/selinux/mysqld-safe.pp +support-files/sysusers.conf +support-files/tmpfiles.conf +support-files/wsrep.cnf +support-files/wsrep_notify +tags +tests/async_queries +tests/bug25714 +tests/mysql_client_test +storage/mroonga/config.sh +storage/mroonga/mrn_version.h +storage/mroonga/data/install.sql +storage/mroonga/vendor/groonga/config.h +storage/mroonga/vendor/groonga/config.sh +storage/mroonga/vendor/groonga/groonga.pc +storage/mroonga/vendor/groonga/src/grnslap +storage/mroonga/vendor/groonga/src/groonga +storage/mroonga/vendor/groonga/src/groonga-benchmark +storage/mroonga/vendor/groonga/src/suggest/groonga-suggest-create-dataset +storage/mroonga/mysql-test/mroonga/storage/r/information_schema_plugins.result +storage/mroonga/mysql-test/mroonga/storage/r/variable_version.result +zlib/zconf.h +xxx/* +yyy/* +zzz/* + +# C and C++ + +# Compiled Object files +*.slo +*.o +*.ko +*.obj +*.elf +*.exp +*.dep +*.idb +*.res +*.tlog + +# Precompiled Headers +*.gch +*.pch + +# Compiled Static libraries +*.lib +*.a +*.la +*.lai +*.lo + +# Compiled Dynamic libraries +*.so +*.so.* +*.dylib +*.dll + +# Executables +*.exe +*.out +*.app +*.i*86 +*.x86_64 +*.hex + + +## Ignore Visual Studio temporary files, build results, and +## files generated by popular Visual Studio add-ons. + +# User-specific files +*.suo +*.user +*.userosscache +*.sln.docstates +*.sln + +*.vcproj +*.vcproj.* +*.vcproj.*.* +*.vcproj.*.*.* +*.vcxproj +*.vcxproj.* +*.vcxproj.*.* +*.vcxproj.*.*.* + +# Build results +[Dd]ebug/ +[Dd]ebugPublic/ +[Rr]elease/ +[Rr]eleases/ +x64/ +x86/ +build/ +bld/ +[Bb]in/ +/cmake-build-debug/ +[Oo]bj/ + +# Roslyn cache directories +*.ide/ + +# MSTest test Results +[Tt]est[Rr]esult*/ +[Bb]uild[Ll]og.* + +#NUNIT +*.VisualState.xml +TestResult.xml + +# Build Results of an ATL Project +[Dd]ebugPS/ +[Rr]eleasePS/ +dlldata.c + +*_i.c +*_p.c +*_i.h +*.ilk +*.meta +*.pdb +*.pgc +*.pgd +*.rsp +*.sbr +*.tlb +*.tli +*.tlh +*.tmp +*.tmp_proj +*.log +*.vspscc +*.vssscc +.builds +*.pidb +*.svclog +*.scc + +# Chutzpah Test files +_Chutzpah* + +# Visual C++ cache files +ipch/ +*.aps +*.ncb +*.opensdf +*.sdf +*.cachefile + +# Visual Studio profiler +*.psess +*.vsp +*.vspx + +# TFS 2012 Local Workspace +$tf/ + +# Guidance Automation Toolkit +*.gpState + +# ReSharper is a .NET coding add-in +_ReSharper*/ +*.[Rr]e[Ss]harper +*.DotSettings.user + +# JustCode is a .NET coding addin-in +.JustCode + +# TeamCity is a build add-in +_TeamCity* + +# DotCover is a Code Coverage Tool +*.dotCover + +# NCrunch +_NCrunch_* +.*crunch*.local.xml + +# MightyMoose +*.mm.* +AutoTest.Net/ + +# Web workbench (sass) +.sass-cache/ + +# Installshield output folder +[Ee]xpress/ + +# DocProject is a documentation generator add-in +DocProject/buildhelp/ +DocProject/Help/*.HxT +DocProject/Help/*.HxC +DocProject/Help/*.hhc +DocProject/Help/*.hhk +DocProject/Help/*.hhp +DocProject/Help/Html2 +DocProject/Help/html + +# Click-Once directory +publish/ + +# Publish Web Output +*.[Pp]ublish.xml +*.azurePubxml +# TODO: Comment the next line if you want to checkin your web deploy settings +# but database connection strings (with potential passwords) will be unencrypted +*.pubxml +*.publishproj + +# NuGet Packages +*.nupkg +# The packages folder can be ignored because of Package Restore +**/packages/* +# except build/, which is used as an MSBuild target. +!**/packages/build/ +# If using the old MSBuild-Integrated Package Restore, uncomment this: +#!**/packages/repositories.config + +# Windows Azure Build Output +csx/ +*.build.csdef + +# Windows Store app package directory +AppPackages/ + +# Others +# sql/ +*.Cache +ClientBin/ +[Ss]tyle[Cc]op.* +~$* +*~ +*.dbmdl +*.dbproj.schemaview +*.pfx +*.publishsettings +node_modules/ + +# RIA/Silverlight projects +Generated_Code/ + +# Backup & report files from converting an old project file +# to a newer Visual Studio version. Backup files are not needed, +# because we have git ;-) +_UpgradeReport_Files/ +Backup*/ +UpgradeLog*.XML +UpgradeLog*.htm +*.stackdump + +# SQL Server files +*.mdf +*.ldf + +# Business Intelligence projects +*.rdl.data +*.bim.layout +*.bim_*.settings + +# Microsoft Fakes +FakesAssemblies/ + +# macOS garbage +.DS_Store + +# QtCreator && CodeBlocks +*.cbp + +compile_commands.json +.clang-format +.kscope/ +.vimrc +.editorconfig +.kateconfig +*.kdev4 + +# Visual Studio Code workspace +.vscode/ + +# Clion && other JetBrains ides +/.idea/ + +.cache/clangd + + +client/mariadb +client/mariadb-admin +client/mariadb-binlog +client/mariadb-check +client/mariadb-dump +client/mariadb-import +client/mariadb-plugin +client/mariadb-show +client/mariadb-slap +client/mariadb-test +client/mariadb-upgrade +extra/mariabackup/mariadb-backup +extra/mariadbd-safe-helper +extra/mariadb-waitpid +libmysqld/examples/mariadb-client-test-embedded +libmysqld/examples/mariadb-embedded +libmysqld/examples/mariadb-test-embedded +man/mariadb.1 +man/mariadb-access.1 +man/mariadb-admin.1 +man/mariadb-backup.1 +man/mariadb-binlog.1 +man/mariadb-check.1 +man/mariadb-client-test.1 +man/mariadb-client-test-embedded.1 +man/mariadb_config.1 +man/mariadb-convert-table-format.1 +man/mariadbd.8 +man/mariadbd-multi.1 +man/mariadbd-safe.1 +man/mariadbd-safe-helper.1 +man/mariadb-dump.1 +man/mariadb-dumpslow.1 +man/mariadb-embedded.1 +man/mariadb-find-rows.1 +man/mariadb-fix-extensions.1 +man/mariadb-hotcopy.1 +man/mariadb-import.1 +man/mariadb-install-db.1 +man/mariadb-ldb.1 +man/mariadb-plugin.1 +man/mariadb-secure-installation.1 +man/mariadb-setpermission.1 +man/mariadb-show.1 +man/mariadb-slap.1 +man/mariadb-test.1 +man/mariadb-test-embedded.1 +man/mariadb-tzinfo-to-sql.1 +man/mariadb-upgrade.1 +man/mariadb-waitpid.1 +scripts/mariadb-access +scripts/mariadb-convert-table-format +scripts/mariadbd-multi +scripts/mariadbd-safe +scripts/mariadb-dumpslow +scripts/mariadb-find-rows +scripts/mariadb-fix-extensions +scripts/mariadb-hotcopy +scripts/mariadb-install-db +scripts/mariadb-secure-installation +scripts/mariadb-setpermission +sql/mariadbd +sql/mariadb-tzinfo-to-sql +storage/rocksdb/mariadb-ldb +strings/ctype-uca1400data.h +strings/uca-dump +tests/mariadb-client-test +versioninfo_dll.rc +versioninfo_exe.rc +win/packaging/ca/symlinks.cc + +# rust output +**/target/ +**/.git + diff --git a/.github/workflows/validation-rust.yaml b/.github/workflows/validation-rust.yaml new file mode 100644 index 0000000000000..57f0cc7bce684 --- /dev/null +++ b/.github/workflows/validation-rust.yaml @@ -0,0 +1,178 @@ +--- + +# Main "useful" actions config file +# Cache config comes from https://github.com/actions/cache/blob/main/examples.md#rust---cargo +# actions-rs/toolchain configures rustup +# actions-rs/cargo actually runs cargo + +on: + push: + branches: + - rust + # pull_request: + +name: Rust Validation + +jobs: + check: + name: "Check (cargo check)" + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v2 + - uses: actions/cache@v3 + with: + path: | + ~/.cargo/bin/ + ~/.cargo/registry/index/ + ~/.cargo/registry/cache/ + ~/.cargo/git/db/ + target/ + rust/target/ + CMakeFiles/ + CMakeCache.txt + CPackConfig.cmake + CPackSourceConfig.cmake + CTestTestfile.cmake + key: ${{ runner.os }}-cargo-${{ hashFiles('**/Cargo.lock') }} + - run: sudo apt-get update && sudo apt-get install cmake + - uses: dtolnay/rust-toolchain@stable + - run: cargo check --manifest-path rust/Cargo.toml + + test: + strategy: + fail-fast: true + matrix: + os: [ubuntu-latest] #, windows-latest, macos-latest] + include: + - os: ubuntu-latest + name: linux + # - os: windows-latest + # name: windows + # - os: macos-latest + # name: mac + + name: "Test on ${{ matrix.name }} (cargo test)" + runs-on: ${{ matrix.os }} + steps: + - uses: actions/checkout@v2 + - uses: actions/cache@v3 + with: + path: | + ~/.cargo/bin/ + ~/.cargo/registry/index/ + ~/.cargo/registry/cache/ + ~/.cargo/git/db/ + target/ + rust/target/ + CMakeFiles/ + CMakeCache.txt + CPackConfig.cmake + CPackSourceConfig.cmake + CTestTestfile.cmake + key: ${{ runner.os }}-cargo-${{ hashFiles('**/Cargo.lock') }} + - run: sudo apt-get update && sudo apt-get install cmake + - uses: dtolnay/rust-toolchain@stable + - run: cargo test --manifest-path rust/Cargo.toml + env: + RUST_BACKTRACE: "1" + + + fmt: + name: "Format (cargo fmt)" + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v2 + - uses: actions/cache@v3 + with: + path: | + ~/.cargo/bin/ + ~/.cargo/registry/index/ + ~/.cargo/registry/cache/ + ~/.cargo/git/db/ + target/ + rust/target/ + CMakeFiles/ + CMakeCache.txt + CPackConfig.cmake + CPackSourceConfig.cmake + CTestTestfile.cmake + key: ${{ runner.os }}-cargo-${{ hashFiles('**/Cargo.lock') }} + - uses: dtolnay/rust-toolchain@stable + with: + components: rustfmt + - run: cargo fmt --manifest-path rust/Cargo.toml --all -- --check + + clippy: + name: "Clippy (cargo clippy)" + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v2 + - uses: actions/cache@v3 + with: + path: | + ~/.cargo/bin/ + ~/.cargo/registry/index/ + ~/.cargo/registry/cache/ + ~/.cargo/git/db/ + target/ + rust/target/ + CMakeFiles/ + CMakeCache.txt + CPackConfig.cmake + CPackSourceConfig.cmake + CTestTestfile.cmake + key: ${{ runner.os }}-cargo-${{ hashFiles('**/Cargo.lock') }} + - run: sudo apt-get update && sudo apt-get install cmake + - uses: dtolnay/rust-toolchain@nightly + with: + components: clippy + - run: cargo clippy --manifest-path rust/Cargo.toml -- -D warnings + + doc: + name: "Docs (cargo doc) & Pub" + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v2 + - uses: actions/cache@v3 + with: + path: | + ~/.cargo/bin/ + ~/.cargo/registry/index/ + ~/.cargo/registry/cache/ + ~/.cargo/git/db/ + target/ + rust/target/ + CMakeFiles/ + CMakeCache.txt + CPackConfig.cmake + CPackSourceConfig.cmake + CTestTestfile.cmake + key: ${{ runner.os }}-cargo-${{ hashFiles('**/Cargo.lock') }} + - run: sudo apt-get update && sudo apt-get install cmake + - uses: dtolnay/rust-toolchain@stable + # test docs for everything + - name: Test build all docs + run: cargo doc --manifest-path rust/Cargo.toml --no-deps + # create docs for the crate we care about + - name: Build docs for publish + run: | + rm -rf target/doc/ rust/target/doc/ + cargo doc --manifest-path rust/mariadb/Cargo.toml --no-deps + cargo doc --manifest-path rust/bindings/Cargo.toml --no-deps + - run: | + echo `pwd`/rust/target/doc >> $GITHUB_PATH + # fake index.html so github likes us + echo "" > rust/target/doc/index.html + - name: Deploy GitHub Pages + run: | + git worktree add gh-pages + git config user.name "Deploy from CI" + git config user.email "" + cd gh-pages + # Delete the ref to avoid keeping history. + git update-ref -d refs/heads/gh-pages + rm -rf * + mv ../rust/target/doc/* . + git add . + git commit -m "Deploy $GITHUB_SHA to gh-pages" + git push --force --set-upstream origin gh-pages diff --git a/.gitignore b/.gitignore index 2fb3857120c1f..d427014224d7c 100644 --- a/.gitignore +++ b/.gitignore @@ -622,3 +622,7 @@ tests/mariadb-client-test versioninfo_dll.rc versioninfo_exe.rc win/packaging/ca/symlinks.cc + +# rust output +rust/target +Cargo.lock diff --git a/README.md b/README.md index 58dbf105fb90a..e09f97ad69769 100644 --- a/README.md +++ b/README.md @@ -1,3 +1,13 @@ +# MariaDB: Rust Support Fork + +This is a fork of [MariaDB server](https://github.com/MariaDB/server) with the +goal of experimenting with a Rust plugin API. It is a work in progress, +everything here should be considered unstable. + +View docs for the Rust API here: +. + + Code status: ------------ diff --git a/docker_build.sh b/docker_build.sh new file mode 100644 index 0000000000000..b9d7da7362e97 --- /dev/null +++ b/docker_build.sh @@ -0,0 +1,11 @@ +# docker run -it -v $(pwd):/build/server ubuntu +apt-get update +apt-get build-dep mariadb-server +apt-get install -y build-essential libncurses5-dev gnutls-dev bison zlib1g-dev ccache g++ cmake ninja-build vim wget + +cd build +mkdir build-mariadb-server-debug +cd build-mariadb-server-debug +cmake ../server -DCONC_WITH_{UNITTEST,SSL}=OFF -DWITH_UNIT_TESTS=OFF -DCMAKE_BUILD_TYPE=Debug -DWITH_SAFEMALLOC=OFF -DWITH_SSL=bundled -DMYSQL_MAINTAINER_MODE=OFF -G Ninja + +# cmake --build . --parallel 4 diff --git a/include/mysql/plugin_auth_common.h b/include/mysql/plugin_auth_common.h index 8edd712875461..7bbdca5aae214 100644 --- a/include/mysql/plugin_auth_common.h +++ b/include/mysql/plugin_auth_common.h @@ -128,4 +128,3 @@ typedef struct st_plugin_vio } MYSQL_PLUGIN_VIO; #endif - diff --git a/rust/.clippy.toml b/rust/.clippy.toml new file mode 100644 index 0000000000000..96995253d74c2 --- /dev/null +++ b/rust/.clippy.toml @@ -0,0 +1 @@ +doc-valid-idents = ["MySQL", "MariaDB"] diff --git a/rust/.rustfmt.toml b/rust/.rustfmt.toml new file mode 100644 index 0000000000000..4dc35d3684f6d --- /dev/null +++ b/rust/.rustfmt.toml @@ -0,0 +1,6 @@ +imports_granularity = "Module" +newline_style = "Unix" +group_imports = "StdExternalCrate" +format_code_in_doc_comments = true +format_macro_bodies = true +format_macro_matchers = true diff --git a/rust/Cargo.toml b/rust/Cargo.toml new file mode 100644 index 0000000000000..cd5c15254140e --- /dev/null +++ b/rust/Cargo.toml @@ -0,0 +1,13 @@ +# Define our crates within a workspace + +[workspace] + +members = [ + "bindings", + "mariadb", + "macros", + "examples/keymgt-basic", + "examples/keymgt-debug", + "examples/encryption", + "plugins/keymgt-clevis", +] diff --git a/rust/Dockerfile b/rust/Dockerfile new file mode 100644 index 0000000000000..206d657aa251a --- /dev/null +++ b/rust/Dockerfile @@ -0,0 +1,65 @@ +# Quick test for our example plugins, build against the current repo but runs +# with the published 10.11 image +# +# ``` +# # Build the image. Change the directory (../) if not building in `rust/` +# docker build -f Dockerfile ../ --tag mdb-plugin-ex +# +# # Run the container, select default plugins as desired +# docker run --rm -e MARIADB_ROOT_PASSWORD=example --name mdb-plugin-ex-c \ +# mdb-plugin-ex \ +# --plugin-maturity=experimental \ +# --plugin-load=libbasic \ +# --plugin-load=libencryption +# --plugin-load=libkeymgt_debug +# +# # Enter a SQL console +# docker exec -it mdb-plugin-ex-c mysql -pexample +# +# # Install desired plugins +# INSTALL PLUGIN basic_key_management SONAME 'libbasic.so'; +# INSTALL PLUGIN encryption_example SONAME 'libencryption.so'; +# INSTALL PLUGIN debug_key_management SONAME 'libkeymgt_debug.so'; +# +# # Stop server +# docker stop mdb-plugin-ex-c +# ``` + +# use nighlty image for faster builds +FROM rustlang/rust:nightly AS build + +ENV CARGO_UNSTABLE_SPARSE_REGISTRY=true +WORKDIR /build + +RUN apt-get update \ + # build requirements + && apt-get install -y cmake clang bison \ + && mkdir /output + +COPY . . + +WORKDIR /build/rust + +RUN --mount=type=cache,target=/usr/local/cargo/registry \ + --mount=type=cache,target=target \ + cargo build \ + # -p basic + # -p encryption + # --release \ + # && cp target/release/*.so /output + && cp target/debug/*.so /output + +# RUN cp target/debug/*.so /output + +RUN export RUST_BACKTRACE=1 + +FROM mariadb:10.11-rc + +# Deb utils +RUN apt-get update \ + && apt-get install -y xxd less vim-tiny binutils + +COPY --from=build /output/* /usr/lib/mysql/plugin/ + +# create database db; use db; create table t1 (id int) encrypted=yes; +# flush tables t1 for export; diff --git a/rust/README.md b/rust/README.md new file mode 100644 index 0000000000000..80850afbdd38c --- /dev/null +++ b/rust/README.md @@ -0,0 +1,33 @@ +# Rust support for MariaDB + +The purpose of this module is to be able to write plugins for MariaDB in Rust. + +## Building + +The Rust portion of this repository does not yet integrate with the main MariaDB +CMake build system to statically link plugins (adding this is a goal). + +To build dynamically, simply run `cargo build` within this `/rust` directory. + +## Testing with Docker + + +```sh +# Build the image. Change the directory (../) if not building in `rust/` +docker build -f Dockerfile ../ --tag mdb-plugin-ex + +# Run the container +docker run --rm -e MARIADB_ROOT_PASSWORD=example --name mdb-plugin-ex-c \ + mdb-plugin-ex \ + --plugin-maturity=experimental +# --plugin-load=libbasic \ +# --plugin-load=libencryption +# --plugin-load=libdebug_key_management + +# Enter a SQL console +docker exec -it mdb-plugin-ex-c mysql -pexample + +# Install desired plugins +INSTALL PLUGIN basic_key_management SONAME 'libbasic.so'; +INSTALL PLUGIN encryption_example SONAME 'libencryption_example.so'; +``` diff --git a/rust/bindings/Cargo.toml b/rust/bindings/Cargo.toml new file mode 100644 index 0000000000000..4417de0938517 --- /dev/null +++ b/rust/bindings/Cargo.toml @@ -0,0 +1,10 @@ +[package] +name = "mariadb-sys" +version = "0.1.0" +edition = '2021' + + +[build-dependencies] +bindgen = "0.63.0" +cmake = "0.1" +doxygen-rs = "0.2" diff --git a/rust/bindings/build.rs b/rust/bindings/build.rs new file mode 100644 index 0000000000000..026486b89da48 --- /dev/null +++ b/rust/bindings/build.rs @@ -0,0 +1,126 @@ +//! This file runs `cmake` as needed, then `bindgen` to produce the rust bindings + +use std::collections::HashSet; +use std::env; +use std::path::PathBuf; +use std::process::Command; + +use bindgen::callbacks::{DeriveInfo, MacroParsingBehavior, ParseCallbacks}; +use bindgen::EnumVariation; + +// `math.h` seems to double define some things, To avoid this, we ignore them. +const IGNORE_MACROS: [&str; 20] = [ + "FE_DIVBYZERO", + "FE_DOWNWARD", + "FE_INEXACT", + "FE_INVALID", + "FE_OVERFLOW", + "FE_TONEAREST", + "FE_TOWARDZERO", + "FE_UNDERFLOW", + "FE_UPWARD", + "FP_INFINITE", + "FP_INT_DOWNWARD", + "FP_INT_TONEAREST", + "FP_INT_TONEARESTFROMZERO", + "FP_INT_TOWARDZERO", + "FP_INT_UPWARD", + "FP_NAN", + "FP_NORMAL", + "FP_SUBNORMAL", + "FP_ZERO", + "IPPORT_RESERVED", +]; + +const DERIVE_COPY_NAMES: [&str; 1] = ["enum_field_types"]; + +#[derive(Debug)] +struct BuildCallbacks(HashSet); + +impl ParseCallbacks for BuildCallbacks { + /// Ignore macros that are in the ignored list + fn will_parse_macro(&self, name: &str) -> MacroParsingBehavior { + if self.0.contains(name) { + MacroParsingBehavior::Ignore + } else { + MacroParsingBehavior::Default + } + } + + /// Use a converter to turn doxygen comments into rustdoc + fn process_comment(&self, comment: &str) -> Option { + Some(doxygen_rs::transform(comment)) + } + + fn add_derives(&self, _info: &DeriveInfo<'_>) -> Vec { + if DERIVE_COPY_NAMES.contains(&_info.name) { + vec!["Copy".to_owned()] + } else { + vec![] + } + } +} + +impl BuildCallbacks { + fn new() -> Self { + Self(IGNORE_MACROS.into_iter().map(|s| s.to_owned()).collect()) + } +} + +fn main() { + // Tell cargo to invalidate the built crate whenever the wrapper changes + println!("cargo:rerun-if-changed=src/wrapper.h"); + + // Run cmake to configure only + Command::new("cmake") + .args(["../../", "-B../../"]) + .output() + .expect("failed to invoke cmake"); + + // The bindgen::Builder is the main entry point + // to bindgen, and lets you build up options for + // the resulting bindings. + let bindings = bindgen::Builder::default() + // The input header we would like to generate + // bindings for. + .header("src/wrapper.h") + // Fix math.h double defines + .parse_callbacks(Box::new(BuildCallbacks::new())) + .clang_arg("-I../../include") + .clang_arg("-I../../sql") + .clang_arg("-xc++") + .clang_arg("-std=c++17") + // Don't derive copy for structs + .derive_copy(false) + // Use rust-style enums labeled with non_exhaustive to represent C enums + .default_enum_style(EnumVariation::Rust { + non_exhaustive: true, + }) + // LLVM has some issues with long dobule and ABI compatibility + // disabling the only relevant function here to suppress errors + .blocklist_function("strfroml") + .blocklist_function("strfromf64x") + .blocklist_function("strtof64x_l") + .blocklist_function("strtof64x") + .blocklist_function("strtold") + .blocklist_function("strtold_l") + // qvct, evct, qfcvt_r, ... + .blocklist_function("[a-z]{1,2}cvt(?:_r)?") + // c++ things that aren't supported + .blocklist_item("List_iterator") + .blocklist_type("std::char_traits") + .opaque_type("std_.*") + .blocklist_item("std_basic_string") + .blocklist_item("std_collate.*") + .blocklist_item("__gnu_cxx.*") + // Finish the builder and generate the bindings. + .generate() + // Unwrap the Result and panic on failure. + .expect("Unable to generate bindings"); + + // Write the bindings to the $OUT_DIR/bindings.rs file. + let out_path = PathBuf::from(env::var("OUT_DIR").unwrap()); + bindings + .write_to_file(out_path.join("bindings.rs")) + .expect("couldn't write bindings"); +} diff --git a/rust/bindings/src/hand_impls.rs b/rust/bindings/src/hand_impls.rs new file mode 100644 index 0000000000000..c70f6734d6f53 --- /dev/null +++ b/rust/bindings/src/hand_impls.rs @@ -0,0 +1,122 @@ +use std::ffi::{c_char, c_double, c_int, c_long, c_longlong, c_uint, c_ulong, c_ulonglong}; + +use super::{mysql_var_check_func, mysql_var_update_func, TYPELIB}; + +// Defined in service_encryption.h but not imported because of tilde syntax +pub const ENCRYPTION_KEY_VERSION_INVALID: c_uint = !0; + +#[allow(dead_code)] // not sure why this lint hits +pub const PLUGIN_VAR_MASK: u32 = super::PLUGIN_VAR_READONLY + | super::PLUGIN_VAR_NOSYSVAR + | super::PLUGIN_VAR_NOCMDOPT + | super::PLUGIN_VAR_NOCMDARG + | super::PLUGIN_VAR_OPCMDARG + | super::PLUGIN_VAR_RQCMDARG + | super::PLUGIN_VAR_DEPRECATED + | super::PLUGIN_VAR_MEMALLOC; + +// We hand write these stucts because the definition is tricky, not all fields are +// always present + +// no support for THD yet +macro_rules! declare_sysvar_type { + (@common $name:ident: $(#[$doc:meta] $fname:ident: $fty:ty),* $(,)*) => { + // Common implementation + #[repr(C)] + #[derive(Debug)] + pub struct $name { + /// Variable flags + pub flags: c_int, + /// Name of the variable + pub name: *const c_char, + /// Variable description + pub comment: *const c_char, + /// Function for getting the variable + pub check: mysql_var_check_func, + /// Function for setting the variable + pub update: mysql_var_update_func, + + // Repeated fields + $( + #[$doc] + pub $fname: $fty + ),* + } + }; + (basic: $name:ident, $ty:ty) => { + // A "basic" sysvar + declare_sysvar_type!{ + @common $name: + #[doc = "Pointer to the value"] + value: *mut $ty, + #[doc = "Default value"] + def_val: $ty, + } + }; + (const basic: $name:ident, $ty:ty) => { + // A "basic" sysvar + declare_sysvar_type!{ + @common $name: + #[doc = "Pointer to the value"] + value: *const $ty, + #[doc = "Default value"] + def_val: $ty, + } + }; + (simple: $name:ident, $ty:ty) => { + // A "simple" sysvar, with minimum maximum and block size + declare_sysvar_type!{ + @common $name: + #[doc = "Pointer to the value"] + value: *mut $ty, + #[doc = "Default value"] + def_val: $ty, + #[doc = "Min value"] + min_val: $ty, + #[doc = "Max value"] + max_val: $ty, + #[doc = "Block size"] + blk_sz: $ty, + } + }; + (typelib: $name:ident, $ty:ty) => { + // A "typelib" sysvar + declare_sysvar_type!{ + @common $name: + #[doc = "Pointer to the value"] + value: *mut $ty, + #[doc = "Default value"] + def_val: $ty, + #[doc = "Typelib"] + typelib: *const TYPELIB + } + }; + + + // (typelib: $name:ident, $ty:ty) => { + + // }; + // (thd: $name:ident, $ty:ty) => { + + // }; +} + +declare_sysvar_type!(@common sysvar_common_t:); +declare_sysvar_type!(basic: sysvar_bool_t, bool); +declare_sysvar_type!(basic: sysvar_str_t, *mut c_char); +declare_sysvar_type!(typelib: sysvar_enum_t, c_ulong); +declare_sysvar_type!(typelib: sysvar_set_t, c_ulonglong); +declare_sysvar_type!(simple: sysvar_int_t, c_int); +declare_sysvar_type!(simple: sysvar_long_t, c_long); +declare_sysvar_type!(simple: sysvar_longlong_t, c_longlong); +declare_sysvar_type!(simple: sysvar_uint_t, c_uint); +declare_sysvar_type!(simple: sysvar_ulong_t, c_ulong); +declare_sysvar_type!(simple: sysvar_ulonglong_t, c_ulonglong); +declare_sysvar_type!(simple: sysvar_double_t, c_double); + +// declare_sysvar_type!(thdbasic: thdvar_bool_t, bool); +// declare_sysvar_type!(thdbasic: thdvar_str_t, *mut c_char); +// declare_sysvar_type!(typelib: sysvar_enum_t, c_ulong); +// declare_sysvar_type!(typelib: sysvar_set_t, c_ulonglong); + +// type THDVAR_FUNC = Option *mut T>; diff --git a/rust/bindings/src/lib.rs b/rust/bindings/src/lib.rs new file mode 100644 index 0000000000000..402b4d3501ea8 --- /dev/null +++ b/rust/bindings/src/lib.rs @@ -0,0 +1,15 @@ +//! Bindings module +//! +//! Autogenerated bindings to C interfaces for Rust +#![allow(non_upper_case_globals)] +#![allow(non_camel_case_types)] +#![allow(non_snake_case)] +#![allow(clippy::useless_transmute)] +#![allow(clippy::too_many_arguments)] +#![allow(clippy::missing_safety_doc)] + +// Bindings are autogenerated at build time using build.rs +include!(concat!(env!("OUT_DIR"), "/bindings.rs")); + +mod hand_impls; +pub use hand_impls::*; diff --git a/rust/bindings/src/wrapper.h b/rust/bindings/src/wrapper.h new file mode 100644 index 0000000000000..cdc2c07e514f3 --- /dev/null +++ b/rust/bindings/src/wrapper.h @@ -0,0 +1,6 @@ +// Directives here indicate what to include in bindings + +// #include +#include +#include +#include diff --git a/rust/examples/encryption/Cargo.toml b/rust/examples/encryption/Cargo.toml new file mode 100644 index 0000000000000..71e17a51ce84b --- /dev/null +++ b/rust/examples/encryption/Cargo.toml @@ -0,0 +1,13 @@ +[package] +name = "encryption-example" +version = "0.1.0" +edition = '2021' + +[lib] +crate-type = ["cdylib"] + +[dependencies] +mariadb = { path = "../../mariadb" } +aes-gcm = "0.10" +sha2 = "0.10" +rand = "0.8" diff --git a/rust/examples/encryption/src/lib.rs b/rust/examples/encryption/src/lib.rs new file mode 100644 index 0000000000000..763f0278d22c3 --- /dev/null +++ b/rust/examples/encryption/src/lib.rs @@ -0,0 +1,168 @@ +//! Basic encryption plugin using: +//! +//! - SHA256 as the hasher + +#![allow(unused)] + +use std::sync::Mutex; +use std::time::{Duration, Instant}; + +use aes_gcm::{ + aead::{Aead, KeyInit, OsRng}, + Aes256Gcm, + Nonce, // Or `Aes128Gcm` +}; +use mariadb::plugin::encryption::{Encryption, EncryptionError, Flags, KeyError, KeyManager}; +use mariadb::plugin::*; +use rand::Rng; +use sha2::{Digest, Sha256 as Hasher}; + +/// Range of key rotations, as seconds +const KEY_ROTATION_MIN: f32 = 45.0; +const KEY_ROTATION_MAX: f32 = 90.0; +const KEY_ROTATION_INTERVAL: f32 = KEY_ROTATION_MAX - KEY_ROTATION_MIN; +const SHA256_SIZE: usize = 32; +// const KEY_ROTATION_INTERVAL: Duration = +// KEY_ROTATION_MAX - KEY_ROTATION_MIN; + +/// Our global key version state +static KEY_VERSIONS: Mutex> = Mutex::new(None); + +/// Contain the state of our keys. We use `Instant` (the monotonically) +/// increasing clock) instead of `SystemTime` (which may occasionally go +/// backwards) +#[derive(Debug)] +struct KeyVersions { + /// Initialization time of the struct, reference point for key version + start: Instant, + /// Most recent key update time + current: Instant, + /// Next time for a key update + next: Instant, +} + +impl KeyVersions { + /// Initialize with a new value. Returns the struct + fn new_now() -> Self { + let now = Instant::now(); + let mut ret = Self { + start: now, + current: now, + next: now, + }; + ret.update_next(); + ret + } + + fn update_next(&mut self) { + let mult = rand::thread_rng().gen_range(0.0..1.0); + let add_duration = KEY_ROTATION_MIN + mult * KEY_ROTATION_INTERVAL; + self.next += Duration::from_secs_f32(add_duration); + } + + /// Update the internal duration if needed, and return the elapsed time + fn update_returning_version(&mut self) -> u64 { + let now = Instant::now(); + if now > self.next { + self.current = now; + self.update_next(); + } + (self.next - self.start).as_secs() + } +} + +struct RustEncryption; + +impl Init for RustEncryption { + /// Initialize function: + fn init() -> Result<(), InitError> { + eprintln!("init called for RustEncryption"); + let mut guard = KEY_VERSIONS.lock().unwrap(); + *guard = Some(KeyVersions::new_now()); + Ok(()) + } + + fn deinit() -> Result<(), InitError> { + eprintln!("deinit called for RustEncryption"); + Ok(()) + } +} + +impl KeyManager for RustEncryption { + fn get_latest_key_version(_key_id: u32) -> Result { + dbg!(_key_id); + let mut guard = KEY_VERSIONS.lock().unwrap(); + let mut vers = guard.as_mut().unwrap(); + Ok(vers.update_returning_version() as u32) + } + + /// Given a key ID and a version, create its hash + fn get_key(key_id: u32, key_version: u32, dst: &mut [u8]) -> Result<(), KeyError> { + dbg!(key_id, key_version, dst.len()); + + let output_size = Hasher::output_size(); + if dst.len() < output_size { + return Err(KeyError::BufferTooSmall); + } + let mut hasher = Hasher::new(); + hasher.update(key_id.to_ne_bytes()); + hasher.update(key_version.to_ne_bytes()); + dst[..output_size].copy_from_slice(&hasher.finalize()); + Ok(()) + } + + fn key_length(key_id: u32, key_version: u32) -> Result { + dbg!(key_id, key_version); + // All keys have the same length + Ok(Hasher::output_size()) + } +} + +impl Encryption for RustEncryption { + fn init( + key_id: u32, + key_version: u32, + key: &[u8], + iv: &[u8], + flags: Flags, + ) -> Result { + eprintln!("encryption init"); + dbg!(&key_id, &key_version); + eprintln!("key: {:x?}", &key); + eprintln!("iv: {:x?}", &iv); + dbg!(flags); + Ok(Self) + } + + fn update(&mut self, src: &[u8], dst: &mut [u8]) -> Result<(), EncryptionError> { + eprintln!("encryption update"); + dbg!(src.len(), dst.len()); + dst[..src.len()].copy_from_slice(src); + Ok(()) + } + + fn finish(&mut self, dst: &mut [u8]) -> Result<(), EncryptionError> { + eprintln!("encryption finish"); + dbg!(dst.len()); + Ok(()) + } + + fn encrypted_length(key_id: u32, key_version: u32, src_len: usize) -> usize { + eprintln!("encryption length"); + dbg!(key_id, key_version, src_len); + src_len + } +} + +register_plugin! { + RustEncryption, + ptype: PluginType::MariaEncryption, + name: "encryption_example", + author: "Trevor Gross", + description: "Example key management / encryption plugin", + license: License::Gpl, + maturity: Maturity::Experimental, + version: "0.1", + init: RustEncryption, // optional + encryption: true, +} diff --git a/rust/examples/encryption/test.sql b/rust/examples/encryption/test.sql new file mode 100644 index 0000000000000..df3059e00e12b --- /dev/null +++ b/rust/examples/encryption/test.sql @@ -0,0 +1,19 @@ +SET GLOBAL innodb_encryption_threads=1; +SET GLOBAL innodb_encrypt_tables=ON; +SET SESSION innodb_default_encryption_key_id=100; + +CREATE DATABASE db; +USE db; + +CREATE TABLE t1 ( + id int PRIMARY KEY, + str varchar(50) +); + +INSERT INTO t1(id, str) VALUES + (1, 'abc'), + (2, 'def'), + (3, 'ghi'), + (4, 'jkl'); + +FLUSH TABLES t1 FOR EXPORT; diff --git a/rust/examples/keymgt-basic/Cargo.toml b/rust/examples/keymgt-basic/Cargo.toml new file mode 100644 index 0000000000000..28031957effc9 --- /dev/null +++ b/rust/examples/keymgt-basic/Cargo.toml @@ -0,0 +1,10 @@ +[package] +name = "keymgt-basic" +version = "0.1.0" +edition = '2021' + +[lib] +crate-type = ["cdylib"] + +[dependencies] +mariadb = { path = "../../mariadb" } diff --git a/rust/examples/keymgt-basic/src/lib.rs b/rust/examples/keymgt-basic/src/lib.rs new file mode 100644 index 0000000000000..0a1d6968a344b --- /dev/null +++ b/rust/examples/keymgt-basic/src/lib.rs @@ -0,0 +1,72 @@ +//! Debug key management +//! +//! Use to debug the encryption code with a fixed key that changes only on user +//! request. The only valid key ID is 1. +//! +//! EXAMPLE ONLY: DO NOT USE IN PRODUCTION! + +#![allow(unused)] + +use std::cell::UnsafeCell; +use std::ffi::c_void; +use std::sync::atomic::{AtomicU32, AtomicUsize, Ordering}; + +use mariadb::plugin::encryption::{Encryption, Flags, KeyError, KeyManager}; +use mariadb::plugin::{register_plugin, PluginType, SysVarOpt, *}; + +struct BasicKeyMgt; + +static COUNTER: AtomicUsize = AtomicUsize::new(30); + +impl Init for BasicKeyMgt { + fn init() -> Result<(), InitError> { + eprintln!("init for BasicKeyMgt"); + Ok(()) + } + + fn deinit() -> Result<(), InitError> { + eprintln!("deinit for BasicKeyMgt"); + Ok(()) + } +} + +impl KeyManager for BasicKeyMgt { + fn get_latest_key_version(key_id: u32) -> Result { + eprintln!("get latest key version with {key_id}"); + static KCOUNT: AtomicU32 = AtomicU32::new(1); + let ret = KCOUNT.fetch_add(1, Ordering::Relaxed); + dbg!(ret); + Ok(ret) + } + + fn get_key(key_id: u32, key_version: u32, dst: &mut [u8]) -> Result<(), KeyError> { + let s = format!("get_key: {key_id}:{key_version}"); + eprintln!("{s}, {}", dst.len()); + + if dst.len() < dbg!(COUNTER.fetch_add(1, Ordering::Relaxed)) { + return Err(KeyError::BufferTooSmall); + } + + // Copy our slice to the buffer, return the copied length + dst[..s.len()].copy_from_slice(s.as_str().as_bytes()); + Ok(()) + } + + fn key_length(key_id: u32, key_version: u32) -> Result { + eprintln!("get key length with {key_id}:{key_version}"); + Ok(COUNTER.load(Ordering::Relaxed)) + } +} + +register_plugin! { + BasicKeyMgt, + ptype: PluginType::MariaEncryption, + name: "basic_key_management", + author: "Trevor Gross", + description: "Basic key management plugin", + license: License::Gpl, + maturity: Maturity::Experimental, + version: "0.1", + init: BasicKeyMgt, // optional + encryption: false, +} diff --git a/rust/examples/keymgt-debug/Cargo.toml b/rust/examples/keymgt-debug/Cargo.toml new file mode 100644 index 0000000000000..0bf4ca5980ba2 --- /dev/null +++ b/rust/examples/keymgt-debug/Cargo.toml @@ -0,0 +1,10 @@ +[package] +name = "keymgt-debug" +version = "0.1.0" +edition = '2021' + +[lib] +crate-type = ["cdylib"] + +[dependencies] +mariadb = { path = "../../mariadb" } diff --git a/rust/examples/keymgt-debug/src/lib.rs b/rust/examples/keymgt-debug/src/lib.rs new file mode 100644 index 0000000000000..2a6c12c5ee0b7 --- /dev/null +++ b/rust/examples/keymgt-debug/src/lib.rs @@ -0,0 +1,129 @@ +//! Debug key management +//! +//! Use to debug the encryption code with a fixed key that changes only on user +//! request. The only valid key ID is 1. +//! +//! EXAMPLE ONLY: DO NOT USE IN PRODUCTION! + +#![allow(unused)] + +use std::cell::UnsafeCell; +use std::ffi::c_void; +use std::sync::atomic::{AtomicI32, AtomicU32, Ordering}; +use std::sync::Mutex; + +use mariadb::log::{self, debug, trace}; +use mariadb::plugin::encryption::{Encryption, Flags, KeyError, KeyManager}; +use mariadb::plugin::{ + register_plugin, Init, InitError, License, Maturity, PluginType, SysVarConstString, SysVarOpt, + SysVarString, +}; + +const KEY_LENGTH: usize = 4; +static KEY_VERSION: AtomicU32 = AtomicU32::new(1); +static TEST_SYSVAR_CONST_STR: SysVarConstString = SysVarConstString::new(); +static TEST_SYSVAR_STR: SysVarString = SysVarString::new(); +static TEST_SYSVAR_I32: AtomicI32 = AtomicI32::new(10); + +struct DebugKeyMgmt; + +impl Init for DebugKeyMgmt { + fn init() -> Result<(), InitError> { + log::set_max_level(log::LevelFilter::Trace); + debug!("DebugKeyMgmt get_latest_key_version"); + trace!( + "current const str sysvar: {:?}", + TEST_SYSVAR_CONST_STR.get() + ); + trace!("current str sysvar: {:?}", TEST_SYSVAR_STR.get()); + trace!( + "current sysvar: {}", + TEST_SYSVAR_I32.load(Ordering::Relaxed) + ); + Ok(()) + } + + fn deinit() -> Result<(), InitError> { + eprintln!("deinit for DebugKeyMgmt"); + Ok(()) + } +} + +impl KeyManager for DebugKeyMgmt { + fn get_latest_key_version(key_id: u32) -> Result { + debug!("DebugKeyMgmt get_latest_key_version"); + if key_id != 1 { + Err(KeyError::VersionInvalid) + } else { + Ok(KEY_VERSION.load(Ordering::Relaxed)) + } + } + + fn get_key(key_id: u32, key_version: u32, dst: &mut [u8]) -> Result<(), KeyError> { + debug!("DebugKeyMgmt get_key"); + if key_id != 1 { + return Err(KeyError::VersionInvalid); + } + + // Convert our integer to a native endian byte array + let key_buf = KEY_VERSION.load(Ordering::Relaxed).to_ne_bytes(); + + if dst.len() < key_buf.len() { + return Err(KeyError::BufferTooSmall); + } + + // Copy our slice to the buffer, return the copied length + dst[..key_buf.len()].copy_from_slice(key_buf.as_slice()); + Ok(()) + } + + fn key_length(key_id: u32, key_version: u32) -> Result { + debug!("DebugKeyMgmt key_length"); + // Return the length of our u32 in bytes + // Just verify our types don't change + debug_assert_eq!( + KEY_LENGTH, + KEY_VERSION.load(Ordering::Relaxed).to_ne_bytes().len() + ); + Ok(KEY_LENGTH) + } +} + +register_plugin! { + DebugKeyMgmt, + ptype: PluginType::MariaEncryption, + name: "debug_key_management", + author: "Trevor Gross", + description: "Debug key management plugin", + license: License::Gpl, + maturity: Maturity::Experimental, + version: "0.2", + init: DebugKeyMgmt, // optional + encryption: false, + variables: [ + SysVar { + ident: TEST_SYSVAR_CONST_STR, + vtype: SysVarConstString, + name: "test_sysvar_const_string", + description: "this is a description", + options: [SysVarOpt::OptCmdArd], + default: "default value" + }, + SysVar { + ident: TEST_SYSVAR_STR, + vtype: SysVarString, + name: "test_sysvar_string", + description: "this is a description", + options: [SysVarOpt::OptCmdArd], + default: "other default value" + }, + SysVar { + ident: TEST_SYSVAR_I32, + vtype: AtomicI32, + name: "test_sysvar_i32", + description: "this is a description", + options: [SysVarOpt::OptCmdArd], + default: 67 + } + ] +} diff --git a/rust/macros/Cargo.toml b/rust/macros/Cargo.toml new file mode 100644 index 0000000000000..8288921d27429 --- /dev/null +++ b/rust/macros/Cargo.toml @@ -0,0 +1,24 @@ +[package] +name = "mariadb-macros" +version = "0.1.0" +edition = "2021" + +[lib] +proc-macro = true + +[[test]] +name = "tests" +path = "tests/entry.rs" + +[dependencies] +# heck = "0.4.0" +# lazy_static = "1.4.0" +proc-macro2 = "1.0.49" +quote = "1.0.23" +syn = { version = "1.0.107", features = [], default-features = false } +# syn = { version = "1.0.107", features = ["full", "extra-traits", "parsing"] } + +[dev-dependencies] +# trybuild = { version = "1.0.65", features = ["diff"] } +mariadb = { path = "../mariadb" } +trybuild = { version = "1.0.77", features = ["diff"] } diff --git a/rust/macros/README.md b/rust/macros/README.md new file mode 100644 index 0000000000000..88494c0b5e24f --- /dev/null +++ b/rust/macros/README.md @@ -0,0 +1,3 @@ +# Macros + +This directory contains procedural macros used by the Rust portion of MariaDB diff --git a/rust/macros/src/fields.rs b/rust/macros/src/fields.rs new file mode 100644 index 0000000000000..9a7a06d182683 --- /dev/null +++ b/rust/macros/src/fields.rs @@ -0,0 +1,55 @@ +pub mod plugin { + /// All fields, in expected order + pub const ALL_FIELDS: &[&str] = &[ + "ptype", + "name", + "author", + "description", + "license", + "maturity", + "version", + "init", + "encryption", + "variables", + ]; + + /// Always required + pub const REQ_FIELDS: &[&str] = &[ + "ptype", + "name", + "author", + "description", + "license", + "maturity", + "version", + ]; + + pub const ENCR_REQ_FIELDS: &[&str] = &["encryption"]; + + pub const ENCR_OPT_FIELDS: &[&str] = &["init", "sysvars"]; +} + +pub mod sysvar { + /// All fields, in expected order + pub const ALL_FIELDS: &[&str] = &[ + "ident", + "vtype", + "name", + "description", + "options", + "default", + "min", + "max", + "interval", + ]; + + /// Always required + pub const REQ_FIELDS: &[&str] = &["ident", "vtype", "name", "description"]; + pub const OPT_FIELDS: &[&str] = &["default", "min", "max", "interval"]; + + // unused since we switched to full generics + pub const _STR_REQ_FIELDS: &[&str] = &[]; + pub const _STR_OPT_FIELDS: &[&str] = &["default"]; + pub const _NUM_REQ_FIELDS: &[&str] = &[]; + pub const _NUM_OPT_FIELDS: &[&str] = &["default", "min", "max", "interval"]; +} diff --git a/rust/macros/src/helpers.rs b/rust/macros/src/helpers.rs new file mode 100644 index 0000000000000..edb0ce75c7a8e --- /dev/null +++ b/rust/macros/src/helpers.rs @@ -0,0 +1,36 @@ +use proc_macro2::Span; +use syn::{parse_quote, Error, Expr, Ident, Lit, LitStr}; + +/// Get the field as a boolean +pub fn expect_bool(field_opt: &Option) -> syn::Result { + let field = field_opt.as_ref().unwrap(); + + if field == &parse_quote! {true} { + Ok(true) + } else if field == &parse_quote! {false} { + Ok(false) + } else { + let msg = "unexpected value: only 'true' or 'false' allowed"; + Err(Error::new_spanned(field, msg)) + } +} + +/// Expect a literal string, error if that's not the case +pub fn expect_litstr(field_opt: &Option) -> syn::Result<&LitStr> { + let field = field_opt.as_ref().unwrap(); + let Expr::Lit(lit) = field else { // got non-literal + let msg = "expected literal expression for this field"; + return Err(Error::new_spanned(field, msg)); + }; + let Lit::Str(litstr) = &lit.lit else { // got literal that wasn't a string + let msg = "only literal strings are allowed for this field"; + return Err(Error::new_spanned(field, msg)); + }; + + Ok(litstr) +} + +/// Create an identifier from a string with span at the macro call site +pub fn make_ident(s: &str) -> Ident { + Ident::new(s, Span::call_site()) +} diff --git a/rust/macros/src/lib.rs b/rust/macros/src/lib.rs new file mode 100644 index 0000000000000..0eb97cf11c3f3 --- /dev/null +++ b/rust/macros/src/lib.rs @@ -0,0 +1,21 @@ +#![warn(clippy::pedantic)] +#![warn(clippy::nursery)] +#![warn(clippy::str_to_string)] +#![warn(clippy::missing_inline_in_public_items)] +#![allow(clippy::missing_panics_doc)] +#![allow(clippy::must_use_candidate)] +#![allow(clippy::option_if_let_else)] + +mod fields; +mod helpers; +mod parse_vars; +mod register_plugin; +use proc_macro::TokenStream; + +/// Macro to use to register a plugin +/// +/// See the `plugin` module in the main `mariadb` crate for examples. +#[proc_macro] +pub fn register_plugin(item: TokenStream) -> TokenStream { + register_plugin::entry(item) +} diff --git a/rust/macros/src/parse_vars.rs b/rust/macros/src/parse_vars.rs new file mode 100644 index 0000000000000..beaf120f1c166 --- /dev/null +++ b/rust/macros/src/parse_vars.rs @@ -0,0 +1,347 @@ +//! Parse sysvar syntax +//! +//! ```ignore +//! { +//! ident: SOME_IDENT, +//! vtype: String, +//! name: "sql_name", +//! description: "this is a description", +//! options: [SysVarOpt::ReadOnly, SysVarOpt::NoCmdOpt], +//! default: "something" +//! } +//! ``` +//! +//! or +//! +//! ```ignore +//! { +//! ident: OTHER_IDENT, +//! vtype: AtomicI32, +//! name: "other_sql_name", +//! description: "this is a description", +//! options: [SysVarOpt::ReqCmdArg], +//! default: 100, +//! min: 10, +//! max: 500, +//! interval: 10 +//! } +//! ``` + +#![allow(unused)] + +// use proc_macro::TokenStream; +use proc_macro2::{Literal, Span, TokenStream, TokenTree}; +use quote::{quote, ToTokens}; +use syn::parse::{Parse, ParseBuffer, ParseStream}; +use syn::punctuated::Punctuated; +use syn::spanned::Spanned; +use syn::token::Group; +use syn::{ + bracketed, parse_macro_input, parse_quote, Attribute, DeriveInput, Error, Expr, ExprArray, + ExprLit, ExprStruct, FieldValue, Ident, ImplItem, ImplItemType, Item, ItemImpl, Lit, LitStr, + Path, PathSegment, Token, Type, TypePath, TypeReference, +}; + +use crate::fields::sysvar::{ALL_FIELDS, OPT_FIELDS, REQ_FIELDS}; +use crate::helpers::expect_litstr; + +#[derive(Clone, Copy, Debug)] +enum VarTypeInner { + SysVar, + ShowVar, +} + +/// Identifiers and their bodies +#[derive(Clone, Debug, Default)] +pub struct Variables { + pub sys: Vec, + pub sys_idents: Vec, + pub show: Vec, + pub show_idents: Vec, +} + +impl Parse for Variables { + fn parse(input: ParseStream) -> syn::Result { + let content; + let _ = bracketed!(content in input); + let var_decl = Punctuated::::parse_terminated(&content)?; + let mut ret = Self::default(); + for var in var_decl { + let ty = var.var_type.unwrap(); + match ty { + VarTypeInner::SysVar => { + let tmp = var.make_usable()?; + ret.sys_idents.push(tmp.0); + ret.sys.push(tmp.1); + } + VarTypeInner::ShowVar => { + let tmp = var.make_usable()?; + ret.show_idents.push(tmp.0); + ret.show.push(tmp.1); + } + } + } + Ok(ret) + } +} + +#[derive(Clone, Debug, Default)] +struct VariableInfo { + span: Option, + var_type: Option, + ident: Option, + vtype: Option, + name: Option, + description: Option, + options: Option, + default: Option, + min: Option, + max: Option, + interval: Option, +} + +impl Parse for VariableInfo { + fn parse(input: ParseStream) -> syn::Result { + let span = input.span(); + let mut ret = Self { + span: Some(span), + ..Default::default() + }; + + // parse struct-like syntax + let st: ExprStruct = input.parse()?; + // verify the weird possible things of a struct are empty + if !st.attrs.is_empty() { + return Err(Error::new_spanned(&st.attrs[0], "no attributes expected")); + } + if st.path != parse_quote! { SysVar } { + return Err(Error::new_spanned( + &st.path, + "only path 'SysVar' is allowed", + )); + } + ret.var_type = Some(VarTypeInner::SysVar); + if st.rest.is_some() { + return Err(Error::new_spanned(&st.rest, "unexpected 'rest' section")); + } + + let fields = st.fields; + let mut field_order: Vec = Vec::new(); + for field in fields.clone() { + let syn::Member::Named(name) = &field.member else { + return Err(Error::new_spanned(field, "missing field name")); + }; + + let name_str = name.to_string(); + let expr = field.expr; + + match name_str.as_str() { + "ident" => ret.ident = Some(expr), + "vtype" => ret.vtype = Some(expr), + "name" => ret.name = Some(expr), + "description" => ret.description = Some(expr), + "options" => ret.options = Some(expr), + "default" => ret.default = Some(expr), + "min" => ret.min = Some(expr), + "max" => ret.max = Some(expr), + "interval" => ret.interval = Some(expr), + _ => { + return Err(Error::new_spanned( + name, + format!("unexpected field '{name_str}'"), + )) + } + } + field_order.push(name_str); + } + + if let Err(msg) = verify_field_order(field_order.as_slice()) { + return Err(Error::new_spanned(fields, msg)); + } + + Ok(ret) + } +} + +impl VariableInfo { + fn make_usable(&self) -> syn::Result<(Ident, TokenStream)> { + match self.var_type.unwrap() { + VarTypeInner::SysVar => self.make_sysvar(), + VarTypeInner::ShowVar => self.make_showvar(), + } + } + + fn make_sysvar(&self) -> syn::Result<(Ident, TokenStream)> { + let Some(vtype) = &self.vtype else { + return Err(Error::new_spanned(&self.vtype, "missing required field 'vtype'")); + }; + + self.validate_correct_fields(REQ_FIELDS, OPT_FIELDS); + + let ty_as_svwrap = quote! { <#vtype as ::mariadb::plugin::internals::SysVarInterface> }; + let name = expect_litstr(&self.name)?; + let ident = self.ident.as_ref().unwrap(); + let opts = self.make_option_fields()?; + let flags = quote! { #ty_as_svwrap::DEFAULT_OPTS #opts }; + let description = expect_litstr(&self.description)?; + + let default = process_default_override(&self.default, "def_val")?; + let min = process_default_override(&self.min, "min_val")?; + let max = process_default_override(&self.max, "max_val")?; + let interval = process_default_override(&self.interval, "blk_sz")?; + + let st_ident = Ident::new(&format!("_sysvar_st_{}", name.value()), Span::call_site()); + let st_tycheck = Ident::new( + &format!("_sysvar_tychk_{}", name.value()), + Span::call_site(), + ); + // https://github.com/rust-lang/rust/issues/86935#issuecomment-1146670057 + let ty_wrap = Ident::new(&format!("_sysvar_Type{}", name.value()), Span::call_site()); + // check to verify our vars are of the right type for our idents + let ty_check = quote! { static #st_tycheck: &#vtype = &#ident; }; + + let usynccell = quote! { ::mariadb::internals::UnsafeSyncCell }; + + let res = quote! { + type #ty_wrap = T; + + #ty_check + + static #st_ident: #usynccell<#ty_wrap::<#ty_as_svwrap::CStructType>> = unsafe { + #usynccell::new( + #ty_wrap::<#ty_as_svwrap::CStructType> { + flags: #flags, + name: ::mariadb::internals::cstr!(#name).as_ptr(), + comment: ::mariadb::internals::cstr!(#description).as_ptr(), + value: ::std::ptr::addr_of!(#ident).cast_mut().cast(), // *mut *mut c_char, + + #default + #min + #max + #interval + + ..#ty_as_svwrap::DEFAULT_C_STRUCT + // def_val: #default, + } + ) + }; + + }; + + Ok((st_ident, res)) + } + + /// Take the options vector, parse it as an array, bitwise or the output, + /// cast to i32 + fn make_option_fields(&self) -> syn::Result { + let Some(input) = &self.options else { + return Ok(TokenStream::new()); + }; + let flags: OptFields = syn::parse(input.to_token_stream().into())?; + let opts = flags.flags; + if opts.is_empty() { + return Ok(TokenStream::new()); + } + let ret = quote! { + | (( #( #opts .as_plugin_var_info() )|* ) & + ::mariadb::bindings::PLUGIN_VAR_MASK as i32) + }; + Ok(ret) + } + + fn make_showvar(&self) -> syn::Result<(Ident, TokenStream)> { + todo!() + } + + fn validate_correct_fields(&self, required: &[&str], optional: &[&str]) -> syn::Result<()> { + // These are all required for all plugin types + let name_map = [ + (&self.ident, "ident"), + (&self.vtype, "vtype"), + (&self.name, "name"), + (&self.description, "description"), + (&self.options, "options"), + (&self.default, "default"), + (&self.min, "min"), + (&self.max, "max"), + (&self.interval, "interval"), + ]; + let vtype = self.vtype.as_ref().unwrap(); + let mut req = REQ_FIELDS.to_vec(); + req.extend_from_slice(required); + + for req_field in &req { + let (field_val, fname) = name_map.iter().find(|f| f.1 == *req_field).unwrap(); + + if field_val.is_none() { + let msg = format!( + "field '{fname}' is expected for variables of type {vtype:?}, but not provided" + ); + return Err(Error::new(self.span.unwrap(), msg)); + } + } + + for (field, fname) in name_map { + if field.is_some() && !req.contains(&fname) && !optional.contains(&fname) { + let msg = + format!("field '{fname}' is not expected for variables of type {vtype:?}"); + return Err(Error::new_spanned(field.as_ref().unwrap(), msg)); + } + } + + Ok(()) + } +} + +#[derive(Clone, Debug)] +struct OptFields { + flags: Vec, +} + +impl Parse for OptFields { + fn parse(input: ParseStream) -> syn::Result { + let content; + let _ = bracketed!(content in input); + let flags = Punctuated::::parse_terminated(&content)?; + let v = Vec::from_iter(flags); + Ok(Self { flags: v }) + } +} + +fn verify_field_order(fields: &[String]) -> Result<(), String> { + let mut expected_order = ALL_FIELDS.to_vec(); + + expected_order.retain(|expected| fields.iter().any(|f| f == expected)); + + if expected_order == fields { + return Ok(()); + } + + Err(format!( + "fields not in expected order. reorder as:\n{expected_order:?}", + )) +} + +/// Process "default override" style fields by these rules: +/// +/// - If `field` is `None`, return an empty `TokenStream` +/// - Enforce it is a literal +/// - If it is a literal string, change it to a `cstr` +/// +/// Might want to relax and take consts at some point, but that's someday... +fn process_default_override(field: &Option, fname: &str) -> syn::Result { + let Some(f_inner) = field.as_ref() else { + return Ok(TokenStream::new()); + }; + + let Expr::Lit(exprlit) = f_inner else { + return Err(Error::new_spanned(f_inner, "only literal values are allowed in this position")); + }; + + let fid = Ident::new(fname, f_inner.span()); + if let syn::Lit::Str(litstr) = &exprlit.lit { + Ok(quote! { #fid: ::mariadb::internals::cstr!(#litstr).as_ptr().cast_mut(), }) + } else { + Ok(quote! { #fid: #exprlit, }) + } +} diff --git a/rust/macros/src/register_plugin.rs b/rust/macros/src/register_plugin.rs new file mode 100644 index 0000000000000..02ce3fbc4bdfc --- /dev/null +++ b/rust/macros/src/register_plugin.rs @@ -0,0 +1,442 @@ +//! + +#![allow(unused)] + +use proc_macro2::{Literal, Span, TokenStream, TokenTree}; +use quote::{quote, ToTokens}; +use syn::parse::{Parse, ParseStream}; +use syn::punctuated::Punctuated; +use syn::token::Group; +use syn::{ + parse_macro_input, parse_quote, Attribute, DeriveInput, Error, Expr, ExprLit, FieldValue, + Ident, ImplItem, ImplItemType, Item, ItemImpl, Lit, LitStr, Path, PathSegment, Token, Type, + TypePath, TypeReference, +}; + +use crate::fields::plugin::{ALL_FIELDS, ENCR_OPT_FIELDS, ENCR_REQ_FIELDS, REQ_FIELDS}; +use crate::helpers::{expect_bool, expect_litstr, make_ident}; +use crate::parse_vars::{self, Variables}; + +/// Entrypoint for this proc macro +pub fn entry(tokens: proc_macro::TokenStream) -> proc_macro::TokenStream { + let tokens_pm2: proc_macro2::TokenStream = tokens.clone().into(); + let input = parse_macro_input!(tokens as PluginInfo); + let plugindef = input.into_encryption_struct(); + match plugindef { + Ok(ts) => ts.into_output().into(), + Err(e) => e.into_compile_error().into(), + } +} + +/// A representation of the contents of a registration macro. This macro will be +/// the same for +#[derive(Clone, Debug)] +struct PluginInfo { + /// The main type that has required methods implemented on it + main_ty: Ident, + span: Span, + ptype: Option, + name: Option, + author: Option, + description: Option, + license: Option, + maturity: Option, + version: Option, + init: Option, + encryption: Option, + variables: Option, +} + +impl Parse for PluginInfo { + fn parse(input: ParseStream) -> syn::Result { + let main_ty = input.parse()?; + // FIXME: span is only the beginning + let span = input.span(); + let mut ret = Self::new(main_ty, span); + let _: Token![,] = input.parse()?; + + let fields = Punctuated::::parse_terminated(input)?; + let mut field_order: Vec = Vec::new(); + for field in fields.clone() { + let syn::Member::Named(name) = &field.member else { + return Err(Error::new_spanned(field, "missing field name")); + }; + + let name_str = name.to_string(); + let expr = field.expr; + + match name_str.as_str() { + "ptype" => ret.ptype = Some(expr), + "name" => ret.name = Some(expr), + "author" => ret.author = Some(expr), + "description" => ret.description = Some(expr), + "license" => ret.license = Some(expr), + "maturity" => ret.maturity = Some(expr), + "version" => ret.version = Some(expr), + "init" => ret.init = Some(expr), + "encryption" => ret.encryption = Some(expr), + "variables" => ret.variables = Some(expr), + _ => { + return Err(Error::new_spanned( + name, + format!("unexpected field '{name_str}'"), + )) + } + } + field_order.push(name_str); + } + + if let Err(msg) = verify_field_order(field_order.as_slice()) { + return Err(Error::new_spanned(fields, msg)); + } + Ok(ret) + } +} + +impl PluginInfo { + const fn new(main_ty: Ident, span: Span) -> Self { + Self { + main_ty, + span, + ptype: None, + name: None, + author: None, + description: None, + license: None, + maturity: None, + version: None, + init: None, + encryption: None, + variables: None, + } + } + + /// Ensure we have the fields that are required for all plugin types + fn validate_correct_fields( + &self, + required: &[&str], + optional: &[&str], + ptype: &str, + ) -> syn::Result<()> { + // These are all required for all plugin types + let name_map = [ + (&self.ptype, "ptype"), + (&self.name, "name"), + (&self.author, "author"), + (&self.description, "description"), + (&self.license, "license"), + (&self.maturity, "maturity"), + (&self.version, "version"), + (&self.init, "init"), + (&self.encryption, "encryption"), + (&self.variables, "sysvars"), + ]; + + let mut req = REQ_FIELDS.to_vec(); + req.extend_from_slice(required); + + for req_field in &req { + let (field_val, fname) = name_map.iter().find(|f| f.1 == *req_field).unwrap(); + + if field_val.is_none() { + let msg = format!("field '{fname}' is expected for plugins of type {ptype}, but not provided\n(in macro 'register_plugin')"); + return Err(Error::new(self.span, msg)); + } + } + + for (field, fname) in name_map { + if field.is_some() && !req.contains(&fname) && !optional.contains(&fname) { + let msg = format!("field '{fname}' is not expected for plugins of type {ptype}\n(in macro 'register_plugin')"); + return Err(Error::new_spanned(field.as_ref().unwrap(), msg)); + } + } + + Ok(()) + } + + /// Validate the sysvars definition and create structure. Returns + fn make_variables(&self) -> syn::Result { + let mut ret = VariableBodies { + sysvar_body: TokenStream::new(), + sysvar_field: quote! { ::std::ptr::null_mut() }, + }; + let Some(vars_decl) = &self.variables else { + return Ok(ret); + }; + let vars: Variables = syn::parse(vars_decl.to_token_stream().into())?; + let name = expect_litstr(&self.name)?.value(); + let sysvar_arr_ident = Ident::new(&format!("_plugin_{name}_sysvars"), Span::call_site()); + + let sysvar_bodies = &vars.sys; + let sysvar_idents = &vars.sys_idents; + assert_eq!(sysvar_bodies.len(), sysvar_idents.len()); + + if sysvar_bodies.is_empty() { + return Ok(ret); + } + + let len = sysvar_bodies.len() + 1; + let usynccell = quote! { ::mariadb::internals::UnsafeSyncCell }; + + ret.sysvar_body = quote! { + #( #sysvar_bodies )* + + pub static #sysvar_arr_ident: + [#usynccell<*mut ::mariadb::bindings::sysvar_common_t>; #len] + = + unsafe { [ + #( #usynccell::new(#sysvar_idents.get().cast()), )* + #usynccell::new(::std::ptr::null_mut()) + ] }; + }; + // panic!("{}", ret.sysvar_body); + ret.sysvar_field = quote! { #sysvar_arr_ident.as_ptr().cast_mut().cast() }; + Ok(ret) + } + + /// Ensure we have the fields required for an encryption plugin + fn validate_as_encryption(&self) -> syn::Result<()> { + self.validate_correct_fields(ENCR_REQ_FIELDS, ENCR_OPT_FIELDS, "encryption")?; + Ok(()) + } + + /// Turn `self` into a tokenstream of a single `st_maria_plugin` for an + /// encryption struct. Uses `idx` to mangle the name and avoid conflicts + fn into_encryption_struct(self) -> syn::Result { + self.validate_as_encryption()?; + + let type_ = &self.main_ty; + let name = expect_litstr(&self.name)?; + let plugin_st_name = Ident::new(&format!("_ST_PLUGIN_{}", name.value()), Span::call_site()); + + let ty_as_wkeymgt = quote! { <#type_ as ::mariadb::plugin::internals::WrapKeyMgr> }; + let ty_as_wenc = quote! { <#type_ as ::mariadb::plugin::internals::WrapEncryption> }; + let interface_version = quote! { ::mariadb::bindings::MariaDB_ENCRYPTION_INTERFACE_VERSION as ::std::ffi::c_int }; + let get_key_vers = quote! { Some(#ty_as_wkeymgt::wrap_get_latest_key_version) }; + let get_key = quote! { Some(#ty_as_wkeymgt::wrap_get_key) }; + let variables = self.make_variables()?; + let variables_body = variables.sysvar_body; + + let (crypt_size, crypt_init, crypt_update, crypt_finish, crypt_len); + + if expect_bool(&self.encryption)? { + // Use encryption if given + crypt_size = quote! { Some(#ty_as_wenc::wrap_crypt_ctx_size) }; + crypt_init = quote! { Some(#ty_as_wenc::wrap_crypt_ctx_init) }; + crypt_update = quote! { Some(#ty_as_wenc::wrap_crypt_ctx_update) }; + crypt_finish = quote! { Some(#ty_as_wenc::wrap_crypt_ctx_finish) }; + crypt_len = quote! { Some(#ty_as_wenc::wrap_encrypted_length) }; + } else { + // Default to builtin encryption + let none = quote! { None }; + ( + crypt_size, + crypt_init, + crypt_update, + crypt_finish, + crypt_len, + ) = (none.clone(), none.clone(), none.clone(), none.clone(), none); + } + + let info_struct = quote! { + static #plugin_st_name: ::mariadb::internals::UnsafeSyncCell< + ::mariadb::bindings::st_mariadb_encryption, + > = unsafe { + ::mariadb::internals::UnsafeSyncCell::new( + ::mariadb::bindings::st_mariadb_encryption { + interface_version: #interface_version, + get_latest_key_version: #get_key_vers, + get_key: #get_key, + crypt_ctx_size: #crypt_size, + crypt_ctx_init: #crypt_init, + crypt_ctx_update: #crypt_update, + crypt_ctx_finish: #crypt_finish, + encrypted_length: #crypt_len, + } + ) + }; + }; + + let version_str = &expect_litstr(&self.version)?.value(); + let version_int = + version_int(version_str).map_err(|e| Error::new_spanned(&self.version, e))?; + let author = expect_litstr(&self.author)?; + let description = expect_litstr(&self.description)?; + let license = self.license.unwrap(); + let maturity = self.maturity.unwrap(); + let ptype = self.ptype.unwrap(); + let system_vars_ptr = variables.sysvar_field; + + let init_fn_name = make_ident(&format!("_{}_init_fn", name.value())); + + // We always initialize the logger, maybe do init/deinit if struct requires + let fn_init = quote! { Some(#init_fn_name) }; + let (fn_deinit, init_fn_body); + if let Some(init_ty) = self.init { + let ty_as_init = quote! { <#init_ty as ::mariadb::plugin::internals::WrapInit> }; + init_fn_body = quote! { ::mariadb::configure_logger!(); #ty_as_init::wrap_init(_p) }; + fn_deinit = quote! { Some(#ty_as_init::wrap_deinit) }; + } else { + init_fn_body = quote! { ::mariadb::configure_logger!(); 0 }; + fn_deinit = quote! { None }; + } + + let init_fn = quote! { + unsafe extern "C" fn #init_fn_name(_p: *mut std::ffi::c_void) -> std::ffi::c_int { + #init_fn_body + } + }; + + let plugin_struct = quote! { + ::mariadb::bindings::st_maria_plugin { + type_: #ptype.to_ptype_registration(), + info: #plugin_st_name.as_ptr().cast_mut().cast(), + name: ::mariadb::internals::cstr!(#name).as_ptr(), + author: ::mariadb::internals::cstr!(#author).as_ptr(), + descr: ::mariadb::internals::cstr!(#description).as_ptr(), + license: #license.to_license_registration(), + init: #fn_init, + deinit: #fn_deinit, + version: #version_int as ::std::ffi::c_uint, + status_vars: ::std::ptr::null_mut(), + system_vars: #system_vars_ptr, + version_info: ::mariadb::internals::cstr!(#version_str).as_ptr(), + maturity: #maturity.to_maturity_registration(), + }, + }; + + Ok(PluginDef { + name: name.value(), + init_fn, + info_struct, + plugin_struct, + variable_body: variables_body, + }) + } +} + +struct VariableBodies { + /// This body will be added + sysvar_body: TokenStream, + /// What to put in the registering `st_mariadb_plugin + sysvar_field: TokenStream, +} + +/// Contains a struct definition of type `st_mariadb_encryption` or whatever is +/// applicable, plus the struct of `st_maria_plugin` that references it +struct PluginDef { + name: String, + init_fn: TokenStream, + info_struct: TokenStream, + plugin_struct: TokenStream, + variable_body: TokenStream, +} + +impl PluginDef { + fn into_output(self) -> TokenStream { + // static and dynamic identifiers + let vers_ident_stc = make_ident(&format!("builtin_{}_plugin_interface_version", self.name)); + let vers_ident_dyn = make_ident("_maria_plugin_interface_version_"); + let size_ident_stc = make_ident(&format!("builtin_{}_sizeof_struct_st_plugin", self.name)); + let size_ident_dyn = make_ident("_maria_sizeof_struct_st_plugin_"); + let decl_ident_stc = make_ident(&format!("builtin_{}_plugin", self.name)); + let decl_ident_dyn = make_ident("_maria_plugin_declarations_"); + + let plugin_ty = quote! { ::mariadb::bindings::st_maria_plugin }; + let version_val = + quote! { mariadb::bindings::MARIA_PLUGIN_INTERFACE_VERSION as ::std::ffi::c_int }; + let size_val = quote! { ::std::mem::size_of::<#plugin_ty>() as ::std::ffi::c_int }; + + let usynccell = quote! { ::mariadb::internals::UnsafeSyncCell }; + let null_ps = quote! { ::mariadb::plugin::internals::new_null_st_maria_plugin() }; + + let info_st = self.info_struct; + let plugin_st = self.plugin_struct; + let init_fn = self.init_fn; + let variable_body = self.variable_body; + + quote! { ::std::ptr::null_mut() }; + + let ret: TokenStream = quote! { + #info_st + #init_fn + #variable_body + + // Different config based on statically or dynamically lynked + #[no_mangle] + #[cfg(make_static_lib)] + static #vers_ident_stc: ::std::ffi::c_int = #version_val; + + #[no_mangle] + #[cfg(not(make_static_lib))] + static #vers_ident_dyn: ::std::ffi::c_int = #version_val; + + #[no_mangle] + #[cfg(make_static_lib)] + static #size_ident_stc: ::std::ffi::c_int = #size_val; + + #[no_mangle] + #[cfg(not(make_static_lib))] + static #size_ident_dyn: ::std::ffi::c_int = #size_val; + + #[no_mangle] + #[cfg(make_static_lib)] + static #decl_ident_stc: [#usynccell<#plugin_ty>; 2] = unsafe { [ + #usynccell::new(#plugin_st), + #usynccell::new(#null_ps), + ] }; + + #[no_mangle] + #[cfg(not(make_static_lib))] + static #decl_ident_dyn: [#usynccell<#plugin_ty>; 2] = unsafe { [ + #usynccell::new(#plugin_st), + #usynccell::new(#null_ps), + ] }; + }; + ret + } +} + +/// Verify attribute order +fn verify_field_order(fields: &[String]) -> Result<(), String> { + let mut expected_order = ALL_FIELDS.to_vec(); + + expected_order.retain(|expected| fields.iter().any(|f| f == expected)); + + if expected_order == fields { + return Ok(()); + } + + Err(format!( + "fields not in expected order. reorder as:\n{expected_order:?}", + )) +} + +/// Convert a string like "1.2" to a hex like "0x0102". Error if no decimal, or +/// if either value exceeds a u8. +fn version_int(s: &str) -> Result { + const USAGE_MSG: &str = r#"expected a two position semvar string, e.g. "1.2""#; + if s.chars().filter(|x| *x == '.').count() != 1 { + return Err(USAGE_MSG.to_owned()); + } + + let splt = s.split_once('.').unwrap(); + let fmt_err = |e| format!("{e}\n{USAGE_MSG}"); + + let major: u16 = splt.0.parse::().map_err(fmt_err)?.into(); + let minor: u16 = splt.1.parse::().map_err(fmt_err)?.into(); + let res: u16 = (major << 8) + minor; + + Ok(res) +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_version_int() { + assert_eq!(version_int("5.2"), Ok(0x0502)); + assert_eq!(version_int("11.0"), Ok(0x0b00)); + } +} diff --git a/rust/macros/tests/entry.rs b/rust/macros/tests/entry.rs new file mode 100644 index 0000000000000..c17efeb546783 --- /dev/null +++ b/rust/macros/tests/entry.rs @@ -0,0 +1,6 @@ +#[test] +fn test_build() { + let t = trybuild::TestCases::new(); + t.pass("tests/pass/*.rs"); + t.compile_fail("tests/fail/*.rs"); +} diff --git a/rust/macros/tests/fail/01-extra-args.rs b/rust/macros/tests/fail/01-extra-args.rs new file mode 100644 index 0000000000000..2233336ac4c39 --- /dev/null +++ b/rust/macros/tests/fail/01-extra-args.rs @@ -0,0 +1,17 @@ +include!("../include.rs"); + +register_plugin! { + TestPlugin, + ptype: PluginType::MariaEncryption, + name: "debug_key_management", + author: "Trevor Gross", + description: "Debug key management plugin", + license: License::Gpl, + maturity: Maturity::Experimental, + version: "0.1", + init: TestPlugin, + encryption: false, + extra: false +} + +fn main() {} diff --git a/rust/macros/tests/fail/01-extra-args.stderr b/rust/macros/tests/fail/01-extra-args.stderr new file mode 100644 index 0000000000000..a8b2d53a367e1 --- /dev/null +++ b/rust/macros/tests/fail/01-extra-args.stderr @@ -0,0 +1,5 @@ +error: unexpected field 'extra' + --> tests/fail/01-extra-args.rs:14:5 + | +14 | extra: false + | ^^^^^ diff --git a/rust/macros/tests/fail/02-extra-sysargs.rs b/rust/macros/tests/fail/02-extra-sysargs.rs new file mode 100644 index 0000000000000..342b35b833af1 --- /dev/null +++ b/rust/macros/tests/fail/02-extra-sysargs.rs @@ -0,0 +1,27 @@ +include!("../include.rs"); + +register_plugin! { + TestPlugin, + ptype: PluginType::MariaEncryption, + name: "debug_key_management", + author: "Trevor Gross", + description: "Debug key management plugin", + license: License::Gpl, + maturity: Maturity::Experimental, + version: "0.1", + init: TestPlugin, + encryption: false, + variables: [ + SysVar { + ident: _SYSVAR_CONST_STR, + vtype: SysVarConstString, + name: "test_sysvar", + description: "this is a description", + options: [SysVarOpt::ReadOnly, SysVarOpt::NoCmdOpt], + default: "default value", + interval: "50" + } + ] +} + +fn main() {} diff --git a/rust/macros/tests/fail/02-extra-sysargs.stderr b/rust/macros/tests/fail/02-extra-sysargs.stderr new file mode 100644 index 0000000000000..05aa910a006f2 --- /dev/null +++ b/rust/macros/tests/fail/02-extra-sysargs.stderr @@ -0,0 +1,7 @@ +error[E0560]: struct `sysvar_str_t` has no field named `blk_sz` + --> tests/fail/02-extra-sysargs.rs:22:23 + | +22 | interval: "50" + | ^^^^ `sysvar_str_t` does not have this field + | + = note: available fields are: `flags`, `name`, `comment`, `check`, `update` ... and 2 others diff --git a/rust/macros/tests/fail/03-wrong-types.rs b/rust/macros/tests/fail/03-wrong-types.rs new file mode 100644 index 0000000000000..eb25475a9b959 --- /dev/null +++ b/rust/macros/tests/fail/03-wrong-types.rs @@ -0,0 +1,30 @@ +/* + * Verify our added check for identifier-type mismatch + */ + +include!("../include.rs"); + +register_plugin! { + TestPlugin, + ptype: PluginType::MariaEncryption, + name: "debug_key_management", + author: "Trevor Gross", + description: "Debug key management plugin", + license: License::Gpl, + maturity: Maturity::Experimental, + version: "0.1", + init: TestPlugin, + encryption: false, + variables: [ + SysVar { + ident: _SYSVAR_ATOMIC, + vtype: SysVarConstString, + name: "test_sysvar", + description: "this is a description", + options: [SysVarOpt::ReadOnly, SysVarOpt::NoCmdOpt], + default: "default value" + } + ] +} + +fn main() {} diff --git a/rust/macros/tests/fail/03-wrong-types.stderr b/rust/macros/tests/fail/03-wrong-types.stderr new file mode 100644 index 0000000000000..31dee11bc8163 --- /dev/null +++ b/rust/macros/tests/fail/03-wrong-types.stderr @@ -0,0 +1,15 @@ +error[E0308]: mismatched types + --> tests/fail/03-wrong-types.rs:7:1 + | +7 | / register_plugin! { +8 | | TestPlugin, +9 | | ptype: PluginType::MariaEncryption, +10 | | name: "debug_key_management", +... | +27 | | ] +28 | | } + | |_^ expected `&SysVarConstString`, found `&AtomicI32` + | + = note: expected reference `&'static mariadb::plugin::SysVarConstString` + found reference `&AtomicI32` + = note: this error originates in the macro `register_plugin` (in Nightly builds, run with -Z macro-backtrace for more info) diff --git a/rust/macros/tests/include.rs b/rust/macros/tests/include.rs new file mode 100644 index 0000000000000..e618cd5828bba --- /dev/null +++ b/rust/macros/tests/include.rs @@ -0,0 +1,34 @@ +/* Simple setup for dummy proc macro testing. Simple compile pass/fail only. */ + +use std::sync::atomic::AtomicI32; + +use mariadb::plugin::encryption::*; +use mariadb::plugin::*; +pub use mariadb_macros::register_plugin; + +static _SYSVAR_ATOMIC: AtomicI32 = AtomicI32::new(0); +static _SYSVAR_CONST_STR: SysVarConstString = SysVarConstString::new(); + +struct TestPlugin; + +impl KeyManager for TestPlugin { + fn get_latest_key_version(_key_id: u32) -> Result { + todo!() + } + fn get_key(_key_id: u32, _key_version: u32, _dst: &mut [u8]) -> Result<(), KeyError> { + todo!() + } + fn key_length(_key_id: u32, _key_version: u32) -> Result { + todo!() + } +} + +impl Init for TestPlugin { + fn init() -> Result<(), InitError> { + todo!() + } + + fn deinit() -> Result<(), InitError> { + todo!() + } +} diff --git a/rust/macros/tests/pass/01-simple.rs b/rust/macros/tests/pass/01-simple.rs new file mode 100644 index 0000000000000..339089927eecf --- /dev/null +++ b/rust/macros/tests/pass/01-simple.rs @@ -0,0 +1,51 @@ +include!("../include.rs"); + +register_plugin! { + TestPlugin, + ptype: PluginType::MariaEncryption, + name: "test_plugin_name", + author: "Test Author", + description: "Debug key management plugin", + license: License::Gpl, + maturity: Maturity::Experimental, + version: "1.2", + encryption: false, +} + +fn main() { + use std::ffi::CStr; + + use mariadb::bindings::st_maria_plugin; + + // verify correct symbols are created + let _: i32 = _maria_plugin_interface_version_; + let _: i32 = _maria_sizeof_struct_st_plugin_; + let plugin_def: &st_maria_plugin = unsafe { &*(_maria_plugin_declarations_[0]).get() }; + + // verify struct has correct fields + let type_ = plugin_def.type_; + let name = unsafe { CStr::from_ptr(plugin_def.name).to_str().unwrap() }; + let author = unsafe { CStr::from_ptr(plugin_def.author).to_str().unwrap() }; + let descr = unsafe { CStr::from_ptr(plugin_def.descr).to_str().unwrap() }; + let license = plugin_def.license; + let init = plugin_def.init; + let deinit = plugin_def.deinit; + let version = plugin_def.version; + let status_vars = plugin_def.status_vars; + let system_vars = plugin_def.system_vars; + let version_info = unsafe { CStr::from_ptr(plugin_def.version_info).to_str().unwrap() }; + let maturity = plugin_def.maturity; + + assert_eq!(type_, PluginType::MariaEncryption as i32); + assert_eq!(name, "test_plugin_name"); + assert_eq!(author, "Test Author"); + assert_eq!(descr, "Debug key management plugin"); + assert_eq!(license, License::Gpl as i32); + assert!(init.is_some()); // we always have an init function to setup logging + assert!(deinit.is_none()); + assert_eq!(version, 0x0102); + assert!(status_vars.is_null()); + assert!(system_vars.is_null()); + assert_eq!(version_info, "1.2"); + assert_eq!(maturity, Maturity::Experimental as u32); +} diff --git a/rust/macros/tests/pass/02-with-init.rs b/rust/macros/tests/pass/02-with-init.rs new file mode 100644 index 0000000000000..dc87882ddfc7c --- /dev/null +++ b/rust/macros/tests/pass/02-with-init.rs @@ -0,0 +1,23 @@ +include!("../include.rs"); + +register_plugin! { + TestPlugin, + ptype: PluginType::MariaEncryption, + name: "debug_key_management", + author: "Trevor Gross", + description: "Debug key management plugin", + license: License::Gpl, + maturity: Maturity::Experimental, + version: "0.1", + init: TestPlugin, + encryption: false, +} + +fn main() { + use mariadb::bindings::st_maria_plugin; + + let plugin_def: &st_maria_plugin = unsafe { &*(_maria_plugin_declarations_[0]).get() }; + + assert!(plugin_def.init.is_some()); + assert!(plugin_def.deinit.is_some()); +} diff --git a/rust/macros/tests/pass/03-with-sysargs.rs b/rust/macros/tests/pass/03-with-sysargs.rs new file mode 100644 index 0000000000000..01f3cc8053a4e --- /dev/null +++ b/rust/macros/tests/pass/03-with-sysargs.rs @@ -0,0 +1,63 @@ +include!("../include.rs"); + +register_plugin! { + TestPlugin, + ptype: PluginType::MariaEncryption, + name: "debug_key_management", + author: "Trevor Gross", + description: "Debug key management plugin", + license: License::Gpl, + maturity: Maturity::Experimental, + version: "0.1", + init: TestPlugin, + encryption: false, + variables: [ + SysVar { + ident: _SYSVAR_CONST_STR, + vtype: SysVarConstString, + name: "test_sysvar", + description: "this is a description", + options: [SysVarOpt::ReadOnly, SysVarOpt::NoCmdOpt], + default: "default value" + } + ] +} + +fn main() { + use std::ffi::CStr; + use std::ptr; + + use mariadb::bindings; + use mariadb::bindings::{st_maria_plugin, st_mysql_sys_var, sysvar_common_t, sysvar_str_t}; + use mariadb::internals::UnsafeSyncCell; + + // verify correct symbols are created + let _: i32 = _maria_plugin_interface_version_; + let _: i32 = _maria_sizeof_struct_st_plugin_; + let plugin_def: &st_maria_plugin = unsafe { &*(_maria_plugin_declarations_[0]).get() }; + + let sysv_ptr: *mut *mut st_mysql_sys_var = plugin_def.system_vars; + let sysvar_st: *const sysvar_str_t = _sysvar_st_test_sysvar.get(); + let sysvar_arr: &[UnsafeSyncCell<*mut sysvar_common_t>] = &_plugin_debug_key_management_sysvars; + let idx_0: *mut sysvar_common_t = unsafe { *sysvar_arr[0].get() }; + let idx_1: *mut sysvar_common_t = unsafe { *sysvar_arr[1].get() }; + assert_eq!(idx_0, sysvar_st.cast_mut().cast()); + assert_eq!(idx_1, ptr::null_mut()); + assert_eq!(sysv_ptr, sysvar_arr.as_ptr().cast_mut().cast()); + + // try the C way, slow casting steps to avoid errors here + let sv1_ptr: *mut st_mysql_sys_var = unsafe { *plugin_def.system_vars.add(0) }; + let sv1: &sysvar_str_t = unsafe { &*sv1_ptr.cast() }; + let flags = sv1.flags; + let sv1_name = unsafe { CStr::from_ptr(sv1.name).to_str().unwrap() }; + let sv1_comment = unsafe { CStr::from_ptr(sv1.comment).to_str().unwrap() }; + let sv1_default = unsafe { CStr::from_ptr(sv1.def_val).to_str().unwrap() }; + + let expected_flags = bindings::PLUGIN_VAR_STR + | ((bindings::PLUGIN_VAR_READONLY | bindings::PLUGIN_VAR_NOCMDOPT) + & bindings::PLUGIN_VAR_MASK); + assert_eq!(flags, expected_flags as i32); + assert_eq!(sv1_name, "test_sysvar"); + assert_eq!(sv1_comment, "this is a description"); + assert_eq!(sv1_default, "default value"); +} diff --git a/rust/mariadb/Cargo.toml b/rust/mariadb/Cargo.toml new file mode 100644 index 0000000000000..4598efd452a0a --- /dev/null +++ b/rust/mariadb/Cargo.toml @@ -0,0 +1,14 @@ +[package] +name = "mariadb" +version = "0.1.0" +edition = '2021' + + +[dependencies] +mariadb-sys = { path = "../bindings" } +mariadb-macros = { path = "../macros" } +cstr = "0.2.11" +concat-idents = "1.1.4" +time = { version = "0.3.17", features = ["formatting"]} +log = "0.4.17" +env_logger = "0.10.0" diff --git a/rust/mariadb/README.md b/rust/mariadb/README.md new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/rust/mariadb/src/common.rs b/rust/mariadb/src/common.rs new file mode 100644 index 0000000000000..90ceeb249ede5 --- /dev/null +++ b/rust/mariadb/src/common.rs @@ -0,0 +1,71 @@ +use std::ffi::c_void; +use std::slice; + +use crate::bindings; + +/// A SQL type and value +#[non_exhaustive] +pub enum Value<'a> { + Decimal(&'a [u8]), + Tiny(i8), + Short(i16), + Long(i64), + Float(f32), + Double(f64), + Null, + Time, // todo + TimeStamp, // todo + Date, // todo + DateTime, // todo + Year, // todo + Varchar(&'a [u8]), + Json(&'a [u8]), +} + +impl<'a> Value<'a> { + /// Supply a + pub(crate) unsafe fn from_ptr( + ty: bindings::enum_field_types, + ptr: *const c_void, + len: usize, + ) -> Self { + // helper function to make a slice + let buf_callback = || slice::from_raw_parts(ptr.cast(), len); + + match ty { + bindings::enum_field_types::MYSQL_TYPE_DECIMAL => Self::Decimal(buf_callback()), + bindings::enum_field_types::MYSQL_TYPE_TINY => Self::Tiny(*ptr.cast()), + bindings::enum_field_types::MYSQL_TYPE_SHORT => Self::Short(*ptr.cast()), + bindings::enum_field_types::MYSQL_TYPE_LONG => Self::Long(*ptr.cast()), + bindings::enum_field_types::MYSQL_TYPE_FLOAT => Self::Float(*ptr.cast()), + bindings::enum_field_types::MYSQL_TYPE_DOUBLE => Self::Double(*ptr.cast()), + bindings::enum_field_types::MYSQL_TYPE_NULL => Self::Null, + bindings::enum_field_types::MYSQL_TYPE_TIMESTAMP => todo!(), + bindings::enum_field_types::MYSQL_TYPE_LONGLONG => todo!(), + bindings::enum_field_types::MYSQL_TYPE_INT24 => todo!(), + bindings::enum_field_types::MYSQL_TYPE_DATE => todo!(), + bindings::enum_field_types::MYSQL_TYPE_TIME => todo!(), + bindings::enum_field_types::MYSQL_TYPE_DATETIME => todo!(), + bindings::enum_field_types::MYSQL_TYPE_YEAR => todo!(), + bindings::enum_field_types::MYSQL_TYPE_NEWDATE => todo!(), + bindings::enum_field_types::MYSQL_TYPE_VARCHAR => todo!(), + bindings::enum_field_types::MYSQL_TYPE_BIT => todo!(), + bindings::enum_field_types::MYSQL_TYPE_TIMESTAMP2 => todo!(), + bindings::enum_field_types::MYSQL_TYPE_DATETIME2 => todo!(), + bindings::enum_field_types::MYSQL_TYPE_TIME2 => todo!(), + bindings::enum_field_types::MYSQL_TYPE_BLOB_COMPRESSED => todo!(), + bindings::enum_field_types::MYSQL_TYPE_VARCHAR_COMPRESSED => todo!(), + bindings::enum_field_types::MYSQL_TYPE_NEWDECIMAL => todo!(), + bindings::enum_field_types::MYSQL_TYPE_ENUM => todo!(), + bindings::enum_field_types::MYSQL_TYPE_SET => todo!(), + bindings::enum_field_types::MYSQL_TYPE_TINY_BLOB => todo!(), + bindings::enum_field_types::MYSQL_TYPE_MEDIUM_BLOB => todo!(), + bindings::enum_field_types::MYSQL_TYPE_LONG_BLOB => todo!(), + bindings::enum_field_types::MYSQL_TYPE_BLOB => todo!(), + bindings::enum_field_types::MYSQL_TYPE_VAR_STRING => todo!(), + bindings::enum_field_types::MYSQL_TYPE_STRING => todo!(), + bindings::enum_field_types::MYSQL_TYPE_GEOMETRY => todo!(), + _ => todo!(), + } + } +} diff --git a/rust/mariadb/src/helpers.rs b/rust/mariadb/src/helpers.rs new file mode 100644 index 0000000000000..89a61b56dbd1f --- /dev/null +++ b/rust/mariadb/src/helpers.rs @@ -0,0 +1,48 @@ +use std::cell::UnsafeCell; +use std::ptr; + +use super::bindings; + +/// Used for plugin registrations, which are in global scope. +#[doc(hidden)] +#[derive(Debug)] +#[repr(transparent)] +pub struct UnsafeSyncCell(UnsafeCell); + +impl UnsafeSyncCell { + /// # Safety + /// + /// This inner value be used in a Sync/Send way + pub const unsafe fn new(value: T) -> Self { + Self(UnsafeCell::new(value)) + } + + pub const fn as_ptr(&self) -> *const T { + self.0.get() + } + + pub const fn get(&self) -> *mut T { + self.0.get() + } + + pub fn get_mut(&mut self) -> &mut T { + self.0.get_mut() + } +} + +#[allow(clippy::non_send_fields_in_send_ty)] +unsafe impl Send for UnsafeSyncCell {} +unsafe impl Sync for UnsafeSyncCell {} + +pub fn str2bool(s: &str) -> Option { + const TRUE_VALS: [&str; 3] = ["on", "true", "1"]; + const FALSE_VALS: [&str; 3] = ["off", "false", "0"]; + let lower = s.to_lowercase(); + if TRUE_VALS.contains(&lower.as_str()) { + Some(true) + } else if FALSE_VALS.contains(&lower.as_str()) { + Some(false) + } else { + None + } +} diff --git a/rust/mariadb/src/lib.rs b/rust/mariadb/src/lib.rs new file mode 100644 index 0000000000000..240482d7b4aaf --- /dev/null +++ b/rust/mariadb/src/lib.rs @@ -0,0 +1,109 @@ +//! Crate representing safe abstractions over MariaDB bindings +#![warn(clippy::pedantic)] +#![warn(clippy::nursery)] +#![warn(clippy::str_to_string)] +#![allow(clippy::option_if_let_else)] +#![allow(clippy::missing_errors_doc)] +#![allow(clippy::must_use_candidate)] +#![allow(clippy::useless_conversion)] +#![allow(clippy::cast_possible_wrap)] +#![allow(clippy::missing_safety_doc)] +#![allow(clippy::missing_const_for_fn)] +#![allow(clippy::module_name_repetitions)] +#![allow(clippy::missing_inline_in_public_items)] +#![allow(unused)] + +use time::{format_description, OffsetDateTime}; + +mod common; +mod helpers; +pub mod plugin; +pub mod service_sql; +use std::fmt::Write; + +#[doc(inline)] +pub use common::*; +pub use log; +#[doc(hidden)] +pub use mariadb_sys as bindings; + +#[doc(hidden)] +pub mod internals { + pub use cstr::cstr; + + pub use super::helpers::UnsafeSyncCell; +} + +/// Our main logger config +/// +/// Writes a timestamp, log level, and message. For debug & trace, also log the +/// file name. +#[doc(hidden)] +pub struct MariaLogger; + +impl MariaLogger { + pub const fn new() -> Self { + Self + } +} + +impl log::Log for MariaLogger { + fn enabled(&self, metadata: &log::Metadata) -> bool { + // metadata.level() <= log::Level::Info + true + } + + fn log(&self, record: &log::Record) { + if !self.enabled(record.metadata()) { + return; + } + + let t = time::OffsetDateTime::now_utc(); + let fmt = time::format_description::parse( + "[year]-[month]-[day] [hour]:[minute]:[second][offset_hour sign:mandatory]:[offset_minute]", + ) + .unwrap(); + + // Format our string + let mut out_str = t.format(&fmt).unwrap(); + write!(out_str, " [{}", record.level()).unwrap(); + + if record.level() == log::Level::Debug || record.level() == log::Level::Trace { + write!(out_str, " {}", record.file().unwrap_or("")).unwrap(); + if let Some(line) = record.line() { + write!(out_str, ":{line}").unwrap(); + } + } + + eprintln!("{out_str}]: {}", record.args()); + } + + fn flush(&self) {} +} + +/// Configure the default logger. This is currently called by default for +/// plugins in the `init` function. +#[macro_export] +macro_rules! configure_logger { + () => { + $crate::configure_logger!($crate::log::LevelFilter::Warn) + }; + ($level:expr) => {{ + static LOGGER: $crate::MariaLogger = $crate::MariaLogger::new(); + $crate::log::set_logger(&LOGGER) + .map(|()| $crate::log::set_max_level($level)) + .expect("failed to configure logger"); + }}; +} + +/// Provide the name of the calling function (full path) +macro_rules! function_name { + () => {{ + fn f() {} + fn type_name_of(_: T) -> &'static str { + std::any::type_name::() + } + let name = type_name_of(f); + &name[..name.len() - 3] + }}; +} diff --git a/rust/mariadb/src/plugin.rs b/rust/mariadb/src/plugin.rs new file mode 100644 index 0000000000000..cf151447eae16 --- /dev/null +++ b/rust/mariadb/src/plugin.rs @@ -0,0 +1,183 @@ +//! Module for everything relevant to plugins +//! +//! Usage: +//! +//! ``` +//! use mariadb::plugin::*; +//! use mariadb::plugin::encryption::*; +//! use mariadb::plugin::SysVarConstString; +//! +//! static SYSVAR_STR: SysVarConstString = SysVarConstString::new(); +//! +//! +//! // May be empty or not +//! struct ExampleKeyManager; +//! +//! impl KeyManager for ExampleKeyManager { +//! // ... +//! # fn get_latest_key_version(key_id: u32) -> Result { todo!() } +//! # fn get_key(key_id: u32, key_version: u32, dst: &mut [u8]) -> Result<(), KeyError> { todo!() } +//! # fn key_length(_key_id: u32, _key_version: u32) -> Result { todo!() } +//! } +//! +//! impl Init for ExampleKeyManager { +//! // ... +//! # fn init() -> Result<(), InitError> { todo!() } +//! # fn deinit() -> Result<(), InitError> { todo!() } +//! } +//! +//! register_plugin! { +//! ExampleKeyManager, // Name of the struct implementing KeyManager +//! ptype: PluginType::MariaEncryption, // plugin type; only encryption supported for now +//! name: "name_as_sql_server_sees_it", // loadable plugin name +//! author: "Author Name", // author's name +//! description: "Sample key managment plugin", // give a description +//! license: License::Gpl, // select a license type +//! maturity: Maturity::Experimental, // indicate license maturity +//! version: "0.1", // provide an "a.b" version +//! init: ExampleKeyManager, // optional: struct implementing Init if needed +//! encryption: false, // false to use default encryption, true if your +//! // struct implements 'Encryption' +//! variables: [ // variables should be a list of typed identifiers +//! SysVar { +//! ident: SYSVAR_STR, +//! vtype: SysVarConstString, +//! name: "sql_name", +//! description: "this is a description", +//! options: [SysVarOpt::ReadOnly, SysVarOpt::NoCmdOpt], +//! default: "something" +//! }, +//! // SysVar { +//! // ident: OTHER_IDENT, +//! // vtype: AtomicI32, +//! // name: "other_sql_name", +//! // description: "this is a description", +//! // options: [SysVarOpt::ReqCmdArg], +//! // default: 100, +//! // min: 10, +//! // max: 500, +//! // interval: 10 +//! // } +//! ] +//! } +//! ``` + +use std::ffi::{c_int, c_uint}; +use std::str::FromStr; + +use mariadb_sys as bindings; +pub mod encryption; +mod encryption_wrapper; +mod variables; +mod variables_parse; +mod wrapper; +pub use mariadb_macros::register_plugin; +pub use variables::{SysVarConstString, SysVarOpt, SysVarString}; + +/// Commonly used plugin types for reexport +pub mod prelude { + pub use super::{register_plugin, Init, InitError, License, Maturity, PluginType}; +} + +/// Reexports for use in proc macros +#[doc(hidden)] +pub mod internals { + pub use super::encryption_wrapper::{WrapEncryption, WrapKeyMgr}; + pub use super::variables::SysVarInterface; + pub use super::wrapper::{new_null_st_maria_plugin, WrapInit}; +} + +/// Defines possible licenses for plugins +#[non_exhaustive] +#[derive(Clone, Copy, Debug)] +#[allow(clippy::cast_possible_wrap)] +pub enum License { + Proprietary = bindings::PLUGIN_LICENSE_PROPRIETARY as isize, + Gpl = bindings::PLUGIN_LICENSE_GPL as isize, + Bsd = bindings::PLUGIN_LICENSE_BSD as isize, +} + +impl License { + #[must_use] + #[doc(hidden)] + pub const fn to_license_registration(self) -> c_int { + self as c_int + } +} + +impl FromStr for License { + type Err = String; + + /// Create a license type from a string + fn from_str(s: &str) -> Result { + match s.to_lowercase().as_str() { + "proprietary" => Ok(Self::Proprietary), + "gpl" => Ok(Self::Gpl), + "bsd" => Ok(Self::Bsd), + _ => Err(format!("'{s}' has no matching license type")), + } + } +} + +/// Defines a type of plugin. This determines the required implementation. +#[non_exhaustive] +#[allow(clippy::cast_possible_wrap)] +pub enum PluginType { + MyUdf = bindings::MYSQL_UDF_PLUGIN as isize, + MyStorageEngine = bindings::MYSQL_STORAGE_ENGINE_PLUGIN as isize, + MyFtParser = bindings::MYSQL_FTPARSER_PLUGIN as isize, + MyDaemon = bindings::MYSQL_DAEMON_PLUGIN as isize, + MyInformationSchema = bindings::MYSQL_INFORMATION_SCHEMA_PLUGIN as isize, + MyAudit = bindings::MYSQL_AUDIT_PLUGIN as isize, + MyReplication = bindings::MYSQL_REPLICATION_PLUGIN as isize, + MyAuthentication = bindings::MYSQL_AUTHENTICATION_PLUGIN as isize, + MariaPasswordValidation = bindings::MariaDB_PASSWORD_VALIDATION_PLUGIN as isize, + /// Use this plugin type both for key managers and for full encryption plugins + MariaEncryption = bindings::MariaDB_ENCRYPTION_PLUGIN as isize, + MariaDataType = bindings::MariaDB_DATA_TYPE_PLUGIN as isize, + MariaFunction = bindings::MariaDB_FUNCTION_PLUGIN as isize, +} + +impl PluginType { + #[must_use] + #[doc(hidden)] + pub const fn to_ptype_registration(self) -> c_int { + self as c_int + } +} + +/// Defines possible licenses for plugins +#[non_exhaustive] +#[allow(clippy::cast_possible_wrap)] +pub enum Maturity { + Unknown = bindings::MariaDB_PLUGIN_MATURITY_UNKNOWN as isize, + Experimental = bindings::MariaDB_PLUGIN_MATURITY_EXPERIMENTAL as isize, + Alpha = bindings::MariaDB_PLUGIN_MATURITY_ALPHA as isize, + Beta = bindings::MariaDB_PLUGIN_MATURITY_BETA as isize, + Gamma = bindings::MariaDB_PLUGIN_MATURITY_GAMMA as isize, + Stable = bindings::MariaDB_PLUGIN_MATURITY_STABLE as isize, +} + +impl Maturity { + #[must_use] + #[doc(hidden)] + pub const fn to_maturity_registration(self) -> c_uint { + self as c_uint + } +} + +/// Indicate that an error occured during initialization or deinitialization +pub struct InitError; + +/// Implement this trait if your plugin requires init or deinit functions +pub trait Init { + /// Initialize the plugin + fn init() -> Result<(), InitError> { + Ok(()) + } + + /// Deinitialize the plugin + fn deinit() -> Result<(), InitError> { + Ok(()) + } +} diff --git a/rust/mariadb/src/plugin/authentication.rs b/rust/mariadb/src/plugin/authentication.rs new file mode 100644 index 0000000000000..f9e317c0af9ae --- /dev/null +++ b/rust/mariadb/src/plugin/authentication.rs @@ -0,0 +1,76 @@ +// st_mysql_server_auth_info +// st_mysql_auth + +//! +//! +//! +//! +//! +//! # Implementation +//! +//! `st_mysql_auth` requires: +//! +//! - `interface_version`: int, set by macro +//! - `client_auth_plugin`: `char*`, indicates client's required plugin for +//! authentication. Set by macro +//! - `authenticate_user`: function, wraps `authenticate_user` +//! +//! +//! +//! +//! +//! + +use std::slice; + +use crate::plugins::vio::Vio; + +#[repr(transparent)] +struct AuthInfo(bindings::MYSQL_SERVER_AUTH_INFO); + +enum PasswordUsage { + NotUsed = bindings::PASSWORD_USED_NO, + NotUsedMention = bindings::PASSWORD_USED_NO_MENTION, + Used = bindings::PASSWORD_USED_YES, +} + +impl AuthInfo { + fn get_user_name(&self) -> &[u8] { + // SAFETY: caller guarantees validity of self + unsafe { slice::from_raw_parts(self.0.user_name, self.0.user_name_length as usize) } + } + + fn get_auth_string(&self) -> &[u8] { + // SAFETY: caller guarantees validity of self + unsafe { slice::from_raw_parts(self.0.auth_string, self.0.auth_string_length as usize) } + } + + fn get_authenticated_as(&self) -> &[u8] { + + } + + fn set_authenticated_as(&self, s: AsRef<[u8]>) -> Result<(), TruncatedError> { + + } + + fn set_password_usage(u: PasswordUsage) { + + } + + fn set_host(&self, s: AsRef<[u8]>) -> Result<(), TruncatedError> { + + } +} + +struct AuthError; + +trait Authentication { + fn authenticate_user(vio: &Vio, info: &AuthInfo) -> Result; + + /// Hash the provided password and write the output to `hash`. Return the + /// number of written bytes if successful, or `Err` if not. + fn hash_password(password: &[u8], hash: &mut [u8]) -> Result; + + /// Prepare the password hash for authentication + fn preprocess_hash(hash: &[u8], out: &mut [u8]) -> Result; +} diff --git a/rust/mariadb/src/plugin/common.rs b/rust/mariadb/src/plugin/common.rs new file mode 100644 index 0000000000000..b308d6dadf30f --- /dev/null +++ b/rust/mariadb/src/plugin/common.rs @@ -0,0 +1,4 @@ +/// Trait for plugins that want to use the init/deinit functions +trait Init { + fn init() -> Self; +} diff --git a/rust/mariadb/src/plugin/encryption.rs b/rust/mariadb/src/plugin/encryption.rs new file mode 100644 index 0000000000000..22f0ca10515c9 --- /dev/null +++ b/rust/mariadb/src/plugin/encryption.rs @@ -0,0 +1,121 @@ +//! Requirements to implement an encryption plugin +//! +//! # Usage +//! +//! - Keep key storage context in globals. These need to be mutex-protected +//! +//! # Implementation +//! +//! `plugin_encryption.h` defines `st_mariadb_encryption`, with the following members: +//! +//! - `interface_version`: integer, set via macro +//! - `get_latest_key_version`: function, wrapped in `Encryption::get_latest_key_version` +//! - `get_key`: function, wrapped in `Encryption::get_key` +//! - `crypt_ctx_size`: function, wrapped in `Encryption::size` +//! - `crypt_ctx_init`: function, wrapped in `Encryption::init` +//! - `crypt_ctx_update`: function, wrapped in `Encryption::update` +//! - `crypt_ctx_finish`: function, wrapped in `Encryption::finish` +//! - `encrypted_length`: function, macro provides call to `std::mem::size_of` + +// use core::cell::UnsafeCell; +use mariadb_sys as bindings; + +/// A type of error to be used by key functions +#[repr(u32)] +#[non_exhaustive] +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum KeyError { + // Values must be nonzero + /// A key ID is invalid or not found. Maps to `ENCRYPTION_KEY_VERSION_INVALID` in C. + VersionInvalid = bindings::ENCRYPTION_KEY_VERSION_INVALID, + /// A key buffer is too small. Maps to `ENCRYPTION_KEY_BUFFER_TOO_SMALL` in C. + BufferTooSmall = bindings::ENCRYPTION_KEY_BUFFER_TOO_SMALL, + Other = 3, +} + +#[repr(i32)] +#[non_exhaustive] +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum EncryptionError { + BadData = bindings::MY_AES_BAD_DATA, + BadKeySize = bindings::MY_AES_BAD_KEYSIZE, + Other = bindings::MY_AES_OPENSSL_ERROR, +} + +/// Representation of the flags integer +#[derive(Clone, Copy, Debug, PartialEq, Eq)] +pub struct Flags(i32); + +impl Flags { + pub(crate) const fn new(value: i32) -> Self { + Self(value) + } + + pub(crate) const fn should_encrypt(self) -> bool { + (self.0 & bindings::ENCRYPTION_FLAG_ENCRYPT as i32) != 0 + } + + pub(crate) const fn should_decrypt(self) -> bool { + // (self.0 & bindings::ENCRYPTION_FLAG_DECRYPT as i32) != 0 + !self.should_encrypt() + } + + pub const fn nopad(&self) -> bool { + (self.0 & bindings::ENCRYPTION_FLAG_NOPAD as i32) != 0 + } +} + +/// Implement this trait on a struct that will serve as encryption context +/// +/// +/// The type of context data that will be passed to various encryption +/// function calls. +#[allow(unused_variables)] +pub trait KeyManager: Send + Sized { + // type Context: Send; + + /// Get the latest version of a key ID. Return `VersionInvalid` if not found. + fn get_latest_key_version(key_id: u32) -> Result; + + /// Return a key for a key version and ID. + /// + /// Given a key ID and a version, write the key to the `key` buffer. If the + /// buffer is too small, return [`KeyError::BufferTooSmall`]. + fn get_key(key_id: u32, key_version: u32, dst: &mut [u8]) -> Result<(), KeyError>; + + /// Calculate the length of a key. Usually this is constant, but the key ID + /// and version can be taken into account if needed. + /// + /// On the C side, this function is combined with `get_key`. + fn key_length(key_id: u32, key_version: u32) -> Result; +} + +// TODO: Maybe split into `Encrypt` and `Decrypt` traits +pub trait Encryption: Sized { + /// Initialize the encryption context object + fn init( + key_id: u32, + key_version: u32, + key: &[u8], + iv: &[u8], + flags: Flags, + ) -> Result; + + /// Update the encryption context with new data, return the number of bytes + /// written + fn update(&mut self, src: &[u8], dst: &mut [u8]) -> Result<(), EncryptionError>; + + /// Write the remaining bytes to the buffer + fn finish(&mut self, dst: &mut [u8]) -> Result<(), EncryptionError>; + + /// Return the exact length of the encrypted data based on the source length + /// + /// As this function must have a definitive answer, this API only supports + /// encryption algorithms where this is possible to compute (i.e., + /// compression is not supported). + fn encrypted_length(key_id: u32, key_version: u32, src_len: usize) -> usize; +} + +// get_latest_key_version: #type::get_latest_key_version +// crypt_ctx_size: std::mem::size_of<#type::Context>() +// crypt_ctx_init: #type:init() diff --git a/rust/mariadb/src/plugin/encryption_wrapper.rs b/rust/mariadb/src/plugin/encryption_wrapper.rs new file mode 100644 index 0000000000000..77a586f631324 --- /dev/null +++ b/rust/mariadb/src/plugin/encryption_wrapper.rs @@ -0,0 +1,186 @@ +//! Wrappers needed for the `st_mariadb_encryption` type + +use std::ffi::{c_int, c_uchar, c_uint, c_void}; +use std::{mem, slice}; + +use log::{error, warn}; +use mariadb_sys as bindings; + +use super::encryption::{Encryption, Flags, KeyError, KeyManager}; + +/// +pub trait WrapKeyMgr: KeyManager { + /// Get the key version, simple wrapper + extern "C" fn wrap_get_latest_key_version(key_id: c_uint) -> c_uint { + match Self::get_latest_key_version(key_id) { + Ok(v) => { + if v == bindings::ENCRYPTION_KEY_NOT_ENCRYPTED { + error!("get_latest_key_version returned value {v}, which is reserved for unencrypted keys."); + error!("the server will likely shut down now."); + } + v + } + Err(_) => KeyError::VersionInvalid as c_uint, + } + } + + /// If key == NULL, return the required buffer size for the key + /// + /// # Safety + /// + /// `dstbuf` must be valid for `buflen` + unsafe extern "C" fn wrap_get_key( + key_id: c_uint, + version: c_uint, + dstbuf: *mut c_uchar, + buflen: *mut c_uint, + ) -> c_uint { + dbg!(key_id, version, dstbuf, *buflen); + if dstbuf.is_null() { + match dbg!(Self::key_length(key_id, version)) { + // FIXME: don't unwrap + Ok(v) => *buflen = v.try_into().unwrap(), + Err(e) => { + return e as c_uint; + } + } + return bindings::ENCRYPTION_KEY_BUFFER_TOO_SMALL; + } + + // SAFETY: caller guarantees validity + let buf = slice::from_raw_parts_mut(dstbuf, *buflen as usize); + + // If successful, return 0. If an error occurs, return it + match dbg!(Self::get_key(key_id, version, buf)) { + Ok(_) => 0, + Err(e) => { + dbg!(e); + + // match e { + // // Set the desired buffer size if available + // KeyError::BufferTooSmall => { + // *buflen = dbg!(Self::key_length(key_id, version) + // .unwrap_or(0) + // .try_into() + // .unwrap()) + // } + // _ => (), + // } + dbg!(e as c_uint) + } + } + } +} + +impl WrapKeyMgr for T where T: KeyManager {} +impl WrapEncryption for T where T: Encryption {} + +pub trait WrapEncryption: Encryption { + extern "C" fn wrap_crypt_ctx_size(_key_id: c_uint, _key_version: c_uint) -> c_uint { + // I believe that key_id and key_version are provided in case this plugin + // uses different structs for different keys. However, it seems safer & more + // user friendly to sidestep that and just make everything the same size + mem::size_of::().try_into().unwrap() + } + + /// # Safety + /// + /// The caller must guarantee that the following is tre + /// + /// - `ctx` points to memory with the size of T (may be uninitialized) + /// - `key` exists for `klen` + /// - `iv` exists for `ivlen` + unsafe extern "C" fn wrap_crypt_ctx_init( + ctx: *mut c_void, + key: *const c_uchar, + klen: c_uint, + iv: *const c_uchar, + ivlen: c_uint, + flags: c_int, + key_id: c_uint, + key_version: c_uint, + ) -> c_int { + /// SAFETY: caller guarantees buffer validity + let keybuf = slice::from_raw_parts(key, klen as usize); + let ivbuf = slice::from_raw_parts(iv, ivlen as usize); + let flags = Flags::new(flags); + match Self::init(key_id, key_version, keybuf, ivbuf, flags) { + Ok(newctx) => { + ctx.cast::().write(newctx); + bindings::MY_AES_OK.try_into().unwrap() + } + Err(e) => e as c_int, + } + } + + /// # Safety + /// + /// The caller must guarantee that the following is true: + /// + /// - `ctx` points to a valid, initialized object of type T + /// - `src` exists for `slen` + /// - ~~`dst` exists for `dlen`~~ + /// + /// FIXME: the `*dlen` we receive from the server is unitialized. For now we + /// assume the destination buffer is equal to source buffer length, but this + /// is a bit of a workaround until MDEV-30389 is resolved. + unsafe extern "C" fn wrap_crypt_ctx_update( + ctx: *mut c_void, + src: *const c_uchar, + slen: c_uint, + dst: *mut c_uchar, + dlen: *mut c_uint, + ) -> c_int { + // dbg!(slen, dlen, *dlen); + let sbuf = slice::from_raw_parts(src, slen as usize); + let dbuf = slice::from_raw_parts_mut(dst, slen as usize); + + let c: &mut Self = &mut *ctx.cast(); + let (ret, written) = match c.update(sbuf, dbuf) { + // FIXME dlen + Ok(_) => (bindings::MY_AES_OK.try_into().unwrap(), 0), + Err(e) => (e as c_int, 0), + }; + *dlen = written; + ret + } + + unsafe extern "C" fn wrap_crypt_ctx_finish( + ctx: *mut c_void, + dst: *mut c_uchar, + dlen: *mut c_uint, + ) -> c_int { + dbg!(*dlen); + let dbuf = slice::from_raw_parts_mut(dst, dlen as usize); + + let c: &mut Self = &mut *ctx.cast(); + let (ret, written) = match c.finish(dbuf) { + // FIXME dlen + Ok(_) => (bindings::MY_AES_OK.try_into().unwrap(), 0), + Err(e) => (e as c_int, 0), + }; + + ctx.drop_in_place(); + ret + } + + unsafe extern "C" fn wrap_encrypted_length( + slen: c_uint, + key_id: c_uint, + key_version: c_uint, + ) -> c_uint { + Self::encrypted_length(key_id, key_version, slen.try_into().unwrap()) + .try_into() + .unwrap() + } +} + +unsafe fn set_buflen_with_check(buflen: *mut c_uint, val: u32) { + if val > 32 { + eprintln!( + "The default encryption does not seem to allow keys above 32 bits. If the server \ + crashes after this message, that is the likely error" + ); + } + *buflen = val.try_into().unwrap(); +} diff --git a/rust/mariadb/src/plugin/ftparser.rs b/rust/mariadb/src/plugin/ftparser.rs new file mode 100644 index 0000000000000..8f473441711eb --- /dev/null +++ b/rust/mariadb/src/plugin/ftparser.rs @@ -0,0 +1,4 @@ +trait FtParser { + fn init() -> Self; + fn parse(&self); +} diff --git a/rust/mariadb/src/plugin/variables.rs b/rust/mariadb/src/plugin/variables.rs new file mode 100644 index 0000000000000..c7e20037f9c49 --- /dev/null +++ b/rust/mariadb/src/plugin/variables.rs @@ -0,0 +1,343 @@ +//! "show variables" and "system variables" + +use std::cell::UnsafeCell; +use std::ffi::{c_double, c_int, c_long, c_longlong, c_ulong, c_ulonglong, c_void, CStr, CString}; +use std::marker::PhantomPinned; +use std::mem::ManuallyDrop; +use std::os::raw::{c_char, c_uint}; +use std::ptr; +use std::sync::atomic::{self, AtomicBool, AtomicI32, AtomicPtr, AtomicU32, Ordering}; +use std::sync::Mutex; + +use bindings::THD; +use cstr::cstr; +use log::{trace, warn}; +use mariadb_sys as bindings; + +use super::variables_parse::{CliMysqlValue, CliValue}; + +/// Possible flags for plugin variables +#[repr(i32)] +#[non_exhaustive] +#[derive(Clone, Copy, PartialEq, Eq)] +#[allow(clippy::cast_possible_wrap)] +pub enum SysVarOpt { + // ThdLocal = bindings::PLUGIN_VAR_THDLOCAL as i32, + /// Variable is read only + ReadOnly = bindings::PLUGIN_VAR_READONLY as i32, + /// Variable is not a server variable + NoSysVar = bindings::PLUGIN_VAR_NOSYSVAR as i32, + /// No command line option + NoCmdOpt = bindings::PLUGIN_VAR_NOCMDOPT as i32, + /// No argument for the command line + NoCmdArg = bindings::PLUGIN_VAR_NOCMDARG as i32, + /// Required CLI argument + ReqCmdArg = bindings::PLUGIN_VAR_RQCMDARG as i32, + /// Optional CLI argument + OptCmdArd = bindings::PLUGIN_VAR_OPCMDARG as i32, + /// Variable is deprecated + Deprecated = bindings::PLUGIN_VAR_DEPRECATED as i32, + // String needs memory allocation - don't expose this + // MemAlloc= bindings::PLUGIN_VAR_MEMALLOC, +} + +type SVInfoInner = ManuallyDrop>; + +/// Basicallly, a system variable will be one of these types, which are dynamic +/// structures on C. Kind of yucky to work with but I think the generic union is +/// a lot more clear. +#[repr(C)] +pub union SysVarInfoU { + bool_t: SVInfoInner, + str_t: SVInfoInner, + enum_t: SVInfoInner, + set_t: SVInfoInner, + int_t: SVInfoInner, + long_t: SVInfoInner, + longlong_t: SVInfoInner, + uint_t: SVInfoInner, + ulong_t: SVInfoInner, + ulonglong_t: SVInfoInner, + double_t: SVInfoInner, + // THD types have a function `resolve` that takes a thread pointer and an + // offset (also a field) +} + +impl SysVarOpt { + pub const fn as_plugin_var_info(self) -> i32 { + self as i32 + } +} + +/// `bindings::mysql_var_update_func` without the `Option` +type SvUpdateFn = + unsafe extern "C" fn(*mut THD, *mut bindings::st_mysql_sys_var, *mut c_void, *const c_void); + +/// A wrapper for system variables. This won't be exposed externally. +/// +/// This provides the interface of update functions. Trait is unsafe because +/// using the wrong structures would cause UB. +pub unsafe trait SysVarInterface: Sized { + /// The C struct representation, e.g. `sysvar_str_t` + type CStructType; + + /// Intermediate type, pointed to by the `CStructType's` `value` pointer + type Intermediate: Copy; + + /// Options to implement by default + const DEFAULT_OPTS: i32; + + /// C struct filled with default values. + const DEFAULT_C_STRUCT: Self::CStructType; + + /// Wrapper for the task of storing the result of the `check` function. + /// Simply converts to our safe rust function "update". + /// + /// - `thd`: thread pointer + /// - `var`: pointer to the c struct + /// - `var_ptr`: where to stash the value + /// - `save`: stash from the `check` function + unsafe extern "C" fn update_wrap( + thd: *mut THD, + var: *mut bindings::st_mysql_sys_var, + target: *mut c_void, + save: *const c_void, + ) { + let new_save: *const Self::Intermediate = save.cast(); + assert!( + !new_save.is_null(), + "got a null pointer from the C interface" + ); + Self::update(&*target.cast(), &*var.cast(), *new_save); + } + + /// Update function: override this if it is pointed to by `UPDATE_FUNC` + unsafe fn update(&self, var: &Self::CStructType, save: Self::Intermediate) { + unimplemented!() + } +} + +/// A const string system variable +/// +/// Consider this very unstable because I don't 100% understand what the SQL +/// side of things does with the malloc / const options +/// +/// Bug: it seems like after updating, the SQL server cannot read the +/// variable... but we can? Do we need to do more in our `update` function? +#[repr(transparent)] +pub struct SysVarConstString(AtomicPtr); + +impl SysVarConstString { + pub const fn new() -> Self { + Self(AtomicPtr::new(std::ptr::null_mut())) + } + + /// Get the current value of this variable. This isn't very efficient since + /// it copies the string, but fixes will come later + /// + /// # Panics + /// + /// Panics if it gets a non-UTF8 C string + pub fn get(&self) -> String { + let ptr = self.0.load(Ordering::SeqCst); + let cs = unsafe { CStr::from_ptr(ptr) }; + cs.to_str() + .unwrap_or_else(|_| panic!("got non-UTF8 string like {}", cs.to_string_lossy())) + .to_owned() + } +} + +unsafe impl SysVarInterface for SysVarConstString { + type CStructType = bindings::sysvar_str_t; + type Intermediate = *mut c_char; + const DEFAULT_OPTS: i32 = bindings::PLUGIN_VAR_STR as i32; + const DEFAULT_C_STRUCT: Self::CStructType = Self::CStructType { + flags: 0, + name: ptr::null(), + comment: ptr::null(), + check: None, + update: None, + value: ptr::null_mut(), + def_val: cstr!("").as_ptr().cast_mut(), + }; +} + +/// An editable c string +/// +/// Note on race conditions: +/// +/// There is a race if the C side reads data while being updated on the Rust +/// side. No worse than what would exist if the plugin was written in C, but +/// important to note it does exist. +#[repr(C)] +pub struct SysVarString { + /// This points to our c string + ptr: AtomicPtr, + mutex: Mutex>, +} + +impl SysVarString { + pub const fn new() -> Self { + Self { + ptr: AtomicPtr::new(std::ptr::null_mut()), + mutex: Mutex::new(None), + } + } + + /// Get the current value of this variable + /// + /// # Panics + /// + /// Panics if the mutex can't be locked + pub fn get(&self) -> Option { + let guard = &*self.mutex.lock().expect("failed to lock mutex"); + let ptr = self.ptr.load(Ordering::SeqCst); + + if !ptr.is_null() && guard.is_some() { + let cs = guard.as_ref().unwrap(); + assert!( + ptr.cast_const() == cs.as_ptr(), + "pointer and c string unsynchronized" + ); + Some(cstr_to_string(cs)) + } else if ptr.is_null() && guard.is_none() { + None + } else { + trace!("pointer {ptr:p} mismatch with guard {guard:?} (likely init condition)"); + // prefer the pointer, must have been set on the C side + let cs = unsafe { CStr::from_ptr(ptr) }; + Some(cstr_to_string(cs)) + } + } +} + +unsafe impl SysVarInterface for SysVarString { + type CStructType = bindings::sysvar_str_t; + type Intermediate = *mut c_char; + const DEFAULT_OPTS: i32 = bindings::PLUGIN_VAR_STR as i32; + const DEFAULT_C_STRUCT: Self::CStructType = Self::CStructType { + flags: 0, + name: ptr::null(), + comment: ptr::null(), + check: None, + update: Some(Self::update_wrap), + value: ptr::null_mut(), + def_val: cstr!("").as_ptr().cast_mut(), + }; + + unsafe fn update(&self, var: &Self::CStructType, save: Self::Intermediate) { + let to_save = save + .as_ref() + .map(|ptr| unsafe { CStr::from_ptr(ptr).to_owned() }); + let guard = &mut *self.mutex.lock().expect("failed to lock mutex"); + *guard = to_save; + let new_ptr = guard + .as_deref() + .map_or(ptr::null_mut(), |cs| cs.as_ptr().cast_mut()); + self.ptr.store(new_ptr, Ordering::SeqCst); + trace!("updated sysvar with inner: {guard:?}"); + } +} + +fn cstr_to_string(cs: &CStr) -> String { + cs.to_str() + .unwrap_or_else(|_| panic!("got non-UTF8 string like {}", cs.to_string_lossy())) + .to_owned() +} + +/// Macro to easily create implementations for all the atomics +macro_rules! atomic_svinterface { + // Special case for boolean, which doesn't have as many fields + ( $atomic_type:ty, + $c_struct_type:ty, + bool, + $default_options:expr $(,)? + ) => { + atomic_svinterface!{ + $atomic_type, + $c_struct_type, + bool, + $default_options, + { def_val: false } + } + }; + + // All other integer types have the same fields + ( $atomic_type:ty, + $c_struct_type:ty, + $inter_type:ty, + $default_options:expr $(,)? + ) => { + atomic_svinterface!{ + $atomic_type, + $c_struct_type, + $inter_type, + $default_options, + { def_val: 0, min_val: <$inter_type>::MIN, max_val: <$inter_type>::MAX, blk_sz: 0 } + } + }; + + // Full generic implementation + ( $atomic_type:ty, // e.g., AtomicI32 + $c_struct_type:ty, // e.g. sysvar_int_t + $inter_type:ty, // e.g. i32 + $default_options:expr, // e.g. PLUGIN_VAR_INT + { $( $extra_struct_fields:tt )* } $(,)? // e.g. default, min, max fields + ) => { + unsafe impl SysVarInterface for $atomic_type { + type CStructType = $c_struct_type; + type Intermediate = $inter_type; + const DEFAULT_OPTS: i32 = ($default_options) as i32; + const DEFAULT_C_STRUCT: Self::CStructType = Self::CStructType { + flags: 0, + name: ptr::null(), + comment: ptr::null(), + check: None, + update: Some(Self::update_wrap), + value: ptr::null_mut(), + $( $extra_struct_fields )* + }; + + unsafe fn update(&self, var: &Self::CStructType, save: Self::Intermediate) { + trace!( + "updated {} system variable to '{:?}'", + std::any::type_name::<$atomic_type>(), save + ); + // based on sql_plugin.cc, seems like there are no null integers + // (can't represent that anyway) + self.store(save, Ordering::SeqCst); + } + } + }; +} + +atomic_svinterface!( + atomic::AtomicBool, + bindings::sysvar_bool_t, + bool, + bindings::PLUGIN_VAR_BOOL, +); +atomic_svinterface!( + atomic::AtomicI32, + bindings::sysvar_int_t, + c_int, + bindings::PLUGIN_VAR_INT, +); +atomic_svinterface!( + atomic::AtomicU32, + bindings::sysvar_uint_t, + c_uint, + bindings::PLUGIN_VAR_INT | bindings::PLUGIN_VAR_UNSIGNED +); +atomic_svinterface!( + atomic::AtomicI64, + bindings::sysvar_longlong_t, + c_longlong, + bindings::PLUGIN_VAR_LONGLONG +); +atomic_svinterface!( + atomic::AtomicU64, + bindings::sysvar_ulonglong_t, + c_ulonglong, + bindings::PLUGIN_VAR_LONGLONG | bindings::PLUGIN_VAR_UNSIGNED +); diff --git a/rust/mariadb/src/plugin/variables_parse.rs b/rust/mariadb/src/plugin/variables_parse.rs new file mode 100644 index 0000000000000..7c8149088e728 --- /dev/null +++ b/rust/mariadb/src/plugin/variables_parse.rs @@ -0,0 +1,152 @@ +//! Variables parser +//! +//! Reimplementation of `check_func_x` in `sql_plugin.cc`. It's just easier to +//! reimpelment these because it means we can use thread safe types. +//! +//! All these functions share similar signatures, see `plugin.h` +//! +//! # Check functions +//! +//! Parse input +//! +//! - `thd`: thread handle +//! - `var`: system variable union. SAFETY: must be correct type (caller varifies) +//! - `save`: pointer to temporary storage +//! - `value`: user-provided value +//! +//! # Update function +//! +//! +//! +//! - `thd`: thread handle +//! - `var`: system variable union. SAFETY: must be correct type (caller varifies) +//! - `save`: pointer to temporary storage +//! - `value`: user-provided value + +use std::cell::UnsafeCell; +use std::ffi::{c_int, c_void, CStr}; +use std::os::raw::c_char; +use std::sync::atomic::{AtomicBool, Ordering}; + +use super::variables::SysVarInfoU; +use crate::bindings; +use crate::helpers::str2bool; + +/// # Safety +/// +/// Variable has to be of the right type, bool +pub unsafe fn check_func_atomic_bool( + thd: *const c_void, + var: *mut c_void, + // var: *mut SysVarInfoU, + save: *mut c_void, + value: *const bindings::st_mysql_value, +) -> c_int { + todo!() + // let sql_val = MySqlValue::from_ptr(value); + // let dest: *const AtomicBool = save.cast(); + // let new_val = match sql_val.value() { + // Value::Int(v) => { + // let tmp = v.unwrap_or(0); + // match tmp { + // 0 => false, + // 1 => true, + // _ => return 1, + // } + // } + // Value::String(s) => { + // let inner = s.expect("got null string"); + // str2bool(&inner) + // .unwrap_or_else(|| panic!("value '{inner}' is not a valid bool indicator")) + // } + // Value::Real(_) => panic!("unexpected real value"), + // }; + // (*dest).store(new_val, Ordering::Relaxed); + // 0 +} + +// pub unsafe fn update_func_atomic_bool( +// thd: *const c_void, +// // var: *mut SysVarInfoU, +// var: *mut c_void, +// var_ptr: *mut c_void, +// save: *const c_void, +// ) { +// let dest: *const AtomicBool = var_ptr.cast(); +// let new_val: u8 = *save.cast(); +// let new_val_bool = match new_val { +// 1 => true, +// 0 => false, +// n => panic!("invalid boolean value {n}"), +// }; +// (*dest).store(new_val_bool, Ordering::Relaxed); +// } + +#[derive(Debug, PartialEq)] +pub enum CliValue { + Int(Option), + Real(Option), + String(Option), +} + +pub struct CliMysqlValue(UnsafeCell); + +impl CliMysqlValue { + /// `item_val_str function`, `item_val_int`, `item_val_real` + pub(crate) fn value(&self) -> CliValue { + unsafe { + match (*self.0.get()).value_type.unwrap()(self.0.get()) + .try_into() + .unwrap() + { + bindings::MYSQL_VALUE_TYPE_INT => self.as_int(), + bindings::MYSQL_VALUE_TYPE_REAL => self.as_real(), + bindings::MYSQL_VALUE_TYPE_STRING => self.as_string(), + x => panic!("unrecognized value type {x}"), + } + } + } + + const unsafe fn from_ptr<'a>(ptr: *const bindings::st_mysql_value) -> &'a Self { + &*ptr.cast() + } + + unsafe fn as_int(&self) -> CliValue { + let mut res = 0i64; + let nul = (*self.0.get()).val_int.unwrap()(self.0.get(), &mut res); + if nul == 0 { + CliValue::Int(Some(res)) + } else { + CliValue::Int(None) + } + } + + unsafe fn as_real(&self) -> CliValue { + let mut res = 0f64; + let nul = (*self.0.get()).val_real.unwrap()(self.0.get(), &mut res); + if nul == 0 { + CliValue::Real(Some(res)) + } else { + CliValue::Real(None) + } + } + + unsafe fn as_string(&self) -> CliValue { + let mut buf = vec![0u8; 512]; + let mut len: c_int = buf.len().try_into().unwrap(); + // This function copies ot the buffer if it fits, returns a temp + // string otherwisw + let ptr = (*self.0.get()).val_str.unwrap()(self.0.get(), buf.as_mut_ptr().cast(), &mut len); + if ptr.is_null() { + return CliValue::String(None); + } + if ptr.cast() == buf.as_ptr() { + buf.truncate(len.try_into().unwrap()); + let res = String::from_utf8(buf).expect("got a buffer that isn't utf8"); + CliValue::String(Some(res)) + } else { + // figure out where the buffer lives otherwise + panic!("buffer too long: needs length {len}"); + } + } +} diff --git a/rust/mariadb/src/plugin/vio.rs b/rust/mariadb/src/plugin/vio.rs new file mode 100644 index 0000000000000..b2e430c10110d --- /dev/null +++ b/rust/mariadb/src/plugin/vio.rs @@ -0,0 +1,36 @@ +//! Representation of the `MYSQL_PLUGIN_VIO` struct, which has methods for +//! reading and writing packets + +#[repr(transparent)] +pub struct Vio (MYSQL_PLUGIN_VIO); + +#[repr(transparent)] +pub struct VioInfo(MYSQL_PLUGIN_VIO_INFO); + +pub struct WriteFailure; + +impl Vio { + fn read_packet(&self) { + let read_fn = self.0.read_packet.expect("read function is null!"); + } + + fn write_packet(&self, packet: &[u8]) -> Result<(), WriteFailure> { + let write_fn = self.0.write_packet.expect("write function is null!"); + let res = unsafe { write_fn(&self, packet.as_ptr(), packet.len()) }; + if res == 0 { + Ok(()) + } else { + Err(WriteFailure) + } + } + + fn info(&self) -> VioInfo { + let info_fn = self.0.write_packet.expect("write function is null!"); + let vi: MaybeUninit::zeroed(); + unsafe { + info_fn(&self, &vi); + vi.assume_init(); + } + vi + } +} diff --git a/rust/mariadb/src/plugin/wrapper.rs b/rust/mariadb/src/plugin/wrapper.rs new file mode 100644 index 0000000000000..f494b806cdfaa --- /dev/null +++ b/rust/mariadb/src/plugin/wrapper.rs @@ -0,0 +1,47 @@ +use std::ffi::{c_int, c_uint, c_void}; +use std::ptr; + +use super::{Init, License, Maturity, PluginType}; +use crate::bindings; + +/// Trait for easily wrapping init/deinit functions +pub trait WrapInit: Init { + #[must_use] + unsafe extern "C" fn wrap_init(_: *mut c_void) -> c_int { + match Self::init() { + Ok(_) => 0, + Err(_) => 1, + } + } + + #[must_use] + unsafe extern "C" fn wrap_deinit(_: *mut c_void) -> c_int { + match Self::deinit() { + Ok(_) => 0, + Err(_) => 1, + } + } +} + +impl WrapInit for T where T: Init {} + +/// New struct with all null values +#[must_use] +#[doc(hidden)] +pub const fn new_null_st_maria_plugin() -> bindings::st_maria_plugin { + bindings::st_maria_plugin { + type_: 0, + info: ptr::null_mut(), + name: ptr::null(), + author: ptr::null(), + descr: ptr::null(), + license: 0, + init: None, + deinit: None, + version: 0, + status_vars: ptr::null_mut(), + system_vars: ptr::null_mut(), + version_info: ptr::null(), + maturity: 0, + } +} diff --git a/rust/mariadb/src/service_sql.rs b/rust/mariadb/src/service_sql.rs new file mode 100644 index 0000000000000..742380b7b25d0 --- /dev/null +++ b/rust/mariadb/src/service_sql.rs @@ -0,0 +1,80 @@ +//! Safe API for `include/mysql/service_sql.h` + +//! +//! FIXME: I think we need to use a different `GLOBAL_SQL_SERVICE` if statically +//! linked, but not yet sure where this is + +use std::cell::UnsafeCell; +use std::ffi::CString; +use std::marker::PhantomData; +use std::ptr::{self, NonNull}; + +mod error; +mod raw; +use bindings::sql_service as GLOBAL_SQL_SERVICE; +use raw::RawConnection; + +pub use self::error::ClientError; +use self::raw::{ClientResult, FetchedRow, RState, RawResult}; +pub use self::raw::{Fetch, Store}; +use crate::bindings; +use crate::helpers::UnsafeSyncCell; + +/// A connection to a local or remote SQL server +pub struct MySqlConn(RawConnection); + +impl MySqlConn { + /// Connect to the local server + /// + /// # Errors + /// + /// Error if the client could not connect + #[inline] + pub fn connect_local() -> ClientResult { + let mut conn = RawConnection::new(); + conn.connect_local()?; + Ok(Self(conn)) + } + + /// Run a query and discard the results + /// + /// # Errors + /// + /// Error if the query could not be completed + #[inline] + pub fn execute(&mut self, q: &str) -> ClientResult<()> { + self.0.query(q)?; + Ok(()) + } + + /// Run a query for lazily loaded results + /// + /// # Errors + /// + /// Error if the row could not be fetched + #[inline] + pub fn query<'a>(&'a mut self, q: &str) -> ClientResult> { + self.0.query(q)?; + let res = self.0.prep_fetch_result()?; + // let cols = + Ok(FetchedRows(res)) + } +} + +/// Representation of returned rows from a query +pub struct FetchedRows<'a>(RawResult<'a, Fetch>); + +impl<'a> FetchedRows<'a> { + #[inline] + pub fn next_row(&mut self) -> Option { + self.0.fetch_next_row() + } +} + +impl Drop for FetchedRows<'_> { + /// Consume the rest of the rows + #[inline] + fn drop(&mut self) { + while self.next_row().is_some() {} + } +} diff --git a/rust/mariadb/src/service_sql/error.rs b/rust/mariadb/src/service_sql/error.rs new file mode 100644 index 0000000000000..4ceacded958d0 --- /dev/null +++ b/rust/mariadb/src/service_sql/error.rs @@ -0,0 +1,30 @@ +use std::fmt::Display; + +use crate::bindings; + +#[non_exhaustive] +pub enum ClientError { + // CommandsOutOfSync = bindings::CR_COMMANDS_OUT_OF_SYNC + /// Error connecting + ConnectError(u32, String), + QueryError(u32, String), + FetchError(u32, String), + Unspecified, +} + +impl From for ClientError { + fn from(value: i32) -> Self { + Self::Unspecified + } +} + +impl Display for ClientError { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + Self::ConnectError(n, e) => write!(f, "connection failed with {n}: '{e}'"), + Self::QueryError(n, e) => write!(f, "query failed with {n}: '{e}'"), + Self::FetchError(n, e) => write!(f, "fetch failed with {n}: '{e}'"), + Self::Unspecified => write!(f, "unspecified error"), + } + } +} diff --git a/rust/mariadb/src/service_sql/raw.rs b/rust/mariadb/src/service_sql/raw.rs new file mode 100644 index 0000000000000..99647451be793 --- /dev/null +++ b/rust/mariadb/src/service_sql/raw.rs @@ -0,0 +1,305 @@ +//! Safe API for a `MySql` connection +//! +//! `RawConnection` comes almost directly from the `diesel` client crate, since +//! they have that all figured out pretty well. Reference: +//! + +use std::cell::UnsafeCell; +use std::ffi::{c_void, CStr, CString}; +use std::marker::PhantomData; +use std::ptr::NonNull; +use std::sync::Once; +use std::{mem, ptr, slice, str}; + +use bindings::{sql_service as GLOBAL_SQL_SERVICE, sql_service_st}; + +use super::error::ClientError; +use crate::{bindings, Value}; + +/// Get a function from our global SQL service +macro_rules! global_func { + ($fname:ident) => { + unsafe { (*GLOBAL_SQL_SERVICE).$fname.unwrap() } + }; +} + +pub struct Fetch; +pub struct Store; +pub trait RState {} + +impl RState for Fetch {} +impl RState for Store {} + +/// Type wrapper for `Result` with a `ClientError` error variant +pub type ClientResult = Result; + +/// A connection to a remote or local server +pub struct RawConnection(NonNull); + +/// Options for connecting to a remote SQL server +pub struct ConnOpts { + host: Option, + database: Option, + user: Option, + password: Option, + port: Option, + unix_socket: Option, + flags: u32, +} + +impl RawConnection { + /// Create a new connection + pub(super) fn new() -> Self { + fn_thread_unsafe_lib_init(); + // Attempt to connect, fail if nonzero (unexpected) + let p_conn = unsafe { global_func!(mysql_init_func)(ptr::null_mut()) }; + let p_conn = NonNull::new(p_conn).expect("OOM: connection allocation failed"); + let result = Self(p_conn); + + // Validate we are using an expected charset + let charset = unsafe { + global_func!(mysql_options_func)( + result.0.as_ptr(), + bindings::mysql_option::MYSQL_SET_CHARSET_NAME, + b"utf8mb4\0".as_ptr().cast(), + ) + }; + assert_eq!( + 0, charset, + "MYSQL_SET_CHARSET_NAME value of {charset} not recognized" + ); + + result + } + + /// Connect to the local SQL server + pub(super) fn connect_local(&mut self) -> ClientResult<()> { + let res = unsafe { global_func!(mysql_real_connect_local_func)(self.0.as_ptr()) }; + self.check_for_errors(ClientError::ConnectError)?; + if res.is_null() { + Ok(()) + } else { + Err(ClientError::ConnectError( + 0, + "unspecified connect error".to_owned(), + )) + } + } + + /// Connect to a remote server + pub fn connect(&mut self, conn_opts: &ConnOpts) -> ClientResult<()> { + let host = conn_opts.host.as_ref(); + let user = conn_opts.user.as_ref(); + let pw = conn_opts.password.as_ref(); + let db = conn_opts.database.as_ref(); + let port = conn_opts.port; + let sock = conn_opts.unix_socket.as_ref(); + + // TODO: see CLIENT_X flags in mariadb_com.h + let res = unsafe { + // Make sure you don't use the fake one! + global_func!(mysql_real_connect_func)( + self.0.as_ptr(), + host.map_or(ptr::null(), |c| c.as_ptr()), + user.map_or(ptr::null(), |c| c.as_ptr()), + pw.map_or(ptr::null(), |c| c.as_ptr()), + db.map_or(ptr::null(), |c| c.as_ptr()), + port.map_or(0, std::convert::Into::into), + sock.map_or(ptr::null(), |c| c.as_ptr()), + conn_opts.flags.into(), + ) + }; + + self.check_for_errors(ClientError::ConnectError)?; + + if res.is_null() { + Ok(()) + } else { + Err(ClientError::QueryError( + 0, + "unspecified query error".to_owned(), + )) + } + } + + /// Execute a query + pub fn query(&mut self, q: &str) -> ClientResult<()> { + unsafe { + let p_self: *const Self = self; + // mysql_real_query in mariadb_lib.c. Real just means use buffers + // instead of c strings + let res = global_func!(mysql_real_query_func)( + p_self.cast_mut().cast(), + q.as_ptr().cast(), + q.len().try_into().unwrap(), + ); + self.check_for_errors(ClientError::QueryError)?; + + if res == 0 { + Ok(()) + } else { + Err(ClientError::QueryError( + 0, + "unspecified query error".to_owned(), + )) + } + } + } + + /// Prepare the result for iteration, but do not store + pub fn prep_fetch_result(&mut self) -> ClientResult> { + let res = unsafe { bindings::mysql_use_result(self.0.as_ptr()) }; + self.check_for_errors(ClientError::QueryError)?; + + match NonNull::new(res) { + Some(ptr) => unsafe { + let field_count = get_field_count(self, ptr.as_ptr())?; + let field_ptr = bindings::mysql_fetch_fields(ptr.as_ptr()); + let fields = slice::from_raw_parts(field_ptr, field_count as usize); + Ok(RawResult { + inner: ptr, + fields: *fields.as_ptr().cast(), + _marker: PhantomData, + }) + }, + None => Err(ClientError::FetchError( + 0, + "unspecified fetch error".to_owned(), + )), + } + } + + /// Get the last error message if available and if so, apply it to function `f` + /// + /// `f` is usually a variant of `ClientError::SomeError`, since those are functions + pub fn check_for_errors(&mut self, f: F) -> ClientResult<()> + where + F: FnOnce(u32, String) -> ClientError, + { + unsafe { + let cs = CStr::from_ptr(global_func!(mysql_error_func)(self.0.as_ptr())); + let emsg = cs.to_string_lossy(); + let errno = global_func!(mysql_errno_func)(self.0.as_ptr()); + + if emsg.is_empty() && errno == 0 { + Ok(()) + } else { + Err(f(errno, emsg.into_owned())) + } + } + } +} + +impl Drop for RawConnection { + /// Close the connection + fn drop(&mut self) { + unsafe { global_func!(mysql_close_func)(self.0.as_ptr()) }; + } +} + +/// Thin wrapper over a result +pub struct RawResult<'a, S: RState> { + inner: NonNull, + fields: &'a [Field], + _marker: PhantomData, +} + +impl<'a> RawResult<'a, Fetch> { + pub fn fetch_next_row(&mut self) -> Option { + let rptr = unsafe { global_func!(mysql_fetch_row_func)(self.inner.as_ptr()) }; + + if rptr.is_null() { + None + } else { + Some(FetchedRow { + inner: rptr, + fields: self.fields, + }) + } + } +} + +impl<'a, S: RState> Drop for RawResult<'a, S> { + /// Free the result + fn drop(&mut self) { + unsafe { global_func!(mysql_free_result_func)(self.inner.as_ptr()) }; + } +} + +/// Representation of a single row, as part of a SQL query +pub struct FetchedRow<'a> { + // *mut *mut c_char + inner: bindings::MYSQL_ROW, + fields: &'a [Field], +} + +impl FetchedRow<'_> { + /// Get the field of a given index + pub fn field_value(&self, index: usize) -> Value { + let field = &self.fields[index]; + assert!(index < self.fields.len()); // extra sanity check + unsafe { + let ptr = *self.inner.add(index); + Value::from_ptr(field.ftype(), ptr.cast(), field.length()) + } + } + + pub const fn field_info(&self, index: usize) -> &Field { + &self.fields[index] + } + + /// Get the total number of fields + pub const fn field_count(&self) -> usize { + self.fields.len() + } +} + +/// Transparant wrapper around a `MYSQL_FIELD` +#[repr(transparent)] +pub struct Field(UnsafeCell); + +impl Field { + pub fn length(&self) -> usize { + unsafe { (*self.0.get()).length.try_into().unwrap() } + } + + pub fn max_length(&self) -> usize { + unsafe { (*self.0.get()).max_length.try_into().unwrap() } + } + + pub fn name(&self) -> &str { + unsafe { + let inner = &*self.0.get(); + let name_slice = slice::from_raw_parts(inner.name.cast(), inner.name_length as usize); + str::from_utf8(name_slice).expect("unexpected: non-utf8 identifier") + } + } + + fn ftype(&self) -> bindings::enum_field_types { + unsafe { (*self.0.get()).type_ } + } +} + +unsafe fn get_field_count( + conn: &mut RawConnection, + q_result: *const bindings::MYSQL_RES, +) -> ClientResult { + let res = unsafe { global_func!(mysql_num_fields_func)(q_result.cast_mut()) }; + conn.check_for_errors(ClientError::QueryError)?; + Ok(res) +} + +fn fn_thread_unsafe_lib_init() { + /// + static MYSQL_THREAD_UNSAFE_INIT: Once = Once::new(); + + MYSQL_THREAD_UNSAFE_INIT.call_once(|| { + // mysql_library_init is defined by `#define mysql_library_init mysql_server_init` + // which isn't picked up by bindgen + let ret = unsafe { bindings::mysql_server_init(0, ptr::null_mut(), ptr::null_mut()) }; + assert_eq!( + ret, 0, + "Unable to perform MySQL global initialization. Return code: {ret}" + ); + }); +} diff --git a/rust/plugins/README.md b/rust/plugins/README.md new file mode 100644 index 0000000000000..dd2a57e6bf868 --- /dev/null +++ b/rust/plugins/README.md @@ -0,0 +1,4 @@ +# Plugins + +This directory contains plugins that are intended to be usable in practice, as +opposed to example plugins. diff --git a/rust/plugins/keymgt-clevis/Cargo.toml b/rust/plugins/keymgt-clevis/Cargo.toml new file mode 100644 index 0000000000000..a9384cd5c4227 --- /dev/null +++ b/rust/plugins/keymgt-clevis/Cargo.toml @@ -0,0 +1,12 @@ +[package] +name = "keymgt-clevis" +version = "0.1.0" +edition = '2021' + +[lib] +crate-type = ["cdylib"] + +[dependencies] +josekit = "0.8.2" +mariadb = { path = "../../mariadb" } +ureq = "2.6.2" diff --git a/rust/plugins/keymgt-clevis/src/lib.rs b/rust/plugins/keymgt-clevis/src/lib.rs new file mode 100644 index 0000000000000..f9c783859166a --- /dev/null +++ b/rust/plugins/keymgt-clevis/src/lib.rs @@ -0,0 +1,185 @@ +//! EXAMPLE ONLY: DO NOT USE IN PRODUCTION! + +#![allow(unused)] + +use std::cell::UnsafeCell; +use std::ffi::c_void; +use std::fmt::Write; +use std::sync::atomic::{AtomicU32, Ordering}; +use std::sync::Mutex; + +use josekit::jws; +use mariadb::log::{debug, error, info}; +use mariadb::plugin::encryption::{Encryption, Flags, KeyError, KeyManager}; +use mariadb::plugin::{ + register_plugin, Init, InitError, License, Maturity, PluginType, SysVarConstString, SysVarOpt, +}; +use mariadb::service_sql::{ClientError, Fetch, FetchedRows, MySqlConn}; + +const KEY_TABLE: &str = "mysql.clevis_keys"; +const SERVER_TABLE: &str = "mysql.clevis_servers"; +/// Max length a key can be, used for table size and buffer checking +const KEY_MAX_BYTES: usize = 16; + +/// String system variable to set server address +static TANG_SERVER: SysVarConstString = SysVarConstString::new(); + +struct KeyMgtClevis; + +/// Get the JWS body from a server +fn fetch_jws() -> String { + // FIXME: error handling + let url = format!("https://{}", TANG_SERVER.get()); + let body: String = ureq::get("http://example.com") + .call() + .unwrap_or_else(|_| panic!("http request for '{url}' failed")) + .into_string() + .expect("http request larger than 10MB"); + todo!(); + body +} + +fn make_new_key(conn: &MySqlConn) -> Result { + let server = TANG_SERVER.get(); + format!( + "INSERT IGNORE INTO {KEY_TABLE} + SET key_server = {server} + RETURNING jws" + ); + + // get the jws value + let jws: &str; + + todo!() +} + +impl Init for KeyMgtClevis { + /// Create needed tables + fn init() -> Result<(), InitError> { + debug!("init for KeyMgtClevis"); + + let mut conn = MySqlConn::connect_local().map_err(|_| InitError)?; + conn.execute(&format!( + "CREATE TABLE IF NOT EXISTS {KEY_TABLE} ( + key_id INT UNSIGNED NOT NULL, + key_version INT UNSIGNED NOT NULL, + key_server VARBINARY(64) NOT NULL, + key VARBINARY((16) NOT NULL, + PRIMARY KEY (key_id, key_version) + ) ENGINE=InnoDB" + )) + .map_err(|_| InitError)?; + conn.execute(&format!( + "CREATE TABLE IF NOT EXISTS {SERVER_TABLE} ( + server VARBINARY(64) NOT NULL PRIMARY KEY, + verify VARBINARY(2048) + enc VARBINARY(2048) + ) ENGINE=InnoDB" + )) + .map_err(|_| InitError)?; + + debug!("finished init for KeyMgtClevis"); + Ok(()) + } + + fn deinit() -> Result<(), InitError> { + debug!("deinit for KeyMgtClevis"); + Ok(()) + } +} + +/// Execute a query, printing an error and returning KeyError if needed. No result +fn run_execute(conn: &mut MySqlConn, q: &str, key_id: u32) -> Result<(), KeyError> { + conn.execute(q).map_err(|e| { + error!("clevis: get_latest_key_version {key_id} - SQL error on {q} - {e}"); + KeyError::Other + }) +} + +/// Execute a query, printing an error, return the result +fn run_query<'a>( + conn: &'a mut MySqlConn, + q: &str, + key_id: u32, +) -> Result, KeyError> { + conn.query(q).map_err(|e| { + error!("clevis: get_latest_key_version {key_id} - SQL error on {q} - {e}"); + KeyError::Other + }) +} + +impl KeyManager for KeyMgtClevis { + fn get_latest_key_version(key_id: u32) -> Result { + let mut conn = MySqlConn::connect_local().map_err(|_| KeyError::Other)?; + let mut q = format!("SELECT key_version FROM {KEY_TABLE} WHERE key_id = {key_id}"); + let _ = run_query(&mut conn, &q, key_id)?; + + // fuund! fetch result, parse to int + // if let Some(row) = todo!() { + if false { + todo!() + // return Ok(); + } + + // directly push format string + let key_version: u32 = 1; + write!(q, "AND key_version = {key_version} FOR UPDATE"); + + run_execute(&mut conn, "START TRANSACTION", key_id)?; + run_query(&mut conn, &q, key_id)?; + + let Ok(new_key) = make_new_key(&conn) else { + run_execute(&mut conn, "ROLLBACK", key_id)?; + todo!(); + }; + + let q = format!( + r#"INSERT INTO {KEY_TABLE} VALUES ( + {key_id}, {key_version}, "{server_name}", {new_key} )"#, + server_name = TANG_SERVER.get() + ); + run_execute(&mut conn, &q, key_id)?; + + todo!() + } + + fn get_key(key_id: u32, key_version: u32, dst: &mut [u8]) -> Result<(), KeyError> { + let mut conn = MySqlConn::connect_local().map_err(|_| KeyError::Other)?; + let q = format!( + "SELECT key FROM {KEY_TABLE} WHERE key_id = {key_id} AND key_version = {key_version}" + ); + conn.query(&q).map_err(|_| KeyError::Other)?; + // TODO: generate key with server + let key: &[u8]; + todo!(); + dst[..key.len()].copy_from_slice(key); + Ok(()) + } + + fn key_length(_key_id: u32, _key_version: u32) -> Result { + Ok(KEY_MAX_BYTES) + } +} + +register_plugin! { + KeyMgtClevis, + ptype: PluginType::MariaEncryption, + name: "clevis_key_management", + author: "Daniel Black & Trevor Gross", + description: "Clevis key management plugin", + license: License::Gpl, + maturity: Maturity::Experimental, + version: "0.1", + init: KeyMgtClevis, + encryption: false, + variables: [ + SysVar { + ident: TANG_SERVER, + vtype: SysVarConstString, + name: "tang_server", + description: "the tang server for key exchange", + options: [SysVarOpt::OptCmdArd], + default: "localhost" + } + ] +}