Spring 统一接口后过滤自定义字段

Posted on May 17, 2020

继上一篇文章 Spring 实现 RESTful 统一返回值和错误处理 之后, 我又遇到了一个需求, 那就是如何过滤有些不需要的字段, 比如我现在要获取一个用户的所有资料, 但是不包括密码和邮箱, 那我是可以写一条 SQL 语句去避开他的, 但是如果我明天不想要密码, 但是我想要邮箱了呢, 那我再写一条?
这当然也没什么问题, 只是如果变换多了之后写起来还是挺蛋疼的

对于 Jackson 内提供的注解 @JsonIgnore 可以解决一部分问题, 但是对于我上面的自定义忽略就无能为力了, 而 @JsonView 则可以达到我的需求, 但是我在使用 @JsonView 的过程中, 发现他返回给我一个 {}, 结果为空, 这就意味着和我上次的统一返回值处理中用到的 ResponseBodyAdvice 发生了冲突, 据我不可靠的观察, 这个应该是由于 JsonViewResponseBodyAdvice 继承了 AbstractMappingJacksonResponseBodyAdvice, 而 AbstractMappingJacksonResponseBodyAdvice 实现了 ResponseBodyAdvice, 我之前的统一返回也是一样通过 ResponseBodyAdvice 进行的返回拦截处理, 应该是由于他们之间的执行顺序不同导致了这个问题.
所以 @JsonView 并不能解决我的需求, 我想了想, 不想写dto, 好像只能通过反射进行处理了

首先, 我们定义一个注解

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
@Target(ElementType.METHOD)
@Retention(RetentionPolicy.RUNTIME)
public @interface CustomFilter {
    String[] include() default {};

    String[] exclude() default {};

    boolean filterNull() default false;
}

我采取的策略是, include 优先级最高, 当 include 有被设置, 那么 exclude 将不再生效, filterNull 则是过滤掉返回值中 null 的部分.
我们的返回值中可能包含两种类型, 一个是直接的 Object, 一个是列表, 我们需要对他们分别处理
首先我们先处理列表

1
2
3
4
5
6
7
private List<Object> handleList(List<?> objectList) throws IllegalAccessException {
    List<Object> newList = new ArrayList<>();
    for (Object obj : objectList) {
        newList.add(handelObject(obj));
    }
    return newList;
}

很简单, 就不解释了.
接下来处理直接的 Object

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
private Object handelObject(Object object) throws IllegalAccessException {
    Map<String, Object> objectMap = new LinkedHashMap<>();
    // 获取所有声明的属性
    Field[] fieldList = object.getClass().getDeclaredFields();
    // 遍历
    for (Field field : fieldList) {
        String fieldName = field.getName();
        // 如果设置了 include 就不会去走 exclude 那段分支
        if (include.size() > 0) {
            if (include.contains(fieldName)) {
                field.setAccessible(true);
                // 如果设置了过滤 null
                if (filterNull && field.get(object) == null) {
                    continue;
                }
                // 如果这个 Object 是个 list
                if (object instanceof List) {
                    objectMap.put(fieldName, handleList((List<?>) field.get(object)));
                } else {
                    objectMap.put(fieldName, field.get(object));
                }
            }
        } else {
            if ((exclude.size() > 0 && exclude.contains(fieldName))) {
                continue;
            }
            field.setAccessible(true);
            if (filterNull && field.get(object) == null) {
                continue;
            }
            objectMap.put(fieldName, field.get(object));
        }
    }
    return objectMap;
}

总体就是通过反射获取所有声明的属性, 然后进行过滤.

最后附上 GlobalResultResolver 所有的代码

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
@EnableWebMvc
@Configuration
public class GlobalResultResolver {

    @RestControllerAdvice
    static class ResultBodyConfigAdvice implements ResponseBodyAdvice<Object> {
        private HashSet<String> include;
        private HashSet<String> exclude;
        private boolean filterNull;

        @Override
        public boolean supports(MethodParameter returnType, Class<? extends HttpMessageConverter<?>> converterType) {
            return true;
        }

        @Override
        public Object beforeBodyWrite(Object body, MethodParameter returnType, MediaType selectedContentType, Class<? extends HttpMessageConverter<?>> selectedConverterType, ServerHttpRequest request, ServerHttpResponse response) {
            if (Objects.requireNonNull(returnType.getMethod()).isAnnotationPresent(CustomFilter.class)) {
                CustomFilter customFilter = returnType.getMethodAnnotation(CustomFilter.class);
                include = new HashSet<>();
                exclude = new HashSet<>();
                if (customFilter != null) {
                    include.addAll(Arrays.asList(customFilter.include()));
                    exclude.addAll(Arrays.asList(customFilter.exclude()));
                    filterNull = customFilter.filterNull();
                    try {
                        if (body instanceof List) {
                            body = handleList((List<?>) body);
                        } else {
                            body = handelObject(body);
                        }
                    } catch (Exception e) {
                        e.printStackTrace();
                    }
                }
            }
            if (body instanceof ResultBody) {
                return body;
            }
            return new ResultBody<>(body);
        }

        private Object handelObject(Object object) throws IllegalAccessException {
            Map<String, Object> objectMap = new LinkedHashMap<>();
            Field[] fieldList = object.getClass().getDeclaredFields();
            for (Field field : fieldList) {
                String fieldName = field.getName();
                if (include.size() > 0) {
                    if (include.contains(fieldName)) {
                        field.setAccessible(true);
                        if (filterNull && field.get(object) == null) {
                            continue;
                        }
                        if (object instanceof List) {
                            objectMap.put(fieldName, handleList((List<?>) field.get(object)));
                        } else {
                            objectMap.put(fieldName, field.get(object));
                        }
                    }
                } else {
                    if ((exclude.size() > 0 && exclude.contains(fieldName))) {
                        continue;
                    }
                    field.setAccessible(true);
                    if (filterNull && field.get(object) == null) {
                        continue;
                    }
                    objectMap.put(fieldName, field.get(object));
                }
            }
            return objectMap;
        }

        private List<Object> handleList(List<?> objectList) throws IllegalAccessException {
            List<Object> newList = new ArrayList<>();
            for (Object obj : objectList) {
                newList.add(handelObject(obj));
            }
            return newList;
        }
    }
}

这样, 我想得到的效果也已经实现了, 如果还有更好的方法, 也可以告诉我……