在后台开发中,服务端的限流器是一个很常见并且十分有用的组件,利用好限流器可以限制请求速率,保护后台服务。 比较常见的限流器分为两种,漏桶算法和令牌桶算法。
漏桶算法原理很简单,用一个漏斗来控制请求的速率。在漏斗上方是收到的所有请求,请求就像水一样会进入漏斗中,同时漏斗也会以恒定的速度将水(请求)从下方进行排出,被排出的水(请求)才能访问服务。当请求量不大时候,如进水速率 < 出水速率那么其实漏斗并没有起到作用;当请求量很大的时候,超过漏斗容量的请求将被溢出,并且出水口可以一直保证恒定的速率。
令牌桶算法原理也很简单,假设我们的服务允许请求速度上限为5000次/分,那么这就意味着桶内的令牌数为5000,并且每隔一分钟桶内的令牌数就会被重置为5000。每一个请求过来都需要从桶内拿一块令牌,如果能取得令牌则允许访问服务,否则将会拒绝请求。
本文将基于redis来设计一个在分布式场景下的令牌桶算法,旨在重点解决以下问题:
在实际场景中,服务的限流往往会和一些参数绑定在一起,比如:限制同一个ip地址的请求速率为5000次/分,限制某一个业务id的请求速率为5000次/分,根据这些绑定的变量数值,我们可以在redis中设置对应的key,通过不断累加该key对应的数值来实现限流器的设计。
假设我们服务请求速率的最大值max
为5000
次/分。
当服务器收到请求时,首先判断redis中对应键k
的数值v
是否超过5000
,如果是则拒绝请求,如果为否则继续判断v
是否为0
,当v
为0
的时候,我们需要进行初始化。初始化需要将v
的值置为1
,并且设置过期时间为60
s。考虑以下几个问题:
incr
命令?对于第一个问题,答案肯定是必然的,我们需要保证只有一个请求
能进行初始化,否则在并发情况下会出现多个请求线程都对v
进行置1
操作,从而导致计数器不准确。
那么如何进行加锁操作呢?在分布式场景下是用本地锁是不正确的,因此我们同样可以利用redis的SET .. NX
命令来实现分布式锁,来保证只有一个线程能进行初始化。
有一个需要注意的细节是:线程在获得锁之后,还需要在读取一次v的值
,如果此时读取到数值不为0
则说明在此之前已经被其他线程捷足先登了,此时就应该放弃初始化。
对于第二个问题,虽然redis的incr
命令也可以保证只有一个请求线程能进行置1
操作(因为redis是单线程的,天然满足锁),但是incr
没有办法设置过期时间,因此不能直接使用incr
命令。
代码如下:
func initCount(key string) (int, error) {
lockKey := key + "_" + initLockKeySuffix
getLock := false
var err error
// 一个循环去抢占锁
for !getLock {
// 再读取一次count数值,保证只有一个线程进行初始化
// 这一步很重要,可能在抢锁的过程中已经有其他线程完成了初始化,那么此线程就不需要初始化了
count, err := redis.GetInt(key)
if err != nil {
// 记录错误继续抢占初始化
log.Errorf("get redis failed, err: %v", err)
time.Sleep(time.Millisecond * 10)
continue
}
if count > 0 {
// 不是第一个线程,放弃初始化
return -1, nil
}
// 设置一个3s过期的nx锁
getLock, err = redis.SetNxWithExpire(lockKey, "ok", 3)
if err != nil {
// 记录错误继续抢占初始化
log.Errorf("set nx failed, err: %v", err)
time.Sleep(time.Millisecond * 10)
continue
}
if getLock {
// 抢到,退出循环
break
} else {
// 没抢到锁,等下一次
time.Sleep(time.Millisecond * 100)
}
}
// 获得锁之后开始进行初始化
ok, err := redis.SetIntWithExpire(key, 1, comm.Interval)
if err != nil || !ok {
// 初始化失败
return 0, myError.WithMessage(err, "redis init failed")
}
log.Info("init success")
// 删除锁
e := redis.DelKey(lockKey)
if e != nil {
// 只能被动等待锁过期
log.Errorf("redis del key failed, key: %s, err: %s", key, e)
}
// 初始化成功
return 1, nil
}
对于上述抢占失败的线程,以及新来的请求线程就没有必要继续初始化了,而是直接对v
值进行加1
操作。考虑以下几个问题:
1
操作是否需要加锁?使用redis的incr
命令进行加1
操作,由于redis天然是单线程的,因此加1
操作是不需要进行加锁的。对于每一个请求,可以通过判断incr
返回值是否大于max
来决定是否拒绝请求。
在并发的情况下,假设我们的服务限制访问速率为5000
次/分,在某一时刻t
请求数量已经达到了4999
次,此时突然并发来了10
个请求,按照上面设计的流程,这10个请求首先读取redis中对应键k
的数值v
,同时读取到了4999
这个值,那么则会都进行加1
操作,于是redis中对应键k
最终值则为5009
,超过了5000
,虽然不影响服务,但是redis中值却超过了预期值,为了解决边界问题
,我采用了阈值法
,根据业务的需求可以事先估计一个阈值δ
,比如80%
,当redis中对应键k
的数值v
小于max * δ
时,则不加锁直接使用incr
进行加1
,当超过时,则进行加锁排队加1
。
代码如下
func increaseCountWithoutLock(key string) (int, error) {
// 直接进行加1,上层对加1后的数值进行判断
return redis.IncrOne(key)
}
func increaseCountWithLock(bid int, key string, count int) (int, error) {
if float64(count) < comm.RateThreshold*comm.RateLimit {
// 没有达到阈值,直接使用redis的incr来保证原子性
return redis.IncrOne(key)
}
// 达到阈值后incr操作需要排队
newCount, err := increaseSerialized(bid, key)
if err != nil {
return 0, myError.WithMessage(err, "increaseSerialized failed")
}
return newCount, nil
}
func increaseSerialized(bid int, key string) (int, error) {
lockKey := strconv.Itoa(bid) + incrLockKeySuffix
getLock := false
var err error
for !getLock {
// 再读取一次
oldCount, err := redis.GetInt(key)
if err != nil {
// 记录错误继续抢占资源
log.Errorf("set nx failed, err: %v", err)
time.Sleep(time.Millisecond * 10)
continue
}
if oldCount >= comm.RateLimit {
// redis不用加1,直接返回
return oldCount + 1, nil
}
// 设置一个3s过期的nx锁
getLock, err = redis.SetNxWithExpire(lockKey, "ok", 3)
if err != nil {
return -1, myError.WithMessage(err, "set redis lock failed")
}
if getLock {
break
} else {
time.Sleep(time.Millisecond * 100)
}
}
// 获得锁之后开始进行加1操作
newCount, err := redis.IncrOne(key)
// 删除锁
e := redis.DelKey(lockKey)
if e != nil {
// 只能被动等待锁过期
log.Errorf("redis del key failed, key: %s, err: %s", key, e)
}
return newCount, err
}
第一种方式不需要加锁,代码简单,但是没有保证redis中计数器的正确性,即没有满足解决问题(但是不影响业务);第二种方式在达到阈值后需要加锁,代码较为复杂。
在初始化redis计数器时,我们使用了SET...EX
方式设置了过期时间,但是在实际中可能出现key
过期后却没有自动删除的现象,于是这里加上了手动删除过期key
的监控,采用redis的ttl
和del
命令组合来重置计数器。
代码如下
func ttlCount(key string) error {
leftTime, err := redis.GetTtl(key)
log.Infof("key: %s, left time: %d", key, leftTime)
if err != nil {
return myError.WithMessage(err, "")
}
if leftTime == -1 || leftTime > comm.Interval {
// 说明此时key没有设置过期时间或者超时时间出错,则进行删除
return redis.DelKey(key)
}
return nil
}