@@ -861,6 +861,9 @@ bool RegisterFloatUFuncs(PyObject* numpy) {
861861 return ok;
862862}
863863
864+ // TODO(jakevdp): simplify the following. We no longer need the already_registered
865+ // check, and heap allocation is probably not important any longer.
866+ //
864867// Returns true if the numpy type for T is successfully registered, including if
865868// it was already registered (e.g. by a different library). If
866869// `already_registered` is non-null, it's set to true if the type was already
@@ -870,35 +873,9 @@ bool RegisterFloatDtype(PyObject* numpy, bool* already_registered = nullptr) {
870873 if (already_registered != nullptr ) {
871874 *already_registered = false ;
872875 }
873- // If another module (presumably either TF or JAX) has registered a bfloat16
874- // type, use it. We don't want two bfloat16 types if we can avoid it since it
875- // leads to confusion if we have two different types with the same name. This
876- // assumes that the other module has a sufficiently complete bfloat16
877- // implementation. The only known NumPy bfloat16 extension at the time of
878- // writing is this one (distributed in TF and JAX).
879- // TODO(phawkins): distribute the bfloat16 extension as its own pip package,
880- // so we can unambiguously refer to a single canonical definition of bfloat16.
881- int typenum =
882- PyArray_TypeNumFromName (const_cast <char *>(TypeDescriptor<T>::kTypeName ));
883- if (typenum != NPY_NOTYPE) {
884- PyArray_Descr* descr = PyArray_DescrFromType (typenum);
885- // The test for an argmax function here is to verify that the
886- // bfloat16 implementation is sufficiently new, and, say, not from
887- // an older version of TF or JAX.
888- if (descr && descr->f && descr->f ->argmax ) {
889- TypeDescriptor<T>::npy_type = typenum;
890- TypeDescriptor<T>::type_ptr = reinterpret_cast <PyObject*>(descr->typeobj );
891- if (already_registered != nullptr ) {
892- *already_registered = true ;
893- }
894- return true ;
895- }
896- }
897-
898876 // It's important that we heap-allocate our type. This is because tp_name
899- // is not a fully-qualified name for a heap-allocated type, and
900- // PyArray_TypeNumFromName() (above) looks at the tp_name field to find
901- // types. Existing implementations in JAX and TensorFlow look for "bfloat16",
877+ // is not a fully-qualified name for a heap-allocated type.
878+ // Existing implementations in JAX and TensorFlow look for "bfloat16",
902879 // not "ml_dtypes.bfloat16" when searching for an implementation.
903880 Safe_PyObjectPtr name =
904881 make_safe (PyUnicode_FromString (TypeDescriptor<T>::kTypeName ));
0 commit comments