构建支持多种数据库类型的代码自动生成工具
背景:
一般的业务代码中写来写去,无外乎是先建好model,然后针对这个model做些CRUD的操作。(主要针对单表的业务操作)针对于数据库dao、mapper等的代码自动生成已经有了mybatisGenerator这种工具,但是针对于controller、service这些我们现在的接口api一般遵循的是restful风格,因此这些也是有规则可循的。举例有个goodsInfo 的model,针对于他的操作,肯定有 单个查询、list查询、修改、删除等。而这些代码没必要复制粘贴来一遍,完全可以由工具自动生成,若有特殊业务场景重写即可。本工具就算解决这类问题的。
效果截图
运行生成示例结果:
表选择界面:
思路:
代码自动生成说起来很神秘,其实无外乎两个方面:
从数据库拿到需要自动生成的代码对应表。
从表结构、字段名生成对应的mapper、model、及controller、service等
如何拿到需要自动生成的代码对应表
sqlservr、mysql、oracle等这些主流数据库中都存在系统表结构的表,存储的是所有用户自己建立表的名称、字段等,所以直接查询这些系统表即可罗列出所有业务表。然后做个可视化界面供用户选择即可。(这里做一下更新,我实际项目中没有用sql查询的方式,因为不同数据库对于系统表的存储方式各不相同,查询语句写的太蛋疼了,实际采用的是 conn.getMetaData() 的方式,采用元数据来拿到指定数据库中各种表结构信息)
如何自动生成代码
有了表结构、字段名等如何自动生成代码呢,这个时候就需要模板引擎了。简单来讲可以理解为把固定的地方写死,变化的地方按照规则替换。
可以用我们小时候写作文的例子来说明。我们(作文厉害的请自动忽略 “们” 😃)小时候写作文,一般是3段式,开头、结尾、和中间流水账。 开头一般是描写环境心情、中间讲述具体故事,结尾总结赞美。
今天天气不错,风和日丽的,我们早早就来到了学些,大家都很开心。(开头)
小明,突然在地上捡到了一个钱包……(一顿思想斗争,最后交给了警察叔叔)
最后,这个故事告诉了我们……(结尾)
从上面的示例范文中是不是很熟悉,。基本上都是这个结构,中间基本可以随意替换,最后都能凑成一篇基本合格的小学作文。而我们现在要做的就是把一些表名称、字段名称当做需要填充的内容填充到指定的代码段中去。
具体实现
获取数据库表、字段等信息
好了,上面讲了一大堆废话(背景和思路,个人觉得还是有必要的),下面到具体实现中来。
获取数据库表结构(表、字段)信息关键代码如下
@Service
public class DbServiceImpl implements IDbService {
private Logger logger = LoggerFactory.getLogger(this.getClass());
@Value("${spring.datasource.driverClassName}")
private String driverClassName;
@Value("${spring.datasource.url}")
private String url;
@Value("${spring.datasource.username}")
private String user;
@Value("${spring.datasource.password}")
private String pwd;
@Override
public List getTables(String tableName) {
List tables = new ArrayList<>();
try {
Class.forName(driverClassName);// 动态加载mysql驱动
Connection connection = DriverManager.getConnection(url, user, pwd);
DatabaseMetaData metaData = connection.getMetaData();
ResultSet resultSet = metaData.getTables(null, null, "%", new String[]{"TABLE"});
// metaData.getTables("yjc", "", "%", new String[]{"TABLE"})
while (resultSet.next()) {
if (resultSet.getString("TABLE_NAME").contains(tableName)) {
TableEntity tmpTable = new TableEntity();
tmpTable.setTbName(resultSet.getString("TABLE_NAME"));
tmpTable.setComments(resultSet.getString("REMARKS"));
tmpTable.setCatalog(resultSet.getString("TABLE_CAT"));
tmpTable.setSchema(resultSet.getString("TABLE_SCHEM"));
tables.add(tmpTable);
}
}
} catch (Exception e) {
logger.error("获取数据库表列表失败", e);
}
return tables;
}
@Override
public List getColumns(String tableName) {
List columnEntityList = new ArrayList<>();
try {
Class.forName(driverClassName);// 动态加载mysql驱动
Connection connection = DriverManager.getConnection(url, user, pwd);
DatabaseMetaData metaData = connection.getMetaData();
ResultSet resultSet = metaData.getColumns(null, null, tableName, "%");
while (resultSet.next()) {
ColumnEntity tmpColumnEntity = new ColumnEntity();
tmpColumnEntity.setColumnName(resultSet.getString("COLUMN_NAME"));
tmpColumnEntity.setBufferLength(resultSet.getInt("BUFFER_LENGTH"));
tmpColumnEntity.setColumnSize(resultSet.getInt("COLUMN_SIZE"));
tmpColumnEntity.setComments(resultSet.getString("REMARKS"));
tmpColumnEntity.setDecimalDigits(resultSet.getInt("DECIMAL_DIGITS"));
tmpColumnEntity.setDataType(resultSet.getInt("DATA_TYPE"));
tmpColumnEntity.setTypeName(resultSet.getString("TYPE_NAME"));
tmpColumnEntity.setIsNullAble(resultSet.getString("IS_NULLABLE"));
tmpColumnEntity.setIsAutoIncrement(resultSet.getString("IS_AUTOINCREMENT"));
columnEntityList.add(tmpColumnEntity);
}
} catch (Exception e) {
logger.error("查询表的列发生异常", e);
}
return columnEntityList;
}
@Override
public TableEntity getTableEntity(String tableName) {
TableEntity tableEntity = new TableEntity();
try {
Class.forName(driverClassName);// 动态加载mysql驱动
Connection connection = DriverManager.getConnection(url, user, pwd);
DatabaseMetaData metaData = connection.getMetaData();
ResultSet resultSet = metaData.getTables(null, null, tableName, new String[]{"TABLE"});
while (resultSet.next()) {
if(tableName.equals(resultSet.getString("TABLE_NAME")))
{
tableEntity.setTbName(resultSet.getString("TABLE_NAME"));
tableEntity.setComments(resultSet.getString("REMARKS"));
tableEntity.setCatalog(resultSet.getString("TABLE_CAT"));
tableEntity.setSchema(resultSet.getString("TABLE_SCHEM"));
tableEntity.setPk(this.getPrimaryKeyColumnName(metaData,tableEntity.getCatalog(),tableEntity.getSchema(),tableEntity.getTbName()));
}
}
} catch (Exception e) {
logger.error("获取表对象失败", e);
}
return tableEntity;
}
private String getPrimaryKeyColumnName(DatabaseMetaData metaData,String catalog,String schema,String tableName)
{
String primaryKeyColumnName="";
try {
ResultSet resultSet = metaData.getPrimaryKeys(catalog, schema, tableName);
while (resultSet.next())
{
primaryKeyColumnName= resultSet.getString("COLUMN_NAME");
}
} catch (SQLException e) {
logger.error("获取主键发生异常",e);
}
return primaryKeyColumnName;
}
}
另外在使用 DatabaseMetaData获取表、列信息的时候,如
DatabaseMetaData metaData = connection.getMetaData();
ResultSet resultSet = metaData.getColumns(null, null, tableName, "%");
DatabaseMetaData metaData = connection.getMetaData();
ResultSet resultSet = metaData.getTables(null, null, "%", new String[]{"TABLE"});
获取表格信息、获取列信息都是返回的 ResultSet,这个ResultSet 有点蛋疼,需要按照字段来查询,或者指定索引顺序来获取想要的结果,对照关系如下面截图
使用模板生成代码
使用的是velocity引擎(当然也可以使用freemarker等,这个不重要)
模板代码示例如下:
package ${package}.${moduleName}.entity;
import com.baomidou.mybatisplus.annotations.TableId;
import com.baomidou.mybatisplus.annotations.TableName;
#if(${hasBigDecimal})
import java.math.BigDecimal;
#end
import java.io.Serializable;
import java.util.Date;
/**
* ${comments}
*
* @author ${author}
* @email ${email}
* @date ${datetime}
*/
@TableName("${tableName}")
public class ${className}Entity implements Serializable {
private static final long serialVersionUID = 1L;
#foreach ($column in $columns)
/**
* $column.comments
*/
#if($column.columnName == $pk.columnName)
@TableId
#end
private $column.attrType $column.attrname;
#end
#foreach ($column in $columns)
/**
* 设置:${column.comments}
*/
public void set${column.attrName}($column.attrType $column.attrname) {
this.$column.attrname = $column.attrname;
}
/**
* 获取:${column.comments}
*/
public $column.attrType get${column.attrName}() {
return $column.attrname;
}
#end
}
package ${package}.${moduleName}.controller;
import java.util.Arrays;
import java.util.Map;
import org.apache.shiro.authz.annotation.RequiresPermissions;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.http.ResponseEntity;
import org.springframework.web.bind.annotation.*;
import ${package}.${moduleName}.entity.${className}Entity;
import ${package}.${moduleName}.service.${className}Service;
import ${mainPath}.common.utils.PageUtils;
import ${mainPath}.common.utils.R;
/**
* ${comments}
*
* @author ${author}
* @email ${email}
* @date ${datetime}
*/
@RestController
@RequestMapping("${moduleName}/${pathName}")
public class ${className}Controller {
@Autowired
private ${className}Service ${classname}Service;
/**
* 列表
*/
@GetMapping("/list")
public ResponseEntity>> list(@PageableDefault(value = 15, sort = { "${pk}" }, direction = Sort.Direction.DESC) Pageable pageable)
{
BaseResponse > baseResponse=new BaseResponse<>();
Page all = ${classname}Repository.findAll(pageable)};
if(all!=null && !all.isEmpty())
{
return new ResponseEntity>>(BaseResponseFactory.success(all),HttpStatus.OK);
}
else
{
return new ResponseEntity<>(HttpStatus.BAD_REQUEST);
}
}
/**
* 单个查询
*/
@GetMapping("/{${pk}}")
@ApiOperation(value = "/{${pk}}", httpMethod = "GET", notes = "查询单个${className}信息}")
public ResponseEntity> info(@PathVariable long ${pk}) {
Optional optional = ${classname}Repository.findById(${pk});
if(optional.isPresent())
{
return new ResponseEntity<>(BaseResponseFactory.success(optional.get()), HttpStatus.OK);
}else
{
return new ResponseEntity(HttpStatus.NOT_FOUND);
}
}
/**
* 保存
*/
@PostMapping("/add")
public ResponseEntity> save(@Validated @RequestBody ${className} goodsInfo, BindingResult bindingResult)
{
ResponseEntity> responseEntity;
BaseResponse baseResponse=new BaseResponse<>();
if(bindingResult.hasErrors())
{
StringBuilder sb=new StringBuilder();
for (FieldError fieldError : bindingResult.getFieldErrors()) {
sb.append(fieldError.getDefaultMessage());
sb.append(" ");
}
baseResponse.setCode(400);
baseResponse.setMessage(sb.toString());
responseEntity=new ResponseEntity<>(baseResponse,HttpStatus.BAD_REQUEST);
}
else
{
${className} save = ${classname}Repository.save(${classname});
baseResponse.setCode(200);
baseResponse.setMessage("保存成功");
baseResponse.setData(save);
responseEntity=new ResponseEntity<>(baseResponse, HttpStatus.OK);
}
return responseEntity;
}
}
字段对应转换
如何将数据库的字段类型对应到java代码上,比如数据库中的varchar,需要对应到java的String,本例是参考了一个自动生成工具的方式,使用了对应配置表,内容如下。
#代码生成器,配置信息
mainPath=com.
#包名
package=redheart
moduleName=erp
#作者
author=pf
email=103868365@qq.com
#表前缀(类名不会包含表前缀)
tablePrefix=yjc_
#类型转换,配置信息
TINYINT=Integer
SMALLINT=Integer
MEDIUMINT=Integer
INT=Integer
INTEGER=Integer
BIGINT=Long
FLOAT=Float
DOUBLE=Double
DECIMAL=BigDecimal
BIT=Boolea
CHAR=String
VARCHAR=String
TINYTEXT=String
TEXT=String
MEDIUMTEXT=String
LONGTEXT=String
DATE=Date
DATETIME=Date
TIMESTAMP=Date