@@ -495,6 +495,8 @@ _PyRuntimeState_Init(_PyRuntimeState *runtime)
495495 return _PyStatus_OK ();
496496}
497497
498+ static void _xidregistry_clear (struct _xidregistry * );
499+
498500void
499501_PyRuntimeState_Fini (_PyRuntimeState * runtime )
500502{
@@ -503,6 +505,8 @@ _PyRuntimeState_Fini(_PyRuntimeState *runtime)
503505 assert (runtime -> object_state .interpreter_leaks == 0 );
504506#endif
505507
508+ _xidregistry_clear (& runtime -> xidregistry );
509+
506510 if (gilstate_tss_initialized (runtime )) {
507511 gilstate_tss_fini (runtime );
508512 }
@@ -548,6 +552,11 @@ _PyRuntimeState_ReInitThreads(_PyRuntimeState *runtime)
548552 for (int i = 0 ; i < NUMLOCKS ; i ++ ) {
549553 reinit_err += _PyThread_at_fork_reinit (lockptrs [i ]);
550554 }
555+ /* PyOS_AfterFork_Child(), which calls this function, later calls
556+ _PyInterpreterState_DeleteExceptMain(), so we only need to update
557+ the main interpreter here. */
558+ assert (runtime -> interpreters .main != NULL );
559+ runtime -> interpreters .main -> xidregistry .mutex = runtime -> xidregistry .mutex ;
551560
552561 PyMem_SetAllocator (PYMEM_DOMAIN_RAW , & old_alloc );
553562
@@ -709,6 +718,10 @@ init_interpreter(PyInterpreterState *interp,
709718 interp -> dtoa = (struct _dtoa_state )_dtoa_state_INIT (interp );
710719 }
711720 interp -> f_opcode_trace_set = false;
721+
722+ assert (runtime -> xidregistry .mutex != NULL );
723+ interp -> xidregistry .mutex = runtime -> xidregistry .mutex ;
724+
712725 interp -> _initialized = 1 ;
713726 return _PyStatus_OK ();
714727}
@@ -930,6 +943,10 @@ interpreter_clear(PyInterpreterState *interp, PyThreadState *tstate)
930943 Py_CLEAR (interp -> sysdict );
931944 Py_CLEAR (interp -> builtins );
932945
946+ _xidregistry_clear (& interp -> xidregistry );
947+ /* The lock is owned by the runtime, so we don't free it here. */
948+ interp -> xidregistry .mutex = NULL ;
949+
933950 if (tstate -> interp == interp ) {
934951 /* We are now safe to fix tstate->_status.cleared. */
935952 // XXX Do this (much) earlier?
@@ -2613,23 +2630,27 @@ _PyCrossInterpreterData_ReleaseAndRawFree(_PyCrossInterpreterData *data)
26132630 crossinterpdatafunc. It would be simpler and more efficient. */
26142631
26152632static int
2616- _xidregistry_add_type (struct _xidregistry * xidregistry , PyTypeObject * cls ,
2617- crossinterpdatafunc getdata )
2633+ _xidregistry_add_type (struct _xidregistry * xidregistry ,
2634+ PyTypeObject * cls , crossinterpdatafunc getdata )
26182635{
2619- // Note that we effectively replace already registered classes
2620- // rather than failing.
26212636 struct _xidregitem * newhead = PyMem_RawMalloc (sizeof (struct _xidregitem ));
26222637 if (newhead == NULL ) {
26232638 return -1 ;
26242639 }
2625- // XXX Assign a callback to clear the entry from the registry?
2626- newhead -> cls = PyWeakref_NewRef ((PyObject * )cls , NULL );
2627- if (newhead -> cls == NULL ) {
2628- PyMem_RawFree (newhead );
2629- return -1 ;
2640+ * newhead = (struct _xidregitem ){
2641+ // We do not keep a reference, to avoid keeping the class alive.
2642+ .cls = cls ,
2643+ .refcount = 1 ,
2644+ .getdata = getdata ,
2645+ };
2646+ if (cls -> tp_flags & Py_TPFLAGS_HEAPTYPE ) {
2647+ // XXX Assign a callback to clear the entry from the registry?
2648+ newhead -> weakref = PyWeakref_NewRef ((PyObject * )cls , NULL );
2649+ if (newhead -> weakref == NULL ) {
2650+ PyMem_RawFree (newhead );
2651+ return -1 ;
2652+ }
26302653 }
2631- newhead -> getdata = getdata ;
2632- newhead -> prev = NULL ;
26332654 newhead -> next = xidregistry -> head ;
26342655 if (newhead -> next != NULL ) {
26352656 newhead -> next -> prev = newhead ;
@@ -2654,39 +2675,77 @@ _xidregistry_remove_entry(struct _xidregistry *xidregistry,
26542675 if (next != NULL ) {
26552676 next -> prev = entry -> prev ;
26562677 }
2657- Py_DECREF (entry -> cls );
2678+ Py_XDECREF (entry -> weakref );
26582679 PyMem_RawFree (entry );
26592680 return next ;
26602681}
26612682
2683+ static void
2684+ _xidregistry_clear (struct _xidregistry * xidregistry )
2685+ {
2686+ struct _xidregitem * cur = xidregistry -> head ;
2687+ xidregistry -> head = NULL ;
2688+ while (cur != NULL ) {
2689+ struct _xidregitem * next = cur -> next ;
2690+ Py_XDECREF (cur -> weakref );
2691+ PyMem_RawFree (cur );
2692+ cur = next ;
2693+ }
2694+ }
2695+
26622696static struct _xidregitem *
26632697_xidregistry_find_type (struct _xidregistry * xidregistry , PyTypeObject * cls )
26642698{
26652699 struct _xidregitem * cur = xidregistry -> head ;
26662700 while (cur != NULL ) {
2667- PyObject * registered = _PyWeakref_GET_REF (cur -> cls );
2668- if (registered == NULL ) {
2669- // The weakly ref'ed object was freed.
2670- cur = _xidregistry_remove_entry (xidregistry , cur );
2671- }
2672- else {
2673- assert (PyType_Check (registered ));
2674- if (registered == (PyObject * )cls ) {
2675- Py_DECREF (registered );
2676- return cur ;
2701+ if (cur -> weakref != NULL ) {
2702+ // cur is/was a heap type.
2703+ PyObject * registered = _PyWeakref_GET_REF (cur -> weakref );
2704+ if (registered == NULL ) {
2705+ // The weakly ref'ed object was freed.
2706+ cur = _xidregistry_remove_entry (xidregistry , cur );
2707+ continue ;
26772708 }
2709+ assert (PyType_Check (registered ));
2710+ assert (cur -> cls == (PyTypeObject * )registered );
2711+ assert (cur -> cls -> tp_flags & Py_TPFLAGS_HEAPTYPE );
26782712 Py_DECREF (registered );
2679- cur = cur -> next ;
26802713 }
2714+ if (cur -> cls == cls ) {
2715+ return cur ;
2716+ }
2717+ cur = cur -> next ;
26812718 }
26822719 return NULL ;
26832720}
26842721
2722+ static inline struct _xidregistry *
2723+ _get_xidregistry (PyInterpreterState * interp , PyTypeObject * cls )
2724+ {
2725+ struct _xidregistry * xidregistry = & interp -> runtime -> xidregistry ;
2726+ if (cls -> tp_flags & Py_TPFLAGS_HEAPTYPE ) {
2727+ assert (interp -> xidregistry .mutex == xidregistry -> mutex );
2728+ xidregistry = & interp -> xidregistry ;
2729+ }
2730+ return xidregistry ;
2731+ }
2732+
26852733static void _register_builtins_for_crossinterpreter_data (struct _xidregistry * xidregistry );
26862734
2735+ static inline void
2736+ _ensure_builtins_xid (PyInterpreterState * interp , struct _xidregistry * xidregistry )
2737+ {
2738+ if (xidregistry != & interp -> xidregistry ) {
2739+ assert (xidregistry == & interp -> runtime -> xidregistry );
2740+ if (xidregistry -> head == NULL ) {
2741+ _register_builtins_for_crossinterpreter_data (xidregistry );
2742+ }
2743+ }
2744+ }
2745+
26872746int
26882747_PyCrossInterpreterData_RegisterClass (PyTypeObject * cls ,
2689- crossinterpdatafunc getdata )
2748+ crossinterpdatafunc getdata )
26902749{
26912750 if (!PyType_Check (cls )) {
26922751 PyErr_Format (PyExc_ValueError , "only classes may be registered" );
@@ -2697,12 +2756,23 @@ _PyCrossInterpreterData_RegisterClass(PyTypeObject *cls,
26972756 return -1 ;
26982757 }
26992758
2700- struct _xidregistry * xidregistry = & _PyRuntime .xidregistry ;
2759+ int res = 0 ;
2760+ PyInterpreterState * interp = _PyInterpreterState_GET ();
2761+ struct _xidregistry * xidregistry = _get_xidregistry (interp , cls );
27012762 PyThread_acquire_lock (xidregistry -> mutex , WAIT_LOCK );
2702- if (xidregistry -> head == NULL ) {
2703- _register_builtins_for_crossinterpreter_data (xidregistry );
2763+
2764+ _ensure_builtins_xid (interp , xidregistry );
2765+
2766+ struct _xidregitem * matched = _xidregistry_find_type (xidregistry , cls );
2767+ if (matched != NULL ) {
2768+ assert (matched -> getdata == getdata );
2769+ matched -> refcount += 1 ;
2770+ goto finally ;
27042771 }
2705- int res = _xidregistry_add_type (xidregistry , cls , getdata );
2772+
2773+ res = _xidregistry_add_type (xidregistry , cls , getdata );
2774+
2775+ finally :
27062776 PyThread_release_lock (xidregistry -> mutex );
27072777 return res ;
27082778}
@@ -2711,13 +2781,20 @@ int
27112781_PyCrossInterpreterData_UnregisterClass (PyTypeObject * cls )
27122782{
27132783 int res = 0 ;
2714- struct _xidregistry * xidregistry = & _PyRuntime .xidregistry ;
2784+ PyInterpreterState * interp = _PyInterpreterState_GET ();
2785+ struct _xidregistry * xidregistry = _get_xidregistry (interp , cls );
27152786 PyThread_acquire_lock (xidregistry -> mutex , WAIT_LOCK );
2787+
27162788 struct _xidregitem * matched = _xidregistry_find_type (xidregistry , cls );
27172789 if (matched != NULL ) {
2718- (void )_xidregistry_remove_entry (xidregistry , matched );
2790+ assert (matched -> refcount > 0 );
2791+ matched -> refcount -= 1 ;
2792+ if (matched -> refcount == 0 ) {
2793+ (void )_xidregistry_remove_entry (xidregistry , matched );
2794+ }
27192795 res = 1 ;
27202796 }
2797+
27212798 PyThread_release_lock (xidregistry -> mutex );
27222799 return res ;
27232800}
@@ -2730,17 +2807,19 @@ _PyCrossInterpreterData_UnregisterClass(PyTypeObject *cls)
27302807crossinterpdatafunc
27312808_PyCrossInterpreterData_Lookup (PyObject * obj )
27322809{
2733- struct _xidregistry * xidregistry = & _PyRuntime .xidregistry ;
2734- PyObject * cls = PyObject_Type (obj );
2810+ PyTypeObject * cls = Py_TYPE (obj );
2811+
2812+ PyInterpreterState * interp = _PyInterpreterState_GET ();
2813+ struct _xidregistry * xidregistry = _get_xidregistry (interp , cls );
27352814 PyThread_acquire_lock (xidregistry -> mutex , WAIT_LOCK );
2736- if ( xidregistry -> head == NULL ) {
2737- _register_builtins_for_crossinterpreter_data ( xidregistry );
2738- }
2739- struct _xidregitem * matched = _xidregistry_find_type (xidregistry ,
2740- ( PyTypeObject * ) cls ) ;
2741- Py_DECREF ( cls );
2815+
2816+ _ensure_builtins_xid ( interp , xidregistry );
2817+
2818+ struct _xidregitem * matched = _xidregistry_find_type (xidregistry , cls );
2819+ crossinterpdatafunc func = matched != NULL ? matched -> getdata : NULL ;
2820+
27422821 PyThread_release_lock (xidregistry -> mutex );
2743- return matched != NULL ? matched -> getdata : NULL ;
2822+ return func ;
27442823}
27452824
27462825/* cross-interpreter data for builtin types */
0 commit comments