diff --git a/smdebug/core/tfevent/util.py b/smdebug/core/tfevent/util.py index 80776821c..f06bd4519 100644 --- a/smdebug/core/tfevent/util.py +++ b/smdebug/core/tfevent/util.py @@ -27,13 +27,22 @@ np.dtype(np.complex64): "DT_COMPLEX64", np.dtype(np.complex128): "DT_COMPLEX128", np.dtype(np.bool): "DT_BOOL", + np.dtype([("qint8", "i1")]): "DT_QINT8", + np.dtype([("quint8", "u1")]): "DT_QUINT8", + np.dtype([("qint16", "= version.parse("2.0.0"): + from tensorflow.python import _pywrap_bfloat16 + + # TF 2.x.x Implements a Custom Numpy Datatype for Brain Floating Type + # Which is currently only supported on TPUs + _np_bfloat16 = _pywrap_bfloat16.TF_bfloat16_type() + _NP_TO_TF.pop(_np_bfloat16) + except (ModuleNotFoundError, ValueError, ImportError): + pass + + for _type in _NP_TO_TF: + try: + _get_proto_dtype(np.dtype(_type)) + except Exception: + assert False + assert True