From d4a1d200051980668a8c25d4d349eb54d167a8c6 Mon Sep 17 00:00:00 2001 From: Victor Petrovykh Date: Fri, 21 Mar 2025 06:49:12 -0400 Subject: [PATCH 1/3] WIP: First pass at adding bulk import and pydantic models. --- gel/codegen/cli.py | 13 +- gel/codegen/models.py | 395 +++++++++++++++++++++++++++++++++ gel/compatibility/pydmodels.py | 229 +++++++++++++++++++ 3 files changed, 636 insertions(+), 1 deletion(-) create mode 100644 gel/codegen/models.py create mode 100644 gel/compatibility/pydmodels.py diff --git a/gel/codegen/cli.py b/gel/codegen/cli.py index e8149ae60..2f1f4fe2a 100644 --- a/gel/codegen/cli.py +++ b/gel/codegen/cli.py @@ -20,7 +20,7 @@ import argparse import sys -from . import generator +from . import generator, models class ColoredArgumentParser(argparse.ArgumentParser): @@ -68,6 +68,13 @@ def error(self, message): default=["async"], help="Choose one or more targets to generate code (default is async)." ) +parser.add_argument( + "--models", + action="store_true", + default=False, + help="Using the schema generate Pydantic models that can be used for " + "bulk inserts.", +) if sys.version_info[:2] >= (3, 9): parser.add_argument( "--skip-pydantic-validation", @@ -94,6 +101,10 @@ def error(self, message): def main(): args = parser.parse_args() + if args.models: + models.Generator(args).run() + return + if not hasattr(args, "skip_pydantic_validation"): args.skip_pydantic_validation = True generator.Generator(args).run() diff --git a/gel/codegen/models.py b/gel/codegen/models.py new file mode 100644 index 000000000..2c4d08eae --- /dev/null +++ b/gel/codegen/models.py @@ -0,0 +1,395 @@ +# +# This source file is part of the EdgeDB open source project. +# +# Copyright 2025-present MagicStack Inc. and the EdgeDB authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import argparse +import getpass +import io +import os +import pathlib +import sys +import textwrap +import typing + +from collections import defaultdict +from contextlib import contextmanager +from pydantic import BaseModel + +import gel +from gel import abstract +from gel import describe +from gel.con_utils import find_gel_project_dir +from gel.color import get_color + +from gel.orm.introspection import FilePrinter, get_mod_and_name + + +C = get_color() +SYS_VERSION_INFO = os.getenv("EDGEDB_PYTHON_CODEGEN_PY_VER") +if SYS_VERSION_INFO: + SYS_VERSION_INFO = tuple(map(int, SYS_VERSION_INFO.split(".")))[:2] +else: + SYS_VERSION_INFO = sys.version_info[:2] + +TYPE_MAPPING = { + "std::str": "str", + "std::float32": "float", + "std::float64": "float", + "std::int16": "int", + "std::int32": "int", + "std::int64": "int", + "std::bigint": "int", + "std::bool": "bool", + "std::uuid": "uuid.UUID", + "std::bytes": "bytes", + "std::decimal": "decimal.Decimal", + "std::datetime": "datetime.datetime", + "std::duration": "datetime.timedelta", + "std::json": "str", + "cal::local_date": "datetime.date", + "cal::local_time": "datetime.time", + "cal::local_datetime": "datetime.datetime", + "cal::relative_duration": "gel.RelativeDuration", + "cal::date_duration": "gel.DateDuration", + "cfg::memory": "gel.ConfigMemory", + "ext::pgvector::vector": "array.array", +} + +TYPE_IMPORTS = { + "std::uuid": "uuid", + "std::decimal": "decimal", + "std::datetime": "datetime", + "std::duration": "datetime", + "cal::local_date": "datetime", + "cal::local_time": "datetime", + "cal::local_datetime": "datetime", + "ext::pgvector::vector": "array", +} + +INPUT_TYPE_MAPPING = TYPE_MAPPING.copy() +INPUT_TYPE_MAPPING.update( + { + "ext::pgvector::vector": "typing.Sequence[float]", + } +) + +INPUT_TYPE_IMPORTS = TYPE_IMPORTS.copy() +INPUT_TYPE_IMPORTS.update( + { + "ext::pgvector::vector": "typing", + } +) + + +def print_msg(msg): + print(msg, file=sys.stderr) + + +def print_error(msg): + print_msg(f"{C.BOLD}{C.FAIL}error: {C.ENDC}{C.BOLD}{msg}{C.ENDC}") + + +def _get_conn_args(args: argparse.Namespace): + if args.password_from_stdin: + if args.password: + print_error( + "--password and --password-from-stdin are " + "mutually exclusive", + ) + sys.exit(22) + if sys.stdin.isatty(): + password = getpass.getpass() + else: + password = sys.stdin.read().strip() + else: + password = args.password + if args.dsn and args.instance: + print_error("--dsn and --instance are mutually exclusive") + sys.exit(22) + return dict( + dsn=args.dsn or args.instance, + credentials_file=args.credentials_file, + host=args.host, + port=args.port, + database=args.database, + user=args.user, + password=password, + tls_ca_file=args.tls_ca_file, + tls_security=args.tls_security, + ) + +INTRO_QUERY = ''' +with module schema +select ObjectType { + name, + links: { + name, + readonly, + required, + cardinality, + exclusive := exists ( + select .constraints + filter .name = 'std::exclusive' + ), + target: {name}, + constraints: { + name, + params: {name, @value}, + }, + + properties: { + name, + readonly, + required, + cardinality, + exclusive := exists ( + select .constraints + filter .name = 'std::exclusive' + ), + target: {name}, + constraints: { + name, + params: {name, @value}, + }, + }, + } filter .name != '__type__' and not exists .expr, + properties: { + name, + readonly, + required, + cardinality, + exclusive := exists ( + select .constraints + filter .name = 'std::exclusive' + ), + target: {name}, + constraints: { + name, + params: {name, @value}, + }, + } filter .name != 'id' and not exists .expr, + backlinks := >[], +} +filter + not .builtin + and + not .internal + and + not .from_alias + and + not re_test('^(std|cfg|sys|schema)::', .name) + and + not any(re_test('^(cfg|sys|schema)::', .ancestors.name)); +''' + +MODULE_QUERY = ''' +with + module schema, + m := (select `Module` filter not .builtin) +select _ := m.name order by _; +''' + +COMMENT = '''\ +# +# Automatically generated from Gel schema. +# +# Do not edit directly as re-generating this file will overwrite any changes. +#\ +''' + +class Generator(FilePrinter): + def __init__(self, args: argparse.Namespace): + self._default_module = "default" + self._targets = args.target + self._async = False + try: + self._project_dir = pathlib.Path(find_gel_project_dir()) + except gel.ClientConnectionError: + print( + "Cannot find gel.toml: " + "codegen must be run under an EdgeDB project dir" + ) + sys.exit(2) + print_msg(f"Found EdgeDB project: {C.BOLD}{self._project_dir}{C.ENDC}") + self._client = gel.create_client(**_get_conn_args(args)) + self._describe_results = [] + + self._cache = {} + self._imports = set() + self._aliases = {} + self._defs = {} + self._names = set() + + self._basemodule = 'models' + self._outdir = pathlib.Path('models') + self._modules = {} + self._types = {} + + super().__init__() + + def run(self): + try: + self._client.ensure_connected() + except gel.EdgeDBError as e: + print(f"Failed to connect to EdgeDB instance: {e}") + sys.exit(61) + + self.get_schema() + + with self._client: + for mod, maps in self._modules.items(): + if not maps: + # skip apparently empty modules + continue + + with self.init_module(mod): + self.write_types(maps) + + print_msg(f"{C.GREEN}{C.BOLD}Done.{C.ENDC}") + + def get_schema(self): + for mod in self._client.query(MODULE_QUERY): + self._modules[mod] = { + 'object_types': {}, + 'scalar_types': {}, + } + + for t in self._client.query(INTRO_QUERY): + mod, name = get_mod_and_name(t.name) + self._types[t.name] = t + self._modules[mod]['object_types'][t.name] = t + + def init_dir(self, dirpath): + if not dirpath: + # nothing to initialize + return + + path = pathlib.Path(dirpath).resolve() + + # ensure `path` directory exists + if not path.exists(): + path.mkdir() + elif not path.is_dir(): + raise NotADirectoryError( + f'{path!r} exists, but it is not a directory') + + # ensure `path` directory contains `__init__.py` + (path / '__init__.py').touch() + + @contextmanager + def init_module(self, mod): + if any(m.startswith(f'{mod}::') for m in self._modules): + # This is a prefix in another module, thus it is part of a nested + # module structure. + dirpath = mod.split('::') + filename = '__init__.py' + else: + # This is a leaf module, so we just need to create a corresponding + # .py file. + *dirpath, filename = mod.split('::') + filename = f'{filename}.py' + + # Along the dirpath we need to ensure that all packages are created + path = self._outdir + for el in dirpath: + path = path / el + self.init_dir(path) + + with open(path / filename, 'wt') as f: + try: + self.out = f + self.write(f'{COMMENT}\n') + yield f + finally: + self.out = None + + def write_types(self, maps): + object_types = maps['object_types'] + scalar_types = maps['scalar_types'] + + if object_types: + self.write(f'from typing import Optional, Any, Annotated') + self.write(f'from gel.compatibility import pydmodels as gm') + + objects = sorted( + object_types.values(), key=lambda x: x.name + ) + for obj in objects: + self.render_type(obj) + + def render_type(self, objtype): + mod, name = get_mod_and_name(objtype.name) + + self.write() + self.write() + self.write(f'class {name}(gm.BaseGelModel):') + self.indent() + self.write(f'__gel_name__ = {objtype.name!r}') + + if len(objtype.properties) > 0: + self.write() + self.write('# Properties:') + for prop in objtype.properties: + self.render_prop(prop, mod) + + if len(objtype.links) > 0: + self.write() + self.write('# Properties:') + for link in objtype.links: + self.render_link(link, mod) + + self.dedent() + + def render_prop(self, prop, curmod): + pytype = TYPE_MAPPING.get(prop.target.name) + defval = '' + if not pytype: + # skip + return + + # FIXME: need to also handle multi + + if not prop.required: + pytype = f'Optional[{pytype}]' + # A value does not need to be supplied + defval = ' = None' + + if prop.exclusive: + pytype = f'Annotated[{pytype}, gm.Exclusive]' + + self.write( + f'{prop.name}: {pytype}{defval}' + ) + + def render_link(self, link, curmod): + mod, name = get_mod_and_name(link.target.name) + if curmod == mod: + pytype = name + else: + pytype = link.target.name.replace('::', '.') + + # FIXME: need to also handle multi + + if link.required: + self.write( + f'{link.name}: {pytype!r}' + ) + else: + # A value does not need to be supplied + self.write( + f'{link.name}: Optional[{pytype!r}] = None' + ) diff --git a/gel/compatibility/pydmodels.py b/gel/compatibility/pydmodels.py new file mode 100644 index 000000000..6913eaf18 --- /dev/null +++ b/gel/compatibility/pydmodels.py @@ -0,0 +1,229 @@ +# +# This source file is part of the EdgeDB open source project. +# +# Copyright 2025-present MagicStack Inc. and the EdgeDB authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import typing +import uuid + +from collections import defaultdict +from contextlib import contextmanager +from pydantic import BaseModel + +import gel + +from gel.orm.introspection import FilePrinter, get_mod_and_name + + +# FIXME: this should be replaced with a special Annotated value using the +# exact type from the schema. No reason to guess. +GEL_TYPE_MAPPING = { + str: "std::str", + float: "std::float64", + int: "std::int64", + bool: "std::bool", + # "uuid.UUID": "std::uuid", + bytes: "std::bytes", + # "decimal.Decimal": "std::decimal", + # "datetime.datetime": "std::datetime", + # "datetime.timedelta": "std::duration", + # "datetime.date": "cal::local_date", + # "datetime.time": "cal::local_time", + # "gel.RelativeDuration": "cal::relative_duration", + # "gel.DateDuration": "cal::date_duration", + # "gel.ConfigMemory": "cfg::memory", + # "array.array": "ext::pgvector::vector", +} + + +class Exclusive: + pass + + +class BaseGelModel(BaseModel): + def exclusive_fields(self): + results = [] + + for name, info in self.model_fields.items(): + for meta in info.metadata: + if meta is Exclusive: + results.append(name) + + return results + + +class ObjData(BaseModel): + obj: BaseGelModel + rank: int | None = None + gelid: uuid.UUID | None = None + + +def is_optional(field): + return ( + typing.get_origin(field) is typing.Union and + type(None) in typing.get_args(field) + ) + + +class Session: + def __init__(self, data, client): + self._data = list(data) + self._client = client + # insert order will come in tiers, where each tier is itself a list of + # objects that have the same insert precedence and can be inserted in + # parallel. + self._insert_order = [[]] + self._idmap = {} + + self.compute_insert_order() + + def commit(self): + for tx in self._client.transaction(): + with tx: + for objs in self._insert_order: + for item in objs: + query, args = self.generate_insert(item) + gelobj = tx.query_single(query, *args) + self._idmap[id(item)].gelid = gelobj.id + self.clear() + + def clear(self): + self._data = [] + self._insert_order = [[]] + self._idmap = {} + + def generate_insert(self, item, *, arg_start=0): + args = [] + arg = arg_start + query = f'insert {item.__gel_name__} {{' + + for name, info in item.model_fields.items(): + val = getattr(item, name) + + if val is None: + # skip empty values + continue + + if isinstance(val, BaseModel): + subquery, subargs = self.generate_select( + val, arg_start=arg) + arg += len(subargs) + args += subargs + query += f'{name} := ({subquery}), ' + + else: + geltype = GEL_TYPE_MAPPING[type(val)] + query += f'{name} := <{geltype}>${arg}, ' + arg += 1 + args.append(val) + + query += '}' + + return query, args + + def generate_select(self, item, *, arg_start=0): + gelid = self._idmap[id(item)].gelid + args = [] + arg = arg_start + query = f'select detached {item.__gel_name__} filter ' + fquery = [] + + if gelid is not None: + fquery.append(f'.id = ${arg}') + arg += 1 + args.append(gelid) + else: + for name in item.exclusive_fields(): + val = getattr(item, name) + geltype = GEL_TYPE_MAPPING[type(val)] + fquery.append(f'.{name} = <{geltype}>${arg}') + arg += 1 + args.append(val) + + query += ' and '.join(fquery) + + return query, args + + def compute_insert_order(self): + # We traverse all the distinct objects to be inserted and follow up on + # their links, recursively constructing a forest. All leaves get rank + # 0, i.e. they can be inserted first without dependencies on other + # objects. Any objects that have non-empty links have a rank that's + # the maximum rank of all their links + 1. + for obj in self._data: + self.rank_object(obj) + + def rank_object(self, obj): + oid = id(obj) + if self._idmap.get(oid) is not None: + return + + rank = 0 + + for name, info in obj.model_fields.items(): + val = getattr(obj, name) + # We care about actual link value, because it can be None + if isinstance(val, BaseModel): + self.rank_object(val) + linked = self._idmap[id(val)] + rank = max(rank, linked.rank + 1) + + if rank >= len(self._insert_order): + # We only need to grow the _insert_order by 1 more rank since we + # are guaranteed that the rank can only increase by 1 at each + # step. + self._insert_order.append([]) + + self._insert_order[rank].append(obj) + self._idmap[oid] = ObjData(obj=obj, rank=rank) + + # def break_cycle(self): + # # We have a set of objects that form a link cycle. We need to find one + # # of them with optional links to this cycle and set that link to be + # # empty instead. The object's link will be updated separately. + # link_found = False + # for oid in self._cycle: + # obj = self._idmap[oid].obj + + # for name, info in obj.model_fields.items(): + # val = self.getmodelattr(obj, name) + # # We care about actual link value, because it can be None + # if ( + # isinstance(val, BaseModel) and + # is_optional(info.annotation) + # ): + # self._updates[oid].add(name) + # link_found = True + # break + + # if not link_found: + # cycle = ', '.join(self._idmap[oid].obj for oid in self._cycle) + # raise Exception('Cycle detected: {cycle}') + + # self._cycle.clear() + + # def getmodelattr(self, obj, name): + # links = self._updates.get(id(obj), set()) + # if name in links: + # # skip this link + # return None + # else: + # return getattr(obj, name) + + +def commit(client, data, *, identity_merge=False): + sess = Session(data, client) + sess.commit() From 935e04d077fba6bc3a513aa70b7498b6142a078d Mon Sep 17 00:00:00 2001 From: Victor Petrovykh Date: Mon, 24 Mar 2025 05:45:17 -0400 Subject: [PATCH 2/3] Consolidate CLI. Combine all the various code generators into a single CLI tool: gel-gen. It now has a required positional argument which determines the specific mode for the genrators: `edgeql`, `pydantic`, `sqlalchemy`, `sqlmodel`, `django`. --- gel/_testbase.py | 2 +- gel/codegen/cli.py | 102 +++++++++++++++--- gel/codegen/generator.py | 40 +------ gel/codegen/models.py | 41 +------ gel/compatibility/__init__.py | 17 +++ gel/compatibility/clihelper.py | 64 +++++++++++ gel/{orm => compatibility}/introspection.py | 0 gel/compatibility/pydmodels.py | 2 +- gel/orm/cli.py | 4 +- gel/orm/django/generator.py | 3 +- gel/orm/sqla.py | 4 +- gel/orm/sqlmodel.py | 4 +- setup.py | 1 + .../linked/test_linked_async_edgeql.py.assert | 2 +- .../linked/test_linked_edgeql.py.assert | 2 +- .../generated_async_edgeql.py.assert | 2 +- .../generated_async_edgeql.py.assert5 | 2 +- ...elect_optional_json_async_edgeql.py.assert | 2 +- .../select_optional_json_edgeql.py.assert | 2 +- .../select_scalar_async_edgeql.py.assert | 2 +- .../select_scalar_edgeql.py.assert | 2 +- .../argnames/query_one_async_edgeql.py.assert | 2 +- .../argnames/query_one_edgeql.py.assert | 2 +- .../generated_async_edgeql.py.assert | 2 +- .../generated_async_edgeql.py.assert3 | 2 +- .../generated_async_edgeql.py.assert5 | 2 +- .../object/link_prop_async_edgeql.py.assert | 2 +- .../object/link_prop_edgeql.py.assert | 2 +- .../select_object_async_edgeql.py.assert | 2 +- .../object/select_object_edgeql.py.assert | 2 +- .../select_objects_async_edgeql.py.assert | 2 +- .../object/select_objects_edgeql.py.assert | 2 +- .../parpkg/select_args_async_edgeql.py.assert | 2 +- .../select_args_async_edgeql.py.assert5 | 2 +- .../parpkg/select_args_edgeql.py.assert | 2 +- .../parpkg/select_args_edgeql.py.assert5 | 2 +- .../subpkg/my_query_async_edgeql.py.assert | 2 +- .../subpkg/my_query_async_edgeql.py.assert5 | 2 +- .../parpkg/subpkg/my_query_edgeql.py.assert | 2 +- .../parpkg/subpkg/my_query_edgeql.py.assert5 | 2 +- ...custom_vector_input_async_edgeql.py.assert | 2 +- ...ustom_vector_input_async_edgeql.py.assert3 | 2 +- .../custom_vector_input_edgeql.py.assert | 2 +- .../custom_vector_input_edgeql.py.assert3 | 2 +- .../select_scalar_async_edgeql.py.assert | 2 +- .../scalar/select_scalar_edgeql.py.assert | 2 +- .../select_scalars_async_edgeql.py.assert | 2 +- .../scalar/select_scalars_edgeql.py.assert | 2 +- tests/test_codegen.py | 6 +- 49 files changed, 223 insertions(+), 137 deletions(-) create mode 100644 gel/compatibility/__init__.py create mode 100644 gel/compatibility/clihelper.py rename gel/{orm => compatibility}/introspection.py (100%) diff --git a/gel/_testbase.py b/gel/_testbase.py index b581ac907..dab21cdba 100644 --- a/gel/_testbase.py +++ b/gel/_testbase.py @@ -38,7 +38,7 @@ import gel from gel import asyncio_client from gel import blocking_client -from gel.orm.introspection import get_schema_json, GelORMWarning +from gel.compatibility.introspection import get_schema_json, GelORMWarning from gel.orm.sqla import ModelGenerator as SQLAModGen from gel.orm.sqlmodel import ModelGenerator as SQLModGen from gel.orm.django.generator import ModelGenerator as DjangoModGen diff --git a/gel/codegen/cli.py b/gel/codegen/cli.py index 2f1f4fe2a..bba8387f0 100644 --- a/gel/codegen/cli.py +++ b/gel/codegen/cli.py @@ -19,8 +19,16 @@ import argparse import sys +import warnings -from . import generator, models +import gel + +from gel.codegen import generator, models +from gel.compatibility.introspection import get_schema_json, GelORMWarning +from gel.compatibility.clihelper import _get_conn_args +from gel.orm.sqla import ModelGenerator as SQLAModGen +from gel.orm.sqlmodel import ModelGenerator as SQLModGen +from gel.orm.django.generator import ModelGenerator as DjangoModGen class ColoredArgumentParser(argparse.ArgumentParser): @@ -34,7 +42,15 @@ def error(self, message): parser = ColoredArgumentParser( - description="Generate Python code for .edgeql files." + description="Generate Python code for various Gel interfaces." +) +parser.add_argument( + "mode", + choices=['edgeql', 'pydantic', 'sqlalchemy', 'sqlmodel', 'django'], + help="Pick which mode to generate code for: " + "`edgeql` makes Python code for .edgeql files; " + "`pydantic` makes Pydantic models based on the database schema; " + "`sqlalchemy`, `sqlmodel`, `django` makes Python ORM code for a database", ) parser.add_argument("--dsn") parser.add_argument("--credentials-file", metavar="PATH") @@ -69,11 +85,13 @@ def error(self, message): help="Choose one or more targets to generate code (default is async)." ) parser.add_argument( - "--models", - action="store_true", - default=False, - help="Using the schema generate Pydantic models that can be used for " - "bulk inserts.", + "--out", + help="The output directory for the generated files.", +) +parser.add_argument( + "--mod", + help="The fullname of the Python module corresponding to the output " + "directory.", ) if sys.version_info[:2] >= (3, 9): parser.add_argument( @@ -101,10 +119,68 @@ def error(self, message): def main(): args = parser.parse_args() - if args.models: - models.Generator(args).run() - return - if not hasattr(args, "skip_pydantic_validation"): - args.skip_pydantic_validation = True - generator.Generator(args).run() + match args.mode: + case 'edgeql': + if not hasattr(args, "skip_pydantic_validation"): + args.skip_pydantic_validation = True + generator.Generator(args).run() + + case 'pydantic': + if args.mod is None: + parser.error('pydantic requires to specify --mod') + if args.out is None: + parser.error('pydantic requires to specify --out') + + models.Generator(args).run() + + case 'sqlalchemy': + if args.mod is None: + parser.error('sqlalchemy requires to specify --mod') + if args.out is None: + parser.error('sqlalchemy requires to specify --out') + + with warnings.catch_warnings(record=True) as wlist: + warnings.simplefilter("always", GelORMWarning) + spec = get_schema_json( + gel.create_client(**generator._get_conn_args(args))) + gen = SQLAModGen( + outdir=args.out, + basemodule=args.mod, + ) + gen.render_models(spec) + + for w in wlist: + print(w.message) + + case 'sqlmodel': + if args.mod is None: + parser.error('sqlmodel requires to specify --mod') + if args.out is None: + parser.error('sqlmodel requires to specify --out') + + with warnings.catch_warnings(record=True) as wlist: + warnings.simplefilter("always", GelORMWarning) + spec = get_schema_json( + gel.create_client(**_get_conn_args(args))) + gen = SQLModGen( + outdir=args.out, + basemodule=args.mod, + ) + gen.render_models(spec) + + for w in wlist: + print(w.message) + + case 'django': + with warnings.catch_warnings(record=True) as wlist: + warnings.simplefilter("always", GelORMWarning) + spec = get_schema_json( + gel.create_client(**_get_conn_args(args))) + gen = DjangoModGen( + out=args.out, + ) + gen.render_models(spec) + + for w in wlist: + print(w.message) diff --git a/gel/codegen/generator.py b/gel/codegen/generator.py index b5c70ddb3..39c00426c 100644 --- a/gel/codegen/generator.py +++ b/gel/codegen/generator.py @@ -17,7 +17,6 @@ # import argparse -import getpass import io import os import pathlib @@ -30,6 +29,7 @@ from gel import describe from gel.con_utils import find_gel_project_dir from gel.color import get_color +from gel.compatibility.clihelper import print_msg, print_error, _get_conn_args C = get_color() @@ -113,44 +113,6 @@ def __get_validators__(cls): """ -def print_msg(msg): - print(msg, file=sys.stderr) - - -def print_error(msg): - print_msg(f"{C.BOLD}{C.FAIL}error: {C.ENDC}{C.BOLD}{msg}{C.ENDC}") - - -def _get_conn_args(args: argparse.Namespace): - if args.password_from_stdin: - if args.password: - print_error( - "--password and --password-from-stdin are " - "mutually exclusive", - ) - sys.exit(22) - if sys.stdin.isatty(): - password = getpass.getpass() - else: - password = sys.stdin.read().strip() - else: - password = args.password - if args.dsn and args.instance: - print_error("--dsn and --instance are mutually exclusive") - sys.exit(22) - return dict( - dsn=args.dsn or args.instance, - credentials_file=args.credentials_file, - host=args.host, - port=args.port, - database=args.database, - user=args.user, - password=password, - tls_ca_file=args.tls_ca_file, - tls_security=args.tls_security, - ) - - class Generator: def __init__(self, args: argparse.Namespace): self._default_module = "default" diff --git a/gel/codegen/models.py b/gel/codegen/models.py index 2c4d08eae..f58fbe521 100644 --- a/gel/codegen/models.py +++ b/gel/codegen/models.py @@ -17,7 +17,6 @@ # import argparse -import getpass import io import os import pathlib @@ -35,7 +34,8 @@ from gel.con_utils import find_gel_project_dir from gel.color import get_color -from gel.orm.introspection import FilePrinter, get_mod_and_name +from gel.compatibility.introspection import FilePrinter, get_mod_and_name +from gel.compatibility.clihelper import print_msg, print_error, _get_conn_args C = get_color() @@ -95,43 +95,6 @@ ) -def print_msg(msg): - print(msg, file=sys.stderr) - - -def print_error(msg): - print_msg(f"{C.BOLD}{C.FAIL}error: {C.ENDC}{C.BOLD}{msg}{C.ENDC}") - - -def _get_conn_args(args: argparse.Namespace): - if args.password_from_stdin: - if args.password: - print_error( - "--password and --password-from-stdin are " - "mutually exclusive", - ) - sys.exit(22) - if sys.stdin.isatty(): - password = getpass.getpass() - else: - password = sys.stdin.read().strip() - else: - password = args.password - if args.dsn and args.instance: - print_error("--dsn and --instance are mutually exclusive") - sys.exit(22) - return dict( - dsn=args.dsn or args.instance, - credentials_file=args.credentials_file, - host=args.host, - port=args.port, - database=args.database, - user=args.user, - password=password, - tls_ca_file=args.tls_ca_file, - tls_security=args.tls_security, - ) - INTRO_QUERY = ''' with module schema select ObjectType { diff --git a/gel/compatibility/__init__.py b/gel/compatibility/__init__.py new file mode 100644 index 000000000..d58764077 --- /dev/null +++ b/gel/compatibility/__init__.py @@ -0,0 +1,17 @@ +# +# This source file is part of the EdgeDB open source project. +# +# Copyright 2025-present MagicStack Inc. and the EdgeDB authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# diff --git a/gel/compatibility/clihelper.py b/gel/compatibility/clihelper.py new file mode 100644 index 000000000..d6047c589 --- /dev/null +++ b/gel/compatibility/clihelper.py @@ -0,0 +1,64 @@ +# +# This source file is part of the EdgeDB open source project. +# +# Copyright 2025-present MagicStack Inc. and the EdgeDB authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import argparse +import getpass +import sys + +from gel.color import get_color + + +C = get_color() + + +def print_msg(msg): + print(msg, file=sys.stderr) + + +def print_error(msg): + print_msg(f"{C.BOLD}{C.FAIL}error: {C.ENDC}{C.BOLD}{msg}{C.ENDC}") + + +def _get_conn_args(args: argparse.Namespace): + if args.password_from_stdin: + if args.password: + print_error( + "--password and --password-from-stdin are " + "mutually exclusive", + ) + sys.exit(22) + if sys.stdin.isatty(): + password = getpass.getpass() + else: + password = sys.stdin.read().strip() + else: + password = args.password + if args.dsn and args.instance: + print_error("--dsn and --instance are mutually exclusive") + sys.exit(22) + return dict( + dsn=args.dsn or args.instance, + credentials_file=args.credentials_file, + host=args.host, + port=args.port, + database=args.database, + user=args.user, + password=password, + tls_ca_file=args.tls_ca_file, + tls_security=args.tls_security, + ) diff --git a/gel/orm/introspection.py b/gel/compatibility/introspection.py similarity index 100% rename from gel/orm/introspection.py rename to gel/compatibility/introspection.py diff --git a/gel/compatibility/pydmodels.py b/gel/compatibility/pydmodels.py index 6913eaf18..e59dee67c 100644 --- a/gel/compatibility/pydmodels.py +++ b/gel/compatibility/pydmodels.py @@ -25,7 +25,7 @@ import gel -from gel.orm.introspection import FilePrinter, get_mod_and_name +from gel.compatibility.introspection import FilePrinter, get_mod_and_name # FIXME: this should be replaced with a special Annotated value using the diff --git a/gel/orm/cli.py b/gel/orm/cli.py index cf07797af..d453f96ec 100644 --- a/gel/orm/cli.py +++ b/gel/orm/cli.py @@ -22,8 +22,8 @@ import gel -from gel.codegen.generator import _get_conn_args -from .introspection import get_schema_json, GelORMWarning +from gel.compatibility.introspection import get_schema_json, GelORMWarning +from gel.compatibility.clihelper import _get_conn_args from .sqla import ModelGenerator as SQLAModGen from .sqlmodel import ModelGenerator as SQLModGen from .django.generator import ModelGenerator as DjangoModGen diff --git a/gel/orm/django/generator.py b/gel/orm/django/generator.py index 74e8ac9fe..8896e7799 100644 --- a/gel/orm/django/generator.py +++ b/gel/orm/django/generator.py @@ -2,7 +2,8 @@ import re import warnings -from ..introspection import get_mod_and_name, GelORMWarning, FilePrinter +from gel.compatibility.introspection import get_mod_and_name +from gel.compatibility.introspection import GelORMWarning, FilePrinter GEL_SCALAR_MAP = { diff --git a/gel/orm/sqla.py b/gel/orm/sqla.py index 21f89d912..d9faef71b 100644 --- a/gel/orm/sqla.py +++ b/gel/orm/sqla.py @@ -4,8 +4,8 @@ from contextlib import contextmanager -from .introspection import get_sql_name, get_mod_and_name -from .introspection import GelORMWarning, FilePrinter +from gel.compatibility.introspection import get_sql_name, get_mod_and_name +from gel.compatibility.introspection import GelORMWarning, FilePrinter GEL_SCALAR_MAP = { diff --git a/gel/orm/sqlmodel.py b/gel/orm/sqlmodel.py index 3cdcdd44a..37baad65e 100644 --- a/gel/orm/sqlmodel.py +++ b/gel/orm/sqlmodel.py @@ -4,8 +4,8 @@ from contextlib import contextmanager -from .introspection import get_sql_name, get_mod_and_name -from .introspection import GelORMWarning, FilePrinter +from gel.compatibility.introspection import get_sql_name, get_mod_and_name +from gel.compatibility.introspection import GelORMWarning, FilePrinter GEL_SCALAR_MAP = { diff --git a/setup.py b/setup.py index 64d56c0b1..b14bddb36 100644 --- a/setup.py +++ b/setup.py @@ -366,6 +366,7 @@ def finalize_options(self): "console_scripts": [ "edgedb-py=gel.codegen.cli:main", "gel-py=gel.codegen.cli:main", + "gel-gen=gel.codegen.cli:main", "gel-orm=gel.orm.cli:main", "gel=gel.cli:main", ] diff --git a/tests/codegen/linked/test_linked_async_edgeql.py.assert b/tests/codegen/linked/test_linked_async_edgeql.py.assert index 33058af62..e6a75d988 100644 --- a/tests/codegen/linked/test_linked_async_edgeql.py.assert +++ b/tests/codegen/linked/test_linked_async_edgeql.py.assert @@ -1,5 +1,5 @@ # AUTOGENERATED FROM 'linked/test_linked.edgeql' WITH: -# $ gel-py +# $ gel-gen edgeql from __future__ import annotations diff --git a/tests/codegen/linked/test_linked_edgeql.py.assert b/tests/codegen/linked/test_linked_edgeql.py.assert index d8d7aa04b..0a75c18f4 100644 --- a/tests/codegen/linked/test_linked_edgeql.py.assert +++ b/tests/codegen/linked/test_linked_edgeql.py.assert @@ -1,5 +1,5 @@ # AUTOGENERATED FROM 'linked/test_linked.edgeql' WITH: -# $ gel-py --target blocking --no-skip-pydantic-validation +# $ gel-gen edgeql --target blocking --no-skip-pydantic-validation from __future__ import annotations diff --git a/tests/codegen/test-project1/generated_async_edgeql.py.assert b/tests/codegen/test-project1/generated_async_edgeql.py.assert index 22ba6c8d4..ba090fa52 100644 --- a/tests/codegen/test-project1/generated_async_edgeql.py.assert +++ b/tests/codegen/test-project1/generated_async_edgeql.py.assert @@ -3,7 +3,7 @@ # 'select_scalar.edgeql' # 'linked/test_linked.edgeql' # WITH: -# $ gel-py --target async --file --no-skip-pydantic-validation +# $ gel-gen edgeql --target async --file --no-skip-pydantic-validation from __future__ import annotations diff --git a/tests/codegen/test-project1/generated_async_edgeql.py.assert5 b/tests/codegen/test-project1/generated_async_edgeql.py.assert5 index 22ba6c8d4..ba090fa52 100644 --- a/tests/codegen/test-project1/generated_async_edgeql.py.assert5 +++ b/tests/codegen/test-project1/generated_async_edgeql.py.assert5 @@ -3,7 +3,7 @@ # 'select_scalar.edgeql' # 'linked/test_linked.edgeql' # WITH: -# $ gel-py --target async --file --no-skip-pydantic-validation +# $ gel-gen edgeql --target async --file --no-skip-pydantic-validation from __future__ import annotations diff --git a/tests/codegen/test-project1/select_optional_json_async_edgeql.py.assert b/tests/codegen/test-project1/select_optional_json_async_edgeql.py.assert index f294ffd27..b34803b1a 100644 --- a/tests/codegen/test-project1/select_optional_json_async_edgeql.py.assert +++ b/tests/codegen/test-project1/select_optional_json_async_edgeql.py.assert @@ -1,5 +1,5 @@ # AUTOGENERATED FROM 'select_optional_json.edgeql' WITH: -# $ gel-py +# $ gel-gen edgeql from __future__ import annotations diff --git a/tests/codegen/test-project1/select_optional_json_edgeql.py.assert b/tests/codegen/test-project1/select_optional_json_edgeql.py.assert index 61d2cd0d4..b54bd2bd5 100644 --- a/tests/codegen/test-project1/select_optional_json_edgeql.py.assert +++ b/tests/codegen/test-project1/select_optional_json_edgeql.py.assert @@ -1,5 +1,5 @@ # AUTOGENERATED FROM 'select_optional_json.edgeql' WITH: -# $ gel-py --target blocking --no-skip-pydantic-validation +# $ gel-gen edgeql --target blocking --no-skip-pydantic-validation from __future__ import annotations diff --git a/tests/codegen/test-project1/select_scalar_async_edgeql.py.assert b/tests/codegen/test-project1/select_scalar_async_edgeql.py.assert index 2a6dc130e..f4680a28e 100644 --- a/tests/codegen/test-project1/select_scalar_async_edgeql.py.assert +++ b/tests/codegen/test-project1/select_scalar_async_edgeql.py.assert @@ -1,5 +1,5 @@ # AUTOGENERATED FROM 'select_scalar.edgeql' WITH: -# $ gel-py +# $ gel-gen edgeql from __future__ import annotations diff --git a/tests/codegen/test-project1/select_scalar_edgeql.py.assert b/tests/codegen/test-project1/select_scalar_edgeql.py.assert index d8b16a531..4ad270b26 100644 --- a/tests/codegen/test-project1/select_scalar_edgeql.py.assert +++ b/tests/codegen/test-project1/select_scalar_edgeql.py.assert @@ -1,5 +1,5 @@ # AUTOGENERATED FROM 'select_scalar.edgeql' WITH: -# $ gel-py --target blocking --no-skip-pydantic-validation +# $ gel-gen edgeql --target blocking --no-skip-pydantic-validation from __future__ import annotations diff --git a/tests/codegen/test-project2/argnames/query_one_async_edgeql.py.assert b/tests/codegen/test-project2/argnames/query_one_async_edgeql.py.assert index ecf680f73..766472689 100644 --- a/tests/codegen/test-project2/argnames/query_one_async_edgeql.py.assert +++ b/tests/codegen/test-project2/argnames/query_one_async_edgeql.py.assert @@ -1,5 +1,5 @@ # AUTOGENERATED FROM 'argnames/query_one.edgeql' WITH: -# $ gel-py +# $ gel-gen edgeql from __future__ import annotations diff --git a/tests/codegen/test-project2/argnames/query_one_edgeql.py.assert b/tests/codegen/test-project2/argnames/query_one_edgeql.py.assert index 56701a68a..43fd9fbdf 100644 --- a/tests/codegen/test-project2/argnames/query_one_edgeql.py.assert +++ b/tests/codegen/test-project2/argnames/query_one_edgeql.py.assert @@ -1,5 +1,5 @@ # AUTOGENERATED FROM 'argnames/query_one.edgeql' WITH: -# $ gel-py --target blocking --no-skip-pydantic-validation +# $ gel-gen edgeql --target blocking --no-skip-pydantic-validation from __future__ import annotations diff --git a/tests/codegen/test-project2/generated_async_edgeql.py.assert b/tests/codegen/test-project2/generated_async_edgeql.py.assert index 4efdfad8e..a3ff54744 100644 --- a/tests/codegen/test-project2/generated_async_edgeql.py.assert +++ b/tests/codegen/test-project2/generated_async_edgeql.py.assert @@ -9,7 +9,7 @@ # 'scalar/select_scalar.edgeql' # 'scalar/select_scalars.edgeql' # WITH: -# $ gel-py --target async --file --no-skip-pydantic-validation +# $ gel-gen edgeql --target async --file --no-skip-pydantic-validation from __future__ import annotations diff --git a/tests/codegen/test-project2/generated_async_edgeql.py.assert3 b/tests/codegen/test-project2/generated_async_edgeql.py.assert3 index 4a1a74444..76605bed7 100644 --- a/tests/codegen/test-project2/generated_async_edgeql.py.assert3 +++ b/tests/codegen/test-project2/generated_async_edgeql.py.assert3 @@ -9,7 +9,7 @@ # 'scalar/select_scalar.edgeql' # 'scalar/select_scalars.edgeql' # WITH: -# $ gel-py --target async --file --no-skip-pydantic-validation +# $ gel-gen edgeql --target async --file --no-skip-pydantic-validation from __future__ import annotations diff --git a/tests/codegen/test-project2/generated_async_edgeql.py.assert5 b/tests/codegen/test-project2/generated_async_edgeql.py.assert5 index 8907bf733..126198b8e 100644 --- a/tests/codegen/test-project2/generated_async_edgeql.py.assert5 +++ b/tests/codegen/test-project2/generated_async_edgeql.py.assert5 @@ -9,7 +9,7 @@ # 'scalar/select_scalar.edgeql' # 'scalar/select_scalars.edgeql' # WITH: -# $ gel-py --target async --file --no-skip-pydantic-validation +# $ gel-gen edgeql --target async --file --no-skip-pydantic-validation from __future__ import annotations diff --git a/tests/codegen/test-project2/object/link_prop_async_edgeql.py.assert b/tests/codegen/test-project2/object/link_prop_async_edgeql.py.assert index bc323374e..ed009d80c 100644 --- a/tests/codegen/test-project2/object/link_prop_async_edgeql.py.assert +++ b/tests/codegen/test-project2/object/link_prop_async_edgeql.py.assert @@ -1,5 +1,5 @@ # AUTOGENERATED FROM 'object/link_prop.edgeql' WITH: -# $ gel-py +# $ gel-gen edgeql from __future__ import annotations diff --git a/tests/codegen/test-project2/object/link_prop_edgeql.py.assert b/tests/codegen/test-project2/object/link_prop_edgeql.py.assert index d3912dacc..43aae9953 100644 --- a/tests/codegen/test-project2/object/link_prop_edgeql.py.assert +++ b/tests/codegen/test-project2/object/link_prop_edgeql.py.assert @@ -1,5 +1,5 @@ # AUTOGENERATED FROM 'object/link_prop.edgeql' WITH: -# $ gel-py --target blocking --no-skip-pydantic-validation +# $ gel-gen edgeql --target blocking --no-skip-pydantic-validation from __future__ import annotations diff --git a/tests/codegen/test-project2/object/select_object_async_edgeql.py.assert b/tests/codegen/test-project2/object/select_object_async_edgeql.py.assert index 37e1ec36b..2afac2d80 100644 --- a/tests/codegen/test-project2/object/select_object_async_edgeql.py.assert +++ b/tests/codegen/test-project2/object/select_object_async_edgeql.py.assert @@ -1,5 +1,5 @@ # AUTOGENERATED FROM 'object/select_object.edgeql' WITH: -# $ gel-py +# $ gel-gen edgeql from __future__ import annotations diff --git a/tests/codegen/test-project2/object/select_object_edgeql.py.assert b/tests/codegen/test-project2/object/select_object_edgeql.py.assert index 2fa0df0ac..c1144fcfa 100644 --- a/tests/codegen/test-project2/object/select_object_edgeql.py.assert +++ b/tests/codegen/test-project2/object/select_object_edgeql.py.assert @@ -1,5 +1,5 @@ # AUTOGENERATED FROM 'object/select_object.edgeql' WITH: -# $ gel-py --target blocking --no-skip-pydantic-validation +# $ gel-gen edgeql --target blocking --no-skip-pydantic-validation from __future__ import annotations diff --git a/tests/codegen/test-project2/object/select_objects_async_edgeql.py.assert b/tests/codegen/test-project2/object/select_objects_async_edgeql.py.assert index f4e97abac..f3138b0b6 100644 --- a/tests/codegen/test-project2/object/select_objects_async_edgeql.py.assert +++ b/tests/codegen/test-project2/object/select_objects_async_edgeql.py.assert @@ -1,5 +1,5 @@ # AUTOGENERATED FROM 'object/select_objects.edgeql' WITH: -# $ gel-py +# $ gel-gen edgeql from __future__ import annotations diff --git a/tests/codegen/test-project2/object/select_objects_edgeql.py.assert b/tests/codegen/test-project2/object/select_objects_edgeql.py.assert index 875b3cfa9..d11e8e77c 100644 --- a/tests/codegen/test-project2/object/select_objects_edgeql.py.assert +++ b/tests/codegen/test-project2/object/select_objects_edgeql.py.assert @@ -1,5 +1,5 @@ # AUTOGENERATED FROM 'object/select_objects.edgeql' WITH: -# $ gel-py --target blocking --no-skip-pydantic-validation +# $ gel-gen edgeql --target blocking --no-skip-pydantic-validation from __future__ import annotations diff --git a/tests/codegen/test-project2/parpkg/select_args_async_edgeql.py.assert b/tests/codegen/test-project2/parpkg/select_args_async_edgeql.py.assert index 4f3ece14d..7ed3b1539 100644 --- a/tests/codegen/test-project2/parpkg/select_args_async_edgeql.py.assert +++ b/tests/codegen/test-project2/parpkg/select_args_async_edgeql.py.assert @@ -1,5 +1,5 @@ # AUTOGENERATED FROM 'parpkg/select_args.edgeql' WITH: -# $ gel-py +# $ gel-gen edgeql from __future__ import annotations diff --git a/tests/codegen/test-project2/parpkg/select_args_async_edgeql.py.assert5 b/tests/codegen/test-project2/parpkg/select_args_async_edgeql.py.assert5 index 0e48f41dc..e8d86d2be 100644 --- a/tests/codegen/test-project2/parpkg/select_args_async_edgeql.py.assert5 +++ b/tests/codegen/test-project2/parpkg/select_args_async_edgeql.py.assert5 @@ -1,5 +1,5 @@ # AUTOGENERATED FROM 'parpkg/select_args.edgeql' WITH: -# $ gel-py +# $ gel-gen edgeql from __future__ import annotations diff --git a/tests/codegen/test-project2/parpkg/select_args_edgeql.py.assert b/tests/codegen/test-project2/parpkg/select_args_edgeql.py.assert index 5adeb9222..38bc69d33 100644 --- a/tests/codegen/test-project2/parpkg/select_args_edgeql.py.assert +++ b/tests/codegen/test-project2/parpkg/select_args_edgeql.py.assert @@ -1,5 +1,5 @@ # AUTOGENERATED FROM 'parpkg/select_args.edgeql' WITH: -# $ gel-py --target blocking --no-skip-pydantic-validation +# $ gel-gen edgeql --target blocking --no-skip-pydantic-validation from __future__ import annotations diff --git a/tests/codegen/test-project2/parpkg/select_args_edgeql.py.assert5 b/tests/codegen/test-project2/parpkg/select_args_edgeql.py.assert5 index ca9cfddbb..991d144aa 100644 --- a/tests/codegen/test-project2/parpkg/select_args_edgeql.py.assert5 +++ b/tests/codegen/test-project2/parpkg/select_args_edgeql.py.assert5 @@ -1,5 +1,5 @@ # AUTOGENERATED FROM 'parpkg/select_args.edgeql' WITH: -# $ gel-py --target blocking --no-skip-pydantic-validation +# $ gel-gen edgeql --target blocking --no-skip-pydantic-validation from __future__ import annotations diff --git a/tests/codegen/test-project2/parpkg/subpkg/my_query_async_edgeql.py.assert b/tests/codegen/test-project2/parpkg/subpkg/my_query_async_edgeql.py.assert index ddb0d6952..ea7e14fa5 100644 --- a/tests/codegen/test-project2/parpkg/subpkg/my_query_async_edgeql.py.assert +++ b/tests/codegen/test-project2/parpkg/subpkg/my_query_async_edgeql.py.assert @@ -1,5 +1,5 @@ # AUTOGENERATED FROM 'parpkg/subpkg/my_query.edgeql' WITH: -# $ gel-py +# $ gel-gen edgeql from __future__ import annotations diff --git a/tests/codegen/test-project2/parpkg/subpkg/my_query_async_edgeql.py.assert5 b/tests/codegen/test-project2/parpkg/subpkg/my_query_async_edgeql.py.assert5 index 7b7138cbf..2b4d509bd 100644 --- a/tests/codegen/test-project2/parpkg/subpkg/my_query_async_edgeql.py.assert5 +++ b/tests/codegen/test-project2/parpkg/subpkg/my_query_async_edgeql.py.assert5 @@ -1,5 +1,5 @@ # AUTOGENERATED FROM 'parpkg/subpkg/my_query.edgeql' WITH: -# $ gel-py +# $ gel-gen edgeql from __future__ import annotations diff --git a/tests/codegen/test-project2/parpkg/subpkg/my_query_edgeql.py.assert b/tests/codegen/test-project2/parpkg/subpkg/my_query_edgeql.py.assert index b22baa497..032352c4c 100644 --- a/tests/codegen/test-project2/parpkg/subpkg/my_query_edgeql.py.assert +++ b/tests/codegen/test-project2/parpkg/subpkg/my_query_edgeql.py.assert @@ -1,5 +1,5 @@ # AUTOGENERATED FROM 'parpkg/subpkg/my_query.edgeql' WITH: -# $ gel-py --target blocking --no-skip-pydantic-validation +# $ gel-gen edgeql --target blocking --no-skip-pydantic-validation from __future__ import annotations diff --git a/tests/codegen/test-project2/parpkg/subpkg/my_query_edgeql.py.assert5 b/tests/codegen/test-project2/parpkg/subpkg/my_query_edgeql.py.assert5 index 2d60d18a5..bbf7baa9b 100644 --- a/tests/codegen/test-project2/parpkg/subpkg/my_query_edgeql.py.assert5 +++ b/tests/codegen/test-project2/parpkg/subpkg/my_query_edgeql.py.assert5 @@ -1,5 +1,5 @@ # AUTOGENERATED FROM 'parpkg/subpkg/my_query.edgeql' WITH: -# $ gel-py --target blocking --no-skip-pydantic-validation +# $ gel-gen edgeql --target blocking --no-skip-pydantic-validation from __future__ import annotations diff --git a/tests/codegen/test-project2/scalar/custom_vector_input_async_edgeql.py.assert b/tests/codegen/test-project2/scalar/custom_vector_input_async_edgeql.py.assert index 277e472a7..23054ffac 100644 --- a/tests/codegen/test-project2/scalar/custom_vector_input_async_edgeql.py.assert +++ b/tests/codegen/test-project2/scalar/custom_vector_input_async_edgeql.py.assert @@ -1,5 +1,5 @@ # AUTOGENERATED FROM 'scalar/custom_vector_input.edgeql' WITH: -# $ gel-py +# $ gel-gen edgeql from __future__ import annotations diff --git a/tests/codegen/test-project2/scalar/custom_vector_input_async_edgeql.py.assert3 b/tests/codegen/test-project2/scalar/custom_vector_input_async_edgeql.py.assert3 index 1a1f0e0d1..5b0d1d4d3 100644 --- a/tests/codegen/test-project2/scalar/custom_vector_input_async_edgeql.py.assert3 +++ b/tests/codegen/test-project2/scalar/custom_vector_input_async_edgeql.py.assert3 @@ -1,5 +1,5 @@ # AUTOGENERATED FROM 'scalar/custom_vector_input.edgeql' WITH: -# $ gel-py +# $ gel-gen edgeql from __future__ import annotations diff --git a/tests/codegen/test-project2/scalar/custom_vector_input_edgeql.py.assert b/tests/codegen/test-project2/scalar/custom_vector_input_edgeql.py.assert index 6ccca1a6c..eab8d5dec 100644 --- a/tests/codegen/test-project2/scalar/custom_vector_input_edgeql.py.assert +++ b/tests/codegen/test-project2/scalar/custom_vector_input_edgeql.py.assert @@ -1,5 +1,5 @@ # AUTOGENERATED FROM 'scalar/custom_vector_input.edgeql' WITH: -# $ gel-py --target blocking --no-skip-pydantic-validation +# $ gel-gen edgeql --target blocking --no-skip-pydantic-validation from __future__ import annotations diff --git a/tests/codegen/test-project2/scalar/custom_vector_input_edgeql.py.assert3 b/tests/codegen/test-project2/scalar/custom_vector_input_edgeql.py.assert3 index 509bd484d..c9fb67ec8 100644 --- a/tests/codegen/test-project2/scalar/custom_vector_input_edgeql.py.assert3 +++ b/tests/codegen/test-project2/scalar/custom_vector_input_edgeql.py.assert3 @@ -1,5 +1,5 @@ # AUTOGENERATED FROM 'scalar/custom_vector_input.edgeql' WITH: -# $ gel-py --target blocking --no-skip-pydantic-validation +# $ gel-gen edgeql --target blocking --no-skip-pydantic-validation from __future__ import annotations diff --git a/tests/codegen/test-project2/scalar/select_scalar_async_edgeql.py.assert b/tests/codegen/test-project2/scalar/select_scalar_async_edgeql.py.assert index 327be5e31..3eaba6aa4 100644 --- a/tests/codegen/test-project2/scalar/select_scalar_async_edgeql.py.assert +++ b/tests/codegen/test-project2/scalar/select_scalar_async_edgeql.py.assert @@ -1,5 +1,5 @@ # AUTOGENERATED FROM 'scalar/select_scalar.edgeql' WITH: -# $ gel-py +# $ gel-gen edgeql from __future__ import annotations diff --git a/tests/codegen/test-project2/scalar/select_scalar_edgeql.py.assert b/tests/codegen/test-project2/scalar/select_scalar_edgeql.py.assert index 0337455b5..59604c9ad 100644 --- a/tests/codegen/test-project2/scalar/select_scalar_edgeql.py.assert +++ b/tests/codegen/test-project2/scalar/select_scalar_edgeql.py.assert @@ -1,5 +1,5 @@ # AUTOGENERATED FROM 'scalar/select_scalar.edgeql' WITH: -# $ gel-py --target blocking --no-skip-pydantic-validation +# $ gel-gen edgeql --target blocking --no-skip-pydantic-validation from __future__ import annotations diff --git a/tests/codegen/test-project2/scalar/select_scalars_async_edgeql.py.assert b/tests/codegen/test-project2/scalar/select_scalars_async_edgeql.py.assert index a72918509..ad16a4708 100644 --- a/tests/codegen/test-project2/scalar/select_scalars_async_edgeql.py.assert +++ b/tests/codegen/test-project2/scalar/select_scalars_async_edgeql.py.assert @@ -1,5 +1,5 @@ # AUTOGENERATED FROM 'scalar/select_scalars.edgeql' WITH: -# $ gel-py +# $ gel-gen edgeql from __future__ import annotations diff --git a/tests/codegen/test-project2/scalar/select_scalars_edgeql.py.assert b/tests/codegen/test-project2/scalar/select_scalars_edgeql.py.assert index 9cf2ebb42..67bf58eec 100644 --- a/tests/codegen/test-project2/scalar/select_scalars_edgeql.py.assert +++ b/tests/codegen/test-project2/scalar/select_scalars_edgeql.py.assert @@ -1,5 +1,5 @@ # AUTOGENERATED FROM 'scalar/select_scalars.edgeql' WITH: -# $ gel-py --target blocking --no-skip-pydantic-validation +# $ gel-gen edgeql --target blocking --no-skip-pydantic-validation from __future__ import annotations diff --git a/tests/test_codegen.py b/tests/test_codegen.py index 94af24932..4cad1ae7d 100644 --- a/tests/test_codegen.py +++ b/tests/test_codegen.py @@ -109,12 +109,13 @@ async def run(*args, extra_env=None): p.returncode, args, output=await p.stdout.read(), ) - cmd = env.get("EDGEDB_PYTHON_TEST_CODEGEN_CMD", "gel-py") + cmd = env.get("EDGEDB_PYTHON_TEST_CODEGEN_CMD", "gel-gen") await run( - cmd, extra_env={"EDGEDB_PYTHON_CODEGEN_PY_VER": "3.8.5"} + cmd, "edgeql", extra_env={"EDGEDB_PYTHON_CODEGEN_PY_VER": "3.8.5"} ) await run( cmd, + "edgeql", "--target", "blocking", "--no-skip-pydantic-validation", @@ -122,6 +123,7 @@ async def run(*args, extra_env=None): ) await run( cmd, + "edgeql", "--target", "async", "--file", From 5c3b710915e29aff75c56013292a0d746bd7fa8e Mon Sep 17 00:00:00 2001 From: Victor Petrovykh Date: Fri, 28 Mar 2025 12:43:53 -0400 Subject: [PATCH 3/3] Add bulk update. Now we have bulk insert and bulk update functionality. Add tests for both modes. --- gel/_testbase.py | 15 ++ gel/codegen/models.py | 119 ++++++--- gel/compatibility/pydmodels.py | 384 +++++++++++++++++++++------ tests/dbsetup/pydantic.gel | 28 ++ tests/test_pydantic_orm.py | 465 +++++++++++++++++++++++++++++++++ 5 files changed, 896 insertions(+), 115 deletions(-) create mode 100644 tests/dbsetup/pydantic.gel create mode 100644 tests/test_pydantic_orm.py diff --git a/gel/_testbase.py b/gel/_testbase.py index dab21cdba..efce9a993 100644 --- a/gel/_testbase.py +++ b/gel/_testbase.py @@ -38,6 +38,8 @@ import gel from gel import asyncio_client from gel import blocking_client +from gel.codegen import cli +from gel.codegen.models import Generator as PydanticGenerator from gel.compatibility.introspection import get_schema_json, GelORMWarning from gel.orm.sqla import ModelGenerator as SQLAModGen from gel.orm.sqlmodel import ModelGenerator as SQLModGen @@ -778,6 +780,19 @@ def setupORM(cls): gen.render_models(cls.spec) +class PydanticTestCase(ORMTestCase): + @classmethod + def setupORM(cls): + cargs = cls.get_connect_args(database=cls.get_database_name()) + args = cli.parser.parse_args([ + 'pydantic', + '--out', str(os.path.join(cls.tmpormdir.name, cls.MODEL_PACKAGE)), + '--mod', cls.MODEL_PACKAGE, + ]) + gen = PydanticGenerator(args, client=cls.client) + gen.run() + + _lock_cnt = 0 diff --git a/gel/codegen/models.py b/gel/codegen/models.py index f58fbe521..d1c0a34be 100644 --- a/gel/codegen/models.py +++ b/gel/codegen/models.py @@ -175,20 +175,14 @@ ''' class Generator(FilePrinter): - def __init__(self, args: argparse.Namespace): + def __init__(self, args: argparse.Namespace, client=None): self._default_module = "default" self._targets = args.target self._async = False - try: - self._project_dir = pathlib.Path(find_gel_project_dir()) - except gel.ClientConnectionError: - print( - "Cannot find gel.toml: " - "codegen must be run under an EdgeDB project dir" - ) - sys.exit(2) - print_msg(f"Found EdgeDB project: {C.BOLD}{self._project_dir}{C.ENDC}") - self._client = gel.create_client(**_get_conn_args(args)) + if client is not None: + self._client = client + else: + self._client = gel.create_client(**_get_conn_args(args)) self._describe_results = [] self._cache = {} @@ -197,11 +191,13 @@ def __init__(self, args: argparse.Namespace): self._defs = {} self._names = set() - self._basemodule = 'models' - self._outdir = pathlib.Path('models') + self._basemodule = args.mod + self._outdir = pathlib.Path(args.out) self._modules = {} self._types = {} + self.init_dir(self._outdir) + super().__init__() def run(self): @@ -285,74 +281,117 @@ def write_types(self, maps): scalar_types = maps['scalar_types'] if object_types: - self.write(f'from typing import Optional, Any, Annotated') + self.write(f'import pydantic') + self.write(f'import typing as pt') + self.write(f'import uuid') self.write(f'from gel.compatibility import pydmodels as gm') objects = sorted( object_types.values(), key=lambda x: x.name ) for obj in objects: + self.render_type(obj, variant='Base') + self.render_type(obj, variant='Update') self.render_type(obj) - def render_type(self, objtype): + def render_type(self, objtype, *, variant=None): mod, name = get_mod_and_name(objtype.name) + is_empty = True self.write() self.write() - self.write(f'class {name}(gm.BaseGelModel):') - self.indent() - self.write(f'__gel_name__ = {objtype.name!r}') - - if len(objtype.properties) > 0: + match variant: + case 'Base': + self.write(f'class _{variant}{name}(gm.BaseGelModel):') + self.indent() + self.write(f'__gel_name__ = {objtype.name!r}') + case 'Update': + self.write(f'class _{variant}{name}(gm.UpdateGelModel):') + self.indent() + self.write(f'__gel_name__ = {objtype.name!r}') + self.write( + f"id: pt.Annotated[uuid.UUID, gm.GelType('std::uuid'), " + f"gm.Exclusive]" + ) + case _: + self.write(f'class {name}(_Base{name}):') + self.indent() + + if variant and len(objtype.properties) > 0: + is_empty = False self.write() self.write('# Properties:') for prop in objtype.properties: - self.render_prop(prop, mod) + self.render_prop(prop, mod, variant=variant) - if len(objtype.links) > 0: - self.write() - self.write('# Properties:') + if variant != 'Base' and len(objtype.links) > 0: + if variant or not is_empty: + self.write() + is_empty = False + self.write('# Links:') for link in objtype.links: - self.render_link(link, mod) + self.render_link(link, mod, variant=variant) + + if not variant: + if not is_empty: + self.write() + self.write('# Class variants:') + self.write(f'base: pt.ClassVar = _Base{name}') + self.write(f'update: pt.ClassVar = _Update{name}') self.dedent() - def render_prop(self, prop, curmod): + def render_prop(self, prop, curmod, *, variant=None): pytype = TYPE_MAPPING.get(prop.target.name) + annotated = [f'gm.GelType({prop.target.name!r})'] defval = '' if not pytype: # skip return - # FIXME: need to also handle multi + if str(prop.cardinality) == 'Many': + annotated.append('gm.Multi') + pytype = f'pt.List[{pytype}]' + defval = ' = []' - if not prop.required: - pytype = f'Optional[{pytype}]' + if variant == 'Update' or not prop.required: + pytype = f'pt.Optional[{pytype}]' # A value does not need to be supplied defval = ' = None' if prop.exclusive: - pytype = f'Annotated[{pytype}, gm.Exclusive]' + annotated.append('gm.Exclusive') + + anno = ', '.join([pytype] + annotated) + pytype = f'pt.Annotated[{anno}]' self.write( f'{prop.name}: {pytype}{defval}' ) - def render_link(self, link, curmod): + def render_link(self, link, curmod, *, variant=None): mod, name = get_mod_and_name(link.target.name) + annotated = [f'gm.GelType({link.target.name!r})', 'gm.Link'] + defval = '' if curmod == mod: pytype = name else: pytype = link.target.name.replace('::', '.') + pytype = repr(pytype) - # FIXME: need to also handle multi + if str(link.cardinality) == 'Many': + annotated.append('gm.Multi') + pytype = f'pt.List[{pytype}]' + defval = ' = []' - if link.required: - self.write( - f'{link.name}: {pytype!r}' - ) - else: + if variant == 'Update' or not link.required: + pytype = f'pt.Optional[{pytype}]' # A value does not need to be supplied - self.write( - f'{link.name}: Optional[{pytype!r}] = None' - ) + defval = ' = None' + + anno = ', '.join([pytype] + annotated) + pytype = f'pt.Annotated[{anno}]' + + self.write( + f'{link.name}: {pytype}{defval}' + ) diff --git a/gel/compatibility/pydmodels.py b/gel/compatibility/pydmodels.py index e59dee67c..7cd6eb6b7 100644 --- a/gel/compatibility/pydmodels.py +++ b/gel/compatibility/pydmodels.py @@ -28,28 +28,20 @@ from gel.compatibility.introspection import FilePrinter, get_mod_and_name -# FIXME: this should be replaced with a special Annotated value using the -# exact type from the schema. No reason to guess. -GEL_TYPE_MAPPING = { - str: "std::str", - float: "std::float64", - int: "std::int64", - bool: "std::bool", - # "uuid.UUID": "std::uuid", - bytes: "std::bytes", - # "decimal.Decimal": "std::decimal", - # "datetime.datetime": "std::datetime", - # "datetime.timedelta": "std::duration", - # "datetime.date": "cal::local_date", - # "datetime.time": "cal::local_time", - # "gel.RelativeDuration": "cal::relative_duration", - # "gel.DateDuration": "cal::date_duration", - # "gel.ConfigMemory": "cfg::memory", - # "array.array": "ext::pgvector::vector", -} +class Exclusive: + pass -class Exclusive: +class GelType: + def __init__(self, name): + self.name = name + + +class Link: + pass + + +class Multi: pass @@ -58,17 +50,57 @@ def exclusive_fields(self): results = [] for name, info in self.model_fields.items(): - for meta in info.metadata: - if meta is Exclusive: - results.append(name) + if Exclusive in info.metadata: + results.append(name) return results + def prop_fields(self): + results = [] + + for name, info in self.model_fields.items(): + if Link not in info.metadata: + results.append(name) + + return results + + def link_fields(self): + results = [] + + for name, info in self.model_fields.items(): + if Link in info.metadata: + results.append(name) + + return results + + def eq_props(self, other): + if other.__class__ is not self.__class__: + return False + + for name in self.prop_fields(): + if getattr(self, name) != getattr(other, name): + return False + + return True + + def get_field_gel_type(self, name): + info = self.model_fields[name] + for anno in info.metadata: + if isinstance(anno, GelType): + return anno.name + + return None + + +class UpdateGelModel(BaseGelModel): + pass + class ObjData(BaseModel): obj: BaseGelModel rank: int | None = None gelid: uuid.UUID | None = None + exval: tuple | None = None def is_optional(field): @@ -79,15 +111,20 @@ def is_optional(field): class Session: - def __init__(self, data, client): + def __init__(self, data, client, *, identity_merge=False): self._data = list(data) self._client = client + self._identity_merge = identity_merge # insert order will come in tiers, where each tier is itself a list of # objects that have the same insert precedence and can be inserted in # parallel. self._insert_order = [[]] self._idmap = {} + # map based on exclusive properties, once an object is inserted, all + # other copies will be updated with the gelid + self._exmap = defaultdict(list) + self.process_exclusive() self.compute_insert_order() def commit(self): @@ -95,15 +132,30 @@ def commit(self): with tx: for objs in self._insert_order: for item in objs: - query, args = self.generate_insert(item) + objdata = self._idmap[id(item)] + if objdata.gelid is not None: + query, args = self.generate_update_new(item) + elif isinstance(item, UpdateGelModel): + query, args = self.generate_update(item) + else: + query, args = self.generate_insert(item) + gelobj = tx.query_single(query, *args) - self._idmap[id(item)].gelid = gelobj.id + objdata.gelid = gelobj.id + if self._identity_merge: + # Update all identical copies of this object with + # the same gelid + exlist = self._exmap[objdata.exval] + for val in exlist: + self._idmap[id(val)].gelid = gelobj.id + self.clear() def clear(self): self._data = [] self._insert_order = [[]] self._idmap = {} + self._exmap = {} def generate_insert(self, item, *, arg_start=0): args = [] @@ -117,16 +169,36 @@ def generate_insert(self, item, *, arg_start=0): # skip empty values continue - if isinstance(val, BaseModel): - subquery, subargs = self.generate_select( - val, arg_start=arg) - arg += len(subargs) - args += subargs - query += f'{name} := ({subquery}), ' + if Link in info.metadata: + subqueries = [] + if Multi in info.metadata: + links = val + else: + links = [val] + + # multi link potentially needs several subqueries + for el in links: + subquery, subargs = self.generate_select( + el, arg_start=arg) + arg += len(subargs) + args += subargs + subqueries.append(f'({subquery})') + + query += f'{name} := ' + if len(subqueries) > 1: + subq = ", ".join(subqueries) + query += f'assert_distinct({{ {subq} }}), ' + else: + query += f'{subqueries[0]}, ' else: - geltype = GEL_TYPE_MAPPING[type(val)] - query += f'{name} := <{geltype}>${arg}, ' + query += f'{name} := ' + geltype = item.get_field_gel_type(name) + if Multi in info.metadata: + query += f'array_unpack(>${arg}), ' + else: + query += f'<{geltype}>${arg}, ' + arg += 1 args.append(val) @@ -134,6 +206,49 @@ def generate_insert(self, item, *, arg_start=0): return query, args + def generate_update_new(self, item, *, arg_start=0): + gelid = self._idmap[id(item)].gelid + args = [] + arg = arg_start + query = f'update detached {item.__gel_name__} ' + query += f'filter .id = ${arg} set {{' + arg += 1 + args.append(gelid) + + for name, info in item.model_fields.items(): + val = getattr(item, name) + + if val is None: + # skip empty values + continue + + # only update links + if Link in info.metadata: + subqueries = [] + if Multi in info.metadata: + links = val + else: + links = [val] + + # multi link potentially needs several subqueries + for el in links: + subquery, subargs = self.generate_select( + el, arg_start=arg) + arg += len(subargs) + args += subargs + subqueries.append(f'({subquery})') + + query += f'{name} := ' + if len(subqueries) > 1: + subq = ", ".join(subqueries) + query += f'assert_distinct({{ {subq} }}), ' + else: + query += f'{subqueries[0]}, ' + + query += '}' + + return query, args + def generate_select(self, item, *, arg_start=0): gelid = self._idmap[id(item)].gelid args = [] @@ -148,7 +263,7 @@ def generate_select(self, item, *, arg_start=0): else: for name in item.exclusive_fields(): val = getattr(item, name) - geltype = GEL_TYPE_MAPPING[type(val)] + geltype = item.get_field_gel_type(name) fquery.append(f'.{name} = <{geltype}>${arg}') arg += 1 args.append(val) @@ -157,6 +272,65 @@ def generate_select(self, item, *, arg_start=0): return query, args + def generate_update(self, item, *, arg_start=0): + # This is an update query for a pre-existing object, so all fields are + # optional except for id. + gelid = item.id + args = [] + arg = arg_start + query = f'update detached {item.__gel_name__} ' + query += f'filter .id = ${arg} set {{' + arg += 1 + args.append(gelid) + + for name, info in item.model_fields.items(): + if name == 'id': + continue + + val = getattr(item, name) + + # FIXME: instead we need to track the modified fields + if val is None: + # skip empty values + continue + + if Link in info.metadata: + subqueries = [] + if Multi in info.metadata: + links = val + else: + links = [val] + + # multi link potentially needs several subqueries + for el in links: + subquery, subargs = self.generate_select( + el, arg_start=arg) + arg += len(subargs) + args += subargs + subqueries.append(f'({subquery})') + + query += f'{name} := ' + if len(subqueries) > 1: + subq = ", ".join(subqueries) + query += f'assert_distinct({{ {subq} }}), ' + else: + query += f'{subqueries[0]}, ' + + else: + query += f'{name} := ' + geltype = item.get_field_gel_type(name) + if Multi in info.metadata: + query += f'array_unpack(>${arg}), ' + else: + query += f'<{geltype}>${arg}, ' + + arg += 1 + args.append(val) + + query += '}' + + return query, args + def compute_insert_order(self): # We traverse all the distinct objects to be inserted and follow up on # their links, recursively constructing a forest. All leaves get rank @@ -168,7 +342,9 @@ def compute_insert_order(self): def rank_object(self, obj): oid = id(obj) - if self._idmap.get(oid) is not None: + # Check if this object has been ranked already. + objdata = self._idmap[oid] + if objdata.rank is not None: return rank = 0 @@ -176,10 +352,16 @@ def rank_object(self, obj): for name, info in obj.model_fields.items(): val = getattr(obj, name) # We care about actual link value, because it can be None - if isinstance(val, BaseModel): - self.rank_object(val) - linked = self._idmap[id(val)] - rank = max(rank, linked.rank + 1) + if val is not None and Link in info.metadata: + if Multi in info.metadata: + links = val + else: + links = [val] + + for el in links: + self.rank_object(el) + linked = self._idmap[id(el)] + rank = max(rank, linked.rank + 1) if rank >= len(self._insert_order): # We only need to grow the _insert_order by 1 more rank since we @@ -188,42 +370,94 @@ def rank_object(self, obj): self._insert_order.append([]) self._insert_order[rank].append(obj) - self._idmap[oid] = ObjData(obj=obj, rank=rank) - - # def break_cycle(self): - # # We have a set of objects that form a link cycle. We need to find one - # # of them with optional links to this cycle and set that link to be - # # empty instead. The object's link will be updated separately. - # link_found = False - # for oid in self._cycle: - # obj = self._idmap[oid].obj - - # for name, info in obj.model_fields.items(): - # val = self.getmodelattr(obj, name) - # # We care about actual link value, because it can be None - # if ( - # isinstance(val, BaseModel) and - # is_optional(info.annotation) - # ): - # self._updates[oid].add(name) - # link_found = True - # break - - # if not link_found: - # cycle = ', '.join(self._idmap[oid].obj for oid in self._cycle) - # raise Exception('Cycle detected: {cycle}') - - # self._cycle.clear() - - # def getmodelattr(self, obj, name): - # links = self._updates.get(id(obj), set()) - # if name in links: - # # skip this link - # return None - # else: - # return getattr(obj, name) + objdata.rank = rank + + def _get_ex_values(self, obj): + # Make a tuple out of the Gel type name and all of the exclusive + # values. + vals = [obj.__gel_name__] + for name in sorted(obj.exclusive_fields()): + val = getattr(obj, name) + vals.append(val) + + if len(vals) > 1: + return tuple(vals) + else: + return None + + def process_exclusive(self): + errors = defaultdict(list) + + for obj in self._data: + self.map_exclusive(obj, errors) + + if len(errors) > 0: + num_err = 0 + msg = 'The following objects have clashing exclusive fields:\n' + for key, val in errors.items(): + first = self._exmap[key][:1] + msg += f'{key[0]}: ' + if num_err + len(val) < 100: + # include all objects + msg += ', '.join( + str(obj) for obj in first + val) + msg += '\n' + else: + # clip objects in error message + msg += ', '.join( + str(obj) for obj in (first + val)[:100 - num_err]) + break + + raise Exception(msg) + + def map_exclusive(self, obj, errors): + oid = id(obj) + if self._idmap.get(oid) is not None: + return + + exval = self._get_ex_values(obj) + self._idmap[oid] = ObjData(obj=obj, exval=exval) + + if exval is not None: + # has exclusive fields + exlist = self._exmap[exval] + if exlist: + other = exlist[0] + else: + other = None + + if other: + if self._identity_merge: + # Objects with the same exclusive fields and the same values + # for other properties are asssumed to be the same object. + if other.eq_props(obj): + self._exmap[exval].append(obj) + else: + errors[exval].append(obj) + + else: + # Objects with the same exclusive fields cannot exist and + # should be flagged as an error. But we want to collect them + # all first. + errors[exval].append(obj) + else: + # No pre-existing copy + self._exmap[exval].append(obj) + + # recurse into links + for name, info in obj.model_fields.items(): + val = getattr(obj, name) + # We care about actual link value, because it can be None + if val is not None and Link in info.metadata: + if Multi in info.metadata: + links = val + else: + links = [val] + + for el in links: + self.map_exclusive(el, errors) def commit(client, data, *, identity_merge=False): - sess = Session(data, client) + sess = Session(data, client, identity_merge=identity_merge) sess.commit() diff --git a/tests/dbsetup/pydantic.gel b/tests/dbsetup/pydantic.gel new file mode 100644 index 000000000..e7c13b78d --- /dev/null +++ b/tests/dbsetup/pydantic.gel @@ -0,0 +1,28 @@ +type ToDoList { + required name: str { + constraint exclusive() + } +} + +type Item { + required num: int64; + required text: str; + required done: bool; + required list: ToDoList; +} + +type LinkedList { + required data: str { + constraint exclusive() + } + multi ints: int64; + next: LinkedList; +} + +type Tree { + required data: str { + constraint exclusive() + } + + multi branches: Tree; +} \ No newline at end of file diff --git a/tests/test_pydantic_orm.py b/tests/test_pydantic_orm.py new file mode 100644 index 000000000..baf905521 --- /dev/null +++ b/tests/test_pydantic_orm.py @@ -0,0 +1,465 @@ +# +# This source file is part of the EdgeDB open source project. +# +# Copyright 2025-present MagicStack Inc. and the EdgeDB authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import json +import os + +from gel import _testbase as tb +from gel.compatibility.pydmodels import commit + + +class TestPydantic(tb.PydanticTestCase): + SCHEMA = os.path.join(os.path.dirname(__file__), 'dbsetup', + 'pydantic.gel') + + MODEL_PACKAGE = 'pymodels' + + @classmethod + def setUpClass(cls): + super().setUpClass() + from pymodels import default + cls.m = default + + def tearDown(self): + self.client.query('delete Object') + super().tearDown() + + def test_pydantic_insert_models_01(self): + # insert a bunch of disconnected objects + data = [ + self.m.ToDoList(name='1st'), + self.m.ToDoList(name='2nd'), + self.m.ToDoList(name='3rd'), + self.m.ToDoList(name='last'), + ] + + commit(self.client, data) + vals = json.loads(self.client.query_json(''' + select ToDoList.name; + ''')) + + self.assertEqual( + set(vals), + {'1st', '2nd', '3rd', 'last'}, + ) + + def test_pydantic_insert_models_02(self): + # insert a bunch of object structures + l = self.m.ToDoList(name='mylist') + data = [ + self.m.Item(num=0, text='first!!!', done=True, list=l), + self.m.Item(num=1, text='do something', done=True, list=l), + self.m.Item(num=2, text='coffee', done=False, list=l), + self.m.Item(num=10, text='last', done=False, + list=self.m.ToDoList(name='otherlist')), + ] + + commit(self.client, data) + vals = json.loads(self.client.query_json(''' + select ToDoList { + name, + items := ( + select .