解决 HttpServletRequest 的输入流只能读取一次的问题 - 今日头条

本文由 简悦 SimpRead 转码, 原文地址 www.toutiao.com

在一个项目中会有很多的接口,而不同的接口可能接收不同类型的数据,例如表单数据和 json 数据,表单数据还好说,调用 request 的 getParam

通常对安全性有要求的接口都会对请求参数做一些签名验证,而我们一般会把验签的逻辑统一放到过滤器或拦截器里,这样就不用每个接口都去重复编写验签的逻辑。

在一个项目中会有很多的接口,而不同的接口可能接收不同类型的数据,例如表单数据和 json 数据,表单数据还好说,调用 request 的 getParameterMap 就能全部取出来。而 json 数据就有些麻烦了,因为 json 数据放在 body 中,我们需要通过 request 的输入流去读取。

但问题在于 request 的输入流只能读取一次不能重复读取,所以我们在过滤器或拦截器里读取了 request 的输入流之后,请求走到 controller 层时就会报错。而本文的目的就是介绍如何解决在这种场景下遇到 HttpServletRequest 的输入流只能读取一次的问题。

注:本文代码基于 SpringBoot 框架

我们先来看看为什么 HttpServletRequest 的输入流只能读一次,当我们调用 getInputStream() 方法获取输入流时得到的是一个 InputStream 对象,而实际类型是 ServletInputStream,它继承于 InputStream。

InputStream 的 read() 方法内部有一个 postion,标志当前流被读取到的位置,每读取一次,该标志就会移动一次,如果读到最后,read() 会返回 - 1,表示已经读取完了。如果想要重新读取则需要调用 reset() 方法,position 就会移动到上次调用 mark 的位置,mark 默认是 0,所以就能从头再读了。调用 reset() 方法的前提是已经重写了 reset() 方法,当然能否 reset 也是有条件的,它取决于 markSupported() 方法是否返回 true。

InputStream 默认不实现 reset(),并且 markSupported() 默认也是返回 false,这一点查看其源码便知:

https://p26.toutiaoimg.com/origin/tos-cn-i-qvj2lq49k0/3585968466eb443798b6034586c59cf4?from=pc

我们再来看看 ServletInputStream,可以看到该类没有重写 mark(),reset() 以及 markSupported() 方法:

https://p26.toutiaoimg.com/origin/tos-cn-i-qvj2lq49k0/a9cf0c4c64b34f7f83ac5848ff86f05e?from=pc

综上,InputStream 默认不实现 reset 的相关方法,而 ServletInputStream 也没有重写 reset 的相关方法,这样就无法重复读取流,这就是我们从 request 对象中获取的输入流就只能读取一次的原因。

既然 ServletInputStream 不支持重新读写,那么为什么不把流读出来后用容器存储起来,后面就可以多次利用了。那么问题就来了,要如何存储这个流呢?

所幸 JavaEE 提供了一个 HttpServletRequestWrapper 类,从类名也可以知道它是一个 http 请求包装器,其基于装饰者模式实现了 HttpServletRequest 界面,部分源码如下:

https://p26.toutiaoimg.com/origin/tos-cn-i-qvj2lq49k0/a2875124e75a48dab2652cac3fb80ab6?from=pc

从上图中的部分源码可以看到,该类并没有真正去实现 HttpServletRequest 的方法,而只是在方法内又去调用 HttpServletRequest 的方法,所以我们可以通过继承该类并实现想要重新定义的方法以达到包装原生 HttpServletRequest 对象的目的。

首先我们要定义一个容器,将输入流里面的数据存储到这个容器里,这个容器可以是数组或集合。然后我们重写 getInputStream 方法,每次都从这个容器里读数据,这样我们的输入流就可以读取任意次了。

具体的实现代码如下:

  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
package com.example.wrapperdemo.controller.wrapper;

import lombok.extern.slf4j.Slf4j;

import javax.servlet.ReadListener;
import javax.servlet.ServletInputStream;
import javax.servlet.ServletRequest;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletRequestWrapper;
import java.io.*;
import java.nio.charset.Charset;

/**
 * @author 01
 * @program wrapper-demo
 * @description 包装HttpServletRequest,目的是让其输入流可重复读
 * @create 2018-12-24 20:48
 * @since 1.0
 **/
@Slf4j
public class RequestWrapper extends HttpServletRequestWrapper {
    /**
     * 存储body数据的容器
     */
    private final byte[] body;

    public RequestWrapper(HttpServletRequest request) throws IOException {
        super(request);

        // 将body数据存储起来
        String bodyStr = getBodyString(request);
        body = bodyStr.getBytes(Charset.defaultCharset());
    }

    /**
     * 获取请求Body
     *
     * @param request request
     * @return String
     */
    public String getBodyString(final ServletRequest request) {
        try {
            return inputStream2String(request.getInputStream());
        } catch (IOException e) {
            log.error("", e);
            throw new RuntimeException(e);
        }
    }

    /**
     * 获取请求Body
     *
     * @return String
     */
    public String getBodyString() {
        final InputStream inputStream = new ByteArrayInputStream(body);

        return inputStream2String(inputStream);
    }

    /**
     * 将inputStream里的数据读取出来并转换成字符串
     *
     * @param inputStream inputStream
     * @return String
     */
    private String inputStream2String(InputStream inputStream) {
        StringBuilder sb = new StringBuilder();
        BufferedReader reader = null;

        try {
            reader = new BufferedReader(new InputStreamReader(inputStream, Charset.defaultCharset()));
            String line;
            while ((line = reader.readLine()) != null) {
                sb.append(line);
            }
        } catch (IOException e) {
            log.error("", e);
            throw new RuntimeException(e);
        } finally {
            if (reader != null) {
                try {
                    reader.close();
                } catch (IOException e) {
                    log.error("", e);
                }
            }
        }

        return sb.toString();
    }

    @Override
    public BufferedReader getReader() throws IOException {
        return new BufferedReader(new InputStreamReader(getInputStream()));
    }

    @Override
    public ServletInputStream getInputStream() throws IOException {

        final ByteArrayInputStream inputStream = new ByteArrayInputStream(body);

        return new ServletInputStream() {
            @Override
            public int read() throws IOException {
                return inputStream.read();
            }

            @Override
            public boolean isFinished() {
                return false;
            }

            @Override
            public boolean isReady() {
                return false;
            }

            @Override
            public void setReadListener(ReadListener readListener) {
            }
        };
    }
}

除了要写一个包装器外,我们还需要在过滤器里将原生的 HttpServletRequest 对象替换成我们的 RequestWrapper 对象,代码如下:

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
package com.example.wrapperdemo.controller.filter;

import com.example.wrapperdemo.controller.wrapper.RequestWrapper;
import lombok.extern.slf4j.Slf4j;

import javax.servlet.*;
import javax.servlet.http.HttpServletRequest;
import java.io.IOException;

/**
 * @author 01
 * @program wrapper-demo
 * @description 替换HttpServletRequest
 * @create 2018-12-24 21:04
 * @since 1.0
 **/
@Slf4j
public class ReplaceStreamFilter implements Filter {
    @Override
    public void init(FilterConfig filterConfig) throws ServletException {
        log.info("StreamFilter初始化...");
    }

    @Override
    public void doFilter(ServletRequest request, ServletResponse response, FilterChain chain) throws IOException, ServletException {
        ServletRequest requestWrapper = new RequestWrapper((HttpServletRequest) request);
        chain.doFilter(requestWrapper, response);
    }

    @Override
    public void destroy() {
        log.info("StreamFilter销毁...");
    }
}

然后我们就可以在拦截器中愉快的获取 json 数据也不慌 controller 层会报错了:

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
package com.example.wrapperdemo.controller.interceptor;

import com.example.wrapperdemo.controller.wrapper.RequestWrapper;
import lombok.extern.slf4j.Slf4j;
import org.springframework.http.MediaType;
import org.springframework.web.servlet.HandlerInterceptor;
import org.springframework.web.servlet.ModelAndView;

import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;

/**
 * @author 01
 * @program wrapper-demo
 * @description 签名拦截器
 * @create 2018-12-24 21:08
 * @since 1.0
 **/
@Slf4j
public class SignatureInterceptor implements HandlerInterceptor {
    @Override
    public boolean preHandle(HttpServletRequest request, HttpServletResponse response, Object handler) throws Exception {
        log.info("[preHandle] executing... request uri is {}", request.getRequestURI());
        if (isJson(request)) {
            // 获取json字符串
            String jsonParam = new RequestWrapper(request).getBodyString();
            log.info("[preHandle] json数据 : {}", jsonParam);

            // 验签逻辑...略...
        }

        return true;
    }

    @Override
    public void postHandle(HttpServletRequest request, HttpServletResponse response, Object handler, ModelAndView modelAndView) throws Exception {

    }

    @Override
    public void afterCompletion(HttpServletRequest request, HttpServletResponse response, Object handler, Exception ex) throws Exception {

    }

    /**
     * 判断本次请求的数据类型是否为json
     *
     * @param request request
     * @return boolean
     */
    private boolean isJson(HttpServletRequest request) {
        if (request.getContentType() != null) {
            return request.getContentType().equals(MediaType.APPLICATION_JSON_VALUE) ||
                    request.getContentType().equals(MediaType.APPLICATION_JSON_UTF8_VALUE);
        }

        return false;
    }
}

编写完以上的代码后,还需要将过滤器和拦截器在配置类中进行注册才会生效,过滤器配置类代码如下:

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
package com.example.wrapperdemo.config;

import com.example.wrapperdemo.controller.filter.ReplaceStreamFilter;
import org.springframework.boot.web.servlet.FilterRegistrationBean;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;

import javax.servlet.Filter;

/**
 * @author 01
 * @program wrapper-demo
 * @description 过滤器配置类
 * @create 2018-12-24 21:06
 * @since 1.0
 **/
@Configuration
public class FilterConfig {
    /**
     * 注册过滤器
     *
     * @return FilterRegistrationBean
     */
    @Bean
    public FilterRegistrationBean someFilterRegistration() {
        FilterRegistrationBean registration = new FilterRegistrationBean();
        registration.setFilter(replaceStreamFilter());
        registration.addUrlPatterns("/*");
        registration.setName("streamFilter");
        return registration;
    }

    /**
     * 实例化StreamFilter
     *
     * @return Filter
     */
    @Bean(name = "replaceStreamFilter")
    public Filter replaceStreamFilter() {
        return new ReplaceStreamFilter();
    }
}

拦截器配置类代码如下:

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
package com.example.wrapperdemo.config;

import com.example.wrapperdemo.controller.interceptor.SignatureInterceptor;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
import org.springframework.web.servlet.config.annotation.InterceptorRegistry;
import org.springframework.web.servlet.config.annotation.WebMvcConfigurer;

/**
 * @author 01
 * @program wrapper-demo
 * @description
 * @create 2018-12-24 21:16
 * @since 1.0
 **/
@Configuration
public class InterceptorConfig implements WebMvcConfigurer {

    @Bean
    public SignatureInterceptor getSignatureInterceptor(){
        return new SignatureInterceptor();
    }

    /**
     * 注册拦截器
     *
     * @param registry registry
     */
    @Override
    public void addInterceptors(InterceptorRegistry registry) {
        registry.addInterceptor(getSignatureInterceptor())
                .addPathPatterns("/**");
    }
}

接下来我们就可以测试一下在拦截器中读取了输入流后在 controller 层是否还能正常接收数据,首先定义一个实体类,代码如下:

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
package com.example.wrapperdemo.param;

import lombok.AllArgsConstructor;
import lombok.Builder;
import lombok.Data;
import lombok.NoArgsConstructor;

/**
 * @author 01
 * @program wrapper-demo
 * @description
 * @create 2018-12-24 21:11
 * @since 1.0
 **/
@Data
@Builder
@NoArgsConstructor
@AllArgsConstructor
public class UserParam {
    private String userName;

    private String phone;

    private String password;
}

然后写一个简单的 Controller,代码如下:

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
package com.example.wrapperdemo.controller;

import com.example.wrapperdemo.param.UserParam;
import org.springframework.web.bind.annotation.PostMapping;
import org.springframework.web.bind.annotation.RequestBody;
import org.springframework.web.bind.annotation.RequestMapping;
import org.springframework.web.bind.annotation.RestController;

/**
 * @author 01
 * @program wrapper-demo
 * @description
 * @create 2018-12-24 20:47
 * @since 1.0
 **/
@RestController
@RequestMapping("/user")
public class DemoController {

    @PostMapping("/register")
    public UserParam register(@RequestBody UserParam userParam){
        return userParam;
    }
}

启动项目,请求结果如下,可以看到 controller 正常接收到数据并返回了:

https://p26.toutiaoimg.com/origin/tos-cn-i-qvj2lq49k0/8adc927e4c0a4b95af639084a43e5274?from=pc

控制台输出如下:

https://p26.toutiaoimg.com/origin/tos-cn-i-qvj2lq49k0/8598cf24d8044c3f9f4aa8a11ca61440?from=pc

https://cloud.tencent.com/developer/article/1702246