脑洞的由来

场景一:项目转用JWT做权限认证,刚开始选用shiro+jwt,但是发现对于一个无状态认证来说,shiro太重了,原本便捷的功能反而显得很多余

功能需求

  1. 从请求中获得token
  2. 校验token合法性和时效性
  3. 拦截请求校验权限

功能实现

整体项目结构

如何在spring boot项目中 使用junit测试_拦截器

可以看到用到的类很少,下面列举下关键类

权限校验注解

/**
 * 此注解用于Controller接口方法上,标记为需要登录
 */
@Retention(RUNTIME)
@Target(METHOD)
public @interface NeedLogin {

    /**
     * 需要的权限
     */
    int[] value() default {};

}

web拦截器

/**
 * web拦截器
 */
public abstract class WebAuthInterceptor implements HandlerInterceptor {

    /**
     * 此处省略
     */

}

账户抽象接口

如何在spring boot项目中 使用junit测试_ide_02

实际运用

接下来我们实际运用时还需要哪些东西

  1. 首先是拦截器
/**
 * web拦截器的基础实现
 * Created by qirong on 2019/5/31.
 */
@Component
public class CommonWebAuthInterceptor extends WebAuthInterceptor {

    @Autowired
    private AccountManager accountManager;

    /**
     * 校验token并获取用户信息
     * @param token
     * @param request
     * @param response
     * @param handler
     * @return
     */
    @Override
    protected Optional<AccountInfo> verifyAndGetUser(String token, HttpServletRequest request, HttpServletResponse response, HandlerMethod handler) {
        Long uid = JwtUtils.getUid(token);
        if (uid == null) {
            return Optional.empty();
        }
        // 根据id获取账户
        Account account = accountManager.getAccount(uid);
        // 校验Token
        if (JwtUtils.verifyToken(token, account.getCredentialsSalt())) {
            return Optional.of(account);
        } else {
            return Optional.empty();
        }
    }

    /**
     * 未登录流程
     *
     * @param token
     * @param request
     * @param response
     * @param handler
     */
    @Override
    protected void unAuthorizedProcess(String token, HttpServletRequest request, HttpServletResponse response, HandlerMethod handler) {
        throw new AuthException("未登录");
    }

    /**
     * 获取所需权限
     *
     * @param handler
     * @return
     */
    @Override
    protected List<Permission> getNeedPermissions(AccountInfo accountInfo, HandlerMethod handler) {
        NeedLogin needLogin = handler.getMethodAnnotation(NeedLogin.class);
        Assert.notNull(needLogin, () -> "");
        List<Permission> permissionList = new ArrayList<>();
        // 获取注解上所有权限
        for (int id : needLogin.value()) {
            // 将权限转换为枚举
            AccountPermission permission = AccountPermission.getByID(id);
            Assert.notNull(permission, "权限转换错误, 权限ID :" + id + ",方法 : " + handler.getMethod().getName());
            permissionList.add(permission);
        }
        return permissionList;
    }

    /**
     * 校验权限
     *
     * @param accountInfo
     * @param needPermissions
     * @param request
     * @param response
     * @return
     */
    @Override
    protected boolean verifyPermissions(AccountInfo accountInfo, List<Permission> needPermissions, HttpServletRequest request, HttpServletResponse response) {
        // 从账户对象上获取权限列表
        List<Permission> hadPermissionList = accountInfo.getPermissions();
        for (Permission permission : needPermissions) {
            // 判断权限列表是否包含所需权限
            if (!hadPermissionList.contains(permission)) {
                return false;
            }
        }
        // 成功
        return true;
    }

    /**
     * 拒绝流程
     *
     * @param token
     * @param request
     * @param response
     * @param handler
     */
    @Override
    protected void forbiddenProcess(String token, HttpServletRequest request, HttpServletResponse response, HandlerMethod handler) {
        throw new AuthException("没有权限,拒绝登录");
    }

}
  1. 拦截器中用到的自定义异常和全局异常处理
/**
 * 自定义权限异常
 * Created by qirong on 2019/6/3.
 */
public class AuthException extends RuntimeException {

    public AuthException() {
        super();
    }

    public AuthException(String msg) {
        super(msg);
    }

    public AuthException(Throwable throwable) {
        super(throwable);
    }

}
/**
 * 全局异常拦截
 */
@ControllerAdvice
public class WebExceptionHandler {

    private static final Logger LOG = LoggerFactory.getLogger(WebExceptionHandler.class);

    /**
     * 权限异常
     */
    @ExceptionHandler(value = AuthException.class)
    @ResponseBody
    public String onException(HttpServletRequest req, Exception e) {
        logException(req, e);
        return e.getMessage();
    }

    private void logException(HttpServletRequest req, Exception e) {
        if (LOG.isErrorEnabled()) {
            LOG.error("请求异常:[uri={},method={},e={}]", req.getRequestURI(), req.getMethod(), e.getClass());
            LOG.error(e.getMessage());
            LOG.error("StackTrace: {}", e);
        }
    }

}

通过ResponseBodyAdvice支持跨域和TOKEN刷新

/**
 * 拦截Controller方法默认返回参数,统一处理返回值/响应体
 * Created by qirong on 2019/8/10.
 */
@ControllerAdvice
public class ResponseHeaderAdvice implements ResponseBodyAdvice<Object> {

    @Autowired
    private AuthConfig authConfig;

    @Override
    public boolean supports(MethodParameter methodParameter, Class<? extends HttpMessageConverter<?>> aClass) {
        return true;
    }

    @Override
    public Object beforeBodyWrite(Object o, MethodParameter methodParameter, MediaType mediaType, Class<? extends HttpMessageConverter<?>> aClass,
        ServerHttpRequest serverHttpRequest, ServerHttpResponse serverHttpResponse) {
        ServletServerHttpRequest serverRequest = (ServletServerHttpRequest) serverHttpRequest;
        ServletServerHttpResponse serverResponse = (ServletServerHttpResponse) serverHttpResponse;
        if (serverRequest == null || serverResponse == null
            || serverRequest.getServletRequest() == null || serverResponse.getServletResponse() == null) {
            return o;
        }

        // 对于未添加跨域消息头的响应进行处理
        HttpServletRequest request = serverRequest.getServletRequest();
        HttpServletResponse response = serverResponse.getServletResponse();
        crossDomain(request, response);
        tokenRefresh(request, response);
        return o;
    }

    /**
     * 跨域支持
     */
    private void crossDomain(HttpServletRequest request, HttpServletResponse response) {
        String originHeader = "Access-Control-Allow-Origin";
        if (!response.containsHeader(originHeader)) {
            String origin = request.getHeader("Origin");
            if (origin == null) {
                String referer = request.getHeader("Referer");
                if (referer != null)
                    origin = referer.substring(0, referer.indexOf("/", 7));
            }
            response.setHeader("Access-Control-Allow-Origin", origin);
        }

        String allowHeaders = "Access-Control-Allow-Headers";
        if (!response.containsHeader(allowHeaders))
            response.setHeader(allowHeaders, request.getHeader(allowHeaders));

        String allowMethods = "Access-Control-Allow-Methods";
        if (!response.containsHeader(allowMethods))
            response.setHeader(allowMethods, "GET,POST,OPTIONS,HEAD");
        //这个很关键,要不然ajax调用时浏览器默认不会把这个token的头属性返给JS
        String exposeHeaders = "access-control-expose-headers";
        if (!response.containsHeader(exposeHeaders))
            response.setHeader(exposeHeaders, "x-auth-token");
    }

    /**
     * 令牌刷新
     */
    private void tokenRefresh(HttpServletRequest request, HttpServletResponse response) {
        try {
            Boolean renewal_token = (Boolean) request.getAttribute(RENEWAL_TOKEN_REQUEST_ATT_NAME);
            if (renewal_token != null && renewal_token) {
                String token = (String) request.getAttribute(ACCESS_TOKEN_REQUEST_ATT_NAME);
                token = JwtUtils.sign(authConfig.getSecretKey(), authConfig.getTokenRefreshSpace(), JwtUtils.getClaimMap(token));
                token = CodingUtils.encryptAES(token, authConfig.getTokenSecretKey());
                response.setHeader(ACCESS_TOKEN_HEADER_NAME, token);
            }
        } catch (Exception e) {
            e.printStackTrace();
        }
    }

}

然后我们来看下测试用的Controller

@RestController
@ApiOperation(value = "测试接口")
public class TestController {

    @Autowired
    private AccountManager accountManager;

    @GetMapping("testAuth")
    @ApiOperation(value = "测试登录服务", notes = "测试登录服务")
    public String testAuth() {
        // 生成Token
        Account account = accountManager.createAccount();
        return JwtUtils.sign(account.getUid(), account.getCredentialsSalt(), 30000);
    }

    @NeedLogin(SysPermission.TEST_PERMISSION_1)
    @ApiImplicitParams({
        @ApiImplicitParam(paramType = "header", dataType = "String", name = "lp_token", value = "令牌", required = true)
    })
    @GetMapping("testPermission")
    @ApiOperation(value = "测试权限服务", notes = "测试权限服务")
    public String testPermission() {
        return "success";
    }

    @NeedLogin({SysPermission.TEST_PERMISSION_1, SysPermission.TEST_PERMISSION_2})
    @ApiImplicitParams({
        @ApiImplicitParam(paramType = "header", dataType = "String", name = "lp_token", value = "令牌", required = true)
    })
    @GetMapping("testNoPermission")
    @ApiOperation(value = "测试无权限服务", notes = "测试无权限服务")
    public String testNoPermission() {
        return "success";
    }

}

如何在spring boot项目中 使用junit测试_拦截器_03