diff --git a/mysql2pgsql/lib/__init__.py b/mysql2pgsql/lib/__init__.py index 27dbf36..bbf396b 100644 --- a/mysql2pgsql/lib/__init__.py +++ b/mysql2pgsql/lib/__init__.py @@ -51,6 +51,7 @@ def status_logger(f): constraints_template = 'ADDING CONSTRAINTS ON %s' write_contents_template = 'WRITING DATA TO %s' index_template = 'ADDING INDEXES TO %s' + trigger_template = 'ADDING TRIGGERS TO %s' statuses = { 'truncate': { 'start': start_template % truncate_template, @@ -72,6 +73,10 @@ def status_logger(f): 'start': start_template % index_template, 'finish': finish_template % index_template, }, + 'write_triggers': { + 'start': start_template % trigger_template, + 'finish': finish_template % trigger_template, + }, } @wraps(f) diff --git a/mysql2pgsql/lib/converter.py b/mysql2pgsql/lib/converter.py index 72251ad..02d5d41 100644 --- a/mysql2pgsql/lib/converter.py +++ b/mysql2pgsql/lib/converter.py @@ -55,7 +55,7 @@ def convert(self): if not self.supress_ddl: if self.verbose: - print_start_table('START CREATING INDEXES AND CONSTRAINTS') + print_start_table('START CREATING INDEXES, CONSTRAINTS, AND TRIGGERS') for table in tables: self.writer.write_indexes(table) @@ -63,8 +63,11 @@ def convert(self): for table in tables: self.writer.write_constraints(table) + for table in tables: + self.writer.write_triggers(table) + if self.verbose: - print_start_table('DONE CREATING INDEXES AND CONSTRAINTS') + print_start_table('DONE CREATING INDEXES, CONSTRAINTS, AND TRIGGERS') if self.verbose: print_start_table('\n\n>>>>>>>>>> FINISHED <<<<<<<<<<') diff --git a/mysql2pgsql/lib/mysql_reader.py b/mysql2pgsql/lib/mysql_reader.py index d74a27e..d75b78d 100644 --- a/mysql2pgsql/lib/mysql_reader.py +++ b/mysql2pgsql/lib/mysql_reader.py @@ -83,8 +83,10 @@ def __init__(self, reader, name): self._name = name self._indexes = [] self._foreign_keys = [] + self._triggers = [] self._columns = self._load_columns() self._load_indexes() + self._load_triggers() def _convert_type(self, data_type): """Normalize MySQL `data_type`""" @@ -181,6 +183,22 @@ def _load_indexes(self): self._indexes.append(index) continue + def _load_triggers(self): + explain = self.reader.db.query('SHOW TRIGGERS WHERE `table` = \'%s\'' % self.name) + for row in explain: + if type(row) is tuple: + trigger = {} + trigger['name'] = row[0] + trigger['event'] = row[1] + trigger['statement'] = row[3] + trigger['timing'] = row[4] + + trigger['statement'] = re.sub('^BEGIN', '', trigger['statement']) + trigger['statement'] = re.sub('^END', '', trigger['statement'], flags=re.MULTILINE) + trigger['statement'] = re.sub('`', '', trigger['statement']) + + self._triggers.append(trigger) + @property def name(self): return self._name @@ -197,6 +215,10 @@ def indexes(self): def foreign_keys(self): return self._foreign_keys + @property + def triggers(self): + return self._triggers + @property def query_for(self): return 'SELECT %(column_names)s FROM `%(table_name)s`' % { diff --git a/mysql2pgsql/lib/postgres_db_writer.py b/mysql2pgsql/lib/postgres_db_writer.py index 2e19170..fc37633 100644 --- a/mysql2pgsql/lib/postgres_db_writer.py +++ b/mysql2pgsql/lib/postgres_db_writer.py @@ -166,6 +166,19 @@ def write_indexes(self, table): for sql in index_sql: self.execute(sql) + @status_logger + def write_triggers(self, table): + """Send DDL to create the specified `table` triggers + + :Parameters: + - `table`: an instance of a :py:class:`mysql2pgsql.lib.mysql_reader.MysqlReader.Table` object that represents the table to read/write. + + Returns None + """ + index_sql = super(PostgresDbWriter, self).write_triggers(table) + for sql in index_sql: + self.execute(sql) + @status_logger def write_constraints(self, table): """Send DDL to create the specified `table` constraints diff --git a/mysql2pgsql/lib/postgres_file_writer.py b/mysql2pgsql/lib/postgres_file_writer.py index 7205f6f..ba5b209 100644 --- a/mysql2pgsql/lib/postgres_file_writer.py +++ b/mysql2pgsql/lib/postgres_file_writer.py @@ -100,6 +100,17 @@ def write_constraints(self, table): """ self.f.write('\n'.join(super(PostgresFileWriter, self).write_constraints(table))) + @status_logger + def write_triggers(self, table): + """Write TRIGGERs existing on `table` to the output file + + :Parameters: + - `table`: an instance of a :py:class:`mysql2pgsql.lib.mysql_reader.MysqlReader.Table` object that represents the table to read/write. + + Returns None + """ + self.f.write('\n'.join(super(PostgresFileWriter, self).write_triggers(table))) + @status_logger def write_contents(self, table, reader): """Write the data contents of `table` to the output file. diff --git a/mysql2pgsql/lib/postgres_writer.py b/mysql2pgsql/lib/postgres_writer.py index 3165520..86c3adb 100644 --- a/mysql2pgsql/lib/postgres_writer.py +++ b/mysql2pgsql/lib/postgres_writer.py @@ -7,7 +7,6 @@ from psycopg2.extensions import QuotedString, Binary, AsIs from pytz import timezone - class PostgresWriter(object): """Base class for :py:class:`mysql2pgsql.lib.postgres_file_writer.PostgresFileWriter` and :py:class:`mysql2pgsql.lib.postgres_db_writer.PostgresDbWriter`. @@ -148,6 +147,9 @@ def process_row(self, table, row): row[index] = '1970-01-01 00:00:00' elif 'bit' in column_type: row[index] = bin(ord(row[index]))[2:] + elif column_type == 'boolean': + # We got here because you used a tinyint(1), if you didn't want a bool, don't use that type + row[index] = 't' if row[index] not in (None, 0) else 'f' if row[index] == 0 else row[index] elif isinstance(row[index], (str, unicode, basestring)): if column_type == 'bytea': row[index] = Binary(row[index]).getquoted()[1:-8] if row[index] else row[index] @@ -155,9 +157,6 @@ def process_row(self, table, row): row[index] = '{%s}' % ','.join('"%s"' % v.replace('"', r'\"') for v in row[index].split(',')) else: row[index] = row[index].replace('\\', r'\\').replace('\n', r'\n').replace('\t', r'\t').replace('\r', r'\r').replace('\0', '') - elif column_type == 'boolean': - # We got here because you used a tinyint(1), if you didn't want a bool, don't use that type - row[index] = 't' if row[index] not in (None, 0) else 'f' if row[index] == 0 else row[index] elif isinstance(row[index], (date, datetime)): if isinstance(row[index], datetime) and self.tz: try: @@ -259,6 +258,32 @@ def write_constraints(self, table): 'ref_column_name': key['ref_column']}) return constraint_sql + def write_triggers(self, table): + trigger_sql = [] + for key in table.triggers: + trigger_sql.append("""CREATE OR REPLACE FUNCTION %(fn_trigger_name)s RETURNS TRIGGER AS $%(trigger_name)s$ + BEGIN + %(trigger_statement)s + RETURN NULL; + END; + $%(trigger_name)s$ LANGUAGE plpgsql;""" % { + 'table_name': table.name, + 'trigger_time': key['timing'], + 'trigger_event': key['event'], + 'trigger_name': key['name'], + 'fn_trigger_name': 'fn_' + key['name'] + '()', + 'trigger_statement': key['statement']}) + + trigger_sql.append("""CREATE TRIGGER %(trigger_name)s %(trigger_time)s %(trigger_event)s ON %(table_name)s + FOR EACH ROW + EXECUTE PROCEDURE fn_%(trigger_name)s();""" % { + 'table_name': table.name, + 'trigger_time': key['timing'], + 'trigger_event': key['event'], + 'trigger_name': key['name']}) + + return trigger_sql + def close(self): raise NotImplementedError