Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions temporalio/activity.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,12 +21,14 @@
from typing import (
Any,
Callable,
List,
Mapping,
MutableMapping,
NoReturn,
Optional,
Sequence,
Tuple,
Type,
Union,
overload,
)
Expand Down Expand Up @@ -351,6 +353,9 @@ class _Definition:
name: str
fn: Callable
is_async: bool
# Types loaded on post init if both are None
arg_types: Optional[List[Type]] = None
ret_type: Optional[Type] = None

@staticmethod
def from_callable(fn: Callable) -> Optional[_Definition]:
Expand Down Expand Up @@ -396,3 +401,9 @@ def _apply_to_callable(fn: Callable, activity_name: str) -> None:
is_async=inspect.iscoroutinefunction(fn) or inspect.iscoroutinefunction(fn.__call__), # type: ignore
),
)

def __post_init__(self) -> None:
if self.arg_types is None and self.ret_type is None:
arg_types, ret_type = temporalio.common._type_hints_from_func(self.fn)
object.__setattr__(self, "arg_types", arg_types)
object.__setattr__(self, "ret_type", ret_type)
8 changes: 3 additions & 5 deletions temporalio/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,6 @@ def __init__(

See :py:meth:`connect` for details on the parameters.
"""
self._type_lookup = temporalio.converter._FunctionTypeLookup()
# Iterate over interceptors in reverse building the impl
self._impl: OutboundInterceptor = _ClientImpl(self)
for interceptor in reversed(list(interceptors)):
Expand Down Expand Up @@ -359,7 +358,7 @@ async def start_workflow(
elif callable(workflow):
defn = temporalio.workflow._Definition.must_from_run_fn(workflow)
name = defn.name
_, ret_type = self._type_lookup.get_type_hints(defn.run_fn)
ret_type = defn.ret_type
else:
raise TypeError("Workflow must be a string or callable")

Expand Down Expand Up @@ -589,12 +588,11 @@ def get_workflow_handle_for(
The workflow handle.
"""
defn = temporalio.workflow._Definition.must_from_run_fn(workflow)
_, ret_type = self._type_lookup.get_type_hints(defn.run_fn)
return self.get_workflow_handle(
workflow_id,
run_id=run_id,
first_execution_run_id=first_execution_run_id,
result_type=ret_type,
result_type=defn.ret_type,
)

@overload
Expand Down Expand Up @@ -1053,7 +1051,7 @@ async def query(
raise RuntimeError("Cannot invoke dynamic query definition")
# TODO(cretz): Check count/type of args at runtime?
query_name = defn.name
_, ret_type = self._client._type_lookup.get_type_hints(defn.fn)
ret_type = defn.ret_type
else:
query_name = str(query)

Expand Down
90 changes: 89 additions & 1 deletion temporalio/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,24 @@

from __future__ import annotations

import inspect
import types
from dataclasses import dataclass
from datetime import datetime, timedelta
from enum import IntEnum
from typing import Any, List, Mapping, Optional, Sequence, Text, Union
from typing import (
Any,
Callable,
List,
Mapping,
Optional,
Sequence,
Text,
Tuple,
Type,
Union,
get_type_hints,
)

import google.protobuf.internal.containers
from typing_extensions import TypeAlias
Expand Down Expand Up @@ -160,3 +174,77 @@ def _apply_headers(
for k, v in source.items():
# This does not copy bytes, just messages
dest[k].CopyFrom(v)


# Same as inspect._NonUserDefinedCallables
_non_user_defined_callables = (
type(type.__call__),
type(all.__call__), # type: ignore
type(int.__dict__["from_bytes"]),
types.BuiltinFunctionType,
)


def _type_hints_from_func(
func: Callable,
) -> Tuple[Optional[List[Type]], Optional[Type]]:
"""Extracts the type hints from the function.

Args:
func: Function to extract hints from.

Returns:
Tuple containing parameter types and return type. The parameter types
will be None if there are any non-positional parameters or if any of the
parameters to not have an annotation that represents a class. If the
first parameter is "self" with no attribute, it is not included.
"""
# If this is a class instance with user-defined __call__, then use that as
# the func. This mimics inspect logic inside Python.
if (
not inspect.isfunction(func)
and not isinstance(func, _non_user_defined_callables)
and not isinstance(func, types.MethodType)
):
# Class type or Callable instance
tmp_func = func if isinstance(func, type) else type(func)
call_func = getattr(tmp_func, "__call__", None)
if call_func is not None and not isinstance(
tmp_func, _non_user_defined_callables
):
func = call_func

# We use inspect.signature for the parameter names and kinds, but we cannot
# use it for annotations because those that are using deferred hinting (i.e.
# from __future__ import annotations) only work with the eval_str parameter
# which is only supported in >= 3.10. But typing.get_type_hints is supported
# in >= 3.7.
sig = inspect.signature(func)
hints = get_type_hints(func)
ret_hint = hints.get("return")
ret = (
ret_hint
if inspect.isclass(ret_hint) and ret_hint is not inspect.Signature.empty
else None
)
args: List[Type] = []
for index, value in enumerate(sig.parameters.values()):
# Ignore self on methods
if (
index == 0
and value.name == "self"
and value.annotation is inspect.Parameter.empty
):
continue
# Stop if non-positional or not a class
if (
value.kind is not inspect.Parameter.POSITIONAL_ONLY
and value.kind is not inspect.Parameter.POSITIONAL_OR_KEYWORD
):
return (None, ret)
# All params must have annotations or we consider none to have them
arg_hint = hints.get(value.name)
if not inspect.isclass(arg_hint) or arg_hint is inspect.Parameter.empty:
return (None, ret)
args.append(arg_hint)
return args, ret
102 changes: 0 additions & 102 deletions temporalio/converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,15 +7,12 @@
import dataclasses
import inspect
import json
import types
import typing
from abc import ABC, abstractmethod
from dataclasses import dataclass
from datetime import datetime
from enum import IntEnum
from typing import (
Any,
Callable,
Dict,
List,
Mapping,
Expand Down Expand Up @@ -705,105 +702,6 @@ def decode_search_attributes(
return ret


class _FunctionTypeLookup:
def __init__(self) -> None:
# Keyed by callable __qualname__, value is optional arg types and
# optional ret type
self._cache: Dict[str, Tuple[Optional[List[Type]], Optional[Type]]] = {}

def get_type_hints(self, fn: Any) -> Tuple[Optional[List[Type]], Optional[Type]]:
# Due to MyPy issues, we cannot type "fn" as callable
if not callable(fn):
return (None, None)
# We base the cache key on the qualified name of the function. However,
# since some callables are not functions, we assume we can never cache
# these just in case the type hints are dynamic for some strange reason.
cache_key = getattr(fn, "__qualname__", None)
if cache_key:
ret = self._cache.get(cache_key)
if ret:
return ret
# TODO(cretz): Do we even need to cache?
ret = _type_hints_from_func(fn)
if cache_key:
self._cache[cache_key] = ret
return ret


# Same as inspect._NonUserDefinedCallables
_non_user_defined_callables = (
type(type.__call__),
type(all.__call__), # type: ignore
type(int.__dict__["from_bytes"]),
types.BuiltinFunctionType,
)


def _type_hints_from_func(
func: Callable,
) -> Tuple[Optional[List[Type]], Optional[Type]]:
"""Extracts the type hints from the function.

Args:
func: Function to extract hints from.

Returns:
Tuple containing parameter types and return type. The parameter types
will be None if there are any non-positional parameters or if any of the
parameters to not have an annotation that represents a class. If the
first parameter is "self" with no attribute, it is not included.
"""
# If this is a class instance with user-defined __call__, then use that as
# the func. This mimics inspect logic inside Python.
if (
not inspect.isfunction(func)
and not isinstance(func, _non_user_defined_callables)
and not isinstance(func, types.MethodType)
):
# Class type or Callable instance
tmp_func = func if isinstance(func, type) else type(func)
call_func = getattr(tmp_func, "__call__", None)
if call_func is not None and not isinstance(
tmp_func, _non_user_defined_callables
):
func = call_func

# We use inspect.signature for the parameter names and kinds, but we cannot
# use it for annotations because those that are using deferred hinting (i.e.
# from __future__ import annotations) only work with the eval_str parameter
# which is only supported in >= 3.10. But typing.get_type_hints is supported
# in >= 3.7.
sig = inspect.signature(func)
hints = typing.get_type_hints(func)
ret_hint = hints.get("return")
ret = (
ret_hint
if inspect.isclass(ret_hint) and ret_hint is not inspect.Signature.empty
else None
)
args: List[Type] = []
for index, value in enumerate(sig.parameters.values()):
# Ignore self on methods
if (
index == 0
and value.name == "self"
and value.annotation is inspect.Parameter.empty
):
continue
# Stop if non-positional or not a class
if (
value.kind is not inspect.Parameter.POSITIONAL_ONLY
and value.kind is not inspect.Parameter.POSITIONAL_OR_KEYWORD
):
return (None, ret)
# All params must have annotations or we consider none to have them
arg_hint = hints.get(value.name)
if not inspect.isclass(arg_hint) or arg_hint is inspect.Parameter.empty:
return (None, ret)
args.append(arg_hint)
return args, ret


def value_to_type(hint: Type, value: Any) -> Any:
"""Convert a given value to the given type hint.

Expand Down
4 changes: 1 addition & 3 deletions temporalio/worker/activity.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,6 @@ def __init__(
shared_state_manager: Optional[SharedStateManager],
data_converter: temporalio.converter.DataConverter,
interceptors: Sequence[Interceptor],
type_lookup: temporalio.converter._FunctionTypeLookup,
) -> None:
self._bridge_worker = bridge_worker
self._task_queue = task_queue
Expand All @@ -63,7 +62,6 @@ def __init__(
self._running_activities: Dict[bytes, _RunningActivity] = {}
self._data_converter = data_converter
self._interceptors = interceptors
self._type_lookup = type_lookup
# Lazily created on first activity
self._worker_shutdown_event: Optional[
temporalio.activity._CompositeEvent
Expand Down Expand Up @@ -303,7 +301,7 @@ async def _run_activity(

# Convert arguments. We only use arg type hints if they match the
# input count.
arg_types, _ = self._type_lookup.get_type_hints(activity_def.fn)
arg_types = activity_def.arg_types
if arg_types is not None and len(arg_types) != len(start.input):
arg_types = None
try:
Expand Down
5 changes: 0 additions & 5 deletions temporalio/worker/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,10 +175,6 @@ def __init__(
)
interceptors = interceptors_from_client + list(interceptors)

# Instead of using the _type_lookup on the client, we create a separate
# one here so we can continue to only use the public API of the client
type_lookup = temporalio.converter._FunctionTypeLookup()

# Extract the bridge service client. We try the service on the client
# first, then we support a worker_service_client on the client's service
# to return underlying service client we can use.
Expand Down Expand Up @@ -240,7 +236,6 @@ def __init__(
shared_state_manager=shared_state_manager,
data_converter=client_config["data_converter"],
interceptors=interceptors,
type_lookup=type_lookup,
)
self._workflow_worker: Optional[_WorkflowWorker] = None
if workflows:
Expand Down
Loading