前往小程序,Get更优阅读体验!
立即前往
首页
学习
活动
专区
工具
TVP
发布
社区首页 >专栏 >利用Redis实现限流

利用Redis实现限流

作者头像
大忽悠爱学习
发布2023-02-13 15:55:00
2.2K0
发布2023-02-13 15:55:00
举报
文章被收录于专栏:c++与qt学习c++与qt学习

利用Redis实现限流


思路

redis实现限流的核心思路是利用redis提供的key过期时间作为限流窗口期,key的值记录该窗口期内已经产生的访问资源次数,key本身记录限流的资源范围。

具体步骤如下:

  • 首先规定资源限制范围,一般都是限制对某个接口的调用频率,因此key使用接口方法名即可
  • 第一次访问资源时,key不存在,那么新创建一个key,并将值设置为1,最后设置key的过期时间,表示开启限流窗口期
  • 每一次访问资源,会首先判断当前是否存在限流窗口期,如果存在,将访问次数加一,并判断是否达到最大资源访问次数限制
  • 如果达到了,则抛出异常,告诉用户访问频繁,请稍后再试
  • 如果没达到,则放行请求
  • 在不是第一次访问资源的前提下,如果发现限流窗口期过了,那么重新开启一个

步骤

1.准备工作

  • 引入redis相关依赖
代码语言:javascript
复制
<dependency>
    <groupId>org.springframework.boot</groupId>
    <artifactId>spring-boot-starter-data-redis</artifactId>
</dependency>
 <dependency>
    <groupId>org.apache.commons</groupId>
    <artifactId>commons-pool2</artifactId>
 </dependency>
  • 添加相关配置信息
代码语言:javascript
复制
spring:
  redis:
    host: xxx
    port: 6379
    password: xxx
    lettuce:
      #只有自动配置连接池的依赖,连接池才会生效
      pool:
        max-active: 8 #最大连接
        max-idle: 8 #最大空闲连接
        min-idle: 0 #最小空闲连接
        max-wait: 100 #连接等待时间
  • 修改redisTemplate的序列化方式为JSON
代码语言:javascript
复制
    @ConditionalOnMissingBean
    @Bean
    public RedisTemplate<String, Object> redisTemplate(RedisConnectionFactory redisConnectionFactory)
    {
        //创建template
        RedisTemplate<String, Object> redisTemplate = new RedisTemplate<>();
        //设置连接工厂
        redisTemplate.setConnectionFactory(redisConnectionFactory);
        //设置序列化工具
        GenericJackson2JsonRedisSerializer jsonRedisSerializer = new GenericJackson2JsonRedisSerializer();
        //key和hashKey采用String序列化
        redisTemplate.setKeySerializer(RedisSerializer.string());
        redisTemplate.setHashKeySerializer(RedisSerializer.string());
        //value和hashValue用JSON序列化
        redisTemplate.setValueSerializer(jsonRedisSerializer);
        redisTemplate.setHashValueSerializer(jsonRedisSerializer);
        return redisTemplate;
    }

2.限流核心类实现

  • 定义一个顶层的流量控制接口实现,pass方法返回true,表示方向请求,否则表示请求被拦截了
代码语言:javascript
复制
/**
 * 流量控制
 * @author 大忽悠
 * @create 2023/2/6 10:50
 */
public interface RateLimiter {
     /**
      * @param requestInfo 请求信息
      * @return 当前请求是否允许通过
      */
     boolean pass(RequestInformation requestInfo);
}
  • requestInfo提供当前请求的相关信息
代码语言:javascript
复制
/**
 * 请求信息
 * @author 大忽悠
 * @create 2023/2/6 10:55
 */
@Data
public class RequestInformation {
    /**
     * 限流key
     */
    private String key;
    /**
     * 限流时间
     */
    private int time;
    /**
     * time时间内最大请求资源次数
     */
    private int count;
    /**
     * 限流类型
     */
    private int limitType;
    /**
     * 请求的方法信息
     */
    private Method method;
    /**
     * 方法参数信息
     */
    private Object[] arguments;
    /**
     * 客户端IP地址
     */
    private String ip;
    private HttpServletRequest httpServletRequest;
    private HttpServletResponse httpServletResponse;

    public RequestInformation() {
    }
}
  • 提供一个限流注解,该注解可以标注在方法或者类上,标注在类上,则表示当前类所有方法都需要流量控制
代码语言:javascript
复制
/**
 * 限流注解
 * @author 大忽悠
 * @create 2023/2/6 10:39
 */
@Target({ElementType.METHOD,ElementType.TYPE})
@Retention(RetentionPolicy.RUNTIME)
@Documented
public @interface Limiter {
    /**
     * @return 限流key--默认为rate_limit:业务名:类名.方法名 ,如果限制了IP类型,则为: rate_limit:业务名:ip:类名.方法名
     */
    String key() default "";
    /**
     * @return 限流时间,单位为s
     */
    int time() default 60;
    /**
     * @return time时间内限制的资源请求次数
     */
    int count() default 100;
    /**
     * @return 限流类型
     */
    int limitType() default LimitType.DEFAULT;
}
  • redis作为限流器的实现
代码语言:javascript
复制
public class RedisRateLimiterImpl implements RateLimiter{
    private static final String RATE_LIMITER_KEY_PREFIX="rate_limiter";
    /**
     * 使用redis做限流处理使用的lua脚本
     */
    private static  final String LIMITER_LUA=
            "local key = KEYS[1]\n" +
            "local count = tonumber(ARGV[1])\n" +
            "local time = tonumber(ARGV[2])\n" +
            "local current = redis.call('get', key)\n" +
            "if current and tonumber(current) > count then\n" +
            "    return 1\n" +
            "end\n" +
            "current = redis.call('incr', key)\n" +
            "if tonumber(current) == 1 then\n" +
            "    redis.call('expire', key, time)\n" +
            "end\n" +
            "return 0\n";
    private RedisTemplate<String, Object> redisTemplate;

    public RedisRateLimiterImpl(RedisTemplate<String, Object> redisTemplate) {
        this.redisTemplate = redisTemplate;
    }

    /**
     * @param requestInfo 请求信息
     * @return 当前请求是否允许通过
     */
    @Override
    public boolean pass(RequestInformation requestInfo) {
        //拿到限流key
        String limiterKey=getRateLimiterKey(requestInfo);
        //执行lua脚本
        Long limiterRes = redisTemplate.execute(RedisScript.of(LIMITER_LUA,Long.class), List.of(limiterKey), requestInfo.getCount(), requestInfo.getTime());
        //判断限流结果
        return limiterRes==0L;
    }


    private String getRateLimiterKey(RequestInformation requestInfo) {
         return combineKey(RATE_LIMITER_KEY_PREFIX,
                 requestInfo.getKey(),
                 requestInfo.getIp(),
                 requestInfo.getMethod().getClass().getName(),
                 requestInfo.getMethod().getName());
    }

    private String combineKey(String ... keys) {
        StringBuilder keyBuilder=new StringBuilder();
        for (int i = 0; i < keys.length; i++) {
              if(StringUtils.isEmpty(keys[i])){
                  continue;
              }
              keyBuilder.append(keys[i]);
              if(i==keys.length-1){
                  continue;
              }
              keyBuilder.append(":");
        }
        return keyBuilder.toString();
    }
}

lua脚本解释:

KEYS 和 ARGV 都是一会调用时候传进来的参数,tonumber 就是把字符串转为数字,redis.call 就是执行具体的 redis 指令,具体流程是这样:

  • 首先获取到传进来的 key 以及 限流的 count 和时间 time。
  • 通过 get 获取到这个 key 对应的值,这个值就是当前时间窗内这个接口可以访问多少次。
  • 如果是第一次访问,此时拿到的结果为 nil,否则拿到的结果应该是一个数字,所以接下来就判断,如果拿到的结果是一个数字,并且这个数字还大于 count,那就说明已经超过流量限制了,那么返回1表示请求拦截。
  • 如果拿到的结果为 nil,说明是第一次访问,此时就给当前 key 自增 1,然后设置一个过期时间。
  • 最后返回0表示请求放行。

注意; lua脚本也可以定义在文件在,然后通过加载文件获取

代码语言:javascript
复制
@Bean
public DefaultRedisScript<Long> limitScript() {
    DefaultRedisScript<Long> redisScript = new DefaultRedisScript<>();
    redisScript.setScriptSource(new ResourceScriptSource(new ClassPathResource("lua/limit.lua")));
    redisScript.setResultType(Long.class);
    return redisScript;
}

或者在 Redis 服务端定义好 Lua 脚本,然后计算出来一个散列值,在 Java 代码中,通过这个散列值锁定要执行哪个 Lua 脚本


3.aop相关逻辑实现

我们需要将限流逻辑在需要流量管控的方法执行前先执行,因此需要拦截目标方法,有两个思路:

  1. 通过@Aspect注解标注一个切面类,用@Before或者@Around注解标注在切面方法上,里面填写限流管控逻辑
  2. 手动编写一个advisor增强器,注入容器,并提供相关拦截器和pointcut实现

这里我采用的是手动编写advisor的方式进行实现,下面演示具体步骤:

  • 编写拦截器
代码语言:javascript
复制
/**
 * 限流方法拦截器
 * @author 大忽悠
 * @create 2023/2/6 11:08
 */
@Slf4j
public class RateLimiterMethodInterceptor implements MethodInterceptor {
    private final RateLimiter rateLimiter;

    public RateLimiterMethodInterceptor(RateLimiter rateLimiter) {
        this.rateLimiter=rateLimiter;
    }

    @Override
    public Object invoke(MethodInvocation invocation) throws Throwable {
        try{
            RequestInformation requestInformation = new RequestInformation();
            buildMethodInfo(requestInformation,invocation);
            buildLimitInfo(requestInformation);
            buildRequestInfo(requestInformation);
            if (rateLimiter.pass(requestInformation)){
                return invocation.proceed();
            }
            logWarn(requestInformation);
        }catch (Exception e){
           e.printStackTrace();
           throw e;
        }
        throw new RateLimiterException("访问过于频繁,请稍后再试!");
    }

    private void logWarn(RequestInformation requestInformation) {
        if(requestInformation.getHttpServletRequest()!=null){
            log.warn("rateLimiter拦截了一个请求,该请求信息如下: URI: {} ,IP: {} ,方法名: {} ,方法参数信息: {} ",
                    requestInformation.getHttpServletRequest().getRequestURI(),requestInformation.getIp(),requestInformation.getMethod().getName(),
                    Arrays.toString(requestInformation.getArguments()));
        }else {
            log.warn("rateLimiter拦截了一个请求,该请求信息如下: 方法名: {} ,方法参数信息: {} ",
                    requestInformation.getMethod().getName(), Arrays.toString(requestInformation.getArguments()));
        }
    }

    private void buildLimitInfo(RequestInformation requestInformation) throws RateLimiterException {
        Method method = requestInformation.getMethod();
        Limiter limiter;
        if(method.isAnnotationPresent(Limiter.class)){
            limiter = method.getAnnotation(Limiter.class);
        }else {
            limiter=method.getClass().getAnnotation(Limiter.class);
        }
        if(limiter==null){
            throw new RateLimiterException("无法在当前方法"+method.getName()+"或者类"+method.getClass().getName()+"上寻找到@Limiter注解");
        }
        requestInformation.setKey(limiter.key());
        requestInformation.setCount(limiter.count());
        requestInformation.setTime(limiter.time());
        requestInformation.setLimitType(limiter.limitType());
    }

    private void buildMethodInfo(RequestInformation requestInformation, MethodInvocation invocation) {
        requestInformation.setMethod(invocation.getMethod());
        if(invocation instanceof ReflectiveMethodInvocation){
            ReflectiveMethodInvocation reflectiveMethodInvocation = (ReflectiveMethodInvocation) invocation;
            requestInformation.setArguments(reflectiveMethodInvocation.getArguments());
        }
    }

    /**
     * 从线程上下文中取出请求和响应相关信息
     */
    private void buildRequestInfo(RequestInformation requestInformation) {
        RequestAttributes requestAttributes = RequestContextHolder.getRequestAttributes();
        if(requestAttributes instanceof ServletRequestAttributes){
            ServletRequestAttributes sra = (ServletRequestAttributes) requestAttributes;
            requestInformation.setHttpServletRequest(sra.getRequest());
            requestInformation.setHttpServletResponse(sra.getResponse());
        }
        if(requestInformation.getHttpServletRequest()!=null && requestInformation.getLimitType()==LimitType.IP){
            requestInformation.setIp(IPUtils.getIpAddress(requestInformation.getHttpServletRequest()));
        }
    }
}
  • 编写advisor增强器
代码语言:javascript
复制
/**
 * 限流增强器
 *
 * @author 大忽悠
 * @create 2023/2/6 10:57
 */
public class RateLimiterAdvisor extends AbstractPointcutAdvisor {
    private Pointcut pointcut;
    private RateLimiterMethodInterceptor rateLimiterMethodInterceptor;

    public RateLimiterAdvisor(RateLimiter rateLimiter) {
        pointcut = buildPointCut();
        rateLimiterMethodInterceptor=new RateLimiterMethodInterceptor(rateLimiter);
    }

    @Override
    public Pointcut getPointcut() {
        return pointcut;
    }

    @Override
    public Advice getAdvice() {
        return rateLimiterMethodInterceptor;
    }


    private Pointcut buildPointCut() {
        return new Pointcut() {
            @Override
            public ClassFilter getClassFilter() {
                return (c)-> AnnotationUtils.isCandidateClass(c,Limiter.class);
            }

            @Override
            public MethodMatcher getMethodMatcher() {
                return new StaticMethodMatcher() {
                    @Override
                    public boolean matches(Method method, Class<?> targetClass) {
                        return method.isAnnotationPresent(Limiter.class) || targetClass.isAnnotationPresent(Limiter.class);
                    }
                };
            }
        };
    }
}
  • 使用配置类将advisor增强器放入容器中
代码语言:javascript
复制
/**
 * @author 大忽悠
 * @create 2023/2/6 11:14
 */
@Configuration
public class RateLimiterAutoConfiguration {
    @Bean
    @ConditionalOnMissingBean
    public RateLimiterAdvisor rateLimiterAdvisor(RateLimiter rateLimiter) {
        return new RateLimiterAdvisor(rateLimiter);
    }

    @Bean
    @ConditionalOnMissingBean
    public RateLimiter rateLimiter(RedisTemplate<String, Object> redisTemplate) {
        return new RedisRateLimiterImpl(redisTemplate);
    }
    ...
}

采用切面进行实现,可以参考江南一点雨大佬给出的实现:

代码语言:javascript
复制
@Aspect
@Component
public class RateLimiterAspect {
    private static final Logger log = LoggerFactory.getLogger(RateLimiterAspect.class);

    @Autowired
    private RedisTemplate<Object, Object> redisTemplate;

    @Autowired
    private RedisScript<Long> limitScript;

    @Before("@annotation(rateLimiter)")
    public void doBefore(JoinPoint point, RateLimiter rateLimiter) throws Throwable {
        String key = rateLimiter.key();
        int time = rateLimiter.time();
        int count = rateLimiter.count();

        String combineKey = getCombineKey(rateLimiter, point);
        List<Object> keys = Collections.singletonList(combineKey);
        try {
            Long number = redisTemplate.execute(limitScript, keys, count, time);
            if (number==null || number.intValue() > count) {
                throw new ServiceException("访问过于频繁,请稍候再试");
            }
            log.info("限制请求'{}',当前请求'{}',缓存key'{}'", count, number.intValue(), key);
        } catch (ServiceException e) {
            throw e;
        } catch (Exception e) {
            throw new RuntimeException("服务器限流异常,请稍候再试");
        }
    }

    public String getCombineKey(RateLimiter rateLimiter, JoinPoint point) {
        StringBuffer stringBuffer = new StringBuffer(rateLimiter.key());
        if (rateLimiter.limitType() == LimitType.IP) {
            stringBuffer.append(IpUtils.getIpAddr(((ServletRequestAttributes) RequestContextHolder.currentRequestAttributes()).getRequest())).append("-");
        }
        MethodSignature signature = (MethodSignature) point.getSignature();
        Method method = signature.getMethod();
        Class<?> targetClass = method.getDeclaringClass();
        stringBuffer.append(targetClass.getName()).append("-").append(method.getName());
        return stringBuffer.toString();
    }
}

Redis 做接口限流,一个注解的事


4.全局异常拦截

代码语言:javascript
复制
    @ResponseStatus()
    @ExceptionHandler(UndeclaredThrowableException.class)
    public Result exception(UndeclaredThrowableException e) {
        log.error("错误类型为RateLimiterException : "+e);
        Throwable cause = e.getCause();
        if(cause instanceof RateLimiterException){
            RateLimiterException ex= (RateLimiterException)cause;
            return Result.error(ex.getMessage(), ex.getMessage());
        }
        return Result.error(cause.getMessage(),cause.getMessage());
    }

5.测试执行

代码语言:javascript
复制
@RestController
@RequestMapping("order")
public class OrderController {

    @Limiter(time = 20,count = 5,limitType = LimitType.IP)
    @GetMapping("/{orderId}")
    public Order queryOrderByUserId(@PathVariable("orderId") Long orderId) {
        return Order.builder().id(1L).name("大忽悠").userId(3L).price(10L).build();
    }
}

每一个 IP 地址,在 20 秒内只能访问5次,大家可以手动测试。

在这里插入图片描述
在这里插入图片描述

完整代码

代码语言:javascript
复制
/**
 * 限流注解
 * @author 大忽悠
 * @create 2023/2/6 10:39
 */
@Target({ElementType.METHOD,ElementType.TYPE})
@Retention(RetentionPolicy.RUNTIME)
@Documented
public @interface Limiter {
    /**
     * @return 限流key--默认为rate_limit:业务名:类名.方法名 ,如果限制了IP类型,则为: rate_limit:业务名:ip:类名.方法名
     */
    String key() default "";
    /**
     * @return 限流时间,单位为s
     */
    int time() default 60;
    /**
     * @return time时间内限制的资源请求次数
     */
    int count() default 100;
    /**
     * @return 限流类型
     */
    int limitType() default LimitType.DEFAULT;
}

/**
 * 限流类型
 * @author 大忽悠
 * @create 2023/2/6 10:43
 */
public interface LimitType {
    /**
     * 默认限流类型
     */
    int DEFAULT=0;
    /**
     * 根据IP进行限制
     */
    int IP=1;
}


/**
 * 流量控制
 * @author 大忽悠
 * @create 2023/2/6 10:50
 */
public interface RateLimiter {
     /**
      * @param requestInfo 请求信息
      * @return 当前请求是否允许通过
      */
     boolean pass(RequestInformation requestInfo);
}


/**
 * 限流增强器
 *
 * @author 大忽悠
 * @create 2023/2/6 10:57
 */
public class RateLimiterAdvisor extends AbstractPointcutAdvisor {
    private Pointcut pointcut;
    private RateLimiterMethodInterceptor rateLimiterMethodInterceptor;

    public RateLimiterAdvisor(RateLimiter rateLimiter) {
        pointcut = buildPointCut();
        rateLimiterMethodInterceptor=new RateLimiterMethodInterceptor(rateLimiter);
    }

    @Override
    public Pointcut getPointcut() {
        return pointcut;
    }

    @Override
    public Advice getAdvice() {
        return rateLimiterMethodInterceptor;
    }


    private Pointcut buildPointCut() {
        return new Pointcut() {
            @Override
            public ClassFilter getClassFilter() {
                return (c)-> AnnotationUtils.isCandidateClass(c,Limiter.class);
            }

            @Override
            public MethodMatcher getMethodMatcher() {
                return new StaticMethodMatcher() {
                    @Override
                    public boolean matches(Method method, Class<?> targetClass) {
                        return method.isAnnotationPresent(Limiter.class) || targetClass.isAnnotationPresent(Limiter.class);
                    }
                };
            }
        };
    }
}

/**
 * @author 大忽悠
 * @create 2023/2/6 11:14
 */
@Configuration
public class RateLimiterAutoConfiguration {
    @Bean
    @ConditionalOnMissingBean
    public RateLimiterAdvisor rateLimiterAdvisor(RateLimiter rateLimiter) {
        return new RateLimiterAdvisor(rateLimiter);
    }

    @Bean
    @ConditionalOnMissingBean
    public RateLimiter rateLimiter(RedisTemplate<String, Object> redisTemplate) {
        return new RedisRateLimiterImpl(redisTemplate);
    }

    @ConditionalOnMissingBean
    @Bean
    public RedisTemplate<String, Object> redisTemplate(RedisConnectionFactory redisConnectionFactory)
    {
        //创建template
        RedisTemplate<String, Object> redisTemplate = new RedisTemplate<>();
        //设置连接工厂
        redisTemplate.setConnectionFactory(redisConnectionFactory);
        //设置序列化工具
        GenericJackson2JsonRedisSerializer jsonRedisSerializer = new GenericJackson2JsonRedisSerializer();
        //key和hashKey采用String序列化
        redisTemplate.setKeySerializer(RedisSerializer.string());
        redisTemplate.setHashKeySerializer(RedisSerializer.string());
        //value和hashValue用JSON序列化
        redisTemplate.setValueSerializer(jsonRedisSerializer);
        redisTemplate.setHashValueSerializer(jsonRedisSerializer);
        return redisTemplate;
    }
}

/**
 * 限流异常
 * @author 大忽悠
 * @create 2023/2/6 11:39
 */
public class RateLimiterException extends Exception {

    public RateLimiterException(String ex) {
        super(ex);
    }
}

/**
 * 限流方法拦截器
 * @author 大忽悠
 * @create 2023/2/6 11:08
 */
@Slf4j
public class RateLimiterMethodInterceptor implements MethodInterceptor {
    private final RateLimiter rateLimiter;

    public RateLimiterMethodInterceptor(RateLimiter rateLimiter) {
        this.rateLimiter=rateLimiter;
    }

    @Override
    public Object invoke(MethodInvocation invocation) throws Throwable {
        try{
            RequestInformation requestInformation = new RequestInformation();
            buildMethodInfo(requestInformation,invocation);
            buildLimitInfo(requestInformation);
            buildRequestInfo(requestInformation);
            if (rateLimiter.pass(requestInformation)){
                return invocation.proceed();
            }
            logWarn(requestInformation);
        }catch (Exception e){
           e.printStackTrace();
           throw e;
        }
        throw new RateLimiterException("访问过于频繁,请稍后再试!");
    }

    private void logWarn(RequestInformation requestInformation) {
        if(requestInformation.getHttpServletRequest()!=null){
            log.warn("rateLimiter拦截了一个请求,该请求信息如下: URI: {} ,IP: {} ,方法名: {} ,方法参数信息: {} ",
                    requestInformation.getHttpServletRequest().getRequestURI(),requestInformation.getIp(),requestInformation.getMethod().getName(),
                    Arrays.toString(requestInformation.getArguments()));
        }else {
            log.warn("rateLimiter拦截了一个请求,该请求信息如下: 方法名: {} ,方法参数信息: {} ",
                    requestInformation.getMethod().getName(), Arrays.toString(requestInformation.getArguments()));
        }
    }

    private void buildLimitInfo(RequestInformation requestInformation) throws RateLimiterException {
        Method method = requestInformation.getMethod();
        Limiter limiter;
        if(method.isAnnotationPresent(Limiter.class)){
            limiter = method.getAnnotation(Limiter.class);
        }else {
            limiter=method.getClass().getAnnotation(Limiter.class);
        }
        if(limiter==null){
            throw new RateLimiterException("无法在当前方法"+method.getName()+"或者类"+method.getClass().getName()+"上寻找到@Limiter注解");
        }
        requestInformation.setKey(limiter.key());
        requestInformation.setCount(limiter.count());
        requestInformation.setTime(limiter.time());
        requestInformation.setLimitType(limiter.limitType());
    }

    private void buildMethodInfo(RequestInformation requestInformation, MethodInvocation invocation) {
        requestInformation.setMethod(invocation.getMethod());
        if(invocation instanceof ReflectiveMethodInvocation){
            ReflectiveMethodInvocation reflectiveMethodInvocation = (ReflectiveMethodInvocation) invocation;
            requestInformation.setArguments(reflectiveMethodInvocation.getArguments());
        }
    }

    /**
     * 从线程上下文中取出请求和响应相关信息
     */
    private void buildRequestInfo(RequestInformation requestInformation) {
        RequestAttributes requestAttributes = RequestContextHolder.getRequestAttributes();
        if(requestAttributes instanceof ServletRequestAttributes){
            ServletRequestAttributes sra = (ServletRequestAttributes) requestAttributes;
            requestInformation.setHttpServletRequest(sra.getRequest());
            requestInformation.setHttpServletResponse(sra.getResponse());
        }
        if(requestInformation.getHttpServletRequest()!=null && requestInformation.getLimitType()==LimitType.IP){
            requestInformation.setIp(IPUtils.getIpAddress(requestInformation.getHttpServletRequest()));
        }
    }
}

public class RedisRateLimiterImpl implements RateLimiter{
    private static final String RATE_LIMITER_KEY_PREFIX="rate_limiter";
    /**
     * 使用redis做限流处理使用的lua脚本
     */
    private static  final String LIMITER_LUA=
            "local key = KEYS[1]\n" +
            "local count = tonumber(ARGV[1])\n" +
            "local time = tonumber(ARGV[2])\n" +
            "local current = redis.call('get', key)\n" +
            "if current and tonumber(current) > count then\n" +
            "    return 1\n" +
            "end\n" +
            "current = redis.call('incr', key)\n" +
            "if tonumber(current) == 1 then\n" +
            "    redis.call('expire', key, time)\n" +
            "end\n" +
            "return 0\n";
    private RedisTemplate<String, Object> redisTemplate;

    public RedisRateLimiterImpl(RedisTemplate<String, Object> redisTemplate) {
        this.redisTemplate = redisTemplate;
    }

    /**
     * @param requestInfo 请求信息
     * @return 当前请求是否允许通过
     */
    @Override
    public boolean pass(RequestInformation requestInfo) {
        //拿到限流key
        String limiterKey=getRateLimiterKey(requestInfo);
        //执行lua脚本
        Long limiterRes = redisTemplate.execute(RedisScript.of(LIMITER_LUA,Long.class), List.of(limiterKey), requestInfo.getCount(), requestInfo.getTime());
        //判断限流结果
        return limiterRes==0L;
    }


    private String getRateLimiterKey(RequestInformation requestInfo) {
         return combineKey(RATE_LIMITER_KEY_PREFIX,
                 requestInfo.getKey(),
                 requestInfo.getIp(),
                 requestInfo.getMethod().getClass().getName(),
                 requestInfo.getMethod().getName());
    }

    private String combineKey(String ... keys) {
        StringBuilder keyBuilder=new StringBuilder();
        for (int i = 0; i < keys.length; i++) {
              if(StringUtils.isEmpty(keys[i])){
                  continue;
              }
              keyBuilder.append(keys[i]);
              if(i==keys.length-1){
                  continue;
              }
              keyBuilder.append(":");
        }
        return keyBuilder.toString();
    }
}


/**
 * 请求信息
 * @author 大忽悠
 * @create 2023/2/6 10:55
 */
@Data
public class RequestInformation {
    /**
     * 限流key
     */
    private String key;
    /**
     * 限流时间
     */
    private int time;
    /**
     * time时间内最大请求资源次数
     */
    private int count;
    /**
     * 限流类型
     */
    private int limitType;
    /**
     * 请求的方法信息
     */
    private Method method;
    /**
     * 方法参数信息
     */
    private Object[] arguments;
    /**
     * 客户端IP地址
     */
    private String ip;
    private HttpServletRequest httpServletRequest;
    private HttpServletResponse httpServletResponse;

    public RequestInformation() {
    }
}

/**
 * @author 大忽悠
 * @create 2023/2/6 12:51
 */
public class IPUtils {
    /**
     * 获取用户真实IP地址,不使用request.getRemoteAddr();的原因是有可能用户使用了代理软件方式避免真实IP地址,
     * 参考文章: http://developer.51cto.com/art/201111/305181.htm
     *
     * 可是,如果通过了多级反向代理的话,X-Forwarded-For的值并不止一个,而是一串IP值,究竟哪个才是真正的用户端的真实IP呢?
     * 答案是取X-Forwarded-For中第一个非unknown的有效IP字符串。
     *
     * 如:X-Forwarded-For:192.168.1.110, 192.168.1.120, 192.168.1.130,
     * 192.168.1.100
     *
     * 用户真实IP为: 192.168.1.110
     *
     * @param request
     * @return
     */
    public static String getIpAddress(HttpServletRequest request) {
        String ip = request.getHeader("x-forwarded-for");
        if (ip == null || ip.length() == 0 || "unknown".equalsIgnoreCase(ip)) {
            ip = request.getHeader("Proxy-Client-IP");
        }
        if (ip == null || ip.length() == 0 || "unknown".equalsIgnoreCase(ip)) {
            ip = request.getHeader("WL-Proxy-Client-IP");
        }
        if (ip == null || ip.length() == 0 || "unknown".equalsIgnoreCase(ip)) {
            ip = request.getHeader("HTTP_CLIENT_IP");
        }
        if (ip == null || ip.length() == 0 || "unknown".equalsIgnoreCase(ip)) {
            ip = request.getHeader("HTTP_X_FORWARDED_FOR");
        }
        if (ip == null || ip.length() == 0 || "unknown".equalsIgnoreCase(ip)) {
            ip = request.getRemoteAddr();
        }
        return ip;
    }
}
本文参与 腾讯云自媒体同步曝光计划,分享自作者个人站点/博客。
原始发表:2023-02-06,如有侵权请联系 cloudcommunity@tencent.com 删除

本文分享自 作者个人站点/博客 前往查看

如有侵权,请联系 cloudcommunity@tencent.com 删除。

本文参与 腾讯云自媒体同步曝光计划  ,欢迎热爱写作的你一起参与!

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
目录
  • 利用Redis实现限流
  • 思路
  • 步骤
    • 1.准备工作
      • 2.限流核心类实现
        • 3.aop相关逻辑实现
          • 4.全局异常拦截
            • 5.测试执行
            • 完整代码
            相关产品与服务
            云数据库 Redis
            腾讯云数据库 Redis(TencentDB for Redis)是腾讯云打造的兼容 Redis 协议的缓存和存储服务。丰富的数据结构能帮助您完成不同类型的业务场景开发。支持主从热备,提供自动容灾切换、数据备份、故障迁移、实例监控、在线扩容、数据回档等全套的数据库服务。
            领券
            问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档