Skip to content

Commit 2e932c6

Browse files
authored
add type inference for serializing ip address types (#1868)
1 parent 4c9bec8 commit 2e932c6

File tree

3 files changed

+109
-16
lines changed

3 files changed

+109
-16
lines changed

src/serializers/infer.rs

Lines changed: 23 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,6 @@ use crate::serializers::type_serializers;
1919
use crate::serializers::type_serializers::format::serialize_via_str;
2020
use crate::serializers::SerializationState;
2121
use crate::tools::{extract_int, py_err, safe_repr};
22-
use crate::url::{PyMultiHostUrl, PyUrl};
2322

2423
use super::config::InfNanMode;
2524
use super::errors::SERIALIZATION_ERR_MARKER;
@@ -168,7 +167,13 @@ pub(crate) fn infer_to_python_known<'py>(
168167
let either_delta = EitherTimedelta::try_from(value)?;
169168
state.config.temporal_mode.timedelta_to_json(value.py(), either_delta)?
170169
}
171-
ObType::Url | ObType::MultiHostUrl | ObType::Path => serialize_via_str(value, serialize_to_python())?,
170+
ObType::Url
171+
| ObType::MultiHostUrl
172+
| ObType::Path
173+
| ObType::Ipv4Address
174+
| ObType::Ipv6Address
175+
| ObType::Ipv4Network
176+
| ObType::Ipv6Network => serialize_via_str(value, serialize_to_python())?,
172177
ObType::Uuid => {
173178
let uuid = super::type_serializers::uuid::uuid_to_string(value)?;
174179
uuid.into_py_any(py)?
@@ -413,9 +418,13 @@ pub(crate) fn infer_serialize_known<'py, S: Serializer>(
413418
let either_delta = EitherTimedelta::try_from(value).map_err(py_err_se_err)?;
414419
state.config.temporal_mode.timedelta_serialize(either_delta, serializer)
415420
}
416-
ObType::Url | ObType::MultiHostUrl | ObType::Path => {
417-
serialize_via_str(value, serialize_to_json(serializer)).map_err(unwrap_ser_error)
418-
}
421+
ObType::Url
422+
| ObType::MultiHostUrl
423+
| ObType::Path
424+
| ObType::Ipv4Address
425+
| ObType::Ipv6Address
426+
| ObType::Ipv4Network
427+
| ObType::Ipv6Network => serialize_via_str(value, serialize_to_json(serializer)).map_err(unwrap_ser_error),
419428
ObType::PydanticSerializable => {
420429
call_pydantic_serializer(value, state, serialize_to_json(serializer)).map_err(unwrap_ser_error)
421430
}
@@ -546,13 +555,15 @@ pub(crate) fn infer_json_key_known<'a, 'py>(
546555
let either_delta = EitherTimedelta::try_from(key)?;
547556
state.config.temporal_mode.timedelta_json_key(&either_delta)
548557
}
549-
ObType::Url => {
550-
let py_url: PyUrl = key.extract()?;
551-
Ok(Cow::Owned(py_url.__str__(key.py()).to_string()))
552-
}
553-
ObType::MultiHostUrl => {
554-
let py_url: PyMultiHostUrl = key.extract()?;
555-
Ok(Cow::Owned(py_url.__str__(key.py())))
558+
ObType::Url
559+
| ObType::MultiHostUrl
560+
| ObType::Path
561+
| ObType::Ipv4Address
562+
| ObType::Ipv6Address
563+
| ObType::Ipv4Network
564+
| ObType::Ipv6Network => {
565+
// FIXME it would be nice to have a "PyCow" which carries ownership of the Python type too
566+
Ok(Cow::Owned(key.str()?.to_string_lossy().into_owned()))
556567
}
557568
ObType::Tuple => {
558569
let mut key_build = super::type_serializers::tuple::KeyBuilder::new();
@@ -574,10 +585,6 @@ pub(crate) fn infer_json_key_known<'a, 'py>(
574585
let k = key.getattr(intern!(key.py(), "value"))?;
575586
infer_json_key(&k, state).map(|cow| Cow::Owned(cow.into_owned()))
576587
}
577-
ObType::Path => {
578-
// FIXME it would be nice to have a "PyCow" which carries ownership of the Python type too
579-
Ok(Cow::Owned(key.str()?.to_string_lossy().into_owned()))
580-
}
581588
ObType::Complex => {
582589
let v = key.downcast::<PyComplex>()?;
583590
Ok(type_serializers::complex::complex_to_str(v).into())

src/serializers/ob_type.rs

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,11 @@ pub struct ObTypeLookup {
5050
uuid_object: Py<PyAny>,
5151
// `complex` builtin
5252
complex: usize,
53+
// ip address types
54+
ipv4_address: Py<PyAny>,
55+
ipv6_address: Py<PyAny>,
56+
ipv4_network: Py<PyAny>,
57+
ipv6_network: Py<PyAny>,
5358
}
5459

5560
static TYPE_LOOKUP: PyOnceLock<ObTypeLookup> = PyOnceLock::new();
@@ -89,6 +94,10 @@ impl ObTypeLookup {
8994
pattern_object: py.import("re").unwrap().getattr("Pattern").unwrap().unbind(),
9095
uuid_object: py.import("uuid").unwrap().getattr("UUID").unwrap().unbind(),
9196
complex: PyComplex::type_object_raw(py) as usize,
97+
ipv4_address: py.import("ipaddress").unwrap().getattr("IPv4Address").unwrap().unbind(),
98+
ipv6_address: py.import("ipaddress").unwrap().getattr("IPv6Address").unwrap().unbind(),
99+
ipv4_network: py.import("ipaddress").unwrap().getattr("IPv4Network").unwrap().unbind(),
100+
ipv6_network: py.import("ipaddress").unwrap().getattr("IPv6Network").unwrap().unbind(),
92101
}
93102
}
94103

@@ -159,6 +168,10 @@ impl ObTypeLookup {
159168
ObType::Pattern => self.pattern_object.as_ptr() as usize == ob_type,
160169
ObType::Uuid => self.uuid_object.as_ptr() as usize == ob_type,
161170
ObType::Complex => self.complex == ob_type,
171+
ObType::Ipv4Address => self.ipv4_address.as_ptr() as usize == ob_type,
172+
ObType::Ipv6Address => self.ipv6_address.as_ptr() as usize == ob_type,
173+
ObType::Ipv4Network => self.ipv4_network.as_ptr() as usize == ob_type,
174+
ObType::Ipv6Network => self.ipv6_network.as_ptr() as usize == ob_type,
162175
ObType::Unknown => false,
163176
};
164177

@@ -254,6 +267,10 @@ impl ObTypeLookup {
254267
ObType::Path
255268
} else if ob_type == self.pattern_object.as_ptr() as usize {
256269
ObType::Pattern
270+
} else if ob_type == self.ipv4_address.as_ptr() as usize {
271+
ObType::Ipv4Address
272+
} else if ob_type == self.ipv6_address.as_ptr() as usize {
273+
ObType::Ipv6Address
257274
} else {
258275
// this allows for subtypes of the supported class types,
259276
// if `ob_type` didn't match any member of self, we try again with the next base type pointer
@@ -334,6 +351,16 @@ impl ObTypeLookup {
334351
ObType::Path
335352
} else if value.is_instance(self.pattern_object.bind(py)).unwrap_or(false) {
336353
ObType::Pattern
354+
} else if value.is_instance_of::<PyComplex>() {
355+
ObType::Complex
356+
} else if value.is_instance(self.ipv4_address.bind(py)).unwrap_or(false) {
357+
ObType::Ipv4Address
358+
} else if value.is_instance(self.ipv6_address.bind(py)).unwrap_or(false) {
359+
ObType::Ipv6Address
360+
} else if value.is_instance(self.ipv4_network.bind(py)).unwrap_or(false) {
361+
ObType::Ipv4Network
362+
} else if value.is_instance(self.ipv6_network.bind(py)).unwrap_or(false) {
363+
ObType::Ipv6Network
337364
} else {
338365
ObType::Unknown
339366
}
@@ -417,6 +444,11 @@ pub enum ObType {
417444
Uuid,
418445
// complex builtin
419446
Complex,
447+
// ip address types
448+
Ipv4Address,
449+
Ipv6Address,
450+
Ipv4Network,
451+
Ipv6Network,
420452
// unknown type
421453
Unknown,
422454
}

tests/serializers/test_any.py

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import dataclasses
2+
import ipaddress
23
import json
34
import platform
45
import re
@@ -722,3 +723,56 @@ class MyEnum(Enum):
722723
assert v.to_json({MyEnum.A: 'x'}) == b'{"1":"x"}'
723724
assert v.to_python(1) == 1
724725
assert v.to_json(1) == b'1'
726+
727+
728+
class SubIpV4(ipaddress.IPv4Address):
729+
def __str__(self):
730+
return super().__str__() + '_subclassed'
731+
732+
733+
class SubIpV6(ipaddress.IPv6Address):
734+
def __str__(self):
735+
return super().__str__() + '_subclassed'
736+
737+
738+
class SubNetV4(ipaddress.IPv4Network):
739+
def __str__(self):
740+
return super().__str__() + '_subclassed'
741+
742+
743+
class SubNetV6(ipaddress.IPv6Network):
744+
def __str__(self):
745+
return super().__str__() + '_subclassed'
746+
747+
748+
class SubInterfaceV4(ipaddress.IPv4Interface):
749+
def __str__(self):
750+
return super().__str__() + '_subclassed'
751+
752+
753+
class SubInterfaceV6(ipaddress.IPv6Interface):
754+
def __str__(self):
755+
return super().__str__() + '_subclassed'
756+
757+
758+
@pytest.mark.parametrize(
759+
('value', 'expected_json'),
760+
[
761+
(ipaddress.IPv4Address('192.168.1.1'), '192.168.1.1'),
762+
(ipaddress.IPv6Address('2001:0db8:85a3:0000:0000:8a2e:0370:7334'), '2001:db8:85a3::8a2e:370:7334'),
763+
(SubIpV4('192.168.1.1'), '192.168.1.1_subclassed'),
764+
(SubIpV6('2001:0db8:85a3:0000:0000:8a2e:0370:7334'), '2001:db8:85a3::8a2e:370:7334_subclassed'),
765+
(ipaddress.IPv4Network('192.168.1.0/24'), '192.168.1.0/24'),
766+
(ipaddress.IPv6Network('2001:0db8:85a3:0000:0000:8a2e:0370:7334'), '2001:db8:85a3::8a2e:370:7334/128'),
767+
(SubNetV4('192.168.1.0/24'), '192.168.1.0/24_subclassed'),
768+
(SubNetV6('2001:0db8:85a3:0000:0000:8a2e:0370:7334'), '2001:db8:85a3::8a2e:370:7334/128_subclassed'),
769+
(ipaddress.IPv4Interface('192.168.1.1/24'), '192.168.1.1/24'),
770+
(ipaddress.IPv6Interface('2001:0db8:85a3:0000:0000:8a2e:0370:7334'), '2001:db8:85a3::8a2e:370:7334/128'),
771+
(SubInterfaceV4('192.168.1.1/24'), '192.168.1.1/24_subclassed'),
772+
(SubInterfaceV6('2001:0db8:85a3:0000:0000:8a2e:0370:7334'), '2001:db8:85a3::8a2e:370:7334/128_subclassed'),
773+
],
774+
)
775+
def test_ipaddress_type_inference(any_serializer, value, expected_json):
776+
assert any_serializer.to_python(value) == value
777+
assert any_serializer.to_python(value, mode='json') == expected_json
778+
assert any_serializer.to_json(value) == f'"{expected_json}"'.encode()

0 commit comments

Comments
 (0)