Skip to content

Commit

Permalink
add extras execute_batch
Browse files Browse the repository at this point in the history
  • Loading branch information
vir-mir committed Dec 5, 2019
1 parent 2911c10 commit 17b5e4c
Show file tree
Hide file tree
Showing 2 changed files with 88 additions and 0 deletions.
43 changes: 43 additions & 0 deletions aiopg/extras.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
def _paginate(seq, page_size):
"""Consume an iterable and return it in chunks.
Every chunk is at most `page_size`. Never return an empty chunk.
"""
page = []
count = len(seq)
it = iter(seq)
for s in range(count + 1):
try:
for i in range(page_size):
page.append(next(it))
yield page
page = []
except StopIteration:
if page:
yield page
return


async def execute_batch(cur, sql, argslist, page_size=100):
r"""Execute groups of statements in fewer server roundtrips.
Execute *sql* several times, against all parameters set (sequences or
mappings) found in *argslist*.
The function is semantically similar to
.. parsed-literal::
*cur*\.\ `~cursor.executemany`\ (\ *sql*\ , *argslist*\ )
but has a different implementation: Psycopg will join the statements into
fewer multi-statement commands, each one containing at most *page_size*
statements, resulting in a reduced number of server roundtrips.
After the execution of the function the `cursor.rowcount` property will
**not** contain a total result.
"""
for page in _paginate(argslist, page_size=page_size):
sqls = [cur.mogrify(sql, args) for args in page]
await cur.execute(b";".join(sqls))
45 changes: 45 additions & 0 deletions tests/test_extras.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
import pytest

from aiopg.extras import _paginate, execute_batch


@pytest.fixture
def connect(make_connection):
async def go(**kwargs):
conn = await make_connection(**kwargs)
async with conn.cursor() as cur:
await cur.execute("DROP TABLE IF EXISTS tbl_extras")
await cur.execute("CREATE TABLE tbl_extras (id int)")
return conn

return go


@pytest.fixture
def cursor(connect, loop):
async def go():
return await (await connect()).cursor()

cur = loop.run_until_complete(go())
yield cur
cur.close()


def test__paginate():
data = [
[1, 2, 3],
[4, 5, 6],
[7],
]
for index, val in enumerate(_paginate((1, 2, 3, 4, 5, 6, 7), page_size=3)):
assert data[index] == list(val)


async def test_execute_batch(cursor):
args = [(1,), (2,), (3,), (4,)]
sql = 'insert into tbl_extras values(%s)'
await execute_batch(cursor, sql, argslist=args, page_size=3)

await cursor.execute('SELECT * from tbl_extras')
ret = await cursor.fetchall()
assert list(ret) == args

0 comments on commit 17b5e4c

Please sign in to comment.