Skip to content

Commit 54ffcfe

Browse files
committed
Remove code checking for other bfloat16 registrations
1 parent c83659f commit 54ffcfe

File tree

1 file changed

+5
-28
lines changed

1 file changed

+5
-28
lines changed

ml_dtypes/_src/custom_float.h

Lines changed: 5 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -861,6 +861,9 @@ bool RegisterFloatUFuncs(PyObject* numpy) {
861861
return ok;
862862
}
863863

864+
// TODO(jakevdp): simplify the following. We no longer need the already_registered
865+
// check, and heap allocation is probably not important any longer.
866+
//
864867
// Returns true if the numpy type for T is successfully registered, including if
865868
// it was already registered (e.g. by a different library). If
866869
// `already_registered` is non-null, it's set to true if the type was already
@@ -870,35 +873,9 @@ bool RegisterFloatDtype(PyObject* numpy, bool* already_registered = nullptr) {
870873
if (already_registered != nullptr) {
871874
*already_registered = false;
872875
}
873-
// If another module (presumably either TF or JAX) has registered a bfloat16
874-
// type, use it. We don't want two bfloat16 types if we can avoid it since it
875-
// leads to confusion if we have two different types with the same name. This
876-
// assumes that the other module has a sufficiently complete bfloat16
877-
// implementation. The only known NumPy bfloat16 extension at the time of
878-
// writing is this one (distributed in TF and JAX).
879-
// TODO(phawkins): distribute the bfloat16 extension as its own pip package,
880-
// so we can unambiguously refer to a single canonical definition of bfloat16.
881-
int typenum =
882-
PyArray_TypeNumFromName(const_cast<char*>(TypeDescriptor<T>::kTypeName));
883-
if (typenum != NPY_NOTYPE) {
884-
PyArray_Descr* descr = PyArray_DescrFromType(typenum);
885-
// The test for an argmax function here is to verify that the
886-
// bfloat16 implementation is sufficiently new, and, say, not from
887-
// an older version of TF or JAX.
888-
if (descr && descr->f && descr->f->argmax) {
889-
TypeDescriptor<T>::npy_type = typenum;
890-
TypeDescriptor<T>::type_ptr = reinterpret_cast<PyObject*>(descr->typeobj);
891-
if (already_registered != nullptr) {
892-
*already_registered = true;
893-
}
894-
return true;
895-
}
896-
}
897-
898876
// It's important that we heap-allocate our type. This is because tp_name
899-
// is not a fully-qualified name for a heap-allocated type, and
900-
// PyArray_TypeNumFromName() (above) looks at the tp_name field to find
901-
// types. Existing implementations in JAX and TensorFlow look for "bfloat16",
877+
// is not a fully-qualified name for a heap-allocated type.
878+
// Existing implementations in JAX and TensorFlow look for "bfloat16",
902879
// not "ml_dtypes.bfloat16" when searching for an implementation.
903880
Safe_PyObjectPtr name =
904881
make_safe(PyUnicode_FromString(TypeDescriptor<T>::kTypeName));

0 commit comments

Comments
 (0)