限流模块主要是三种限流的算法+aop实现

@Target({ElementType.TYPE})
@Retention(RetentionPolicy.RUNTIME)
@Import({RedisBloomFilterRegistar.class, RedisLimiterRegistar.class})
public @interface EnableRedisAux {
    String[] bloomFilterPath() default "";
    boolean enableLimit() default false;
    boolean transaction() default false;

}

然后spring会加载@Import的类,被注入的类通过获取注解上的信息来确定是否启用切面

public class RedisLimiterRegistar implements ImportBeanDefinitionRegistrar {
    @Override
    public void registerBeanDefinitions(AnnotationMetadata importingClassMetadata, BeanDefinitionRegistry registry) {
        Map<String, Object> attributes = importingClassMetadata
                .getAnnotationAttributes(EnableRedisAux.class.getCanonicalName());
        Boolean enableLimit = (Boolean) attributes.get("enableLimit");
        //如果开启限流,则扫描组件、初始化对应的限流器和切面
        if(enableLimit){
            ClassPathBeanDefinitionScanner scanConfigure =
                    new ClassPathBeanDefinitionScanner(registry, true);
            scanConfigure.scan("com.opensource.redisaux.limiter.autoconfigure");
        }
    }

}

然后到配置类,主要是加载三个脚本,并且把对应的限流器缓存起来并加载切面类

@Configuration
@AutoConfigureAfter(RedisAutoConfiguration.class)
@ConditionalOnBean(RedisTemplate.class)
public class RedisLimiterAutoConfiguration {

    @Autowired
    @Qualifier(BloomFilterConsts.INNERTEMPLATE)
    private RedisTemplate redisTemplate;

    /**
     * 滑动窗口的lua脚本,步骤:
     * 1.记录当前时间戳
     * 2.把小于(当前时间戳-窗口大小得到的时间戳)的key删掉
     * 3.返回该窗口内的成员个数
     * @return
     */
    @Bean
    public DefaultRedisScript windowLimitScript() {
        DefaultRedisScript script = new DefaultRedisScript();
        script.setResultType(Boolean.class);
        script.setScriptText("redis.call('zadd',KEYS[1],ARGV[1],ARGV[1]) redis.call('zremrangebyscore',KEYS[1],0,ARGV[2]) return redis.call('zcard',KEYS[1]) <= tonumber(ARGV[3])");
        return script;
    }

    /**
     * 具体思想看lua脚本注释
     * @return
     */
    @Bean
    public DefaultRedisScript tokenLimitScript() {
        DefaultRedisScript script = new DefaultRedisScript();
        script.setResultType(Long.class);
        script.setLocation(new ClassPathResource("TokenRateLimit.lua"));
        return script;
    }
    /**
     * 具体思想看lua脚本注释
     * @return
     */
    @Bean
    public DefaultRedisScript funnelLimitScript() {
        DefaultRedisScript script = new DefaultRedisScript();
        script.setResultType(Boolean.class);
        script.setLocation(new ClassPathResource("FunnelRateLimit.lua"));
        return script;
    }
    /**
     * 切面
     * @return
     */
    @Bean
    public LimiterAspect limiterAspect(){
        Map<Integer, BaseRateLimiter> map = new HashMap();
        map.put(BaseRateLimiter.WINDOW_LIMITER, new WindowRateLimiter(redisTemplate, windowLimitScript()));
        map.put(BaseRateLimiter.TOKEN_LIMITER, new TokenRateLimiter(redisTemplate, tokenLimitScript()));
        map.put(BaseRateLimiter.FUNNEL_LIMITER, new FunnelRateLimiter(redisTemplate, funnelLimitScript()));
        return new LimiterAspect(map);
    }



}

切面类,主要是查看当前方法的注解所对应的类型是哪个,然后去map里面找对应实体类,进行限流操作,操作不通过则调用同类的失败方法,默认不传原有方法的参数,如果启用了传输参数,请把原来方法的参数搬到失败方法那里,并且不可以同名

@SuppressWarnings("unchecked")
@Aspect
public class LimiterAspect  {


    private final Map<Integer, BaseRateLimiter> rateLimiterMap;

    private final Map<String, Annotation> annotationMap;


    public LimiterAspect(Map<Integer, BaseRateLimiter> rateLimiterMap
    ) {
        this.rateLimiterMap = rateLimiterMap;
        this.annotationMap = new ConcurrentHashMap();
    }


    @Pointcut("@annotation(com.opensource.redisaux.limiter.annonations.TokenLimiter)||@annotation(com.opensource.redisaux.limiter.annonations.WindowLimiter)||@annotation(com.opensource.redisaux.limiter.annonations.FunnelLimiter)")
    public void limitPoint() {

    }


    @Around("limitPoint()")
    public Object doAroundAdvice(ProceedingJoinPoint proceedingJoinPoint) throws Throwable {
        MethodSignature signature = (MethodSignature) proceedingJoinPoint.getSignature();
        Class<?> beanClass = proceedingJoinPoint.getTarget().getClass();
        //获取所在类名
        String targetName = beanClass.getName();
        //获取执行的方法
        Method method = signature.getMethod();
        String methodKey = CommonUtil.getMethodKey(targetName, method);
        //该注解用于获取对应限流器
        LimiterType baseLimiter = null;
        Annotation target = null;
        if ((target = annotationMap.get(methodKey)) == null) {
            //找出限流器并且把对应的注解存到map里面
            Annotation[] annotations = signature.getMethod().getAnnotations();
            for (Annotation annotation : annotations) {
                if (annotation.annotationType().isAnnotationPresent(LimiterType.class)) {
                    target = annotation;
                    annotationMap.put(methodKey, target);
                    break;
                }
            }
        }
        baseLimiter = target.annotationType().getAnnotation(LimiterType.class);
        BaseRateLimiter rateLimiter = rateLimiterMap.get(baseLimiter.mode());
        if (rateLimiter.canExecute(target, methodKey)) {
            return proceedingJoinPoint.proceed();
        } else {
            //否则执行失败逻辑
            Object bean =proceedingJoinPoint.getTarget();
            BaseRateLimiter.KeyInfoNode keyInfoNode = BaseRateLimiter.keyInfoMap.get(methodKey);
            String fallBackMethodStr = keyInfoNode.getFallBackMethod();
            if ("".equals(fallBackMethodStr)) {
                return "too much request";
            }

           Method fallBackMethod= keyInfoNode.isPassArgs()?
                   beanClass.getMethod(fallBackMethodStr,method.getParameterTypes()):
                   beanClass.getMethod(fallBackMethodStr);
            fallBackMethod.setAccessible(true);
           return keyInfoNode.isPassArgs()?fallBackMethod.invoke(bean,proceedingJoinPoint.getArgs()):fallBackMethod.invoke(bean);
        }
    }



}

然后到限流设计部分,四个注解

@Retention(RetentionPolicy.RUNTIME)
@Target(ElementType.TYPE)
public @interface LimiterType {
    /**
     * 模式,用于去相应的map里面寻找对应的limiter
     * @return
     */
    int mode();
}
@Retention(RetentionPolicy.RUNTIME)
@Target(ElementType.METHOD)
@LimiterType(mode = BaseRateLimiter.TOKEN_LIMITER)
public @interface TokenLimiter {

    /**
     * 令牌桶容量
     *
     * @return
     */
    double capacity();

    /**
     * 令牌生成速率
     *
     * @return
     */
    double rate();

    /**
     * 速率时间单位,默认秒
     *
     * @return
     */
    TimeUnit rateUnit() default TimeUnit.SECONDS;

    /**
     * 每次请求所需要的令牌数
     *
     * @return
     */
    double need();

    /**
     * 是否阻塞等待
     *
     * @return
     */
    boolean isAbort() default false;

    /**
     * 阻塞超时时间
     *
     * @return
     */
    int timeout() default -1;

    /**
     * 单位,默认毫秒
     *
     * @return
     */
    TimeUnit timeoutUnit() default TimeUnit.MILLISECONDS;

    String fallback() default "";

    boolean passArgs() default false;

}
@Retention(RetentionPolicy.RUNTIME)
@Target(ElementType.METHOD)
@LimiterType(mode = BaseRateLimiter.FUNNEL_LIMITER)
public @interface FunnelLimiter {

    /**
     * 漏斗容量
     * @return
     */
    double capacity();
    /**
     *每秒漏出的速率
     * @return
     */
    double passRate();
    /**
     *时间单位
     * @return
     */
    TimeUnit timeUnit() default TimeUnit.SECONDS;
    /**
     *每次请求所需加的水量
     * @return
     */
    double addWater();

    String fallback() default "";

    boolean passArgs() default false;

}

@LimiterType是为了确定当前方法需要的限流类型

其他的注解里面的信息用于后面业务逻辑处理

这里有一个抽象限流器,里面有一些公共信息map,和对应的方法,这个KeyInfoNode里的keyList是接口映射到redis的keyName,因为脚本执行传参是list,passArgs和fallBack就是失败后的方法信息,有一个判断是否能通过的方法

public abstract class BaseRateLimiter {
    public final static int WINDOW_LIMITER =1;
    public final static int TOKEN_LIMITER =2;
    public final static int FUNNEL_LIMITER=3;
    //存放的是keyNameList、是否传参,回调方法名
    public static Map<String, KeyInfoNode> keyInfoMap =new ConcurrentHashMap();



    
   static List<String> getKey(String methodKey,String method,boolean passArgs){
        KeyInfoNode keyInfoNode;
        if((keyInfoNode= keyInfoMap.get(methodKey))==null){
            keyInfoNode= new KeyInfoNode();
            keyInfoNode.fallBackMethod=method;
            keyInfoNode.passArgs=passArgs;
            keyInfoNode.keyNameList= Collections.singletonList(methodKey);
            keyInfoMap.put(methodKey, keyInfoNode);
        }
        return keyInfoNode.getKeyNameList();
    }

    /**
     * 限流情况下,是否可以通过执行
     * @param redisLimiter
     * @param methodKey
     * @return
     */
    public   Boolean canExecute(Annotation redisLimiter, String methodKey){return null;};


    public static class KeyInfoNode{

        private  List<String> keyNameList;
        private  boolean passArgs;
        private String fallBackMethod;


        public List<String> getKeyNameList() {
            return keyNameList;
        }

        public boolean isPassArgs() {
            return passArgs;
        }

        public String getFallBackMethod() {
            return fallBackMethod;
        }
    }

}

看一下漏斗限流,这里通过aop拦截方法所获取的注解信息来确定限流方法的执行参数

public class FunnelRateLimiter extends BaseRateLimiter {
    private RedisTemplate redisTemplate;
    private DefaultRedisScript redisScript;


    public FunnelRateLimiter(RedisTemplate redisTemplate, DefaultRedisScript redisScript) {
        this.redisScript = redisScript;
        this.redisTemplate = redisTemplate;

    }

    @Override
    public Boolean canExecute(Annotation baseLimiter, String methodKey) {
        FunnelLimiter funnelLimiter = (FunnelLimiter) baseLimiter;
        TimeUnit timeUnit = funnelLimiter.timeUnit();
        double capacity = funnelLimiter.capacity();
        double need = funnelLimiter.addWater();
        double rate = funnelLimiter.passRate();
        long l = timeUnit.toMillis(1);
        double millRate = rate / l;
        String methodName=funnelLimiter.fallback();
        boolean passArgs=funnelLimiter.passArgs();
        List<String> keyList = BaseRateLimiter.getKey(methodKey,methodName,passArgs);
        return (Boolean) redisTemplate.execute(redisScript, keyList, new Object[]{capacity, millRate, need, Double.valueOf(System.currentTimeMillis())});
    }


}

执行的脚本如下

通过redis的hash表来构造一个漏斗器对象,它的属性有,漏斗容量,漏水的速率,一次请求所加的水,最后一次请求的时间,当前的水量

当请求来时,根据上一次请求的时间和本次时间来计算这段时间所流出的水,然后设置当前的水量,再判断是否可以加水,如果可以的话,更新最后一次请求时间和当前水量,这种限流方式可以保证通过的请求量是稳定的,因为漏斗的单位时间通过的水量是恒定的。

--参数说明,key[1]为对应服务接口的信息,argv1为capacity,argv2为漏水速率,argv3为一次所需流出的水量,argv4为时间戳
local limitInfo = redis.call('hmget', KEYS[1], 'capacity', 'passRate', 'addWater','water', 'lastTs')
local capacity = limitInfo[1]
local passRate = limitInfo[2]
local addWater= limitInfo[3]
local water = limitInfo[4]
local lastTs = limitInfo[5]

--初始化漏斗
if capacity == false then
    capacity = tonumber(ARGV[1])
    passRate = tonumber(ARGV[2])
    --请求一次所要加的水量
    addWater=tonumber(ARGV[3])
    --当前水量
    water = 0
    lastTs = tonumber(ARGV[4])
    redis.call('hmset', KEYS[1], 'capacity', capacity, 'passRate', passRate,'addWater',addWater,'water', water, 'lastTs', lastTs)
    return true
else
    local nowTs = tonumber(ARGV[4])
    --计算距离上一次请求到现在的漏水量
    local waterPass = tonumber((nowTs - lastTs)* passRate)
    --计算当前水量,即执行漏水
    water=math.max(0,water-waterPass)
    --设置本次请求的时间
    lastTs = nowTs
    --判断是否可以加水
    addWater=tonumber(addWater)
    if capacity-water >= addWater then
        --加水
        water=water+addWater
        --更新当前水量和时间戳
        redis.call('hmset', KEYS[1], 'water', water, 'lastTs', lastTs)
        return true
    end
    return false
end

令牌桶限流

这里就直接贴脚本了,因为都是通过获取注解来给脚本作参数的,同样的,也是通过hash表来构造一个令牌桶对象,令牌桶数量、令牌生成速率、每次请求所需的令牌,上一次请求的时间

大概原理其实和漏斗很像,只不过请求所处的角色不同,这里可以看作是消费者,而漏斗算法那里请求可以看作是生产者。

过程:通过计算当前时间与上一次时间的时间段生成的令牌,然后是否够本次请求使用,并更新对应的信息,令牌桶限流由于定期生产令牌,所以可以响应瞬时的突发请求,比如某个时刻,令牌桶中有10个令牌,那么1秒甚至更短的时间内可以相应10个请求也不会出错,但由于漏斗限流是固定流出水量,当1s内发生10个请求,流速不一定跟的上,满了以后就只能拒绝服务了

--参数说明,key[1]为对应服务接口的信息,argv1为capacity,argv2为令牌生成速率,argv3为每次需要的令牌数,argv4为当前时间戳
local limitInfo = redis.call('hmget', KEYS[1], 'capacity', 'passRate', 'leftToken', 'lastTs')
local capacity = limitInfo[1]
local rate = limitInfo[2]
local leftToken = limitInfo[3]
local lastTs = limitInfo[4]

--初始化令牌桶
if capacity == false then
    capacity = tonumber(ARGV[1])
    rate = tonumber(ARGV[2])
    leftToken = tonumber(ARGV[1])
    lastTs = tonumber(ARGV[4])
    redis.call('hmset', KEYS[1], 'capacity', capacity, 'passRate', rate, 'leftToken', leftToken, 'lastTs', lastTs)
    return -1
else
    local nowTs = tonumber(ARGV[4])
    --计算距离上一次请求到现在生产令牌数
    local genTokenNum = tonumber((nowTs - lastTs)* rate)
    --计算该段时间的剩余令牌
    leftToken = genTokenNum + leftToken
    --设置剩余令牌
    leftToken = math.min(capacity, leftToken)
    --设置本次请求的时间
    lastTs = nowTs
    local need = tonumber(ARGV[3])
    --返回需要等待的毫秒数,-1则不用等待
    if leftToken >= need then
        --减去需要的令牌
        leftToken = leftToken - need
        --更新剩余空间和上一次的漏水时间戳
        redis.call('hmset', KEYS[1], 'leftToken', leftToken, 'lastTs', lastTs)
        return -1
    end
    return (need-leftToken)/rate
end

然后到滑动窗口,主要是用到了sorted set结构,ARGV[1]是当前请求时间戳,然后把超出窗口外的时间戳的请求数量都删除,返回当前的请求数是否小于滑动窗口单位时间通过的请求数

@Retention(RetentionPolicy.RUNTIME)
@Target(ElementType.METHOD)
@LimiterType(mode = BaseRateLimiter.WINDOW_LIMITER)
public @interface WindowLimiter {
    /**
     * 持续时间,窗口间隔
     * @return
     */
    int during();

    TimeUnit timeUnit() default TimeUnit.SECONDS;

    /**
     * 通过的请求数
     * @return
     */
    long value();

    String fallback() default "";

    boolean passArgs() default false;



}
redis.call('zadd',KEYS[1],ARGV[1],ARGV[1])
 redis.call('zremrangebyscore',KEYS[1],0,ARGV[2]) 
return redis.call('zcard',KEYS[1]) <= tonumber(ARGV[3])

目前的限流功能还有许多可以改进的地方,以后有时间就更新下

大概的实现思路就是这样,详情可以看github上的代码