diff --git a/yudao-framework/yudao-spring-boot-starter-data-permission/src/main/java/cn/iocoder/yudao/framework/datapermission/core/interceptor/DataPermissionInterceptor.java b/yudao-framework/yudao-spring-boot-starter-data-permission/src/main/java/cn/iocoder/yudao/framework/datapermission/core/interceptor/DataPermissionInterceptor.java index a1af259e0..2209a3834 100644 --- a/yudao-framework/yudao-spring-boot-starter-data-permission/src/main/java/cn/iocoder/yudao/framework/datapermission/core/interceptor/DataPermissionInterceptor.java +++ b/yudao-framework/yudao-spring-boot-starter-data-permission/src/main/java/cn/iocoder/yudao/framework/datapermission/core/interceptor/DataPermissionInterceptor.java @@ -4,18 +4,20 @@ import cn.hutool.core.collection.CollUtil; import cn.iocoder.yudao.framework.common.util.collection.SetUtils; import cn.iocoder.yudao.framework.datapermission.core.rule.DataPermissionRule; import cn.iocoder.yudao.framework.datapermission.core.rule.DataPermissionRuleFactory; +import cn.iocoder.yudao.framework.mybatis.core.util.MyBatisUtils; import com.alibaba.ttl.TransmittableThreadLocal; import com.baomidou.mybatisplus.core.toolkit.CollectionUtils; import com.baomidou.mybatisplus.core.toolkit.PluginUtils; -import com.baomidou.mybatisplus.core.toolkit.StringPool; import com.baomidou.mybatisplus.extension.parser.JsqlParserSupport; import com.baomidou.mybatisplus.extension.plugins.inner.InnerInterceptor; import lombok.RequiredArgsConstructor; import net.sf.jsqlparser.expression.*; import net.sf.jsqlparser.expression.operators.conditional.AndExpression; import net.sf.jsqlparser.expression.operators.conditional.OrExpression; -import net.sf.jsqlparser.expression.operators.relational.*; -import net.sf.jsqlparser.schema.Column; +import net.sf.jsqlparser.expression.operators.relational.ExistsExpression; +import net.sf.jsqlparser.expression.operators.relational.ExpressionList; +import net.sf.jsqlparser.expression.operators.relational.InExpression; +import net.sf.jsqlparser.expression.operators.relational.ItemsList; import net.sf.jsqlparser.schema.Table; import net.sf.jsqlparser.statement.delete.Delete; import net.sf.jsqlparser.statement.select.*; @@ -32,6 +34,15 @@ import java.sql.Connection; import java.util.*; import java.util.concurrent.ConcurrentHashMap; +/** + * 数据权限拦截器,通过 {@link DataPermissionRule} 数据权限规则,重写 SQL 的方式来实现 + * 主要的 SQL 重写方法,可见 {@link #builderExpression(Expression, Table)} 方法 + * + * 整体的代码实现上,参考 {@link com.baomidou.mybatisplus.extension.plugins.inner.TenantLineInnerInterceptor} 实现。 + * 所以每次 MyBatis Plus 升级时,需要 Review 下其具体的实现是否有变更! + * + * @author 芋道源码 + */ @RequiredArgsConstructor public class DataPermissionInterceptor extends JsqlParserSupport implements InnerInterceptor { @@ -40,7 +51,8 @@ public class DataPermissionInterceptor extends JsqlParserSupport implements Inne private final MappedStatementCache mappedStatementCache = new MappedStatementCache(); @Override // SELECT 场景 - public void beforeQuery(Executor executor, MappedStatement ms, Object parameter, RowBounds rowBounds, ResultHandler resultHandler, BoundSql boundSql) { + public void beforeQuery(Executor executor, MappedStatement ms, Object parameter, + RowBounds rowBounds, ResultHandler resultHandler, BoundSql boundSql) { // 获得 Mapper 对应的数据权限的规则 List rules = ruleFactory.getDataPermissionRule(ms.getId()); if (mappedStatementCache.noRewritable(ms, rules)) { // 如果无需重写,则跳过 @@ -59,12 +71,12 @@ public class DataPermissionInterceptor extends JsqlParserSupport implements Inne } } - @Override // 只处理 UPDATE / DELETE 场景 + @Override // 只处理 UPDATE / DELETE 场景,不处理 INSERT 场景 public void beforePrepare(StatementHandler sh, Connection connection, Integer transactionTimeout) { PluginUtils.MPStatementHandler mpSh = PluginUtils.mpStatementHandler(sh); MappedStatement ms = mpSh.mappedStatement(); SqlCommandType sct = ms.getSqlCommandType(); - if (sct == SqlCommandType.UPDATE || sct == SqlCommandType.DELETE) { // 无需处理 Insert 语句 + if (sct == SqlCommandType.UPDATE || sct == SqlCommandType.DELETE) { // 获得 Mapper 对应的数据权限的规则 List rules = ruleFactory.getDataPermissionRule(ms.getId()); if (mappedStatementCache.noRewritable(ms, rules)) { // 如果无需重写,则跳过 @@ -117,7 +129,8 @@ public class DataPermissionInterceptor extends JsqlParserSupport implements Inne @Override protected void processUpdate(Update update, int index, String sql, Object obj) { final Table table = update.getTable(); - update.setWhere(this.andExpression(table, update.getWhere())); +// update.setWhere(this.andExpression(table, update.getWhere())); + update.setWhere(this.builderExpression(update.getWhere(), table)); } /** @@ -125,26 +138,27 @@ public class DataPermissionInterceptor extends JsqlParserSupport implements Inne */ @Override protected void processDelete(Delete delete, int index, String sql, Object obj) { - delete.setWhere(this.andExpression(delete.getTable(), delete.getWhere())); +// delete.setWhere(this.andExpression(delete.getTable(), delete.getWhere())); + delete.setWhere(this.builderExpression(delete.getWhere(), delete.getTable())); } - /** - * delete update 语句 where 处理 - */ - protected BinaryExpression andExpression(Table table, Expression where) { - //获得where条件表达式 - EqualsTo equalsTo = new EqualsTo(); - equalsTo.setLeftExpression(this.getAliasColumn(table)); - equalsTo.setRightExpression(getTenantId()); - if (null != where) { - if (where instanceof OrExpression) { - return new AndExpression(equalsTo, new Parenthesis(where)); - } else { - return new AndExpression(equalsTo, where); - } - } - return equalsTo; - } +// /** +// * delete update 语句 where 处理 +// */ +// protected BinaryExpression andExpression(Table table, Expression where) { +// //获得where条件表达式 +// EqualsTo equalsTo = new EqualsTo(); +// equalsTo.setLeftExpression(this.getAliasColumn(table)); +// equalsTo.setRightExpression(getTenantId()); +// if (null != where) { +// if (where instanceof OrExpression) { +// return new AndExpression(equalsTo, new Parenthesis(where)); +// } else { +// return new AndExpression(equalsTo, where); +// } +// } +// return equalsTo; +// } /** * 处理 PlainSelect @@ -155,10 +169,11 @@ public class DataPermissionInterceptor extends JsqlParserSupport implements Inne processWhereSubSelect(where); if (fromItem instanceof Table) { Table fromTable = (Table) fromItem; - if (!ignoreTable(fromTable.getName())) { - //#1186 github - plainSelect.setWhere(builderExpression(where, fromTable)); - } +// if (!ignoreTable(fromTable.getName())) { +// //#1186 github +// plainSelect.setWhere(builderExpression(where, fromTable)); +// } + plainSelect.setWhere(builderExpression(where, fromTable)); } else { processFromItem(fromItem); } @@ -311,20 +326,21 @@ public class DataPermissionInterceptor extends JsqlParserSupport implements Inne processJoin(join); continue; } - // 当前表是否忽略 - boolean needIgnore = ignoreTable(fromTable.getName()); - // 表名压栈,忽略的表压入 null,以便后续不处理 - tables.push(needIgnore ? null : fromTable); +// // 当前表是否忽略 +// boolean needIgnore = ignoreTable(fromTable.getName()); +// // 表名压栈,忽略的表压入 null,以便后续不处理 +// tables.push(needIgnore ? null : fromTable); // 尾缀多个 on 表达式的时候统一处理 if (originOnExpressions.size() > 1) { Collection onExpressions = new LinkedList<>(); for (Expression originOnExpression : originOnExpressions) { Table currentTable = tables.poll(); - if (currentTable == null) { - onExpressions.add(originOnExpression); - } else { - onExpressions.add(builderExpression(originOnExpression, currentTable)); - } +// if (currentTable == null) { +// onExpressions.add(originOnExpression); +// } else { +// onExpressions.add(builderExpression(originOnExpression, currentTable)); +// } + onExpressions.add(builderExpression(originOnExpression, currentTable)); } join.setOnExpressions(onExpressions); } @@ -341,15 +357,18 @@ public class DataPermissionInterceptor extends JsqlParserSupport implements Inne protected void processJoin(Join join) { if (join.getRightItem() instanceof Table) { Table fromTable = (Table) join.getRightItem(); - if (ignoreTable(fromTable.getName())) { - // 过滤退出执行 - return; - } +// if (ignoreTable(fromTable.getName())) { +// // 过滤退出执行 +// return; +// } // 走到这里说明 on 表达式肯定只有一个 - Collection originOnExpressions = join.getOnExpressions(); - List onExpressions = new LinkedList<>(); - onExpressions.add(builderExpression(originOnExpressions.iterator().next(), fromTable)); - join.setOnExpressions(onExpressions); +// Collection originOnExpressions = join.getOnExpressions(); +// List onExpressions = new LinkedList<>(); +// onExpressions.add(builderExpression(originOnExpressions.iterator().next(), fromTable)); +// join.setOnExpressions(onExpressions); + Expression originOnExpression = CollUtil.getFirst(join.getOnExpressions()); + originOnExpression = builderExpression(originOnExpression, fromTable); + join.setOnExpressions(CollUtil.newArrayList(originOnExpression)); } } @@ -357,50 +376,69 @@ public class DataPermissionInterceptor extends JsqlParserSupport implements Inne * 处理条件 */ protected Expression builderExpression(Expression currentExpression, Table table) { - EqualsTo equalsTo = new EqualsTo(); - equalsTo.setLeftExpression(this.getAliasColumn(table)); - equalsTo.setRightExpression(getTenantId()); + // 获得 Table 对应的数据权限条件 + Expression equalsTo = buildDataPermissionExpression(table); + if (equalsTo == null) { // 如果没条件,则返回 currentExpression 默认 + return currentExpression; + } + + // 表达式为空,则直接返回 equalsTo if (currentExpression == null) { return equalsTo; } + // 如果表达式为 Or,则需要 (currentExpression) AND equalsTo if (currentExpression instanceof OrExpression) { return new AndExpression(new Parenthesis(currentExpression), equalsTo); - } else { - return new AndExpression(currentExpression, equalsTo); } + // 如果表达式为 And,则直接返回 currentExpression AND equalsTo + return new AndExpression(currentExpression, equalsTo); } +// /** +// * 租户字段别名设置 +// *

tenantId 或 tableAlias.tenantId

+// * +// * @param table 表对象 +// * @return 字段 +// */ +// protected Column getAliasColumn(Table table) { +// StringBuilder column = new StringBuilder(); +// if (table.getAlias() != null) { +// column.append(table.getAlias().getName()).append(StringPool.DOT); +// } +// column.append(getTenantIdColumn()); +// return new Column(column.toString()); +// } + /** - * 租户字段别名设置 - *

tenantId 或 tableAlias.tenantId

+ * 构建指定表的数据权限的 Expression 过滤条件 * - * @param table 表对象 - * @return 字段 + * @param table 表 + * @return Expression 过滤条件 */ - protected Column getAliasColumn(Table table) { - StringBuilder column = new StringBuilder(); - if (table.getAlias() != null) { - column.append(table.getAlias().getName()).append(StringPool.DOT); + private Expression buildDataPermissionExpression(Table table) { + // 生成条件 + Expression allExpression = null; + for (DataPermissionRule rule : ContextHolder.getRules()) { + // 判断表名是否匹配 + if (!rule.getTableNames().contains(table.getName())) { + continue; + } + // 单条规则的条件 + String tableName = MyBatisUtils.getTableName(table); + Expression oneExpress = rule.getExpression(tableName, table.getAlias()); + // 拼接到 allExpression 中 + allExpression = allExpression == null ? oneExpress + : new AndExpression(allExpression, oneExpress); } - column.append(getTenantIdColumn()); - return new Column(column.toString()); + + // 如果条件非空,说明已经重写了 + if (allExpression != null) { + ContextHolder.setRewrite(true); + } + return allExpression; } - // TODO 芋艿:未实现 - - private boolean ignoreTable(String tableName) { - return false; - } - - private String getTenantIdColumn() { - return "dept_id"; - } - - private Expression getTenantId() { - return new LongValue(1L); - } - - /** * 判断 SQL 是否重写。如果没有重写,则添加到 {@link MappedStatementCache} 中 * diff --git a/yudao-framework/yudao-spring-boot-starter-mybatis/src/main/java/cn/iocoder/yudao/framework/mybatis/core/util/MyBatisUtils.java b/yudao-framework/yudao-spring-boot-starter-mybatis/src/main/java/cn/iocoder/yudao/framework/mybatis/core/util/MyBatisUtils.java index f7f0cad1c..54e405817 100644 --- a/yudao-framework/yudao-spring-boot-starter-mybatis/src/main/java/cn/iocoder/yudao/framework/mybatis/core/util/MyBatisUtils.java +++ b/yudao-framework/yudao-spring-boot-starter-mybatis/src/main/java/cn/iocoder/yudao/framework/mybatis/core/util/MyBatisUtils.java @@ -7,6 +7,7 @@ import com.baomidou.mybatisplus.core.metadata.OrderItem; import com.baomidou.mybatisplus.extension.plugins.MybatisPlusInterceptor; import com.baomidou.mybatisplus.extension.plugins.inner.InnerInterceptor; import com.baomidou.mybatisplus.extension.plugins.pagination.Page; +import net.sf.jsqlparser.schema.Table; import java.util.ArrayList; import java.util.Collection; @@ -18,6 +19,8 @@ import java.util.stream.Collectors; */ public class MyBatisUtils { + private static final String MYSQL_ESCAPE_CHARACTER = "`"; + public static Page buildPage(PageParam pageParam) { return buildPage(pageParam, null); } @@ -48,4 +51,20 @@ public class MyBatisUtils { interceptor.setInterceptors(inners); } + /** + * 获得 Table 对应的表名 + * + * 兼容 MySQL 转义表名 `t_xxx` + * + * @param table 表 + * @return 去除转移字符后的表名 + */ + public static String getTableName(Table table) { + String tableName = table.getName(); + if (tableName.startsWith(MYSQL_ESCAPE_CHARACTER) && tableName.endsWith(MYSQL_ESCAPE_CHARACTER)) { + tableName = tableName.substring(1, tableName.length() - 1); + } + return tableName; + } + }