场景:限制请求后端接口的频率,例如1秒钟只能请求次数不能超过10次,通常的写法是:
1.先去从redis里面拿到当前请求次数
2.判断当前次数是否大于或等于限制次数
3.当前请求次数小于限制次数时进行自增
这三步在请求不是很密集的时候,程序执行很快,可能不会产生问题,如果两个请求几乎在同一时刻到来,我们第1步和第2步的判断是无法保证原子性的。
改进方式:使用redis的lua脚本,将"读取值、判断大小、自增"放到redis的一次操作中,redis底层所有的操作请求都是串行的,也就是一个请求执行完,才会执行下一个请求。
自增的lua脚本如下
/*** 自增过期时间的原子性脚本*/private String maxCountScriptText() {return "local key = KEYS[1]\n" +"local count = tonumber(ARGV[1])\n" +"local time = tonumber(ARGV[2])\n" +"local current = redis.call('get', key);\n" +"if current and tonumber(current) > count then\n" +" return tonumber(current);\n" +"end\n" +"current = redis.call('incr', key)\n" +"if tonumber(current) == 1 then\n" +" redis.call('expire', key, time)\n" +"end\n" +"return tonumber(current);";}
将接口限流功能封装成一个注解@RateLimiter,在接口方法上面加上@RateLimiter就可以实现限流:
redis工具类:
package com.zhou.redis.util;import com.zhou.redis.dto.MyRedisMessage;
import com.zhou.redis.exception.LockException;
import com.zhou.redis.script.MaxCountQueryScript;
import com.zhou.redis.script.MaxCountScript;
import lombok.extern.slf4j.Slf4j;
import org.springframework.context.annotation.Configuration;
import org.springframework.data.redis.core.HashOperations;
import org.springframework.data.redis.core.RedisTemplate;
import org.springframework.data.redis.core.ValueOperations;import java.time.Duration;
import java.util.Collections;
import java.util.List;
import java.util.concurrent.TimeUnit;@Configuration
@Slf4j
public class RedisUtil {public RedisTemplate<String, Object> redisTemplate;private MaxCountScript maxCountScript;private MaxCountQueryScript maxCountQueryScript;public RedisUtil(RedisTemplate redisTemplate, MaxCountScript maxCountScript, MaxCountQueryScript maxCountQueryScript) {this.redisTemplate = redisTemplate;this.maxCountScript = maxCountScript;this.maxCountQueryScript = maxCountQueryScript;}/*** 尝试加锁,返回加锁成功或者失败* @param time 秒**/public boolean tryLock(String key,Object value,Long time){if(time == null || time <= 0){time = 30L;}Boolean b = redisTemplate.opsForValue().setIfAbsent(key, value, Duration.ofSeconds(time));return b == null ? false : b;}/*** 释放锁(拿到锁之后才能调用释放锁)**/public boolean unLock(String key){Boolean b = redisTemplate.delete(key);return b == null ? false : b;}/*** 对key进行自增1* @param maxCount 最大值* @param time 增加次数* @return 自增后的值*/public Long incr(String key,int maxCount, int time){List<String> keys = Collections.singletonList(key);return redisTemplate.execute(maxCountScript, keys, maxCount, time);}/*** 获得当前值*/public Long incrNow(String key){List<String> keys = Collections.singletonList(key);return redisTemplate.execute(maxCountQueryScript, keys);}
}
redis配置类:
package com.zhou.redis.config;import com.fasterxml.jackson.annotation.JsonAutoDetect;
import com.fasterxml.jackson.annotation.PropertyAccessor;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.zhou.redis.listener.MyRedisListener;
import com.zhou.redis.script.MaxCountQueryScript;
import com.zhou.redis.script.MaxCountScript;
import com.zhou.redis.util.RedisTopic;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
import org.springframework.data.redis.connection.RedisConnectionFactory;
import org.springframework.data.redis.core.RedisTemplate;
import org.springframework.data.redis.listener.PatternTopic;
import org.springframework.data.redis.listener.RedisMessageListenerContainer;
import org.springframework.data.redis.listener.Topic;
import org.springframework.data.redis.listener.adapter.MessageListenerAdapter;
import org.springframework.data.redis.serializer.Jackson2JsonRedisSerializer;
import org.springframework.data.redis.serializer.StringRedisSerializer;import java.util.Arrays;
import java.util.List;@Configuration
public class RedisConfig {@SuppressWarnings("all")@Beanpublic RedisTemplate<String, Object> redisTemplate(RedisConnectionFactory factory) {RedisTemplate<String, Object> template = new RedisTemplate<>();template.setConnectionFactory(factory);//Json序列化配置Jackson2JsonRedisSerializer<Object> jackson2JsonRedisSerializer = new Jackson2JsonRedisSerializer<>(Object.class);ObjectMapper om = new ObjectMapper();om.setVisibility(PropertyAccessor.ALL, JsonAutoDetect.Visibility.ANY);om.enableDefaultTyping(ObjectMapper.DefaultTyping.NON_FINAL);jackson2JsonRedisSerializer.setObjectMapper(om);//String的序列化StringRedisSerializer stringRedisSerializer = new StringRedisSerializer();//key采用string的序列化template.setKeySerializer(stringRedisSerializer);//hash的key采用string的序列化template.setHashKeySerializer(stringRedisSerializer);//value序列化采用jacksontemplate.setValueSerializer(jackson2JsonRedisSerializer);//hash的value序列化方式采用jacksontemplate.setHashValueSerializer(jackson2JsonRedisSerializer);template.afterPropertiesSet();return template;}/*** Redis消息监听器容器* 这个容器加载了RedisConnectionFactory和消息监听器* 可以添加多个监听不同话题的redis监听器,只需要把消息监听器和相应的消息订阅处理器绑定,该消息监听器* 通过反射技术调用消息订阅处理器的相关方法进行一些业务处理** @param redisConnectionFactory 连接工厂* @param adapter 适配器* @return redis消息监听容器*/@Bean@SuppressWarnings("all")public RedisMessageListenerContainer container(RedisConnectionFactory redisConnectionFactory,FuncUpdateListener listener,MessageListenerAdapter adapter) {RedisMessageListenerContainer container = new RedisMessageListenerContainer();// 监听所有库的key过期事件container.setConnectionFactory(redisConnectionFactory);// 所有的订阅消息,都需要在这里进行注册绑定,new PatternTopic(TOPIC_NAME1)表示发布的主题信息// 可以添加多个 messageListener,配置不同的通道List<Topic> topicList = Arrays.asList(new PatternTopic(RedisTopic.TOPIC1),new PatternTopic(RedisTopic.TOPIC2));container.addMessageListener(listener, topicList);/*** 设置序列化对象* 特别注意:1. 发布的时候需要设置序列化;订阅方也需要设置序列化* 2. 设置序列化对象必须放在[加入消息监听器]这一步后面,否则会导致接收器接收不到消息*/Jackson2JsonRedisSerializer seria = new Jackson2JsonRedisSerializer(Object.class);ObjectMapper objectMapper = new ObjectMapper();objectMapper.setVisibility(PropertyAccessor.ALL, JsonAutoDetect.Visibility.ANY);objectMapper.enableDefaultTyping(ObjectMapper.DefaultTyping.NON_FINAL);seria.setObjectMapper(objectMapper);container.setTopicSerializer(seria);return container;}/*** 这个地方是给messageListenerAdapter 传入一个消息接受的处理器,利用反射的方法调用“receiveMessage”* 也有好几个重载方法,这边默认调用处理器的方法 叫OnMessage*/@SuppressWarnings("all")@Beanpublic MessageListenerAdapter listenerAdapter() {//MessageListenerAdapter receiveMessage = new MessageListenerAdapter(printMessageReceiver, "receiveMessage");MessageListenerAdapter receiveMessage = new MessageListenerAdapter();Jackson2JsonRedisSerializer seria = new Jackson2JsonRedisSerializer(Object.class);ObjectMapper objectMapper = new ObjectMapper();objectMapper.setVisibility(PropertyAccessor.ALL, JsonAutoDetect.Visibility.ANY);objectMapper.enableDefaultTyping(ObjectMapper.DefaultTyping.NON_FINAL);seria.setObjectMapper(objectMapper);receiveMessage.setSerializer(seria);return receiveMessage;}@Beanpublic MaxCountScript maxCountScript() {return new MaxCountScript(maxCountScriptText());}@Beanpublic MaxCountQueryScript maxCountQueryScript() {return new MaxCountQueryScript(maxCountQueryScriptText());}/*** 自增过期时间的原子性脚本*/private String maxCountScriptText() {return "local key = KEYS[1]\n" +"local count = tonumber(ARGV[1])\n" +"local time = tonumber(ARGV[2])\n" +"local current = redis.call('get', key);\n" +"if current and tonumber(current) > count then\n" +" return tonumber(current);\n" +"end\n" +"current = redis.call('incr', key)\n" +"if tonumber(current) == 1 then\n" +" redis.call('expire', key, time)\n" +"end\n" +"return tonumber(current);";/*return "local limitMaxCount = tonumber(ARGV[1])\n" +"local limitSecond = tonumber(ARGV[2])\n" +"local num = tonumber(redis.call('get', KEYS[1]) or '-1')\n" +"if limitMaxCount then\n" +" return -1\n" +"end\n" +"if num == -1 then\n" +" redis.call('incr', KEYS[1])\n" +" redis.call('expire', KEYS[1], limitSecond)\n" +" return 1\n" +"else\n" +" if num >= limitMaxCount then\n" +" return 0\n" +" else\n" +" redis.call('incr', KEYS[1])\n" +" return 1\n" +" end\n" +"end";*/}/*** 查询当前值脚本*/private String maxCountQueryScriptText() {return "local key = KEYS[1]\n" +"local current = redis.call('get', key);\n" +"if current then\n" +" return tonumber(current);\n" +"else\n" +" return current\n" +"end\n";}
}
拦截模式枚举类:根据ip拦截或者方法拦截
package com.zhou.aop;/*** @author lang.zhou* @since 2023/1/31 17:56*/
public enum LimitType {IP,DEFAULT
}
封装自定义注解:@RateLimiter
package com.zhou.aop;import java.lang.annotation.*;/*** @author lang.zhou* @since 2023/1/31 17:49*/
@Target(ElementType.METHOD)
@Retention(RetentionPolicy.RUNTIME)
@Documented
public @interface RateLimiter {/*** 限流key*/String key() default "RateLimiter";/*** 限流时间,单位秒*/int time() default 60;/*** 限流次数*/int count() default 100;/*** 限流类型*/LimitType limitType() default LimitType.DEFAULT;/*** 限流后返回的文字*/String limitMsg() default "访问过于频繁,请稍候再试";
}
注解的切面逻辑:
package com.zhou.aop;import com.zhou.redis.util.RedisUtil;
import com.zhou.common.utils.IpUtil;
import lombok.extern.slf4j.Slf4j;
import org.aspectj.lang.JoinPoint;
import org.aspectj.lang.annotation.Aspect;
import org.aspectj.lang.annotation.Before;
import org.aspectj.lang.reflect.MethodSignature;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.stereotype.Component;
import org.springframework.web.context.request.RequestContextHolder;
import org.springframework.web.context.request.ServletRequestAttributes;import javax.servlet.http.HttpServletRequest;
import java.lang.reflect.Method;/*** 接口限流切面* @author lang.zhou* @since 2023/1/31 17:50*/
@Aspect
@Slf4j
@Component
public class RateLimiterAspect {@Autowiredprivate RedisUtil redisUtils;@Before("@annotation(rateLimiter)")public void doBefore(JoinPoint point, RateLimiter rateLimiter) {int time = rateLimiter.time();int count = rateLimiter.count();String combineKey = getCombineKey(rateLimiter, point);try {Long number = redisUtils.incr(combineKey, count, time);if (number == null || number.intValue() > count){log.info("请求【{}】被拦截,{}秒内请求次数{}",combineKey,time,number);throw new RuntimeException(rateLimiter.limitMsg());}} catch (ServiceRuntimeException e) {throw e;} catch (Exception e) {throw new RuntimeException("网络繁忙,请稍候再试");}}/*** 获取限流key*/public String getCombineKey(RateLimiter rateLimiter, JoinPoint point) {StringBuilder s = new StringBuilder(rateLimiter.key());ServletRequestAttributes requestAttributes = (ServletRequestAttributes) RequestContextHolder.getRequestAttributes();if(requestAttributes != null){HttpServletRequest request = requestAttributes.getRequest();if (rateLimiter.limitType() == LimitType.IP) {s.append(IpUtil.getIpAddr(request)).append("-");}}MethodSignature signature = (MethodSignature) point.getSignature();Method method = signature.getMethod();Class<?> targetClass = method.getDeclaringClass();s.append(targetClass.getName()).append(".").append(method.getName());return s.toString();}}
lua自增脚本类:
package com.zhou.redis.script;import org.springframework.data.redis.core.script.DefaultRedisScript;/*** @author lang.zhou* @since 2023/2/25*/
public class MaxCountScript extends DefaultRedisScript<Long> {public MaxCountScript(String script) {super(script,Long.class);}
}
lua查询当前值的脚本类:
package com.zhou.redis.script;import org.springframework.data.redis.core.script.DefaultRedisScript;/*** @author lang.zhou* @since 2023/2/25*/
public class MaxCountQueryScript extends DefaultRedisScript<Long> {public MaxCountQueryScript(String script) {super(script,Long.class);}
}
订阅消息通道的枚举:
package com.zhou.redis.util;public class RedisTopic {public static final String TOPIC1 = "TOPIC1";public static final String TOPIC2 = "TOPIC2";
}
消息实体类:
package com.zhou.redis.dto;import lombok.Data;import java.io.Serializable;/*** redis订阅消息实体* @since 2022/11/11 17:34*/
@Data
public class MyRedisMessage implements Serializable {private String msg;
}
订阅消息监听器:
package com.zhou.redis.listener;import com.zhou.redis.dto.MyRedisMessage;
import com.zhou.redis.util.RedisTopic;
import lombok.extern.slf4j.Slf4j;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.data.redis.connection.Message;
import org.springframework.data.redis.connection.MessageListener;
import org.springframework.data.redis.core.RedisTemplate;
import org.springframework.stereotype.Component;import javax.script.ScriptException;/*** @author lang.zhou*/
@Slf4j
@Component
public class MyRedisListener implements MessageListener {@Autowiredprivate RedisTemplate<String,Object> redisTemplate;@Overridepublic void onMessage(Message message, byte[] pattern) {String topic = new String(pattern);// 接收的topiclog.info("channel:{}" , topic);if(RedisTopic.TOPIC1.equals(topic)){//}else if(RedisTopic.TOPIC2.equals(topic)){//序列化对象(特别注意:发布的时候需要设置序列化;订阅方也需要设置序列化)MyRedisMessage msg = (MyRedisMessage) redisTemplate.getValueSerializer().deserialize(message.getBody());log.info("message:{}",msg);}}
}
注解使用方式:1秒内一个ip最多只能请求10次
@RestController
@RequestMapping("/test/api")
public class CheckController{@PostMapping("/limit")@RateLimiter(time = 1, count = 10, limitType = LimitType.IP, limitMsg = "请求过于频繁,请稍后重试")public void limit(HttpServletRequest request){//执行业务代码}}