Skip to content

Commit 7c91033

Browse files
committed
Reintroduce cdef _union & _find
This avoids potential SEGFAULTs from Python calls
1 parent be8144d commit 7c91033

File tree

2 files changed

+69
-10
lines changed

2 files changed

+69
-10
lines changed

src/sage/sets/disjoint_set.pxd

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,8 @@ cdef class DisjointSet_class(SageObject):
2222
cdef class DisjointSet_of_integers(DisjointSet_class):
2323
cpdef int find(self, int i)
2424
cpdef void union(self, int i, int j)
25+
cdef inline int _find(self, int i)
26+
cdef inline void _union(self, int i, int j)
2527
cpdef root_to_elements_dict(self)
2628
cpdef element_to_root_dict(self)
2729
cpdef to_digraph(self)

src/sage/sets/disjoint_set.pyx

Lines changed: 67 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ AUTHORS:
1414
- Sébastien Labbé (2008) - Initial version.
1515
- Sébastien Labbé (2009-11-24) - Pickling support
1616
- Sébastien Labbé (2010-01) - Inclusion into sage (:issue:`6775`).
17+
- Giorgos Mousa (2024-04-22): Optimize
1718
1819
EXAMPLES:
1920
@@ -447,7 +448,7 @@ cdef class DisjointSet_of_integers(DisjointSet_class):
447448
448449
INPUT:
449450
450-
- ``i`` -- element in ``self`` (no input checking)
451+
- ``i`` -- element in ``self``
451452
452453
EXAMPLES::
453454
@@ -468,8 +469,31 @@ cdef class DisjointSet_of_integers(DisjointSet_class):
468469
{{0}, {1, 2, 3, 4}}
469470
sage: [e.find(i) for i in range(5)]
470471
[0, 1, 1, 1, 1]
471-
sage: e.find(5) # no input checking
472-
0
472+
sage: e.find(2**10)
473+
ValueError: i(=1024) must be between 0 and 4
474+
...
475+
"""
476+
card = self.cardinality()
477+
if i < 0 or i>= card:
478+
raise ValueError('i(=%s) must be between 0 and %s' % (i, card - 1))
479+
return OP_find(self._nodes, i)
480+
481+
cdef inline int _find(self, int i):
482+
r"""
483+
Return the representative of the set that ``i`` currently belongs to.
484+
485+
INPUT:
486+
487+
- ``i`` -- element in ``self``
488+
489+
EXAMPLES::
490+
491+
sage: e = DisjointSet(5)
492+
sage: e._find(5) # only C-callable
493+
Traceback (most recent call last):
494+
...
495+
AttributeError: 'sage.sets.disjoint_set.DisjointSet_of_integers'
496+
object has no attribute '_find'. Did you mean: 'find'?
473497
"""
474498
return OP_find(self._nodes, i)
475499

@@ -495,8 +519,38 @@ cdef class DisjointSet_of_integers(DisjointSet_class):
495519
{{0, 1}, {2, 4}, {3}}
496520
sage: d.union(1, 4); d
497521
{{0, 1, 2, 4}, {3}}
498-
sage: d.union(1, 5); d # no input checking
499-
{{0, 1, 2, 4}, {3}}
522+
sage: d.union(1, 5)
523+
ValueError: j(=5) must be between 0 and 4
524+
...
525+
"""
526+
cdef int card = self._nodes.degree
527+
if i < 0 or i >= card:
528+
raise ValueError('i(=%s) must be between 0 and %s' % (i, card - 1))
529+
if j < 0 or j >= card:
530+
raise ValueError('j(=%s) must be between 0 and %s' % (j, card - 1))
531+
OP_join(self._nodes, i, j)
532+
533+
cdef inline void _union(self, int i, int j):
534+
r"""
535+
Combine the set of ``i`` and the set of ``j`` into one.
536+
537+
All elements in those two sets will share the same representative
538+
that can be gotten using find.
539+
540+
INPUT:
541+
542+
- ``i`` -- element in ``self``
543+
- ``j`` -- element in ``self``
544+
545+
EXAMPLES::
546+
547+
sage: d = DisjointSet(5); d
548+
{{0}, {1}, {2}, {3}, {4}}
549+
sage: d._union(0, 1) # only C-callable
550+
Traceback (most recent call last):
551+
...
552+
AttributeError: 'sage.sets.disjoint_set.DisjointSet_of_integers'
553+
object has no attribute '_union'. Did you mean: 'union'?
500554
"""
501555
OP_join(self._nodes, i, j)
502556

@@ -727,7 +781,7 @@ cdef class DisjointSet_of_hashables(DisjointSet_class):
727781
728782
INPUT:
729783
730-
- ``e`` -- element in ``self`` (no input checking)
784+
- ``e`` -- element in ``self``
731785
732786
EXAMPLES::
733787
@@ -754,7 +808,7 @@ cdef class DisjointSet_of_hashables(DisjointSet_class):
754808
KeyError: 5
755809
"""
756810
cdef int i = <int> self._el_to_int[e]
757-
cdef int r = <int> self._d.find(i)
811+
cdef int r = <int> self._d._find(i)
758812
return self._int_to_el[r]
759813

760814
cpdef void union(self, e, f):
@@ -766,8 +820,8 @@ cdef class DisjointSet_of_hashables(DisjointSet_class):
766820
767821
INPUT:
768822
769-
- ``e`` -- element in ``self`` (no input checking)
770-
- ``f`` -- element in ``self`` (no input checking)
823+
- ``e`` -- element in ``self``
824+
- ``f`` -- element in ``self``
771825
772826
EXAMPLES::
773827
@@ -779,10 +833,13 @@ cdef class DisjointSet_of_hashables(DisjointSet_class):
779833
{{'a', 'b'}, {'c', 'e'}, {'d'}}
780834
sage: e.union('b', 'e'); e
781835
{{'a', 'b', 'c', 'e'}, {'d'}}
836+
sage: e.union('a', 2**10)
837+
KeyError: 1024
838+
...
782839
"""
783840
cdef int i = <int> self._el_to_int[e]
784841
cdef int j = <int> self._el_to_int[f]
785-
self._d.union(i, j)
842+
self._d._union(i, j)
786843

787844
cpdef root_to_elements_dict(self):
788845
r"""

0 commit comments

Comments
 (0)