diff --git a/gel/_internal/_codegen/_models/_pydantic.py b/gel/_internal/_codegen/_models/_pydantic.py index 58a69891..b61f4dc2 100644 --- a/gel/_internal/_codegen/_models/_pydantic.py +++ b/gel/_internal/_codegen/_models/_pydantic.py @@ -6170,6 +6170,45 @@ def resolve( f"# type: ignore [assignment, misc, unused-ignore]" ) + if function.schemapath in { + SchemaPath('std', 'UNION'), + SchemaPath('std', 'IF'), + SchemaPath('std', '??'), + }: + # Special case for the UNION, IF and ?? operators + # Produce a union type instead of just taking the first + # valid type. + # + # See gel: edb.compiler.func.compile_operator + create_union = self.import_name( + BASE_IMPL, "create_optional_union" + ) + + tvars: list[str] = [] + for param, path in sources: + if ( + param.name in required_generic_params + or param.name in optional_generic_params + ): + pn = param_vars[param.name] + tvar = f"__t_{pn}__" + + resolve(pn, path, tvar) + tvars.append(tvar) + + self.write( + f"{gtvar} = {tvars[0]} " + f"# type: ignore [assignment, misc, unused-ignore]" + ) + for tvar in tvars[1:]: + self.write( + f"{gtvar} = {create_union}({gtvar}, {tvar}) " + f"# type: ignore [" + f"assignment, misc, unused-ignore]" + ) + + continue + # Try to infer generic type from required params first for param, path in sources: if param.name in required_generic_params: diff --git a/gel/_internal/_qbmodel/_abstract/__init__.py b/gel/_internal/_qbmodel/_abstract/__init__.py index 41fb17a8..4500a0f1 100644 --- a/gel/_internal/_qbmodel/_abstract/__init__.py +++ b/gel/_internal/_qbmodel/_abstract/__init__.py @@ -68,6 +68,9 @@ from ._methods import ( BaseGelModel, BaseGelModelIntersection, + BaseGelModelUnion, + create_optional_union, + create_union, ) @@ -138,6 +141,7 @@ "ArrayMeta", "BaseGelModel", "BaseGelModelIntersection", + "BaseGelModelUnion", "ComputedLinkSet", "ComputedLinkWithPropsSet", "ComputedMultiLinkDescriptor", @@ -181,6 +185,8 @@ "TupleMeta", "UUIDImpl", "copy_or_ref_lprops", + "create_optional_union", + "create_union", "empty_set_if_none", "field_descriptor", "get_base_scalars_backed_by_py_type", diff --git a/gel/_internal/_qbmodel/_abstract/_methods.py b/gel/_internal/_qbmodel/_abstract/_methods.py index bcba7d83..fe413936 100644 --- a/gel/_internal/_qbmodel/_abstract/_methods.py +++ b/gel/_internal/_qbmodel/_abstract/_methods.py @@ -18,8 +18,9 @@ from gel._internal import _qb from gel._internal._schemapath import ( - TypeNameIntersection, TypeNameExpr, + TypeNameIntersection, + TypeNameUnion, ) from gel._internal import _type_expression from gel._internal._xmethod import classonlymethod @@ -270,6 +271,17 @@ class BaseGelModelIntersectionBacklinks( rhs: ClassVar[type[AbstractGelObjectBacklinksModel]] +class BaseGelModelUnion( + BaseGelModel, + _type_expression.Union, + Generic[_T_Lhs, _T_Rhs], +): + __gel_type_class__: ClassVar[type] + + lhs: ClassVar[type[AbstractGelModel]] + rhs: ClassVar[type[AbstractGelModel]] + + T = TypeVar('T') U = TypeVar('U') @@ -318,6 +330,17 @@ def combine_dicts( return result +def _order_base_types(lhs: type, rhs: type) -> tuple[type, ...]: + if lhs == rhs: + return (lhs,) + elif issubclass(lhs, rhs): + return (lhs, rhs) + elif issubclass(rhs, lhs): + return (rhs, lhs) + else: + return (lhs, rhs) + + _type_intersection_cache: weakref.WeakKeyDictionary[ type[AbstractGelModel], weakref.WeakKeyDictionary[ @@ -430,17 +453,6 @@ def object( return result -def _order_base_types(lhs: type, rhs: type) -> tuple[type, ...]: - if lhs == rhs: - return (lhs,) - elif issubclass(lhs, rhs): - return (lhs, rhs) - elif issubclass(rhs, lhs): - return (rhs, lhs) - else: - return (lhs, rhs) - - def create_intersection_backlinks( lhs_backlinks: type[AbstractGelObjectBacklinksModel], rhs_backlinks: type[AbstractGelObjectBacklinksModel], @@ -500,3 +512,106 @@ def create_intersection_backlinks( ) return backlinks + + +_type_union_cache: weakref.WeakKeyDictionary[ + type[AbstractGelModel], + weakref.WeakKeyDictionary[ + type[AbstractGelModel], + type[BaseGelModelUnion[AbstractGelModel, AbstractGelModel]], + ], +] = weakref.WeakKeyDictionary() + + +def create_optional_union( + lhs: type[_T_Lhs] | None, + rhs: type[_T_Rhs] | None, +) -> type[BaseGelModelUnion[_T_Lhs, _T_Rhs] | AbstractGelModel] | None: + if lhs is None: + return rhs + elif rhs is None: + return lhs + else: + return create_union(lhs, rhs) + + +def create_union( + lhs: type[_T_Lhs], + rhs: type[_T_Rhs], +) -> type[BaseGelModelUnion[_T_Lhs, _T_Rhs]]: + """Create a runtime union type which acts like a GelModel.""" + + if (lhs_entry := _type_union_cache.get(lhs)) and ( + rhs_entry := lhs_entry.get(rhs) + ): + return rhs_entry # type: ignore[return-value] + + # Combine pointer reflections from args + ptr_reflections: dict[str, _qb.GelPointerReflection] = { + p_name: p_refl + for p_name, p_refl in lhs.__gel_reflection__.pointers.items() + if p_name in rhs.__gel_reflection__.pointers + } + + # Create type reflection for union type + class __gel_reflection__(_qb.GelObjectTypeExprMetadata.__gel_reflection__): # noqa: N801 + expr_object_types: set[type[AbstractGelModel]] = getattr( + lhs.__gel_reflection__, 'expr_object_types', {lhs} + ) | getattr(rhs.__gel_reflection__, 'expr_object_types', {rhs}) + + type_name = TypeNameUnion( + args=( + lhs.__gel_reflection__.type_name, + rhs.__gel_reflection__.type_name, + ) + ) + + pointers = ptr_reflections + + @classmethod + def object( + cls, + ) -> Any: + raise NotImplementedError( + "Type expressions schema objects are inaccessible" + ) + + # Create the resulting union type + result = type( + f"({lhs.__name__} | {rhs.__name__})", + (BaseGelModelUnion,), + { + 'lhs': lhs, + 'rhs': rhs, + '__gel_reflection__': __gel_reflection__, + "__gel_proxied_dunders__": frozenset( + { + "__backlinks__", + } + ), + }, + ) + + # Generate field descriptors. + descriptors: dict[str, ModelFieldDescriptor] = { + p_name: field_descriptor(result, p_name, l_path_alias.__gel_origin__) + for p_name, p_refl in lhs.__gel_reflection__.pointers.items() + if ( + hasattr(lhs, p_name) + and (l_path_alias := getattr(lhs, p_name, None)) is not None + and isinstance(l_path_alias, _qb.PathAlias) + ) + if ( + hasattr(rhs, p_name) + and (r_path_alias := getattr(rhs, p_name, None)) is not None + and isinstance(r_path_alias, _qb.PathAlias) + ) + } + for p_name, descriptor in descriptors.items(): + setattr(result, p_name, descriptor) + + if lhs not in _type_union_cache: + _type_union_cache[lhs] = weakref.WeakKeyDictionary() + _type_union_cache[lhs][rhs] = result + + return result diff --git a/gel/_internal/_typing_dispatch.py b/gel/_internal/_typing_dispatch.py index d74bc23f..77ee9d22 100644 --- a/gel/_internal/_typing_dispatch.py +++ b/gel/_internal/_typing_dispatch.py @@ -70,6 +70,8 @@ def _issubclass(lhs: Any, tp: Any, fn: Any) -> bool: if issubclass(lhs, _type_expression.Intersection): return any(_issubclass(c, tp, fn) for c in (lhs.lhs, lhs.rhs)) + elif issubclass(lhs, _type_expression.Union): + return all(_issubclass(c, tp, fn) for c in (lhs.lhs, lhs.rhs)) if _typing_inspect.is_generic_alias(tp): origin = typing.get_origin(tp) diff --git a/gel/models/pydantic.py b/gel/models/pydantic.py index 2ce9801a..2d7db0e6 100644 --- a/gel/models/pydantic.py +++ b/gel/models/pydantic.py @@ -76,6 +76,8 @@ PyTypeScalarConstraint, RangeMeta, UUIDImpl, + create_optional_union, + create_union, empty_set_if_none, ) @@ -215,6 +217,8 @@ "classonlymethod", "computed_field", "construct_infix_op_chain", + "create_optional_union", + "create_union", "dispatch_overload", "empty_set_if_none", ) diff --git a/tests/test_qb.py b/tests/test_qb.py index 9cd5a0aa..eb25008b 100644 --- a/tests/test_qb.py +++ b/tests/test_qb.py @@ -2600,6 +2600,121 @@ def test_qb_backlinks_error_01(self): with self.assertRaisesRegex(ValueError, "unsupported query type"): self.client.query(query) + def test_qb_backlinks_error_02(self): + # Unions don't have backlinks + from models.orm_qb import default, std + + with self.assertRaisesRegex( + AttributeError, "has no attribute '__backlinks__'" + ): + std.union(default.Inh_ABC, default.Inh_AB_AC).__backlinks__ + + def test_qb_std_coalesce_scalar_01(self): + from models.orm_qb import std + + query = std.coalesce(1, 2) + result = self.client.query(query) + + self.assertEqual(result, [1]) + + def test_qb_std_coalesce_scalar_02(self): + from models.orm_qb import std + + query = std.coalesce(None, 2) + result = self.client.query(query) + + self.assertEqual(result, [2]) + + def test_qb_std_coalesce_object_01(self): + from models.orm_qb import default, std + + inh_a_objs = { + obj.a: obj + for obj in self.client.query(default.Inh_A.select(a=True)) + } + + query = std.coalesce(default.Inh_AB_AC, default.Inh_ABC) + result = self.client.query(query) + self._assertListEqualUnordered([inh_a_objs[17]], result) + + def test_qb_std_coalesce_object_02(self): + from models.orm_qb import default, std + + inh_a_objs = { + obj.a: obj + for obj in self.client.query(default.Inh_A.select(a=True)) + } + + query = std.coalesce(None, default.Inh_ABC) + result = self.client.query(query) + + self._assertListEqualUnordered([inh_a_objs[13]], result) + + def test_qb_std_coalesce_object_03(self): + from models.orm_qb import default, std + + inh_a_objs = { + obj.a: obj + for obj in self.client.query(default.Inh_A.select(a=True)) + } + + query = std.coalesce( + default.Inh_AB.is_(default.Inh_AC), default.Inh_ABC + ) + result = self.client.query(query) + self._assertListEqualUnordered([inh_a_objs[17]], result) + + def test_qb_std_union_scalar_01(self): + from models.orm_qb import std + + query = std.union(1, 2) + result = self.client.query(query) + self._assertListEqualUnordered(result, [1, 2]) + + def test_qb_std_union_scalar_02(self): + from models.orm_qb import std + + query = std.union(1, [2, 3]) + result = self.client.query(query) + self._assertListEqualUnordered(result, [1, 2, 3]) + + def test_qb_std_union_scalar_03(self): + from models.orm_qb import std + + query = std.union([1, 2], [2, 3]) + result = self.client.query(query) + self._assertListEqualUnordered(result, [1, 2, 2, 3]) + + def test_qb_std_union_object_01(self): + from models.orm_qb import default, std + + inh_a_objs = { + obj.a: obj + for obj in self.client.query(default.Inh_A.select(a=True)) + } + + query = std.union(default.Inh_ABC, default.Inh_AB_AC).select('*') + result = self.client.query(query) + self._assertListEqualUnordered( + [inh_a_objs[13], inh_a_objs[17]], result + ) + + def test_qb_std_union_object_02(self): + from models.orm_qb import default, std + + inh_a_objs = { + obj.a: obj + for obj in self.client.query(default.Inh_A.select(a=True)) + } + + query = std.union( + default.Inh_ABC, default.Inh_AB.is_(default.Inh_AC) + ).select('*') + result = self.client.query(query) + self._assertListEqualUnordered( + [inh_a_objs[13], inh_a_objs[17]], result + ) + class TestQueryBuilderModify(tb.ModelTestCase): """This test suite is for data manipulation using QB."""