diff --git a/src/libumf.c b/src/libumf.c index f96a764bc7..e9b78d4898 100644 --- a/src/libumf.c +++ b/src/libumf.c @@ -19,6 +19,7 @@ #include "provider_level_zero_internal.h" #include "provider_tracking.h" #include "utils_common.h" +#include "utils_concurrency.h" #include "utils_log.h" #if !defined(UMF_NO_HWLOC) #include "topology.h" @@ -27,6 +28,10 @@ umf_memory_tracker_handle_t TRACKER = NULL; static uint64_t umfRefCount = 0; +static utils_mutex_t initMutex; +static UTIL_ONCE_FLAG initMutexOnce = UTIL_ONCE_FLAG_INIT; + +static void initialize_init_mutex(void) { utils_mutex_init(&initMutex); } static umf_ctl_node_t CTL_NODE(umf)[] = {CTL_CHILD(provider), CTL_CHILD(pool), CTL_NODE_END}; @@ -34,11 +39,16 @@ static umf_ctl_node_t CTL_NODE(umf)[] = {CTL_CHILD(provider), CTL_CHILD(pool), void initialize_global_ctl(void) { CTL_REGISTER_MODULE(NULL, umf); } umf_result_t umfInit(void) { - if (utils_fetch_and_add_u64(&umfRefCount, 1) == 0) { + utils_init_once(&initMutexOnce, initialize_init_mutex); + + utils_mutex_lock(&initMutex); + + if (umfRefCount == 0) { utils_log_init(); umf_result_t umf_result = umfMemoryTrackerCreate(&TRACKER); if (umf_result != UMF_RESULT_SUCCESS) { LOG_ERR("Failed to create memory tracker"); + utils_mutex_unlock(&initMutex); return umf_result; } @@ -48,6 +58,7 @@ umf_result_t umfInit(void) { if (umf_result != UMF_RESULT_SUCCESS) { LOG_ERR("Failed to initialize IPC cache"); umfMemoryTrackerDestroy(TRACKER); + utils_mutex_unlock(&initMutex); return umf_result; } @@ -55,6 +66,9 @@ umf_result_t umfInit(void) { initialize_global_ctl(); } + umfRefCount++; + utils_mutex_unlock(&initMutex); + if (TRACKER) { LOG_DEBUG("UMF library initialized"); } @@ -63,7 +77,15 @@ umf_result_t umfInit(void) { } umf_result_t umfTearDown(void) { - if (utils_fetch_and_sub_u64(&umfRefCount, 1) == 1) { + utils_init_once(&initMutexOnce, initialize_init_mutex); + + utils_mutex_lock(&initMutex); + if (umfRefCount == 0) { + utils_mutex_unlock(&initMutex); + return UMF_RESULT_SUCCESS; + } + + if (--umfRefCount == 0) { #if !defined(_WIN32) && !defined(UMF_NO_HWLOC) umfMemspaceHostAllDestroy(); umfMemspaceHighestCapacityDestroy(); @@ -96,6 +118,7 @@ umf_result_t umfTearDown(void) { fini_tbb_global_state(); LOG_DEBUG("UMF library finalized"); } + utils_mutex_unlock(&initMutex); return UMF_RESULT_SUCCESS; }