@@ -5,12 +5,14 @@ from flint.types.fmpz cimport any_as_fmpz
55from flint.types.fmpz cimport fmpz
66from flint.types.fmpq cimport fmpq
77
8+ from flint.flintlib.flint cimport ulong
89from flint.flintlib.fmpz cimport fmpz_t
910from flint.flintlib.nmod cimport nmod_pow_fmpz, nmod_inv
1011from flint.flintlib.nmod_vec cimport *
1112from flint.flintlib.fmpz cimport fmpz_fdiv_ui, fmpz_init, fmpz_clear
1213from flint.flintlib.fmpz cimport fmpz_set_ui, fmpz_get_ui
1314from flint.flintlib.fmpq cimport fmpq_mod_fmpz
15+ from flint.flintlib.ulong_extras cimport n_gcdinv
1416
1517cdef int any_as_nmod(mp_limb_t * val, obj, nmod_t mod) except - 1 :
1618 cdef int success
@@ -64,9 +66,6 @@ cdef class nmod(flint_scalar):
6466 def __int__ (self ):
6567 return int (self .val)
6668
67- def __long__ (self ):
68- return self .val
69-
7069 def modulus (self ):
7170 return self .mod.n
7271
@@ -170,6 +169,8 @@ cdef class nmod(flint_scalar):
170169 cdef nmod r
171170 cdef mp_limb_t sval, tval, x
172171 cdef nmod_t mod
172+ cdef ulong tinvval
173+
173174 if typecheck(s, nmod):
174175 mod = (< nmod> s).mod
175176 sval = (< nmod> s).val
@@ -180,17 +181,19 @@ cdef class nmod(flint_scalar):
180181 tval = (< nmod> t).val
181182 if not any_as_nmod(& sval, s, mod):
182183 return NotImplemented
184+
183185 if tval == 0 :
184186 raise ZeroDivisionError (" %s is not invertible mod %s " % (tval, mod.n))
185187 if not s:
186188 return s
187- # XXX: check invertibility?
188- x = nmod_div(sval, tval, mod)
189- if x == 0 :
189+
190+ g = n_gcdinv( & tinvval, < ulong > tval, < ulong > mod.n )
191+ if g ! = 1 :
190192 raise ZeroDivisionError (" %s is not invertible mod %s " % (tval, mod.n))
193+
191194 r = nmod.__new__ (nmod)
192195 r.mod = mod
193- r.val = x
196+ r.val = nmod_mul(sval, < mp_limb_t > tinvval, mod)
194197 return r
195198
196199 def __truediv__ (s , t ):
@@ -200,18 +203,43 @@ cdef class nmod(flint_scalar):
200203 return nmod._div_(t, s)
201204
202205 def __invert__ (self ):
203- return (1 / self ) # XXX: speed up
206+ cdef nmod r
207+ cdef ulong g, inv, sval
208+ sval = < ulong> (< nmod> self ).val
209+ g = n_gcdinv(& inv, sval, self .mod.n)
210+ if g != 1 :
211+ raise ZeroDivisionError (" %s is not invertible mod %s " % (sval, self .mod.n))
212+ r = nmod.__new__ (nmod)
213+ r.mod = self .mod
214+ r.val = < mp_limb_t> inv
215+ return r
204216
205- def __pow__ (self , exp ):
217+ def __pow__ (self , exp , modulus = None ):
206218 cdef nmod r
219+ cdef mp_limb_t rval, mod
220+ cdef ulong g, rinv
221+
222+ if modulus is not None :
223+ raise TypeError (" three-argument pow() not supported by nmod" )
224+
207225 e = any_as_fmpz(exp)
208226 if e is NotImplemented :
209227 return NotImplemented
210- r = nmod.__new__ (nmod)
211- r.mod = self .mod
212- r.val = self .val
228+
229+ rval = (< nmod> self ).val
230+ mod = (< nmod> self ).mod.n
231+
232+ # XXX: It is not clear that it is necessary to special case negative
233+ # exponents here. The nmod_pow_fmpz function seems to handle this fine
234+ # but the Flint docs say that the exponent must be nonnegative.
213235 if e < 0 :
214- r.val = nmod_inv(r.val, self .mod)
236+ g = n_gcdinv(& rinv, < ulong> rval, < ulong> mod)
237+ if g != 1 :
238+ raise ZeroDivisionError (" %s is not invertible mod %s " % (rval, mod))
239+ rval = < mp_limb_t> rinv
215240 e = - e
216- r.val = nmod_pow_fmpz(r.val, (< fmpz> e).val, self .mod)
241+
242+ r = nmod.__new__ (nmod)
243+ r.mod = self .mod
244+ r.val = nmod_pow_fmpz(rval, (< fmpz> e).val, self .mod)
217245 return r
0 commit comments