调用速率限制是 Web API 中的常见要求,旨在防止滥用并确保公平使用资源。借助Spring Boot 中的 AOP,我们可以通过拦截方法调用并限制在特定时间范围内允许的请求数量来实现速率限制。
为了在 Spring Boot 中使用 AOP 实现速率限制:
可以使用各种技术在 Spring Boot API 中实现速率限制。一种常见的方法是使用 Spring AOP来拦截传入的请求并实施速率限制。
创建一个配置类,在其中定义速率限制参数,例如允许的请求数和时间段。
@Configuration
public class RateLimitConfig {
@Value("${rate.limit.requests}")
private int requests;
@Value("${rate.limit.seconds}")
private int seconds;
// Getters and setters
}
使用 Spring AOP 实现一个方面来拦截方法调用并强制执行速率限制。
@Aspect
@Component
public class RateLimitAspect {
@Autowired
private RateLimitConfig rateLimitConfig;
@Autowired
private RateLimiter rateLimiter;
@Around("@annotation(RateLimited)")
public Object enforceRateLimit(ProceedingJoinPoint joinPoint) throws Throwable {
String key = getKey(joinPoint);
if (!rateLimiter.tryAcquire(key, rateLimitConfig.getRequests(), rateLimitConfig.getSeconds())) {
throw new RateLimitExceededException("Rate limit exceeded");
}
return joinPoint.proceed();
}
private String getKey(ProceedingJoinPoint joinPoint) {
//为正在调用的方法生成唯一密钥
//方法签名、用户ID、IP地址等。
}
}
创建自定义注释来标记应受速率限制的方法。
@Target(ElementType.METHOD)
@Retention(RetentionPolicy.RUNTIME)
public @interface RateLimited {
}
创建速率限制器组件,使用令牌桶算法或任何其他合适的算法来管理速率限制。
@Component
public class RateLimiter {
private final Map<String,RateLimitedSemaphore> semaphores = new ConcurrentHashMap<>();
public boolean tryAcquire(String key, int requests, int seconds) {
long currentTime = System.currentTimeMillis();
// 计算时间窗口
long startTime = currentTime - seconds * 1000;
// 过期删除
cleanupExpiredEntries(startTime);
// 获取semaphore
RateLimitedSemaphore semaphore = semaphores.computeIfAbsent(key, k -> {
RateLimitedSemaphore newSemaphore = new RateLimitedSemaphore(requests);
newSemaphore.setLastAcquireTime(currentTime); // Set last acquire time
return newSemaphore;
});
// 校验 semaphore
boolean acquired = semaphore.tryAcquire();
if (acquired) {
semaphore.setLastAcquireTime(currentTime);
// 更新
}
return acquired;
}
private void cleanupExpiredEntries(long startTime) {
Iterator<Map.Entry<String, RateLimitedSemaphore>> iterator = semaphores.entrySet().iterator();
while (iterator.hasNext()) {
Map.Entry<String, RateLimitedSemaphore> entry = iterator.next();
String key = entry.getKey();
RateLimitedSemaphore semaphore = entry.getValue();
if (semaphore.getLastAcquireTime() < startTime) {
iterator.remove();
}
}
}
private class RateLimitedSemaphore extends Semaphore {
private volatile long lastAcquireTime;
public RateLimitedSemaphore(int permits) {
super(permits);
}
public long getLastAcquireTime() {
return lastAcquireTime;
}
public void setLastAcquireTime(long lastAcquireTime) {
this.lastAcquireTime = lastAcquireTime;
}
}
}
用注解来注释应该进行速率限制的控制器方法 @RateLimited。
@RestController
public class MyController {
@RateLimited
@GetMapping("/api/resource")
public ResponseEntity<String> getResource() {
// Implementation
}
}
application.properties在您的 或 中配置速率限制属性 application.yml。
rate.limit.requests=10
rate.limit.seconds=60
要按 IP 地址限制请求,可以从传入请求中提取 IP 地址并将其用作速率限制的密钥:
private String getKey(HttpServletRequest request) {
String ipAddress = request.getRemoteAddr();
return ipAddress; //用ID做key
}
还需要修改enforceRateLimit 中的方法 RateLimitAspect 以将对象传递 HttpServletRequest 给 getKey 方法:
@Around("@annotation(RateLimited)")
public Object enforceRateLimit(ProceedingJoinPoint joinPoint) throws Throwable {
ServletRequestAttributes requestAttributes = (ServletRequestAttributes) RequestContextHolder.getRequestAttributes();
HttpServletRequest request = requestAttributes.getRequest();
String key = getKey(request);
if (!rateLimiter.tryAcquire(key, rateLimitConfig.getRequests(), rateLimitConfig.getSeconds())) {
throw new RateLimitExceededException("Rate limit exceeded");
}
return joinPoint.proceed();
}