diff --git a/datafusion-cli/Cargo.lock b/datafusion-cli/Cargo.lock index b8c2528d22d5..b2bba578fe8d 100644 --- a/datafusion-cli/Cargo.lock +++ b/datafusion-cli/Cargo.lock @@ -1110,6 +1110,7 @@ dependencies = [ "object_store", "parking_lot", "predicates", + "regex", "rstest", "rustyline", "tokio", diff --git a/datafusion-cli/Cargo.toml b/datafusion-cli/Cargo.toml index 82c4a4b71e1d..04345ed112f8 100644 --- a/datafusion-cli/Cargo.toml +++ b/datafusion-cli/Cargo.toml @@ -40,6 +40,7 @@ env_logger = "0.9" mimalloc = { version = "0.1", default-features = false } object_store = { version = "0.7.0", features = ["aws", "gcp"] } parking_lot = { version = "0.12" } +regex = "1.8" rustyline = "11.0" tokio = { version = "1.24", features = ["macros", "rt", "rt-multi-thread", "sync", "parking_lot"] } url = "2.2" diff --git a/datafusion-cli/src/main.rs b/datafusion-cli/src/main.rs index 8429738a0953..25f412608453 100644 --- a/datafusion-cli/src/main.rs +++ b/datafusion-cli/src/main.rs @@ -26,10 +26,11 @@ use datafusion_cli::{ exec, print_format::PrintFormat, print_options::PrintOptions, DATAFUSION_CLI_VERSION, }; use mimalloc::MiMalloc; +use std::collections::HashMap; use std::env; use std::path::Path; use std::str::FromStr; -use std::sync::Arc; +use std::sync::{Arc, OnceLock}; #[global_allocator] static GLOBAL: MiMalloc = MiMalloc; @@ -252,26 +253,125 @@ fn is_valid_memory_pool_size(size: &str) -> Result<(), String> { } } -fn extract_memory_pool_size(size: &str) -> Result { - let mut size = size; - let factor = if let Some(last_char) = size.chars().last() { - match last_char { - 'm' | 'M' => { - size = &size[..size.len() - 1]; - 1024 * 1024 - } - 'g' | 'G' => { - size = &size[..size.len() - 1]; - 1024 * 1024 * 1024 - } - _ => 1, +#[derive(Debug, Clone, Copy)] +enum ByteUnit { + Byte, + KiB, + MiB, + GiB, + TiB, +} + +impl ByteUnit { + fn multiplier(&self) -> usize { + match self { + ByteUnit::Byte => 1, + ByteUnit::KiB => 1 << 10, + ByteUnit::MiB => 1 << 20, + ByteUnit::GiB => 1 << 30, + ByteUnit::TiB => 1 << 40, } + } +} + +fn extract_memory_pool_size(size: &str) -> Result { + fn byte_suffixes() -> &'static HashMap<&'static str, ByteUnit> { + static BYTE_SUFFIXES: OnceLock> = OnceLock::new(); + BYTE_SUFFIXES.get_or_init(|| { + let mut m = HashMap::new(); + m.insert("b", ByteUnit::Byte); + m.insert("k", ByteUnit::KiB); + m.insert("kb", ByteUnit::KiB); + m.insert("m", ByteUnit::MiB); + m.insert("mb", ByteUnit::MiB); + m.insert("g", ByteUnit::GiB); + m.insert("gb", ByteUnit::GiB); + m.insert("t", ByteUnit::TiB); + m.insert("tb", ByteUnit::TiB); + m + }) + } + + fn suffix_re() -> &'static regex::Regex { + static SUFFIX_REGEX: OnceLock = OnceLock::new(); + SUFFIX_REGEX.get_or_init(|| regex::Regex::new(r"^(-?[0-9]+)([a-z]+)?$").unwrap()) + } + + let lower = size.to_lowercase(); + if let Some(caps) = suffix_re().captures(&lower) { + let num_str = caps.get(1).unwrap().as_str(); + let num = num_str.parse::().map_err(|_| { + format!("Invalid numeric value in memory pool size '{}'", size) + })?; + + let suffix = caps.get(2).map(|m| m.as_str()).unwrap_or("b"); + let unit = byte_suffixes() + .get(suffix) + .ok_or_else(|| format!("Invalid memory pool size '{}'", size))?; + + Ok(num * unit.multiplier()) } else { - return Err(format!("Invalid memory pool size '{}'", size)); - }; + Err(format!("Invalid memory pool size '{}'", size)) + } +} - match size.parse::() { - Ok(size) if size > 0 => Ok(factor * size), - _ => Err(format!("Invalid memory pool size '{}'", size)), +#[cfg(test)] +mod tests { + use super::*; + + fn assert_conversion(input: &str, expected: Result) { + let result = extract_memory_pool_size(input); + match expected { + Ok(v) => assert_eq!(result.unwrap(), v), + Err(e) => assert_eq!(result.unwrap_err(), e), + } + } + + #[test] + fn memory_pool_size() -> Result<(), String> { + // Test basic sizes without suffix, assumed to be bytes + assert_conversion("5", Ok(5)); + assert_conversion("100", Ok(100)); + + // Test various units + assert_conversion("5b", Ok(5)); + assert_conversion("4k", Ok(4 * 1024)); + assert_conversion("4kb", Ok(4 * 1024)); + assert_conversion("20m", Ok(20 * 1024 * 1024)); + assert_conversion("20mb", Ok(20 * 1024 * 1024)); + assert_conversion("2g", Ok(2 * 1024 * 1024 * 1024)); + assert_conversion("2gb", Ok(2 * 1024 * 1024 * 1024)); + assert_conversion("3t", Ok(3 * 1024 * 1024 * 1024 * 1024)); + assert_conversion("4tb", Ok(4 * 1024 * 1024 * 1024 * 1024)); + + // Test case insensitivity + assert_conversion("4K", Ok(4 * 1024)); + assert_conversion("4KB", Ok(4 * 1024)); + assert_conversion("20M", Ok(20 * 1024 * 1024)); + assert_conversion("20MB", Ok(20 * 1024 * 1024)); + assert_conversion("2G", Ok(2 * 1024 * 1024 * 1024)); + assert_conversion("2GB", Ok(2 * 1024 * 1024 * 1024)); + assert_conversion("2T", Ok(2 * 1024 * 1024 * 1024 * 1024)); + + // Test invalid input + assert_conversion( + "invalid", + Err("Invalid memory pool size 'invalid'".to_string()), + ); + assert_conversion("4kbx", Err("Invalid memory pool size '4kbx'".to_string())); + assert_conversion( + "-20mb", + Err("Invalid numeric value in memory pool size '-20mb'".to_string()), + ); + assert_conversion( + "-100", + Err("Invalid numeric value in memory pool size '-100'".to_string()), + ); + assert_conversion( + "12k12k", + Err("Invalid memory pool size '12k12k'".to_string()), + ); + + Ok(()) } }