diff --git a/rust/examples/keymgt-debug/src/lib.rs b/rust/examples/keymgt-debug/src/lib.rs index 2405248cdc3d6..2a6c12c5ee0b7 100644 --- a/rust/examples/keymgt-debug/src/lib.rs +++ b/rust/examples/keymgt-debug/src/lib.rs @@ -9,18 +9,21 @@ use std::cell::UnsafeCell; use std::ffi::c_void; -use std::sync::atomic::{AtomicU32, Ordering}; +use std::sync::atomic::{AtomicI32, AtomicU32, Ordering}; use std::sync::Mutex; use mariadb::log::{self, debug, trace}; use mariadb::plugin::encryption::{Encryption, Flags, KeyError, KeyManager}; use mariadb::plugin::{ register_plugin, Init, InitError, License, Maturity, PluginType, SysVarConstString, SysVarOpt, + SysVarString, }; const KEY_LENGTH: usize = 4; static KEY_VERSION: AtomicU32 = AtomicU32::new(1); -static TEST_SYSVAR_STR: SysVarConstString = SysVarConstString::new(); +static TEST_SYSVAR_CONST_STR: SysVarConstString = SysVarConstString::new(); +static TEST_SYSVAR_STR: SysVarString = SysVarString::new(); +static TEST_SYSVAR_I32: AtomicI32 = AtomicI32::new(10); struct DebugKeyMgmt; @@ -28,7 +31,15 @@ impl Init for DebugKeyMgmt { fn init() -> Result<(), InitError> { log::set_max_level(log::LevelFilter::Trace); debug!("DebugKeyMgmt get_latest_key_version"); - trace!("current sysvar: {}", TEST_SYSVAR_STR.get()); + trace!( + "current const str sysvar: {:?}", + TEST_SYSVAR_CONST_STR.get() + ); + trace!("current str sysvar: {:?}", TEST_SYSVAR_STR.get()); + trace!( + "current sysvar: {}", + TEST_SYSVAR_I32.load(Ordering::Relaxed) + ); Ok(()) } @@ -91,12 +102,28 @@ register_plugin! { encryption: false, variables: [ SysVar { - ident: TEST_SYSVAR_STR, + ident: TEST_SYSVAR_CONST_STR, vtype: SysVarConstString, - name: "test_sysvar", + name: "test_sysvar_const_string", description: "this is a description", options: [SysVarOpt::OptCmdArd], default: "default value" + }, + SysVar { + ident: TEST_SYSVAR_STR, + vtype: SysVarString, + name: "test_sysvar_string", + description: "this is a description", + options: [SysVarOpt::OptCmdArd], + default: "other default value" + }, + SysVar { + ident: TEST_SYSVAR_I32, + vtype: AtomicI32, + name: "test_sysvar_i32", + description: "this is a description", + options: [SysVarOpt::OptCmdArd], + default: 67 } ] } diff --git a/rust/macros/src/lib.rs b/rust/macros/src/lib.rs index 93111f5bf78c4..0eb97cf11c3f3 100644 --- a/rust/macros/src/lib.rs +++ b/rust/macros/src/lib.rs @@ -2,11 +2,9 @@ #![warn(clippy::nursery)] #![warn(clippy::str_to_string)] #![warn(clippy::missing_inline_in_public_items)] -#![allow(clippy::missing_const_for_fn)] #![allow(clippy::missing_panics_doc)] #![allow(clippy::must_use_candidate)] #![allow(clippy::option_if_let_else)] -#![allow(clippy::cast_possible_truncation)] mod fields; mod helpers; diff --git a/rust/macros/src/parse_vars.rs b/rust/macros/src/parse_vars.rs index 60016ff3284d9..beaf120f1c166 100644 --- a/rust/macros/src/parse_vars.rs +++ b/rust/macros/src/parse_vars.rs @@ -190,18 +190,23 @@ impl VariableInfo { let max = process_default_override(&self.max, "max_val")?; let interval = process_default_override(&self.interval, "blk_sz")?; - let st_ident = Ident::new(&format!("_st_sysvar_{}", name.value()), Span::call_site()); - // https://github.com/rust-lang/rust/issues/86935#issuecomment-1146670057 - let ty_wrap = Ident::new( - &format!("_st_sysvar_Type{}", name.value()), + let st_ident = Ident::new(&format!("_sysvar_st_{}", name.value()), Span::call_site()); + let st_tycheck = Ident::new( + &format!("_sysvar_tychk_{}", name.value()), Span::call_site(), ); + // https://github.com/rust-lang/rust/issues/86935#issuecomment-1146670057 + let ty_wrap = Ident::new(&format!("_sysvar_Type{}", name.value()), Span::call_site()); + // check to verify our vars are of the right type for our idents + let ty_check = quote! { static #st_tycheck: &#vtype = &#ident; }; let usynccell = quote! { ::mariadb::internals::UnsafeSyncCell }; let res = quote! { type #ty_wrap = T; + #ty_check + static #st_ident: #usynccell<#ty_wrap::<#ty_as_svwrap::CStructType>> = unsafe { #usynccell::new( #ty_wrap::<#ty_as_svwrap::CStructType> { @@ -223,7 +228,7 @@ impl VariableInfo { }; - Ok((st_ident.clone(), res)) + Ok((st_ident, res)) } /// Take the options vector, parse it as an array, bitwise or the output, @@ -319,7 +324,7 @@ fn verify_field_order(fields: &[String]) -> Result<(), String> { /// Process "default override" style fields by these rules: /// -/// - If `field` is `None`, return an empty TokenStream +/// - If `field` is `None`, return an empty `TokenStream` /// - Enforce it is a literal /// - If it is a literal string, change it to a `cstr` /// diff --git a/rust/macros/src/register_plugin.rs b/rust/macros/src/register_plugin.rs index d9eec5dd52136..02ce3fbc4bdfc 100644 --- a/rust/macros/src/register_plugin.rs +++ b/rust/macros/src/register_plugin.rs @@ -94,7 +94,7 @@ impl Parse for PluginInfo { } impl PluginInfo { - fn new(main_ty: Ident, span: Span) -> Self { + const fn new(main_ty: Ident, span: Span) -> Self { Self { main_ty, span, diff --git a/rust/macros/tests/fail/02-extra-sysargs.rs b/rust/macros/tests/fail/02-extra-sysargs.rs index 13ae14eee1d35..342b35b833af1 100644 --- a/rust/macros/tests/fail/02-extra-sysargs.rs +++ b/rust/macros/tests/fail/02-extra-sysargs.rs @@ -13,7 +13,7 @@ register_plugin! { encryption: false, variables: [ SysVar { - ident: _SYSVAR_STR, + ident: _SYSVAR_CONST_STR, vtype: SysVarConstString, name: "test_sysvar", description: "this is a description", diff --git a/rust/macros/tests/fail/03-wrong-types.rs b/rust/macros/tests/fail/03-wrong-types.rs new file mode 100644 index 0000000000000..eb25475a9b959 --- /dev/null +++ b/rust/macros/tests/fail/03-wrong-types.rs @@ -0,0 +1,30 @@ +/* + * Verify our added check for identifier-type mismatch + */ + +include!("../include.rs"); + +register_plugin! { + TestPlugin, + ptype: PluginType::MariaEncryption, + name: "debug_key_management", + author: "Trevor Gross", + description: "Debug key management plugin", + license: License::Gpl, + maturity: Maturity::Experimental, + version: "0.1", + init: TestPlugin, + encryption: false, + variables: [ + SysVar { + ident: _SYSVAR_ATOMIC, + vtype: SysVarConstString, + name: "test_sysvar", + description: "this is a description", + options: [SysVarOpt::ReadOnly, SysVarOpt::NoCmdOpt], + default: "default value" + } + ] +} + +fn main() {} diff --git a/rust/macros/tests/fail/03-wrong-types.stderr b/rust/macros/tests/fail/03-wrong-types.stderr new file mode 100644 index 0000000000000..31dee11bc8163 --- /dev/null +++ b/rust/macros/tests/fail/03-wrong-types.stderr @@ -0,0 +1,15 @@ +error[E0308]: mismatched types + --> tests/fail/03-wrong-types.rs:7:1 + | +7 | / register_plugin! { +8 | | TestPlugin, +9 | | ptype: PluginType::MariaEncryption, +10 | | name: "debug_key_management", +... | +27 | | ] +28 | | } + | |_^ expected `&SysVarConstString`, found `&AtomicI32` + | + = note: expected reference `&'static mariadb::plugin::SysVarConstString` + found reference `&AtomicI32` + = note: this error originates in the macro `register_plugin` (in Nightly builds, run with -Z macro-backtrace for more info) diff --git a/rust/macros/tests/include.rs b/rust/macros/tests/include.rs index 0b6eeda01e76f..e618cd5828bba 100644 --- a/rust/macros/tests/include.rs +++ b/rust/macros/tests/include.rs @@ -7,7 +7,7 @@ use mariadb::plugin::*; pub use mariadb_macros::register_plugin; static _SYSVAR_ATOMIC: AtomicI32 = AtomicI32::new(0); -static _SYSVAR_STR: SysVarConstString = SysVarConstString::new(); +static _SYSVAR_CONST_STR: SysVarConstString = SysVarConstString::new(); struct TestPlugin; diff --git a/rust/macros/tests/pass/03-with-sysargs.rs b/rust/macros/tests/pass/03-with-sysargs.rs index c678fa2d4f87a..01f3cc8053a4e 100644 --- a/rust/macros/tests/pass/03-with-sysargs.rs +++ b/rust/macros/tests/pass/03-with-sysargs.rs @@ -13,7 +13,7 @@ register_plugin! { encryption: false, variables: [ SysVar { - ident: _SYSVAR_STR, + ident: _SYSVAR_CONST_STR, vtype: SysVarConstString, name: "test_sysvar", description: "this is a description", @@ -37,7 +37,7 @@ fn main() { let plugin_def: &st_maria_plugin = unsafe { &*(_maria_plugin_declarations_[0]).get() }; let sysv_ptr: *mut *mut st_mysql_sys_var = plugin_def.system_vars; - let sysvar_st: *const sysvar_str_t = _st_sysvar_test_sysvar.get(); + let sysvar_st: *const sysvar_str_t = _sysvar_st_test_sysvar.get(); let sysvar_arr: &[UnsafeSyncCell<*mut sysvar_common_t>] = &_plugin_debug_key_management_sysvars; let idx_0: *mut sysvar_common_t = unsafe { *sysvar_arr[0].get() }; let idx_1: *mut sysvar_common_t = unsafe { *sysvar_arr[1].get() }; diff --git a/rust/mariadb/src/helpers.rs b/rust/mariadb/src/helpers.rs index 981070f94041d..89a61b56dbd1f 100644 --- a/rust/mariadb/src/helpers.rs +++ b/rust/mariadb/src/helpers.rs @@ -30,6 +30,7 @@ impl UnsafeSyncCell { } } +#[allow(clippy::non_send_fields_in_send_ty)] unsafe impl Send for UnsafeSyncCell {} unsafe impl Sync for UnsafeSyncCell {} diff --git a/rust/mariadb/src/lib.rs b/rust/mariadb/src/lib.rs index 8c7d4e4e63c82..240482d7b4aaf 100644 --- a/rust/mariadb/src/lib.rs +++ b/rust/mariadb/src/lib.rs @@ -1,11 +1,13 @@ //! Crate representing safe abstractions over MariaDB bindings #![warn(clippy::pedantic)] #![warn(clippy::nursery)] -#![warn(clippy::missing_inline_in_public_items)] #![warn(clippy::str_to_string)] +#![allow(clippy::option_if_let_else)] #![allow(clippy::missing_errors_doc)] #![allow(clippy::must_use_candidate)] #![allow(clippy::useless_conversion)] +#![allow(clippy::cast_possible_wrap)] +#![allow(clippy::missing_safety_doc)] #![allow(clippy::missing_const_for_fn)] #![allow(clippy::module_name_repetitions)] #![allow(clippy::missing_inline_in_public_items)] diff --git a/rust/mariadb/src/plugin.rs b/rust/mariadb/src/plugin.rs index 8d84e900b13dc..cf151447eae16 100644 --- a/rust/mariadb/src/plugin.rs +++ b/rust/mariadb/src/plugin.rs @@ -72,7 +72,7 @@ mod variables; mod variables_parse; mod wrapper; pub use mariadb_macros::register_plugin; -pub use variables::{SysVarConstString, SysVarOpt}; +pub use variables::{SysVarConstString, SysVarOpt, SysVarString}; /// Commonly used plugin types for reexport pub mod prelude { diff --git a/rust/mariadb/src/plugin/encryption.rs b/rust/mariadb/src/plugin/encryption.rs index 0a8a81fd425a0..22f0ca10515c9 100644 --- a/rust/mariadb/src/plugin/encryption.rs +++ b/rust/mariadb/src/plugin/encryption.rs @@ -47,20 +47,20 @@ pub enum EncryptionError { pub struct Flags(i32); impl Flags { - pub(crate) fn new(value: i32) -> Self { + pub(crate) const fn new(value: i32) -> Self { Self(value) } - pub(crate) fn should_encrypt(self) -> bool { + pub(crate) const fn should_encrypt(self) -> bool { (self.0 & bindings::ENCRYPTION_FLAG_ENCRYPT as i32) != 0 } - pub(crate) fn should_decrypt(self) -> bool { + pub(crate) const fn should_decrypt(self) -> bool { // (self.0 & bindings::ENCRYPTION_FLAG_DECRYPT as i32) != 0 !self.should_encrypt() } - pub fn nopad(&self) -> bool { + pub const fn nopad(&self) -> bool { (self.0 & bindings::ENCRYPTION_FLAG_NOPAD as i32) != 0 } } diff --git a/rust/mariadb/src/plugin/variables.rs b/rust/mariadb/src/plugin/variables.rs index f8e7840d645c6..c7e20037f9c49 100644 --- a/rust/mariadb/src/plugin/variables.rs +++ b/rust/mariadb/src/plugin/variables.rs @@ -81,15 +81,8 @@ pub unsafe trait SysVarInterface: Sized { /// The C struct representation, e.g. `sysvar_str_t` type CStructType; - /// Intermediate type, pointed to by the CStructType's `value` pointer - type Intermediate; - - /// Associated const with an optional function pointer to an update - /// function. - /// - /// If a sysvar type should use a custom update function, implmeent `update` - /// and set this value to `update_wrap`. - const UPDATE_FUNC: Option = None; + /// Intermediate type, pointed to by the `CStructType's` `value` pointer + type Intermediate: Copy; /// Options to implement by default const DEFAULT_OPTS: i32; @@ -111,11 +104,15 @@ pub unsafe trait SysVarInterface: Sized { save: *const c_void, ) { let new_save: *const Self::Intermediate = save.cast(); - Self::update(&*target.cast(), &*var.cast(), new_save.as_ref()); + assert!( + !new_save.is_null(), + "got a null pointer from the C interface" + ); + Self::update(&*target.cast(), &*var.cast(), *new_save); } /// Update function: override this if it is pointed to by `UPDATE_FUNC` - unsafe fn update(&self, var: &Self::CStructType, save: Option<&Self::Intermediate>) { + unsafe fn update(&self, var: &Self::CStructType, save: Self::Intermediate) { unimplemented!() } } @@ -137,6 +134,10 @@ impl SysVarConstString { /// Get the current value of this variable. This isn't very efficient since /// it copies the string, but fixes will come later + /// + /// # Panics + /// + /// Panics if it gets a non-UTF8 C string pub fn get(&self) -> String { let ptr = self.0.load(Ordering::SeqCst); let cs = unsafe { CStr::from_ptr(ptr) }; @@ -184,6 +185,10 @@ impl SysVarString { } /// Get the current value of this variable + /// + /// # Panics + /// + /// Panics if the mutex can't be locked pub fn get(&self) -> Option { let guard = &*self.mutex.lock().expect("failed to lock mutex"); let ptr = self.ptr.load(Ordering::SeqCst); @@ -194,11 +199,11 @@ impl SysVarString { ptr.cast_const() == cs.as_ptr(), "pointer and c string unsynchronized" ); - Some(cstr_to_string(&cs)) + Some(cstr_to_string(cs)) } else if ptr.is_null() && guard.is_none() { None } else { - warn!("pointer {ptr:p} mismatch with guard {guard:?}"); + trace!("pointer {ptr:p} mismatch with guard {guard:?} (likely init condition)"); // prefer the pointer, must have been set on the C side let cs = unsafe { CStr::from_ptr(ptr) }; Some(cstr_to_string(cs)) @@ -209,7 +214,6 @@ impl SysVarString { unsafe impl SysVarInterface for SysVarString { type CStructType = bindings::sysvar_str_t; type Intermediate = *mut c_char; - const UPDATE_FUNC: Option = Some(Self::update_wrap); const DEFAULT_OPTS: i32 = bindings::PLUGIN_VAR_STR as i32; const DEFAULT_C_STRUCT: Self::CStructType = Self::CStructType { flags: 0, @@ -221,14 +225,17 @@ unsafe impl SysVarInterface for SysVarString { def_val: cstr!("").as_ptr().cast_mut(), }; - unsafe fn update(&self, var: &Self::CStructType, save: Option<&Self::Intermediate>) { - let to_save = save.map(|ptr| unsafe { CStr::from_ptr(*ptr).to_owned() }); + unsafe fn update(&self, var: &Self::CStructType, save: Self::Intermediate) { + let to_save = save + .as_ref() + .map(|ptr| unsafe { CStr::from_ptr(ptr).to_owned() }); let guard = &mut *self.mutex.lock().expect("failed to lock mutex"); *guard = to_save; let new_ptr = guard .as_deref() .map_or(ptr::null_mut(), |cs| cs.as_ptr().cast_mut()); self.ptr.store(new_ptr, Ordering::SeqCst); + trace!("updated sysvar with inner: {guard:?}"); } } @@ -281,27 +288,24 @@ macro_rules! atomic_svinterface { type CStructType = $c_struct_type; type Intermediate = $inter_type; const DEFAULT_OPTS: i32 = ($default_options) as i32; - const UPDATE_FUNC: Option = Some(Self::update_wrap as SvUpdateFn); const DEFAULT_C_STRUCT: Self::CStructType = Self::CStructType { flags: 0, name: ptr::null(), comment: ptr::null(), check: None, - update: None, + update: Some(Self::update_wrap), value: ptr::null_mut(), $( $extra_struct_fields )* }; - unsafe fn update(&self, var: &Self::CStructType, save: Option<&Self::Intermediate>) { + unsafe fn update(&self, var: &Self::CStructType, save: Self::Intermediate) { trace!( "updated {} system variable to '{:?}'", std::any::type_name::<$atomic_type>(), save ); // based on sql_plugin.cc, seems like there are no null integers // (can't represent that anyway) - let new = save.expect("somehow got a null pointer"); - self.store(*new, Ordering::SeqCst); - trace!("updated system variable to '{}'", new); + self.store(save, Ordering::SeqCst); } } }; diff --git a/rust/mariadb/src/plugin/variables_parse.rs b/rust/mariadb/src/plugin/variables_parse.rs index a8293ce3a63cc..7c8149088e728 100644 --- a/rust/mariadb/src/plugin/variables_parse.rs +++ b/rust/mariadb/src/plugin/variables_parse.rs @@ -107,7 +107,7 @@ impl CliMysqlValue { } } - unsafe fn from_ptr<'a>(ptr: *const bindings::st_mysql_value) -> &'a Self { + const unsafe fn from_ptr<'a>(ptr: *const bindings::st_mysql_value) -> &'a Self { &*ptr.cast() } diff --git a/rust/mariadb/src/service_sql/raw.rs b/rust/mariadb/src/service_sql/raw.rs index eca03435d4a10..99647451be793 100644 --- a/rust/mariadb/src/service_sql/raw.rs +++ b/rust/mariadb/src/service_sql/raw.rs @@ -244,12 +244,12 @@ impl FetchedRow<'_> { } } - pub fn field_info(&self, index: usize) -> &Field { + pub const fn field_info(&self, index: usize) -> &Field { &self.fields[index] } /// Get the total number of fields - pub fn field_count(&self) -> usize { + pub const fn field_count(&self) -> usize { self.fields.len() } } diff --git a/rust/plugins/keymgt-clevis/src/lib.rs b/rust/plugins/keymgt-clevis/src/lib.rs index dca3f4481b7b6..f9c783859166a 100644 --- a/rust/plugins/keymgt-clevis/src/lib.rs +++ b/rust/plugins/keymgt-clevis/src/lib.rs @@ -48,7 +48,7 @@ fn make_new_key(conn: &MySqlConn) -> Result { ); // get the jws value - let jws: &str = todo!(); + let jws: &str; todo!() } @@ -117,7 +117,8 @@ impl KeyManager for KeyMgtClevis { // fuund! fetch result, parse to int // if let Some(row) = todo!() { if false { - return Ok(todo!()); + todo!() + // return Ok(); } // directly push format string @@ -129,7 +130,7 @@ impl KeyManager for KeyMgtClevis { let Ok(new_key) = make_new_key(&conn) else { run_execute(&mut conn, "ROLLBACK", key_id)?; - return todo!(); + todo!(); }; let q = format!( @@ -149,7 +150,8 @@ impl KeyManager for KeyMgtClevis { ); conn.query(&q).map_err(|_| KeyError::Other)?; // TODO: generate key with server - let key: &[u8] = todo!(); + let key: &[u8]; + todo!(); dst[..key.len()].copy_from_slice(key); Ok(()) }