令牌桶算法属于流量控制算法,在一定时间内保证一个键(key)的访问量不超过某个阈值。这里的关键是设置一个令牌桶,在某个时间段内生成一定数量的令牌,然后每次访问时从桶中获取令牌,如果桶中没有令牌,就拒绝访问。
参考网上一个博主写的:https://blog.csdn.net/xdx_dili/article/details/133683315
注意:我这边只是学习实践加上修改对应的代码记录下而已
第一步:记得要下载redis并配置好
第二步:创建springboot项目并引入maven,配置好配置文件
(注意我这边使用的springboot版本是2.6.x,因为2.7开始博主的部分代码不可用了)
<dependencies><dependency><groupId>org.springframework.boot</groupId><artifactId>spring-boot-starter-web</artifactId></dependency><dependency><groupId>org.springframework.boot</groupId><artifactId>spring-boot-starter-data-redis</artifactId></dependency><dependency><groupId>org.springframework.boot</groupId><artifactId>spring-boot-starter-aop</artifactId></dependency><dependency><groupId>com.google.guava</groupId><artifactId>guava</artifactId><version>21.0</version></dependency><dependency><groupId>org.springframework.boot</groupId><artifactId>spring-boot-starter-test</artifactId></dependency><dependency><groupId>org.apache.commons</groupId><artifactId>commons-lang3</artifactId></dependency><dependency><groupId>com.alibaba</groupId><artifactId>fastjson</artifactId><version>1.2.78</version></dependency>
</dependencies>
#application.properties
#redis
spring.redis.host=127.0.0.1
spring.redis.port=6379server.port=8081
第三步:代码部分
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
import org.springframework.data.redis.connection.lettuce.LettuceConnectionFactory;
import org.springframework.data.redis.core.RedisTemplate;
import org.springframework.data.redis.serializer.GenericJackson2JsonRedisSerializer;
import org.springframework.data.redis.serializer.StringRedisSerializer;import java.io.Serializable;/*** RedisTemplate调用实例*/
@Configuration
public class RedisLimiterHelper {//spring会帮助我们注入LettuceConnectionFactory@Beanpublic RedisTemplate<String, Serializable> limitRedisTemplate(LettuceConnectionFactory redisConnectionFactory) {RedisTemplate<String, Serializable> template = new RedisTemplate<>();template.setKeySerializer(new StringRedisSerializer());template.setValueSerializer(new GenericJackson2JsonRedisSerializer());template.setConnectionFactory(redisConnectionFactory);return template;}
}
/*** 限流类型枚举类*/
public enum LimitType {/*** 自定义key*/CUSTOMER,/*** 请求者IP*/IP;
}
import java.lang.annotation.*;/*** 自定义限流注解*/
//@Target({ElementType.METHOD, ElementType.TYPE}) 表示这个注解可以应用到方法和类上。
//ElementType.METHOD 表示这个注解可以应用到方法上,即可以在方法上添加这个注解。
//ElementType.TYPE 表示这个注解可以应用到类上,即可以在类上添加这个注解。
//@Retention(RetentionPolicy.RUNTIME) 表示这个注解的元数据信息(metadata)在运行时可用。
//@Inherited 表示这个注解可以被继承。
//@Documented 表示这个注解的文档信息(documentation)会被自动生成
@Target({ElementType.METHOD, ElementType.TYPE})
@Retention(RetentionPolicy.RUNTIME)
@Inherited
@Documented
public @interface RedisLimit {/*** 名字*/String name() default "";/*** key*/String key() default "";/*** Key的前缀*/String prefix() default "";/*** 给定的时间范围 单位(秒)*/int period();/*** 一定时间内最多访问次数*/int count();/*** 限流的类型(用户自定义key 或者 请求ip)*/LimitType limitType() default LimitType.CUSTOMER;}
import org.apache.commons.lang3.StringUtils;
import org.aspectj.lang.ProceedingJoinPoint;
import org.aspectj.lang.annotation.Around;
import org.aspectj.lang.annotation.Aspect;
import org.aspectj.lang.reflect.MethodSignature;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.context.annotation.Configuration;
import org.springframework.data.redis.core.RedisTemplate;
import org.springframework.data.redis.core.script.DefaultRedisScript;
import org.springframework.data.redis.core.script.RedisScript;
import org.springframework.web.context.request.RequestContextHolder;
import org.springframework.web.context.request.ServletRequestAttributes;import javax.servlet.http.HttpServletRequest;
import java.io.Serializable;
import java.lang.reflect.Method;
import java.util.Arrays;
import java.util.List;/*** 限流切面实现*/
@Aspect
@Configuration
public class LimitInterceptor {private static final Logger logger = LoggerFactory.getLogger(LimitInterceptor.class);private static final String UNKNOWN = "unknown";private final RedisTemplate<String, Serializable> limitRedisTemplate;@Autowiredpublic LimitInterceptor(RedisTemplate<String, Serializable> limitRedisTemplate) {this.limitRedisTemplate = limitRedisTemplate;}//切面(使用了该RedisLimit注解时触发)@Around("execution(public * *(..)) && @annotation(com.zhangximing.redis_springboot.annotate.RedisLimit)")public Object interceptor(ProceedingJoinPoint pjp) {MethodSignature signature = (MethodSignature) pjp.getSignature();Method method = signature.getMethod();RedisLimit limitAnnotation = method.getAnnotation(RedisLimit.class);LimitType limitType = limitAnnotation.limitType();String key;int limitPeriod = limitAnnotation.period();int limitCount = limitAnnotation.count();/*** 根据限流类型获取不同的key ,如果不传我们会以方法名作为key*/switch (limitType) {case IP:key = getIpAddress();break;case CUSTOMER:key = limitAnnotation.key();break;default:key = StringUtils.upperCase(method.getName());}List<String> keys = Arrays.asList(StringUtils.join(limitAnnotation.prefix(), key));try {//lua脚本String luaScript = buildLuaScript();//获取已调用数量RedisScript<Long> redisScript = new DefaultRedisScript<>(luaScript, Long.class);Long count = limitRedisTemplate.execute(redisScript, keys, limitCount, limitPeriod);//判断是否已超过限制if (count != null && count <= limitCount) {return pjp.proceed();} else {throw new RuntimeException("服务忙,请稍后再试");}} catch (Throwable e) {if (e instanceof RuntimeException) {
// throw new RuntimeException(e.getLocalizedMessage());return e.getLocalizedMessage();}
// throw new RuntimeException("server exception");return "服务异常";}}//编写 redis Lua 限流脚本public String buildLuaScript() {StringBuilder lua = new StringBuilder();lua.append("local c");//KEYS[1]表示keylua.append("\nc = redis.call('get',KEYS[1])");// 调用不超过最大值,则直接返回 (ARGV[1]表示第一个参数)lua.append("\nif c and tonumber(c) > tonumber(ARGV[1]) then");lua.append("\nreturn c;");lua.append("\nend");// 执行计算器自加lua.append("\nc = redis.call('incr',KEYS[1])");lua.append("\nif tonumber(c) == 1 then");// 从第一次调用开始限流,设置对应键值的过期lua.append("\nredis.call('expire',KEYS[1],ARGV[2])");lua.append("\nend");lua.append("\nreturn c;");return lua.toString();}//获取id地址public String getIpAddress() {HttpServletRequest request = ((ServletRequestAttributes) RequestContextHolder.getRequestAttributes()).getRequest();String ip = request.getHeader("x-forwarded-for");if (ip == null || ip.length() == 0 || UNKNOWN.equalsIgnoreCase(ip)) {ip = request.getHeader("Proxy-Client-IP");}if (ip == null || ip.length() == 0 || UNKNOWN.equalsIgnoreCase(ip)) {ip = request.getHeader("WL-Proxy-Client-IP");}if (ip == null || ip.length() == 0 || UNKNOWN.equalsIgnoreCase(ip)) {ip = request.getRemoteAddr();}return ip;}
}
import com.zhangximing.redis_springboot.annotate.LimitType;
import com.zhangximing.redis_springboot.annotate.RedisLimit;
import org.springframework.web.bind.annotation.RequestMapping;
import org.springframework.web.bind.annotation.RestController;
import java.util.concurrent.atomic.AtomicInteger;/*** @Author: zhangximing* @Email: 530659058@qq.com* @Date: 2023/12/20 10:59* @Description: 限流测试类 参考 https://blog.csdn.net/xdx_dili/article/details/133683315*/
@RestController
@RequestMapping("/limit")
public class LimitController {private static final AtomicInteger ATOMIC_INTEGER_1 = new AtomicInteger();//十秒同一IP限制访问5次@RedisLimit(period = 10, count = 5, name = "测试接口", limitType = LimitType.IP)@RequestMapping("/test")public String testLimit(){return "SUCCESS:"+ATOMIC_INTEGER_1.incrementAndGet();}}
效果展示: