【技术应用】springboot+redis+lua实现分布式限流

  • 一、前言
  • 二、lua介绍
  • 三、限流实现分析
  • 四、实现代码
  • 五、总结


一、前言

最近看到很多关于lua的内容,很多语言都支持使用lua脚本,并且我们经常使用的开源组件redisnginx等都已经支持lua脚本的使用,尤其在游戏里用到的相对比较多,既然当前工作中应用这么广,我们应该学习总结一下;

最近项目有涉及到接口限流的功能,今天就举例redis+lua实现限流功能代码

注:当前如果参加面试涉及redis的内容,我相信面试官大概率的会问redis操作如何保障原子性,这时候就可以使用lua脚本实现原子性

二、lua介绍

Lua 是一个简洁、轻量、可扩展的脚本语言,它有着相对简单的API 因此很容易嵌入应用中,很多应用程序使用Lua作为自己的嵌入式脚本语言,以此来实现可配置性、可扩展性。
Redis 从 2.6 版本开始支持 Lua 脚本,客户端通过 Lua 脚本,可以将多个 Redis 命令组合成一个原子性操作在服务器上执行。

Lua 脚本有以下优点:

  • 保证操作原子性
  • 减少网络开销,将多个指令组合到一个脚本中,与服务器的交互从多次变为一次
  • 可重复使用,在初次载入脚本之后,服务器会为脚本生成缓存,后续执行脚本时可直接使用缓存

三、限流实现分析

1、通过aop实现request请求连接,解析请求者ip;
2、redis+lua实现请求次数统计;

开发环境:springboot + aop + redis + lua

四、实现代码

1、lua脚本

限流的判断逻辑是在iplimite.lua脚本中实现的

-- 为某个接口的请求IP设置计数器,比如:127.0.0.1请求课程接口
-- KEYS[1] = 127.0.0.1 也就是用户的IP
-- ARGV[1] = 过期时间 30m
-- ARGV[2] = 限制的次数
local limitCount = redis.call('incr',KEYS[1]);
if limitCount == 1 then
    redis.call("expire",KEYS[1],ARGV[2])
end
-- 如果次数还没有过期,并且还在规定的次数内,说明还在请求同一接口
if limitCount > tonumber(ARGV[1]) then
    return false
end

return true

redis 的incr原理:

  • redis Incr 命令将 key 中储存的数字值增一。
  • 如果 key 不存在,那么 key 的值会先被初始化为 0 ,然后再执行 INCR 操作。
  • 如果值包含错误的类型,或字符串类型的值不能表示为数字,那么返回一个错误。
  • 本操作的值限制在 64 位(bit)有符号数字表示之内。

incr是原子操作的,对于这种场景,可以不用获取原来的值,直接用redis的incr实现readwrite的打包原子操作,就不会出现读了一半,然后被别人篡改了。真实场景可能不仅仅是这种库存问题,那么像批量设置多个值的场景可以用mset,批量获取多个值的mget,与incr相对应的decr,这些都是原子的。

:redis是IO多路复用模型,即一个线程来处理多个TCP连接,这样的好处就是,即使客户端并发请求,也得排队处理,一定程度上解决了多线程模型带的并发问题,所以redis是并发安全的?从redis本身的架构模式来说,并发是安全的,不存在同时执行两个客户端的命令。但是如果因为某些业务场景用的有问题,那么即使是单线程的redis也无能为力。

2、java初始化LuaConfig类

package com.qbb.limit.config;

import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
import org.springframework.core.io.ClassPathResource;
import org.springframework.data.redis.core.script.DefaultRedisScript;
import org.springframework.scripting.support.ResourceScriptSource;

@Configuration
public class LuaConfig {
    /**
     * 将lua脚本的内容加载出来放入到DefaultRedisScript
     *
     * @return
     */
    @Bean
    public DefaultRedisScript<Boolean> ipLimitLua() {
        DefaultRedisScript<Boolean> defaultRedisScript = new DefaultRedisScript<>();
        defaultRedisScript.setScriptSource(new ResourceScriptSource(new ClassPathResource("lua/iplimite.lua")));
        defaultRedisScript.setResultType(Boolean.class);
        return defaultRedisScript;
    }
}

Spring提供了RedisScript接口,方便开发者调用Lua脚本。

public interface RedisScript<T> {
   //该方法用来获取脚本的SHA1
    String getSha1();
   //用来获取返回类型
    @Nullable
    Class<T> getResultType();
   //用来获取脚本字符串
    String getScriptAsString();
}

3、初始化redisConfig类

package com.qbb.limit.config;

import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
import org.springframework.data.redis.connection.RedisConnectionFactory;
import org.springframework.data.redis.core.RedisTemplate;
import org.springframework.data.redis.serializer.GenericJackson2JsonRedisSerializer;
import org.springframework.data.redis.serializer.StringRedisSerializer;

@Configuration
public class RedisConfig {

    @Bean
    public RedisTemplate<String, Object> redisTemplate(RedisConnectionFactory redisConnectionFactory) {
        // 1: 开始创建一个redistemplate
        RedisTemplate<String, Object> redisTemplate = new RedisTemplate<>();
        // 2:开始redis连接工厂跪安了
        redisTemplate.setConnectionFactory(redisConnectionFactory);
        // 创建一个json的序列化方式
        GenericJackson2JsonRedisSerializer jackson2JsonRedisSerializer = new GenericJackson2JsonRedisSerializer();
        // 设置key用string序列化方式
        redisTemplate.setKeySerializer(new StringRedisSerializer());
        // 设置value用jackjson进行处理
        redisTemplate.setValueSerializer(jackson2JsonRedisSerializer);
        // hash也要进行修改
        redisTemplate.setHashKeySerializer(new StringRedisSerializer());
        redisTemplate.setHashValueSerializer(jackson2JsonRedisSerializer);
        // 默认调用
        redisTemplate.afterPropertiesSet();
        return redisTemplate;
    }

}

4、自定义注解
设置接口限流信息

package com.qbb.limit.aop;

import java.lang.annotation.*;

@Target(ElementType.METHOD)
@Retention(RetentionPolicy.RUNTIME)
@Documented
public @interface AccessLimiter {
    // 每timeout限制请求的个数
    int limit() default 10;

    // 时间,单位默认是秒
    int timeout() default 1;
}

5、拦截器内容

package com.qbb.limit.core;

import com.google.common.collect.Lists;
import com.qbb.limit.aop.AccessLimiter;
import com.qbb.limit.utils.RequestUtils;
import lombok.extern.slf4j.Slf4j;
import org.aspectj.lang.JoinPoint;
import org.aspectj.lang.annotation.Aspect;
import org.aspectj.lang.annotation.Before;
import org.aspectj.lang.annotation.Pointcut;
import org.aspectj.lang.reflect.MethodSignature;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.data.redis.core.StringRedisTemplate;
import org.springframework.data.redis.core.script.DefaultRedisScript;
import org.springframework.stereotype.Component;
import org.springframework.web.context.request.RequestContextHolder;
import org.springframework.web.context.request.ServletRequestAttributes;

import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import java.io.PrintWriter;
import java.lang.reflect.Method;

@Component
@Aspect
@Slf4j
public class LimiterAspect {
    @Autowired
    private StringRedisTemplate stringRedisTemplate;
    
    @Autowired
    private DefaultRedisScript<Boolean> ipLimitLua;

    // 1: 切入点
    @Pointcut("@annotation(com.qbb.limit.aop.AccessLimiter)")
    public void limiterPointcut() {
    }

    @Before("limiterPointcut()")
    public void limiter(JoinPoint joinPoint) {
        log.info("限流进来了.......");
        // 1:获取方法的签名作为key
        MethodSignature methodSignature = (MethodSignature) joinPoint.getSignature();
        Method method = methodSignature.getMethod();
        String classname = methodSignature.getMethod().getDeclaringClass().getName();
        String packageName = methodSignature.getMethod().getDeclaringClass().getPackage().getName();
        log.info("classname:{},packageName:{}", classname, packageName);
        // 4: 读取方法的注解信息获取限流参数
        AccessLimiter annotation = method.getAnnotation(AccessLimiter.class);
        // 5:获取注解方法名
        String methodNameKey = method.getName();
        // 6:获取服务请求的对象
        ServletRequestAttributes requestAttributes = (ServletRequestAttributes) RequestContextHolder.getRequestAttributes();
        HttpServletRequest request = requestAttributes.getRequest();
        HttpServletResponse response = requestAttributes.getResponse();
        String userIp = RequestUtils.getIpAddr(request);
        log.info("用户IP是:.......{}", userIp);
        // 7:通过方法反射获取注解的参数
        Integer limit = annotation.limit();
        Integer timeout = annotation.timeout();
        String redisKey = method + ":" + userIp;
        // 8: 请求lua脚本
        Boolean acquired = stringRedisTemplate.execute(ipLimitLua, Lists.newArrayList(redisKey), limit.toString(), timeout.toString());
        // 如果超过限流限制
        if (!acquired) {
            // 抛出异常,然后让全局异常去处理
            response.setCharacterEncoding("UTF-8");
            response.setContentType("text/html;charset=UTF-8");

            try (PrintWriter writer = response.getWriter();) {
                // 解决报错:getWriter() has already been called for this response] with root cause
                writer.print("<h1>手速慢点,请稍后在试一试!!!</h1>");
                writer.flush();
            } catch (Exception ex) {
                throw new RuntimeException("手速慢点,请稍后在试一试!!!");
            }
        }
    }
}

6、获取IP信息

package com.qbb.limit.utils;

import javax.servlet.http.HttpServletRequest;

public class RequestUtils {

    public static String getIpAddr(HttpServletRequest request) {
        if (request == null) {
            return "unknown";
        }
        String ip = request.getHeader("x-forwarded-for");
        if (ip == null || ip.length() == 0 || "unknown".equalsIgnoreCase(ip)) {
            ip = request.getHeader("Proxy-Client-IP");
        }
        if (ip == null || ip.length() == 0 || "unknown".equalsIgnoreCase(ip)) {
            ip = request.getHeader("X-Forwarded-For");
        }
        if (ip == null || ip.length() == 0 || "unknown".equalsIgnoreCase(ip)) {
            ip = request.getHeader("WL-Proxy-Client-IP");
        }
        if (ip == null || ip.length() == 0 || "unknown".equalsIgnoreCase(ip)) {
            ip = request.getHeader("X-Real-IP");
        }
        if (ip == null || ip.length() == 0 || "unknown".equalsIgnoreCase(ip)) {
            ip = request.getRemoteAddr();
        }
        return "0:0:0:0:0:0:0:1".equals(ip) ? "127.0.0.1" : ip;
    }
}

7、业务接口

package com.qbb.limit.controller;

import com.qbb.limit.aop.AccessLimiter;
import org.springframework.web.bind.annotation.GetMapping;
import org.springframework.web.bind.annotation.RestController;

@RestController
public class HelloController {
    @GetMapping("/hello")
    @AccessLimiter(timeout = 1, limit = 3) // 1秒钟超过3次限流
    public String index() {
        // 分布锁
        return "success";
    }

}

8、请求结果

java springboot 录屏如何实现 lua springboot_限流

java springboot 录屏如何实现 lua springboot_lua_02

五、总结

当前的限流方案是在业务层实现的,我们也可以在nginxgatewaynocas等实现限流,限流方式也分为滑动窗口漏斗算法令牌桶算法实现,大家可以结合自己使用场景选择request请求限流的方式。