文章目录
- 工作原理
- 需求实现
- 1)自定义防重复提交注解
- 2)定义防重复提交AOP切面
- 3)RedisLock 工具类
- 4)过滤器 + 请求工具类
- 5)测试Controller
- 6)测试结果
工作原理
分布式环境下,可能会遇到用户对某个接口被重复点击的场景,为了防止接口重复提交造成的问题,可用 Redis 实现一个简单的分布式锁来解决问题。
在 Redis 中, SETNX
命令是可以帮助我们实现互斥。SETNX 即 SET if Not eXists (对应 Java 中的 setIfAbsent
方法),如果 key 不存在的话,才会设置 key 的值。如果 key 已经存在, SETNX 啥也不做。
需求实现
- 自定义一个防止重复提交的注解,注解中可以携带到期时间和一个参数的key
- 为需要防止重复提交的接口添加注解
- 注解AOP会拦截加了此注解的请求,进行加解锁处理并且添加注解上设置的key超时时间
- Redis 中的
key = token + "-" + path + "-" + param_value;
(例如:17800000001 + /api/subscribe/ + zhangsan) - 如果重复调用某个加了注解的接口且key还未到期,就会返回重复提交的Result。
1)自定义防重复提交注解
自定义防止重复提交注解,注解中可设置 超时时间 + 要扫描的参数(请求中的某个参数,最终拼接后成为Redis中的key)
package com.lihw.lihwtestboot.noRepeatSubmit;import java.lang.annotation.ElementType;
import java.lang.annotation.Retention;
import java.lang.annotation.RetentionPolicy;
import java.lang.annotation.Target;
/*** 防重复提交注解*/
@Target(ElementType.METHOD)
@Retention(RetentionPolicy.RUNTIME)
public @interface NoRepeatSubmit {/*** 锁过期的时间*/int seconds() default 5;/*** 要扫描的参数*/String scanParam() default "";
}
2)定义防重复提交AOP切面
@Pointcut("@annotation(noRepeatSubmit)")
表示切点表达式,它使用了注解匹配的方式来选择被注解 @NoRepeatSubmit
标记的方法。
package com.lihw.lihwtestboot.noRepeatSubmit;import com.alibaba.fastjson.JSONObject;
import org.aspectj.lang.ProceedingJoinPoint;
import org.aspectj.lang.annotation.Around;
import org.aspectj.lang.annotation.Aspect;
import org.aspectj.lang.annotation.Pointcut;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.stereotype.Component;
import org.springframework.util.Assert;
import org.springframework.web.context.request.RequestContextHolder;
import org.springframework.web.context.request.ServletRequestAttributes;
import javax.servlet.http.HttpServletRequest;
import java.io.BufferedReader;
import java.io.IOException;
import java.util.UUID;
/*** 重复提交aop*/
@Aspect
@Component
public class RepeatSubmitAspect {private static final Logger LOGGER = LoggerFactory.getLogger(RepeatSubmitAspect.class);@Autowiredprivate RedisLock redisLock;@Pointcut("@annotation(noRepeatSubmit)")public void pointCut(NoRepeatSubmit noRepeatSubmit) {}@Around("pointCut(noRepeatSubmit)")public Object around(ProceedingJoinPoint pjp, NoRepeatSubmit noRepeatSubmit) throws Throwable {//获取基本信息ServletRequestAttributes attributes = (ServletRequestAttributes) RequestContextHolder.getRequestAttributes();HttpServletRequest request = attributes.getRequest();Assert.notNull(request, "request can not null");int lockSeconds = noRepeatSubmit.seconds();//过期时间String threadName = Thread.currentThread().getName();// 获取当前线程名称String param = noRepeatSubmit.scanParam();//请求参数String path = request.getServletPath();String type = request.getMethod();String param_value = "";if (type.equals("POST")){param_value = JSONObject.parseObject(new BodyReaderHttpServletRequestWrapper(request).getBodyString()).getString(param);}else if (type.equals("GET")){param_value = request.getParameter(param);}String token = request.getHeader("uid");LOGGER.info("线程:{}, 接口:{},重复提交验证",threadName,path);String key;if (!"".equals(param) && param != null){key = token + "-" + path + "-" + param_value;//生成key}else {key = token + "-" + path;//生成key}String clientId = getClientId();// 调接口时生成临时value(UUID)// 用于添加锁,如果添加成功返回true,失败返回false boolean isSuccess = redisLock.tryLock(key, clientId, lockSeconds);ApiResult result = new ApiResult();if (isSuccess) {LOGGER.info("加锁成功:接口 = {}, key = {}", path, key);// 获取锁成功Object obj;try {// 执行进程obj = pjp.proceed();// aop代理链执行的方法} finally {// 据key从redis中获取valueif (clientId.equals(redisLock.get(key))) {// 解锁redisLock.releaseLock(key, clientId);LOGGER.info("解锁成功:接口={}, key = {},",path, key);}}return obj;} else {// 添加锁失败,认为是重复提交的请求LOGGER.info("重复请求:接口 = {}, key = {}",path, key);result.setData("重复提交");return result;}}private String getClientId() {return UUID.randomUUID().toString();}public static String getRequestBodyData(HttpServletRequest request) throws IOException{BufferedReader bufferReader = new BufferedReader(request.getReader());StringBuilder sb = new StringBuilder();String line = null;while ((line = bufferReader.readLine()) != null) {sb.append(line);}return sb.toString();}
}
3)RedisLock 工具类
package com.lihw.lihwtestboot.noRepeatSubmit;import org.apache.commons.lang3.StringUtils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.data.redis.core.StringRedisTemplate;
import org.springframework.stereotype.Service;
import java.util.concurrent.TimeUnit;@Service
public class RedisLock {private static final Logger logger = LoggerFactory.getLogger(RedisLock.class);/** 不设置过期时长 */public final static long NOT_EXPIRE = -1;@Autowiredprivate StringRedisTemplate redisTemplate;/*** @param lockKey 加锁键* @param clientId 加锁客户端唯一标识(采用UUID)* @param seconds 锁过期时间* @return*/public boolean tryLock(String lockKey, String clientId, long seconds) {if (redisTemplate.opsForValue().setIfAbsent(lockKey, clientId,seconds, TimeUnit.SECONDS)) {return true;//得到锁}else{return false;}}/*** 与 tryLock 相对应,用作释放锁** @param lockKey* @param clientId* @return*/public boolean releaseLock(String lockKey, String clientId) {String currentValue = redisTemplate.opsForValue().get(lockKey);try {if (!StringUtils.isEmpty(currentValue) && currentValue.equals(clientId)) {redisTemplate.opsForValue().getOperations().delete(lockKey);return true;}else {return false;}} catch (Exception e) {logger.error("解锁异常,,{}" , e);return false;}}/*** 获取* @param key* @return*/public String get(String key) {return get(key, NOT_EXPIRE);}public String get(String key, long expire) {String value = redisTemplate.opsForValue().get(key);if(expire != NOT_EXPIRE){redisTemplate.expire(key, expire, TimeUnit.SECONDS);}return value;}/*** 删除* @param key*/public void delete(String key) {redisTemplate.delete(key);}
}
4)过滤器 + 请求工具类
Filter类
package com.lihw.lihwtestboot.noRepeatSubmit;import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.boot.web.servlet.ServletComponentScan;
import javax.servlet.*;
import javax.servlet.annotation.WebFilter;
import javax.servlet.http.HttpServletRequest;
import java.io.IOException;@ServletComponentScan
@WebFilter(urlPatterns = "/*",filterName = "channelFilter")
public class ChannelFilter implements Filter {private final Logger logger = LoggerFactory.getLogger(this.getClass());@Overridepublic void init(FilterConfig filterConfig) throws ServletException {}@Overridepublic void doFilter(ServletRequest servletRequest, ServletResponse servletResponse, FilterChain filterChain) throws IOException, ServletException {logger.info("-----------------------Execute filter start---------------------");// 防止流读取一次后就没有了, 所以需要将流继续写出去HttpServletRequest httpServletRequest = (HttpServletRequest) servletRequest;ServletRequest requestWrapper = new BodyReaderHttpServletRequestWrapper(httpServletRequest);filterChain.doFilter(requestWrapper, servletResponse);}}
BodyReaderHttpServletRequestWrapper
对GET和POST请求的获取参数方法进行了封装
package com.lihw.lihwtestboot.noRepeatSubmit;import javax.servlet.ReadListener;
import javax.servlet.ServletInputStream;
import javax.servlet.ServletRequest;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletRequestWrapper;
import java.io.*;
import java.nio.charset.Charset;
import java.nio.charset.StandardCharsets;public class BodyReaderHttpServletRequestWrapper extends HttpServletRequestWrapper{/*** Request请求参数获取处理类*/private final byte[] body;public BodyReaderHttpServletRequestWrapper(HttpServletRequest request) throws IOException {super(request);String sessionStream = getBodyString(request);body = sessionStream.getBytes(StandardCharsets.UTF_8);}/*** 获取请求Body** @param request* @return*/private String getBodyString(final ServletRequest request) {StringBuilder sb = new StringBuilder();InputStream inputStream = null;BufferedReader reader = null;try {inputStream = cloneInputStream(request.getInputStream());reader = new BufferedReader(new InputStreamReader(inputStream, Charset.forName("UTF-8")));String line = "";while ((line = reader.readLine()) != null) {sb.append(line);}} catch (IOException e) {e.printStackTrace();} finally {if (inputStream != null) {try {inputStream.close();} catch (IOException e) {e.printStackTrace();}}if (reader != null) {try {reader.close();} catch (IOException e) {e.printStackTrace();}}}return sb.toString();}public String getBodyString() {return new String(body, StandardCharsets.UTF_8);}/*** Description: 复制输入流** @param inputStream* @return*/public InputStream cloneInputStream(ServletInputStream inputStream) {ByteArrayOutputStream byteArrayOutputStream = new ByteArrayOutputStream();byte[] buffer = new byte[1024];int len;try {while ((len = inputStream.read(buffer)) > -1) {byteArrayOutputStream.write(buffer, 0, len);}byteArrayOutputStream.flush();} catch (IOException e) {e.printStackTrace();}InputStream byteArrayInputStream = new ByteArrayInputStream(byteArrayOutputStream.toByteArray());return byteArrayInputStream;}@Overridepublic BufferedReader getReader() throws IOException {return new BufferedReader(new InputStreamReader(getInputStream()));}@Overridepublic ServletInputStream getInputStream() throws IOException {final ByteArrayInputStream bais = new ByteArrayInputStream(body);return new ServletInputStream() {@Overridepublic int read() throws IOException {return bais.read();}@Overridepublic boolean isFinished() {return false;}@Overridepublic boolean isReady() {return false;}@Overridepublic void setReadListener(ReadListener readListener) {}};}
}
5)测试Controller
package com.lihw.lihwtestboot.noRepeatSubmit;import org.springframework.validation.annotation.Validated;
import org.springframework.web.bind.annotation.*;
import javax.validation.constraints.NotEmpty;@RestController
@RequestMapping("/api")
@Validated
public class noRepeatSubmitController {@GetMapping("/subscribe/{channel}")@NoRepeatSubmit(seconds = 10,scanParam = "username")public ApiResult subscribe(@RequestHeader(name = "uid") String phone,@RequestHeader(name = "username") String username,@PathVariable("channel") @NotEmpty(message = "channel不能为空") String channel) {System.out.println("phone=" + phone);System.out.println("username=" + username);System.out.println("channel=" + channel);try {Thread.sleep(5000);//模拟耗时} catch (InterruptedException e) {e.printStackTrace();}return new ApiResult("success","data");}
}
6)测试结果
重复点击