前往小程序,Get更优阅读体验!
立即前往
首页
学习
活动
专区
工具
TVP
发布
社区首页 >专栏 >mybatis插件拦截原理学习

mybatis插件拦截原理学习

作者头像
路行的亚洲
发布2021-03-04 15:46:59
3180
发布2021-03-04 15:46:59
举报
文章被收录于专栏:后端技术学习后端技术学习

mybatis中,我们知道如果需要对分页或者排序进行增强时,可以采用拦截来实现增强,那它的增强原理又是怎样的呢?

做拦截操作:执行插件包装方法(重要)。可以看到它里面有三个参数是我们需要关注的,一个是签名方法是通过拦截器拿到的,一个是目标方法的字节码对象,一个是接口列表,接口列表是代理的主要对象,采用动态代理,因此这里做了一个判断,判断接口列表的长度是否大于0,如果大于0,则执行动态代理,可以看到代理过程是在Plugin类中的,同时可以看到此时必然会有一个invoke方法执行动态代理,实现拦截增强。执行拦截增强的前提是有方法可以拦截,因此此时会判断method是否为空,或者是否包含增强方法,如果有,则执行interceptor的intercept方法,否者执行method.invoke方法。可以看到签名map可以有多个method,因为其具有的签名方法有一对多或者一对一的方式,同时接口是一个列表,而目标字节码为type。

我们还是按照原来的方式,在mybatis的test中找到pluginTest这个测试方法:

代码语言:javascript
复制
class PluginTest {

  @Test
  void mapPluginShouldInterceptGet() {
    Map map = new HashMap();
    //调用拦截方法AlwaysMapPlugin,执行增强plugin的代理
    map = (Map) new AlwaysMapPlugin().plugin(map);
    assertEquals("Always", map.get("Anything"));
  }

  @Test
  void shouldNotInterceptToString() {
    Map map = new HashMap();
    map = (Map) new AlwaysMapPlugin().plugin(map);
    assertNotEquals("Always", map.toString());
  }

  //执行拦截操作
  @Intercepts({
      @Signature(type = Map.class, method = "get", args = {Object.class})})
  public static class AlwaysMapPlugin implements Interceptor {
    @Override
    public Object intercept(Invocation invocation) {
      return "Always";
    }

  }

}

插件方法:

代码语言:javascript
复制
1.getSignatureMap:获取签名map 获取注解intercepts,如果是插件增强,此时必然可以看到注解Intercepts,拿到签名的值进行遍历放入到methods里面,而methods的数据结构可以看到是Set<Method> methods = signatureMap.computeIfAbsent(sig.type(), k -> new HashSet<>());
2.getClass 获取目标类
3.getAllInterfaces:获取所有的接口列表 获取所有接口列表,如果类型不为空,则通过类型拿到所有的接口,如果签名map中包含, 则再接口列表中添加,否者获取父类字节码类型,返回接口列表数组,如果接口列表的长度>0,则说明需要进行代理,而代理的过程则是plugin
代码语言:javascript
复制
//执行插件
default Object plugin(Object target) {
  return Plugin.wrap(target, this);
}

//执行包装
public static Object wrap(Object target, Interceptor interceptor) {
    Map<Class<?>, Set<Method>> signatureMap = getSignatureMap(interceptor);
    Class<?> type = target.getClass();
    Class<?>[] interfaces = getAllInterfaces(type, signatureMap);
    //如果接口列表的长度>0,则说明需要进行代理,而代理的过程则是plugin
    if (interfaces.length > 0) {
        return Proxy.newProxyInstance(
            type.getClassLoader(),
            interfaces,
            new Plugin(target, interceptor, signatureMap));
    }
    return target;
}

可以看到如果接口列表不为空时,会执行动态代理,执行invoke方法。可以看到执行过程会在Plugin类中的invoke中。

代码语言:javascript
复制
//执行invoke方法,获取需要代理的方法,进行拦截
@Override
public Object invoke(Object proxy, Method method, Object[] args) throws Throwable {
  try {
    Set<Method> methods = signatureMap.get(method.getDeclaringClass());
    if (methods != null && methods.contains(method)) {
      return interceptor.intercept(new Invocation(target, method, args));
    }
    //否者执行方法的invoke
    return method.invoke(target, args);
  } catch (Exception e) {
    throw ExceptionUtil.unwrapThrowable(e);
  }
}

也即此时会执行interceptor的intercept方法:

代码语言:javascript
复制
//执行拦截操作
@Intercepts({
    @Signature(type = Map.class, method = "get", args = {Object.class})})
public static class AlwaysMapPlugin implements Interceptor {
  @Override
  public Object intercept(Invocation invocation) {
    return "Always";
  }

}

此时会执行到拦截方法,调用拦截方法执行增强,可以看到果然是执行了增强方法,我的key是随便输入,此时返回的value值就是增强后的值。

增强的结果

比如拦截sql语句执行打印时,此时的method就会有多个。比如我们想对sql打印执行拦截操作,此时就会写如下:

代码语言:javascript
复制
写好拦截注解,同时签名方法,同时里面的参数
重写三个方法:intercept方法(通常这里会写需要拦截方法的具体的逻辑,而对于sql,我们需要拿到boundSql,此时我们可以拿到mappedStatement,然后通过映射语句拿到绑定sql和sqlId和配置信息,结合时间,从而计算sql执行时间等信息)、plugin方法、setProperties方法

进行拦截具体实现:

代码语言:javascript
复制
//写好拦截注解,同时签名方法,同时里面的参数
@Intercepts({
        @Signature(type = Executor.class, method = "update", args = { MappedStatement.class, Object.class }),
        @Signature(type = Executor.class, method = "query", args = { MappedStatement.class, Object.class,
                RowBounds.class, ResultHandler.class }) })
public class SqlInterceptor implements Interceptor {
    private Properties properties;

    //重写intercept方法
    @Override
    public Object intercept(Invocation invocation) throws Throwable {
        //获取映射语句
        MappedStatement mappedStatement = (MappedStatement) invocation.getArgs()[0];
        Object parameter = null;
        if (invocation.getArgs().length > 1) {
            parameter = invocation.getArgs()[1];
        }
        //通过映射语句拿到id、绑定sql、配置信息,然后计算执行时间
        String sqlId = mappedStatement.getId();
        BoundSql boundSql = mappedStatement.getBoundSql(parameter);
        Configuration configuration = mappedStatement.getConfiguration();
        Object returnValue = null;
        long start = System.currentTimeMillis();
        returnValue = invocation.proceed();
        long end = System.currentTimeMillis();
        long time = (end - start);
        if (time > 1) {
            String sql = getSql(configuration, boundSql, sqlId, time);
            System.err.println(sql);
        }
        return returnValue;
    }

    //重写plugin方法
    @Override
    public Object plugin(Object target) {
        return Plugin.wrap(target, this);
    }

    //重写setProperties方法
    @Override
    public void setProperties(Properties properties) {
        this.properties = properties;
    }

    //获取sql
    public static String getSql(Configuration configuration, BoundSql boundSql, String sqlId, long time) {
        String sql = showSql(configuration, boundSql);
        StringBuilder str = new StringBuilder(100);
        str.append(sqlId);
        str.append(":");
        str.append(sql);
        str.append(">>>>>>");
        str.append(time);
        str.append("ms");
        return str.toString();
    }

    //获取参数信息
    private static String getParameterValue(Object obj) {
        String value = null;
        if (obj instanceof String) {
            value = "'" + obj.toString() + "'";
        } else if (obj instanceof Date) {
            DateFormat formatter = DateFormat.getDateTimeInstance(DateFormat.DEFAULT, DateFormat.DEFAULT, Locale.CHINA);
            value = "'" + formatter.format(new Date()) + "'";
        } else {
            if (obj != null) {
                value = obj.toString();
            } else {
                value = "";
            }

        }
        return value;
    }

    //获取sql
    public static String showSql(Configuration configuration, BoundSql boundSql) {
        //通过绑定sql拿到参数对象、参数映射、sql,将\\进行替换
        Object parameterObject = boundSql.getParameterObject();
        List<ParameterMapping> parameterMappings = boundSql.getParameterMappings();
        String sql = boundSql.getSql().replaceAll("[\\s]+", " ");
        if (parameterMappings.size() > 0 && parameterObject != null) {
            TypeHandlerRegistry typeHandlerRegistry = configuration.getTypeHandlerRegistry();
            if (typeHandlerRegistry.hasTypeHandler(parameterObject.getClass())) {
                sql = sql.replaceFirst("\\?", getParameterValue(parameterObject));

            } else {
                MetaObject metaObject = configuration.newMetaObject(parameterObject);
                for (ParameterMapping parameterMapping : parameterMappings) {
                    String propertyName = parameterMapping.getProperty();
                    if (metaObject.hasGetter(propertyName)) {
                        Object obj = metaObject.getValue(propertyName);
                        sql = sql.replaceFirst("\\?", getParameterValue(obj));
                    } else if (boundSql.hasAdditionalParameter(propertyName)) {
                        Object obj = boundSql.getAdditionalParameter(propertyName);
                        sql = sql.replaceFirst("\\?", getParameterValue(obj));
                    }
                }
            }
        }
        return sql;
    }

}
本文参与 腾讯云自媒体同步曝光计划,分享自微信公众号。
原始发表:2021-02-13,如有侵权请联系 cloudcommunity@tencent.com 删除

本文分享自 后端技术学习 微信公众号,前往查看

如有侵权,请联系 cloudcommunity@tencent.com 删除。

本文参与 腾讯云自媒体同步曝光计划  ,欢迎热爱写作的你一起参与!

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档