Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 39 additions & 0 deletions gel/_internal/_codegen/_models/_pydantic.py
Original file line number Diff line number Diff line change
Expand Up @@ -6170,6 +6170,45 @@ def resolve(
f"# type: ignore [assignment, misc, unused-ignore]"
)

if function.schemapath in {
SchemaPath('std', 'UNION'),
SchemaPath('std', 'IF'),
SchemaPath('std', '??'),
}:
# Special case for the UNION, IF and ?? operators
# Produce a union type instead of just taking the first
# valid type.
#
# See gel: edb.compiler.func.compile_operator
create_union = self.import_name(
BASE_IMPL, "create_optional_union"
)

tvars: list[str] = []
for param, path in sources:
if (
param.name in required_generic_params
or param.name in optional_generic_params
):
pn = param_vars[param.name]
tvar = f"__t_{pn}__"

resolve(pn, path, tvar)
tvars.append(tvar)

self.write(
f"{gtvar} = {tvars[0]} "
f"# type: ignore [assignment, misc, unused-ignore]"
)
for tvar in tvars[1:]:
self.write(
f"{gtvar} = {create_union}({gtvar}, {tvar}) "
f"# type: ignore ["
f"assignment, misc, unused-ignore]"
)

continue

# Try to infer generic type from required params first
for param, path in sources:
if param.name in required_generic_params:
Expand Down
6 changes: 6 additions & 0 deletions gel/_internal/_qbmodel/_abstract/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,9 @@
from ._methods import (
BaseGelModel,
BaseGelModelIntersection,
BaseGelModelUnion,
create_optional_union,
create_union,
)


Expand Down Expand Up @@ -138,6 +141,7 @@
"ArrayMeta",
"BaseGelModel",
"BaseGelModelIntersection",
"BaseGelModelUnion",
"ComputedLinkSet",
"ComputedLinkWithPropsSet",
"ComputedMultiLinkDescriptor",
Expand Down Expand Up @@ -181,6 +185,8 @@
"TupleMeta",
"UUIDImpl",
"copy_or_ref_lprops",
"create_optional_union",
"create_union",
"empty_set_if_none",
"field_descriptor",
"get_base_scalars_backed_by_py_type",
Expand Down
139 changes: 127 additions & 12 deletions gel/_internal/_qbmodel/_abstract/_methods.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,9 @@

from gel._internal import _qb
from gel._internal._schemapath import (
TypeNameIntersection,
TypeNameExpr,
TypeNameIntersection,
TypeNameUnion,
)
from gel._internal import _type_expression
from gel._internal._xmethod import classonlymethod
Expand Down Expand Up @@ -270,6 +271,17 @@ class BaseGelModelIntersectionBacklinks(
rhs: ClassVar[type[AbstractGelObjectBacklinksModel]]


class BaseGelModelUnion(
BaseGelModel,
_type_expression.Union,
Generic[_T_Lhs, _T_Rhs],
):
__gel_type_class__: ClassVar[type]

lhs: ClassVar[type[AbstractGelModel]]
rhs: ClassVar[type[AbstractGelModel]]


T = TypeVar('T')
U = TypeVar('U')

Expand Down Expand Up @@ -318,6 +330,17 @@ def combine_dicts(
return result


def _order_base_types(lhs: type, rhs: type) -> tuple[type, ...]:
if lhs == rhs:
return (lhs,)
elif issubclass(lhs, rhs):
return (lhs, rhs)
elif issubclass(rhs, lhs):
return (rhs, lhs)
else:
return (lhs, rhs)


_type_intersection_cache: weakref.WeakKeyDictionary[
type[AbstractGelModel],
weakref.WeakKeyDictionary[
Expand Down Expand Up @@ -430,17 +453,6 @@ def object(
return result


def _order_base_types(lhs: type, rhs: type) -> tuple[type, ...]:
if lhs == rhs:
return (lhs,)
elif issubclass(lhs, rhs):
return (lhs, rhs)
elif issubclass(rhs, lhs):
return (rhs, lhs)
else:
return (lhs, rhs)


def create_intersection_backlinks(
lhs_backlinks: type[AbstractGelObjectBacklinksModel],
rhs_backlinks: type[AbstractGelObjectBacklinksModel],
Expand Down Expand Up @@ -500,3 +512,106 @@ def create_intersection_backlinks(
)

return backlinks


_type_union_cache: weakref.WeakKeyDictionary[
type[AbstractGelModel],
weakref.WeakKeyDictionary[
type[AbstractGelModel],
type[BaseGelModelUnion[AbstractGelModel, AbstractGelModel]],
],
] = weakref.WeakKeyDictionary()


def create_optional_union(
lhs: type[_T_Lhs] | None,
rhs: type[_T_Rhs] | None,
) -> type[BaseGelModelUnion[_T_Lhs, _T_Rhs] | AbstractGelModel] | None:
if lhs is None:
return rhs
elif rhs is None:
return lhs
else:
return create_union(lhs, rhs)


def create_union(
lhs: type[_T_Lhs],
rhs: type[_T_Rhs],
) -> type[BaseGelModelUnion[_T_Lhs, _T_Rhs]]:
"""Create a runtime union type which acts like a GelModel."""

if (lhs_entry := _type_union_cache.get(lhs)) and (
rhs_entry := lhs_entry.get(rhs)
):
return rhs_entry # type: ignore[return-value]

# Combine pointer reflections from args
ptr_reflections: dict[str, _qb.GelPointerReflection] = {
p_name: p_refl
for p_name, p_refl in lhs.__gel_reflection__.pointers.items()
if p_name in rhs.__gel_reflection__.pointers
}

# Create type reflection for union type
class __gel_reflection__(_qb.GelObjectTypeExprMetadata.__gel_reflection__): # noqa: N801
expr_object_types: set[type[AbstractGelModel]] = getattr(
lhs.__gel_reflection__, 'expr_object_types', {lhs}
) | getattr(rhs.__gel_reflection__, 'expr_object_types', {rhs})

type_name = TypeNameUnion(
args=(
lhs.__gel_reflection__.type_name,
rhs.__gel_reflection__.type_name,
)
)

pointers = ptr_reflections

@classmethod
def object(
cls,
) -> Any:
raise NotImplementedError(
"Type expressions schema objects are inaccessible"
)

# Create the resulting union type
result = type(
f"({lhs.__name__} | {rhs.__name__})",
(BaseGelModelUnion,),
{
'lhs': lhs,
'rhs': rhs,
'__gel_reflection__': __gel_reflection__,
"__gel_proxied_dunders__": frozenset(
{
"__backlinks__",
}
),
},
)

# Generate field descriptors.
descriptors: dict[str, ModelFieldDescriptor] = {
p_name: field_descriptor(result, p_name, l_path_alias.__gel_origin__)
for p_name, p_refl in lhs.__gel_reflection__.pointers.items()
if (
hasattr(lhs, p_name)
and (l_path_alias := getattr(lhs, p_name, None)) is not None
and isinstance(l_path_alias, _qb.PathAlias)
)
if (
hasattr(rhs, p_name)
and (r_path_alias := getattr(rhs, p_name, None)) is not None
and isinstance(r_path_alias, _qb.PathAlias)
)
}
for p_name, descriptor in descriptors.items():
setattr(result, p_name, descriptor)

if lhs not in _type_union_cache:
_type_union_cache[lhs] = weakref.WeakKeyDictionary()
_type_union_cache[lhs][rhs] = result

return result
2 changes: 2 additions & 0 deletions gel/_internal/_typing_dispatch.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,8 @@ def _issubclass(lhs: Any, tp: Any, fn: Any) -> bool:

if issubclass(lhs, _type_expression.Intersection):
return any(_issubclass(c, tp, fn) for c in (lhs.lhs, lhs.rhs))
elif issubclass(lhs, _type_expression.Union):
return all(_issubclass(c, tp, fn) for c in (lhs.lhs, lhs.rhs))

if _typing_inspect.is_generic_alias(tp):
origin = typing.get_origin(tp)
Expand Down
4 changes: 4 additions & 0 deletions gel/models/pydantic.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,8 @@
PyTypeScalarConstraint,
RangeMeta,
UUIDImpl,
create_optional_union,
create_union,
empty_set_if_none,
)

Expand Down Expand Up @@ -215,6 +217,8 @@
"classonlymethod",
"computed_field",
"construct_infix_op_chain",
"create_optional_union",
"create_union",
"dispatch_overload",
"empty_set_if_none",
)
115 changes: 115 additions & 0 deletions tests/test_qb.py
Original file line number Diff line number Diff line change
Expand Up @@ -2600,6 +2600,121 @@ def test_qb_backlinks_error_01(self):
with self.assertRaisesRegex(ValueError, "unsupported query type"):
self.client.query(query)

def test_qb_backlinks_error_02(self):
# Unions don't have backlinks
from models.orm_qb import default, std

with self.assertRaisesRegex(
AttributeError, "has no attribute '__backlinks__'"
):
std.union(default.Inh_ABC, default.Inh_AB_AC).__backlinks__

def test_qb_std_coalesce_scalar_01(self):
from models.orm_qb import std

query = std.coalesce(1, 2)
result = self.client.query(query)

self.assertEqual(result, [1])

def test_qb_std_coalesce_scalar_02(self):
from models.orm_qb import std

query = std.coalesce(None, 2)
result = self.client.query(query)

self.assertEqual(result, [2])

def test_qb_std_coalesce_object_01(self):
from models.orm_qb import default, std

inh_a_objs = {
obj.a: obj
for obj in self.client.query(default.Inh_A.select(a=True))
}

query = std.coalesce(default.Inh_AB_AC, default.Inh_ABC)
result = self.client.query(query)
self._assertListEqualUnordered([inh_a_objs[17]], result)

def test_qb_std_coalesce_object_02(self):
from models.orm_qb import default, std

inh_a_objs = {
obj.a: obj
for obj in self.client.query(default.Inh_A.select(a=True))
}

query = std.coalesce(None, default.Inh_ABC)
result = self.client.query(query)

self._assertListEqualUnordered([inh_a_objs[13]], result)

def test_qb_std_coalesce_object_03(self):
from models.orm_qb import default, std

inh_a_objs = {
obj.a: obj
for obj in self.client.query(default.Inh_A.select(a=True))
}

query = std.coalesce(
default.Inh_AB.is_(default.Inh_AC), default.Inh_ABC
)
result = self.client.query(query)
self._assertListEqualUnordered([inh_a_objs[17]], result)

def test_qb_std_union_scalar_01(self):
from models.orm_qb import std

query = std.union(1, 2)
result = self.client.query(query)
self._assertListEqualUnordered(result, [1, 2])

def test_qb_std_union_scalar_02(self):
from models.orm_qb import std

query = std.union(1, [2, 3])
result = self.client.query(query)
self._assertListEqualUnordered(result, [1, 2, 3])

def test_qb_std_union_scalar_03(self):
from models.orm_qb import std

query = std.union([1, 2], [2, 3])
result = self.client.query(query)
self._assertListEqualUnordered(result, [1, 2, 2, 3])

def test_qb_std_union_object_01(self):
from models.orm_qb import default, std

inh_a_objs = {
obj.a: obj
for obj in self.client.query(default.Inh_A.select(a=True))
}

query = std.union(default.Inh_ABC, default.Inh_AB_AC).select('*')
result = self.client.query(query)
self._assertListEqualUnordered(
[inh_a_objs[13], inh_a_objs[17]], result
)

def test_qb_std_union_object_02(self):
from models.orm_qb import default, std

inh_a_objs = {
obj.a: obj
for obj in self.client.query(default.Inh_A.select(a=True))
}

query = std.union(
default.Inh_ABC, default.Inh_AB.is_(default.Inh_AC)
).select('*')
result = self.client.query(query)
self._assertListEqualUnordered(
[inh_a_objs[13], inh_a_objs[17]], result
)


class TestQueryBuilderModify(tb.ModelTestCase):
"""This test suite is for data manipulation using QB."""
Expand Down