diff --git a/src/base/tests/test_util.py b/src/base/tests/test_util.py index 325368fe1..fa9f924d8 100644 --- a/src/base/tests/test_util.py +++ b/src/base/tests/test_util.py @@ -18,6 +18,11 @@ from odoo import modules from odoo.tools import mute_logger +try: + from odoo.sql_db import db_connect +except ImportError: + from openerp.sql_db import db_connect + from odoo.addons.base.maintenance.migrations import util from odoo.addons.base.maintenance.migrations.testing import UnitTestCase, parametrize from odoo.addons.base.maintenance.migrations.util import snippets @@ -1494,6 +1499,27 @@ def test_iter(self): self.assertEqual(result, expected) +class TestQueryIds(UnitTestCase): + def test_straight(self): + result = list(util.query_ids(self.env.cr, "SELECT * FROM (VALUES (1), (2)) AS x(x)", itersize=2)) + self.assertEqual(result, [1, 2]) + + def test_chunks(self): + with util.query_ids(self.env.cr, "SELECT * FROM (VALUES (1), (2)) AS x(x)") as ids: + result = list(util.chunks(ids, 100, fmt=list)) + self.assertEqual(result, [[1, 2]]) + + def test_destructor(self): + ids = util.query_ids(self.env.cr, "SELECT id from res_users") + del ids + + def test_pk_violation(self): + with db_connect(self.env.cr.dbname).cursor() as cr, mute_logger("odoo.sql_db"), self.assertRaises( + ValueError + ), util.query_ids(cr, "SELECT * FROM (VALUES (1), (1)) AS x(x)") as ids: + list(ids) + + class TestRecords(UnitTestCase): def test_rename_xmlid(self): cr = self.env.cr diff --git a/src/util/models.py b/src/util/models.py index c92d42a0b..d3c769b56 100644 --- a/src/util/models.py +++ b/src/util/models.py @@ -27,6 +27,7 @@ get_m2m_tables, get_value_or_en_translation, parallel_execute, + query_ids, table_exists, update_m2m_tables, view_exists, @@ -128,17 +129,17 @@ def remove_model(cr, model, drop_table=True, ignore_m2m=()): 'SELECT id FROM "{}" r WHERE {}'.format(ir.table, ir.model_filter(prefix="r.")), [model] ).decode() - cr.execute(query) - if ir.table == "ir_ui_view": - for (view_id,) in cr.fetchall(): - remove_view(cr, view_id=view_id, silent=True) - else: - # remove in batch - size = (cr.rowcount + chunk_size - 1) / chunk_size - it = chunks([id for (id,) in cr.fetchall()], chunk_size, fmt=tuple) - for sub_ids in log_progress(it, _logger, qualifier=ir.table, size=size): - remove_records(cr, ref_model, sub_ids) - _rm_refs(cr, ref_model, sub_ids) + with query_ids(cr, query, itersize=chunk_size) as ids_: + if ir.table == "ir_ui_view": + for view_id in ids_: + remove_view(cr, view_id=view_id, silent=True) + else: + # remove in batch + size = (len(ids_) + chunk_size - 1) / chunk_size + it = chunks(ids_, chunk_size, fmt=tuple) + for sub_ids in log_progress(it, _logger, qualifier=ir.table, size=size): + remove_records(cr, ref_model, sub_ids) + _rm_refs(cr, ref_model, sub_ids) if ir.set_unknown: # Link remaining records not linked to a XMLID diff --git a/src/util/orm.py b/src/util/orm.py index 228c572ca..8638f9a85 100644 --- a/src/util/orm.py +++ b/src/util/orm.py @@ -42,7 +42,7 @@ from .exceptions import MigrationError from .helpers import table_of_model from .misc import chunks, log_progress, version_between, version_gte -from .pg import SQLStr, column_exists, format_query, get_columns, named_cursor +from .pg import SQLStr, column_exists, format_query, get_columns, named_cursor, query_ids # python3 shims try: @@ -288,27 +288,16 @@ def recompute_fields(cr, model, fields, ids=None, logger=_logger, chunk_size=256 Model = env(cr)[model] if isinstance(model, basestring) else model model = Model._name - if ids is None: - query = format_query(cr, "SELECT id FROM {}", table_of_model(cr, model)) if query is None else SQLStr(query) - cr.execute( - format_query(cr, "CREATE UNLOGGED TABLE _upgrade_rf(id) AS (WITH query AS ({}) SELECT * FROM query)", query) + ids_ = ids + if ids_ is None: + ids_ = query_ids( + cr, + format_query(cr, "SELECT id FROM {}", table_of_model(cr, model)) if query is None else SQLStr(query), + itersize=2**20, ) - count = cr.rowcount - cr.execute("ALTER TABLE _upgrade_rf ADD CONSTRAINT pk_upgrade_rf_id PRIMARY KEY (id)") - - def get_ids(): - with named_cursor(cr, itersize=2**20) as ncr: - ncr.execute("SELECT id FROM _upgrade_rf ORDER BY id") - for (id_,) in ncr: - yield id_ - - ids_ = get_ids() - else: - count = len(ids) - ids_ = ids + count = len(ids_) if not count: - cr.execute("DROP TABLE IF EXISTS _upgrade_rf") return _logger.info("Computing fields %s of %r on %d records", fields, model, count) @@ -338,7 +327,6 @@ def get_ids(): else: flush(records) invalidate(records) - cr.execute("DROP TABLE IF EXISTS _upgrade_rf") class iter_browse(object): @@ -374,7 +362,9 @@ class iter_browse(object): :param model: the model to iterate :type model: :class:`odoo.model.Model` - :param list(int) ids: list of IDs of the records to iterate + :param iterable(int) ids: iterable of IDs of the records to iterate + :param str query: alternative to ids, SQL query that can produce them. + Can also be a DML statement with a RETURNING clause. :param int chunk_size: number of records to load in each iteration chunk, `200` by default :param bool yield_chunks: when iterating, yield records in chunks of `chunk_size` instead of one by one. @@ -389,14 +379,27 @@ class iter_browse(object): See also :func:`~odoo.upgrade.util.orm.env` """ - __slots__ = ("_chunk_size", "_cr_uid", "_it", "_logger", "_model", "_patch", "_size", "_strategy", "_yield_chunks") + __slots__ = ( + "_chunk_size", + "_cr_uid", + "_ids", + "_it", + "_logger", + "_model", + "_patch", + "_query", + "_size", + "_strategy", + "_yield_chunks", + ) def __init__(self, model, *args, **kw): assert len(args) in [1, 3] # either (cr, uid, ids) or (ids,) self._model = model self._cr_uid = args[:-1] - ids = args[-1] - self._size = len(ids) + self._ids = args[-1] + self._size = kw.pop("size", None) + self._query = kw.pop("query", None) self._chunk_size = kw.pop("chunk_size", 200) # keyword-only argument self._yield_chunks = kw.pop("yield_chunks", False) self._logger = kw.pop("logger", _logger) @@ -405,8 +408,32 @@ def __init__(self, model, *args, **kw): if kw: raise TypeError("Unknown arguments: %s" % ", ".join(kw)) + if not (self._ids is None) ^ (self._query is None): + raise TypeError("Must be initialized using exactly one of `ids` or `query`") + + if self._query: + self._ids = query_ids(self._model.env.cr, self._query, itersize=self._chunk_size) + + if not self._size: + try: + self._size = len(self._ids) + except TypeError: + raise ValueError("When passing ids as a generator, the size kwarg is mandatory") self._patch = None - self._it = chunks(ids, self._chunk_size, fmt=self._browse) + self._it = chunks(self._ids, self._chunk_size, fmt=self._browse) + + def _values_query(self, query): + cr = self._model.env.cr + cr.execute(format_query(cr, "WITH query AS ({}) SELECT count(*) FROM query", SQLStr(query))) + size = cr.fetchone()[0] + + def get_values(): + with named_cursor(cr, itersize=self._chunk_size) as ncr: + ncr.execute(SQLStr(query)) + for row in ncr.iterdict(): + yield row + + return size, get_values() def _browse(self, ids): next(self._end(), None) @@ -459,35 +486,47 @@ def caller(*args, **kwargs): self._it = None return caller - def create(self, values, **kw): + def create(self, values=None, query=None, **kw): """ Create records. An alternative to the default `create` method of the ORM that is safe to use to create millions of records. - :param list(dict) values: list of values of the records to create + :param iterable(dict) values: iterable of values of the records to create + :param int size: the no. of elements produced by values, required if values is a generator + :param str query: alternative to values, SQL query that can produce them. + *No* DML statements allowed. Only SELECT. :param bool multi: whether to use the multi version of `create`, by default is `True` from Odoo 12 and above """ multi = kw.pop("multi", version_gte("saas~11.5")) + size = kw.pop("size", None) if kw: raise TypeError("Unknown arguments: %s" % ", ".join(kw)) - if not values: - raise ValueError("`create` cannot be called with an empty `values` argument") + if not (values is None) ^ (query is None): + raise ValueError("`create` needs to be called using exactly one of `values` or `query` arguments") if self._size: raise ValueError("`create` can only called on empty `browse_record` objects.") - ids = [] - size = len(values) + if query: + size, values = self._values_query(query) + + if size is None: + try: + size = len(values) + except TypeError: + raise ValueError("When passing a generator of values, the size kwarg is mandatory") + it = chunks(values, self._chunk_size, fmt=list) if self._logger: sz = (size + self._chunk_size - 1) // self._chunk_size qualifier = "env[%r].create([:%d])" % (self._model._name, self._chunk_size) it = log_progress(it, self._logger, qualifier=qualifier, size=sz) + ids = [] self._patch = no_selection_cache_validation() for sub_values in it: self._patch.start() diff --git a/src/util/pg.py b/src/util/pg.py index 0ab92a9c2..ec2d47e27 100644 --- a/src/util/pg.py +++ b/src/util/pg.py @@ -1932,3 +1932,83 @@ def bulk_update_table(cr, table, columns, mapping, key_col="id"): key_col=key_col, ) cr.execute(query, [Json(mapping)]) + + +class query_ids(object): + """ + Iterator over ids returned by a query. + + This allows iteration over a potentially huge number of ids without exhausting memory. + + :param str query: the query that returns the ids. It can be DML, e.g. `UPDATE table WHERE ... RETURNING id`. + :param int itersize: determines the number of rows fetched from PG at once, see :func:`~odoo.upgrade.util.pg.named_cursor`. + """ + + def __init__(self, cr, query, itersize=None): + self._ncr = None + self._cr = cr + self._tmp_tbl = "_upgrade_query_ids_{}".format(uuid.uuid4().hex) + cr.execute( + format_query( + cr, + "CREATE UNLOGGED TABLE {}(id) AS (WITH query AS ({}) SELECT * FROM query)", + self._tmp_tbl, + SQLStr(query), + ) + ) + self._len = cr.rowcount + try: + cr.execute( + format_query( + cr, + "ALTER TABLE {} ADD CONSTRAINT {} PRIMARY KEY (id)", + self._tmp_tbl, + "pk_{}_id".format(self._tmp_tbl), + ) + ) + except psycopg2.IntegrityError as e: + if e.pgcode == errorcodes.UNIQUE_VIOLATION: + raise ValueError("The query for ids is producing duplicate values:\n{}".format(query)) + raise + self._ncr = named_cursor(cr, itersize) + self._ncr.execute(format_query(cr, "SELECT id FROM {} ORDER BY id", self._tmp_tbl)) + self._it = iter(self._ncr) + + def _close(self): + if self._ncr: + if self._ncr.closed: + return + self._ncr.close() + try: + self._cr.execute(format_query(self._cr, "DROP TABLE IF EXISTS {}", self._tmp_tbl)) + except psycopg2.InternalError as e: + if e.pgcode != errorcodes.IN_FAILED_SQL_TRANSACTION: + raise + + def __len__(self): + return self._len + + def __iter__(self): + return self + + def __next__(self): + if self._ncr.closed: + raise StopIteration + try: + return next(self._it)[0] + except StopIteration: + self._close() + raise + + def next(self): + return self.__next__() + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_value, traceback): + self._close() + return False + + def __del__(self): + self._close()