需求背景:
限制某sql在30秒内最多只能执行3次
需求分析
微服务分布式部署,既然是分布式限流,首先自然就想到了结合redis的zset数据结构来实现。
分析对zset的操作,有几个步骤,首先,判断zset中符合rangeScore的元素个数是否已经达到阈值,如果未达到阈值,则add元素,并返回true。如果已达到阈值,则直接返回false。
代码实现
首先,我们需要根据需求编写一个lua脚本
redis.call('ZREMRANGEBYSCORE', KEYS[1], 0, tonumber(ARGV[3]))
local res = 0
if(redis.call('ZCARD', KEYS[1]) < tonumber(ARGV[5])) thenredis.call('ZADD', KEYS[1], tonumber(ARGV[2]), ARGV[1])res = 1
end
redis.call('EXPIRE', KEYS[1], tonumber(ARGV[4]))
return res
ARGV[1]: zset element
ARGV[2]: zset score(当前时间戳)
ARGV[3]: 30秒前的时间戳
ARGV[4]: zset key 过期时间30秒
ARGV[5]: 限流阈值
private final RedisTemplate<String, Object> redisTemplate;public boolean execLuaScript(String luaStr, List<String> keys, List<Object> args){RedisScript<Boolean> redisScript = RedisScript.of(luaStr, Boolean.class)return redisTemplate.execute(redisScript, keys, args.toArray());
}
测试一下效果
@SpringBootTest
public class ApiApplicationTest {@Testpublic void test2() throws InterruptedException{String luaStr = "redis.call('ZREMRANGEBYSCORE', KEYS[1], 0, tonumber(ARGV[3]))\n" +"local res = 0\n" +"if(redis.call('ZCARD', KEYS[1]) < tonumber(ARGV[5])) then\n" +" redis.call('ZADD', KEYS[1], tonumber(ARGV[2]), ARGV[1])\n" +" res = 1\n" +"end\n" +"redis.call('EXPIRE', KEYS[1], tonumber(ARGV[4]))\n" +"return res";for (int i = 0; i < 10; i++) {boolean res = execLuaScript(luaStr, Arrays.asList("aaaa"), Arrays.asList("ele"+i, System.currentTimeMillis(),System.currentTimeMillis()-30*1000, 30, 3));System.out.println(res);Thread.sleep(5000);}}
}
测试结果符合预期!
扩展阅读
lua脚本每次都需要传一长串脚本内容来回传输,会增加网络流量和延迟,而且每次都需要服务器重新解释和编译,效率较为低下。因此,不建议在实际生产环境中直接执行lua脚本,而应该使用lua脚本的hash值来进行传输。
为了方便使用,我们先把方法封装一下
import lombok.RequiredArgsConstructor;
import org.springframework.data.redis.connection.RedisScriptingCommands;
import org.springframework.data.redis.connection.ReturnType;
import org.springframework.data.redis.core.RedisTemplate;
import org.springframework.data.redis.serializer.RedisSerializer;
import org.springframework.stereotype.Component;import java.util.List;/*** @author 敖癸* @formatter:on* @since 2024/3/25*/
@Component
@RequiredArgsConstructor
public class RedisService {private final RedisTemplate<String, Object> redisTemplate;private static RedisScriptingCommands commands;private static RedisSerializer keySerializer;private static RedisSerializer valSerializer;public String loadScript(String luaStr) {byte[] bytes = RedisSerializer.string().serialize(luaStr);return this.getCommands().scriptLoad(bytes);}public <T> T execLuaHashScript(String hash, Class<T> returnType, List<String> keys, Object[] args) {byte[][] keysAndArgs = toByteArray(this.getKeySerializer(), this.getValSerializer(), keys, args);return this.getCommands().evalSha(hash, ReturnType.fromJavaType(returnType), keys.size(), keysAndArgs);}private static byte[][] toByteArray(RedisSerializer keySerializer, RedisSerializer argsSerializer, List<String> keys, Object[] args) {final int keySize = keys != null ? keys.size() : 0;byte[][] keysAndArgs = new byte[args.length + keySize][];int i = 0;if (keys != null) {for (String key : keys) {keysAndArgs[i++] = keySerializer.serialize(key);}}for (Object arg : args) {if (arg instanceof byte[]) {keysAndArgs[i++] = (byte[]) arg;} else {keysAndArgs[i++] = argsSerializer.serialize(arg);}}return keysAndArgs;}private RedisScriptingCommands getCommands() {if (commands == null) {commands = redisTemplate.getRequiredConnectionFactory().getConnection().scriptingCommands();}return commands;}private RedisSerializer getKeySerializer() {if (keySerializer == null) {keySerializer = redisTemplate.getKeySerializer();}return keySerializer;}private RedisSerializer getValSerializer() {if (valSerializer == null) {valSerializer = redisTemplate.getValueSerializer();}return valSerializer;}
}
- 测试一下:
@SpringBootTest
@TestInstance(TestInstance.Lifecycle.PER_CLASS)
public class ApiApplicationTest implements ApplicationContextAware {private static ApplicationContext context;private static RedisService redisService;public static String luaHash;private final static String LUA_STR = "redis.call('ZREMRANGEBYSCORE', KEYS[1], 0, tonumber(ARGV[3]))\n" +"local res = 0\n" +"if(redis.call('ZCARD', KEYS[1]) < tonumber(ARGV[5])) then\n" +" redis.call('ZADD', KEYS[1], tonumber(ARGV[2]), ARGV[1])\n" +" res = 1\n" +"end\n" +"redis.call('EXPIRE', KEYS[1], tonumber(ARGV[4]))\n" +"return res";@Overridepublic void setApplicationContext(ApplicationContext applicationContext) throws BeansException {context = applicationContext;}@BeforeAllpublic static void before(){redisService = context.getBean(RedisService.class);luaHash = redisService.loadScript(LUA_STR);System.out.println("lua脚本hash: "+ luaHash);}@Testpublic void testLuaHash() throws InterruptedException {for (int i = 0; i < 50; i++) {List<String> keys = Collections.singletonList("aaaa");Object[] args = new Object[]{"ele" + i, System.currentTimeMillis(), System.currentTimeMillis() - 30 * 1000, 30, 3};Boolean b = redisService.execLuaHashScript(luaHash, Boolean.class, keys, args);System.out.println(b);Thread.sleep(3000);}}
}
使用的时候在项目启动时候,把脚本load一下,后续直接用hash值就行了
搞定收工!