分页插件

在 MyBatis 拦截器中,最常用的一种就是实现分页插件。如果不使用分页插件来实现分页功能,就需要自己在映射文件的 SQL 中增加分页条件,并且为了获得数据的总数还需要额外增加一个 count 查询的 SQL,写起来很麻烦。如果要兼容多种数据库,可能要根据 databaseId 来写不同的分页 SQL,不仅写起来麻烦,也会让 SQL 变得臃肿不堪。

为了解决上面所遇到的问题,可以使用 MyBatis 的拦截器很容易地实现通用分页功能,并且针对不同的数据进行不同的配置。

这一节中要展示的这个分页插件诞生于本书的写作过程中。虽然笔者从 2014 年就开源了 PageHelper 分页插件(地址是 https://github.com/pagehelper/Mybatis-PageHelper ),但是本书中的这个插件是以多年的经验为基础重新进行设计的。这个插件更轻量级,更易扩展,实现更优雅,理解起来更容易,同时提供了更丰富的接口,很容易根据个人的喜好进行修改。PageHelper 以本书中的分页插件原理为基础,重新实现了分页功能,并升级到了5.0版本。

分页插件的核心部分由两个类组成,PageInterceptor 拦截器类和数据库方言接口 Dialect。本节还提供了基于 MySQL 数据库的实现。

PageInterceptor拦截器类

package tk.mybatis.simple.plugin;

import org.apache.ibatis.cache.CacheKey;
import org.apache.ibatis.executor.Executor;
import org.apache.ibatis.mapping.BoundSql;
import org.apache.ibatis.mapping.MappedStatement;
import org.apache.ibatis.mapping.ResultMap;
import org.apache.ibatis.mapping.ResultMapping;
import org.apache.ibatis.plugin.*;
import org.apache.ibatis.session.ResultHandler;
import org.apache.ibatis.session.RowBounds;

import java.lang.reflect.Field;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.Properties;

/**
 * Mybatis - 通用分页拦截器
 *
 * @author liuzh
 * @version 1.0.0
 */
@SuppressWarnings({"rawtypes", "unchecked"})
@Intercepts(
	@Signature(
		type = Executor.class,
		method = "query",
		args = {MappedStatement.class, Object.class,
				RowBounds.class, ResultHandler.class}
	)
)
public class PageInterceptor implements Interceptor {
    private static final List<ResultMapping> EMPTY_RESULTMAPPING
    		= new ArrayList<ResultMapping>(0);
    private Dialect dialect;
    private Field additionalParametersField;

	@Override
    public Object intercept(Invocation invocation) throws Throwable {
        //获取拦截方法的参数
        Object[] args = invocation.getArgs();
        MappedStatement ms = (MappedStatement) args[0];
        Object parameterObject = args[1];
        RowBounds rowBounds = (RowBounds) args[2];
        //调用方法判断是否需要进行分页,如果不需要,直接返回结果
        if (!dialect.skip(ms.getId(), parameterObject, rowBounds)) {
        	ResultHandler resultHandler = (ResultHandler) args[3];
            //当前的目标对象
            Executor executor = (Executor) invocation.getTarget();
            BoundSql boundSql = ms.getBoundSql(parameterObject);
            //反射获取动态参数
            Map<String, Object> additionalParameters =
            		(Map<String, Object>) additionalParametersField.get(boundSql);
            //判断是否需要进行 count 查询
            if (dialect.beforeCount(ms.getId(), parameterObject, rowBounds)){
            	//根据当前的 ms 创建一个返回值为 Long 类型的 ms
                MappedStatement countMs = newMappedStatement(ms, Long.class);
                //创建 count 查询的缓存 key
                CacheKey countKey = executor.createCacheKey(
                		countMs,
                		parameterObject,
                		RowBounds.DEFAULT,
                		boundSql);
                //调用方言获取 count sql
                String countSql = dialect.getCountSql(
                		boundSql,
                		parameterObject,
                		rowBounds,
                		countKey);
                BoundSql countBoundSql = new BoundSql(
                		ms.getConfiguration(),
                		countSql,
                		boundSql.getParameterMappings(),
                		parameterObject);
                //当使用动态 SQL 时,可能会产生临时的参数,这些参数需要手动设置到新的 BoundSql 中
                for (String key : additionalParameters.keySet()) {
                    countBoundSql.setAdditionalParameter(
                    		key, additionalParameters.get(key));
                }
                //执行 count 查询
                Object countResultList = executor.query(
                		countMs,
                		parameterObject,
                		RowBounds.DEFAULT,
                		resultHandler,
                		countKey,
                		countBoundSql);
                Long count = (Long) ((List) countResultList).get(0);
                //处理查询总数
                dialect.afterCount(count, parameterObject, rowBounds);
                if(count == 0L){
                	//当查询总数为 0 时,直接返回空的结果
                	return dialect.afterPage(
                			new ArrayList(),
                			parameterObject,
                			rowBounds);
                }
            }
            //判断是否需要进行分页查询
            if (dialect.beforePage(ms.getId(), parameterObject, rowBounds)){
            	//生成分页的缓存 key
                CacheKey pageKey = executor.createCacheKey(
                		ms,
                		parameterObject,
                		rowBounds,
                		boundSql);
                //调用方言获取分页 sql
                String pageSql = dialect.getPageSql(
                		boundSql,
                		parameterObject,
                		rowBounds,
                		pageKey);
                BoundSql pageBoundSql = new BoundSql(
                		ms.getConfiguration(),
                		pageSql,
                		boundSql.getParameterMappings(),
                		parameterObject);
                //设置动态参数
                for (String key : additionalParameters.keySet()) {
                    pageBoundSql.setAdditionalParameter(
                    		key, additionalParameters.get(key));
                }
                //执行分页查询
                List resultList = executor.query(
                		ms,
                		parameterObject,
                		RowBounds.DEFAULT,
                		resultHandler,
                		pageKey,
                		pageBoundSql);

                return dialect.afterPage(resultList, parameterObject, rowBounds);
            }
        }
        //返回默认查询
        return invocation.proceed();
    }

    /**
     * 根据现有的 ms 创建一个新的,使用新的返回值类型
     *
     * @param ms
     * @param resultType
     * @return
     */
    public MappedStatement newMappedStatement(
    		MappedStatement ms, Class<?> resultType) {
        MappedStatement.Builder builder = new MappedStatement.Builder(
        		ms.getConfiguration(),
        		ms.getId() + "_Count",
        		ms.getSqlSource(),
        		ms.getSqlCommandType()
        );
        builder.resource(ms.getResource());
        builder.fetchSize(ms.getFetchSize());
        builder.statementType(ms.getStatementType());
        builder.keyGenerator(ms.getKeyGenerator());
        if (ms.getKeyProperties() != null
        		&& ms.getKeyProperties().length != 0) {
            StringBuilder keyProperties = new StringBuilder();
            for (String keyProperty : ms.getKeyProperties()) {
                keyProperties.append(keyProperty).append(",");
            }
            keyProperties.delete(
            		keyProperties.length() - 1, keyProperties.length());
            builder.keyProperty(keyProperties.toString());
        }
        builder.timeout(ms.getTimeout());
        builder.parameterMap(ms.getParameterMap());
        //count查询返回值int
        List<ResultMap> resultMaps = new ArrayList<ResultMap>();
        ResultMap resultMap = new ResultMap.Builder(
        		ms.getConfiguration(),
        		ms.getId(),
        		resultType,
        		EMPTY_RESULTMAPPING).build();
        resultMaps.add(resultMap);
        builder.resultMaps(resultMaps);
        builder.resultSetType(ms.getResultSetType());
        builder.cache(ms.getCache());
        builder.flushCacheRequired(ms.isFlushCacheRequired());
        builder.useCache(ms.isUseCache());
        return builder.build();
    }

    @Override
    public Object plugin(Object target) {
        return Plugin.wrap(target, this);
    }

    @Override
    public void setProperties(Properties properties) {
        String dialectClass = properties.getProperty("dialect");
        try {
            dialect = (Dialect) Class.forName(dialectClass).newInstance();
        } catch (Exception e) {
            throw new RuntimeException(
            		"使用 PageInterceptor 分页插件时,必须设置 dialect 属性");
        }
        dialect.setProperties(properties);
        try {
            //反射获取 BoundSql 中的 additionalParameters 属性
            additionalParametersField = BoundSql.class.getDeclaredField(
            		"additionalParameters");
            additionalParametersField.setAccessible(true);
        } catch (NoSuchFieldException e) {
            throw new RuntimeException(e);
        }
    }

}

拦截器拦截了 Executor 类的 query 接口,虽然 Executor 中有两个 query 接口,但是参数较多的 query 接口只在 MyBaits 内部被调用,该接口不能被拦截,所以拦截的 query 是参数较少的这个方法。

分页插件的主要逻辑可以看代码中的注释,这里仅对代码做一个简单的讲解。代码中和 Dialect 有关的方法都是根据这段逻辑设计的。按照这里的逻辑,先判断当前的 MyBatis 方法是否需要进行分页:如果不需要进行分页,就直接调用 invocation.proceed() 返回;如果需要进行分页,首先获取当前方法的 BoundSql,这个对象中包含了要执行的 SQL 和对应的参数。通过这个对象的 SQL 和参数生成一个 count 查询的 BoundSql,由于这种情况下的 MappedStatement 对象中的 resultMap 或 resultType 类型为当前查询结果的类型,并不适合返回 count 查询值,因此通过 newMappedStatement 方法根据当前的 MappedStatement 生成了一个返回值类型为 Long 的对象,然后通过 Executor 执行查询,得到了数据总数。得到总数后,根据 dialect.afterCount 判断是否继续进行分页查询,因为如果当前查询的结果为 0,就不必继续进行分页查询了(为了节省时间),而是可以直接返回空值。如果需要进行分页,就使用 dialect 获取分页查询 SQL,同 count 查询类似,得到分页数据的结果后,通过 dialect 对结果进行处理并返回。

一开始看这段代码可能会有些吃力,大家可以在学习完第 11 章的内容后再回过头来看这段代码,对 MyBatis 源码有一定了解后就会容易理解这里提到的各种类的作用。除了主要的逻辑部分外,在 setProperties 中还要求必须设置 dialect 参数,该参数的值为 Dialect 实现类的全限定名称。这里进行反射实例化后,又调用了 Dialect 的 setProperties,通过参数传递可以让 Dialect 实现更多可配置的功能。除了实例化 dialect,这段代码还初始化了 additionalParametersField,这是通过反射获取了 BoundSql 对象中的 additionalParameters 属性,在创建新的 BoundSql 对象中,通过这个属性反射获取了执行动态 SQL 时产生的动态参数。

Dialect接口

除了分页插件的拦截器外,还需要理解 Dialect 接口,代码如下。

package tk.mybatis.simple.plugin;

import java.util.List;
import java.util.Properties;

import org.apache.ibatis.cache.CacheKey;
import org.apache.ibatis.mapping.BoundSql;
import org.apache.ibatis.session.RowBounds;

/**
 * 数据库方言,针对不同数据库进行实现
 *
 * @author liuzh
 */
@SuppressWarnings("rawtypes")
public interface Dialect {
	/**
	 * 跳过 count 和 分页查询
	 *
	 * @param msId 执行的  MyBatis 方法全名
	 * @param parameterObject 方法参数
	 * @param rowBounds 分页参数
	 * @return true 跳过,返回默认查询结果,false 执行分页查询
	 */
	boolean skip(String msId, Object parameterObject, RowBounds rowBounds);

	/**
	 * 执行分页前,返回 true 会进行 count 查询,false 会继续下面的 beforePage 判断
	 *
	 * @param msId 执行的  MyBatis 方法全名
	 * @param parameterObject 方法参数
	 * @param rowBounds 分页参数
	 * @return
	 */
	boolean beforeCount(String msId, Object parameterObject, RowBounds rowBounds);

	/**
	 * 生成 count 查询 sql
	 *
	 * @param boundSql 绑定 SQL 对象
	 * @param parameterObject 方法参数
	 * @param rowBounds 分页参数
	 * @param countKey count 缓存 key
	 * @return
	 */
	String getCountSql(BoundSql boundSql, Object parameterObject, RowBounds rowBounds, CacheKey countKey);

	/**
	 * 执行完 count 查询后
	 *
	 * @param count 查询结果总数
	 * @param parameterObject 接口参数
	 * @param rowBounds 分页参数
	 */
	void afterCount(long count, Object parameterObject, RowBounds rowBounds);

	/**
	 * 执行分页前,返回 true 会进行分页查询,false 会返回默认查询结果
	 *
	 * @param msId 执行的 MyBatis 方法全名
	 * @param parameterObject 方法参数
	 * @param rowBounds 分页参数
	 * @return
	 */
	boolean beforePage(String msId, Object parameterObject, RowBounds rowBounds);

	/**
	 * 生成分页查询 sql
	 *
	 * @param boundSql 绑定 SQL 对象
	 * @param parameterObject 方法参数
	 * @param rowBounds 分页参数
	 * @param pageKey 分页缓存 key
	 * @return
	 */
	String getPageSql(BoundSql boundSql, Object parameterObject, RowBounds rowBounds, CacheKey pageKey);

	/**
	 * 分页查询后,处理分页结果,拦截器中直接 return 该方法的返回值
	 *
	 * @param pageList 分页查询结果
	 * @param parameterObject 方法参数
	 * @param rowBounds 分页参数
	 * @return
	 */
	Object afterPage(List pageList, Object parameterObject, RowBounds rowBounds);

	/**
	 * 设置参数
	 *
	 * @param properties 插件属性
	 */
	void setProperties(Properties properties);
}

接口方法的含义请参看代码中的注释。具体的用法可以参考后面的 MySQL 实现类。Dialect 接口提供的方法可以控制分页逻辑及对分页结果的处理,不同的处理方式可以实现不同的效果,为了以最简单的方式实现一个可以使用的插件,代码中新增了一个 PageRowBounds 类,该类继承自 RowBounds 类,RowBounds 类包含 offset(偏移值)和 limit(限制数)。PageRowBounds 在此基础上额外增加了一个 total 属性用于记录查询总数。通过使用 PageRowBounds 方式可以很简单地处理分页参数和查询总数,这是一种最简单的实现,代码如下。

package tk.mybatis.simple.plugin;

import org.apache.ibatis.session.RowBounds;

/**
 * 可以记录 total 的分页参数
 *
 * @author liuzh
 */
public class PageRowBounds extends RowBounds{
	private long total;

	public PageRowBounds() {
		super();
	}

	public PageRowBounds(int offset, int limit) {
		super(offset, limit);
	}

	public long getTotal() {
		return total;
	}

	public void setTotal(long total) {
		this.total = total;
	}
}

MySqlDialect实现

有了 PageRowBounds,我们便可以以最简单的逻辑实现 MySQL 的分页,MySqlDialect 实现类代码如下。

package tk.mybatis.simple.plugin;

import java.util.List;
import java.util.Properties;

import org.apache.ibatis.cache.CacheKey;
import org.apache.ibatis.mapping.BoundSql;
import org.apache.ibatis.session.RowBounds;

/**
 * MySql 实现
 *
 * @author liuzh
 */
@SuppressWarnings("rawtypes")
public class MySqlDialect implements Dialect {

	@Override
	public boolean skip(String msId, Object parameterObject, RowBounds rowBounds) {
		//这里使用 RowBounds 分页,默认没有 RowBounds 参数时,会使用 RowBounds.DEFAULT 作为默认值
		if(rowBounds != RowBounds.DEFAULT){
			return false;
		}
		return true;
	}

	@Override
	public boolean beforeCount(String msId, Object parameterObject, RowBounds rowBounds) {
		//只有使用 PageRowBounds 才能记录总数,否则查询了总数也没用
		if(rowBounds instanceof PageRowBounds){
    		return true;
    	}
		return false;
	}

	@Override
	public String getCountSql(BoundSql boundSql, Object parameterObject, RowBounds rowBounds, CacheKey countKey) {
		//简单嵌套实现 MySql count 查询
		return "select count(*) from (" + boundSql.getSql() + ") temp";
	}

    @Override
    public void afterCount(long count, Object parameterObject, RowBounds rowBounds) {
    	//记录总数,按照 beforeCount 逻辑,只有 PageRowBounds 时才会查询 count,所以这里直接强制转换
    	((PageRowBounds)rowBounds).setTotal(count);
    }

    @Override
	public boolean beforePage(String msId, Object parameterObject, RowBounds rowBounds) {
		if(rowBounds != RowBounds.DEFAULT){
			return true;
		}
		return false;
	}

	@Override
	public String getPageSql(BoundSql boundSql, Object parameterObject, RowBounds rowBounds, CacheKey pageKey) {
		//pageKey 会影响缓存,通过固定的 RowBounds 可以保证二级缓存有效
		pageKey.update("RowBounds");
		return boundSql.getSql() + " limit " + rowBounds.getOffset() + "," + rowBounds.getLimit();
	}

	@Override
	public Object afterPage(List pageList, Object parameterObject, RowBounds rowBounds) {
		return pageList;
	}

	@Override
	public void setProperties(Properties properties) {

	}
}

为了让例子更简单,MySQL 实现方法中使用了 MyBatis 内存分页参数 RowBounds 对象,通过插件可以将内存分页转换为物理分页。同时为了支持获取查询总数,这里提供了一个 PageRowBounds 类,使用 PageRowBounds 类会查询 count 结果,并将结果保存到 PageRowBounds 对象中。

有了上面这些代码后,想要使用拦截器,还需要在 mybatis-config.xml 中进行如下配置。

<plugins>
    <plugin interceptor="tk.mybatis.simple.plugin.PageInterceptor">
        <property name="dialect" value="tk.mybatis.simple.plugin.MySqlDialect"/>
    </plugin>
</plugins>

配置好后,增加一个测试方法进行测试,在 3.1.2 节介绍 RoleMapper 接口时讲过一个 selectAll 方法,现在在 selectAll 方法基础上增加方法名字相同但参数不同的接口,原 selectAll 以及新增加的 selectAll 方法如下。

@ResultMap("roleResultMap")
@Select("select * from sys_role")
List<SysRole> selectAll();

List<SysRole> selectAll(RowBounds rowBounds);

不管接口使用注解实现还是在 XML 映射文件中实现,需要做的都是在接口方法中增加 RowBounds 参数。MyBatis 会对这个类型的参数进行特殊处理,这个参数可以选择 RowBounds 或者 PageRowBounds 类型。在 RoleMapperTest 测试类中添加如下的测试方法。

@Test
public void testSelectAllByRowBounds(){
    SqlSession sqlSession = getSqlSession();
    try {
        RoleMapper roleMapper = sqlSession.getMapper(RoleMapper.class);
        //查询前两个,使用 RowBounds 类型不会查询总数
        RowBounds rowBounds = new RowBounds(0, 1);
        List<SysRole> list = roleMapper.selectAll(rowBounds);
        for(SysRole role : list){
            System.out.println("角色名:" + role.getRoleName());
        }
        //使用 PageRowBounds 会查询总数
        PageRowBounds pageRowBounds = new PageRowBounds(0, 1);
        list = roleMapper.selectAll(pageRowBounds);
        //获取总数
        System.out.println("查询总数:" + pageRowBounds.getTotal());
        for(SysRole role : list){
            System.out.println("角色名:" + role.getRoleName());
        }
        //再次查询
        pageRowBounds = new PageRowBounds(1, 1);
        list = roleMapper.selectAll(pageRowBounds);
        //获取总数
        System.out.println("查询总数:" + pageRowBounds.getTotal());
        for(SysRole role : list){
            System.out.println("角色名:" + role.getRoleName());
        }
    } finally {
        sqlSession.close();
    }
}

执行该测试,输出日志如下。

从输出的日志中可以看到,第一次执行时,因为使用的是 RowBounds 类型的参数,所以只有分页查询。在第二次执行方法时,由于使用的是 PageRowBounds 参数,因此日志中额外输出了一次 count 查询。第三次使用的还是 PageRowBounds 查询,并没有输出 count 查询,但是得到了查询总数,这是因为分页插件可以支持一级和二级缓存。count 查询的缓存支持查看如下的拦截中的代码。

// 创建 count 查询的缓存 key
CacheKey countKey = executor.createCacheKey(countMs, parameterObject, RowBounds.DEFAULT, boundSql);

注意参数 RowBounds.DEFAULT,因为 count 查询和分页参数无关,只和查询条件有关,因此不管分页参数如何,这里都使用 RowBounds.DEFAULT 参数,这就保证了在分页参数不同时,count 查询总是可以使用缓存的结果,除非当 count 查询的参数(parameterObject)发生变化,才会重新执行 count 查询。在一个 SqlSession 中只能看到一级缓存的效果,要想查看二级缓存的效果,可以参考第 7 章缓存配置一节的内容,然后在不同的 SqlSession 中进行查看。

这个插件主要是用于让开发者学习,虽然也能用于生产环境,但是其中还有很多值得优化的地方,比如 count 查询不一定要用嵌套的方式,而且在 count 查询时,如果 SQL 中包含排序,去掉排序就可以提高效率等。

如果想要在生产环境使用分页插件,推荐使用 PageHelper 分页插件,这个插件支持十几种数据库,并且在很多方面都进行了优化,插件地址是 https://github.com/pagehelper/Mybatis-PageHelper