diff --git a/cytoolz/itertoolz.pxd b/cytoolz/itertoolz.pxd index 57e0609..ba4872c 100644 --- a/cytoolz/itertoolz.pxd +++ b/cytoolz/itertoolz.pxd @@ -36,9 +36,8 @@ cdef object c_merge_sorted(object seqs, object key=*) cdef class interleave: cdef list iters - cdef list newiters cdef Py_ssize_t i - cdef Py_ssize_t n + cdef Py_ssize_t active cdef class _unique_key: diff --git a/cytoolz/itertoolz.pyx b/cytoolz/itertoolz.pyx index 0cce8d5..0b626c7 100644 --- a/cytoolz/itertoolz.pyx +++ b/cytoolz/itertoolz.pyx @@ -328,54 +328,46 @@ cdef class interleave: Returns a lazy iterator """ def __cinit__(self, seqs): - self.iters = [iter(seq) for seq in seqs] - self.newiters = [] + self.iters = list(map(iter, seqs)) self.i = 0 - self.n = PyList_GET_SIZE(self.iters) + self.active = PyList_GET_SIZE(self.iters) def __iter__(self): return self def __next__(self): - # This implementation is similar to what is done in `toolz` in that we - # construct a new list of iterators, `self.newiters`, when a value is - # successfully retrieved from an iterator from `self.iters`. + cdef object itrobj, val + cdef list iters cdef PyObject *obj - cdef object val + cdef Py_ssize_t _len - if self.i == self.n: - self.n = PyList_GET_SIZE(self.newiters) - self.i = 0 - if self.n == 0: - raise StopIteration - self.iters = self.newiters - self.newiters = [] - val = PyList_GET_ITEM(self.iters, self.i) + iters = self.iters + _len = PyList_GET_SIZE(iters) + + itrobj = PyList_GET_ITEM(iters, self.i) self.i += 1 - obj = PtrIter_Next(val) + if self.i == _len: + self.i = 0 + obj = PtrIter_Next(itrobj) - # TODO: optimization opportunity. Previously, it was possible to - # continue on given exceptions, `self.pass_exceptions`, which is - # why this code is structured this way. Time to clean up? while obj is NULL: + # Check if error occurred obj = PyErr_Occurred() if obj is not NULL: + # Iterator raised an exception val = obj PyErr_Clear() raise val - if self.i == self.n: - self.n = PyList_GET_SIZE(self.newiters) - self.i = 0 - if self.n == 0: - raise StopIteration - self.iters = self.newiters - self.newiters = [] - val = PyList_GET_ITEM(self.iters, self.i) - self.i += 1 - obj = PtrIter_Next(val) + self.active = max(self.active - 1, 0) + if self.active == 0: + raise StopIteration - PyList_Append(self.newiters, val) + itrobj = PyList_GET_ITEM(iters, self.i) + self.i += 1 + if self.i == _len: + self.i = 0 + obj = PtrIter_Next(itrobj) val = obj Py_XDECREF(obj) return val