diff --git a/src/flint/test/test.py b/src/flint/test/test.py index 13a3481f..db19f623 100644 --- a/src/flint/test/test.py +++ b/src/flint/test/test.py @@ -90,6 +90,7 @@ def test_fmpz(): assert int(f) == i assert flint.fmpz(f) == f assert flint.fmpz(str(i)) == f + assert raises(lambda: flint.fmpz(1,2), TypeError) assert raises(lambda: flint.fmpz("qwe"), ValueError) assert raises(lambda: flint.fmpz([]), TypeError) for s in L: @@ -162,6 +163,9 @@ def test_fmpz(): # XXX: Handle negative modulus like int? assert raises(lambda: pow(flint.fmpz(2), 2, -1), ValueError) + assert raises(lambda: pow(flint.fmpz(2), "asd", 2), TypeError) + assert raises(lambda: pow(flint.fmpz(2), 2, "asd"), TypeError) + f = flint.fmpz(2) assert f.numerator == f assert type(f.numerator) is flint.fmpz @@ -543,7 +547,8 @@ def test_fmpz_mat(): assert str(M(2,2,[1,2,3,4])) == '[1, 2]\n[3, 4]' assert M(1,2,[3,4]) * flint.fmpq(1,3) == flint.fmpq_mat(1, 2, [1, flint.fmpq(4,3)]) assert flint.fmpq(1,3) * M(1,2,[3,4]) == flint.fmpq_mat(1, 2, [1, flint.fmpq(4,3)]) - assert M(1,2,[3,4]) / 3 == flint.fmpq_mat(1, 2, [1, flint.fmpq(4,3)]) + assert raises(lambda: M(1,2,[3,4]) / 3, DomainError) + assert M(1,2,[2,4]) / 2 == M(1,2,[1,2]) assert M(2,2,[1,2,3,4]).inv().det() == flint.fmpq(1) / M(2,2,[1,2,3,4]).det() assert M(2,2,[1,2,3,4]).inv().inv() == M(2,2,[1,2,3,4]) assert raises(lambda: M.randrank(4,3,4,1), ValueError) @@ -965,7 +970,8 @@ def test_fmpq_poly(): assert 3 * Q([1,2,3]) == Q([3,6,9]) assert Q([1,2,3]) * flint.fmpq(2,3) == (Q([1,2,3]) * 2) / 3 assert flint.fmpq(2,3) * Q([1,2,3]) == (Q([1,2,3]) * 2) / 3 - assert raises(lambda: Q([1,2]) / Q([1,2]), TypeError) + assert Q([1,2]) / Q([1,2]) == Q([1]) + assert raises(lambda: Q([1,2]) / Q([2,2]), DomainError) assert Q([1,2,3]) / flint.fmpq(2,3) == Q([1,2,3]) * flint.fmpq(3,2) assert Q([1,2,3]) ** 2 == Q([1,2,3]) * Q([1,2,3]) assert raises(lambda: pow(Q([1,2]), 3, 5), NotImplementedError) @@ -2056,7 +2062,7 @@ def test_fmpz_mod_poly(): assert raises(lambda: f.exact_division(0), ZeroDivisionError) assert (f * g).exact_division(g) == f - assert raises(lambda: f.exact_division(g), ValueError) + assert raises(lambda: f.exact_division(g), DomainError) # true div assert raises(lambda: f / "AAA", TypeError) @@ -2276,6 +2282,151 @@ def test_fmpz_mod_mat(): assert raises(lambda: flint.fmpz_mod_mat(A, c11), TypeError) +def test_division_scalar(): + Z = flint.fmpz + Q = flint.fmpq + F17 = lambda x: flint.nmod(x, 17) + ctx = flint.fmpz_mod_ctx(163) + F163 = lambda a: flint.fmpz_mod(a, ctx) + # fmpz exact division + for (a, b) in [(Z(4), Z(2)), (Z(4), 2), (4, Z(2))]: + assert a / b == Z(2) + for (a, b) in [(Z(5), Z(2)), (Z(5), 2), (5, Z(2))]: + assert raises(lambda: a / b, DomainError) + # fmpz Euclidean division + for (a, b) in [(Z(5), Z(2)), (Z(5), 2), (5, Z(2))]: + assert a // b == 2 + assert a % b == 1 + assert divmod(a, b) == (2, 1) + # field division + for (a, b) in [(Q(5), Q(2)), (Q(5), 2), (5, Q(2))]: + assert a / b == Q(5,2) + for (a, b) in [(F17(5), F17(2)), (F17(5), 2), (5, F17(2))]: + assert a / b == F17(11) + for (a, b) in [(F163(5), F163(2)), (F163(5), 2), (5, F163(2))]: + assert a / b == F163(84) + # divmod with fields - should this give remainder zero instead of error? + for K in [Q, F17, F163]: + for (a, b) in [(K(5), K(2)), (K(5), 2), (5, K(2))]: + assert raises(lambda: divmod(a, b), TypeError) + # Zero division + for R in [Z, Q, F17, F163]: + assert raises(lambda: R(5) / 0, ZeroDivisionError) + assert raises(lambda: R(5) / R(0), ZeroDivisionError) + assert raises(lambda: 5 / R(0), ZeroDivisionError) + # Bad types + for R in [Z, Q, F17, F163]: + assert raises(lambda: R(5) / "AAA", TypeError) + assert raises(lambda: "AAA" / R(5), TypeError) + + +def test_division_poly(): + Z = flint.fmpz + Q = flint.fmpq + F17 = lambda x: flint.nmod(x, 17) + ctx = flint.fmpz_mod_ctx(163) + F163 = lambda a: flint.fmpz_mod(a, ctx) + PZ = lambda x: flint.fmpz_poly(x) + PQ = lambda x: flint.fmpq_poly(x) + PF17 = lambda x: flint.nmod_poly(x, 17) + PF163 = lambda x: flint.fmpz_mod_poly(x, flint.fmpz_mod_poly_ctx(163)) + # fmpz exact scalar division + assert PZ([2, 4]) / Z(2) == PZ([1, 2]) + assert PZ([2, 4]) / 2 == PZ([1, 2]) + assert raises(lambda: PZ([2, 5]) / Z(2), DomainError) + assert raises(lambda: PZ([2, 5]) / 2, DomainError) + # field division by scalar + for (K, PK) in [(Q, PQ), (F17, PF17), (F163, PF163)]: + assert PK([2, 5]) / K(2) == PK([K(2)/K(2), K(5)/K(2)]) + assert PK([2, 5]) / 2 == PK([K(2)/K(2), K(5)/K(2)]) + # No other scalar division is allowed + for (R, PR) in [(Z, PZ), (Q, PQ), (F17, PF17), (F163, PF163)]: + assert raises(lambda: R(2) / PR([2, 5]), DomainError) + assert raises(lambda: 2 / PR([2, 5]), DomainError) + assert raises(lambda: PR([2, 5]) / 0, ZeroDivisionError) + assert raises(lambda: PR([2, 5]) / R(0), ZeroDivisionError) + # exact polynomial division + for (R, PR) in [(Z, PZ), (Q, PQ), (F17, PF17), (F163, PF163)]: + assert PR([2, 4]) / PR([1, 2]) == PR([2]) + assert PR([2, -3, 1]) / PR([-1, 1]) == PR([-2, 1]) + assert raises(lambda: PR([2, 4]) / PR([1, 3]), DomainError) + assert PR([2]) / PR([2]) == 2 / PR([2]) == PR([1]) + assert PR([0]) / PR([1, 2]) == 0 / PR([1, 2]) == PR([0]) + if R is Z: + assert raises(lambda: PR([1, 2]) / PR([2, 4]), DomainError) + assert raises(lambda: 1 / PR([2]), DomainError) + else: + assert PR([1, 2]) / PR([2, 4]) == PR([R(1)/R(2)]) + assert 1 / PR([2]) == PR([R(1)/R(2)]) + assert raises(lambda: PR([1, 2]) / PR([0]), ZeroDivisionError) + # Euclidean polynomial division + for (R, PR) in [(Z, PZ), (Q, PQ), (F17, PF17), (F163, PF163)]: + assert PR([2, 4]) // PR([1, 2]) == PR([2]) + assert PR([2, 4]) % PR([1, 2]) == PR([0]) + assert divmod(PR([2, 4]), PR([1, 2])) == (PR([2]), PR([0])) + assert PR([3, -3, 1]) // PR([-1, 1]) == PR([-2, 1]) + assert PR([3, -3, 1]) % PR([-1, 1]) == PR([1]) + assert divmod(PR([3, -3, 1]), PR([-1, 1])) == (PR([-2, 1]), PR([1])) + assert PR([2]) // PR([2]) == 2 // PR([2]) == PR([1]) + assert PR([2]) % PR([2]) == 2 % PR([2]) == PR([0]) + assert divmod(PR([2]), PR([2])) == (PR([1]), PR([0])) + assert PR([0]) // PR([1, 2]) == 0 // PR([1, 2]) == PR([0]) + assert PR([0]) % PR([1, 2]) == 0 % PR([1, 2]) == PR([0]) + assert divmod(PR([0]), PR([1, 2])) == (PR([0]), PR([0])) + if R is Z: + assert PR([2, 2]) // PR([2, 4]) == PR([2, 2]) // PR([2, 4]) == PR([0]) + assert PR([2, 2]) % PR([2, 4]) == PR([2, 2]) % PR([2, 4]) == PR([2, 2]) + assert divmod(PR([2, 2]), PR([2, 4])) == (PR([0]), PR([2, 2])) + assert 1 // PR([2]) == PR([1]) // PR([2]) == PR([0]) + assert 1 % PR([2]) == PR([1]) % PR([2]) == PR([1]) + assert divmod(1, PR([2])) == (PR([0]), PR([1])) + else: + assert PR([2, 2]) // PR([2, 4]) == PR([R(1)/R(2)]) + assert PR([2, 2]) % PR([2, 4]) == PR([1]) + assert divmod(PR([2, 2]), PR([2, 4])) == (PR([R(1)/R(2)]), PR([1])) + assert 1 // PR([2]) == PR([R(1)/R(2)]) + assert 1 % PR([2]) == PR([0]) + assert divmod(1, PR([2])) == (PR([R(1)/R(2)]), PR([0])) + + +def test_division_matrix(): + Z = flint.fmpz + Q = flint.fmpq + F17 = lambda x: flint.nmod(x, 17) + ctx = flint.fmpz_mod_ctx(163) + F163 = lambda a: flint.fmpz_mod(a, ctx) + MZ = lambda x: flint.fmpz_mat(x) + MQ = lambda x: flint.fmpq_mat(x) + MF17 = lambda x: flint.nmod_mat(x, 17) + MF163 = lambda x: flint.fmpz_mod_mat(x, ctx) + # fmpz exact division + assert MZ([[2, 4]]) / Z(2) == MZ([[1, 2]]) + assert MZ([[2, 4]]) / 2 == MZ([[1, 2]]) + assert raises(lambda: MZ([[2, 5]]) / Z(2), DomainError) + assert raises(lambda: MZ([[2, 5]]) / 2, DomainError) + # field division by scalar + for (K, MK) in [(Q, MQ), (F17, MF17), (F163, MF163)]: + assert MK([[2, 5]]) / K(2) == MK([[K(2)/K(2), K(5)/K(2)]]) + assert MK([[2, 5]]) / 2 == MK([[K(2)/K(2), K(5)/K(2)]]) + # No other division is allowed + for (R, MR) in [(Z, MZ), (Q, MQ), (F17, MF17), (F163, MF163)]: + M = MR([[2, 5]]) + for s in (2, R(2)): + assert raises(lambda: s / M, TypeError) + assert raises(lambda: M // s, TypeError) + assert raises(lambda: s // M, TypeError) + assert raises(lambda: M % s, TypeError) + assert raises(lambda: s % M, TypeError) + assert raises(lambda: divmod(s, M), TypeError) + assert raises(lambda: divmod(M, s), TypeError) + assert raises(lambda: M / M, TypeError) + assert raises(lambda: M // M, TypeError) + assert raises(lambda: M % M, TypeError) + assert raises(lambda: divmod(M, M), TypeError) + assert raises(lambda: M / 0, ZeroDivisionError) + assert raises(lambda: M / R(0), ZeroDivisionError) + + def _all_polys(): return [ # (poly_type, scalar_type, is_field) @@ -2436,16 +2587,18 @@ def setbad(obj, i, val): assert raises(lambda: P([1, 2, 1]) % P([0]), ZeroDivisionError) assert raises(lambda: divmod(P([1, 2, 1]), P([0])), ZeroDivisionError) + # Exact/field scalar division if is_field: assert P([2, 2]) / 2 == P([1, 1]) assert P([1, 2]) / 2 == P([S(1)/2, 1]) - assert raises(lambda: P([1, 2]) / 0, ZeroDivisionError) else: - assert raises(lambda: P([2, 2]) / 2, TypeError) + assert P([2, 2]) / 2 == P([1, 1]) + assert raises(lambda: P([1, 2]) / 2, DomainError) + assert raises(lambda: P([1, 2]) / 0, ZeroDivisionError) - assert raises(lambda: 1 / P([1, 1]), TypeError) - assert raises(lambda: P([1, 2, 1]) / P([1, 1]), TypeError) - assert raises(lambda: P([1, 2, 1]) / P([1, 2]), TypeError) + assert P([1, 2, 1]) / P([1, 1]) == P([1, 1]) + assert raises(lambda: 1 / P([1, 1]), DomainError) + assert raises(lambda: P([1, 2, 1]) / P([1, 2]), DomainError) assert P([1, 1]) ** 0 == P([1]) assert P([1, 1]) ** 1 == P([1, 1]) @@ -3023,7 +3176,9 @@ def test_all_tests(): test_fmpz_mod_poly, test_fmpz_mod_mat, - test_arb, + test_division_scalar, + test_division_poly, + test_division_matrix, test_polys, @@ -3048,7 +3203,9 @@ def test_all_tests(): test_matrices_rref, test_matrices_solve, + test_arb, + test_pickling, - test_all_tests, + test_all_tests, ] diff --git a/src/flint/types/fmpq_poly.pyx b/src/flint/types/fmpq_poly.pyx index 19d4180c..dddac1d8 100644 --- a/src/flint/types/fmpq_poly.pyx +++ b/src/flint/types/fmpq_poly.pyx @@ -16,6 +16,9 @@ from flint.flintlib.arith cimport arith_bernoulli_polynomial from flint.flintlib.arith cimport arith_euler_polynomial from flint.flintlib.arith cimport arith_legendre_polynomial +from flint.utils.flint_exceptions import DomainError + + cdef any_as_fmpq_poly(obj): if typecheck(obj, fmpq_poly): return obj @@ -295,23 +298,35 @@ cdef class fmpq_poly(flint_poly): return t return t._mod_(s) - @staticmethod - def _div_(fmpq_poly s, t): - cdef fmpq_poly r - t = any_as_fmpq(t) + def __truediv__(fmpq_poly s, t): + cdef fmpq_poly res + cdef fmpq_poly_t r + t2 = any_as_fmpq(t) + if t2 is NotImplemented: + t2 = any_as_fmpq_poly(t) + if t2 is NotImplemented: + return t2 + if fmpq_poly_is_zero((t2).val): + raise ZeroDivisionError("fmpq_poly division by 0") + res = fmpq_poly.__new__(fmpq_poly) + fmpq_poly_init(r) + fmpq_poly_divrem(res.val, r, (s).val, (t2).val) + exact = fmpq_poly_is_zero(r) + fmpq_poly_clear(r) + if not exact: + raise DomainError("fmpq_poly inexact division") + else: + if fmpq_is_zero((t2).val): + raise ZeroDivisionError("fmpq_poly scalar division by 0") + res = fmpq_poly.__new__(fmpq_poly) + fmpq_poly_scalar_div_fmpq(res.val, (s).val, (t2).val) + return res + + def __rtruediv__(fmpq_poly s, t): + t = any_as_fmpq_poly(t) if t is NotImplemented: return t - if fmpq_is_zero((t).val): - raise ZeroDivisionError("fmpq_poly scalar division by 0") - r = fmpq_poly.__new__(fmpq_poly) - fmpq_poly_scalar_div_fmpq(r.val, (s).val, (t).val) - return r - - def __div__(s, t): - return fmpq_poly._div_(s, t) - - def __truediv__(s, t): - return fmpq_poly._div_(s, t) + return t / s def _divmod_(s, t): cdef fmpq_poly P, Q diff --git a/src/flint/types/fmpz.pyx b/src/flint/types/fmpz.pyx index 1d27bc70..a1aa7e05 100644 --- a/src/flint/types/fmpz.pyx +++ b/src/flint/types/fmpz.pyx @@ -1,5 +1,3 @@ -from cpython.version cimport PY_MAJOR_VERSION - from flint.flint_base.flint_base cimport flint_scalar from flint.utils.typecheck cimport typecheck from flint.utils.conversion cimport chars_from_str @@ -12,6 +10,9 @@ from flint.flintlib.fmpz_factor cimport * from flint.flintlib.arith cimport * from flint.flintlib.partitions cimport * +from flint.utils.flint_exceptions import DomainError + + cdef fmpz_get_intlong(fmpz_t x): """ Convert fmpz_t to a Python int or long. @@ -29,10 +30,6 @@ cdef int fmpz_set_any_ref(fmpz_t x, obj): if typecheck(obj, fmpz): x[0] = (obj).val[0] return FMPZ_REF - if PY_MAJOR_VERSION < 3 and PyInt_Check(obj): - fmpz_init(x) - fmpz_set_si(x, PyInt_AS_LONG(obj)) - return FMPZ_TMP if PyLong_Check(obj): fmpz_init(x) fmpz_set_pylong(x, obj) @@ -103,9 +100,6 @@ cdef class fmpz(flint_scalar): def __int__(self): return fmpz_get_intlong(self.val) - def __long__(self): - return long(fmpz_get_intlong(self.val)) - def __index__(self): return fmpz_get_intlong(self.val) @@ -134,27 +128,18 @@ cdef class fmpz(flint_scalar): cdef fmpz_struct * sval cdef int ttype sval = &((s).val[0]) - if PY_MAJOR_VERSION < 3 and PyInt_Check(t): - tl = PyInt_AS_LONG(t) - if op == 2: res = fmpz_cmp_si(sval, tl) == 0 - elif op == 3: res = fmpz_cmp_si(sval, tl) != 0 - elif op == 0: res = fmpz_cmp_si(sval, tl) < 0 - elif op == 1: res = fmpz_cmp_si(sval, tl) <= 0 - elif op == 4: res = fmpz_cmp_si(sval, tl) > 0 - elif op == 5: res = fmpz_cmp_si(sval, tl) >= 0 - else: - ttype = fmpz_set_any_ref(tval, t) - if ttype != FMPZ_UNKNOWN: - if op == 2: res = fmpz_equal(sval, tval) - elif op == 3: res = not fmpz_equal(sval, tval) - elif op == 0: res = fmpz_cmp(sval, tval) < 0 - elif op == 1: res = fmpz_cmp(sval, tval) <= 0 - elif op == 4: res = fmpz_cmp(sval, tval) > 0 - elif op == 5: res = fmpz_cmp(sval, tval) >= 0 - if ttype == FMPZ_TMP: - fmpz_clear(tval) - if ttype == FMPZ_UNKNOWN: - return NotImplemented + ttype = fmpz_set_any_ref(tval, t) + if ttype != FMPZ_UNKNOWN: + if op == 2: res = fmpz_equal(sval, tval) + elif op == 3: res = not fmpz_equal(sval, tval) + elif op == 0: res = fmpz_cmp(sval, tval) < 0 + elif op == 1: res = fmpz_cmp(sval, tval) <= 0 + elif op == 4: res = fmpz_cmp(sval, tval) > 0 + elif op == 5: res = fmpz_cmp(sval, tval) >= 0 + if ttype == FMPZ_TMP: + fmpz_clear(tval) + if ttype == FMPZ_UNKNOWN: + return NotImplemented return res def bit_length(self): @@ -265,6 +250,39 @@ cdef class fmpz(flint_scalar): if ttype == FMPZ_TMP: fmpz_clear(tval) return u + def __truediv__(s, t): + cdef fmpz_struct tval[1] + cdef fmpz_struct rval[1] + cdef int ttype + + ttype = fmpz_set_any_ref(tval, t) + if ttype == FMPZ_UNKNOWN: + return NotImplemented + + if fmpz_is_zero(tval): + if ttype == FMPZ_TMP: + fmpz_clear(tval) + raise ZeroDivisionError("fmpz division by zero") + + q = fmpz.__new__(fmpz) + fmpz_init(rval) + fmpz_fdiv_qr((q).val, rval, (s).val, tval) + exact = fmpz_is_zero(rval) + fmpz_clear(rval) + + if ttype == FMPZ_TMP: fmpz_clear(tval) + + if exact: + return q + else: + raise DomainError("fmpz division is not exact") + + def __rtruediv__(s, t): + t = any_as_fmpz(t) + if t is NotImplemented: + return t + return t.__truediv__(s) + def __floordiv__(s, t): cdef fmpz_struct tval[1] cdef int ttype = FMPZ_UNKNOWN diff --git a/src/flint/types/fmpz_mat.pyx b/src/flint/types/fmpz_mat.pyx index a2609c2a..4012f8ca 100644 --- a/src/flint/types/fmpz_mat.pyx +++ b/src/flint/types/fmpz_mat.pyx @@ -18,6 +18,9 @@ from flint.flintlib.fmpq_mat cimport fmpq_mat_init from flint.flintlib.fmpq_mat cimport fmpq_mat_set_fmpz_mat_div_fmpz from flint.flintlib.fmpq_mat cimport fmpq_mat_solve_fmpz_mat +from flint.utils.flint_exceptions import DomainError + + cdef any_as_fmpz_mat(obj): if typecheck(obj, fmpz_mat): return obj @@ -131,13 +134,10 @@ cdef class fmpz_mat(flint_mat): def __nonzero__(self): return not fmpz_mat_is_zero(self.val) - def __richcmp__(s, t, int op): + def __richcmp__(fmpz_mat s, t, int op): cdef bint r if op != 2 and op != 3: raise TypeError("matrices cannot be ordered") - s = any_as_fmpz_mat(s) - if t is NotImplemented: - return s t = any_as_fmpz_mat(t) if t is NotImplemented: return t @@ -282,15 +282,22 @@ cdef class fmpz_mat(flint_mat): return fmpq_mat(s) * t return NotImplemented - @staticmethod - def _div_(fmpz_mat s, t): - return s * (1 / fmpq(t)) - - def __truediv__(s, t): - return fmpz_mat._div_(s, t) - - def __div__(s, t): - return fmpz_mat._div_(s, t) + def __truediv__(fmpz_mat s, t): + cdef fmpz_mat u + cdef fmpz_mat_struct *sval + t = any_as_fmpz(t) + if t is NotImplemented: + return t + if fmpz_is_zero((t).val): + raise ZeroDivisionError("division by zero") + sval = &(s).val[0] + u = fmpz_mat.__new__(fmpz_mat) + fmpz_mat_init(u.val, fmpz_mat_nrows(sval), fmpz_mat_ncols(sval)) + fmpz_mat_scalar_divexact_fmpz(u.val, sval, (t).val) + # XXX: check for exact division - there should be a better way! + if u * t != s: + raise DomainError("fmpz_mat division is not exact") + return u def __pow__(self, e, m): cdef fmpz_mat t diff --git a/src/flint/types/fmpz_mod_poly.pyx b/src/flint/types/fmpz_mod_poly.pyx index 8502b001..f5b5d60b 100644 --- a/src/flint/types/fmpz_mod_poly.pyx +++ b/src/flint/types/fmpz_mod_poly.pyx @@ -452,7 +452,19 @@ cdef class fmpz_mod_poly(flint_poly): return res def __truediv__(s, t): - return fmpz_mod_poly._div_(s, t) + t2 = s.ctx.mod.any_as_fmpz_mod(t) + if t2 is not NotImplemented: + return s._div_(t2) + t2 = s.ctx.any_as_fmpz_mod_poly(t) + if t2 is NotImplemented: + return NotImplemented + return s.exact_division(t2) + + def __rtruediv__(s, t): + t = s.ctx.any_as_fmpz_mod_poly(t) + if t is NotImplemented: + return NotImplemented + return t.exact_division(s) def exact_division(self, right): """ @@ -482,7 +494,7 @@ cdef class fmpz_mod_poly(flint_poly): res.val, self.val, (right).val, res.ctx.mod.val ) if check == 0: - raise ValueError( + raise DomainError( f"{right} does not divide {self}" ) diff --git a/src/flint/types/fmpz_poly.pyx b/src/flint/types/fmpz_poly.pyx index 27253e48..4630e420 100644 --- a/src/flint/types/fmpz_poly.pyx +++ b/src/flint/types/fmpz_poly.pyx @@ -19,7 +19,7 @@ from flint.types.arb cimport arb from flint.types.acb cimport any_as_acb_or_notimplemented cimport libc.stdlib from flint.flintlib.fmpz cimport fmpz_init, fmpz_clear, fmpz_set -from flint.flintlib.fmpz cimport fmpz_is_one, fmpz_equal_si, fmpz_equal +from flint.flintlib.fmpz cimport fmpz_is_zero, fmpz_is_one, fmpz_equal_si, fmpz_equal from flint.flintlib.acb_modular cimport * from flint.flintlib.ulong_extras cimport n_is_prime from flint.flintlib.fmpz_poly cimport * @@ -29,6 +29,9 @@ from flint.flintlib.acb cimport * from flint.flintlib.arb_poly cimport * from flint.flintlib.arb_fmpz_poly cimport * +from flint.utils.flint_exceptions import DomainError + + cdef any_as_fmpz_poly(x): cdef fmpz_poly res if typecheck(x, fmpz_poly): @@ -227,6 +230,34 @@ cdef class fmpz_poly(flint_poly): def __rmul__(self, other): return self._mul_(other) + def __truediv__(fmpz_poly self, other): + cdef fmpz_poly res + o = any_as_fmpz(other) + if o is NotImplemented: + o = any_as_fmpz_poly(other) + if o is NotImplemented: + return NotImplemented + if fmpz_poly_is_zero((o).val): + raise ZeroDivisionError("fmpz_poly division by 0") + res, r = self._divmod_(o) + if r: + raise DomainError("fmpz_poly division is not exact") + else: + if fmpz_is_zero((o).val): + raise ZeroDivisionError("fmpz_poly division by 0") + res = fmpz_poly.__new__(fmpz_poly) + fmpz_poly_scalar_divexact_fmpz(res.val, self.val, (o).val) + # Check division is exact - there should be a better way to do this + if res * o != self: + raise DomainError("fmpz_poly division is not exact") + return res + + def __rtruediv__(fmpz_poly self, other): + o = any_as_fmpz_poly(other) + if o is NotImplemented: + return NotImplemented + return o / self + def _floordiv_(self, other): cdef fmpz_poly res if fmpz_poly_is_zero((other).val): diff --git a/src/flint/types/nmod_poly.pyx b/src/flint/types/nmod_poly.pyx index 0ddaf66b..637001af 100644 --- a/src/flint/types/nmod_poly.pyx +++ b/src/flint/types/nmod_poly.pyx @@ -12,6 +12,9 @@ from flint.flintlib.nmod_poly_factor cimport * from flint.flintlib.fmpz_poly cimport fmpz_poly_get_nmod_poly from flint.flintlib.ulong_extras cimport n_gcdinv +from flint.utils.flint_exceptions import DomainError + + cdef any_as_nmod_poly(obj, nmod_t mod): cdef nmod_poly r cdef mp_limb_t v @@ -258,7 +261,23 @@ cdef class nmod_poly(flint_poly): def __rmul__(s, t): return s._mul_(t) - # TODO: __div__, __truediv__ + def __truediv__(s, t): + t = any_as_nmod_poly(t, (s).val.mod) + if t is NotImplemented: + return t + res, r = s._divmod_(t) + if not nmod_poly_is_zero((r).val): + raise DomainError("nmod_poly inexact division") + return res + + def __rtruediv__(s, t): + t = any_as_nmod_poly(t, (s).val.mod) + if t is NotImplemented: + return t + res, r = t._divmod_(s) + if not nmod_poly_is_zero((r).val): + raise DomainError("nmod_poly inexact division") + return res def _floordiv_(s, t): cdef nmod_poly r @@ -308,13 +327,6 @@ cdef class nmod_poly(flint_poly): return t return t._divmod_(s) - def __truediv__(s, t): - try: - t = nmod(t, (s).val.mod.n) - except TypeError: - return NotImplemented - return s * t ** -1 - def __mod__(s, t): return divmod(s, t)[1] # XXX