背景:最近对一个老项目进行改造,使其支持多机部署,其中最关键的一点就是实现多机session共享。项目有多老呢,jdk版本是1.6,spring版本是3.2,jedis版本是2.2。

1.方案的确定

接到这项目任务后,理所当然地google了,一搜索,发现解决方案分为两大类:

  1. tomcat的session管理
  2. spring-session

对于“tomcat的session管理”,很不幸,线上代码用的是resin,直接pass了;

对于“spring-session”,这是spring全家桶系列,项目中正好使用了spring,可以很方便集成,并且原业务代码不用做任何发动,似乎是个不错的选择。但是,在引入spring-session过程中发生了意外:项目中使用的jedis版本不支持!项目中使用的jedis版本是2.2,而spring-session中使用的jedis版本是2.5,有些命令像"set PX/EX NX/XX",项目中使用的redis是不支持的,但spring-session引入的jedis支持,直接引入的话,风险难以把控,而升级项目中的redis版本的话,代价就比较高了。

综上所述,以上两个方案都行不能,既然第三方组件行不通,那就只能自主实现了。

通过参考一些开源项目的实现,自主实现分布式session的关键点有以下几点:

  • 使用servlet的filter功能来接管session;
  • 使用redis来管理全局session

2.使用filter来接管session

为了实现此功能,我们定义如下几个类:

  • SessionFilter:servlet的过滤器,替换原始的session
  • DistributionSession:分布式session,实现了HttpSession类;
  • SessionRequestWrapper:在filter中用来接管session;

类的具体实现如下:

SessionFilter类

/**
 * 该类实现了Filter
 */
public class SessionFilter implements Filter {

    /** redis的相关操作 */
    @Autowired
    private RedisExtend redisExtend;

    @Override
    public void init(FilterConfig filterConfig) throws ServletException {

    }

    @Override
    public void doFilter(ServletRequest request, ServletResponse response, FilterChain chain)
            throws IOException, ServletException {
	//这里将request转换成自主实现的SessionRequestWrapper
	//经过传递后,项目中获取到的request就是SessionRequestWrapper
        ServletRequest servletRequest = new SessionRequestWrapper((HttpServletRequest)request,
                (HttpServletResponse)response, redisExtend);
        chain.doFilter(servletRequest, response);
    }

    @Override
    public void destroy() {

    }
}

ServletRequestWrap类

/**
 * 该类继承了HttpServletRequestWrapper并重写了session相关类
 * 之后项目中通过'request.getSession()'就是调用此类的getSession()方法了
 */
public class SessionRequestWrapper extends HttpServletRequestWrapper {

    private final Logger log = LoggerFactory.getLogger(SessionRequestWrapper.class);
	
    /** 原本的requst,用来获取原始的session */
    private HttpServletRequest request;
	
    /** 原始的response,操作cookie会用到 */
    private HttpServletResponse response;
	
    /** redis命令的操作类 */
    private RedisExtend redisExtend;
	
    /** session的缓存,存在本机的内存中 */
    private MemorySessionCache sessionCache;
	
    /** 自定义sessionId */
    private String sid;

    public SessionRequestWrapper(HttpServletRequest request, HttpServletResponse response, RedisExtend redisExtend) {
        super(request);
        this.request = request;
        this.response = response;
        this.redisExtend = redisExtend;
        this.sid = getSsessionIdFromCookie();
        this.sessionCache = MemorySessionCache.initAndGetInstance(request.getSession().getMaxInactiveInterval());
    }
	
   /**
    * 获取session的操作
    */
    @Override
    public HttpSession getSession(boolean create) {
       if (!create) {
           return null;
       }
       HttpSession httpSession = request.getSession();
       try {
           return sessionCache.getSession(httpSession.getId(), new Callable<DistributionSession>() {
                @Override
                public DistributionSession call() throws Exception {
                    return new DistributionSession(request, redisExtend, sessionCache, sid);
                }
           });
        } catch (Exception e) {
            log.error("从sessionCache获取session出错:{}", ExceptionUtils.getStackTrace(e));
            return new DistributionSession(request, redisExtend, sessionCache, sid);
        }
        return null;
    }

    @Override
    public HttpSession getSession() {
        return getSession(true);
    }
	
   /**
    * 从cookie里获取自定义sessionId,如果没有,则创建一个
    */
    private String getSsessionIdFromCookie() {
        String sid = CookieUtil.getCookie(SessionUtil.SESSION_KEY, this);
        if (StringUtils.isEmpty(sid)) {
            sid = java.util.UUID.randomUUID().toString();
            CookieUtil.setCookie(SessionUtil.SESSION_KEY, sid, this, response);
            this.setAttribute(SessionUtil.SESSION_KEY, sid);
        }
        return sid;
    }

}

DistributionSession类

/*
 * 分布式session的实现类,实现了session
 * 项目中由request.getSession()获取到的session就是该类
 */
public class DistributionSession implements HttpSession {

    private final Logger log = LoggerFactory.getLogger(DistributionSession.class);

    /** 自定义sessionId */
    private String sid;
	
    /** 原始的session */
    private HttpSession httpSession;
	
    /** redis操作类 */
    private RedisExtend redisExtend;

    /** session的本地内存缓存 */
    private MemorySessionCache sessionCache;

    /** 最后访问时间 */
    private final String LAST_ACCESSED_TIME = "lastAccessedTime";
    /** 创建时间 */
    private final String CREATION_TIME = "creationTime";

    public DistributionSession(HttpServletRequest request, RedisExtend redisExtend,
                               MemorySessionCache sessionCache, String sid) {
        this.httpSession = request.getSession();
        this.sid = sid;
        this.redisExtend = redisExtend;
        this.sessionCache = sessionCache;
        if(this.isNew()) {
            this.setAttribute(CREATION_TIME, System.currentTimeMillis());
        }
        this.refresh();
    }

    @Override
    public String getId() {
        return this.sid;
    }

    @Override
    public ServletContext getServletContext() {
        return httpSession.getServletContext();
    }

    @Override
    public Object getAttribute(String name) {
        byte[] content = redisExtend.hget(SafeEncoder.encode(SessionUtil.getSessionKey(sid)),
                SafeEncoder.encode(name));
        if(ArrayUtils.isNotEmpty(content)) {
            try {
                return ObjectSerializerUtil.deserialize(content);
            } catch (Exception e) {
                log.error("获取属性值失败:{}", ExceptionUtils.getStackTrace(e));
            }
        }
        return null;
    }

    @Override
    public Enumeration<String> getAttributeNames() {
        byte[] data = redisExtend.get(SafeEncoder.encode(SessionUtil.getSessionKey(sid)));
        if(ArrayUtils.isNotEmpty(data)) {
            try {
                Map<String, Object> map = (Map<String, Object>) ObjectSerializerUtil.deserialize(data);
                return (new Enumerator(map.keySet(), true));
            } catch (Exception e) {
                log.error("获取所有属性名失败:{}", ExceptionUtils.getStackTrace(e));
            }
        }
        return new Enumerator(new HashSet<String>(), true);
    }

    @Override
    public void setAttribute(String name, Object value) {
        if(null != name && null != value) {
            try {
                redisExtend.hset(SafeEncoder.encode(SessionUtil.getSessionKey(sid)),
                        SafeEncoder.encode(name), ObjectSerializerUtil.serialize(value));
            } catch (Exception e) {
                log.error("添加属性失败:{}", ExceptionUtils.getStackTrace(e));
            }
        }
    }

    @Override
    public void removeAttribute(String name) {
        if(null == name) {
            return;
        }
        redisExtend.hdel(SafeEncoder.encode(SessionUtil.getSessionKey(sid)), SafeEncoder.encode(name));
    }

    @Override
    public boolean isNew() {
        Boolean result = redisExtend.exists(SafeEncoder.encode(SessionUtil.getSessionKey(sid)));
        if(null == result) {
            return false;
        }
        return result;
    }

    @Override
    public void invalidate() {
        sessionCache.invalidate(sid);
        redisExtend.del(SafeEncoder.encode(SessionUtil.getSessionKey(sid)));
    }

    @Override
    public int getMaxInactiveInterval() {
        return httpSession.getMaxInactiveInterval();
    }

    @Override
    public long getCreationTime() {
        Object time = this.getAttribute(CREATION_TIME);
        if(null != time) {
            return (Long)time;
        }
        return 0L;
    }

    @Override
    public long getLastAccessedTime() {
        Object time = this.getAttribute(LAST_ACCESSED_TIME);
        if(null != time) {
            return (Long)time;
        }
        return 0L;
    }

    @Override
    public void setMaxInactiveInterval(int interval) {
        httpSession.setMaxInactiveInterval(interval);
    }

    @Override
    public Object getValue(String name) {
        throw new NotImplementedException();
    }

    @Override
    public HttpSessionContext getSessionContext() {
        throw new NotImplementedException();
    }

    @Override
    public String[] getValueNames() {
        throw new NotImplementedException();
    }

    @Override
    public void putValue(String name, Object value) {
        throw new NotImplementedException();
    }

    @Override
    public void removeValue(String name) {
        throw new NotImplementedException();
    }

    /**
     * 更新过期时间
     * 根据session的过期规则,每次访问时,都要更新redis的过期时间
     */
    public void refresh() {
        //更新最后访问时间
        this.setAttribute(LAST_ACCESSED_TIME, System.currentTimeMillis());
        //刷新有效期
        redisExtend.expire(SafeEncoder.encode(SessionUtil.getSessionKey(sid)),
                httpSession.getMaxInactiveInterval());
    }

    /**
     * Enumeration 的实现
     */
    class Enumerator implements Enumeration<String> {

        public Enumerator(Collection<String> collection) {
            this(collection.iterator());
        }

        public Enumerator(Collection<String> collection, boolean clone) {
            this(collection.iterator(), clone);
        }

        public Enumerator(Iterator<String> iterator) {
            super();
            this.iterator = iterator;
        }

        public Enumerator(Iterator<String> iterator, boolean clone) {
            super();
            if (!clone) {
                this.iterator = iterator;
            }
            else {
                List<String> list = new ArrayList<String>();
                while (iterator.hasNext()) {
                    list.add(iterator.next());
                }
                this.iterator = list.iterator();
            }
        }

        private Iterator<String> iterator = null;

        @Override
        public boolean hasMoreElements() {
            return (iterator.hasNext());
        }

        @Override
        public String nextElement() throws NoSuchElementException {
            return (iterator.next());
        }

    }
}

由项目中的redis操作类RedisExtend是由spring容器来实例化的,为了能在DistributionSession类中使用该实例,需要使用spring容器来实例化filter,在spring的配置文件中添加以下内容:

<!-- 分布式 session的filter -->
<bean id="sessionFilter" class="com.xxx.session.SessionFilter"></bean>

在web.xml中配置filter时,也要通过spring来管理:

<!-- 一般来说,该filter应该位于所有的filter之前。 -->
<filter>
	<!-- spring实例化时的实例名称 -->
	<filter-name>sessionFilter</filter-name>
	<!-- 采用spring代理来实现filter -->
	<filter-class>org.springframework.web.filter.DelegatingFilterProxy</filter-class>
	<init-param>
		<param-name>targetFilterLifecycle</param-name>
		<param-value>true</param-value>
	</init-param>
</filter>
<filter-mapping>
	<filter-name>sessionFilter</filter-name>
	<url-pattern>/*</url-pattern>
</filter-mapping>

3.全局的session管理:redis

使用redis来管理session时,对象应该使用什么序列化方式?首先,理所当然地想到使用json。我们来看看json序列化时究竟行不行。

在项目中,往session设置值和从session中获取值的操作分别如下:

/** 假设现在有一个user类,属性有:name与age*/
User user = new User("a", 13);
request.getSession().setAttribute("user", user);
//通过以下方式获取
User user = (User)request.getSession().getAttribute("user");

DistributionSession中实现setAttribute()方法时,可以采用如下方式:

public void setAttribute(String name, Object object) {
	String jsonString = JsonUtil.toJson(object);
	redisExtend.hset(this.sid, name, jsonString);
}

但在getAttribute()方法的实现上,json反序列化就无能为力了:

public Object getAttribute(String name) {
	String jsonString = redisExtend.hget(this.sid, name);
	return JsonUtil.toObject(jsonString, Object.class);
}

在json反序列化时,如果不指定类型,或指定为Object时,json序列化就有问题了:

  • fastJson会序列化成JSONObject
  • gson与jackson会序列化成Map
//这里的object实际类型是JSONObject或Map,取决于使用的json工具包
Object object = request.getSession().getAttribute("user");
//在类型转换时,这一句会报错
User user = (User)object;

有个小哥哥就比较聪明,在序列化时,把参数的类型一并带上了,如上面的json序列化成com.xxx.User:{"name":"a","age":13}再保存到redis中,这样在反序化时,先获取到com.xxx.User类,再来做json反序列:

String jsonString = redisExtend.hget(this.sid, name);
String[] array = jsonString.split(":");
Class type = Class.forname(array[0]);
Object obj = JsonUtil.toObject(array[1], type);

这样确实能解决一部分问题,但如果反序列化参数中有泛型就无能为力了!现在session存储的属性如下:

List<User> list = new ArrayList<>();
User user1 = new User("a", 13);
User user2 = new User("b", 12);
list.add(user1);
list.add(user2);
request.getSession().setAttribute("users", list);

这种情况下,序列出来的json会这样:

java.util.List:[{"name":"a","age":13}, {"name":"b","age":12}]

在反序列化时,会这样:

Object obj = JsonUtil.toObject(array[1], List.class);

到这里确实是没问题的,但我们可以看到泛型信息丢失了,我们在调用getAttribute()时,会这样调用:

//这里的obj实现类型是List,至于List的泛型类型,是JSONObject或Map,取决于使用的json工具包
Object obj = request.getSession().getAttribute("users");
//如果这样调用不用报错:List users = (List)obj;
//加上泛型值后,java编译器会认为是要把JSONObject或Map转成User,还是会导致类型转换错误
List<User> users = (List)obj;

这一步就会出现问题了,原因是在反序列化时,只传了List,没有指定List里面放的是什么对象,Json反序列化是按Object类型来处理的,前面提到fastJson会序列化成JSONObject,gson与jackson会序列化成Map,直接强转成User一定会报错。

为了解决这个问题,这里直接使用java的对象序列化方法:

public class ObjectSerializerUtil {

    /**
     * 序列化
     * @param obj
     * @return
     * @throws IOException
     */
    public static byte[] serialize(Object obj) throws IOException {
        byte[] bytes;
        ByteArrayOutputStream baos = null;
        ObjectOutputStream oos = null;
        try {
            baos = new ByteArrayOutputStream();
            oos = new ObjectOutputStream(baos);
            oos.writeObject(obj);
            bytes = baos.toByteArray();
        } finally {
            if(null != oos) {
                oos.close();
            }
            if(null != baos) {
                baos.close();
            }
        }
        return bytes;
    }

    /**
     * 反序列化
     * @param bytes
     * @return
     * @throws IOException
     * @throws ClassNotFoundException
     */
    public static Object deserialize(byte[] bytes) throws IOException, ClassNotFoundException {
        Object obj;
        ByteArrayInputStream bais = null;
        ObjectInputStream ois = null;
        try {
            bais = new ByteArrayInputStream(bytes);
            ois = new ObjectInputStream(bais);
            obj = ois.readObject();
        } finally {
            if(null != ois) {
                ois.close();
            }
            if(null != bais) {
                bais.close();
            }
        }
        return obj;
    }
}

4.jessionId的处理

session共享的关键就在于jessionId的处理了,正是cookie里有了jessonId的存在,http才会有所谓的登录/注销一说。对于jessionId,先提两个问题:

  1. jessionId是由客户端生成还是由服务端生成的?
  2. 如果客户端传了jessionId,服务端就不用再生成了?

对于第一个问题,jessionId是在服务端创建的,当用户首次访问时,服务端发现没有传jessionId,会在服务端分配一个jessionId,做一些初始化工作,并把jessionId返回到客户端。客户端收到后,会保存在cookie里,下次请求时,会把这个jessionId传过去,这样当服务端再次接收到请求后,不知道该用户之前已经访问过了,不用再做初始化工作了。

如果客户端的cookie里存在了jessionId,是不是就不会再在服务端生成jessionId了呢?答案是不一定。当服务端接收到jessionId后,会判断该jessionId是否由当前服务端创建,如果是,则使用此jessionId,否则会丢弃此jessionId而重新创建一个jessionId。

在集群环境中,客户端C第一次访问了服务端的S1服务器,并创建了一个jessionId1,当下一次再访问的时候,如果访问到的是服务端的S2服务器,此时客户端虽然上送了jessionId1,但S2服务器并不认,它会把C当作是首次访问,并分配新的jessionId,这就意味着用户需要重新登录。这种情景下,使用jessionId来区分用户就不太合理了。

为了解决这个问题,这里使用在cookie中保存自定义的sessionKey的形式来解决这个问题:

//完整代码见第二部分SessionRequestWrapper类
private String getSsessionIdFromCookie() {
	String sid = CookieUtil.getCookie(SessionUtil.SESSION_KEY, this);
	if (StringUtils.isEmpty(sid)) {
		sid = java.util.UUID.randomUUID().toString();
		CookieUtil.setCookie(SessionUtil.SESSION_KEY, sid, this, response);
		this.setAttribute(SessionUtil.SESSION_KEY, sid);
	}
	return sid;
}

cookie的操作代码如下:

CookieUtil类

public class CookieUtil {

	protected static final Log logger = LogFactory.getLog(CookieUtil.class);

	/**
	 * 设置cookie</br>
	 * 
	 * @param name
	 *            cookie名称
	 * @param value
	 *            cookie值
	 * @param request
	 *            http请求
	 * @param response
	 *            http响应
	 */
	public static void setCookie(String name, String value, HttpServletRequest request, HttpServletResponse response) {
		int maxAge = -1;
		CookieUtil.setCookie(name, value, maxAge, request, response);
	}

	/**
	 * 设置cookie</br>
	 * 
	 * @param name
	 *            cookie名称
	 * @param value
	 *            cookie值
	 * @param maxAge
	 *            最大生存时间
	 * @param request
	 *            http请求
	 * @param response
	 *            http响应
	 */
	public static void setCookie(String name, String value, int maxAge, HttpServletRequest request, HttpServletResponse response) {
		String domain = request.getServerName();
		setCookie(name, value, maxAge, domain, response);
	}

	public static void setCookie(String name, String value, int maxAge, String domain, HttpServletResponse response) {
		AssertUtil.assertNotEmpty(name, new NullPointerException("cookie名称不能为空."));
		AssertUtil.assertNotNull(value, new NullPointerException("cookie值不能为空."));

		Cookie cookie = new Cookie(name, value);
		cookie.setDomain(domain);
		cookie.setMaxAge(maxAge);
		cookie.setPath("/");
		response.addCookie(cookie);
	}

	/**
	 * 获取cookie的值</br>
	 * 
	 * @param name
	 *            cookie名称
	 * @param request
	 *            http请求
	 * @return cookie值
	 */
	public static String getCookie(String name, HttpServletRequest request) {
		AssertUtil.assertNotEmpty(name, new NullPointerException("cookie名称不能为空."));

		Cookie[] cookies = request.getCookies();
		if (cookies == null) {
			return null;
		}
		for (int i = 0; i < cookies.length; i++) {
			if (name.equalsIgnoreCase(cookies[i].getName())) {
				return cookies[i].getValue();
			}
		}
		return null;
	}

	/**
	 * 删除cookie</br>
	 * 
	 * @param name
	 *            cookie名称
	 * @param request
	 *            http请求
	 * @param response
	 *            http响应
	 */
	public static void deleteCookie(String name, HttpServletRequest request, HttpServletResponse response) {
		AssertUtil.assertNotEmpty(name, new RuntimeException("cookie名称不能为空."));
		CookieUtil.setCookie(name, "", -1, request, response);
	}

	/**
	 * 删除cookie</br>
	 * 
	 * @param name
	 *            cookie名称
	 * @param response
	 *            http响应
	 */
	public static void deleteCookie(String name, String domain, HttpServletResponse response) {
		AssertUtil.assertNotEmpty(name, new NullPointerException("cookie名称不能为空."));
		CookieUtil.setCookie(name, "", -1, domain, response);
	}

}

这样之后,项目中使用自定义sid来标识客户端,并且自定义sessionKey的处理全部由自己处理,不会像jessionId那样会判断是否由当前服务端生成。

5.进一步优化

1)DistributionSession并不需要每次重新生成 在SessionRequestWrapper类中,获取session的方法如下:

@Override
public HttpSession getSession(boolean create) {
	if (create) {
		HttpSession httpSession = request.getSession();
		try {
			return sessionCache.getSession(httpSession.getId(), new Callable<DistributionSession>() {
				@Override
				public DistributionSession call() throws Exception {
					return new DistributionSession(request, redisExtend, sessionCache, sid);
				}
			});
		} catch (Exception e) {
			log.error("从sessionCache获取session出错:{}", ExceptionUtils.getStackTrace(e));
			return new DistributionSession(request, redisExtend, sessionCache, sid);
		}
	} else {
		return null;
	}
}

这里采用了缓存技术,使用sid作为key来缓存DistributionSession,如果不采用缓存,则获取session的操作如下:

@Override
public HttpSession getSession(boolean create) {
	return new DistributionSession(request, redisExtend, sessionCache, sid);
}

如果同一sid多次访问同一服务器,并不需要每次都创建一个DistributionSession,这里就使用缓存来存储这些DistributionSession,这样下次访问时,就不用再次生成DistributionSession对象了。

缓存类如下:

MemorySessionCache类

public class MemorySessionCache {

    private Cache<String, DistributionSession> cache;

    private static AtomicBoolean initFlag = new AtomicBoolean(false);

    /**
     * 初始化,并返回实例
     * @param maxInactiveInterval
     * @return
     */
    public static MemorySessionCache initAndGetInstance(int maxInactiveInterval) {
        MemorySessionCache sessionCache = getInstance();
        //保证全局只初始化一次
        if(initFlag.compareAndSet(false, true)) {
            sessionCache.cache = CacheBuilder.newBuilder()
                    //考虑到并没有多少用户会同时在线,这里将缓存数设置为100,超过的值不保存在缓存中
                    .maximumSize(100)
                    //多久未访问,就清除
                    .expireAfterAccess(maxInactiveInterval, TimeUnit.SECONDS).build();
        }
        return sessionCache;
    }

    /**
     * 获取session
     * @param sid
     * @param callable
     * @return
     * @throws ExecutionException
     */
    public DistributionSession getSession(String sid, Callable<DistributionSession> callable)
            throws ExecutionException {
        DistributionSession session = getInstance().cache.get(sid, callable);
        session.refresh();
        return session;
    }

    /**
     * 将session从cache中删除
     * @param sid
     */
    public void invalidate(String sid) {
        getInstance().cache.invalidate(sid);
    }


    /**
     * 单例的内部类实现方式
     */
    private MemorySessionCache() {

    }

    private static class MemorySessionCacheHolder {
        private static final MemorySessionCache singletonPattern = new MemorySessionCache();
    }

    private static MemorySessionCache getInstance() {
        return MemorySessionCacheHolder.singletonPattern;
    }

}

总结:使用redis自主实现session共享,关键点有三个:

  1. 使用filter来接管全局session;
  2. 将java对象序列化成二进制数据保存到redis,反序列化时也使用java对象反序列化方式;
  3. 原始的jessionId可能会丢弃并重新生成,需要自主操作cookie重新定义sessionKey.