diff --git a/Cargo.lock b/Cargo.lock index 4f0eb08..1d4be09 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -273,6 +273,7 @@ dependencies = [ "pyo3", "pyo3-build-config", "rayon", + "struct_deep_getter_derive", "thiserror", ] @@ -301,9 +302,9 @@ dependencies = [ [[package]] name = "proc-macro2" -version = "1.0.43" +version = "1.0.49" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0a2ca2c61bc9f3d74d2886294ab7b9853abd9c1ad903a3ac7815c58989bb7bab" +checksum = "57a8eca9f9c4ffde41714334dee777596264c7825420f521abc92b5b5deb63a5" dependencies = [ "unicode-ident", ] @@ -380,9 +381,9 @@ dependencies = [ [[package]] name = "quote" -version = "1.0.21" +version = "1.0.23" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bbe448f377a7d6961e30f5955f9b8d106c3f5e449d493ee1b125c1d43c2b5179" +checksum = "8856d8364d252a14d474036ea1358d63c9e6965c8e5c1885c18f73d70bff9c7b" dependencies = [ "proc-macro2", ] @@ -458,11 +459,25 @@ version = "1.9.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "2fd0db749597d91ff862fd1d55ea87f7855a744a8425a64695b6fca237d1dad1" +[[package]] +name = "struct_deep_getter" +version = "0.1.0" + +[[package]] +name = "struct_deep_getter_derive" +version = "0.1.0" +dependencies = [ + "proc-macro2", + "quote", + "struct_deep_getter", + "syn", +] + [[package]] name = "syn" -version = "1.0.99" +version = "1.0.107" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "58dbef6ec655055e20b86b15a8cc6d439cca19b667537ac6a1369572d151ab13" +checksum = "1f4064b5b16e03ae50984a5a8ed5d4f8803e6bc1fd170a3cda91a1be4b18e3f5" dependencies = [ "proc-macro2", "quote", diff --git a/Cargo.toml b/Cargo.toml index 8705cdf..ad425df 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -14,6 +14,7 @@ numpy = "0.17.2" maxminddb = { version = "0.23.0", features = ["mmap"] } thiserror = "1.0.36" rayon = "1.5.3" +struct_deep_getter_derive = { path = "struct_deep_getter_derive" } [build-dependencies] pyo3-build-config = "0.17.1" diff --git a/struct_deep_getter/Cargo.toml b/struct_deep_getter/Cargo.toml new file mode 100644 index 0000000..c7e2ed5 --- /dev/null +++ b/struct_deep_getter/Cargo.toml @@ -0,0 +1,12 @@ +[package] +name = "struct_deep_getter" +version = "0.1.0" +edition = "2021" + +# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html + +[dependencies] +#struct_deep_getter_derive = { path = "../struct_deep_getter_derive", optional = true } + +[features] +#derive = ["dep:struct_deep_getter_derive"] diff --git a/struct_deep_getter/src/lib.rs b/struct_deep_getter/src/lib.rs new file mode 100644 index 0000000..dc408d7 --- /dev/null +++ b/struct_deep_getter/src/lib.rs @@ -0,0 +1,4 @@ +pub trait StructDeepGetter { + fn deeper_structs() -> Vec; + fn get_path(&self, path: &str) -> T; +} diff --git a/struct_deep_getter_derive/Cargo.toml b/struct_deep_getter_derive/Cargo.toml new file mode 100644 index 0000000..f1e35f6 --- /dev/null +++ b/struct_deep_getter_derive/Cargo.toml @@ -0,0 +1,15 @@ +[package] +name = "struct_deep_getter_derive" +version = "0.1.0" +edition = "2021" + +# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html + +[lib] +proc-macro = true + +[dependencies] +struct_deep_getter = { path = "../struct_deep_getter" } +syn = { version = "^1.0.107", features = ["full", "visit"]} +quote = "^1.0.23" +proc-macro2 = "^1.0.49" diff --git a/struct_deep_getter_derive/src/lib.rs b/struct_deep_getter_derive/src/lib.rs new file mode 100644 index 0000000..eba85c8 --- /dev/null +++ b/struct_deep_getter_derive/src/lib.rs @@ -0,0 +1,320 @@ +extern crate proc_macro; +extern crate core; + +use proc_macro::TokenStream; +use std::collections::{HashMap, VecDeque}; +use proc_macro2::Span; + +use quote::quote; +use syn; +use syn::{Ident, Fields, FieldsNamed, File, GenericArgument, ItemStruct, Lit, Meta, MetaList, NestedMeta, PathArguments, Type}; +use syn::visit::Visit; + + +// todo: generate enum with getters +// todo: return getter strings +// todo: return getter enum +// todo: pass through the original code, remove the macro attributes from it + + +#[derive(Debug)] +struct Struct { + ident: String, + fields: Vec, + generate: bool, + return_type: Option, + replacement_type: Option, +} + +fn get_nested_meta_params(meta: &MetaList) -> HashMap { + meta + .nested + .iter() + .map(|nm| match nm { + NestedMeta::Meta(m) => match m { + Meta::NameValue(nameval) => + match (nameval.path.get_ident(), &nameval.lit) { + (Some(ident), Lit::Str(lit)) => { + (ident.to_string(), lit.value()) + } + _ => panic!("struct_deep_getter attributes are expected to be strings") + } + _ => panic!("struct_deep_getter attributes should be named list") + } + NestedMeta::Lit(_) => panic!("struct_deep_getter attributes should be named list") + }) + .collect() +} + +fn get_meta(i: &ItemStruct) -> Option { + for attr in &i.attrs { + if let Ok(Meta::List(meta)) = attr.parse_meta() { + if let Some(macro_name) = meta.path.get_ident() { + if macro_name == "struct_deep_getter" { + return Some(meta); + } + } + } + } + + None +} + +#[derive(Debug)] +struct Field { + ident: String, + types: Vec, +} + +fn extract_types_deque(start_type: &Type) -> Vec { + let mut res = Vec::new(); + let mut rec_ty = VecDeque::new(); + rec_ty.push_front(start_type.clone()); + + while let Some(ty) = rec_ty.pop_front() { + match ty { + Type::Path(p) => { + for seg in p.path.segments { + res.push(seg.ident.to_string()); + + if !seg.arguments.is_empty() { + match seg.arguments { + PathArguments::AngleBracketed(ab_arg) => { + for arg in ab_arg.args.iter() { + match arg { + GenericArgument::Type(next_ty) => { + rec_ty.push_back(next_ty.clone()); + } + GenericArgument::Lifetime(_) => { + // ignore lifetimes + } + _ => panic!("only types as arguments are supported") + } + } + } + _ => panic!("only angle-bracketed type segments are supported") + } + } + } + } + Type::Reference(r) => { + // ignore refs push concrete types + rec_ty.push_back(r.elem.as_ref().clone()); + } + _ => panic!("only type paths are supported") + } + } + + res +} + +fn extract_fields(fields: &FieldsNamed) -> Vec { + fields + .named + .iter() + // .filter(|field| matches!(field.vis, Visibility::Public(_))) + .map(|field| { + let ident = field.ident.as_ref().unwrap().to_string(); + let types = extract_types_deque(&field.ty); + Field { ident, types } + }) + .collect() +} + +#[derive(Default)] +struct StructVisitor { + structs: HashMap, +} + +impl<'ast> Visit<'ast> for StructVisitor { + fn visit_item_struct(&mut self, i: &'ast ItemStruct) { + let ident = i.ident.to_string(); + + let mut generate = false; + let mut return_type = None; + let mut replacement_type = None; + if let Some(meta) = get_meta(i) { + generate = true; + + let params = get_nested_meta_params(&meta); + return_type = params.get("return_type").map(|v| v.clone()); + replacement_type = params.get("replacement_type").map(|v| v.clone()); + } + + let fields = match &i.fields { + Fields::Named(named) => extract_fields(&named), + _ => panic!("only named fields are supported") + }; + + self.structs.insert( + ident.clone(), + Struct { + ident, + fields, + generate, + return_type, + replacement_type, + }); + } +} + +fn fields_to_path(fields: &[&Field]) -> String { + fields + .iter() + .map(|field| { + if field.types.contains(&"Vec".to_owned()) { + format!("{}[0]", field.ident) + } else if field.types.contains(&"BTreeMap".to_owned()) { + format!("{}[\"en\"]", field.ident) + } else { + field.ident.to_string() + } + }) + .collect::>() + .join(".") +} + +fn fields_to_getter(fields: &[&Field]) -> proc_macro2::TokenStream { + let mut within_option = false; + let size = fields.len(); + + let mut tokens: Vec = Vec::new(); + + for (i, field) in fields.iter().enumerate() { + let ident = Ident::new(&field.ident, Span::call_site()); + let is_last = i == size - 1; + + let str_types: Vec<&str> = field.types.iter().map(|s| s.as_str()).collect(); + let res = match str_types.as_slice() { + ["Option", "Vec", _] => { + if !within_option { + within_option = true; + quote!(.#ident.as_ref().and_then(|x| x.first())) + } else { + quote!(.and_then(|x| x.#ident.as_ref()).and_then(|x| x.first())) + } + }, + ["Option", "BTreeMap", _, _] => { + if !within_option { + within_option = true; + quote!(.#ident.as_ref().and_then(|x| x.get("en").copied())) + } else { + quote!(.and_then(|x| x.#ident.as_ref()).and_then(|x| x.get("en").copied())) + } + }, + ["Option", _] => { + if !within_option { + within_option = true; + if is_last { + quote!(.#ident) + } else { + quote!(.#ident.as_ref()) + } + } else { + if is_last { + quote!(.and_then(|x| x.#ident)) + } else { + quote!(.and_then(|x| x.#ident.as_ref())) + } + } + }, + &[] | &[..] => todo!(), + }; + + tokens.push(res); + } + let tokens = proc_macro2::TokenStream::from_iter(tokens); + + let path = fields_to_path(fields); + let res = quote!( + #path => self #tokens.into(), + ); + + proc_macro2::TokenStream::from(res) +} + +struct GeneratorState<'a> { + current_struct: &'a Struct, + fields: Vec<&'a Field> +} + +fn generate_getters(target: String, structs: &HashMap) -> (Vec, Vec) { + let mut paths = Vec::new(); + let mut getters = Vec::new(); + + let mut to_visit = VecDeque::new(); + to_visit.push_back(GeneratorState { + current_struct: structs.get(&target).unwrap(), + fields: Vec::new() + }); + + while let Some(state) = to_visit.pop_front() { + for field in state.current_struct.fields.iter() { + let ty = field.types.last().unwrap(); + let mut fields = state.fields.clone(); + fields.push(field); + + if structs.contains_key(ty) { + to_visit.push_front(GeneratorState { + current_struct: structs.get(ty).unwrap(), + fields + }); + } else { + paths.push(fields_to_path(&fields)); + getters.push(fields_to_getter(&fields)); + } + } + } + + (paths, getters) +} + +fn generate_impl(target: String, structs: &HashMap) -> TokenStream { + let target_struct = structs.get(&target).unwrap(); + + let trgt = if let Some(rpl) = &target_struct.replacement_type { + Ident::new(rpl, Span::mixed_site()) + } else { + Ident::new(&target_struct.ident, Span::mixed_site()) + }; + + let res_type = Ident::new(&target_struct.return_type.as_ref().unwrap(), Span::mixed_site()); + + let (paths, getters) = generate_getters(target, structs); + let getters = proc_macro2::TokenStream::from_iter(getters); + let res = quote!( + impl<'a> struct_deep_getter::StructDeepGetter<#res_type> for #trgt<'a> { + fn deeper_structs() -> Vec { + let mut res = Vec::new(); + #(res.push(#paths.to_string());)* + res + } + + fn get_path(&self, path: &str) -> #res_type { + match path { + #getters + _ => panic!("error"), + } + } + } + ); + println!("{}", res); + TokenStream::from(res) +} + +#[proc_macro] +pub fn make_paths(input: TokenStream) -> TokenStream { + let ast: File = syn::parse(input).unwrap(); + let mut state = StructVisitor::default(); + state.visit_file(&ast); + + let mut impls = Vec::new(); + for (ident, strct) in state.structs.iter() { + println!("{:?}", strct); + if strct.generate { + impls.push(generate_impl(ident.clone(), &state.structs)); + } + } + + TokenStream::from_iter(impls) +} diff --git a/struct_deep_getter_derive/tests/tests.rs b/struct_deep_getter_derive/tests/tests.rs new file mode 100644 index 0000000..3a47d7c --- /dev/null +++ b/struct_deep_getter_derive/tests/tests.rs @@ -0,0 +1,169 @@ +#![allow(dead_code)] + +use std::collections::BTreeMap; +use struct_deep_getter::StructDeepGetter; +use struct_deep_getter_derive::make_paths; + +struct SuperType { + value: String +} + +impl From> for SuperType { + fn from(s: Option<&str>) -> Self { + SuperType { value: s.unwrap_or("None").to_owned() } + } +} + +impl From> for SuperType { + fn from(s: Option) -> Self { + SuperType { value: format!("{}", s.unwrap_or(0)) } + } +} + +impl From> for SuperType { + fn from(s: Option) -> Self { + SuperType { value: format!("{}", s.unwrap_or(0)) } + } +} + +impl From> for SuperType { + fn from(s: Option) -> Self { + SuperType { value: format!("{}", s.unwrap_or(0.0)) } + } +} + +impl From> for SuperType { + fn from(s: Option) -> Self { + SuperType { value: format!("{}", s.unwrap_or(false)) } + } +} + +pub struct MaxmindCity<'a> { + pub city: Option>, + pub continent: Option>, + pub country: Option>, + pub location: Option>, + pub postal: Option>, + pub registered_country: Option>, + pub represented_country: Option>, + pub subdivisions: Option>>, + pub traits: Option, +} + +pub struct City2<'a> { + pub geoname_id: Option, + pub names: Option>, +} + +pub struct Location<'a> { + pub accuracy_radius: Option, + pub latitude: Option, + pub longitude: Option, + pub metro_code: Option, + pub time_zone: Option<&'a str>, +} + +pub struct Postal<'a> { + pub code: Option<&'a str>, +} + +pub struct Subdivision<'a> { + pub geoname_id: Option, + pub iso_code: Option<&'a str>, + pub names: Option>, +} + +pub struct Continent<'a> { + pub code: Option<&'a str>, + pub geoname_id: Option, + pub names: Option>, +} + +pub struct Country<'a> { + pub geoname_id: Option, + pub is_in_european_union: Option, + pub iso_code: Option<&'a str>, + pub names: Option>, +} + +pub struct RepresentedCountry<'a> { + pub geoname_id: Option, + pub is_in_european_union: Option, + pub iso_code: Option<&'a str>, + pub names: Option>, + pub representation_type: Option<&'a str>, +} + +pub struct Traits { + pub is_anonymous_proxy: Option, + pub is_satellite_provider: Option, +} + +make_paths!( + #[struct_deep_getter(return_type = "SuperType", replacement_type = "MaxmindCity")] + pub struct City<'a> { + pub city: Option>, + pub continent: Option>, + pub country: Option>, + pub location: Option>, + pub postal: Option>, + pub registered_country: Option>, + pub represented_country: Option>, + pub subdivisions: Option>>, + pub traits: Option, + } + + pub struct City2<'a> { + pub geoname_id: Option, + pub names: Option>, + } + + pub struct Location<'a> { + pub accuracy_radius: Option, + pub latitude: Option, + pub longitude: Option, + pub metro_code: Option, + pub time_zone: Option<&'a str>, + } + + pub struct Postal<'a> { + pub code: Option<&'a str>, + } + + pub struct Subdivision<'a> { + pub geoname_id: Option, + pub iso_code: Option<&'a str>, + pub names: Option>, + } + + pub struct Continent<'a> { + pub code: Option<&'a str>, + pub geoname_id: Option, + pub names: Option>, + } + + pub struct Country<'a> { + pub geoname_id: Option, + pub is_in_european_union: Option, + pub iso_code: Option<&'a str>, + pub names: Option>, + } + + pub struct RepresentedCountry<'a> { + pub geoname_id: Option, + pub is_in_european_union: Option, + pub iso_code: Option<&'a str>, + pub names: Option>, + pub representation_type: Option<&'a str>, + } + + pub struct Traits { + pub is_anonymous_proxy: Option, + pub is_satellite_provider: Option, + } +); + +#[test] +fn test_paths() { + assert_eq!(MaxmindCity::deeper_structs(), vec!["lol"]) +}