diff --git a/.travis.yml b/.travis.yml index 680f91f1..830b0924 100644 --- a/.travis.yml +++ b/.travis.yml @@ -197,18 +197,22 @@ jobs: - name: "OSX py 3.5" os: osx + osx_image: xcode10.2 env: BUILD=tests,wheels PYTHON_VERSION=3.5.9 PGVERSION=12 - name: "OSX py 3.6" os: osx + osx_image: xcode10.2 env: BUILD=tests,wheels PYTHON_VERSION=3.6.10 PGVERSION=12 - name: "OSX py 3.7" os: osx + osx_image: xcode10.2 env: BUILD=tests,wheels PYTHON_VERSION=3.7.7 PGVERSION=12 - name: "OSX py 3.8" os: osx + osx_image: xcode10.2 env: BUILD=tests,wheels PYTHON_VERSION=3.8.3 PGVERSION=12 cache: diff --git a/asyncpg/connection.py b/asyncpg/connection.py index b7266471..aedd7139 100644 --- a/asyncpg/connection.py +++ b/asyncpg/connection.py @@ -331,6 +331,13 @@ async def executemany(self, command: str, args, *, timeout: float=None): .. versionchanged:: 0.11.0 `timeout` became a keyword-only parameter. + + .. versionchanged:: 0.22.0 + The execution was changed to be in an implicit transaction if there + was no explicit transaction, so that it will no longer end up with + partial success. If you still need the previous behavior to + progressively execute many args, please use a loop with prepared + statement instead. """ self._check_open() return await self._executemany(command, args, timeout) @@ -1010,6 +1017,9 @@ async def _copy_in(self, copy_stmt, source, timeout): f = source elif isinstance(source, collections.abc.AsyncIterable): # assuming calling output returns an awaitable. + # copy_in() is designed to handle very large amounts of data, and + # the source async iterable is allowed to return an arbitrary + # amount of data on every iteration. reader = source else: # assuming source is an instance supporting the buffer protocol. diff --git a/asyncpg/prepared_stmt.py b/asyncpg/prepared_stmt.py index 5df6b674..eeb45367 100644 --- a/asyncpg/prepared_stmt.py +++ b/asyncpg/prepared_stmt.py @@ -202,11 +202,24 @@ async def fetchrow(self, *args, timeout=None): return None return data[0] - async def __bind_execute(self, args, limit, timeout): + @connresource.guarded + async def executemany(self, args, *, timeout: float=None): + """Execute the statement for each sequence of arguments in *args*. + + :param args: An iterable containing sequences of arguments. + :param float timeout: Optional timeout value in seconds. + :return None: This method discards the results of the operations. + + .. versionadded:: 0.22.0 + """ + return await self.__do_execute( + lambda protocol: protocol.bind_execute_many( + self._state, args, '', timeout)) + + async def __do_execute(self, executor): protocol = self._connection._protocol try: - data, status, _ = await protocol.bind_execute( - self._state, args, '', limit, True, timeout) + return await executor(protocol) except exceptions.OutdatedSchemaCacheError: await self._connection.reload_schema_state() # We can not find all manually created prepared statements, so just @@ -215,6 +228,11 @@ async def __bind_execute(self, args, limit, timeout): # invalidate themselves (unfortunately, clearing caches again). self._state.mark_closed() raise + + async def __bind_execute(self, args, limit, timeout): + data, status, _ = await self.__do_execute( + lambda protocol: protocol.bind_execute( + self._state, args, '', limit, True, timeout)) self._last_status = status return data diff --git a/asyncpg/protocol/consts.pxi b/asyncpg/protocol/consts.pxi index 97cbbf35..e1f8726e 100644 --- a/asyncpg/protocol/consts.pxi +++ b/asyncpg/protocol/consts.pxi @@ -8,3 +8,5 @@ DEF _MAXINT32 = 2**31 - 1 DEF _COPY_BUFFER_SIZE = 524288 DEF _COPY_SIGNATURE = b"PGCOPY\n\377\r\n\0" +DEF _EXECUTE_MANY_BUF_NUM = 4 +DEF _EXECUTE_MANY_BUF_SIZE = 32768 diff --git a/asyncpg/protocol/coreproto.pxd b/asyncpg/protocol/coreproto.pxd index c96b1fa5..f21559b4 100644 --- a/asyncpg/protocol/coreproto.pxd +++ b/asyncpg/protocol/coreproto.pxd @@ -114,6 +114,7 @@ cdef class CoreProtocol: # True - completed, False - suspended bint result_execute_completed + cpdef is_in_transaction(self) cdef _process__auth(self, char mtype) cdef _process__prepare(self, char mtype) cdef _process__bind_execute(self, char mtype) @@ -146,6 +147,7 @@ cdef class CoreProtocol: cdef _auth_password_message_sasl_continue(self, bytes server_response) cdef _write(self, buf) + cdef _writelines(self, list buffers) cdef _read_server_messages(self) @@ -155,9 +157,13 @@ cdef class CoreProtocol: cdef _ensure_connected(self) + cdef WriteBuffer _build_parse_message(self, str stmt_name, str query) cdef WriteBuffer _build_bind_message(self, str portal_name, str stmt_name, WriteBuffer bind_data) + cdef WriteBuffer _build_empty_bind_data(self) + cdef WriteBuffer _build_execute_message(self, str portal_name, + int32_t limit) cdef _connect(self) @@ -166,8 +172,10 @@ cdef class CoreProtocol: WriteBuffer bind_data, int32_t limit) cdef _bind_execute(self, str portal_name, str stmt_name, WriteBuffer bind_data, int32_t limit) - cdef _bind_execute_many(self, str portal_name, str stmt_name, - object bind_data) + cdef bint _bind_execute_many(self, str portal_name, str stmt_name, + object bind_data) + cdef bint _bind_execute_many_more(self, bint first=*) + cdef _bind_execute_many_fail(self, object error, bint first=*) cdef _bind(self, str portal_name, str stmt_name, WriteBuffer bind_data) cdef _execute(self, str portal_name, int32_t limit) diff --git a/asyncpg/protocol/coreproto.pyx b/asyncpg/protocol/coreproto.pyx index fdc26ec6..12ebf6c6 100644 --- a/asyncpg/protocol/coreproto.pyx +++ b/asyncpg/protocol/coreproto.pyx @@ -27,13 +27,13 @@ cdef class CoreProtocol: # type of `scram` is `SCRAMAuthentcation` self.scram = None - # executemany support data - self._execute_iter = None - self._execute_portal_name = None - self._execute_stmt_name = None - self._reset_result() + cpdef is_in_transaction(self): + # PQTRANS_INTRANS = idle, within transaction block + # PQTRANS_INERROR = idle, within failed transaction + return self.xact_status in (PQTRANS_INTRANS, PQTRANS_INERROR) + cdef _read_server_messages(self): cdef: char mtype @@ -263,27 +263,16 @@ cdef class CoreProtocol: elif mtype == b'Z': # ReadyForQuery self._parse_msg_ready_for_query() - if self.result_type == RESULT_FAILED: - self._push_result() - else: - try: - buf = next(self._execute_iter) - except StopIteration: - self._push_result() - except Exception as e: - self.result_type = RESULT_FAILED - self.result = e - self._push_result() - else: - # Next iteration over the executemany() arg sequence - self._send_bind_message( - self._execute_portal_name, self._execute_stmt_name, - buf, 0) + self._push_result() elif mtype == b'I': # EmptyQueryResponse self.buffer.discard_message() + elif mtype == b'1': + # ParseComplete + self.buffer.discard_message() + cdef _process__bind(self, char mtype): if mtype == b'E': # ErrorResponse @@ -730,6 +719,11 @@ cdef class CoreProtocol: self.result_execute_completed = False self._discard_data = False + # executemany support data + self._execute_iter = None + self._execute_portal_name = None + self._execute_stmt_name = None + cdef _set_state(self, ProtocolState new_state): if new_state == PROTOCOL_IDLE: if self.state == PROTOCOL_FAILED: @@ -780,6 +774,17 @@ cdef class CoreProtocol: if self.con_status != CONNECTION_OK: raise apg_exc.InternalClientError('not connected') + cdef WriteBuffer _build_parse_message(self, str stmt_name, str query): + cdef WriteBuffer buf + + buf = WriteBuffer.new_message(b'P') + buf.write_str(stmt_name, self.encoding) + buf.write_str(query, self.encoding) + buf.write_int16(0) + + buf.end_message() + return buf + cdef WriteBuffer _build_bind_message(self, str portal_name, str stmt_name, WriteBuffer bind_data): @@ -795,6 +800,25 @@ cdef class CoreProtocol: buf.end_message() return buf + cdef WriteBuffer _build_empty_bind_data(self): + cdef WriteBuffer buf + buf = WriteBuffer.new() + buf.write_int16(0) # The number of parameter format codes + buf.write_int16(0) # The number of parameter values + buf.write_int16(0) # The number of result-column format codes + return buf + + cdef WriteBuffer _build_execute_message(self, str portal_name, + int32_t limit): + cdef WriteBuffer buf + + buf = WriteBuffer.new_message(b'E') + buf.write_str(portal_name, self.encoding) # name of the portal + buf.write_int32(limit) # number of rows to return; 0 - all + + buf.end_message() + return buf + # API for subclasses cdef _connect(self): @@ -845,12 +869,7 @@ cdef class CoreProtocol: self._ensure_connected() self._set_state(PROTOCOL_PREPARE) - buf = WriteBuffer.new_message(b'P') - buf.write_str(stmt_name, self.encoding) - buf.write_str(query, self.encoding) - buf.write_int16(0) - buf.end_message() - packet = buf + packet = self._build_parse_message(stmt_name, query) buf = WriteBuffer.new_message(b'D') buf.write_byte(b'S') @@ -872,10 +891,7 @@ cdef class CoreProtocol: buf = self._build_bind_message(portal_name, stmt_name, bind_data) packet = buf - buf = WriteBuffer.new_message(b'E') - buf.write_str(portal_name, self.encoding) # name of the portal - buf.write_int32(limit) # number of rows to return; 0 - all - buf.end_message() + buf = self._build_execute_message(portal_name, limit) packet.write_buffer(buf) packet.write_bytes(SYNC_MESSAGE) @@ -894,11 +910,8 @@ cdef class CoreProtocol: self._send_bind_message(portal_name, stmt_name, bind_data, limit) - cdef _bind_execute_many(self, str portal_name, str stmt_name, - object bind_data): - - cdef WriteBuffer buf - + cdef bint _bind_execute_many(self, str portal_name, str stmt_name, + object bind_data): self._ensure_connected() self._set_state(PROTOCOL_BIND_EXECUTE_MANY) @@ -907,17 +920,92 @@ cdef class CoreProtocol: self._execute_iter = bind_data self._execute_portal_name = portal_name self._execute_stmt_name = stmt_name + return self._bind_execute_many_more(True) - try: - buf = next(bind_data) - except StopIteration: - self._push_result() - except Exception as e: - self.result_type = RESULT_FAILED - self.result = e + cdef bint _bind_execute_many_more(self, bint first=False): + cdef: + WriteBuffer packet + WriteBuffer buf + list buffers = [] + + # as we keep sending, the server may return an error early + if self.result_type == RESULT_FAILED: + self._write(SYNC_MESSAGE) + return False + + # collect up to four 32KB buffers to send + # https://github.com/MagicStack/asyncpg/pull/289#issuecomment-391215051 + while len(buffers) < _EXECUTE_MANY_BUF_NUM: + packet = WriteBuffer.new() + + # fill one 32KB buffer + while packet.len() < _EXECUTE_MANY_BUF_SIZE: + try: + # grab one item from the input + buf = next(self._execute_iter) + + # reached the end of the input + except StopIteration: + if first: + # if we never send anything, simply set the result + self._push_result() + else: + # otherwise, append SYNC and send the buffers + packet.write_bytes(SYNC_MESSAGE) + buffers.append(packet) + self._writelines(buffers) + return False + + # error in input, give up the buffers and cleanup + except Exception as ex: + self._bind_execute_many_fail(ex, first) + return False + + # all good, write to the buffer + first = False + packet.write_buffer( + self._build_bind_message( + self._execute_portal_name, + self._execute_stmt_name, + buf, + ) + ) + packet.write_buffer( + self._build_execute_message(self._execute_portal_name, 0, + ) + ) + + # collected one buffer + buffers.append(packet) + + # write to the wire, and signal the caller for more to send + self._writelines(buffers) + return True + + cdef _bind_execute_many_fail(self, object error, bint first=False): + cdef WriteBuffer buf + + self.result_type = RESULT_FAILED + self.result = error + if first: self._push_result() + elif self.is_in_transaction(): + # we're in an explicit transaction, just SYNC + self._write(SYNC_MESSAGE) else: - self._send_bind_message(portal_name, stmt_name, buf, 0) + # In an implicit transaction, if `ignore_till_sync` is set, + # `ROLLBACK` will be ignored and `Sync` will restore the state; + # or the transaction will be rolled back with a warning saying + # that there was no transaction, but rollback is done anyway, + # so we could safely ignore this warning. + # GOTCHA: cannot use simple query message here, because it is + # ignored if `ignore_till_sync` is set. + buf = self._build_parse_message('', 'ROLLBACK') + buf.write_buffer(self._build_bind_message( + '', '', self._build_empty_bind_data())) + buf.write_buffer(self._build_execute_message('', 0)) + buf.write_bytes(SYNC_MESSAGE) + self._write(buf) cdef _execute(self, str portal_name, int32_t limit): cdef WriteBuffer buf @@ -927,10 +1015,7 @@ cdef class CoreProtocol: self.result = [] - buf = WriteBuffer.new_message(b'E') - buf.write_str(portal_name, self.encoding) # name of the portal - buf.write_int32(limit) # number of rows to return; 0 - all - buf.end_message() + buf = self._build_execute_message(portal_name, limit) buf.write_bytes(SYNC_MESSAGE) @@ -1013,6 +1098,9 @@ cdef class CoreProtocol: cdef _write(self, buf): raise NotImplementedError + cdef _writelines(self, list buffers): + raise NotImplementedError + cdef _decode_row(self, const char* buf, ssize_t buf_len): pass diff --git a/asyncpg/protocol/protocol.pyx b/asyncpg/protocol/protocol.pyx index a6d9ad5d..4df256e6 100644 --- a/asyncpg/protocol/protocol.pyx +++ b/asyncpg/protocol/protocol.pyx @@ -126,11 +126,6 @@ cdef class BaseProtocol(CoreProtocol): def get_record_class(self): return self.record_class - def is_in_transaction(self): - # PQTRANS_INTRANS = idle, within transaction block - # PQTRANS_INERROR = idle, within failed transaction - return self.xact_status in (PQTRANS_INTRANS, PQTRANS_INERROR) - cdef inline resume_reading(self): if not self.is_reading: self.is_reading = True @@ -215,6 +210,7 @@ cdef class BaseProtocol(CoreProtocol): self._check_state() timeout = self._get_timeout_impl(timeout) + timer = Timer(timeout) # Make sure the argument sequence is encoded lazily with # this generator expression to keep the memory pressure under @@ -224,7 +220,7 @@ cdef class BaseProtocol(CoreProtocol): waiter = self._new_waiter(timeout) try: - self._bind_execute_many( + more = self._bind_execute_many( portal_name, state.name, arg_bufs) # network op @@ -233,6 +229,22 @@ cdef class BaseProtocol(CoreProtocol): self.statement = state self.return_extra = False self.queries_count += 1 + + while more: + with timer: + await asyncio.wait_for( + self.writing_allowed.wait(), + timeout=timer.get_remaining_budget()) + # On Windows the above event somehow won't allow context + # switch, so forcing one with sleep(0) here + await asyncio.sleep(0) + if not timer.has_budget_greater_than(0): + raise asyncio.TimeoutError + more = self._bind_execute_many_more() # network op + + except asyncio.TimeoutError as e: + self._bind_execute_many_fail(e) # network op + except Exception as ex: waiter.set_exception(ex) self._coreproto_error() @@ -893,6 +905,9 @@ cdef class BaseProtocol(CoreProtocol): cdef _write(self, buf): self.transport.write(memoryview(buf)) + cdef _writelines(self, list buffers): + self.transport.writelines(buffers) + # asyncio callbacks: def data_received(self, data): @@ -945,6 +960,13 @@ class Timer: def get_remaining_budget(self): return self._budget + def has_budget_greater_than(self, amount): + if self._budget is None: + # Unlimited budget. + return True + else: + return self._budget > amount + class Protocol(BaseProtocol, asyncio.Protocol): pass diff --git a/tests/test_execute.py b/tests/test_execute.py index 5ecc100f..8cf0d2f2 100644 --- a/tests/test_execute.py +++ b/tests/test_execute.py @@ -9,6 +9,7 @@ import asyncpg from asyncpg import _testbase as tb +from asyncpg.exceptions import UniqueViolationError class TestExecuteScript(tb.ConnectedTestCase): @@ -97,57 +98,194 @@ async def test_execute_script_interrupted_terminate(self): self.con.terminate() - async def test_execute_many_1(self): - await self.con.execute('CREATE TEMP TABLE exmany (a text, b int)') - try: - result = await self.con.executemany(''' - INSERT INTO exmany VALUES($1, $2) - ''', [ - ('a', 1), ('b', 2), ('c', 3), ('d', 4) - ]) +class TestExecuteMany(tb.ConnectedTestCase): + def setUp(self): + super().setUp() + self.loop.run_until_complete(self.con.execute( + 'CREATE TABLE exmany (a text, b int PRIMARY KEY)')) - self.assertIsNone(result) + def tearDown(self): + self.loop.run_until_complete(self.con.execute('DROP TABLE exmany')) + super().tearDown() - result = await self.con.fetch(''' - SELECT * FROM exmany - ''') + async def test_executemany_basic(self): + result = await self.con.executemany(''' + INSERT INTO exmany VALUES($1, $2) + ''', [ + ('a', 1), ('b', 2), ('c', 3), ('d', 4) + ]) - self.assertEqual(result, [ - ('a', 1), ('b', 2), ('c', 3), ('d', 4) - ]) + self.assertIsNone(result) - # Empty set - result = await self.con.executemany(''' - INSERT INTO exmany VALUES($1, $2) - ''', ()) + result = await self.con.fetch(''' + SELECT * FROM exmany + ''') - result = await self.con.fetch(''' - SELECT * FROM exmany - ''') + self.assertEqual(result, [ + ('a', 1), ('b', 2), ('c', 3), ('d', 4) + ]) - self.assertEqual(result, [ - ('a', 1), ('b', 2), ('c', 3), ('d', 4) - ]) - finally: - await self.con.execute('DROP TABLE exmany') + # Empty set + await self.con.executemany(''' + INSERT INTO exmany VALUES($1, $2) + ''', ()) - async def test_execute_many_2(self): - await self.con.execute('CREATE TEMP TABLE exmany (b int)') + result = await self.con.fetch(''' + SELECT * FROM exmany + ''') - try: - bad_data = ([1 / 0] for v in range(10)) + self.assertEqual(result, [ + ('a', 1), ('b', 2), ('c', 3), ('d', 4) + ]) - with self.assertRaises(ZeroDivisionError): - async with self.con.transaction(): - await self.con.executemany(''' - INSERT INTO exmany VALUES($1) - ''', bad_data) + async def test_executemany_bad_input(self): + bad_data = ([1 / 0] for v in range(10)) - good_data = ([v] for v in range(10)) + with self.assertRaises(ZeroDivisionError): async with self.con.transaction(): await self.con.executemany(''' - INSERT INTO exmany VALUES($1) - ''', good_data) - finally: - await self.con.execute('DROP TABLE exmany') + INSERT INTO exmany (b)VALUES($1) + ''', bad_data) + + good_data = ([v] for v in range(10)) + async with self.con.transaction(): + await self.con.executemany(''' + INSERT INTO exmany (b)VALUES($1) + ''', good_data) + + async def test_executemany_server_failure(self): + with self.assertRaises(UniqueViolationError): + await self.con.executemany(''' + INSERT INTO exmany VALUES($1, $2) + ''', [ + ('a', 1), ('b', 2), ('c', 2), ('d', 4) + ]) + result = await self.con.fetch('SELECT * FROM exmany') + self.assertEqual(result, []) + + async def test_executemany_server_failure_after_writes(self): + with self.assertRaises(UniqueViolationError): + await self.con.executemany(''' + INSERT INTO exmany VALUES($1, $2) + ''', [('a' * 32768, x) for x in range(10)] + [ + ('b', 12), ('c', 12), ('d', 14) + ]) + result = await self.con.fetch('SELECT b FROM exmany') + self.assertEqual(result, []) + + async def test_executemany_server_failure_during_writes(self): + # failure at the beginning, server error detected in the middle + pos = 0 + + def gen(): + nonlocal pos + while pos < 128: + pos += 1 + if pos < 3: + yield ('a', 0) + else: + yield 'a' * 32768, pos + + with self.assertRaises(UniqueViolationError): + await self.con.executemany(''' + INSERT INTO exmany VALUES($1, $2) + ''', gen()) + result = await self.con.fetch('SELECT b FROM exmany') + self.assertEqual(result, []) + self.assertLess(pos, 128, 'should stop early') + + async def test_executemany_client_failure_after_writes(self): + with self.assertRaises(ZeroDivisionError): + await self.con.executemany(''' + INSERT INTO exmany VALUES($1, $2) + ''', (('a' * 32768, y + y / y) for y in range(10, -1, -1))) + result = await self.con.fetch('SELECT b FROM exmany') + self.assertEqual(result, []) + + async def test_executemany_timeout(self): + with self.assertRaises(asyncio.TimeoutError): + await self.con.executemany(''' + INSERT INTO exmany VALUES(pg_sleep(0.1) || $1, $2) + ''', [('a' * 32768, x) for x in range(128)], timeout=0.5) + result = await self.con.fetch('SELECT * FROM exmany') + self.assertEqual(result, []) + + async def test_executemany_timeout_flow_control(self): + event = asyncio.Event() + + async def locker(): + test_func = getattr(self, self._testMethodName).__func__ + opts = getattr(test_func, '__connect_options__', {}) + conn = await self.connect(**opts) + try: + tx = conn.transaction() + await tx.start() + await conn.execute("UPDATE exmany SET a = '1' WHERE b = 10") + event.set() + await asyncio.sleep(1) + await tx.rollback() + finally: + event.set() + await conn.close() + + await self.con.executemany(''' + INSERT INTO exmany VALUES(NULL, $1) + ''', [(x,) for x in range(128)]) + fut = asyncio.ensure_future(locker()) + await event.wait() + with self.assertRaises(asyncio.TimeoutError): + await self.con.executemany(''' + UPDATE exmany SET a = $1 WHERE b = $2 + ''', [('a' * 32768, x) for x in range(128)], timeout=0.5) + await fut + result = await self.con.fetch( + 'SELECT * FROM exmany WHERE a IS NOT NULL') + self.assertEqual(result, []) + + async def test_executemany_client_failure_in_transaction(self): + tx = self.con.transaction() + await tx.start() + with self.assertRaises(ZeroDivisionError): + await self.con.executemany(''' + INSERT INTO exmany VALUES($1, $2) + ''', (('a' * 32768, y + y / y) for y in range(10, -1, -1))) + result = await self.con.fetch('SELECT b FROM exmany') + # only 2 batches executed (2 x 4) + self.assertEqual( + [x[0] for x in result], [y + 1 for y in range(10, 2, -1)]) + await tx.rollback() + result = await self.con.fetch('SELECT b FROM exmany') + self.assertEqual(result, []) + + async def test_executemany_client_server_failure_conflict(self): + self.con._transport.set_write_buffer_limits(65536 * 64, 16384 * 64) + with self.assertRaises(UniqueViolationError): + await self.con.executemany(''' + INSERT INTO exmany VALUES($1, 0) + ''', (('a' * 32768,) for y in range(4, -1, -1) if y / y)) + result = await self.con.fetch('SELECT b FROM exmany') + self.assertEqual(result, []) + + async def test_executemany_prepare(self): + stmt = await self.con.prepare(''' + INSERT INTO exmany VALUES($1, $2) + ''') + result = await stmt.executemany([ + ('a', 1), ('b', 2), ('c', 3), ('d', 4) + ]) + self.assertIsNone(result) + result = await self.con.fetch(''' + SELECT * FROM exmany + ''') + self.assertEqual(result, [ + ('a', 1), ('b', 2), ('c', 3), ('d', 4) + ]) + # Empty set + await stmt.executemany(()) + result = await self.con.fetch(''' + SELECT * FROM exmany + ''') + self.assertEqual(result, [ + ('a', 1), ('b', 2), ('c', 3), ('d', 4) + ])