Mybatis拦截器实现及原理

Mybatis拦截器的功能就是把对应的sql语句拦截下来然后进行修改,实现我们想实现的功能。

对于Mybatis拦截器,我们通过一个分页查询功能的例子来了解它。


文章目录

  • Mybatis拦截器实现及原理
  • 对比案例
  • 拦截器的相关事项
  • 拦截器的实现过程


对比案例

这里是一个没有使用拦截器的案例,通过案例来理解分页的实现过程,如果想直接看拦截器内容的话可以跳过。:
我们先创建一个分页实体类:

public class Page {
	/**
	 * 总条数
	 */
	private int totalNumber;
	/**
	 * 当前第几页
	 */
	private int currentPage;
	/**
	 * 总页数
	 */
	private int totalPage;
	/**
	 * 每页显示条数
	 */
	private int pageNumber = 5;
	/**
	 * 数据库中limit的参数,从第几条开始取
	 */
	private int dbIndex;
	/**
	 * 数据库中limit的参数,一共取多少条
	 */
	private int dbNumber;
	
	/**
	 * 根据当前对象中属性值计算并设置相关属性值
	 */
	public void count() {
		// 计算总页数
		int totalPageTemp = this.totalNumber / this.pageNumber;
		int plus = (this.totalNumber % this.pageNumber) == 0 ? 0 : 1;
		totalPageTemp = totalPageTemp + plus;
		if(totalPageTemp <= 0) {
			totalPageTemp = 1;
		}
		this.totalPage = totalPageTemp;
		
		// 设置当前页数
		// 总页数小于当前页数,应将当前页数设置为总页数
		if(this.totalPage < this.currentPage) {
			this.currentPage = this.totalPage;
		}
		// 当前页数小于1设置为1
		if(this.currentPage < 1) {
			this.currentPage = 1;
		}
		
		// 设置limit的参数
		this.dbIndex = (this.currentPage - 1) * this.pageNumber;
		this.dbNumber = this.pageNumber;
	}

	public int getTotalNumber() {
		return totalNumber;
	}

	public void setTotalNumber(int totalNumber) {
		this.totalNumber = totalNumber;
		this.count();
	}

	public int getCurrentPage() {
		return currentPage;
	}

	public void setCurrentPage(int currentPage) {
		this.currentPage = currentPage;
	}

	public int getTotalPage() {
		return totalPage;
	}

	public void setTotalPage(int totalPage) {
		this.totalPage = totalPage;
	}

	public int getPageNumber() {
		return pageNumber;
	}

	public void setPageNumber(int pageNumber) {
		this.pageNumber = pageNumber;
		this.count();
	}

	public int getDbIndex() {
		return dbIndex;
	}

	public void setDbIndex(int dbIndex) {
		this.dbIndex = dbIndex;
	}

	public int getDbNumber() {
		return dbNumber;
	}

	public void setDbNumber(int dbNumber) {
		this.dbNumber = dbNumber;
	}
}

接着,在servlet里创建page对象,并将配置对象(command、description、currentPage)传给service层进行查询,查询后再将配置对象传给页面。

protected void doGet(HttpServletRequest req, HttpServletResponse resp)
			throws ServletException, IOException {
		// 设置编码
		req.setCharacterEncoding("UTF-8");
		// 接受页面的值
		String command = req.getParameter("command");
		String description = req.getParameter("description");
		String currentPage = req.getParameter("currentPage");
		// 创建分页对象
		Page page = new Page();
		Pattern pattern = Pattern.compile("[0-9]{1,9}");
		if(currentPage == null ||  !pattern.matcher(currentPage).matches()) {
			page.setCurrentPage(1);
		} else {
			page.setCurrentPage(Integer.valueOf(currentPage));
		}
		QueryService listService = new QueryService();
		// 查询消息列表并传给页面
		req.setAttribute("messageList", listService.queryMessageList(command, description,page));
		// 向页面传值
		req.setAttribute("command", command);
		req.setAttribute("description", description);
		req.setAttribute("page", page);
		// 向页面跳转
		req.getRequestDispatcher("/WEB-INF/jsp/back/list.jsp").forward(req, resp);
	}

service层通过dao层进行数据总条数的查询,并将parameter对象传给dao层

public List<Message> queryMessageList(String command,String description,Page page) {
		// 组织消息对象
		Message message = new Message();
		message.setCommand(command);
		message.setDescription(description);
		MessageDao messageDao = new MessageDao();
		// 根据条件查询条数
		int totalNumber = messageDao.count(message);
		// 组织分页查询参数
		page.setTotalNumber(totalNumber);
		Map<String,Object> parameter = new HashMap<String, Object>();
		parameter.put("message", message);
		parameter.put("page", page);
		// 分页查询并返回结果
		return messageDao.queryMessageList(parameter);
	}
public int count(Message message) {
		DBAccess dbAccess = new DBAccess();
		SqlSession sqlSession = null;
		int result = 0;
		try {
			sqlSession = dbAccess.getSqlSession();
			// 通过sqlSession执行SQL语句
			IMessage imessage = sqlSession.getMapper(IMessage.class);
			result = imessage.count(message);
		} catch (Exception e) {
			// TODO Auto-generated catch block
			e.printStackTrace();
		} finally {
			if(sqlSession != null) {
				sqlSession.close();
			}
		}
		return result;
	}

dao层通过xml文件的对应关系进行查询

<select id="count"  parameterType="com.imooc.bean.Message" resultType="int">
  	select count(*) from MESSAGE
    <where>
    	<if test="command != null and !"".equals(command.trim())">
	    	and COMMAND=#{command}
	    </if>
	    <if test="description != null and !"".equals(description.trim())">
	    	and DESCRIPTION like '%' #{description} '%'
	    </if>
    </where>
  </select>

dao层实现分页查询的功能

public List<Message> queryMessageList(Map<String,Object> parameter) {
		DBAccess dbAccess = new DBAccess();
		List<Message> messageList = new ArrayList<Message>();
		SqlSession sqlSession = null;
		try {
			sqlSession = dbAccess.getSqlSession();
			// 通过sqlSession执行SQL语句
			IMessage imessage = sqlSession.getMapper(IMessage.class);
			messageList = imessage.queryMessageList(parameter);
		} catch (Exception e) {
			// TODO Auto-generated catch block
			e.printStackTrace();
		} finally {
			if(sqlSession != null) {
				sqlSession.close();
			}
		}
		return messageList;
	}
<select id="queryMessageList" parameterType="java.util.Map" resultMap="MessageResult">
    select <include refid="columns"/> from MESSAGE
    <where>
    	<if test="message.command != null and !"".equals(message.command.trim())">
	    	and COMMAND=#{message.command}
	    </if>
	    <if test="message.description != null and !"".equals(message.description.trim())">
	    	and DESCRIPTION like '%' #{message.description} '%'
	    </if>
    </where>
    order by ID limit #{page.dbIndex},#{page.dbNumber}
  </select>

拦截器的相关事项

从分页查询的例子来看,在实际的开发过程中,如果有多个列表页面的话,使用上面的方法我们需要不断地编写分页功能,拦截器能将共通的代码封装起来,再通过调用共通的代码就能减少开发的重复步骤。

首先要明确拦截器要做什么:
1、明确要拦截住什么功能
2、明确拦截器拦截后要做什么
3、拦截器完成功能实现后要交回主权

注意事项:
1、拦截什么对象
2、拦截对象的什么行为
3、什么时候拦截

拦截器的实现过程

把上面的例子进行修改,使其一开始不具有分页的功能,只能进行查询:

public List<Message> queryMessageListByPage(String command,String description,Page page) {
		Map<String,Object> parameter = new HashMap<String, Object>();
		// 组织消息对象
		Message message = new Message();
		message.setCommand(command);
		message.setDescription(description);
		parameter.put("message", message);
		parameter.put("page", page);
		MessageDao messageDao = new MessageDao();
		// 分页查询并返回结果
		return messageDao.queryMessageListByPage(parameter);
	}
<select id="queryMessageListByPage" parameterType="java.util.Map" resultMap="MessageResult">
    select <include refid="columns"/> from MESSAGE
    <where>
    	<if test="message.command != null and !"".equals(message.command.trim())">
	    	and COMMAND=#{message.command}
	    </if>
	    <if test="message.description != null and !"".equals(message.description.trim())">
	    	and DESCRIPTION like '%' #{message.description} '%'
	    </if>
    </where>
    order by ID
  </select>

接下来是实现步骤:
(这里面穿插了源码的讲解,若想直接了解实现过程可忽略)
1、创建拦截器类实现Interceptor接口

public class PageInterceptor implements Interceptor {
@Override
	public Object intercept(Invocation invocation) throws Throwable {
	    return null;
    }
	@Override
	public Object plugin(Object target) {
	    return null;
	}
	@Override
	public void setProperties(Properties properties) {
		return null;
		// TODO Auto-generated method stub
		
	}

}

通过jdbc的代码我们可以了解到,拦截器要拦截的时机应该是在获取statement之前。我们再通过源码分析:

public interface StatementHandler {
    Statement prepare(Connection var1) throws SQLException;
    ...
}

prepare方法返回的是statement,所以我们接着看prepare方法

statement = this.instantiateStatement(connection);

prepare里有这样一句代码,说明statement通过该方法获取
再看instantiateStatement()方法的实现

protected Statement instantiateStatement(Connection connection) throws SQLException {
        String sql = this.boundSql.getSql();
        if (this.mappedStatement.getKeyGenerator() instanceof Jdbc3KeyGenerator) {
            String[] keyColumnNames = this.mappedStatement.getKeyColumns();
            return keyColumnNames == null ? connection.prepareStatement(sql, 1) : connection.prepareStatement(sql, keyColumnNames);
        } else {
            return this.mappedStatement.getResultSetType() != null ? connection.prepareStatement(sql, this.mappedStatement.getResultSetType().getValue(), 1007) : connection.prepareStatement(sql);
        }
    }

Statement就在这里产生,从这几处来看我们拦截住prepare方法就能够在statement产生之前把它拦截,而且instantiateStatement方法里获取到了sql语句,所以,我们能够对SQL语句进行修改。所以第二步我们就是在该位置进行拦截,拦截可以通过注解来实现,我们拦截的是StatementHandler接口的prepare方法,参数是Connection类的。

2、在拦截器的定义前编写注解(确保拦截的是在prepareStatement方法执行之前的sql语句),如下所示:

@Intercepts({@Signature(type=StatementHandler.class,method="prepare",args={Connection.class})})
public class PageInterceptor implements Interceptor

3、编译plugin方法
Plugin.wrap方法进行了一个判断:传入的对象是否是我们要拦截的类型,即是否与注解相对应。

public Object plugin(Object target) {
		return Plugin.wrap(target, this);
	}

我们看看warp方法的源码

public static Object wrap(Object target, Interceptor interceptor) {
        Map<Class<?>, Set<Method>> signatureMap = getSignatureMap(interceptor);
        Class<?> type = target.getClass();
        Class<?>[] interfaces = getAllInterfaces(type, signatureMap);
        return interfaces.length > 0 ? Proxy.newProxyInstance(type.getClassLoader(), interfaces, new Plugin(target, interceptor, signatureMap)) : target;
    }

getSignatureMap的参数是第二个参数即warp里的this,getSignatureMap方法则是:

private static Map<Class<?>, Set<Method>> getSignatureMap(Interceptor interceptor) {
        Intercepts interceptsAnnotation = (Intercepts)interceptor.getClass().getAnnotation(Intercepts.class);
        if (interceptsAnnotation == null) {
            throw new PluginException("No @Intercepts annotation was found in interceptor " + interceptor.getClass().getName());
        } else {
            Signature[] sigs = interceptsAnnotation.value();
            Map<Class<?>, Set<Method>> signatureMap = new HashMap();
            Signature[] arr$ = sigs;
            int len$ = sigs.length;

            for(int i$ = 0; i$ < len$; ++i$) {
                Signature sig = arr$[i$];
                Set<Method> methods = (Set)signatureMap.get(sig.type());
                if (methods == null) {
                    methods = new HashSet();
                    signatureMap.put(sig.type(), methods);
                }

                try {
                    Method method = sig.type().getMethod(sig.method(), sig.args());
                    ((Set)methods).add(method);
                } catch (NoSuchMethodException var10) {
                    throw new PluginException("Could not find method on " + sig.type() + " named " + sig.method() + ". Cause: " + var10, var10);
                }
            }

            return signatureMap;
        }
    }

这里通过getAnnotation取到了注解,然后再通过后面的方法把我们要拦截的类型放在Map里,然后我们看回warp方法,它通过getAllInterfaces方法判断拦截的是否是我们注解的里的类型:

private static Class<?>[] getAllInterfaces(Class<?> type, Map<Class<?>, Set<Method>> signatureMap) {
        HashSet interfaces;
        for(interfaces = new HashSet(); type != null; type = type.getSuperclass()) {
            Class[] arr$ = type.getInterfaces();
            int len$ = arr$.length;

            for(int i$ = 0; i$ < len$; ++i$) {
                Class<?> c = arr$[i$];
                if (signatureMap.containsKey(c)) {
                    interfaces.add(c);
                }
            }
        }

        return (Class[])interfaces.toArray(new Class[interfaces.size()]);
    }
}

如果target的的类型和要拦截的class对不上的话,interface的size为0
在warp里if语句进行了判断,如果interface的长度大于0则返回一个代理类,否则返回target。
讲了这么多其实就是plugin确保了我们要拦截的对象

4、获取拦截的对象statementHandler
5、用MetaObject.forObject()方法对拦截的对象进行封装
6、获取mappedStatement对象,通过mappedStatement对象获取sql语句的id
注:这里通过sql语句的id进行判断,确保了我们拦截的的确是我们要拦截的对象
7、if语句里statementHandler的getBoundSql()方法获取BoundSql类型对象再通过该对象的getSql语句获取原始的sql语句
通过源码来看看SQL语句的获取:
PreparedStatementHandler里有这样一个方法

String sql = this.boundSql.getSql();

而PreparedStatementHandler继承了BaseStatementHandler的boundsql

protected BoundSql boundSql;

BaseStatementHandler实现了StatementHandler接口,StatementHandler里有getBoundSql()方法获取boundsql,再通过boundsql的getsql()方法获取sql语句

BoundSql getBoundSql();

8、通过boundsql的getParameterObject()方法获取参数,并根据需要获取相应的数据(在案例中用到的是page参数),再对sql语句进行修改
9、用修改后的sql语句替换原来的sql语句(交回主权):metaObject.setValue(“delegate.boundSql.sql”, pageSql);

注:10和11步并不是必要的,在本例中需要
10、新增查询的sql语句:
String countSql = "select count(
*) from (" + sql + “)a”;
Connection connection = (Connection)invocation.getArgs()[0];
PreparedStatement countStatement = connection.prepareStatement(countSql);
再用metaObject.getValue(“delegate.parameterHandler”)方法获取参数
11、得到结果集并进行相应的赋值:
parameterHandler.setParameters(countStatement);
ResultSet rs = countStatement.executeQuery();
while(rs.next()) {
page.setTotalNumber(rs.getInt(1));
}

public Object intercept(Invocation invocation) throws Throwable {
		StatementHandler statementHandler = (StatementHandler)invocation.getTarget();
		MetaObject metaObject = MetaObject.forObject(statementHandler, SystemMetaObject.DEFAULT_OBJECT_FACTORY, SystemMetaObject.DEFAULT_OBJECT_WRAPPER_FACTORY);
		MappedStatement mappedStatement = (MappedStatement)metaObject.getValue("delegate.mappedStatement");
		// 配置文件中SQL语句的ID
		String id = mappedStatement.getId();
		if(id.matches(".+ByPage$")) {
			BoundSql boundSql = statementHandler.getBoundSql();
			// 原始的SQL语句
			String sql = boundSql.getSql();
			// 查询总条数的SQL语句
			String countSql = "select count(*) from (" + sql + ")a";
			Connection connection = (Connection)invocation.getArgs()[0];
			PreparedStatement countStatement = connection.prepareStatement(countSql);
			ParameterHandler parameterHandler = (ParameterHandler)metaObject.getValue("delegate.parameterHandler");
			parameterHandler.setParameters(countStatement);
			ResultSet rs = countStatement.executeQuery();
			
			Map<?,?> parameter = (Map<?,?>)boundSql.getParameterObject();
			Page page = (Page)parameter.get("page");
			if(rs.next()) {
				page.setTotalNumber(rs.getInt(1));
			}
			// 改造后带分页查询的SQL语句
			String pageSql = sql + " limit " + page.getDbIndex() + "," + page.getDbNumber();
			metaObject.setValue("delegate.boundSql.sql", pageSql);
		}
		return invocation.proceed();
	}

最后在配置文件里要注册拦截器,即用plugins标签注册:
对于注册有两种方式:

<plugins>
  	<plugin interceptor="com.imooc.interceptor.PageInterceptor"/>
  </plugins>
<plugins>
  	<plugin interceptor="com.imooc.interceptor.PageInterceptor">
  		<property name="test" value="abc"/>
  	</plugin>
  </plugins>

第一种方式没什么可以解释的,我们来看第二种方式:
用property标签设置name和value,然后在拦截器的setproperty里用properties.getproperty()方法获取property

@Override
	public void setProperties(Properties properties) {
		this.test = properties.getProperty("test");
		// TODO Auto-generated method stub
		
	}