最近项目中引入oauth2框架,发现token存在固定30分钟失效问题;而用户在实际使用过程中,如果固定30分钟就要登出,重新登录方能使用,体验度极差;需要后端能够提供token续签的功能;
网上也搜索过不少资料,例如:
后端提供刷新token接口,前端加入定时器,依赖后端返回的过期时间定时刷新token;
但此方式无法满足当前项目的需要,项目允许同一个账号开启多个网页访问,需要登录;且前端使用的是本地session缓存,token只针对单个页面有效;同一个账号若是通过刷新token接口获取新的token,会导致其他界面的token失效;
为了解决项目token续签问题,通过源码分析,获得了突破;
思路一:如何实现token续签?
用户登录成功后,会在redis中缓存key值:
private static final String ACCESS = "access:";
private static final String AUTH_TO_ACCESS = "auth_to_access:";
private static final String AUTH = "auth:";
private static final String CLIENT_ID_TO_ACCESS = "client_id_to_access:";
private static final String UNAME_TO_ACCESS = "uname_to_access:";
这些key都是有过期时间的,若是想在原token的基础上实现自动续签,更新这几个key的过期时间就可以了;(后续实践证明除了要更新key的过期时间,还要更新对应的value里面的expiration)
思路二:什么时间实现token续签?
前端每一个请求都会携带token,在gateway里面增加一个过滤器,拦截所有的请求,对token进行解析,如果token还有10分钟(举例)过期,就重新设置token过期时间为30分钟;
(此处不可每次请求都对token续签,效率很低)
思路有了,可以垒代码了,在上代码之前,先看一下源码:
@Transactional
public OAuth2AccessToken createAccessToken(OAuth2Authentication authentication) throws AuthenticationException {
OAuth2AccessToken existingAccessToken = tokenStore.getAccessToken(authentication);
OAuth2RefreshToken refreshToken = null;
if (existingAccessToken != null) {
if (existingAccessToken.isExpired()) {
if (existingAccessToken.getRefreshToken() != null) {
refreshToken = existingAccessToken.getRefreshToken();
// The token store could remove the refresh token when the
// access token is removed, but we want to
// be sure...
tokenStore.removeRefreshToken(refreshToken);
}
tokenStore.removeAccessToken(existingAccessToken);
}
else {
// Re-store the access token in case the authentication has changed
tokenStore.storeAccessToken(existingAccessToken, authentication);
return existingAccessToken;
}
}
// Only create a new refresh token if there wasn't an existing one
// associated with an expired access token.
// Clients might be holding existing refresh tokens, so we re-use it in
// the case that the old access token
// expired.
if (refreshToken == null) {
refreshToken = createRefreshToken(authentication);
}
// But the refresh token itself might need to be re-issued if it has
// expired.
else if (refreshToken instanceof ExpiringOAuth2RefreshToken) {
ExpiringOAuth2RefreshToken expiring = (ExpiringOAuth2RefreshToken) refreshToken;
if (System.currentTimeMillis() > expiring.getExpiration().getTime()) {
refreshToken = createRefreshToken(authentication);
}
}
OAuth2AccessToken accessToken = createAccessToken(authentication, refreshToken);
tokenStore.storeAccessToken(accessToken, authentication);
// In case it was modified
refreshToken = accessToken.getRefreshToken();
if (refreshToken != null) {
tokenStore.storeRefreshToken(refreshToken, authentication);
}
return accessToken;
}
登录的时候会生成token 信息,oauth2会先根据登录请求参数去redis里面获取缓存的token对象信息。
若存在token,,或判断该对象的过期时间,若过期了会删除token,重新生成;若没有过期,则会返回原token对象;
若不存在token,直接生成新的token对象;
获取token信息代码很关键,取的是AUTH_TO_ACCESS 对象:
@Override
public OAuth2AccessToken getAccessToken(OAuth2Authentication authentication) {
String key = authenticationKeyGenerator.extractKey(authentication);
byte[] serializedKey = serializeKey(AUTH_TO_ACCESS + key);
byte[] bytes = null;
RedisConnection conn = getConnection();
try {
bytes = conn.get(serializedKey);
} finally {
conn.close();
}
OAuth2AccessToken accessToken = deserializeAccessToken(bytes);
if (accessToken != null
&& !key.equals(authenticationKeyGenerator.extractKey(readAuthentication(accessToken.getValue())))) {
// Keep the stores consistent (maybe the same user is
// represented by this authentication but the details have
// changed)
storeAccessToken(accessToken, authentication);
}
return accessToken;
}
了解了token的生成逻辑,对token的续签操作就很简单了;
过滤器代码很简单,因为涉及到项目隐私,直接上token续签逻辑:
以下是代码逻辑:
maven依赖:
<!-- auth2配置,只是为了引jar包解析token -->
<dependency>
<groupId>org.springframework.security.oauth</groupId>
<artifactId>spring-security-oauth2</artifactId>
<version>2.0.15.RELEASE</version>
<exclusions>
<!-- 必须去除,否则会加载 security 配置导致请求无效 -->
<exclusion>
<groupId>org.springframework.security</groupId>
<artifactId>spring-security-config</artifactId>
</exclusion>
<exclusion>
<groupId>org.springframework.security</groupId>
<artifactId>spring-security-web</artifactId>
</exclusion>
</exclusions>
</dependency>
过滤器代码:
package xxx;
import com.alibaba.fastjson.JSON;
import com.google.gson.Gson;
import com.netflix.zuul.ZuulFilter;
import com.netflix.zuul.context.RequestContext;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.lang.StringUtils;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.cloud.netflix.zuul.filters.support.FilterConstants;
import org.springframework.data.redis.core.RedisTemplate;
import org.springframework.http.HttpStatus;
import org.springframework.stereotype.Component;
import org.springframework.ui.Model;
import javax.annotation.Resource;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
@Slf4j
@Component
public class WebTokenFilter extends ZuulFilter {
private Gson gson = new Gson();
@Autowired
private RedisTemplate<String, String> redisTemplate;
@Resource
private CosSecurityProperties cosSecurityProperties;
@Override
public String filterType() {
return FilterConstants.PRE_TYPE;
}
@Override
public int filterOrder() {
return -1;
}
@Override
public boolean shouldFilter() {
RequestContext ctx = RequestContext.getCurrentContext();
String url = ctx.getRequest().getRequestURI();
if(getPass(url)){//白名单url不校验token
return false;
}
return true;
}
public boolean getPass(String methodUrl){
for (String passUrl : cosSecurityProperties.getIgnoreUrlList()) {
if (methodUrl.contains(passUrl)) {
return true;
}
}
return false;
}
@Override
public Object run() {
log.info("---WebTokenFilter---");
RequestContext ctx = RequestContext.getCurrentContext();
Model error = checkToken(ctx);
if (error == null) {
ctx.set("isSuccess", true);
} else {
ctx.setSendZuulResponse(false);
ctx.set("isSuccess", false);
fillResponse(ctx, error);
}
return null;
}
/**
* 验证token
* @param ctx
* @return
*/
private Model checkToken(RequestContext ctx) {
HttpServletRequest request = ctx.getRequest();
//验证请求token
String token = request.getHeader("Authorization");
try {
if (StringUtils.isBlank(token)) {
log.info("没有读取到token");
return null;
}
String username = tokenHandleUtil.getUsernameByToken(token);
if (StringUtils.isEmpty(username)) {
return null;
}
ctx.addZuulRequestHeader("username", username);
}catch (BusinessException e) {
log.error("【token】校验不通过,token=[{}],errorMessage=[{}]", token, e.getMessage());
log.error("error:",e);
ctx.setSendZuulResponse(false);
ctx.setResponseStatusCode(HttpStatus.UNAUTHORIZED.value());
ResultMessage result = new ResultMessage(false, e.getErrMsg());
ctx.getResponse().setCharacterEncoding("UTF-8");
ctx.getResponse().setContentType("application/json; charset=utf-8");
ctx.setResponseBody(JSON.toJSONString(result, SerializerFeature.BrowserCompatible));
// ctx.set(CosSecurityConstants.KEY_IS_SECURITY_PASS, false);
}
return null;
}
/**
* 设置response
*
* @param ctx
* @param error
*/
private void fillResponse(RequestContext ctx, Model error) {
HttpServletRequest request = ctx.getRequest();
HttpServletResponse response = ctx.getResponse();
//序列化message
String message = gson.toJson(error);
log.info("response message:{}", message);
String contentType = request.getHeader("Content-Type");
String accept = request.getHeader("accept");
if ((contentType != null && contentType.toLowerCase().contains("application/json"))
|| (accept != null && accept.toLowerCase().contains("application/json"))) {
response.setContentType("application/json;charset=UTF-8");
response.setHeader("Access-Control-Allow-Origin", "*");
response.setHeader("Access-Control-Allow-Methods", "*");
response.setHeader("Access-Control-Allow-Headers", "*");
ctx.setResponseBody(message);
} else {
ctx.setSendZuulResponse(false);
response.setContentType("text/html;charset=UTF-8");
response.setHeader("Access-Control-Allow-Origin", "*");
response.setHeader("Access-Control-Allow-Methods", "*");
response.setHeader("Access-Control-Allow-Headers", "*");
ctx.setResponseBody("<h3>error<h3></br>no Permission denied");
}
}
}
token 解析工具类,包含续签逻辑:
package xxx;
import org.apache.commons.lang.StringUtils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.data.redis.connection.RedisConnection;
import org.springframework.data.redis.connection.RedisConnectionFactory;
import org.springframework.security.oauth2.common.DefaultOAuth2AccessToken;
import org.springframework.security.oauth2.provider.OAuth2Authentication;
import org.springframework.security.oauth2.provider.token.AuthenticationKeyGenerator;
import org.springframework.security.oauth2.provider.token.DefaultAuthenticationKeyGenerator;
import org.springframework.security.oauth2.provider.token.store.redis.JdkSerializationStrategy;
import org.springframework.security.oauth2.provider.token.store.redis.RedisTokenStoreSerializationStrategy;
import org.springframework.stereotype.Component;
import org.springframework.web.context.request.RequestContextHolder;
import org.springframework.web.context.request.ServletRequestAttributes;
import javax.servlet.http.HttpServletRequest;
import java.util.Date;
import java.util.Map;
/**
*
* 类名称:TokenHandleUtil
* 类描述:token处理工具类
* 创建人:pansh
* 创建时间:2021年12月22日 上午10:57:21
*
*/
@Component
public class TokenHandleUtil {
private static final Logger logger = LoggerFactory.getLogger(TokenHandleUtil.class);
private static final String ACCESS = "access:";
private static final String AUTH_TO_ACCESS = "auth_to_access:";
private static final String AUTH = "auth:";
/**
* token剩余过期时间
*/
@Value("${oauth.token.expires.remain}")
private int expriesRemain;
/**
* token总过期时间
*/
@Value("${oauth.token.expires.total}")
private int expriesTotal;
@Autowired
private RedisConnectionFactory redisConnectionFactory;
private RedisTokenStoreSerializationStrategy serializationStrategy = new JdkSerializationStrategy();
public String getAccessToken(String authorization){
String accessToken = null;
if(StringUtils.isEmpty(authorization)){
HttpServletRequest request = ((ServletRequestAttributes) RequestContextHolder.getRequestAttributes()).getRequest();
authorization = request.getHeader("Authorization");
}
if (StringUtils.isNotEmpty(authorization)) {
String[] auth = authorization.split(" ");
if (auth.length == 2) {
accessToken = auth[1];
}
}
return accessToken;
}
/**
* 根据token获取用户名
*/
public String getUsernameByToken(String authorization){
String username = null;
// authorization形式:4ae3b353-ded2-4796-a921-4b16ddede9e8
String token = authorization;
if (StringUtils.isNotEmpty(token)) {
// 根据token获取username值
RedisConnection conn = getConnection();
try {
byte[] key = serializeKey(ACCESS + token);
byte[] bytes = conn.get(key);
DefaultOAuth2AccessToken accessToken = serializationStrategy.deserialize(bytes, DefaultOAuth2AccessToken.class);
if(accessToken != null){
Map<String, Object> userMap = accessToken.getAdditionalInformation();
logger.info("userMap={}", userMap);
if (userMap != null) {
username = (String) userMap.get("username");
}
int expiresIn = accessToken.getExpiresIn();
if (expiresIn < expriesRemain) {
logger.info("token续签, authorization={}", authorization);
Date newExpiration = new Date(System.currentTimeMillis() + expriesTotal*1000);
accessToken.setExpiration(newExpiration);
// 重置access中OAuth2AccessToken过期信息
conn.set(key, serializationStrategy.serialize(accessToken));
// 重置access、auth、uname_to_access的过期时间
conn.expire(key, expriesTotal);
byte[] authKey = serializeKey(AUTH + token);
conn.expire(authKey, expriesTotal);
if (StringUtils.isNotEmpty(username)) {
conn.expire(serializeKey("uname_to_access:browser:" + username), expriesTotal);
}
// 重置auth_to_access过期时间
bytes = conn.get(authKey);
OAuth2Authentication authentication = serializationStrategy.deserialize(bytes, OAuth2Authentication.class);
AuthenticationKeyGenerator authenticationKeyGenerator = new DefaultAuthenticationKeyGenerator();
// 解析redis存储的authToAccess的token值
String authToAccessToken = authenticationKeyGenerator.extractKey(authentication);
byte[] authToAccessKey = serializeKey(AUTH_TO_ACCESS + authToAccessToken);
// 重置auth_to_access过期时间
bytes = conn.get(authToAccessKey);
DefaultOAuth2AccessToken authToAccessTokenObj = serializationStrategy.deserialize(bytes, DefaultOAuth2AccessToken.class);
authToAccessTokenObj.setExpiration(newExpiration);
// 重置auth_to_access中OAuth2AccessToken过期信息
conn.set(authToAccessKey, serializationStrategy.serialize(authToAccessTokenObj));
conn.expire(authToAccessKey, expriesTotal);
}
}
}finally{
conn.close();
}
}
return username;
}
private RedisConnection getConnection() {
return redisConnectionFactory.getConnection();
}
private byte[] serializeKey(String object) {
return serialize("" + object);
}
private byte[] serialize(String string) {
return serializationStrategy.serialize(string);
}
}