Skip to content
3 changes: 2 additions & 1 deletion google/cloud/spanner_v1/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@
from .types.type import Type
from .types.type import TypeAnnotationCode
from .types.type import TypeCode
from .data_types import JsonObject
from .data_types import JsonObject, Interval
from .transaction import BatchTransactionId, DefaultTransactionOptions

from google.cloud.spanner_v1 import param_types
Expand Down Expand Up @@ -145,6 +145,7 @@
"TypeCode",
# Custom spanner related data types
"JsonObject",
"Interval",
# google.cloud.spanner_v1.services
"SpannerClient",
"SpannerAsyncClient",
Expand Down
13 changes: 12 additions & 1 deletion google/cloud/spanner_v1/_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
from google.cloud._helpers import _date_from_iso8601_date
from google.cloud.spanner_v1 import TypeCode
from google.cloud.spanner_v1 import ExecuteSqlRequest
from google.cloud.spanner_v1 import JsonObject
from google.cloud.spanner_v1 import JsonObject, Interval
from google.cloud.spanner_v1 import TransactionOptions
from google.cloud.spanner_v1.request_id_header import with_request_id
from google.rpc.error_details_pb2 import RetryInfo
Expand Down Expand Up @@ -251,6 +251,8 @@ def _make_value_pb(value):
return Value(null_value="NULL_VALUE")
else:
return Value(string_value=base64.b64encode(value))
if isinstance(value, Interval):
return Value(string_value=str(value))

raise ValueError("Unknown type: %s" % (value,))

Expand Down Expand Up @@ -367,6 +369,8 @@ def _get_type_decoder(field_type, field_name, column_info=None):
for item_field in field_type.struct_type.fields
]
return lambda value_pb: _parse_struct(value_pb, element_decoders)
elif type_code == TypeCode.INTERVAL:
return _parse_interval
else:
raise ValueError("Unknown type: %s" % (field_type,))

Expand Down Expand Up @@ -473,6 +477,13 @@ def _parse_nullable(value_pb, decoder):
return decoder(value_pb)


def _parse_interval(value_pb):
"""Parse a Value protobuf containing an interval."""
if hasattr(value_pb, "string_value"):
return Interval.from_str(value_pb.string_value)
return Interval.from_str(value_pb)


class _SessionWrapper(object):
"""Base class for objects wrapping a session.

Expand Down
149 changes: 148 additions & 1 deletion google/cloud/spanner_v1/data_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,8 @@

import json
import types

import re
from dataclasses import dataclass
from google.protobuf.message import Message
from google.protobuf.internal.enum_type_wrapper import EnumTypeWrapper

Expand Down Expand Up @@ -97,6 +98,152 @@ def serialize(self):
return json.dumps(self, sort_keys=True, separators=(",", ":"))


@dataclass
class Interval:
"""Represents a Spanner INTERVAL type.

An interval is a combination of months, days and nanoseconds.
Internally, Spanner supports Interval value with the following range of individual fields:
months: [-120000, 120000]
days: [-3660000, 3660000]
nanoseconds: [-316224000000000000000, 316224000000000000000]
"""

months: int = 0
days: int = 0
nanos: int = 0

def __str__(self) -> str:
"""Returns the ISO8601 duration format string representation."""
result = ["P"]

# Handle years and months
if self.months:
is_negative = self.months < 0
abs_months = abs(self.months)
years, months = divmod(abs_months, 12)
if years:
result.append(f"{'-' if is_negative else ''}{years}Y")
if months:
result.append(f"{'-' if is_negative else ''}{months}M")

# Handle days
if self.days:
result.append(f"{self.days}D")

# Handle time components
if self.nanos:
result.append("T")
nanos = abs(self.nanos)
is_negative = self.nanos < 0

# Convert to hours, minutes, seconds
nanos_per_hour = 3600000000000
hours, nanos = divmod(nanos, nanos_per_hour)
if hours:
if is_negative:
result.append("-")
result.append(f"{hours}H")

nanos_per_minute = 60000000000
minutes, nanos = divmod(nanos, nanos_per_minute)
if minutes:
if is_negative:
result.append("-")
result.append(f"{minutes}M")

nanos_per_second = 1000000000
seconds, nanos_fraction = divmod(nanos, nanos_per_second)

if seconds or nanos_fraction:
if is_negative:
result.append("-")
if seconds:
result.append(str(seconds))
elif nanos_fraction:
result.append("0")

if nanos_fraction:
nano_str = f"{nanos_fraction:09d}"
trimmed = nano_str.rstrip("0")
if len(trimmed) <= 3:
while len(trimmed) < 3:
trimmed += "0"
elif len(trimmed) <= 6:
while len(trimmed) < 6:
trimmed += "0"
else:
while len(trimmed) < 9:
trimmed += "0"
result.append(f".{trimmed}")
result.append("S")

if len(result) == 1:
result.append("0Y") # Special case for zero interval

return "".join(result)

@classmethod
def from_str(cls, s: str) -> "Interval":
"""Parse an ISO8601 duration format string into an Interval."""
pattern = r"^P(-?\d+Y)?(-?\d+M)?(-?\d+D)?(T(-?\d+H)?(-?\d+M)?(-?((\d+([.,]\d{1,9})?)|([.,]\d{1,9}))S)?)?$"
match = re.match(pattern, s)
if not match or len(s) == 1:
raise ValueError(f"Invalid interval format: {s}")

parts = match.groups()
if not any(parts[:3]) and not parts[3]:
raise ValueError(
f"Invalid interval format: at least one component (Y/M/D/H/M/S) is required: {s}"
)

if parts[3] == "T" and not any(parts[4:7]):
raise ValueError(
f"Invalid interval format: time designator 'T' present but no time components specified: {s}"
)

def parse_num(s: str, suffix: str) -> int:
if not s:
return 0
return int(s.rstrip(suffix))

years = parse_num(parts[0], "Y")
months = parse_num(parts[1], "M")
total_months = years * 12 + months

days = parse_num(parts[2], "D")

nanos = 0
if parts[3]: # Has time component
# Convert hours to nanoseconds
hours = parse_num(parts[4], "H")
nanos += hours * 3600000000000

# Convert minutes to nanoseconds
minutes = parse_num(parts[5], "M")
nanos += minutes * 60000000000

# Handle seconds and fractional seconds
if parts[6]:
seconds = parts[6].rstrip("S")
if "," in seconds:
seconds = seconds.replace(",", ".")

if "." in seconds:
sec_parts = seconds.split(".")
whole_seconds = sec_parts[0] if sec_parts[0] else "0"
nanos += int(whole_seconds) * 1000000000
frac = sec_parts[1][:9].ljust(9, "0")
frac_nanos = int(frac)
if seconds.startswith("-"):
frac_nanos = -frac_nanos
nanos += frac_nanos
else:
nanos += int(seconds) * 1000000000

return cls(months=total_months, days=days, nanos=nanos)


def _proto_message(bytes_val, proto_message_object):
"""Helper for :func:`get_proto_message`.
parses serialized protocol buffer bytes data into proto message.
Expand Down
1 change: 1 addition & 0 deletions google/cloud/spanner_v1/param_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
PG_NUMERIC = Type(code=TypeCode.NUMERIC, type_annotation=TypeAnnotationCode.PG_NUMERIC)
PG_JSONB = Type(code=TypeCode.JSON, type_annotation=TypeAnnotationCode.PG_JSONB)
PG_OID = Type(code=TypeCode.INT64, type_annotation=TypeAnnotationCode.PG_OID)
INTERVAL = Type(code=TypeCode.INTERVAL)


def Array(element_type):
Expand Down
1 change: 1 addition & 0 deletions google/cloud/spanner_v1/streamed.py
Original file line number Diff line number Diff line change
Expand Up @@ -391,6 +391,7 @@ def _merge_struct(lhs, rhs, type_):
TypeCode.NUMERIC: _merge_string,
TypeCode.JSON: _merge_string,
TypeCode.PROTO: _merge_string,
TypeCode.INTERVAL: _merge_string,
TypeCode.ENUM: _merge_string,
}

Expand Down
13 changes: 12 additions & 1 deletion tests/system/_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,9 +115,20 @@ def scrub_instance_ignore_not_found(to_scrub):
"""Helper for func:`cleanup_old_instances`"""
scrub_instance_backups(to_scrub)

for database_pb in to_scrub.list_databases():
db = to_scrub.database(database_pb.name.split("/")[-1])
db.reload()
try:
if db.enable_drop_protection:
db.enable_drop_protection = False
operation = db.update(["enable_drop_protection"])
operation.result(DATABASE_OPERATION_TIMEOUT_IN_SECONDS)
except exceptions.NotFound:
pass

try:
retry_429_503(to_scrub.delete)()
except exceptions.NotFound: # lost the race
except exceptions.NotFound:
pass


Expand Down
13 changes: 10 additions & 3 deletions tests/system/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,10 +151,17 @@ def instance_config(instance_configs):
if not instance_configs:
raise ValueError("No instance configs found.")

us_west1_config = [
config for config in instance_configs if config.display_name == "us-west1"
import random

us_configs = [
config
for config in instance_configs
if config.display_name in ["us-south1", "us-east4"]
]
config = us_west1_config[0] if len(us_west1_config) > 0 else instance_configs[0]

config = (
random.choice(us_configs) if us_configs else random.choice(instance_configs)
)
yield config


Expand Down
Loading
Loading