总监:喂,小王啊!起来没呢?加个班呗!

我:泥煤啊…

总监:我有个需求啊,这最近导入数据比较多,但是后台用户反映导入了数据,不想要了,删除起来麻烦啊!你也知道,顾客是上帝嘛,给我完成一个导入数据自动一键回滚的功能!

我:说啥也不干,今天休息,我还要打游戏。

总监:那个你申请一个在家办公,两倍工资,我这面批一个。

我:好嘞!


一:需求分析

  1. 导入一定分为很多种,有商品的,有图片的,有各种业务的,一定要兼容各种具体的业务,那么就不能依赖于具体实现。
  2. 分析在各个业务层,导入无谓就是处理完数据之后生成的增删改语句。
  3. 那我只需要处理sql语句就可以了,把增删改的语句生成它具体的相反的语句。insert生成delete语句,udapte生成delete和insert语句,delete生成insert语句。
  4. 那么多的mapper层接口的语句,怎么知道哪个语句是需要生成相反的语句呢?可以自定义一个注解,然后我们在执行之前看看该接口上面有没有这个注解就行了。
  5. 那在并行多次导入的时候怎么区分哪些任务是属于同一任务的呢?这么办,运用线程标识该次任务。那开启多线程怎么办呢?开启多线程就把每一个线程都存入任务名称。
  6. 好啦,差不多思路就是这些,总结一下就是在sql执行之前,拦截要执行的sql。判定该要执行的sql的mapper层接口是否有自定义约定的注解,如果有,那么该语句是需要生成相反sql的。判定该线程中是否存储有任务名称,有则生成相反语句并存储到redis中,该任务名称为redis中的key。value采用list结构,我们从左向右添加,要是执行的话,也是从左边进行执行。

二:代码编写

  1. 首先自定义一个注解,EnableReverseSql
@Target({ElementType.METHOD, ElementType.PARAMETER})
@Retention(RetentionPolicy.RUNTIME)
public @interface EnableReverseSql {
}
  1. 定义一个生成反向sql的顶级接口,以后用于适配不同的数据库
public interface ReverseSqlDb {

    String getSql(Statement statement);

    String  insertGenerateDelete(Invocation invocation,String sql);

    List<String> updateGenerateDeleteAndInsert(Invocation invocation,String sql,String className);

    String deleteGenerateInsert(Invocation invocation,String sql,String className);

    String getDbVersion();

}
  1. 因为要生成反向sql,比如删除,只会知道id,那么我们必须要查询该id的所有信息,才能生成insert语句,这里定义一个mapper接口,用于执行在代码中生成的查询语句。在有@ReverseSqlDb注解的接口层必须继承该类。
public interface ReverseMapper {
    @Select("${sql}")
    @InterceptorIgnore(tenantLine = "true")
    LinkedHashMap<String,Object> performSql(@Param("sql") String sql);
}

例:

@Mapper
public interface TestMapper extends ReverseMapper {
}

这样我们就可以执行代码中任意生成的sql语句。

  1. 核心思想,要拦截sql,就需要实现Interceptor接口,用于拦截需要执行的语句,新建ReverseSqlInterceptor
@Slf4j
@Component
@Intercepts({
        @Signature(type = StatementHandler.class, method = "update", args = Statement.class),
        @Signature(type = StatementHandler.class, method = "batch", args = Statement.class)
})
public class ReverseSqlInterceptor implements Interceptor {

    @Resource
    ReverseSqlDbChainOfResponsibility reverseSqlDbChainOfResponsibility;

    @Autowired
    SpringConfigProperties springProperties;


    /**
     * 获取当前在使用的数据库
     * @return 数据库名称
     */
    private String getDbVersion(){
        String druidUrl;
      	//这里只是为了适配不同数据源进行的判断,最终只是要过去正在使用的是什么数据库
        if( springProperties.getDataSource().getDruid() == null){
            druidUrl = springProperties.getDataSource().getUrl();
        }else{
            druidUrl = springProperties.getDataSource().getDruid().getUrl();
        }
        return druidUrl.split(":")[1];
    }

    @Override
    public Object intercept(Invocation invocation) throws Throwable {
        //获取Statement类对象
        Statement statement = this.getStatement(invocation);
        Object target = PluginUtils.realTarget(invocation.getTarget());
        MetaObject metaObject = SystemMetaObject.forObject(target);
        MappedStatement mappedStatement = (MappedStatement) metaObject.getValue("delegate.mappedStatement");
        //获取命名空间
        String namespace = mappedStatement.getId();
        //获取类名
        String className = namespace.substring(0, namespace.lastIndexOf("."));
        //获取当前类的方法名
        String methodName = namespace.substring(namespace.lastIndexOf(".") + 1);
        //获取当前类有哪些方法
        Method[] ms = Class.forName(className).getMethods();
        for (Method m : ms) {
            if (m.getName().equals(methodName)) {
                //判断是否有这个注解
                Annotation annotation = m.getAnnotation(EnableReverseSql.class);
                if (annotation != null) {
                   //通过反射redis类实例对象并获取
                    Method getMethod = InternalThreadLocal.getMethodForNameGet();
                    HashMap<String, Object> stringObjectHashMap  = (HashMap<String, Object>) getMethod.invoke(InternalThreadLocal.treadUtilEntity);
                    String taskName = (String) stringObjectHashMap.get("name");
                    if(taskName==null){
                        throw new RuntimeException("未获取到线程中的任务名称,请添加任务名称");
                    }
                    reverseSqlDbChainOfResponsibility.selectDbAndExecuteChain(getDbVersion(),invocation,statement,className,taskName);
                } else {
                    return invocation.proceed();
                }
            }
        }
        return invocation.proceed();
    }


    /**
     * ThreadLocal内部静态类
     */
    private static class InternalThreadLocal{

        private static Class<?> treadUtilClass;

        private static Constructor<?> declaredConstructor;

        private static Object treadUtilEntity;

        private static Class[] treadUtilArguments;

        /**
         * 加载ThreadUtil工具类
         */
        static {
            try {
                treadUtilClass = Class.forName("com.common.util.ThreadUtil");
                declaredConstructor = treadUtilClass.getDeclaredConstructor();
                //强制使用私有的构造方法
                declaredConstructor.setAccessible(true);
                treadUtilEntity = declaredConstructor.newInstance();
                treadUtilArguments = new Class[0];
            } catch (InstantiationException | IllegalAccessException | ClassNotFoundException | NoSuchMethodException | InvocationTargetException e) {
                log.error("获取ThreadUtil工具类失败");
                e.printStackTrace();
            }
        }

        private static Method getMethodForNameGet() throws NoSuchMethodException {
            return treadUtilClass.getMethod("get",treadUtilArguments);
        }


    }

    /**
     * 获取statement
     */
    private Statement getStatement(Invocation invocation) {
        Statement statement;
        Object firstArg = invocation.getArgs()[0];
        if (Proxy.isProxyClass(firstArg.getClass())) {
            statement = (Statement) SystemMetaObject.forObject(firstArg).getValue("h.statement");
        } else {
            statement = (Statement) firstArg;
        }
        MetaObject stmtMetaObj = SystemMetaObject.forObject(statement);
        try {
            statement = (Statement) stmtMetaObj.getValue("stmt.statement");
        } catch (Exception e) {
            //这个位置不需要捕获异常,会报错
        }
        if (stmtMetaObj.hasGetter("delegate")) {
            try {
                statement = (Statement) stmtMetaObj.getValue("delegate");
            } catch (Exception e) {
                //这个位置不需要捕获异常,会报错
            }
        }
        if(statement != null){
            return statement;
        }else{
            throw new RuntimeException("未获取到Statement类");
        }
    }

}
  1. 为了以后适配更多的数据库,新建ReverseSqlDbChainOfResponsibility类,用于适配不同的数据库
@Component
public class ReverseSqlDbChainOfResponsibility implements CommandLineRunner, ApplicationContextAware {

    private Collection<ReverseSqlDb> reverseSqlDbList;

    private volatile ApplicationContext applicationContext;


    @Override
    public void run(String... args) throws Exception {
        init();
    }

    private void init() {
        reverseSqlDbList = new LinkedList<>(this.applicationContext.getBeansOfType(ReverseSqlDb.class).values());

    }

    @Override
    public void setApplicationContext(ApplicationContext applicationContext) throws BeansException {
        this.applicationContext=applicationContext;
    }


    @SneakyThrows
    void selectDbAndExecuteChain(String dbVersion,
                                 Invocation invocation,
                                 Statement statement,
                                 String className,
                                 String taskName){
        //反射获取列表的又添加
        Method lSetMethod = RedisMethod.getLrSetObj();
        //获取列表的右添加列表
        Method lSetListMethod = RedisMethod.getLrSetList();
        for (ReverseSqlDb reverseSqlDb:reverseSqlDbList
             ) {
            if(reverseSqlDb instanceof Proxy){
                continue;
            }
            if(dbVersion.equals(reverseSqlDb.getDbVersion())){
                //判断sql是要执行增删改中的哪一个方法
                String sql = reverseSqlDb.getSql(statement);
                if (sql.contains("INSERT") || sql.contains("insert")) {
                    //调用新增方法生成反向sql
                    String reverseSql = reverseSqlDb.insertGenerateDelete(invocation, sql);
                    lSetMethod.invoke(RedisMethod.redisEntity,taskName,reverseSql);
                } else if (sql.contains("UPDATE") || sql.contains("update")) {
                    //调用修改方法生成反向sql
                    List<String> reverseSqlList = reverseSqlDb.updateGenerateDeleteAndInsert(invocation, sql, className);
                    lSetListMethod.invoke(RedisMethod.redisEntity,taskName,reverseSqlList);
                } else if (sql.contains("DELETE") || sql.contains("delete")) {
                    //调用删除方法生成反向sql
                    String reverseSql = reverseSqlDb.deleteGenerateInsert(invocation, sql, className);
                    lSetMethod.invoke(RedisMethod.redisEntity,taskName,reverseSql);
                }
            }
        }
    }
}
  1. 编写具体的实现类ReverseSqlDbPg
@Slf4j
@Component
public class ReverseSqlDbPg extends AbstractBusiness implements ReverseSqlDb {

    private static final String DRUID_POOLED_PREPARED_STATEMENT = "com.alibaba.druid.pool.DruidPooledPreparedStatement";
    private static final String T4C_PREPARED_STATEMENT = "oracle.jdbc.driver.T4CPreparedStatement";
    private static final String ORACLE_PREPARED_STATEMENT_WRAPPER = "oracle.jdbc.driver.OraclePreparedStatementWrapper";

    private Method oracleGetOriginalSqlMethod;
    private Method druidGetSqlMethod;

    static final String DB_VERSION = "postgresql";

    /**
     * 获取当前正在执行的sql
     *
     * @param statement 声明
     * @return 当前要执行的语句
     */
    @Override
    public String getSql(Statement statement) {
        String originalSql = null;
        String stmtClassName = statement.getClass().getName();
        if (DRUID_POOLED_PREPARED_STATEMENT.equals(stmtClassName)) {
            try {
                if (druidGetSqlMethod == null) {
                    Class<?> clazz = Class.forName(DRUID_POOLED_PREPARED_STATEMENT);
                    druidGetSqlMethod = clazz.getMethod("getSql");
                }
                Object stmtSql = druidGetSqlMethod.invoke(statement);
                if (stmtSql instanceof String) {
                    originalSql = (String) stmtSql;
                }
            } catch (Exception e) {
                e.printStackTrace();
            }
        } else if (T4C_PREPARED_STATEMENT.equals(stmtClassName)
                || ORACLE_PREPARED_STATEMENT_WRAPPER.equals(stmtClassName)) {
            try {
                if (oracleGetOriginalSqlMethod != null) {
                    Object stmtSql = oracleGetOriginalSqlMethod.invoke(statement);
                    if (stmtSql instanceof String) {
                        originalSql = (String) stmtSql;
                    }
                } else {
                    Class<?> clazz = Class.forName(stmtClassName);
                    oracleGetOriginalSqlMethod = getMethodRegular(clazz, "getOriginalSql");
                    if (oracleGetOriginalSqlMethod != null) {
                        oracleGetOriginalSqlMethod.setAccessible(true);
                        if (null != oracleGetOriginalSqlMethod) {
                            Object stmtSql = oracleGetOriginalSqlMethod.invoke(statement);
                            if (stmtSql instanceof String) {
                                originalSql = (String) stmtSql;
                            }
                        }
                    }
                }
            } catch (Exception e) {
                //ignore
            }
        }
        if (originalSql == null) {
            originalSql = statement.toString();
        }
        originalSql = originalSql.replaceAll("[\\s]+", StringPool.SPACE);
        int index = indexOfSqlStart(originalSql);
        if (index > 0) {
            originalSql = originalSql.substring(index);
        }
        return originalSql;
    }

    /**
     * 新增生成删除
     *
     * @param invocation
     * @param sql        要执行的sql
     * @return 生成的反向sql
     */
    @Override
    public String insertGenerateDelete(Invocation invocation, String sql) {
        List<String> paramTerList = this.getParamTerList(invocation);
        //添加的参数列表第一位是id,我们就默认第一位是id,添加的话只需要反向生成删除的sql即可
        //获取要删除的表名,添加语句的insert into 表名,所以这里取列表中的第三位
        String[] words = sql.split(" ");
        String tableName = words[2];
        //拼接反向sql
        StringBuilder stringBuilder = new StringBuilder();
        stringBuilder.append("delete from ");
        stringBuilder.append(tableName);
        stringBuilder.append(" where id = ");
        stringBuilder.append(paramTerList.get(0));
        return stringBuilder.toString();
    }

    /**
     * 修改方法生成删除和新增方法实现
     */
    @Override
    public List<String> updateGenerateDeleteAndInsert(Invocation invocation, String sql, String className) {
        ArrayList<String> resList = new ArrayList<>();
        List<String> paramTerList = this.getParamTerList(invocation);
        //修改的语句最后的参数为id,默认最后的条件为id.表明则为单词的第二个单词,由此获得id与表名
        String[] words = sql.split(" ");
        String tableName = words[1];
        String id = paramTerList.get(paramTerList.size() - 1);
        //生成删除语句
        StringBuffer deleteBuffer = new StringBuffer();
        deleteBuffer.append("delete from ");
        deleteBuffer.append(tableName);
        deleteBuffer.append(" where id = ");
        deleteBuffer.append(id);
        //修改我们需要查询该id的所有数据,这里通过反射注入该接口,并通过继承的方式必须实现我们规定的接口,从而执行拼接的查询sql
        LinkedHashMap<String, Object> resMap = this.getMapById(className, tableName, id);
        //获取所有的value
        Set<String> keys = resMap.keySet();
        //拼接新增语句
        StringBuffer insertBuffer = new StringBuffer();
        insertBuffer.append("insert into ");
        insertBuffer.append(tableName);
        insertBuffer.append(" values (");
        for (String key : keys
        ) {
            if (resMap.get(key) != null) {
                insertBuffer.append("'");
                insertBuffer.append(resMap.get(key));
                insertBuffer.append("'");
                insertBuffer.append(",");
            } else {
                insertBuffer.append(resMap.get(key));
                insertBuffer.append(",");
            }

        }
        insertBuffer.deleteCharAt(insertBuffer.length() - 1);
        insertBuffer.append(")");
        //结果添加到列表
        resList.add(deleteBuffer.toString());
        resList.add(insertBuffer.toString());
        log.info(resMap.toString());
        return resList;
    }

    /**
     *  删除语句生成新增的具体执行方法
     */
    @Override
    public String deleteGenerateInsert(Invocation invocation, String sql, String className) {
        List<String> paramTerList = this.getParamTerList(invocation);
        //修改的语句最后的参数为id,默认最后的条件为id.表名则为单词的第三个单词,由此获得id与表名
        String[] words = sql.split(" ");
        String tableName = words[2];
        String id = paramTerList.get(paramTerList.size()-1);
        LinkedHashMap<String, Object> resMap = this.getMapById(className, tableName, id);
        //获取所有的value
        Set<String> keys = resMap.keySet();
        //拼接新增语句
        StringBuffer insertBuffer = new StringBuffer();
        insertBuffer.append("insert into ");
        insertBuffer.append(tableName);
        insertBuffer.append(" values (");
        for (String key:keys
        ) {
            if(resMap.get(key)!=null){
                insertBuffer.append("'");
                insertBuffer.append(resMap.get(key));
                insertBuffer.append("'");
                insertBuffer.append(",");
            }else{
                insertBuffer.append(resMap.get(key));
                insertBuffer.append(",");
            }
        }
        insertBuffer.deleteCharAt(insertBuffer.length()-1);
        insertBuffer.append(")");
        insertBuffer.append(" on CONFLICT(id) do NOTHING ");
        return insertBuffer.toString();
    }

    /**
     * 获取当前执行器是哪个执行器
     * @return 执行器名称
     */
    @Override
    public String getDbVersion() {
        return DB_VERSION;
    }

    /**
     * 通过反射获取接口并执行继承的方法
     */
    private LinkedHashMap<String, Object> getMapById(String className, String tableName, String id) {
        Class<? extends ReverseMapper> serviceClass;
        try {
            serviceClass = (Class<? extends ReverseMapper>) Class.forName(className);
        } catch (Exception e) {
            throw new RuntimeException("如使用**注解,请继承ReverseMapper接口");
        }
        //生成查询语句
        StringBuffer selectBuffer = new StringBuffer();
        selectBuffer.append("select * from ");
        selectBuffer.append(tableName);
        selectBuffer.append(" where id = ");
        selectBuffer.append(id);
        //反射调用规定好的方法
        return super.getMapper(serviceClass).performSql(selectBuffer.toString());
    }


    /**
     * 获取该语句的参数列表
     */
    private List<String> getParamTerList(Invocation invocation) {
        Object target = PluginUtils.realTarget(invocation.getTarget());
        MetaObject metaObject = SystemMetaObject.forObject(target);
        // 参数
        BoundSql boundSql = (BoundSql) metaObject.getValue("delegate.boundSql");
        Object parameterObject = boundSql.getParameterObject();
        List<ParameterMapping> parameterMappings = new ArrayList<>(boundSql.getParameterMappings());
        if (parameterMappings.isEmpty() && parameterObject == null) {
            log.warn("parameterMappings is empty or parameterObject is null");
        }
        MappedStatement mappedStatement = (MappedStatement) metaObject.getValue("delegate.mappedStatement");
        Configuration configuration = mappedStatement.getConfiguration();
        TypeHandlerRegistry typeHandlerRegistry = configuration.getTypeHandlerRegistry();
        List<String> parameterList = new ArrayList<>();
        MetaObject newMetaObject = configuration.newMetaObject(parameterObject);
        for (ParameterMapping parameterMapping : parameterMappings) {
            String parameter = null;
            if (parameterMapping.getMode() == ParameterMode.OUT) {
                continue;
            }
            String propertyName = parameterMapping.getProperty();
            if (typeHandlerRegistry.hasTypeHandler(parameterObject.getClass())) {
                parameter = getParameterValue(parameterObject);
            } else if (newMetaObject.hasGetter(propertyName)) {
                parameter = getParameterValue(newMetaObject.getValue(propertyName));
            } else if (boundSql.hasAdditionalParameter(propertyName)) {
                parameter = getParameterValue(boundSql.getAdditionalParameter(propertyName));
            }
            parameterList.add(parameter);
        }
        return parameterList;
    }


    /**
     * 获取参数
     *
     * @param param Object类型参数
     * @return 转换之后的参数
     */
    private static String getParameterValue(Object param) {
        if (param == null) {
            return "null";
        }
        if (param instanceof Number) {
            return param.toString();
        }
        String value = param.toString();
        return StringUtils.quotaMark(value);
    }


    /**
     * 获取此方法名的具体 Method
     *
     * @param clazz      class 对象
     * @param methodName 方法名
     * @return 方法
     */
    private Method getMethodRegular(Class<?> clazz, String methodName) {
        if (Object.class.equals(clazz)) {
            return null;
        }
        for (Method method : clazz.getDeclaredMethods()) {
            if (method.getName().equals(methodName)) {
                return method;
            }
        }
        return getMethodRegular(clazz.getSuperclass(), methodName);
    }

    /**
     * 获取sql语句开头部分
     *
     * @param sql ignore
     * @return ignore
     */
    private int indexOfSqlStart(String sql) {
        String upperCaseSql = sql.toUpperCase();
        Set<Integer> set = new HashSet<>();
        set.add(upperCaseSql.indexOf("SELECT "));
        set.add(upperCaseSql.indexOf("UPDATE "));
        set.add(upperCaseSql.indexOf("INSERT "));
        set.add(upperCaseSql.indexOf("DELETE "));
        set.remove(-1);
        if (CollectionUtils.isEmpty(set)) {
            return -1;
        }
        List<Integer> list = new ArrayList<>(set);
        list.sort(Comparator.naturalOrder());
        return list.get(0);
    }


}
  1. 这里在获取mapper接口中个各个方法的时候,为了防止具体使用者没有继承ReverseMapper接口,写了一个AbstractBusiness类抽象类。代码如下
public abstract class AbstractBusiness {

    @Resource
    BeanHelper helper;
    public void setMapper(BeanHelper helper) {
        this.helper = helper;
    }

    public BeanHelper getMapper() {
        return helper;
    }

    public final <T extends ReverseMapper> T getMapper(Class<T> t){
        return helper.getBean(t);
    }

}
  1. BeanHelper代码
@Component
public class BeanHelper implements ApplicationContextAware {
    public static ApplicationContext applicationContext;

    @Override
    public void setApplicationContext(ApplicationContext applicationContext) throws BeansException {
        BeanHelper.applicationContext = applicationContext;
    }

    public <T extends ReverseMapper> T getBean(Class<T> t) {
        return applicationContext.getBean(t);
    }

    public Object getSpringBean(String s) {
        return  applicationContext.getBean(s);
    }
}
  1. 其他帮助类
@Slf4j
public class RedisMethod {

    private static Class<?> redisClass;

    public static Object redisEntity;

    static {
        try {
            redisClass = Class.forName("com.component.redis.Redis");
            if(redisClass == null){
                throw new RuntimeException("获取不到redis工具类,请引入包后重试");
            }
            redisEntity = redisClass.newInstance();
        } catch (IllegalAccessException | InstantiationException | ClassNotFoundException e) {
            log.error("获取Redis工具类失败");
            e.printStackTrace();
        }
    }

    /**
     * 获取redis工具类右添加obj的方法
     * @return 方法
     */
    @SneakyThrows
    public static Method getLrSetObj(){
        Class[] rightPushArguments = new Class[2];
        rightPushArguments[0] = String.class;
        rightPushArguments[1] = Object.class;
        return redisClass.getMethod("lrSetObj",rightPushArguments);
    }

    /**
     * 获取又添加列表的方法
     * @return 方法
     */
    @SneakyThrows
    public static Method getLrSetList(){
        Class[] rightPushListArguments = new Class[2];
        rightPushListArguments[0] = String.class;
        rightPushListArguments[1] = List.class;
        return redisClass.getMethod("lrSetList",rightPushListArguments);
    }
}

@Slf4j
public class ThreadUtil {

    private ThreadUtil(){ }


    private static final ThreadLocal threadLocal = new ThreadLocal<>();

    public static void set(Object o) {
        threadLocal.set(o);
    }

    public static Object get() {
        return threadLocal.get();
    }

    public static void remove(){
        threadLocal.remove();
    }

}

三:代码执行

public class TestController {
  
   @Resource
   private TestMapper testMapper;
  
   public void test(){
     //此处放入线程的为任务名称,根据业务自行调整,多线程时也需要放入线程中该任务名称
     ThreadUtil.set("test");
     testMapper.insert();
     ThreadUtil.remove();
   }
}
  1. 最后线程执行结束的时候,不要忘了调用ThreadUtil.remove()方法删除,删除线程中数据。
  2. 思路就是这样,代码还有优化的空间。可自行修改。

四:注意事项

  1. 当前代码只支持一条条插入,一条条的删除,一条一条的修改,而且只能是最基本的增删改,要使用此功能的话需要编写对应接口的标准的增删改语句,标准语句的格式在代码中有所描述。
  2. 数据在更新之前都会进行一次查询,速度会响应的减慢。
  3. 生成的反向sql是存储在redis中的,可以把redis中的key存储在数据库,(反向的sql已经有了,怎么用,怎么执行看自己的业务),redis中的数据可以在下一次导入之前进行删除。

我:喂,总监啊,搞定了啊!

总监:不错,小伙子。