写在前面:

  1. 文件类型:txt(mtx改改也能用),存储方式:三元组。
  2. heatmap要求:绘制200x200的压缩矩阵,剔除边角料(如423x856的矩阵画出来实际上只用到了前400x800的数据,多余的数据丢弃。因为我的矩阵很大的缘故,这样的数据取舍不影响整体观察)。在heatmap中画出行列比。
  3. 一次性画出整个文件夹下的所有矩阵的heatmap并以图片的形式保存在本文件夹下。(这个文件夹下的所有矩阵文件的命名格式统一,且矩阵文件内容格式也统一!!)
  4. 电脑为Mac,不懂Windows上怎么用python。
  5. python版本为python3。(之前从老师那来的在python2下能运行的代码在python3下运行出错,所以一定要注意版本!!二者语法上会有细微差别。)还有各种什么库,该装的装(详见import…)。

画出来的heatmap是这样滴:

pytorch 矩阵热图 python热力学矩阵_pytorch 矩阵热图

矩阵文件命名,举个栗子🌰:

$top_circuit.@sub@_10.@sub@_11.full_25922_26768.txt
 $top_circuit.@sub@_11.@sub@_15.full_21906_22520.txt

矩阵的txt文件,举个🌰:

主要是体会一下代码中的23、24的意义。我没想到什么好的办法可以直接知道哪一行的内容是行数 列数 非零元数,以及哪一行开始是具体数据(一般就是行数 列数 非零元数的下一行)。像我这次这个格式的txt文件也许可以检测b[][]从哪行开始b[][0]不是“%”,但是我想了一下,成千上万个文件都花一遍时间去检测这个,还不如我自己直接数一下…主要也是因为我所有的矩阵的前23行都长得一样,所以可以这样。有一个学长想用我的代码画两个图,但是他一个文件中从第13行开始是矩阵,而另一个从第一行就是矩阵,像这种情况就不能统一像我这么处理了。也许可以检测一下从哪里开始b[][0]是数字,那么这是矩阵的数据的开始。

pytorch 矩阵热图 python热力学矩阵_数组_02

import os
import numpy as np
import numpy.ma as ma
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
from matplotlib import cm
import matplotlib.patches as mpatches
import csv
import plotly.graph_objects as go
from matplotlib.animation import FuncAnimation
import re
from pylab import *

def IsSubString(SubStrList,Str):
    flag=True
    for substr in SubStrList:
        if not(substr in Str):
            flag=False
    return flag
 
def GetFileList(FindPath,FlageStr):
    FileList=[]
    FileNames=os.listdir(FindPath)
    for fn in FileNames:
         if (IsSubString(FlagStr,fn)):
             fullfilename=os.path.join(FindPath,fn)
             FileList.append(fullfilename)
    return FileList


def drawpic(filepath):
	#以下几行是为了画行列比而准备的,如果不需要画这两根线,以下几行可以删掉。
	#而且不同项目的文件命名方式也不一样,所以这几行大概率直接删掉比较好。
	#删了这里后面就还有一个地方要删,注意⚠️
    #totalcount记录的是文件名称的数组
    #举例说明:filepath = “$top_circuit.@sub@_10.@sub@_11.full_25922_26768.txt”
    totalCount = re.sub("\D"," ",filepath) #匹配非数字字符将其替换为空格
    #此时 totalCount = (空格)10(空格)11(空格)25922(空格)26768(空格)
    #print(totalCount)
    totalCount = totalCount.split() #以空格分割totalcount
    #print(totalCount)
    n = len(totalCount) #记录数组长度
    #print(int(totalCount[a-1]))

    #下面正式开始画矩阵
    file = open(filepath)
    a = file.read()
    b = a.split('\n') #用回车符分割file并存在b[]数组中
    for i in range(len(b)):
        b[i] = b[i].split() #对b数组中的每一个b[i]再次用回车符进行分割并将结果存在数组b[i][]中

    b = np.array(b)
    data = np.zeros((200,200),dtype=float)
    #我的矩阵数据第24行是行数 列数 非零元数,25行开始是具体数据
    row = int(int(b[23][0])/200)

    for i in range(int(b[23][2])):
        j = int(b[i+24][0])
        k = int(b[i+24][1])
        if j%row != 0:
            r = int(j/row)
        else:
            r = int(j/row-1)
    
        if k%row != 0:
            c = int(k/row)
        else:
            c = int(k/row-1)
        
        #剔除边角料。
        if r<200:
            if c<200:
                data[r][c] = data[r][c]+float(b[i+24][2])

    data = np.array(data)
    #print(data)

    #以下三行代码画行列比,不画的删掉
    x = 200*int(totalCount[n-2])/int(totalCount[n-1])
    #画行列比
    plot([x,x],[0,200],color='black',linewidth=1.0) #[x,0][x,200]
    plot([0,200],[x,x],color='black',linewidth=1.0) #[0,x][200,x]

    norm = cm.colors.Normalize(vmin=-10,vmax=10)#设定colorbar的数值范围为-10~10
    plt.imshow(data,cmap=cm.seismic,norm=norm)
    plt.colorbar()
    figname = filepath+'.png' #设定图片保存时的命名格式
    plt.savefig(figname)
    plt.clf()
    #plt.show()

FindPath = '/Users/luoyuchen/Fiels/code/something/heatmap' #需要绘制的矩阵所在文件夹的路径
FlagStr = ['$','txt'] #判断字符,表示收集该文件夹下所有命名中带有“$”和“txt”字样的文件,可根据需要修改判断字符
FileList = GetFileList(FindPath,FlagStr)
findex = 0
for fn in FileList:
    findex = findex+1
    if(os.path.isfile(fn)):
        print(str(findex)+'--->'+fn)
        drawpic(fn)

运行(mac系统):
终端中进入python文件所在的文件夹(我的 python文件命名为heatmap.py)输入:

python3 heatmap.py