import pymysql


class DataDict(object):

    def __init__(self, db):
        self.host = '127.0.0.1'
        self.port = 3306
        self.user = 'root'
        self.password = 'root'
        self.database = db

    def run(self, tables):
        # 建立连接
        try:
            conn = pymysql.connect(host=self.host,
                                   port=self.port,
                                   user=self.user,
                                   password=self.password,
                                   database=self.database)
            cursor = conn.cursor(cursor=pymysql.cursors.DictCursor)
            print('数据库[%s:%d]连接成功' % (self.host, self.port))
        except Exception:
            print('数据库[%s:%d]连接失败' % (self.host, self.port))
            exit(1)

        # 创建文件
        file_name = self.database + '.md'
        dict_file = open(file_name, mode='a', encoding='UTF-8')

        # 需要生成的表
        tables_arr = []
        if tables == 'all':
            sql = "SELECT `TABLE_NAME`, `TABLE_COMMENT` FROM information_schema.TABLES WHERE TABLE_SCHEMA= '%s'" % self.database
            cursor.execute(sql)
            for table in cursor.fetchall():
                table_name = table['TABLE_NAME']
                tables_arr.append(table_name)
        else:
            tables_arr = table_str.split(',')

        # 遍历获取表信息
        for table_name in tables_arr:
            table_name = table_name.strip()

            # 1、获取表信息
            sql = "SELECT `TABLE_COMMENT` FROM information_schema.TABLES WHERE TABLE_SCHEMA = '%s' AND `TABLE_NAME` = '%s'" % (self.database, table_name)
            cursor.execute(sql)
            if cursor.rowcount == 0:
                print('[%s]表不存在' % table_name)
                continue
            table_info = cursor.fetchone()
            table_comment = table_info['TABLE_COMMENT']
            dict_file.write('#### %s %s' % (table_name, table_comment))
            dict_file.write('\n| 字段名称 | 字段类型 | 允许NULL | 默认值 | 索引 | 字段注释 |')
            dict_file.write('\n| --- | --- | --- | --- | --- | --- |')

            # 2、获取表结构
            field_str = "COLUMN_NAME, COLUMN_TYPE, IS_NULLABLE, COLUMN_DEFAULT, COLUMN_KEY, COLUMN_COMMENT"
            sql = "SELECT %s FROM information_schema.COLUMNS WHERE TABLE_SCHEMA='%s' AND `TABLE_NAME` = '%s'" % (field_str, self.database, table_name)
            cursor.execute(sql)
            for field in cursor.fetchall():
                column_name = field['COLUMN_NAME']
                column_type = field['COLUMN_TYPE']
                column_nullable = field['IS_NULLABLE']
                column_key = field['COLUMN_KEY']
                column_default = str(field['COLUMN_DEFAULT'])
                column_comment = field['COLUMN_COMMENT']
                info = "| %s | %s | %s | %s | %s | %s |" % (column_name, column_type, column_nullable, column_default, column_key, column_comment)
                dict_file.write('\n' + info)

            dict_file.write('\n')
            print('[%s]表生成完毕' % table_name)

        # 关闭连接
        dict_file.close()
        cursor.close()
        conn.close()
        print('[%s]字典已生成' % file_name)


# 程序执行入口
if __name__ == '__main__':
    db = input('请输入数据库名:')
    dataDict = DataDict(db)
    while True:
        tables = input('输入表名<t1,t2,t3>获取指定表,或输入<all>获取全部,输入<q>退出:\n')
        if tables == 'q':
            print('Exit...')
            break
        dataDict.run(tables)