diff --git a/psqlextra/introspect/models.py b/psqlextra/introspect/models.py index 61a478d..e160bca 100644 --- a/psqlextra/introspect/models.py +++ b/psqlextra/introspect/models.py @@ -7,6 +7,7 @@ Optional, Type, TypeVar, + Union, cast, ) @@ -115,9 +116,10 @@ def models_from_cursor( ) for index, related_field_name in enumerate(related_fields): - related_model = model._meta.get_field( - related_field_name - ).related_model + related_model = cast( + Union[Type[Model], None], + model._meta.get_field(related_field_name).related_model, + ) if not related_model: continue diff --git a/psqlextra/sql.py b/psqlextra/sql.py index b265508..cf12d8c 100644 --- a/psqlextra/sql.py +++ b/psqlextra/sql.py @@ -1,4 +1,5 @@ from collections import OrderedDict +from collections.abc import Iterable from typing import Any, Dict, List, Optional, Tuple, Union import django @@ -7,6 +8,7 @@ from django.db import connections, models from django.db.models import Expression, sql from django.db.models.constants import LOOKUP_SEP +from django.db.models.expressions import Ref from .compiler import PostgresInsertOnConflictCompiler from .compiler import SQLUpdateCompiler as PostgresUpdateCompiler @@ -74,6 +76,14 @@ def rename_annotations(self, annotations) -> None: self.annotation_select_mask.remove(old_name) self.annotation_select_mask.append(new_name) + if isinstance(self.group_by, Iterable): + for statement in self.group_by: + if not isinstance(statement, Ref): + continue + + if statement.refs in annotations: # type: ignore[attr-defined] + statement.refs = annotations[statement.refs] # type: ignore[attr-defined] + self.annotations.clear() self.annotations.update(new_annotations) diff --git a/settings.py b/settings.py index ed0d0f9..7266ccb 100644 --- a/settings.py +++ b/settings.py @@ -11,7 +11,7 @@ 'default': dj_database_url.config(default='postgres:///psqlextra'), } -DATABASES['default']['ENGINE'] = 'psqlextra.backend' +DATABASES['default']['ENGINE'] = 'tests.psqlextra_test_backend' LANGUAGE_CODE = 'en' LANGUAGES = ( @@ -24,3 +24,6 @@ 'psqlextra', 'tests', ) + +USE_TZ = True +TIME_ZONE = 'UTC' diff --git a/setup.py b/setup.py index c3431e2..918beb8 100644 --- a/setup.py +++ b/setup.py @@ -97,7 +97,7 @@ def run(self): "docformatter==1.4", "mypy==1.2.0; python_version > '3.6'", "mypy==0.971; python_version <= '3.6'", - "django-stubs==1.16.0; python_version > '3.6'", + "django-stubs==4.2.7; python_version > '3.6'", "django-stubs==1.9.0; python_version <= '3.6'", "typing-extensions==4.5.0; python_version > '3.6'", "typing-extensions==4.1.0; python_version <= '3.6'", diff --git a/tests/psqlextra_test_backend/__init__.py b/tests/psqlextra_test_backend/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/psqlextra_test_backend/base.py b/tests/psqlextra_test_backend/base.py new file mode 100644 index 0000000..0961a2b --- /dev/null +++ b/tests/psqlextra_test_backend/base.py @@ -0,0 +1,23 @@ +from datetime import timezone + +import django + +from django.conf import settings + +from psqlextra.backend.base import DatabaseWrapper as PSQLExtraDatabaseWrapper + + +class DatabaseWrapper(PSQLExtraDatabaseWrapper): + # Works around the compatibility issue of Django <3.0 and psycopg2.9 + # in combination with USE_TZ + # + # See: https://github.com/psycopg/psycopg2/issues/1293#issuecomment-862835147 + if django.VERSION < (3, 1): + + def create_cursor(self, name=None): + cursor = super().create_cursor(name) + cursor.tzinfo_factory = ( + lambda offset: timezone.utc if settings.USE_TZ else None + ) + + return cursor diff --git a/tests/test_query.py b/tests/test_query.py index 7db4bea..38d6b3c 100644 --- a/tests/test_query.py +++ b/tests/test_query.py @@ -1,5 +1,8 @@ +from datetime import datetime, timezone + from django.db import connection, models -from django.db.models import Case, F, Q, Value, When +from django.db.models import Case, F, Min, Q, Value, When +from django.db.models.functions.datetime import TruncSecond from django.test.utils import CaptureQueriesContext, override_settings from psqlextra.expressions import HStoreRef @@ -96,6 +99,40 @@ def test_query_annotate_in_expression(): assert result.is_he_henk == "really henk" +def test_query_annotate_group_by(): + """Tests whether annotations with GROUP BY clauses are properly renamed + when the annotation overwrites a field name.""" + + model = get_fake_model( + { + "name": models.TextField(), + "timestamp": models.DateTimeField(null=False), + "value": models.IntegerField(), + } + ) + + timestamp = datetime(2024, 1, 1, 0, 0, 0, 0, tzinfo=timezone.utc) + + model.objects.create(name="me", timestamp=timestamp, value=1) + + result = ( + model.objects.values("name") + .annotate( + timestamp=TruncSecond("timestamp", tzinfo=timezone.utc), + value=Min("value"), + ) + .values_list( + "name", + "value", + "timestamp", + ) + .order_by("name") + .first() + ) + + assert result == ("me", 1, timestamp) + + def test_query_hstore_value_update_f_ref(): """Tests whether F(..) expressions can be used in hstore values when performing update queries."""