Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ dependencies = [
"cloud-files >= 5.3.2",
"packaging >= 23.2",
"pdbp >= 1.5.3",
"psutil >= 7.1.0",
"boto3 >= 1.38.15",
"slack_sdk >= 3.31.0",
"tabulate >= 0.9.0",
Expand Down Expand Up @@ -93,7 +94,7 @@ chunkedgraph = [
# Note: graph_tool must be installed separately via conda:
# conda install -c conda-forge graph-tool-base
]
convnet = ["torch >= 2.0", "artificery >= 0.0.3.3", "onnx2torch"]
convnet = ["torch >= 2.0", "artificery >= 0.0.3.3", "onnx2torch", "nvidia-ml-py >= 13.580"]
databackends = ["google-cloud-datastore", "google-cloud-firestore"]
docs = [
"piccolo_theme >= 0.24.0",
Expand Down
42 changes: 42 additions & 0 deletions tests/unit/common/test_resource_monitor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
import time

from zetta_utils.common.resource_monitor import ResourceMonitor


def test_resource_monitor_basic():
monitor = ResourceMonitor(log_interval_seconds=0.1)

assert monitor.log_interval_seconds == 0.1
assert len(monitor.samples) == 0
assert monitor.prev_time is None

usage = monitor.get_all_usage()
assert "timestamp" in usage
assert "cpu_percent" in usage
assert "memory" in usage
assert "disk_io" in usage
assert "network" in usage
assert "gpus" in usage

monitor.log_usage()
time.sleep(0.1)
monitor.log_usage()
time.sleep(0.1)
monitor.log_usage()

summary = monitor.get_summary_stats()

assert len(monitor.samples) == 3
assert "sample_count" in summary
assert summary["sample_count"] == 3
assert "duration_seconds" in summary
assert "cpu" in summary
assert "memory" in summary

monitor.log_summary()


def test_resource_monitor_empty():
monitor = ResourceMonitor(log_interval_seconds=0.1)
summary = monitor.get_summary_stats()
assert not summary
7 changes: 4 additions & 3 deletions tests/unit/geometry/test_bbox_strider.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
# pylint: disable=missing-docstring,redefined-outer-name,unused-argument,pointless-statement,line-too-long,protected-access,unsubscriptable-object,unused-variable
import multiprocessing

import pytest

from zetta_utils import MULTIPROCESSING_NUM_TASKS_THRESHOLD
from zetta_utils.geometry import BBox3D, BBoxStrider, IntVec3D, Vec3D


Expand Down Expand Up @@ -57,10 +57,11 @@ def test_bbox_strider_get_all_chunks(mocker):


def test_bbox_strider_get_all_chunks_parallel(mocker):
num_cores = multiprocessing.cpu_count()
strider = BBoxStrider(
bbox=BBox3D.from_coords(
start_coord=Vec3D(0, 0, 0), end_coord=Vec3D(2, 1, num_cores), resolution=Vec3D(1, 1, 1)
start_coord=Vec3D(0, 0, 0),
end_coord=Vec3D(2, 1, MULTIPROCESSING_NUM_TASKS_THRESHOLD + 1),
resolution=Vec3D(1, 1, 1),
),
chunk_size=IntVec3D(1, 1, 1),
stride=IntVec3D(1, 1, 1),
Expand Down
133 changes: 131 additions & 2 deletions tests/unit/mazepa/test_semaphores.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,22 @@
# pylint: disable=bare-except, c-extension-no-member

import os
from typing import List
from multiprocessing import shared_memory
from typing import List, get_args

import posix_ipc
import pytest

from zetta_utils.mazepa.semaphores import (
DummySemaphore,
SemaphoreType,
TimingTracker,
_log_timing_summary,
configure_semaphores,
get_semaphore_stats,
name_to_posix_name,
reset_all_timing_data,
reset_timing_data,
semaphore,
)

Expand Down Expand Up @@ -72,7 +78,7 @@ def test_get_parent_semaphore():
except:
pass
sema = posix_ipc.Semaphore(name_to_posix_name("read", os.getppid()), flags=posix_ipc.O_CREX)
assert sema.name == semaphore("read").name
assert sema.name == semaphore("read").semaphore.name
sema.unlink()


Expand Down Expand Up @@ -154,3 +160,126 @@ def test_missing_semaphore_type_exc(semaphore_spec):
with pytest.raises(ValueError):
with configure_semaphores(semaphore_spec):
pass


def test_timingtracker_reset():
tracker = TimingTracker("test")
tracker.create_shared_memory().close()
tracker.reset_timing_data()
tracker.unlink()


def test_timingtracker_duplicate_shm_exc():
tracker = TimingTracker("test")
tracker.create_shared_memory().close()
with pytest.raises(RuntimeError):
tracker.create_shared_memory()
tracker.unlink()


def test_timingtracker_duplicate_unlink():
tracker = TimingTracker("test")
tracker.create_shared_memory().close()
tracker.unlink()
tracker.unlink()


def test_timingtracker_add_wait_time_noshm_exc():
tracker = TimingTracker("read")
with pytest.raises(RuntimeError):
tracker.add_wait_time(1.0)


def test_timingtracker_add_lease_time_noshm_exc():
tracker = TimingTracker("read")
with pytest.raises(RuntimeError):
tracker.add_lease_time(1.0)


def test_timingtracker_get_timing_data_noshm_exc():
tracker = TimingTracker("read")
with pytest.raises(RuntimeError):
tracker.get_timing_data()


def test_timingtracker_reset_timing_data_noshm_exc():
tracker = TimingTracker("read")
with pytest.raises(RuntimeError):
tracker.reset_timing_data()


def test_get_semaphore_stats_exc():
with pytest.raises(ValueError):
get_semaphore_stats("exc") # type: ignore


def test_reset_all_timing_data():
for name in get_args(SemaphoreType):
tracker = TimingTracker(name)
tracker.create_shared_memory().close()
reset_all_timing_data()
for name in get_args(SemaphoreType):
tracker = TimingTracker(name)
tracker.unlink()


def test_reset_timing_data():
tracker = TimingTracker("read")
tracker.create_shared_memory().close()
reset_timing_data("read")
tracker.unlink()


def test_reset_timing_data_wrongname_exc():
with pytest.raises(ValueError):
reset_timing_data("exc") # type: ignore


def test_reset_timing_data_function_exc():
with pytest.raises(RuntimeError):
reset_timing_data("read")


def test_log_timing_summary_exc():
with pytest.raises(RuntimeError):
_log_timing_summary({"read": 1})


def test_tracker_cleanup_exc(mocker):

original_unlink = shared_memory.SharedMemory.unlink

def failing_unlink(self):
raise PermissionError("Cannot unlink shared memory")

mocker.patch.object(shared_memory.SharedMemory, "unlink", failing_unlink)

tracker = TimingTracker("read")
tracker.create_shared_memory().close()

with pytest.raises(RuntimeError):
tracker.unlink()

mocker.patch.object(shared_memory.SharedMemory, "unlink", original_unlink)

tracker.unlink()


def test_configure_tracker_cleanup_exc(mocker):

original_unlink = shared_memory.SharedMemory.unlink

def failing_unlink(self):
raise PermissionError("Cannot unlink shared memory")

mocker.patch.object(shared_memory.SharedMemory, "unlink", failing_unlink)

with pytest.raises(RuntimeError):
with configure_semaphores():
pass

mocker.patch.object(shared_memory.SharedMemory, "unlink", original_unlink)

for name in get_args(SemaphoreType):
semaphore(name).unlink()
TimingTracker(name).unlink()
4 changes: 3 additions & 1 deletion zetta_utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,9 @@

from .log import get_logger

# Set global multiprocessing threshold
MULTIPROCESSING_NUM_TASKS_THRESHOLD = 128

if "sphinx" not in sys.modules: # pragma: no cover
import pdbp # noqa

Expand All @@ -26,7 +29,6 @@
warnings.filterwarnings("ignore", category=DeprecationWarning)



def _load_core_modules():
"""Load core modules that were previously imported at package level."""
from . import log, typing, parsing, builder, common, constants
Expand Down
12 changes: 12 additions & 0 deletions zetta_utils/cloud_management/resource_allocation/k8s/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,20 +67,30 @@ def get_mazepa_worker_command(
num_procs: int = 1,
semaphores_spec: dict[SemaphoreType, int] | None = None,
idle_timeout: int | None = None,
suppress_worker_logs: bool = False,
resource_monitor_interval: float | None = 1.0,
):
if num_procs == 1 and semaphores_spec is None:
command = "mazepa.run_worker"
num_procs_line = ""
semaphores_line = ""
suppress_worker_logs_line = ""
else:
command = "mazepa.run_worker_manager"
num_procs_line = f"num_procs: {num_procs}\n"
semaphores_line = f"semaphores_spec: {json.dumps(semaphores_spec)}\n"
suppress_worker_logs_line = f"suppress_worker_logs: {json.dumps(suppress_worker_logs)}\n"

idle_timeout_line = ""
if idle_timeout:
idle_timeout_line = f"idle_timeout: {idle_timeout}\n"

resource_monitor_interval_line = ""
if resource_monitor_interval is not None:
resource_monitor_interval_line = (
f"resource_monitor_interval: {resource_monitor_interval}\n"
)

result = f"zetta -vv -l try run -r {run.RUN_ID} --no-main-run-process -p -s '{{"
result += (
f'"@type": "{command}"\n'
Expand All @@ -89,6 +99,8 @@ def get_mazepa_worker_command(
+ num_procs_line
+ semaphores_line
+ idle_timeout_line
+ suppress_worker_logs_line
+ resource_monitor_interval_line
+ """
sleep_sec: 5
}'
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -93,14 +93,21 @@ def get_mazepa_worker_deployment( # pylint: disable=too-many-locals
gpu_accelerator_type: str | None = None,
adc_available: bool = False,
cave_secret_available: bool = False,
suppress_worker_logs: bool = False,
resource_monitor_interval: float | None = 1.0,
):
if labels is None:
labels_final = {"run_id": run_id}
else:
labels_final = labels

worker_command = get_mazepa_worker_command(
task_queue_spec, outcome_queue_spec, num_procs, semaphores_spec
task_queue_spec,
outcome_queue_spec,
num_procs,
semaphores_spec,
suppress_worker_logs=suppress_worker_logs,
resource_monitor_interval=resource_monitor_interval,
)
logger.debug(f"Making a deployment with worker command: '{worker_command}'")

Expand Down
1 change: 1 addition & 0 deletions zetta_utils/common/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,5 +9,6 @@
from .misc import get_unique_id
from .path import abspath, is_local
from .pprint import lrpad
from .resource_monitor import ResourceMonitor
from .signal_handlers import custom_signal_handler_ctx
from .timer import RepeatTimer, Timer
Loading
Loading