From 495abba37e21e4d6b119a2441c41e9addb74a4eb Mon Sep 17 00:00:00 2001 From: tqchen Date: Wed, 28 Dec 2022 19:12:26 -0500 Subject: [PATCH] [CONTAINER] Struct Hash/Equal and JSON support for ShapeTuple This PR add struct equal/hash and json serialization support for shape tuple. Testcases added. --- src/node/structural_hash.cc | 44 +++++++++++++++++++ src/support/base64.h | 9 +++- .../test_container_structural_equal.py | 14 ++++++ .../python/unittest/test_runtime_container.py | 5 +++ 4 files changed, 70 insertions(+), 2 deletions(-) diff --git a/src/node/structural_hash.cc b/src/node/structural_hash.cc index 1d1185cddc3d..0426b8454dce 100644 --- a/src/node/structural_hash.cc +++ b/src/node/structural_hash.cc @@ -484,6 +484,50 @@ TVM_REGISTER_REFLECTION_VTABLE(ArrayNode, ArrayNodeTrait) return ::tvm::runtime::make_object(); }); +struct ShapeTupleObjTrait { + static constexpr const std::nullptr_t VisitAttrs = nullptr; + + static void SHashReduce(const ShapeTupleObj* self, SHashReducer hash_reduce) { + hash_reduce(self->size); + for (size_t i = 0; i < self->size; ++i) { + hash_reduce(self->data[i]); + } + } + + static bool SEqualReduce(const ShapeTupleObj* lhs, const ShapeTupleObj* rhs, + SEqualReducer equal) { + if (lhs->size != rhs->size) return false; + for (size_t i = 0; i < lhs->size; ++i) { + if (!equal(lhs->data[i], rhs->data[i])) return false; + } + return true; + } +}; + +TVM_REGISTER_REFLECTION_VTABLE(ShapeTupleObj, ShapeTupleObjTrait) + .set_creator([](const std::string& blob) { + // Store shape tuple in blob to avoid large integer overflow in JSON. + dmlc::MemoryStringStream mstrm(const_cast(&blob)); + support::Base64InStream b64strm(&mstrm); + b64strm.InitPosition(); + uint64_t size; + b64strm.Read(&size); + std::vector data(size); + b64strm.ReadArray(data.data(), size); + ShapeTuple shape(data); + return RefToObjectPtr::Get(shape); + }) + .set_repr_bytes([](const Object* n) -> std::string { + std::string blob; + dmlc::MemoryStringStream mstrm(&blob); + support::Base64OutStream b64strm(&mstrm); + const auto* shape = static_cast(n); + b64strm.Write(shape->size); + b64strm.WriteArray(shape->data, shape->size); + b64strm.Finish(); + return blob; + }); + struct MapNodeTrait { static constexpr const std::nullptr_t VisitAttrs = nullptr; diff --git a/src/support/base64.h b/src/support/base64.h index 7b37afce66cc..aba4197bce20 100644 --- a/src/support/base64.h +++ b/src/support/base64.h @@ -115,8 +115,10 @@ class Base64InStream : public dmlc::Stream { } /*! \brief whether current position is end of a base64 stream */ bool IsEOF(void) const { return num_prev_ == 0 && (temp_ch_ == EOF || isspace(temp_ch_)); } + + using dmlc::Stream::Read; // override read function. - virtual size_t Read(void* ptr, size_t size) { + size_t Read(void* ptr, size_t size) final { using base64::DecodeTable; if (size == 0) return 0; // use tlen to record left size @@ -224,7 +226,10 @@ class Base64InStream : public dmlc::Stream { class Base64OutStream : public dmlc::Stream { public: explicit Base64OutStream(dmlc::Stream* fp) : fp_(fp) {} - virtual void Write(const void* ptr, size_t size) { + + using dmlc::Stream::Write; + + void Write(const void* ptr, size_t size) final { using base64::EncodeTable; size_t tlen = size; const unsigned char* cptr = static_cast(ptr); diff --git a/tests/python/unittest/test_container_structural_equal.py b/tests/python/unittest/test_container_structural_equal.py index cdd9ffb7af53..61511c609ca4 100644 --- a/tests/python/unittest/test_container_structural_equal.py +++ b/tests/python/unittest/test_container_structural_equal.py @@ -107,6 +107,20 @@ def test_array_structural_equal_to_self(contents): assert get_first_mismatch_ensure_symmetry(a, b) is None +@pytest.mark.parametrize( + "contents", + [ + [], + [1], + [1, 2, 3], + ], +) +def test_shape_tuple_structural_equal_to_self(contents): + a = tvm.runtime.ShapeTuple(list(contents)) + b = tvm.runtime.ShapeTuple(list(contents)) + assert get_first_mismatch_ensure_symmetry(a, b) is None + + @pytest.mark.parametrize( "a, b, expected_a_path, expected_b_path", [ diff --git a/tests/python/unittest/test_runtime_container.py b/tests/python/unittest/test_runtime_container.py index 8c302e920577..7538075ae7f8 100644 --- a/tests/python/unittest/test_runtime_container.py +++ b/tests/python/unittest/test_runtime_container.py @@ -90,6 +90,11 @@ def test_shape_tuple(): # ShapleTuple vs. ShapeTuple assert stuple == _container.ShapeTuple(shape) + # test pickle + z = pickle.loads(pickle.dumps(stuple)) + assert isinstance(z, tvm.runtime.ShapeTuple) + assert stuple == z + if __name__ == "__main__": test_string()