最近项目中引入oauth2框架,发现token存在固定30分钟失效问题;而用户在实际使用过程中,如果固定30分钟就要登出,重新登录方能使用,体验度极差;需要后端能够提供token续签的功能;

网上也搜索过不少资料,例如:

后端提供刷新token接口,前端加入定时器,依赖后端返回的过期时间定时刷新token;

但此方式无法满足当前项目的需要,项目允许同一个账号开启多个网页访问,需要登录;且前端使用的是本地session缓存,token只针对单个页面有效;同一个账号若是通过刷新token接口获取新的token,会导致其他界面的token失效;

为了解决项目token续签问题,通过源码分析,获得了突破;

思路一:如何实现token续签?
用户登录成功后,会在redis中缓存key值:

private static final String ACCESS = "access:";
	private static final String AUTH_TO_ACCESS = "auth_to_access:";
	private static final String AUTH = "auth:";
	private static final String CLIENT_ID_TO_ACCESS = "client_id_to_access:";
	private static final String UNAME_TO_ACCESS = "uname_to_access:";

这些key都是有过期时间的,若是想在原token的基础上实现自动续签,更新这几个key的过期时间就可以了;(后续实践证明除了要更新key的过期时间,还要更新对应的value里面的expiration)

思路二:什么时间实现token续签?

前端每一个请求都会携带token,在gateway里面增加一个过滤器,拦截所有的请求,对token进行解析,如果token还有10分钟(举例)过期,就重新设置token过期时间为30分钟;
(此处不可每次请求都对token续签,效率很低)

思路有了,可以垒代码了,在上代码之前,先看一下源码:

@Transactional
	public OAuth2AccessToken createAccessToken(OAuth2Authentication authentication) throws AuthenticationException {

		OAuth2AccessToken existingAccessToken = tokenStore.getAccessToken(authentication);
		OAuth2RefreshToken refreshToken = null;
		if (existingAccessToken != null) {
			if (existingAccessToken.isExpired()) {
				if (existingAccessToken.getRefreshToken() != null) {
					refreshToken = existingAccessToken.getRefreshToken();
					// The token store could remove the refresh token when the
					// access token is removed, but we want to
					// be sure...
					tokenStore.removeRefreshToken(refreshToken);
				}
				tokenStore.removeAccessToken(existingAccessToken);
			}
			else {
				// Re-store the access token in case the authentication has changed
				tokenStore.storeAccessToken(existingAccessToken, authentication);
				return existingAccessToken;
			}
		}

		// Only create a new refresh token if there wasn't an existing one
		// associated with an expired access token.
		// Clients might be holding existing refresh tokens, so we re-use it in
		// the case that the old access token
		// expired.
		if (refreshToken == null) {
			refreshToken = createRefreshToken(authentication);
		}
		// But the refresh token itself might need to be re-issued if it has
		// expired.
		else if (refreshToken instanceof ExpiringOAuth2RefreshToken) {
			ExpiringOAuth2RefreshToken expiring = (ExpiringOAuth2RefreshToken) refreshToken;
			if (System.currentTimeMillis() > expiring.getExpiration().getTime()) {
				refreshToken = createRefreshToken(authentication);
			}
		}

		OAuth2AccessToken accessToken = createAccessToken(authentication, refreshToken);
		tokenStore.storeAccessToken(accessToken, authentication);
		// In case it was modified
		refreshToken = accessToken.getRefreshToken();
		if (refreshToken != null) {
			tokenStore.storeRefreshToken(refreshToken, authentication);
		}
		return accessToken;

	}

登录的时候会生成token 信息,oauth2会先根据登录请求参数去redis里面获取缓存的token对象信息。
若存在token,,或判断该对象的过期时间,若过期了会删除token,重新生成;若没有过期,则会返回原token对象;
若不存在token,直接生成新的token对象;

获取token信息代码很关键,取的是AUTH_TO_ACCESS 对象:

@Override
	public OAuth2AccessToken getAccessToken(OAuth2Authentication authentication) {
		String key = authenticationKeyGenerator.extractKey(authentication);
		byte[] serializedKey = serializeKey(AUTH_TO_ACCESS + key);
		byte[] bytes = null;
		RedisConnection conn = getConnection();
		try {
			bytes = conn.get(serializedKey);
		} finally {
			conn.close();
		}
		OAuth2AccessToken accessToken = deserializeAccessToken(bytes);
		if (accessToken != null
				&& !key.equals(authenticationKeyGenerator.extractKey(readAuthentication(accessToken.getValue())))) {
			// Keep the stores consistent (maybe the same user is
			// represented by this authentication but the details have
			// changed)
			storeAccessToken(accessToken, authentication);
		}
		return accessToken;
	}

了解了token的生成逻辑,对token的续签操作就很简单了;

过滤器代码很简单,因为涉及到项目隐私,直接上token续签逻辑:
以下是代码逻辑:

maven依赖:

<!--   auth2配置,只是为了引jar包解析token     -->
        <dependency>
            <groupId>org.springframework.security.oauth</groupId>
            <artifactId>spring-security-oauth2</artifactId>
            <version>2.0.15.RELEASE</version>
            <exclusions>
                <!--       必须去除,否则会加载  security 配置导致请求无效       -->
                <exclusion>
                    <groupId>org.springframework.security</groupId>
                    <artifactId>spring-security-config</artifactId>
                </exclusion>
                <exclusion>
                    <groupId>org.springframework.security</groupId>
                    <artifactId>spring-security-web</artifactId>
                </exclusion>
            </exclusions>
        </dependency>

过滤器代码:

package xxx;

import com.alibaba.fastjson.JSON;

import com.google.gson.Gson;
import com.netflix.zuul.ZuulFilter;
import com.netflix.zuul.context.RequestContext;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.lang.StringUtils;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.cloud.netflix.zuul.filters.support.FilterConstants;
import org.springframework.data.redis.core.RedisTemplate;
import org.springframework.http.HttpStatus;
import org.springframework.stereotype.Component;
import org.springframework.ui.Model;

import javax.annotation.Resource;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;

@Slf4j
@Component
public class WebTokenFilter extends ZuulFilter {

    private Gson gson = new Gson();
    
    @Autowired
    private RedisTemplate<String, String> redisTemplate;
    
    @Resource
    private CosSecurityProperties cosSecurityProperties;
    
    @Override
    public String filterType() {
        return FilterConstants.PRE_TYPE;
    }

    @Override
    public int filterOrder() {
        return -1;
    }

    @Override
    public boolean shouldFilter() {
        RequestContext ctx = RequestContext.getCurrentContext();
        String url = ctx.getRequest().getRequestURI();
        if(getPass(url)){//白名单url不校验token
            return false;
        }
        return true;
    }

    public boolean getPass(String methodUrl){
        for (String passUrl : cosSecurityProperties.getIgnoreUrlList()) {
            if (methodUrl.contains(passUrl)) {
                return true;
            }
        }
        return false;
    }

    @Override
    public Object run() {
        log.info("---WebTokenFilter---");
        RequestContext ctx = RequestContext.getCurrentContext();
        Model error = checkToken(ctx);
        if (error == null) {
            ctx.set("isSuccess", true);
        } else {
            ctx.setSendZuulResponse(false);
            ctx.set("isSuccess", false);
            fillResponse(ctx, error);
        }
        return null;
    }

    /**
     * 验证token
     * @param ctx
     * @return
     */
    private Model checkToken(RequestContext ctx) {
        HttpServletRequest request = ctx.getRequest();
        //验证请求token
        String token = request.getHeader("Authorization");
        try {
            if (StringUtils.isBlank(token)) {
                log.info("没有读取到token");
                return null;
            }
            String username = tokenHandleUtil.getUsernameByToken(token);

            if (StringUtils.isEmpty(username)) {
                return null;
            }
            ctx.addZuulRequestHeader("username", username);
        }catch (BusinessException e) {
            log.error("【token】校验不通过,token=[{}],errorMessage=[{}]", token, e.getMessage());
            log.error("error:",e);
            ctx.setSendZuulResponse(false);
            ctx.setResponseStatusCode(HttpStatus.UNAUTHORIZED.value());
            ResultMessage result = new ResultMessage(false, e.getErrMsg());
            ctx.getResponse().setCharacterEncoding("UTF-8");
            ctx.getResponse().setContentType("application/json; charset=utf-8");
            ctx.setResponseBody(JSON.toJSONString(result, SerializerFeature.BrowserCompatible));
//            ctx.set(CosSecurityConstants.KEY_IS_SECURITY_PASS, false);
        }
        return null;
    }

    /**
     * 设置response
     *
     * @param ctx
     * @param error
     */
    private void fillResponse(RequestContext ctx, Model error) {
        HttpServletRequest request = ctx.getRequest();
        HttpServletResponse response = ctx.getResponse();
        //序列化message
        String message = gson.toJson(error);
        log.info("response message:{}", message);

        String contentType = request.getHeader("Content-Type");
        String accept = request.getHeader("accept");
        if ((contentType != null && contentType.toLowerCase().contains("application/json"))
                || (accept  != null && accept.toLowerCase().contains("application/json"))) {
            response.setContentType("application/json;charset=UTF-8");
            response.setHeader("Access-Control-Allow-Origin", "*");
            response.setHeader("Access-Control-Allow-Methods", "*");
            response.setHeader("Access-Control-Allow-Headers", "*");
            ctx.setResponseBody(message);
        } else {
            ctx.setSendZuulResponse(false);
            response.setContentType("text/html;charset=UTF-8");
            response.setHeader("Access-Control-Allow-Origin", "*");
            response.setHeader("Access-Control-Allow-Methods", "*");
            response.setHeader("Access-Control-Allow-Headers", "*");
            ctx.setResponseBody("<h3>error<h3></br>no Permission denied");
        }
    }
}

token 解析工具类,包含续签逻辑:

package xxx;

import org.apache.commons.lang.StringUtils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.data.redis.connection.RedisConnection;
import org.springframework.data.redis.connection.RedisConnectionFactory;
import org.springframework.security.oauth2.common.DefaultOAuth2AccessToken;
import org.springframework.security.oauth2.provider.OAuth2Authentication;
import org.springframework.security.oauth2.provider.token.AuthenticationKeyGenerator;
import org.springframework.security.oauth2.provider.token.DefaultAuthenticationKeyGenerator;
import org.springframework.security.oauth2.provider.token.store.redis.JdkSerializationStrategy;
import org.springframework.security.oauth2.provider.token.store.redis.RedisTokenStoreSerializationStrategy;
import org.springframework.stereotype.Component;
import org.springframework.web.context.request.RequestContextHolder;
import org.springframework.web.context.request.ServletRequestAttributes;

import javax.servlet.http.HttpServletRequest;
import java.util.Date;
import java.util.Map;

/**
*
* 类名称:TokenHandleUtil
* 类描述:token处理工具类
* 创建人:pansh
* 创建时间:2021年12月22日 上午10:57:21
*
 */
@Component
public class TokenHandleUtil {

	private static final Logger logger = LoggerFactory.getLogger(TokenHandleUtil.class);

	private static final String ACCESS = "access:";
	private static final String AUTH_TO_ACCESS = "auth_to_access:";
	private static final String AUTH = "auth:";

	/**
	 * token剩余过期时间
	 */
	@Value("${oauth.token.expires.remain}")
	private int expriesRemain;

	/**
	 * token总过期时间
	 */
	@Value("${oauth.token.expires.total}")
	private int expriesTotal;

	@Autowired
    private RedisConnectionFactory redisConnectionFactory;

	private RedisTokenStoreSerializationStrategy serializationStrategy = new JdkSerializationStrategy();

	public String getAccessToken(String authorization){
		String accessToken = null;
		if(StringUtils.isEmpty(authorization)){
			HttpServletRequest request = ((ServletRequestAttributes) RequestContextHolder.getRequestAttributes()).getRequest();
			authorization = request.getHeader("Authorization");
		}
		if (StringUtils.isNotEmpty(authorization)) {
    		String[] auth = authorization.split(" ");
			if (auth.length == 2) {
				accessToken = auth[1];
			}
		}
		return accessToken;
	}

	/**
	 * 根据token获取用户名
	 */
	public String getUsernameByToken(String authorization){
		String username = null;
		// authorization形式:4ae3b353-ded2-4796-a921-4b16ddede9e8
		String token = authorization;
		if (StringUtils.isNotEmpty(token)) {
			// 根据token获取username值
			RedisConnection conn = getConnection();
			try {
				byte[] key = serializeKey(ACCESS + token);
				byte[] bytes = conn.get(key);
				DefaultOAuth2AccessToken accessToken = serializationStrategy.deserialize(bytes, DefaultOAuth2AccessToken.class);
				if(accessToken != null){
					Map<String, Object> userMap = accessToken.getAdditionalInformation();
					logger.info("userMap={}", userMap);
					if (userMap != null) {
						username = (String) userMap.get("username");
					}

					int expiresIn = accessToken.getExpiresIn();
					if (expiresIn < expriesRemain) {
						logger.info("token续签, authorization={}", authorization);
						Date newExpiration = new Date(System.currentTimeMillis() + expriesTotal*1000);
						accessToken.setExpiration(newExpiration);
						// 重置access中OAuth2AccessToken过期信息
						conn.set(key, serializationStrategy.serialize(accessToken));
						// 重置access、auth、uname_to_access的过期时间
						conn.expire(key, expriesTotal);

						byte[] authKey = serializeKey(AUTH + token);
						conn.expire(authKey, expriesTotal);
						if (StringUtils.isNotEmpty(username)) {
							conn.expire(serializeKey("uname_to_access:browser:" + username), expriesTotal);
						}

						// 重置auth_to_access过期时间
						bytes = conn.get(authKey);
						OAuth2Authentication authentication = serializationStrategy.deserialize(bytes, OAuth2Authentication.class);
						AuthenticationKeyGenerator authenticationKeyGenerator = new DefaultAuthenticationKeyGenerator();
						// 解析redis存储的authToAccess的token值
						String authToAccessToken = authenticationKeyGenerator.extractKey(authentication);
						byte[] authToAccessKey = serializeKey(AUTH_TO_ACCESS + authToAccessToken);

						// 重置auth_to_access过期时间
						bytes = conn.get(authToAccessKey);
						DefaultOAuth2AccessToken authToAccessTokenObj = serializationStrategy.deserialize(bytes, DefaultOAuth2AccessToken.class);
						authToAccessTokenObj.setExpiration(newExpiration);
						// 重置auth_to_access中OAuth2AccessToken过期信息
						conn.set(authToAccessKey, serializationStrategy.serialize(authToAccessTokenObj));
						conn.expire(authToAccessKey, expriesTotal);
					}
				}
			}finally{
				conn.close();
			}
		}
		return username;
	}

	private RedisConnection getConnection() {
		return redisConnectionFactory.getConnection();
	}

	private byte[] serializeKey(String object) {
		return serialize("" + object);
	}

	private byte[] serialize(String string) {
		return serializationStrategy.serialize(string);
	}
}