From b91071cdb98e63a88f9cc6521febd661a5a872f9 Mon Sep 17 00:00:00 2001 From: Marius Wachtler Date: Wed, 9 Sep 2015 11:31:08 +0100 Subject: [PATCH] Fix set comparisons --- src/runtime/dict.cpp | 6 ++-- src/runtime/set.cpp | 77 ++++++++++++++++++++++++++++++++------------ test/tests/set.py | 8 +++-- 3 files changed, 66 insertions(+), 25 deletions(-) diff --git a/src/runtime/dict.cpp b/src/runtime/dict.cpp index 1940cfbc3..f42b1496f 100644 --- a/src/runtime/dict.cpp +++ b/src/runtime/dict.cpp @@ -85,7 +85,7 @@ Box* dictCopy(BoxedDict* self) { raiseExcHelper(TypeError, "descriptor 'copy' requires a 'dict' object but received a '%s'", getTypeName(self)); BoxedDict* r = new BoxedDict(); - r->d.insert(self->d.begin(), self->d.end()); + r->d = self->d; return r; } @@ -576,11 +576,11 @@ Box* dictEq(BoxedDict* self, Box* _rhs) { if (self->d.size() != rhs->d.size()) return False; - for (const auto& p : *self) { + for (const auto& p : self->d) { auto it = rhs->d.find(p.first); if (it == rhs->d.end()) return False; - if (!nonzero(compare(p.second, it->second, AST_TYPE::Eq))) + if (!PyEq()(p.second, it->second)) return False; } diff --git a/src/runtime/set.cpp b/src/runtime/set.cpp index c32506d6a..9214757ff 100644 --- a/src/runtime/set.cpp +++ b/src/runtime/set.cpp @@ -407,6 +407,9 @@ static Box* setIssubset(BoxedSet* self, Box* container) { assert(PyAnySet_Check(container)); BoxedSet* rhs = static_cast(container); + if (self->s.size() > rhs->s.size()) + return False; + for (auto e : self->s) { if (rhs->s.find(e) == rhs->s.end()) return False; @@ -421,13 +424,7 @@ static Box* setIssuperset(BoxedSet* self, Box* container) { container = makeNewSet(set_cls, container); } assert(PyAnySet_Check(container)); - - BoxedSet* rhs = static_cast(container); - for (auto e : rhs->s) { - if (self->s.find(e) == self->s.end()) - return False; - } - return True; + return setIssubset((BoxedSet*)container, self); } static Box* setIsdisjoint(BoxedSet* self, Box* container) { @@ -473,7 +470,7 @@ Box* setCopy(BoxedSet* self) { RELEASE_ASSERT(PyAnySet_Check(self), ""); BoxedSet* rtn = new BoxedSet(); - rtn->s.insert(self->s.begin(), self->s.end()); + rtn->s = self->s; return rtn; } @@ -497,24 +494,56 @@ Box* setContains(BoxedSet* self, Box* v) { Box* setEq(BoxedSet* self, BoxedSet* rhs) { RELEASE_ASSERT(PyAnySet_Check(self), ""); if (!PyAnySet_Check(rhs)) - return NotImplemented; + return False; if (self->s.size() != rhs->s.size()) return False; - for (auto e : self->s) { - if (!rhs->s.count(e)) - return False; - } - return True; + return setIssubset(self, rhs); } Box* setNe(BoxedSet* self, BoxedSet* rhs) { Box* r = setEq(self, rhs); - if (r->cls == bool_cls) - return boxBool(r == False); - assert(r == NotImplemented); - return r; + assert(r->cls == bool_cls); + return boxBool(r == False); +} + +Box* setLe(BoxedSet* self, BoxedSet* rhs) { + RELEASE_ASSERT(PyAnySet_Check(self), ""); + if (!PyAnySet_Check(rhs)) + raiseExcHelper(TypeError, "can only compare to a set"); + + return setIssubset(self, rhs); +} + +Box* setLt(BoxedSet* self, BoxedSet* rhs) { + RELEASE_ASSERT(PyAnySet_Check(self), ""); + if (!PyAnySet_Check(rhs)) + raiseExcHelper(TypeError, "can only compare to a set"); + + if (self->s.size() >= rhs->s.size()) + return False; + + return setIssubset(self, rhs); +} + +Box* setGe(BoxedSet* self, BoxedSet* rhs) { + RELEASE_ASSERT(PyAnySet_Check(self), ""); + if (!PyAnySet_Check(rhs)) + raiseExcHelper(TypeError, "can only compare to a set"); + + return setIssuperset(self, rhs); +} + +Box* setGt(BoxedSet* self, BoxedSet* rhs) { + RELEASE_ASSERT(PyAnySet_Check(self), ""); + if (!PyAnySet_Check(rhs)) + raiseExcHelper(TypeError, "can only compare to a set"); + + if (self->s.size() <= rhs->s.size()) + return False; + + return setIssuperset(self, rhs); } Box* setNonzero(BoxedSet* self) { @@ -627,10 +656,18 @@ void setupSet() { set_cls->giveAttr("__contains__", new BoxedFunction(boxRTFunction((void*)setContains, BOXED_BOOL, 2))); frozenset_cls->giveAttr("__contains__", set_cls->getattr(internStringMortal("__contains__"))); - set_cls->giveAttr("__eq__", new BoxedFunction(boxRTFunction((void*)setEq, UNKNOWN, 2))); + set_cls->giveAttr("__eq__", new BoxedFunction(boxRTFunction((void*)setEq, BOXED_BOOL, 2))); frozenset_cls->giveAttr("__eq__", set_cls->getattr(internStringMortal("__eq__"))); - set_cls->giveAttr("__ne__", new BoxedFunction(boxRTFunction((void*)setNe, UNKNOWN, 2))); + set_cls->giveAttr("__ne__", new BoxedFunction(boxRTFunction((void*)setNe, BOXED_BOOL, 2))); frozenset_cls->giveAttr("__ne__", set_cls->getattr(internStringMortal("__ne__"))); + set_cls->giveAttr("__le__", new BoxedFunction(boxRTFunction((void*)setLe, BOXED_BOOL, 2))); + frozenset_cls->giveAttr("__le__", set_cls->getattr(internStringMortal("__le__"))); + set_cls->giveAttr("__lt__", new BoxedFunction(boxRTFunction((void*)setLt, BOXED_BOOL, 2))); + frozenset_cls->giveAttr("__lt__", set_cls->getattr(internStringMortal("__lt__"))); + set_cls->giveAttr("__ge__", new BoxedFunction(boxRTFunction((void*)setGe, BOXED_BOOL, 2))); + frozenset_cls->giveAttr("__ge__", set_cls->getattr(internStringMortal("__ge__"))); + set_cls->giveAttr("__gt__", new BoxedFunction(boxRTFunction((void*)setGt, BOXED_BOOL, 2))); + frozenset_cls->giveAttr("__gt__", set_cls->getattr(internStringMortal("__gt__"))); set_cls->giveAttr("__nonzero__", new BoxedFunction(boxRTFunction((void*)setNonzero, BOXED_BOOL, 1))); frozenset_cls->giveAttr("__nonzero__", set_cls->getattr(internStringMortal("__nonzero__"))); diff --git a/test/tests/set.py b/test/tests/set.py index 6d08f9865..13f7f5bb9 100644 --- a/test/tests/set.py +++ b/test/tests/set.py @@ -128,8 +128,12 @@ class MyFrozenset(frozenset): for s1 in set(range(5)), frozenset(range(5)): for s2 in compare_to: - print type(s2), sorted(s2), s1.issubset(s2), s1.issuperset(s2), s1 == s2, s1 != s2, sorted(s1.difference(s2)), s1.isdisjoint(s2), sorted(s1.union(s2)), sorted(s1.intersection(s2)) - + print type(s2), sorted(s2), s1.issubset(s2), s1.issuperset(s2), sorted(s1.difference(s2)), s1.isdisjoint(s2), sorted(s1.union(s2)), sorted(s1.intersection(s2)) + print s1 == s2, s1 != s2 + try: + print s1 < s2, s1 <= s2, s1 > s2, s1 >= s2 + except Exception as e: + print e f = float('nan') s = set([f]) print f in s, f == list(s)[0]