Skip to content

教程:使数据范围权限DataScope支持Bean Searcher #246

@zengyufei

Description

@zengyufei
  1. 引入依赖
<dependency>
	<groupId>cn.zhxu</groupId>
	<artifactId>bean-searcher-boot-starter</artifactId>
	<version>4.1.2</version>
</dependency>
  1. 实现 SqlInterceptor
@RequiredArgsConstructor
@Intercepts({
		@Signature(type = StatementHandler.class, method = "prepare", args = {Connection.class, Integer.class})})
public class DataPermissionInterceptor implements Interceptor, SqlInterceptor {

	private final DataScopeSqlProcessor dataScopeSqlProcessor;

	private final DataPermissionHandler dataPermissionHandler;

	@Override
	public Object intercept(Invocation invocation) throws Throwable {
            ......
	}

	@Override
	public <T> SearchSql<T> intercept(SearchSql<T> searchSql, Map<String, Object> map, FetchType fetchType) {
		String hashCode;
		final boolean shouldQueryCluster = fetchType.shouldQueryCluster();
		final boolean shouldQueryList = fetchType.shouldQueryList();
		String sqlString = shouldQueryCluster ? searchSql.getClusterSqlString() : searchSql.getListSqlString();
		hashCode = Convert.toStr(sqlString.hashCode());

		// 获取当前需要控制的 dataScope 集合
		List<DataScope> filterDataScopes = dataPermissionHandler.filterDataScopes(hashCode);
		if (filterDataScopes == null || filterDataScopes.isEmpty()) {
			return searchSql;
		}

		// 根据用户权限判断是否需要拦截,例如管理员可以查看所有,则直接放行
		if (dataPermissionHandler.ignorePermissionControl(filterDataScopes, hashCode)) {
			return searchSql;
		}

		if (shouldQueryCluster) {
			// 创建 matchNumTreadLocal
			DataScopeMatchNumHolder.initMatchNum();
			try {
				final String countSql = searchSql.getClusterSqlString();
				searchSql.setClusterSqlString(dataScopeSqlProcessor.parserSingle(countSql, filterDataScopes));

				// 根据 DataScopes 进行数据权限的 sql 处理
				// 如果解析后发现当前 hashCode 对应的 sql,没有任何数据权限匹配,则记录下来,后续可以直接跳过不解析
				Integer matchNum = DataScopeMatchNumHolder.pollMatchNum();
				List<DataScope> allDataScopes = dataPermissionHandler.dataScopes();
				if (allDataScopes.size() == filterDataScopes.size() && matchNum != null && matchNum == 0) {
					MappedStatementIdsWithoutDataScope.addToWithoutSet(filterDataScopes, hashCode);
				}
			} finally {
				DataScopeMatchNumHolder.removeIfEmpty();
			}

		}
		if (shouldQueryList) {
			// 创建 matchNumTreadLocal
			DataScopeMatchNumHolder.initMatchNum();
			try {
				final String listSql = searchSql.getListSqlString();
				searchSql.setListSqlString(dataScopeSqlProcessor.parserSingle(listSql, filterDataScopes));

				// 根据 DataScopes 进行数据权限的 sql 处理
				// 如果解析后发现当前 hashCode 对应的 sql,没有任何数据权限匹配,则记录下来,后续可以直接跳过不解析
				Integer matchNum = DataScopeMatchNumHolder.pollMatchNum();
				List<DataScope> allDataScopes = dataPermissionHandler.dataScopes();
				if (allDataScopes.size() == filterDataScopes.size() && matchNum != null && matchNum == 0) {
					MappedStatementIdsWithoutDataScope.addToWithoutSet(filterDataScopes, hashCode);
				}
			} finally {
				DataScopeMatchNumHolder.removeIfEmpty();
			}
		}
		return searchSql;
	}

	@Override
	public Object plugin(Object target) {
		if (target instanceof StatementHandler) {
			return Plugin.wrap(target, this);
		}
		return target;
	}

}
  1. 解决 Bean Searcher 不支持事务问题
@AutoConfiguration
@RequiredArgsConstructor
@ConditionalOnBean(DataScope.class)
public class DataScopeAutoConfiguration {
    /**
    * 使 beanSearcher 支持事务
    * */
    @Bean
    @Primary
    public SqlExecutor regMyDefaultSqlExecutor(@Autowired DataSource dataSource, ObjectProvider<SqlExecutor.SlowListener> slowListener, BeanSearcherProperties config) {
		MyDefaultSqlExecutor executor = new MyDefaultSqlExecutor(dataSource);
		ifAvailable(slowListener, executor::setSlowListener);
		executor.setSlowSqlThreshold(config.getSql().getSlowSqlThreshold());
		return executor;
    }

	private <T> void ifAvailable(ObjectProvider<T> provider, Consumer<T> consumer) {
		// 为了兼容 1.x 的 SpringBoot,最低兼容到 v1.4
		// 不直接使用 ObjectProvider.ifAvailable 方法
		T dependency = provider.getIfAvailable();
		if (dependency != null) {
			consumer.accept(dependency);
		}
	}
}

MyDefaultSqlExecutor 类

package com.hccake.ballcat.common.datascope;

import cn.zhxu.bs.BeanMeta;
import cn.zhxu.bs.SearchException;
import cn.zhxu.bs.SearchSql;
import cn.zhxu.bs.SqlExecutor;
import cn.zhxu.bs.SqlResult;
import cn.zhxu.bs.bean.SearchBean;
import cn.zhxu.bs.implement.DefaultSqlExecutor;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.jdbc.datasource.DataSourceUtils;

import javax.sql.DataSource;
import java.sql.Connection;
import java.sql.PreparedStatement;
import java.sql.ResultSet;
import java.sql.SQLException;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.concurrent.ConcurrentHashMap;

/**
 * JDBC Sql 执行器
 *
 * @author Troy.Zhou
 * @since 1.1.1
 */
public class MyDefaultSqlExecutor implements SqlExecutor {

    protected static final Logger log = LoggerFactory.getLogger(DefaultSqlExecutor.class);

    /**
     * 默认数据源
     */
    private DataSource dataSource;

    /**
     * 具名数据源
     *
     * @since v3.0.0
     */
    private final Map<String, DataSource> dataSourceMap = new ConcurrentHashMap<>();

    /**
     * 是否使用只读事务
     */
    private boolean transactional = false;

    /**
     * 慢 SQL 阈值(单位:毫秒),默认:500 毫秒
     *
     * @since v3.7.0
     */
    private long slowSqlThreshold = 500;

    /**
     * 慢 SQL 监听器
     *
     * @since v3.7.0
     */
    private SlowListener slowListener;


    public MyDefaultSqlExecutor() {
    }

    public MyDefaultSqlExecutor(DataSource dataSource) {
        this.dataSource = dataSource;
    }


    @Override
    public <T> SqlResult<T> execute(SearchSql<T> searchSql) {
        if (!searchSql.isShouldQueryList() && !searchSql.isShouldQueryCluster()) {
            return new SqlResult<>(searchSql);
        }
        Connection connection;
        try {
            connection = getConnection(searchSql.getBeanMeta());
        } catch (SQLException e) {
            throw new SearchException("Can not get connection from dataSource!", e);
        }
        try {
            return doExecute(searchSql, connection);
        } catch (SQLException e) {
            // 如果有异常,则立马关闭,否则与 SqlResult 一起关闭
//			closeQuietly(connection);
            DataSourceUtils.releaseConnection(connection, dataSource);
            throw new SearchException("A exception occurred when executing sql.", e);
        }
    }

    protected Connection getConnection(BeanMeta<?> beanMeta) throws SQLException {
        String name = beanMeta.getDataSource();
        if (name == null || "".equals(name)) {
            final DataSource dataSource = this.getDataSource();
            if (dataSource == null) {
                throw new SearchException("There is not a default dataSource for " + beanMeta.getBeanClass());
            }
            return DataSourceUtils.doGetConnection(dataSource);
        }
        DataSource dataSource = this.getDataSourceMap().get(name);
        if (dataSource == null) {
            throw new SearchException("There is not a dataSource named " + name + " for " + beanMeta.getBeanClass());
        }
        return DataSourceUtils.doGetConnection(dataSource);
    }


    protected <T> SqlResult<T> doExecute(SearchSql<T> searchSql, Connection connection) throws SQLException {
        final boolean readOnly = connection.isReadOnly();
//		if (transactional) {
//			connection.setAutoCommit(false);
//			connection.setTransactionIsolation(transactionIsolation);
//			connection.setReadOnly(true);
//		}
        SqlResult.ResultSet listResult = null;
        SqlResult.Result clusterResult = null;
        try {
            Number totalCount = null;
            if (searchSql.isShouldQueryCluster()) {
                clusterResult = executeClusterSql(searchSql, connection);
                String countAlias = searchSql.getCountAlias();
                if (countAlias != null) {
                    totalCount = (Number) clusterResult.get(countAlias);
                }
            }
            if (searchSql.isShouldQueryList()) {
                if (totalCount == null || totalCount.longValue() > 0) {
                    listResult = executeListSql(searchSql, connection);
                } else {
                    listResult = SqlResult.ResultSet.EMPTY;
                }
            }
        } catch (SQLException e) {
            closeQuietly(clusterResult);
            throw e;
        }
        return new SqlResult<T>(searchSql, listResult, clusterResult) {
            @Override
            public void close() {
                try {
                    super.close();
                } finally {
//					closeQuietly(connection);
                    DataSourceUtils.releaseConnection(connection, dataSource);
                }
            }
        };
    }

    protected SqlResult.ResultSet executeListSql(SearchSql<?> searchSql, Connection connection) throws SQLException {
        Result result = executeQuery(connection, searchSql.getListSqlString(),
                searchSql.getListSqlParams(), searchSql.getBeanMeta());
        ResultSet resultSet = result.resultSet;
        return new SqlResult.ResultSet() {
            @Override
            public boolean next() throws SQLException {
                return resultSet.next();
            }

            @Override
            public Object get(String columnLabel) throws SQLException {
                return resultSet.getObject(columnLabel);
            }

            @Override
            public void close() {
                result.close();
            }
        };
    }

    protected SqlResult.Result executeClusterSql(SearchSql<?> searchSql, Connection connection) throws SQLException {
        Result result = executeQuery(connection, searchSql.getClusterSqlString(),
                searchSql.getClusterSqlParams(), searchSql.getBeanMeta());
        ResultSet resultSet = result.resultSet;
        boolean hasValue;
        try {
            hasValue = resultSet.next();
        } catch (SQLException e) {
            result.close();
            throw e;
        }
        return new SqlResult.Result() {
            @Override
            public Object get(String columnLabel) throws SQLException {
                if (hasValue) {
                    return resultSet.getObject(columnLabel);
                }
                return null;
            }

            @Override
            public void close() {
                result.close();
            }
        };
    }

    protected static class Result {

        final PreparedStatement statement;
        final ResultSet resultSet;

        public Result(PreparedStatement statement, ResultSet resultSet) {
            this.statement = statement;
            this.resultSet = resultSet;
        }

        public void close() {
            closeQuietly(resultSet);
            closeQuietly(statement);
        }

    }

    protected Result executeQuery(Connection connection, String sql, List<Object> params,
                                  BeanMeta<?> beanMeta) throws SQLException {
        PreparedStatement statement = connection.prepareStatement(sql);
        int size = params.size();
        for (int i = 0; i < size; i++) {
            statement.setObject(i + 1, params.get(i));
        }
        long t0 = System.currentTimeMillis();
        try {
            int timeout = beanMeta.getTimeout();
            if (timeout > 0) {
                // 这个方法比较耗时,只在 timeout 大于 0 的情况下才调用它
                statement.setQueryTimeout(timeout);
            }
            ResultSet resultSet = statement.executeQuery();
            return new Result(statement, resultSet);
        } catch (SQLException e) {
            closeQuietly(statement);
            throw e;
        } finally {
            long cost = System.currentTimeMillis() - t0;
            afterExecute(beanMeta, sql, params, cost);
        }
    }

    protected void afterExecute(BeanMeta<?> beanMeta, String sql, List<Object> params, long timeCost) {
        if (timeCost >= slowSqlThreshold) {
            Class<?> beanClass = beanMeta.getBeanClass();
            SlowListener listener = slowListener;
            if (listener != null) {
                listener.onSlowSql(beanClass, sql, params, timeCost);
            }
            log.warn("bean-searcher [{}ms] slow-sql: [{}] params: {} on [{}]", timeCost, sql, params, beanClass.getName());
        } else {
            log.debug("bean-searcher [{}ms] sql: [{}] params: {}", timeCost, sql, params);
        }
    }

    protected static void closeQuietly(AutoCloseable resource) {
        try {
            if (resource != null) {
                resource.close();
            }
        } catch (Exception e) {
            log.error("Can not close {}", resource.getClass().getSimpleName(), e);
        }
    }

    /**
     * 设置默认数据源
     *
     * @param dataSource 数据源
     */
    public void setDataSource(DataSource dataSource) {
        this.dataSource = Objects.requireNonNull(dataSource);
    }

    public DataSource getDataSource() {
        return dataSource;
    }

    /**
     * 设置具名数据源
     *
     * @param name       数据源名称
     * @param dataSource 数据源
     * @see SearchBean#dataSource()
     * @since v3.1.0
     */
    public void setDataSource(String name, DataSource dataSource) {
        if (name != null && dataSource != null) {
            dataSourceMap.put(name.trim(), dataSource);
        }
    }

    public Map<String, DataSource> getDataSourceMap() {
        return dataSourceMap;
    }

    /**
     * 设置是否使用只读事务
     *
     * @param transactional 是否使用事务
     */
    public void setTransactional(boolean transactional) {
        this.transactional = transactional;
    }

    public boolean isTransactional() {
        return transactional;
    }

    public long getSlowSqlThreshold() {
        return slowSqlThreshold;
    }

    /**
     * 设置慢 SQL 阈值(最小慢 SQL 执行时间)
     *
     * @param slowSqlThreshold 慢 SQL 阈值,单位:毫秒
     * @since v3.7.0
     */
    public void setSlowSqlThreshold(long slowSqlThreshold) {
        this.slowSqlThreshold = slowSqlThreshold;
    }

    public SlowListener getSlowListener() {
        return slowListener;
    }

    public void setSlowListener(SlowListener slowListener) {
        this.slowListener = slowListener;
    }

}
  1. 可选: 简化前端多值参数传递支持 XX,YY,ZZ
/**
     * 为了简化多值参数传递,不是必须的
     * 参考:https://github.com/troyzhxu/bean-searcher/issues/10
     *
     * @return 参数过滤器
     */
    @Bean
    public ParamFilter myParamFilter(BeanSearcherProperties config) {
        final BeanSearcherProperties.Params configParams = config.getParams();
        final String separator = configParams.getSeparator();
        final String operatorKey = configParams.getOperatorKey();
        return new ParamFilter() {

            final String OP_SUFFIX = separator + operatorKey;

            @Override
            public <T> Map<String, Object> doFilter(BeanMeta<T> beanMeta, Map<String, Object> paraMap) {
                Map<String, Object> newParaMap = new HashMap<>();
                paraMap.forEach((key, value) -> {
                    if (key == null) {
                        return;
                    }
                    boolean isOpKey = key.endsWith(OP_SUFFIX);
                    String opKey = isOpKey ? key : key + OP_SUFFIX;
                    Object opVal = paraMap.get(opKey);
                    if (!Arrays.asList("mv", "il", "bt", "nb", "ol", "ni").contains(StrUtil.trim((CharSequence) opVal))) {
                        newParaMap.put(key, value);
                        return;
                    }
                    if (newParaMap.containsKey(key)) {
                        return;
                    }
                    String valKey = key;
                    Object valVal = value;
                    if (isOpKey) {
                        valKey = key.substring(0, key.length() - OP_SUFFIX.length());
                        valVal = paraMap.get(valKey);
                    }
                    if (strContainDou(valVal)) {
                        try {
                            final List<String> split = StrUtil.split((String) valVal, ",");
                            for (int i = 0; i < split.size(); i++) {
                                final String v = split.get(i);
                                newParaMap.put(valKey + separator + i, StrUtil.trim(v));
                            }
                            newParaMap.put(opKey, opVal);
                            return;
                        } catch (Exception ignore) {
                        }
                    }
                    newParaMap.put(key, value);
                });
                return newParaMap;
            }

            private boolean strContainDou(Object value) {
                if (value instanceof String) {
                    String str = ((String) value).trim();
                    return str.contains(",");
                }
                return false;
            }

        };
    }

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions