在给spring webflux做接口签名、防重放的时候,往往需要获取请求参数,请求方法等,而spring webflux无法像spring mvc那样好获取,这里根据之前的实践特地说明一下:
总体思路:
1、利用过滤器,从原request中获取到信息后,缓存在一个上下文对象中,然后构造新的request,传入后面的过滤器。因为原request流式的,用过一次后便无法再取参数了。
2、通过exchange的Attributes传递上下文对象,在不同的过滤器中使用即可。
1、上下文对象
@Getter
@Setter
@ToString
public class GatewayContext {public static final String CACHE_GATEWAY_CONTEXT = "cacheGatewayContext";/*** cache requestMethod*/private String requestMethod;/*** cache queryParams*/private MultiValueMap<String, String> queryParams;/*** cache json body*/private String requestBody;/*** cache Response Body*/private Object responseBody;/*** request headers*/private HttpHeaders requestHeaders;/*** cache form data*/private MultiValueMap<String, String> formData;/*** cache all request data include:form data and query param*/private MultiValueMap<String, String> allRequestData = new LinkedMultiValueMap<>(0);private byte[] requestBodyBytes;}
2、在过滤器中获取请求参数、请求方法。
这里我们只对application/json
、application/x-www-form-urlencoded
这种做body参数拦截,而对于其他的请求,则可以通过url直接获取到query参数。
@Slf4j
@Component
public class GatewayContextFilter implements WebFilter, Ordered {/*** default HttpMessageReader*/private static final List<HttpMessageReader<?>> MESSAGE_READERS = HandlerStrategies.withDefaults().messageReaders();@Overridepublic Mono<Void> filter(ServerWebExchange exchange, WebFilterChain chain) {ServerHttpRequest request = exchange.getRequest();GatewayContext gatewayContext = new GatewayContext();HttpHeaders headers = request.getHeaders();gatewayContext.setRequestHeaders(headers);gatewayContext.getAllRequestData().addAll(request.getQueryParams());gatewayContext.setRequestMethod(request.getMethodValue().toUpperCase());gatewayContext.setQueryParams(request.getQueryParams());/** save gateway context into exchange*/exchange.getAttributes().put(GatewayContext.CACHE_GATEWAY_CONTEXT, gatewayContext);MediaType contentType = headers.getContentType();if (headers.getContentLength() > 0) {if (MediaType.APPLICATION_JSON.equals(contentType)) {return readBody(exchange, chain, gatewayContext);}if (MediaType.APPLICATION_FORM_URLENCODED.equalsTypeAndSubtype(contentType)) {return readFormData(exchange, chain, gatewayContext);}}String path = request.getPath().value();if (!"/".equals(path)) {log.info("{} Gateway context is set with {}-{}", path, contentType, gatewayContext);}return chain.filter(exchange);}@Overridepublic int getOrder() {return Integer.MIN_VALUE + 1;}/*** ReadFormData*/private Mono<Void> readFormData(ServerWebExchange exchange, WebFilterChain chain, GatewayContext gatewayContext) {HttpHeaders headers = exchange.getRequest().getHeaders();return exchange.getFormData().doOnNext(multiValueMap -> {gatewayContext.setFormData(multiValueMap);gatewayContext.getAllRequestData().addAll(multiValueMap);log.debug("[GatewayContext]Read FormData Success");}).then(Mono.defer(() -> {Charset charset = headers.getContentType().getCharset();charset = charset == null ? StandardCharsets.UTF_8 : charset;String charsetName = charset.name();MultiValueMap<String, String> formData = gatewayContext.getFormData();/** formData is empty just return*/if (null == formData || formData.isEmpty()) {return chain.filter(exchange);}log.info("1. Gateway Context formData: {}", formData);StringBuilder formDataBodyBuilder = new StringBuilder();String entryKey;List<String> entryValue;try {/** repackage form data*/for (Map.Entry<String, List<String>> entry : formData.entrySet()) {entryKey = entry.getKey();entryValue = entry.getValue();if (entryValue.size() > 1) {for (String value : entryValue) {formDataBodyBuilder.append(URLEncoder.encode(entryKey, charsetName).replace("+", "%20").replace("*", "%2A").replace("%7E", "~")).append("=").append(URLEncoder.encode(value, charsetName).replace("+", "%20").replace("*", "%2A").replace("%7E", "~")).append("&");}} else {formDataBodyBuilder.append(URLEncoder.encode(entryKey, charsetName).replace("+", "%20").replace("*", "%2A").replace("%7E", "~")).append("=").append(URLEncoder.encode(entryValue.get(0), charsetName).replace("+", "%20").replace("*", "%2A").replace("%7E", "~")).append("&");}}} catch (UnsupportedEncodingException e) {log.error("GatewayContext readFormData error {}", e.getMessage(), e);}/** 1. substring with the last char '&'* 2. if the current request is encrypted, substring with the start chat 'secFormData'*/String formDataBodyString = "";String originalFormDataBodyString = "";if (formDataBodyBuilder.length() > 0) {formDataBodyString = formDataBodyBuilder.substring(0, formDataBodyBuilder.length() - 1);originalFormDataBodyString = formDataBodyString;}/** get data bytes*/byte[] bodyBytes = formDataBodyString.getBytes(charset);int contentLength = bodyBytes.length;gatewayContext.setRequestBodyBytes(originalFormDataBodyString.getBytes(charset));HttpHeaders httpHeaders = new HttpHeaders();httpHeaders.putAll(exchange.getRequest().getHeaders());httpHeaders.remove(HttpHeaders.CONTENT_LENGTH);/** in case of content-length not matched*/httpHeaders.setContentLength(contentLength);/** use BodyInserter to InsertFormData Body*/BodyInserter<String, ReactiveHttpOutputMessage> bodyInserter = BodyInserters.fromObject(formDataBodyString);CachedBodyOutputMessage cachedBodyOutputMessage = new CachedBodyOutputMessage(exchange, httpHeaders);log.info("2. GatewayContext Rewrite Form Data :{}", formDataBodyString);return bodyInserter.insert(cachedBodyOutputMessage, new BodyInserterContext()).then(Mono.defer(() -> {ServerHttpRequestDecorator decorator = new ServerHttpRequestDecorator(exchange.getRequest()) {@Overridepublic HttpHeaders getHeaders() {return httpHeaders;}@Overridepublic Flux<DataBuffer> getBody() {return cachedBodyOutputMessage.getBody();}};return chain.filter(exchange.mutate().request(decorator).build());}));}));}/*** ReadJsonBody*/private Mono<Void> readBody(ServerWebExchange exchange, WebFilterChain chain, GatewayContext gatewayContext) {return DataBufferUtils.join(exchange.getRequest().getBody()).flatMap(dataBuffer -> {/** read the body Flux<DataBuffer>, and release the buffer* when SpringCloudGateway Version Release To G.SR2,this can be update with the new version's feature* see PR https://github.com/spring-cloud/spring-cloud-gateway/pull/1095*/byte[] bytes = new byte[dataBuffer.readableByteCount()];dataBuffer.read(bytes);DataBufferUtils.release(dataBuffer);gatewayContext.setRequestBodyBytes(bytes);Flux<DataBuffer> cachedFlux = Flux.defer(() -> {DataBuffer buffer = exchange.getResponse().bufferFactory().wrap(bytes);DataBufferUtils.retain(buffer);return Mono.just(buffer);});/** repackage ServerHttpRequest*/ServerHttpRequest mutatedRequest = new ServerHttpRequestDecorator(exchange.getRequest()) {@Overridepublic Flux<DataBuffer> getBody() {return cachedFlux;}};ServerWebExchange mutatedExchange = exchange.mutate().request(mutatedRequest).build();return ServerRequest.create(mutatedExchange, MESSAGE_READERS).bodyToMono(String.class).doOnNext(objectValue -> {gatewayContext.setRequestBody(objectValue);if (objectValue != null && !objectValue.trim().startsWith("{")) {return;}try {gatewayContext.getAllRequestData().setAll(JsonUtil.fromJson(objectValue, Map.class));} catch (Exception e) {log.warn("Gateway context Read JsonBody error:{}", e.getMessage(), e);}}).then(chain.filter(mutatedExchange));});}}
3、签名、防重放校验
这里我们从上下文对象中取出参数即可
签名算法逻辑:
@Slf4j
@Component
public class GatewaySignCheckFilter implements WebFilter, Ordered {@Value("${api.rest.prefix}")private String apiPrefix;@Autowiredprivate RedisUtil redisUtil;//前后端约定签名密钥private static final String API_SECRET = "secret-xxx";@Overridepublic int getOrder() {return Integer.MIN_VALUE + 2;}@NotNull@Overridepublic Mono<Void> filter(ServerWebExchange exchange, @NotNull WebFilterChain chain) {ServerHttpRequest request = exchange.getRequest();String uri = request.getURI().getPath();GatewayContext gatewayContext = (GatewayContext) exchange.getAttributes().get(GatewayContext.CACHE_GATEWAY_CONTEXT);HttpHeaders headers = gatewayContext.getRequestHeaders();MediaType contentType = headers.getContentType();log.info("check url:{},method:{},contentType:{}", uri, gatewayContext.getRequestMethod(), contentType == null ? "" : contentType.toString());//如果contentType为空,只能是get请求if (contentType == null || StringUtils.isBlank(contentType.toString())) {if (request.getMethod() != HttpMethod.GET) {throw new RuntimeException("非法访问");}checkSign(uri, gatewayContext, exchange);} else {if (MediaType.APPLICATION_JSON.equals(contentType) || MediaType.APPLICATION_FORM_URLENCODED.equalsTypeAndSubtype(contentType)) {checkSign(uri, gatewayContext, exchange);}}return chain.filter(exchange);}private void checkSign(String uri, GatewayContext gatewayContext, ServerWebExchange exchange) {//忽略掉的请求List<String> ignores = Lists.newArrayList("/open/**", "/open/login/params", "/open/image");for (String ignore : ignores) {ignore = apiPrefix + ignore;if (uri.equals(ignore) || uri.startsWith(ignore.replace("/**", "/"))) {log.info("check sign ignore:{}", uri);return;}}String method = gatewayContext.getRequestMethod();log.info("start check sign {}-{}", method, uri);HttpHeaders headers = gatewayContext.getRequestHeaders();log.info("headers:{}", JsonUtils.objectToJson(headers));String clientId = getHeaderAttr(headers, SystemSign.CLIENT_ID);String timestamp = getHeaderAttr(headers, SystemSign.TIMESTAMP);String nonce = getHeaderAttr(headers, SystemSign.NONCE);String sign = getHeaderAttr(headers, SystemSign.SIGN);checkTime(timestamp);checkOnce(nonce);String headerStr = String.format("%s=%s&%s=%s&%s=%s", SystemSign.CLIENT_ID, clientId,SystemSign.NONCE, nonce, SystemSign.TIMESTAMP, timestamp);String signSecret = API_SECRET;String queryUri = uri + getQueryParam(gatewayContext.getQueryParams());log.info("headerStr:{},signSecret:{},queryUri:{}", headerStr, signSecret, queryUri);String realSign = calculatorSign(clientId, queryUri, gatewayContext, headerStr, signSecret);log.info("sign:{}, realSign:{}", sign, realSign);if (!realSign.equals(sign)) {log.warn("wrong sign");throw new RuntimeException("Illegal sign");}}private String getQueryParam(MultiValueMap<String, String> queryParams) {if (queryParams == null || queryParams.size() == 0) {return StringUtils.EMPTY;}StringBuilder builder = new StringBuilder("?");for (Map.Entry<String, List<String>> entry : queryParams.entrySet()) {String key = entry.getKey();List<String> value = entry.getValue();builder.append(key).append("=").append(value.get(0)).append("&");}builder.deleteCharAt(builder.length() - 1);return builder.toString();}private String getHeaderAttr(HttpHeaders headers, String key) {List<String> values = headers.get(key);if (CollectionUtils.isEmpty(values)) {log.warn("GatewaySignCheckFilter empty header:{}", key);throw new RuntimeException("GatewaySignCheckFilter empty header:" + key);}String value = values.get(0);if (StringUtils.isBlank(value)) {log.warn("GatewaySignCheckFilter empty header:{}", key);throw new RuntimeException("GatewaySignCheckFilter empty header:" + key);}return value;}private String calculatorSign(String clientId, String queryUri, GatewayContext gatewayContext, String headerStr, String signSecret) {String method = gatewayContext.getRequestMethod();byte[] bodyBytes = gatewayContext.getRequestBodyBytes();if (bodyBytes == null) {//空白的md5固定为:d41d8cd98f00b204e9800998ecf8427ebodyBytes = new byte[]{};}String bodyMd5 = UaaSignUtils.getMd5(bodyBytes);String ori = String.format("%s\n%s\n%s\n%s\n%s\n", method, clientId, headerStr, queryUri, bodyMd5);log.info("clientId:{},signSecret:{},headerStr:{},bodyMd5:{},queryUri:{},ori:{}", clientId, signSecret, headerStr, bodyMd5, queryUri, ori);return UaaSignUtils.sha256HMAC(ori, signSecret);}private void checkOnce(String nonce) {if (StringUtils.isBlank(nonce)) {log.warn("GatewaySignCheckFilter checkOnce Illegal");}String key = "api:auth:" + nonce;int fifteenMin = 60 * 15 * 1000;Boolean succ = redisUtil.setNxWithExpire(key, "1", fifteenMin);if (succ == null || !succ) {log.warn("GatewaySignCheckFilter checkOnce Repeat");throw new RuntimeException("checkOnce Repeat");}}private void checkTime(String timestamp) {long time;try {time = Long.parseLong(timestamp);} catch (Exception ex) {log.error("GatewaySignCheckFilter checkTime error:{}", ex.getMessage(), ex);throw new RuntimeException("checkTime error");}long now = DateTimeUtil.now();log.info("now: {}, time: {}", DateTimeUtil.millsToStr(now), DateTimeUtil.millsToStr(time));int fiveMinutes = 60 * 5 * 1000;long duration = now - time;if (duration > fiveMinutes || (-duration) > fiveMinutes) {log.warn("GatewaySignCheckFilter checkTime Late");throw new RuntimeException("checkTime Late");}}public interface SystemSign {/*** 客户端ID:固定值,由后端给前端颁发约定*/String CLIENT_ID = "client-id";/*** 客户端计算出的签名*/String SIGN = "sign";/*** 时间戳*/String TIMESTAMP = "timestamp";/*** 唯一值*/String NONCE = "nonce";}}