最近经常oracle数据库或mysql数据库的大量数据导入工作,所以利用python编写了一个快速导入的工具.话不多说直接上代码.

1、配置文件

[sync]
#original:1-ORACLE|2-MYSQL|3-SQL|4-CSV  (数据来源,)
original = 4
#destination:1-ORACLE|2-MYSQL          (目标导入数据库类型)
destination = 2
#批量导入每次的条数
rowevery = 3000
#表名列表用逗号分隔,SQL或CSV文件类型时配置为: 表名.sql 或 表名.csv
tables = TABLE1,TABLE2
[mysql]
host=127.0.0.1
port=8911
database=data_test
user=adm
password=adm
[oracle]
host=192.168.0.21
port=1521
database=data_test
user=adm
password=adm

2、主程序

#coding=utf-8
import os
import sys
#import pdb 
#添加模块路径
#sys.path.append(sys.path[0] + '/inc')
sys.path.append('../inc')
import pymysql
import cx_Oracle
import ConfigParser
import time
import datetime
import re
import signal
import threading

g_flagRun = True
g_arrThread = []
g_lstImport = []
g_lstLoadTables = []
con_ini = 'cfg_SyncDatabase.ini'
parser = ConfigParser.ConfigParser()
parser.read(con_ini)

#处理程序退出信号
def sigCHandler(signum, frame):
    print("Input exit signal!")
    #current_time = time.time()
    global g_RunFlag
    g_RunFlag = False

    for thr in g_arrThread :
        thr.join()     #线程阻塞,等待至线程中止

    print(datetime.datetime.now().strftime('%Y%m%d%H%M%S:') + '应用程序退出;')
    sys.exit(0)

#线程管理
def thdJoin(threads):
    #令主线程阻塞,等待子线程执行完才继续,使用这个方法比使用join的好处是,可以ctrl+c kill掉进程
    for thr in threads:
        while True:
            if thr.isAlive():
                time.sleep(1)
            else:
                break

#连接MYSQL数据库
def dbconnectMySql(param):
    param['port'] = int(param['port'])
    conn = None
    try:
        conn = pymysql.Connect(**dict(param))
    except Exception as e:
        print(e)

    return conn
#连接ORACLE数据库
def dbconnectOracle():
    conn =  None
    try:
        dsn = dict(parser.items('oracle'))['user'] + '/' + dict(parser.items('oracle'))['password'] + '@' + dict(parser.items('oracle'))['host'] + '/' + dict(parser.items('oracle'))['database']
        conn = cx_Oracle.connect(dsn)
    except Exception as e:
        print(e)

    return conn

#判断变量类型的函数
def getVariateType(tVariate):
    type='None'
    if isinstance(tVariate,int):
        type = "INT"
    elif isinstance(tVariate,str):
        type = "STRING"
    elif isinstance(tVariate,float):
        type = "FLOAT"
    elif isinstance(tVariate,list):
        type = "LIST"
    elif isinstance(tVariate,tuple):
        type = "TUPLE"
    elif isinstance(tVariate,dict):
        type = "DICT"
    elif isinstance(tVariate,set):
        type = "SET"
    elif isinstance(tVariate,datetime.datetime):
        type = "DATETIME"
    return type

#导入表定义
class ST_IMPORT:
    def __init__(self):
        self.name = ''          #表名
        self.sqlTemplet = ''    #sql模板字符串
        self.sqlArgs = []       #模板字符串的参数
        self.count = 0          #导入数量
#运行程序
def start():
    print(datetime.datetime.now().strftime('%Y%m%d%H%M%S:') + '应用程序启动;')

    # 
    # for tName in lstTablesName:
    #     st = ST_LOAD_TABLE()
    #     st.name =  tName
    #     st.sqlTemplet = ''

    #     g_lstLoadTables.append(st)
    #添加线程
    g_arrThread.append(threading.Thread(target=thdReadDeal))
    g_arrThread.append(threading.Thread(target=thdWriteDeal))

    #线程启动
    for thr in g_arrThread :
        thr.setDaemon(True)       #设置此线程是否被主线程守护回收。默认False不回收,需要在 start 方法前调用;设为True相当于像主线程中注册守护,主线程结束时会将其一并回收
        thr.start()               #启动线程
    
    #捕获退出信号的方式退出程序
    for sig in [signal.SIGINT,  # 键盘中 Ctrl-C 组合键信号
               signal.SIGHUP,  # 发送给具有Terminal的Controlling Process,当terminal 被disconnect时候发送
                signal.SIGTERM  # 命令行数据 kill pid 时的信号
               ]:
        signal.signal(sig, sigCHandler)

    thdJoin(g_arrThread)        #自定义线程阻塞管理        thd.join()                #线程阻塞,等待至线程中止
    
#导入语句生成线程
def thdReadDeal():
    print(datetime.datetime.now().strftime('%Y%m%d%H%M%S:') + '导入语句生成线程启动;')
    global g_flagRun
    iSyncOriginal = int(dict(parser.items('sync'))['original'])     #数据源
    dbOrig = None
    if iSyncOriginal == 1:
        dbOrig = dbconnectOracle()
    elif iSyncOriginal == 2:
        dbOrig = dbconnectMySql(dict(parser.items('mysql')))
    elif iSyncOriginal != 3 and iSyncOriginal != 4 :
        print(datetime.datetime.now().strftime('%Y%m%d%H%M%S:') + '未定义的数据来源,导入语句生成线程结束;')
        g_flagRun = False
        return
    #_______________________________从数据库导入_________________________________________________________________________________________________________________________Beg
    if iSyncOriginal in [1,2] and dbOrig != None:
        lstTablesName = (dict(parser.items('sync'))['tables']).split(",")   #从配置文件中取出需导入的表名
        nRowEvery = int(dict(parser.items('sync'))['rowevery'])             #从配置文件中取出每次提交数量
        for tTablesName in lstTablesName :
            #进程关闭线程退出
            if not g_flagRun:
                break 
            #等待前表完成
            while len(g_lstImport) > 0 and tTablesName[0] != tTablesName and g_flagRun:
                time.sleep(1)
                           
            strSqlTemplet = ''   #sql模板字符串  
            tFieldName = '' #拼接字段名称
            tFieldVal = '' #接接字段值
            try :
                #获取操作游标
                csOrig = None
                csOrig = dbOrig.cursor()
                strSqlOrig = 'select * from ' + tTablesName 
                print(datetime.datetime.now().strftime('%Y%m%d%H%M%S:')+'查询原表数据('+strSqlOrig+');')
                hRet = csOrig.execute(strSqlOrig)
                if hRet == None:
                    break
                g_lstImport.append(tTablesName)#开始导入将表名插入列表(第一条)
                lstSqlArgs = [] #导入数据参数列表
                while True :
                    row = csOrig.fetchone()
                    if row == None  or not g_flagRun:
                        if len(lstSqlArgs) > 0 and g_flagRun:
                            g_lstImport.append(lstSqlArgs)
                            lstSqlArgs = []
                        g_lstImport.append('END') #结束本次导入作业(最后一条)
                        break
                    #生成导入语句
                    if strSqlTemplet == '':    #首次时需要拼接字段列表
                        for desc in csOrig.description :
                            if tFieldName == '' :
                                tFieldName =  desc[0]
                                tFieldVal = '%s'
                            else :
                                tFieldName = tFieldName + ',' + desc[0]
                                tFieldVal = tFieldVal + ',' + '%s'
                        tFieldName = '('+ tFieldName + ')'
                        tFieldVal = '('+ tFieldVal + ')'
                        #拼接导入语句模板
                        strSqlTemplet = "insert into "+ tTablesName + tFieldName + " values " + tFieldVal
                        g_lstImport.append(strSqlTemplet)   #将模板加入列表(第二条)

                    lstSqlArgs.append(row)  #将数据加入到列表中
                    if len(lstSqlArgs) >= nRowEvery :
                        g_lstImport.append(lstSqlArgs)
                        lstSqlArgs = [] 
                    while len(g_lstImport) > 10 and g_flagRun:
                        time.sleep(1)
                csOrig.close()
            except Exception as e:
                print(datetime.datetime.now().strftime('%Y%m%d%H%M%S:')+str(e)+';[-1]')
                g_flagRun = False
        dbOrig.close()
    #_______________________________数据库导入_________________________________________________________________________________________________________________________End
    #pdb.set_trace()
    #_______________________________从SQL文件导入______________________________________________________________________________________________________________________Beg
    if iSyncOriginal == 3:
        lstTablesName = (dict(parser.items('sync'))['tables']).split(",")   #从配置文件中取出需导入的表名
        nRowEvery = int(dict(parser.items('sync'))['rowevery'])             #从配置文件中取出每次提交数量
        for tTablesName in lstTablesName :
            #进程关闭线程退出
            if not g_flagRun:
                break 
            #等待前表完成
            while len(g_lstImport) > 0 and tTablesName[0] != tTablesName and g_flagRun:
                time.sleep(1)
                           
            try :
                hSqlFile = open(tTablesName+'.sql', 'rb') #打开文件rb方式最快
                g_lstImport.append(tTablesName)         #开始导入将表名插入列表(第一条)
                g_lstImport.append('SQL-FILE')     #将模板加入列表(第二条)
                lstSqlArgs = [] #导入数据参数列表
                for tLine in hSqlFile:
                    if not g_flagRun:
                        g_lstImport.append('END') #结束本次导入作业(最后一条)
                        break
                    if tLine.strip() == '':
                        continue
                    lstSqlArgs.append(tLine)
                    if len(lstSqlArgs) >= nRowEvery :
                        g_lstImport.append(lstSqlArgs)
                        lstSqlArgs = []
                    while len(g_lstImport) > 10 and g_flagRun:
                        time.sleep(1)
            except Exception as e:
                print(datetime.datetime.now().strftime('%Y%m%d%H%M%S:')+str(e)+';[-1]')
                g_flagRun = False
            finally:
                if len(lstSqlArgs) > 0 and g_flagRun:
                    g_lstImport.append(lstSqlArgs)
                    lstSqlArgs = []
                g_lstImport.append('END') #结束本次导入作业(最后一条)
                hSqlFile.close()
    #_______________________________从SQL文件导入______________________________________________________________________________________________________________________End
    #pdb.set_trace()
    #_______________________________从CSV文件导入______________________________________________________________________________________________________________________Beg
    if iSyncOriginal == 4:
        lstTablesName = (dict(parser.items('sync'))['tables']).split(",")   #从配置文件中取出需导入的表名
        nRowEvery = int(dict(parser.items('sync'))['rowevery'])             #从配置文件中取出每次提交数量
        for tTablesName in lstTablesName :
            #进程关闭线程退出
            if not g_flagRun:
                break 
            #等待前表完成
            while len(g_lstImport) > 0 and tTablesName[0] != tTablesName and g_flagRun:
                time.sleep(1)
            strSqlTemplet = ''   #sql模板字符串  
            tFieldName = '' #拼接字段名称
            tFieldVal = '' #接接字段值               
            try :
                hCsvFile = open(tTablesName+'.csv', 'rb') #打开文件rb方式最快
                g_lstImport.append(tTablesName)         #开始导入将表名插入列表(第一条)
                lstSqlArgs = [] #导入数据参数列表
                for tLine in hCsvFile:
                    if not g_flagRun:
                        g_lstImport.append('END') #结束本次导入作业(最后一条)
                        break
                    tLine=[eval(t) for t in  tLine.split(',')]
                    #生成导入语句
                    if strSqlTemplet == '':    #首次时需要拼接字段列表
                        for desc in tLine:                            
                            if tFieldName == '' :
                                tFieldName =  desc
                                tFieldVal = '%s'
                            else :
                                tFieldName = tFieldName + ',' + desc
                                tFieldVal = tFieldVal + ',' + '%s'
                        tFieldName = '('+ tFieldName + ')'
                        tFieldVal = '('+ tFieldVal + ')'
                        #拼接导入语句模板
                        strSqlTemplet = "insert into "+ tTablesName + tFieldName + " values " + tFieldVal
                        g_lstImport.append(strSqlTemplet)   #将模板加入列表(第二条)
                        continue
                    lstSqlArgs.append(tuple(tLine))       #转成元组
                    if len(lstSqlArgs) >= nRowEvery :
                        g_lstImport.append(lstSqlArgs)
                        lstSqlArgs = []
                    while len(g_lstImport) > 10 and g_flagRun:
                        time.sleep(1)
            except Exception as e:
                print(datetime.datetime.now().strftime('%Y%m%d%H%M%S:')+str(e)+';[-1]')
                g_flagRun = False
            finally:
                if len(lstSqlArgs) > 0 and g_flagRun:
                    g_lstImport.append(lstSqlArgs)
                    lstSqlArgs = []
                g_lstImport.append('END') #结束本次导入作业(最后一条)
                hCsvFile.close()
    #_______________________________从CSV文件导入______________________________________________________________________________________________________________________End
     
    g_lstImport.append('EXIT')
    print(datetime.datetime.now().strftime('%Y%m%d%H%M%S:') + '导入语句生成线程结束;')

#导入语句执行线程
def thdWriteDeal(): 
    print(datetime.datetime.now().strftime('%Y%m%d%H%M%S:') + '导入语句执行线程启动;')
    global g_flagRun
    iSyncDestination = int(dict(parser.items('sync'))['destination'])   #从配置文件中取出目标数据库
    lstTablesName = (dict(parser.items('sync'))['tables']).split(",")   #从配置文件中取出导入的表名

    dbDest = None
    if iSyncDestination == 1:
        dbDest = dbconnectOracle()
    elif iSyncDestination == 2:
        dbDest = dbconnectMySql(dict(parser.items('mysql')))
    else :
        print(datetime.datetime.now().strftime('%Y%m%d%H%M%S:') + '未定义的数据目的;')
        g_flagRun = False
        return
    tablesName = ""     #表名
    sqlTemplet = ''     #sql模板字符串
    tmBeg = None
    tmEnd = None
    nTimes = 0
    nRowTotal = 0 
    try :
        csDest = None
        csDest = dbDest.cursor()
        while True :#循环列表            
            if len(g_lstImport) <= 0:
                if not g_flagRun:#线程结束
                    break;
                time.sleep(1)
                continue
            tImport = g_lstImport[0]
            del g_lstImport[0]
            tImportType = getVariateType(tImport)
            if tImportType == 'STRING':
                if tImport in lstTablesName :
                    tablesName = tImport       #表名
                    tmBeg = datetime.datetime.now()
                    nTimes = 0
                    nRowTotal = 0 
                    continue
                elif tImport == "END" :            #最后余量提交
                    tmEnd = datetime.datetime.now()
                    if tmBeg != None and tmEnd !=None: 
                        print(datetime.datetime.now().strftime('%Y%m%d%H%M%S:<<')+tablesName+'>导入完成,共导入'+str(nRowTotal)+'条,总用时'+str((tmEnd-tmBeg).seconds)+'秒;')
                    else:
                        print(datetime.datetime.now().strftime('%Y%m%d%H%M%S:<<')+tablesName+'>>无有效的导入内容;')
                    tablesName = ''
                    nTimes = 0
                    nRowTotal = 0
                    tmBeg = None
                    tmEnd = None
                    continue
                elif tImport == "EXIT" :
                    break  
                elif tablesName !='' and nTimes == 0:
                    sqlTemplet = tImport
                    continue 
            elif tImportType != 'LIST':
                print(datetime.datetime.now().strftime('%Y%m%d%H%M%S:<<')+tablesName+'>>导入列表异常;')
                g_flagRun = False
                break
            
            #print(datetime.datetime.now().strftime('%Y%m%d%H%M%S:')+'执行语句('+strSqlIns+');')
            #使用批量提交
            #pdb.set_trace()
            if sqlTemplet == 'SQL-FILE':
                for tSql in tImport:
                    hRet = csDest.execute(tSql)
                    if hRet == None:
                        break
            else:
                hRet = csDest.executemany(sqlTemplet,tImport)
            if hRet == None:
                print(datetime.datetime.now().strftime('%Y%m%d%H%M%S:<<')+tablesName+'>>执行异常;')
                g_flagRun = False
                break
            dbDest.commit()
            nRowCount = len(tImport)
            nRowTotal = nRowTotal + nRowCount
            nTimes = nTimes + 1
            print(datetime.datetime.now().strftime('%Y%m%d%H%M%S:')+"<<%s>>第%d次导入,本次导入%d条,共导入%d条;"%(tablesName,nTimes,nRowCount,nRowTotal) )

    except Exception as e:
        print(datetime.datetime.now().strftime('%Y%m%d%H%M%S:')+str(e)+';[-1]')
        g_flagRun = False
    print(datetime.datetime.now().strftime('%Y%m%d%H%M%S:') + '导入语句执行线程结束;')
if __name__ == '__main__':
    start()