最近在做项目的时候,需要比对两个数据库的表结构差异,由于表数量比较多,人工比对的话需要大量时间,且不可复用,于是想到用 python 写一个脚本来达到诉求,下次有相同诉求的时候只需改 sql 文件名即可。
compare_diff.py:
import re
import json# 建表语句对象
class TableStmt(object):table_name = ""create_stmt = ""# 表对象
class Table(object):table_name = ""fields = []indexes = []# 字段对象
class Field(object):field_name = ""field_type = ""# 索引对象
class Index(object):name = ""type = ""columns = ""# 自定义JSON序列化器,非必须,打印时可用到
def obj_2_dict(obj):if isinstance(obj, Field):return {"field_name": obj.field_name,"field_type": obj.field_type}elif isinstance(obj, Index):return {"name": obj.name,"type": obj.type,"columns": obj.columns}raise TypeError(f"Type {type(obj)} is not serializable")# 正则表达式模式来匹配完整的建表语句
create_table_pattern = re.compile(r"CREATE TABLE `(?P<table_name>\w+)`.*?\)\s*ENGINE[A-Za-z0-9=_ ''\n\r\u4e00-\u9fa5]+;",re.DOTALL | re.IGNORECASE
)# 正则表达式模式来匹配字段名和字段类型,只提取基本类型忽略其他信息
table_pattern = re.compile(r"^\s*`(?P<field>\w+)`\s+(?P<type>[a-zA-Z]+(?:\(\d+(?:,\d+)?\))?)",re.MULTILINE
)# 正则表达式模式来匹配索引定义
index_pattern = re.compile(r'(?<!`)KEY\s+`?(\w+)`?\s*\(([^)]+)\)|'r'PRIMARY\s+KEY\s*\(([^)]+)\)|'r'UNIQUE\s+KEY\s+`?(\w+)`?\s*\(([^)]+)\)|'r'FULLTEXT\s+KEY\s+`?(\w+)`?\s*\(([^)]+)\)',re.IGNORECASE)# 提取每个表名及建表语句
def extract_create_table_statements(sql_script):matches = create_table_pattern.finditer(sql_script)table_create_stmts = []for match in matches:tableStmt = TableStmt()tableStmt.table_name = match.group('table_name').lower() # 表名统一转换成小写tableStmt.create_stmt = match.group(0).strip() # 获取匹配到的整个建表语句table_create_stmts.append(tableStmt)return table_create_stmts# 提取索引
def extract_indexes(sql):matches = index_pattern.findall(sql)indexes = []for match in matches:index = Index()if match[0]: # 普通索引index.type = 'index'index.name = match[0].lower()index.columns = match[1].lower()elif match[2]: # 主键index.type = 'primary key'index.name = 'primary'index.columns = match[2].lower()elif match[3]: # 唯一索引index.type = 'unique index'index.name = match[3].lower()index.columns = match[4].lower()elif match[5]: # 全文索引index.type = 'fulltext index'index.name = match[5].lower()index.columns = match[6].lower()indexes.append(index)return indexes# 提取字段
def extract_fields(sql):matches = table_pattern.finditer(sql)fields = []for match in matches:field = Field()field.field_name = match.group('field').lower() # 字段名统一转换成小写field.field_type = match.group('type').lower() # 字段类型统一转换小写fields.append(field)return fields# 提取表信息
def extract_table_info(tableStmt: TableStmt):table = Table()table.table_name = tableStmt.table_name.lower()# 获取字段table.fields = extract_fields(tableStmt.create_stmt)# 获取索引table.indexes = extract_indexes(tableStmt.create_stmt)return table# 提取sql脚本中所有的表
def get_all_tables(sql_script):table_map = {}table_stmts = extract_create_table_statements(sql_script)for stmt in table_stmts:table = extract_table_info(stmt)table_map[table.table_name] = tablereturn table_map# 比较两个表的字段
def compare_fields(source: Table, target: Table):source_fields_map = {field.field_name: field for field in source.fields}target_fields_map = {field.field_name: field for field in target.fields}source_fields_not_in_target = []fields_type_not_match = []# source表有,而target表没有的字段for field in source.fields:if field.field_name not in target_fields_map.keys():source_fields_not_in_target.append(field.field_name)continuetarget_field = target_fields_map.get(field.field_name)if field.field_type != target_field.field_type:fields_type_not_match.append("field=" + field.field_name + ", source type: " + field.field_type + ", target type: " + target_field.field_type)target_fields_not_in_source = []# target表有,而source表没有的字段for field in target.fields:if field.field_name not in source_fields_map.keys():target_fields_not_in_source.append(field.field_name)continue# 不用再比较type了,因为如果这个字段在source和target都有的话,前面已经比较过type了return source_fields_not_in_target, fields_type_not_match, target_fields_not_in_source# 比较两个表的索引
def compare_indexes(source: Table, target: Table):source_indexes_map = {index.name: index for index in source.indexes}target_indexes_map = {index.name: index for index in target.indexes}source_indexes_not_in_target = []index_column_not_match = []index_type_not_match = []for index in source.indexes:if index.name not in target_indexes_map.keys():# source表有而target表没有的索引source_indexes_not_in_target.append(index.name)continuetarget_index = target_indexes_map.get(index.name)# 索引名相同,类型不同if index.type != target_index.type:index_type_not_match.append("name=" + index.name + ", source type: " + index.type + ", target type: " + target_index.type)continue# 索引名和类型都相同,字段不同if index.columns != target_index.columns:index_column_not_match.append("name=" + index.name + ", source columns=" + index.columns + ", target columns=" + target_index.columns)target_indexes_not_in_source = []for index in target.indexes:if index.name not in source_indexes_map.keys():# target表有而source表没有的索引target_indexes_not_in_source.append(index.name)continuereturn source_indexes_not_in_target, index_column_not_match, index_type_not_match, target_indexes_not_in_source# 打印比较的结果,如果结果为空列表(说明没有不同)则不打印
def print_diff(desc, compare_result):if len(compare_result) > 0:print(f"{desc} {compare_result}")# 比较脚本里面的所有表
def compare_table(source_sql_script, target_sql_script):source_table_map = get_all_tables(source_sql_script)target_table_map = get_all_tables(target_sql_script)source_table_not_in_target = []for key, source_table in source_table_map.items():# 只比较白名单里面的表if len(white_list_tables) > 0 and key not in white_list_tables:continue# 不比较黑名单里面的表if len(black_list_tables) > 0 and key in black_list_tables:continueif key not in target_table_map.keys():# source有而target没有的表source_table_not_in_target.append(key)continuetarget_table = target_table_map[key]# 比较字段(source_fields_not_in_target, fields_type_not_match, target_fields_not_in_source) = compare_fields(source_table, target_table)# 比较索引(source_indexes_not_in_target, index_column_not_match, index_type_not_match, target_indexes_not_in_source) = compare_indexes(source_table, target_table)print(f"====== table = {key} ======")print_diff("source field not in target, fields:", source_fields_not_in_target)print_diff("target field not in source, fields:", target_fields_not_in_source)print_diff("field type not match:", fields_type_not_match)print_diff("source index not in target, indexes:", source_indexes_not_in_target)print_diff("target index not in source, indexes:", target_indexes_not_in_source)print_diff("index type not match:", index_type_not_match)print_diff("index column not match:", index_column_not_match)print("")# 找出target有而source没有的表target_table_not_in_source = []for key, target_table in target_table_map.items():# 只比较白名单里面的表if len(white_list_tables) > 0 and key not in white_list_tables:continue# 不比较黑名单里面的表if len(black_list_tables) > 0 and key in black_list_tables:continueif key not in source_table_map.keys():target_table_not_in_source.append(key)print_diff("source table not in target, table list:", source_table_not_in_target)print_diff("target table not in source, table list:", target_table_not_in_source)# 读取sql文件
def sql_read(file_name):with open(file_name, "r", encoding='utf-8') as file:return file.read()def print_all_tables():table_map = get_all_tables(sql_read("sql1.sql"))for key, item in table_map.items():print(key)print(json.dumps(item.fields, default=obj_2_dict, ensure_ascii=False, indent=4))print(json.dumps(item.indexes, default=obj_2_dict, ensure_ascii=False, indent=4))print("")# print_all_tables()# 黑白名单设置,适用于只比较所有表中一部分表的情况
# 白名单表,不为空的话,只比较这里面的表
white_list_tables = []
# 黑名单表,不为空的话,不比较这里面的表
black_list_tables = []if __name__ == '__main__':# 说明:mysql默认大小写不敏感,如果数据库设置了大小写敏感,脚本需要修改,里面所有的表名、字段名和索引名都默认转了小写再去比较的source_script = sql_read("sql1.sql")target_script = sql_read("sql2.sql")compare_table(source_script, target_script)
运行效果如下:
====== table = table1 ======
source field not in target, fields: ['age', 'email']
target field not in source, fields: ['name']
field type not match: ['field=created_at, source type: date, target type: bigint(20)', 'field=updated_at, source type: timestamp, target type: date']
source index not in target, indexes: ['unique_name']
target index not in source, indexes: ['idx_country_env']====== table = table2 ======
index type not match: ['name=fulltext_index, source type: fulltext index, target type: index']
index column not match: ['name=index, source columns=`age`, target columns=`description`']====== table = table3 ======
index column not match: ['name=primary, source columns=`id`, `value`, target columns=`value`, `id`']source table not in target, table list: ['activity_instance']
target table not in source, table list: ['table5']
结果说明:
- 按照 table 来打印 source table 和 target table 的字段和索引差异,此时 table 在两个 sql 脚本里都存在
- 最后打印只在其中一个 sql 脚本里存在的 table list
sql1.sql:
CREATE TABLE `table1` (`id` INT(11) NOT NULL AUTO_INCREMENT,`age` INT(11) DEFAULT NULL,`email` varchar(32) DEFAULT NULL COMMENT '邮箱',`created_at` date DEFAULT NULL,`updated_at` TIMESTAMP ON UPDATE CURRENT_TIMESTAMP,PRIMARY KEY (`id`),UNIQUE KEY `unique_name` (`name`)
) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COMMENT ='测试表';CREATE TABLE `table2` (`id` INT(11) NOT NULL,`description` TEXT NOT NULL,`created_at` TIMESTAMP DEFAULT CURRENT_TIMESTAMP,PRIMARY KEY (`id`),UNIQUE KEY `unique_name` (`name`),KEY `index` (`age`),FULLTEXT KEY `fulltext_index` (`name`, `age`)
) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4;CREATE TABLE `table3` (`id` INT(11) NOT NULL AUTO_INCREMENT,`value` DECIMAL(10,2) NOT NULL,`updated_at` TIMESTAMP ON UPDATE CURRENT_TIMESTAMP,PRIMARY KEY (`id`, `value`)
) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4;/******************************************/
/* DatabaseName = database */
/* TableName = activity_instance */
/******************************************/
CREATE TABLE `activity_instance`
(`id` bigint(20) unsigned NOT NULL AUTO_INCREMENT COMMENT '主键',`gmt_create` bigint(20) NOT NULL COMMENT '创建时间',`gmt_modified` bigint(20) NOT NULL COMMENT '修改时间',`activity_name` varchar(400) NOT NULL COMMENT '活动名称',`benefit_type` varchar(16) DEFAULT NULL,`benefit_id` varchar(32) DEFAULT NULL,PRIMARY KEY (`id`),KEY `idx_country_env` (`env`, `country_code`),KEY `idx_benefit_type_id` (`benefit_type`, `benefit_id`)
) ENGINE = InnoDBAUTO_INCREMENT = 139DEFAULT CHARSET = utf8mb4 COMMENT ='活动时间模板表'
;
sql2.sql:
CREATE TABLE `TABLE1` (`id` INT(11) NOT NULL AUTO_INCREMENT,`name` VARCHAR(255) NOT NULL,`created_at` bigint(20) DEFAULT NULL,`updated_at` date ON UPDATE CURRENT_TIMESTAMP,PRIMARY KEY (`id`),KEY `idx_country_env` (`env`, `country_code`),
) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COMMENT ='测试表';CREATE TABLE `table2` (`id` INT(11) NOT NULL,`description` TEXT NOT NULL,`created_at` TIMESTAMP DEFAULT CURRENT_TIMESTAMP,PRIMARY KEY (`id`),UNIQUE KEY `unique_name` (`name`),KEY `index` (`description`),KEY `fulltext_index` (`name`, `age`)
) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4;CREATE TABLE `table3` (`id` INT(11) NOT NULL AUTO_INCREMENT,`value` DECIMAL(10,2) NOT NULL,`updated_at` TIMESTAMP ON UPDATE CURRENT_TIMESTAMP,PRIMARY KEY (`value`, `id`)
) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4;CREATE TABLE `TABLE5` (`id` INT(11) NOT NULL AUTO_INCREMENT,`value` DECIMAL(10,2) NOT NULL,`updated_at` TIMESTAMP ON UPDATE CURRENT_TIMESTAMP
) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4;
把 python 和 sql 脚本拷贝下来分别放在同一个目录下的3个文件中即可,示例在 python 3.12 环境上成功运行。