最近在尝试从mybatis sql模板中获取参数信息,期间学习了mybatis内部的一些结构,接下来笔者就向大家分享mybatis相关知识和具体代码实现。
在mybatis入门中,官方向大家介绍了如何快速初始化mybatis demo。样例代码如下:
String resource = "org/mybatis/example/mybatis-config.xml";
InputStream inputStream = Resources.getResourceAsStream(resource);
SqlSessionFactory sqlSessionFactory = new SqlSessionFactoryBuilder().build(inputStream);
在上边的代码里,mybatis通过SqlSessionFactoryBuilder
来帮助我们创建sqlSessionFactory。打开SqlSessionFactoryBuilder类,我们能发现实际上执行的方法如下:
public SqlSessionFactory build(Reader reader, String environment, Properties properties) {
try {
XMLConfigBuilder parser = new XMLConfigBuilder(reader, environment, properties);
return build(parser.parse());
} catch (Exception e) {
throw ExceptionFactory.wrapException("Error building SqlSession.", e);
} finally {
ErrorContext.instance().reset();
try {
reader.close();
} catch (IOException e) {
// Intentionally ignore. Prefer previous error.
}
}
}
在XMLConfigBuilder
类中,我们顺着parse方法,就会发现它实际上调用内部的parseConfiguration
方法完成配置解析。
private void parseConfiguration(XNode root) {
try {
// issue #117 read properties first
propertiesElement(root.evalNode("properties"));
Properties settings = settingsAsProperties(root.evalNode("settings"));
loadCustomVfs(settings);
loadCustomLogImpl(settings);
typeAliasesElement(root.evalNode("typeAliases"));
pluginElement(root.evalNode("plugins"));
objectFactoryElement(root.evalNode("objectFactory"));
objectWrapperFactoryElement(root.evalNode("objectWrapperFactory"));
reflectorFactoryElement(root.evalNode("reflectorFactory"));
settingsElement(settings);
// read it after objectFactory and objectWrapperFactory issue #631
environmentsElement(root.evalNode("environments"));
databaseIdProviderElement(root.evalNode("databaseIdProvider"));
typeHandlerElement(root.evalNode("typeHandlers"));
mapperElement(root.evalNode("mappers"));
} catch (Exception e) {
throw new BuilderException("Error parsing SQL Mapper Configuration. Cause: " + e, e);
}
}
这段逻辑不用细读就能看出,它实际上是将xml配置中的各个部分分别交给了对应的方法,由各个方法实现解析和处理。而我们最关心的mapper加载就在mapperElement
方法中。
在第一节中,我们已经摸到了mapperElement方法,这个方法虽然各类判断较多,如果你是按照官方文档配置的,实际上它只会调用下面这些代码:
ErrorContext.instance().resource(resource);
InputStream inputStream = Resources.getResourceAsStream(resource);
XMLMapperBuilder mapperParser = new XMLMapperBuilder(inputStream, configuration, resource, configuration.getSqlFragments());
mapperParser.parse();
你可以debug或翻查mapperParser的parse方法,看看最后解析完的sql模板最终放到了哪里。经过笔者debug和代码翻查,最终确定流转路径如下:
XMLStatementBuilder类最终将单个查询语句解析成了mappedStatement,而mappedStatement中存放sql模板的属性是SqlSource,而SqlSource的实现中,使用SqlNode存放解析过的sql模板。
mybatis将sql模板的内容划分为以下这几类:
以上8类节点,这里就不详细展开了,大家可以翻查源代码,其中的属性就是标签中的属性和其他相关信息。
接第二节,在XMLStatementBuilder的方法中,调用了XMLLanguageDriver的createSqlSource方法生成SqlSource,然后作为mappedStatement的属性存储起来。
而XMLLanguageDriver的createSqlSource方法则调用了XMLScriptBuilder的parseScriptNode方法创建SqlSource。
public SqlSource parseScriptNode() {
MixedSqlNode rootSqlNode = parseDynamicTags(context);
SqlSource sqlSource;
if (isDynamic) {
sqlSource = new DynamicSqlSource(configuration, rootSqlNode);
} else {
sqlSource = new RawSqlSource(configuration, rootSqlNode, parameterType);
}
return sqlSource;
}
笔者细究isDynamic,最终确定如果在sql模板中包含$内容或xml标签,isDynamic就会为true。
DynamicSqlSource和RawSqlSource有什么区别呢?我们先来看看RawSqlSource的相关方法:
RawSqlSource
public RawSqlSource(Configuration configuration, SqlNode rootSqlNode, Class<?> parameterType) {
this(configuration, getSql(configuration, rootSqlNode), parameterType);
}
public RawSqlSource(Configuration configuration, String sql, Class<?> parameterType) {
SqlSourceBuilder sqlSourceParser = new SqlSourceBuilder(configuration);
Class<?> clazz = parameterType == null ? Object.class : parameterType;
sqlSource = sqlSourceParser.parse(sql, clazz, new HashMap<>());
}
private static String getSql(Configuration configuration, SqlNode rootSqlNode) {
DynamicContext context = new DynamicContext(configuration, null);
rootSqlNode.apply(context);
return context.getSql();
}
SqlSourceBuilder
public SqlSource parse(String originalSql, Class<?> parameterType, Map<String, Object> additionalParameters) {
SqlSourceBuilder.ParameterMappingTokenHandler handler = new SqlSourceBuilder.ParameterMappingTokenHandler(configuration, parameterType, additionalParameters);
GenericTokenParser parser = new GenericTokenParser("#{", "}", handler);
String sql;
if (configuration.isShrinkWhitespacesInSql()) {
sql = parser.parse(removeExtraWhitespaces(originalSql));
} else {
sql = parser.parse(originalSql);
}
return new StaticSqlSource(configuration, sql, handler.getParameterMappings());
}
经过debug,笔者发现由于sql模板中只有#参数,mybatis在初始化解析的时候,直接将#参数变为?,然后在对应的ParameterMappings列表中上添加一个ParameterMapping。这样做避免了每次查询都要提取#参数,生成paramedSql 。它的好处是能够加快查询速度,减少内存消耗。
大家需要注意,最终存放在RawSqlSource中的StaticSqlSource,而StaticSqlSource中包含已经解析出参数的ParameterMappings列表。
而DynamicSqlSource恰好相反,由于sql的最终形态和入参息息相关,所以mybatis无法对这类sql模板预处理,只能在运行时动态渲染生成paramedSql。
这块内容由于不是本文的重点,就暂不细讲了。但大家需要知道的是DynamicSqlSource中的sqlNode类型为MixedSqlNode即可。
经过前三节的分析,我们已经得知sql模板最终存放在Configuration
->MappedStatement
->SqlSource
中。接下来我们就可以模拟mybatis初始化,然后从SqlSource中获取参数信息。
笔者在这里定义了一个枚举类ParamType
,用来区分参数类型。
package com.gavinzh.learn.mybatis;
public enum ParamType {
// # 预编译
PRE_COMPILE,
// $ 被替换
REPLACE,
// foreach产出的内容变量
INTERNAL,
// 使用bind标签产生的变量
BIND;
}
接下来定义一个bean InputParam
,用于存放解析出来的参数。
package com.gavinzh.learn.mybatis;
public class InputParam {
private String name;
private ParamType type;
private String source;
private boolean required;
public InputParam(String name, ParamType type, String source, boolean required) {
this.name = name;
this.type = type;
this.source = source;
this.required = required;
}
public String getName() {
return name;
}
public void setName(String name) {
this.name = name;
}
public ParamType getType() {
return type;
}
public void setType(ParamType type) {
this.type = type;
}
public String getSource() {
return source;
}
public void setSource(String source) {
this.source = source;
}
public boolean isRequired() {
return required;
}
public void setRequired(boolean required) {
this.required = required;
}
}
最重要的工具类ParamUtils
,笔者借用了mybatis中的GenericTokenParser查找#和$参数。
package com.gavinzh.learn.mybatis;
import org.apache.ibatis.builder.StaticSqlSource;
import org.apache.ibatis.builder.xml.XMLMapperBuilder;
import org.apache.ibatis.mapping.MappedStatement;
import org.apache.ibatis.mapping.ParameterMapping;
import org.apache.ibatis.mapping.SqlSource;
import org.apache.ibatis.parsing.GenericTokenParser;
import org.apache.ibatis.parsing.TokenHandler;
import org.apache.ibatis.parsing.XNode;
import org.apache.ibatis.scripting.defaults.RawSqlSource;
import org.apache.ibatis.scripting.xmltags.*;
import org.apache.ibatis.session.Configuration;
import java.io.ByteArrayInputStream;
import java.io.InputStream;
import java.lang.reflect.Field;
import java.util.*;
import static com.gavinzh.learn.mybatis.ParamType.*;
import static java.util.stream.Collectors.toList;
/**
* mybatis 参数处理器
*
* @author zhangheng
* @date 2021-01-09 17:15
*/
public class ParamUtils {
private static final String verifyTemplate = "<?xml version=\"1.0\" encoding=\"UTF-8\" ?>\n"
+ "<!DOCTYPE mapper\n"
+ " PUBLIC \"-//mybatis.org//DTD Mapper 3.0//EN\"\n"
+ " \"http://mybatis.org/dtd/mybatis-3-mapper.dtd\">\n"
+ "<mapper namespace=\"Verify\">\n"
+ " <select id=\"verify\">\n"
+ " %s\n"
+ " </select>\n"
+ "</mapper>";
public static SqlSource xmlVerify(String xml) {
xml = String.format(verifyTemplate, xml);
InputStream inputStream = new ByteArrayInputStream(xml.getBytes());
Configuration configuration = new Configuration();
Map<String, XNode> sqlFragments = configuration.getSqlFragments();
XMLMapperBuilder xmlMapperBuilder =
new XMLMapperBuilder(inputStream, configuration, "Verify", sqlFragments);
xmlMapperBuilder.parse();
MappedStatement mappedStatement = configuration.getMappedStatement("verify", false);
return mappedStatement.getSqlSource();
}
public static List<InputParam> parseInputParam(String xml) {
SqlSource sqlSource = xmlVerify(xml);
return parseInputParam(sqlSource);
}
public static List<InputParam> parseInputParam(SqlSource sqlSource) {
if (sqlSource instanceof RawSqlSource) {
return parseRawParam((RawSqlSource)sqlSource);
}
if (sqlSource instanceof DynamicSqlSource) {
return parseDynamicParam((DynamicSqlSource)sqlSource);
}
return Collections.emptyList();
}
private static List<InputParam> parseDynamicParam(DynamicSqlSource sqlSource) {
SqlNode sqlNode = getFieldValue(sqlSource, "rootSqlNode");
return parseSqlNode(sqlNode, true);
}
private static List<InputParam> parseSqlNode(SqlNode sqlNode, boolean required) {
if (sqlNode instanceof MixedSqlNode) {
return parseMixedSqlNodeParam((MixedSqlNode)sqlNode, required);
}
if (sqlNode instanceof TextSqlNode) {
return parseTextSqlNodeParam((TextSqlNode)sqlNode, required);
}
if (sqlNode instanceof StaticTextSqlNode) {
return parseStaticTextSqlNodeParam((StaticTextSqlNode)sqlNode, required);
}
if (sqlNode instanceof IfSqlNode) {
return parseIfSqlNodeParam((IfSqlNode)sqlNode);
}
if (sqlNode instanceof ForEachSqlNode) {
return parseForEachSqlNodeParam((ForEachSqlNode)sqlNode, required);
}
if (sqlNode instanceof ChooseSqlNode) {
return parseChooseSqlNodeParam((ChooseSqlNode)sqlNode);
}
if (sqlNode instanceof TrimSqlNode) {
return parseTrimSqlNodeParam((TrimSqlNode)sqlNode, required);
}
if (sqlNode instanceof VarDeclSqlNode) {
return parseVarDeclSqlNodeParam((VarDeclSqlNode)sqlNode, required);
}
return Collections.emptyList();
}
private static List<InputParam> parseVarDeclSqlNodeParam(VarDeclSqlNode sqlNode, boolean required) {
String name = getFieldValue(sqlNode, "name");
String expression = getFieldValue(sqlNode, "expression");
return Collections.singletonList(new InputParam(name, BIND, expression, required));
}
private static List<InputParam> parseTrimSqlNodeParam(TrimSqlNode sqlNode, boolean required) {
SqlNode contents = getFieldValue(sqlNode, "contents");
return parseSqlNode(contents, required);
}
private static List<InputParam> parseChooseSqlNodeParam(ChooseSqlNode sqlNode) {
List<InputParam> chooseParamList = new ArrayList<InputParam>();
List<SqlNode> ifSqlNodes = getFieldValue(sqlNode, "ifSqlNodes");
ifSqlNodes.forEach(content -> chooseParamList.addAll(parseSqlNode(content, false)));
SqlNode defaultSqlNode = getFieldValue(sqlNode, "defaultSqlNode");
if (defaultSqlNode != null) {
chooseParamList.addAll(parseSqlNode(defaultSqlNode, false));
}
return chooseParamList;
}
private static List<InputParam> parseForEachSqlNodeParam(ForEachSqlNode sqlNode, boolean required) {
List<InputParam> forEachParamList = new ArrayList<InputParam>();
// TODO collectionExpression可以表达式,但大多数情况都只会传入变量
String collectionExpression = getFieldValue(sqlNode, "collectionExpression");
forEachParamList.add(new InputParam(collectionExpression, PRE_COMPILE, null, required));
String item = getFieldValue(sqlNode, "item");
if (item != null) {
forEachParamList.add(new InputParam(item, INTERNAL, collectionExpression, false));
}
String index = getFieldValue(sqlNode, "index");
if (index != null) {
forEachParamList.add(new InputParam(index, INTERNAL, collectionExpression, false));
}
return forEachParamList;
}
private static List<InputParam> parseIfSqlNodeParam(IfSqlNode sqlNode) {
// TODO 还可以从test中获取变量
SqlNode contents = getFieldValue(sqlNode, "contents");
return parseSqlNode(contents, false);
}
private static List<InputParam> parseStaticTextSqlNodeParam(StaticTextSqlNode sqlNode, boolean required) {
TextSqlNodeTokenHandler handler = new TextSqlNodeTokenHandler();
GenericTokenParser parser = new GenericTokenParser("#{", "}", handler);
parser.parse(getFieldValue(sqlNode, "text"));
// TODO mybatis允许在大括号内标记类型,所以可以从大括号内尝试获取类型
return handler.getParamSet().stream()
.map(param -> new InputParam(param, PRE_COMPILE, null, required))
.collect(toList());
}
private static List<InputParam> parseTextSqlNodeParam(TextSqlNode sqlNode, boolean required) {
TextSqlNodeTokenHandler handler1 = new TextSqlNodeTokenHandler();
GenericTokenParser parser1 = new GenericTokenParser("${", "}", handler1);
parser1.parse(getFieldValue(sqlNode, "text"));
TextSqlNodeTokenHandler handler2 = new TextSqlNodeTokenHandler();
GenericTokenParser parser2 = new GenericTokenParser("#{", "}", handler2);
parser2.parse(getFieldValue(sqlNode, "text"));
// TODO mybatis允许在大括号内标记类型,所以可以从大括号内尝试获取类型
List<InputParam> all = new ArrayList<InputParam>();
all.addAll(handler1.getParamSet().stream()
.map(param -> new InputParam(param, REPLACE, null, required))
.collect(toList()));
all.addAll(handler2.getParamSet().stream()
.map(param -> new InputParam(param, PRE_COMPILE, null, required))
.collect(toList()));
return all;
}
private static List<InputParam> parseMixedSqlNodeParam(MixedSqlNode sqlNode, boolean required) {
List<SqlNode> contents = getFieldValue(sqlNode, "contents");
return contents.stream()
.map(node -> parseSqlNode(node, required))
.flatMap(Collection::stream)
.collect(toList());
}
private static List<InputParam> parseRawParam(RawSqlSource sqlSource) {
StaticSqlSource SqlSource = getFieldValue(sqlSource, "sqlSource");
List<ParameterMapping> parameterMappings = getFieldValue(SqlSource, "parameterMappings");
return parameterMappings.stream()
.map(parameterMapping ->
new InputParam(parameterMapping.getProperty(), PRE_COMPILE, null, true))
.collect(toList());
}
private static <T> T getFieldValue(Object o, String fieldName) {
return getFieldValue(o.getClass(), o, fieldName);
}
private static <T> T getFieldValue(Class clazz, Object o, String fieldName) {
try {
Field field = clazz.getDeclaredField(fieldName);
field.setAccessible(true);
return (T)field.get(o);
} catch (Exception e) {
if (clazz.getSuperclass() != null) {
return getFieldValue(clazz.getSuperclass(), o, fieldName);
} else {
return null;
}
}
}
static class TextSqlNodeTokenHandler implements TokenHandler {
private Set<String> paramSet = new HashSet<String>();
public Set<String> getParamSet() {
return paramSet;
}
@Override
public String handleToken(String content) {
paramSet.add(content);
return content;
}
}
}
工具类中,有一些TODO项,都是一些可以更进一步做的事情,大家有兴趣的话,可以自行再次开发。
最终,我们拿一个官网的例子演示一下:
package com.gavinzh.learn.mybatis;
/**
* @author zhangheng
* @date 2021-01-09 17:32
*/
public class Main {
private static String mybatisSql = "<bind name=\"likeStr\" value=\"'%' + like + '%'\" />"
+ "SELECT * FROM BLOG\n"
+ "<where>\n"
+ " <choose>\n"
+ " <when test=\"title != null\">\n"
+ " AND title like #{title}\n"
+ " </when>\n"
+ " <when test=\"author != null and author_name != null\">\n"
+ " AND author_name like #{author_name}\n"
+ " </when>\n"
+ " <otherwise>\n"
+ " AND featured = 1\n"
+ " </otherwise>\n"
+ " </choose>\n"
+ " <foreach item=\"item\" index=\"index\" collection=\"list\"\n"
+ " open=\"(\" separator=\",\" close=\")\">\n"
+ " #{item}\n"
+ " </foreach>\n"
+ " <if test=\"state != null\">\n"
+ " state = #{state}\n"
+ " </if>\n"
+ "</where>\n"
+ "limit ${pageSize} offset ${offset} #{test}";
public static void main(String[] args) {
ParamUtils.parseInputParam(mybatisSql)
.forEach(inputParam -> {
if (inputParam.getSource() != null) {
System.out.println(String.format("%s 类型:%s 必填:%s 源:%s", inputParam.getName(),
inputParam.getType(), inputParam.isRequired(), inputParam.getSource()));
} else {
System.out.println(String.format("%s 类型:%s 必填:%s", inputParam.getName(),
inputParam.getType(), inputParam.isRequired()));
}
});
}
}
结果如下:
笔者基本上是顺藤摸瓜,借mybatis自身的方法实现了相关参数解析。回过头来再看,如果直接用正则,似乎也可以拿到参数,但其中foreach标签的解析可能就会有问题,也不可能得知一个参数是否是必须传入的。
最终总结一下,通过mybatis的sqlNode结构获取参数信息是获得参数的最佳手段。