总监:喂,小王啊!起来没呢?加个班呗!
我:泥煤啊…
总监:我有个需求啊,这最近导入数据比较多,但是后台用户反映导入了数据,不想要了,删除起来麻烦啊!你也知道,顾客是上帝嘛,给我完成一个导入数据自动一键回滚的功能!
我:说啥也不干,今天休息,我还要打游戏。
总监:那个你申请一个在家办公,两倍工资,我这面批一个。
我:好嘞!
一:需求分析
- 导入一定分为很多种,有商品的,有图片的,有各种业务的,一定要兼容各种具体的业务,那么就不能依赖于具体实现。
- 分析在各个业务层,导入无谓就是处理完数据之后生成的增删改语句。
- 那我只需要处理sql语句就可以了,把增删改的语句生成它具体的相反的语句。insert生成delete语句,udapte生成delete和insert语句,delete生成insert语句。
- 那么多的mapper层接口的语句,怎么知道哪个语句是需要生成相反的语句呢?可以自定义一个注解,然后我们在执行之前看看该接口上面有没有这个注解就行了。
- 那在并行多次导入的时候怎么区分哪些任务是属于同一任务的呢?这么办,运用线程标识该次任务。那开启多线程怎么办呢?开启多线程就把每一个线程都存入任务名称。
- 好啦,差不多思路就是这些,总结一下就是在sql执行之前,拦截要执行的sql。判定该要执行的sql的mapper层接口是否有自定义约定的注解,如果有,那么该语句是需要生成相反sql的。判定该线程中是否存储有任务名称,有则生成相反语句并存储到redis中,该任务名称为redis中的key。value采用list结构,我们从左向右添加,要是执行的话,也是从左边进行执行。
二:代码编写
- 首先自定义一个注解,EnableReverseSql
@Target({ElementType.METHOD, ElementType.PARAMETER})
@Retention(RetentionPolicy.RUNTIME)
public @interface EnableReverseSql {
}
- 定义一个生成反向sql的顶级接口,以后用于适配不同的数据库
public interface ReverseSqlDb {
String getSql(Statement statement);
String insertGenerateDelete(Invocation invocation,String sql);
List<String> updateGenerateDeleteAndInsert(Invocation invocation,String sql,String className);
String deleteGenerateInsert(Invocation invocation,String sql,String className);
String getDbVersion();
}
- 因为要生成反向sql,比如删除,只会知道id,那么我们必须要查询该id的所有信息,才能生成insert语句,这里定义一个mapper接口,用于执行在代码中生成的查询语句。在有@ReverseSqlDb注解的接口层必须继承该类。
public interface ReverseMapper {
@Select("${sql}")
@InterceptorIgnore(tenantLine = "true")
LinkedHashMap<String,Object> performSql(@Param("sql") String sql);
}
例:
@Mapper
public interface TestMapper extends ReverseMapper {
}
这样我们就可以执行代码中任意生成的sql语句。
- 核心思想,要拦截sql,就需要实现Interceptor接口,用于拦截需要执行的语句,新建ReverseSqlInterceptor
@Slf4j
@Component
@Intercepts({
@Signature(type = StatementHandler.class, method = "update", args = Statement.class),
@Signature(type = StatementHandler.class, method = "batch", args = Statement.class)
})
public class ReverseSqlInterceptor implements Interceptor {
@Resource
ReverseSqlDbChainOfResponsibility reverseSqlDbChainOfResponsibility;
@Autowired
SpringConfigProperties springProperties;
/**
* 获取当前在使用的数据库
* @return 数据库名称
*/
private String getDbVersion(){
String druidUrl;
//这里只是为了适配不同数据源进行的判断,最终只是要过去正在使用的是什么数据库
if( springProperties.getDataSource().getDruid() == null){
druidUrl = springProperties.getDataSource().getUrl();
}else{
druidUrl = springProperties.getDataSource().getDruid().getUrl();
}
return druidUrl.split(":")[1];
}
@Override
public Object intercept(Invocation invocation) throws Throwable {
//获取Statement类对象
Statement statement = this.getStatement(invocation);
Object target = PluginUtils.realTarget(invocation.getTarget());
MetaObject metaObject = SystemMetaObject.forObject(target);
MappedStatement mappedStatement = (MappedStatement) metaObject.getValue("delegate.mappedStatement");
//获取命名空间
String namespace = mappedStatement.getId();
//获取类名
String className = namespace.substring(0, namespace.lastIndexOf("."));
//获取当前类的方法名
String methodName = namespace.substring(namespace.lastIndexOf(".") + 1);
//获取当前类有哪些方法
Method[] ms = Class.forName(className).getMethods();
for (Method m : ms) {
if (m.getName().equals(methodName)) {
//判断是否有这个注解
Annotation annotation = m.getAnnotation(EnableReverseSql.class);
if (annotation != null) {
//通过反射redis类实例对象并获取
Method getMethod = InternalThreadLocal.getMethodForNameGet();
HashMap<String, Object> stringObjectHashMap = (HashMap<String, Object>) getMethod.invoke(InternalThreadLocal.treadUtilEntity);
String taskName = (String) stringObjectHashMap.get("name");
if(taskName==null){
throw new RuntimeException("未获取到线程中的任务名称,请添加任务名称");
}
reverseSqlDbChainOfResponsibility.selectDbAndExecuteChain(getDbVersion(),invocation,statement,className,taskName);
} else {
return invocation.proceed();
}
}
}
return invocation.proceed();
}
/**
* ThreadLocal内部静态类
*/
private static class InternalThreadLocal{
private static Class<?> treadUtilClass;
private static Constructor<?> declaredConstructor;
private static Object treadUtilEntity;
private static Class[] treadUtilArguments;
/**
* 加载ThreadUtil工具类
*/
static {
try {
treadUtilClass = Class.forName("com.common.util.ThreadUtil");
declaredConstructor = treadUtilClass.getDeclaredConstructor();
//强制使用私有的构造方法
declaredConstructor.setAccessible(true);
treadUtilEntity = declaredConstructor.newInstance();
treadUtilArguments = new Class[0];
} catch (InstantiationException | IllegalAccessException | ClassNotFoundException | NoSuchMethodException | InvocationTargetException e) {
log.error("获取ThreadUtil工具类失败");
e.printStackTrace();
}
}
private static Method getMethodForNameGet() throws NoSuchMethodException {
return treadUtilClass.getMethod("get",treadUtilArguments);
}
}
/**
* 获取statement
*/
private Statement getStatement(Invocation invocation) {
Statement statement;
Object firstArg = invocation.getArgs()[0];
if (Proxy.isProxyClass(firstArg.getClass())) {
statement = (Statement) SystemMetaObject.forObject(firstArg).getValue("h.statement");
} else {
statement = (Statement) firstArg;
}
MetaObject stmtMetaObj = SystemMetaObject.forObject(statement);
try {
statement = (Statement) stmtMetaObj.getValue("stmt.statement");
} catch (Exception e) {
//这个位置不需要捕获异常,会报错
}
if (stmtMetaObj.hasGetter("delegate")) {
try {
statement = (Statement) stmtMetaObj.getValue("delegate");
} catch (Exception e) {
//这个位置不需要捕获异常,会报错
}
}
if(statement != null){
return statement;
}else{
throw new RuntimeException("未获取到Statement类");
}
}
}
- 为了以后适配更多的数据库,新建ReverseSqlDbChainOfResponsibility类,用于适配不同的数据库
@Component
public class ReverseSqlDbChainOfResponsibility implements CommandLineRunner, ApplicationContextAware {
private Collection<ReverseSqlDb> reverseSqlDbList;
private volatile ApplicationContext applicationContext;
@Override
public void run(String... args) throws Exception {
init();
}
private void init() {
reverseSqlDbList = new LinkedList<>(this.applicationContext.getBeansOfType(ReverseSqlDb.class).values());
}
@Override
public void setApplicationContext(ApplicationContext applicationContext) throws BeansException {
this.applicationContext=applicationContext;
}
@SneakyThrows
void selectDbAndExecuteChain(String dbVersion,
Invocation invocation,
Statement statement,
String className,
String taskName){
//反射获取列表的又添加
Method lSetMethod = RedisMethod.getLrSetObj();
//获取列表的右添加列表
Method lSetListMethod = RedisMethod.getLrSetList();
for (ReverseSqlDb reverseSqlDb:reverseSqlDbList
) {
if(reverseSqlDb instanceof Proxy){
continue;
}
if(dbVersion.equals(reverseSqlDb.getDbVersion())){
//判断sql是要执行增删改中的哪一个方法
String sql = reverseSqlDb.getSql(statement);
if (sql.contains("INSERT") || sql.contains("insert")) {
//调用新增方法生成反向sql
String reverseSql = reverseSqlDb.insertGenerateDelete(invocation, sql);
lSetMethod.invoke(RedisMethod.redisEntity,taskName,reverseSql);
} else if (sql.contains("UPDATE") || sql.contains("update")) {
//调用修改方法生成反向sql
List<String> reverseSqlList = reverseSqlDb.updateGenerateDeleteAndInsert(invocation, sql, className);
lSetListMethod.invoke(RedisMethod.redisEntity,taskName,reverseSqlList);
} else if (sql.contains("DELETE") || sql.contains("delete")) {
//调用删除方法生成反向sql
String reverseSql = reverseSqlDb.deleteGenerateInsert(invocation, sql, className);
lSetMethod.invoke(RedisMethod.redisEntity,taskName,reverseSql);
}
}
}
}
}
- 编写具体的实现类ReverseSqlDbPg
@Slf4j
@Component
public class ReverseSqlDbPg extends AbstractBusiness implements ReverseSqlDb {
private static final String DRUID_POOLED_PREPARED_STATEMENT = "com.alibaba.druid.pool.DruidPooledPreparedStatement";
private static final String T4C_PREPARED_STATEMENT = "oracle.jdbc.driver.T4CPreparedStatement";
private static final String ORACLE_PREPARED_STATEMENT_WRAPPER = "oracle.jdbc.driver.OraclePreparedStatementWrapper";
private Method oracleGetOriginalSqlMethod;
private Method druidGetSqlMethod;
static final String DB_VERSION = "postgresql";
/**
* 获取当前正在执行的sql
*
* @param statement 声明
* @return 当前要执行的语句
*/
@Override
public String getSql(Statement statement) {
String originalSql = null;
String stmtClassName = statement.getClass().getName();
if (DRUID_POOLED_PREPARED_STATEMENT.equals(stmtClassName)) {
try {
if (druidGetSqlMethod == null) {
Class<?> clazz = Class.forName(DRUID_POOLED_PREPARED_STATEMENT);
druidGetSqlMethod = clazz.getMethod("getSql");
}
Object stmtSql = druidGetSqlMethod.invoke(statement);
if (stmtSql instanceof String) {
originalSql = (String) stmtSql;
}
} catch (Exception e) {
e.printStackTrace();
}
} else if (T4C_PREPARED_STATEMENT.equals(stmtClassName)
|| ORACLE_PREPARED_STATEMENT_WRAPPER.equals(stmtClassName)) {
try {
if (oracleGetOriginalSqlMethod != null) {
Object stmtSql = oracleGetOriginalSqlMethod.invoke(statement);
if (stmtSql instanceof String) {
originalSql = (String) stmtSql;
}
} else {
Class<?> clazz = Class.forName(stmtClassName);
oracleGetOriginalSqlMethod = getMethodRegular(clazz, "getOriginalSql");
if (oracleGetOriginalSqlMethod != null) {
oracleGetOriginalSqlMethod.setAccessible(true);
if (null != oracleGetOriginalSqlMethod) {
Object stmtSql = oracleGetOriginalSqlMethod.invoke(statement);
if (stmtSql instanceof String) {
originalSql = (String) stmtSql;
}
}
}
}
} catch (Exception e) {
//ignore
}
}
if (originalSql == null) {
originalSql = statement.toString();
}
originalSql = originalSql.replaceAll("[\\s]+", StringPool.SPACE);
int index = indexOfSqlStart(originalSql);
if (index > 0) {
originalSql = originalSql.substring(index);
}
return originalSql;
}
/**
* 新增生成删除
*
* @param invocation
* @param sql 要执行的sql
* @return 生成的反向sql
*/
@Override
public String insertGenerateDelete(Invocation invocation, String sql) {
List<String> paramTerList = this.getParamTerList(invocation);
//添加的参数列表第一位是id,我们就默认第一位是id,添加的话只需要反向生成删除的sql即可
//获取要删除的表名,添加语句的insert into 表名,所以这里取列表中的第三位
String[] words = sql.split(" ");
String tableName = words[2];
//拼接反向sql
StringBuilder stringBuilder = new StringBuilder();
stringBuilder.append("delete from ");
stringBuilder.append(tableName);
stringBuilder.append(" where id = ");
stringBuilder.append(paramTerList.get(0));
return stringBuilder.toString();
}
/**
* 修改方法生成删除和新增方法实现
*/
@Override
public List<String> updateGenerateDeleteAndInsert(Invocation invocation, String sql, String className) {
ArrayList<String> resList = new ArrayList<>();
List<String> paramTerList = this.getParamTerList(invocation);
//修改的语句最后的参数为id,默认最后的条件为id.表明则为单词的第二个单词,由此获得id与表名
String[] words = sql.split(" ");
String tableName = words[1];
String id = paramTerList.get(paramTerList.size() - 1);
//生成删除语句
StringBuffer deleteBuffer = new StringBuffer();
deleteBuffer.append("delete from ");
deleteBuffer.append(tableName);
deleteBuffer.append(" where id = ");
deleteBuffer.append(id);
//修改我们需要查询该id的所有数据,这里通过反射注入该接口,并通过继承的方式必须实现我们规定的接口,从而执行拼接的查询sql
LinkedHashMap<String, Object> resMap = this.getMapById(className, tableName, id);
//获取所有的value
Set<String> keys = resMap.keySet();
//拼接新增语句
StringBuffer insertBuffer = new StringBuffer();
insertBuffer.append("insert into ");
insertBuffer.append(tableName);
insertBuffer.append(" values (");
for (String key : keys
) {
if (resMap.get(key) != null) {
insertBuffer.append("'");
insertBuffer.append(resMap.get(key));
insertBuffer.append("'");
insertBuffer.append(",");
} else {
insertBuffer.append(resMap.get(key));
insertBuffer.append(",");
}
}
insertBuffer.deleteCharAt(insertBuffer.length() - 1);
insertBuffer.append(")");
//结果添加到列表
resList.add(deleteBuffer.toString());
resList.add(insertBuffer.toString());
log.info(resMap.toString());
return resList;
}
/**
* 删除语句生成新增的具体执行方法
*/
@Override
public String deleteGenerateInsert(Invocation invocation, String sql, String className) {
List<String> paramTerList = this.getParamTerList(invocation);
//修改的语句最后的参数为id,默认最后的条件为id.表名则为单词的第三个单词,由此获得id与表名
String[] words = sql.split(" ");
String tableName = words[2];
String id = paramTerList.get(paramTerList.size()-1);
LinkedHashMap<String, Object> resMap = this.getMapById(className, tableName, id);
//获取所有的value
Set<String> keys = resMap.keySet();
//拼接新增语句
StringBuffer insertBuffer = new StringBuffer();
insertBuffer.append("insert into ");
insertBuffer.append(tableName);
insertBuffer.append(" values (");
for (String key:keys
) {
if(resMap.get(key)!=null){
insertBuffer.append("'");
insertBuffer.append(resMap.get(key));
insertBuffer.append("'");
insertBuffer.append(",");
}else{
insertBuffer.append(resMap.get(key));
insertBuffer.append(",");
}
}
insertBuffer.deleteCharAt(insertBuffer.length()-1);
insertBuffer.append(")");
insertBuffer.append(" on CONFLICT(id) do NOTHING ");
return insertBuffer.toString();
}
/**
* 获取当前执行器是哪个执行器
* @return 执行器名称
*/
@Override
public String getDbVersion() {
return DB_VERSION;
}
/**
* 通过反射获取接口并执行继承的方法
*/
private LinkedHashMap<String, Object> getMapById(String className, String tableName, String id) {
Class<? extends ReverseMapper> serviceClass;
try {
serviceClass = (Class<? extends ReverseMapper>) Class.forName(className);
} catch (Exception e) {
throw new RuntimeException("如使用**注解,请继承ReverseMapper接口");
}
//生成查询语句
StringBuffer selectBuffer = new StringBuffer();
selectBuffer.append("select * from ");
selectBuffer.append(tableName);
selectBuffer.append(" where id = ");
selectBuffer.append(id);
//反射调用规定好的方法
return super.getMapper(serviceClass).performSql(selectBuffer.toString());
}
/**
* 获取该语句的参数列表
*/
private List<String> getParamTerList(Invocation invocation) {
Object target = PluginUtils.realTarget(invocation.getTarget());
MetaObject metaObject = SystemMetaObject.forObject(target);
// 参数
BoundSql boundSql = (BoundSql) metaObject.getValue("delegate.boundSql");
Object parameterObject = boundSql.getParameterObject();
List<ParameterMapping> parameterMappings = new ArrayList<>(boundSql.getParameterMappings());
if (parameterMappings.isEmpty() && parameterObject == null) {
log.warn("parameterMappings is empty or parameterObject is null");
}
MappedStatement mappedStatement = (MappedStatement) metaObject.getValue("delegate.mappedStatement");
Configuration configuration = mappedStatement.getConfiguration();
TypeHandlerRegistry typeHandlerRegistry = configuration.getTypeHandlerRegistry();
List<String> parameterList = new ArrayList<>();
MetaObject newMetaObject = configuration.newMetaObject(parameterObject);
for (ParameterMapping parameterMapping : parameterMappings) {
String parameter = null;
if (parameterMapping.getMode() == ParameterMode.OUT) {
continue;
}
String propertyName = parameterMapping.getProperty();
if (typeHandlerRegistry.hasTypeHandler(parameterObject.getClass())) {
parameter = getParameterValue(parameterObject);
} else if (newMetaObject.hasGetter(propertyName)) {
parameter = getParameterValue(newMetaObject.getValue(propertyName));
} else if (boundSql.hasAdditionalParameter(propertyName)) {
parameter = getParameterValue(boundSql.getAdditionalParameter(propertyName));
}
parameterList.add(parameter);
}
return parameterList;
}
/**
* 获取参数
*
* @param param Object类型参数
* @return 转换之后的参数
*/
private static String getParameterValue(Object param) {
if (param == null) {
return "null";
}
if (param instanceof Number) {
return param.toString();
}
String value = param.toString();
return StringUtils.quotaMark(value);
}
/**
* 获取此方法名的具体 Method
*
* @param clazz class 对象
* @param methodName 方法名
* @return 方法
*/
private Method getMethodRegular(Class<?> clazz, String methodName) {
if (Object.class.equals(clazz)) {
return null;
}
for (Method method : clazz.getDeclaredMethods()) {
if (method.getName().equals(methodName)) {
return method;
}
}
return getMethodRegular(clazz.getSuperclass(), methodName);
}
/**
* 获取sql语句开头部分
*
* @param sql ignore
* @return ignore
*/
private int indexOfSqlStart(String sql) {
String upperCaseSql = sql.toUpperCase();
Set<Integer> set = new HashSet<>();
set.add(upperCaseSql.indexOf("SELECT "));
set.add(upperCaseSql.indexOf("UPDATE "));
set.add(upperCaseSql.indexOf("INSERT "));
set.add(upperCaseSql.indexOf("DELETE "));
set.remove(-1);
if (CollectionUtils.isEmpty(set)) {
return -1;
}
List<Integer> list = new ArrayList<>(set);
list.sort(Comparator.naturalOrder());
return list.get(0);
}
}
- 这里在获取mapper接口中个各个方法的时候,为了防止具体使用者没有继承ReverseMapper接口,写了一个AbstractBusiness类抽象类。代码如下
public abstract class AbstractBusiness {
@Resource
BeanHelper helper;
public void setMapper(BeanHelper helper) {
this.helper = helper;
}
public BeanHelper getMapper() {
return helper;
}
public final <T extends ReverseMapper> T getMapper(Class<T> t){
return helper.getBean(t);
}
}
- BeanHelper代码
@Component
public class BeanHelper implements ApplicationContextAware {
public static ApplicationContext applicationContext;
@Override
public void setApplicationContext(ApplicationContext applicationContext) throws BeansException {
BeanHelper.applicationContext = applicationContext;
}
public <T extends ReverseMapper> T getBean(Class<T> t) {
return applicationContext.getBean(t);
}
public Object getSpringBean(String s) {
return applicationContext.getBean(s);
}
}
- 其他帮助类
@Slf4j
public class RedisMethod {
private static Class<?> redisClass;
public static Object redisEntity;
static {
try {
redisClass = Class.forName("com.component.redis.Redis");
if(redisClass == null){
throw new RuntimeException("获取不到redis工具类,请引入包后重试");
}
redisEntity = redisClass.newInstance();
} catch (IllegalAccessException | InstantiationException | ClassNotFoundException e) {
log.error("获取Redis工具类失败");
e.printStackTrace();
}
}
/**
* 获取redis工具类右添加obj的方法
* @return 方法
*/
@SneakyThrows
public static Method getLrSetObj(){
Class[] rightPushArguments = new Class[2];
rightPushArguments[0] = String.class;
rightPushArguments[1] = Object.class;
return redisClass.getMethod("lrSetObj",rightPushArguments);
}
/**
* 获取又添加列表的方法
* @return 方法
*/
@SneakyThrows
public static Method getLrSetList(){
Class[] rightPushListArguments = new Class[2];
rightPushListArguments[0] = String.class;
rightPushListArguments[1] = List.class;
return redisClass.getMethod("lrSetList",rightPushListArguments);
}
}
@Slf4j
public class ThreadUtil {
private ThreadUtil(){ }
private static final ThreadLocal threadLocal = new ThreadLocal<>();
public static void set(Object o) {
threadLocal.set(o);
}
public static Object get() {
return threadLocal.get();
}
public static void remove(){
threadLocal.remove();
}
}
三:代码执行
public class TestController {
@Resource
private TestMapper testMapper;
public void test(){
//此处放入线程的为任务名称,根据业务自行调整,多线程时也需要放入线程中该任务名称
ThreadUtil.set("test");
testMapper.insert();
ThreadUtil.remove();
}
}
- 最后线程执行结束的时候,不要忘了调用ThreadUtil.remove()方法删除,删除线程中数据。
- 思路就是这样,代码还有优化的空间。可自行修改。
四:注意事项
- 当前代码只支持一条条插入,一条条的删除,一条一条的修改,而且只能是最基本的增删改,要使用此功能的话需要编写对应接口的标准的增删改语句,标准语句的格式在代码中有所描述。
- 数据在更新之前都会进行一次查询,速度会响应的减慢。
- 生成的反向sql是存储在redis中的,可以把redis中的key存储在数据库,(反向的sql已经有了,怎么用,怎么执行看自己的业务),redis中的数据可以在下一次导入之前进行删除。
我:喂,总监啊,搞定了啊!
总监:不错,小伙子。