【Python】Matplotlib基本用法(学习笔记)——参考《Python数据科学手册》
本文的参考资料:O’reilly出版的《Python数据科学手册》,该书在Github上开源,采用jupyter notebook编写。
Matplotlib官网:https://matplotlib.org/ Matplotlib画廊:https://matplotlib.org/gallery/index.html
以下使用MPL作为matplotlib的简称。
常用技巧
常用的Matplotlib导入方式
import numpy as np
import pandas as pd
import matplotlib as mpl
import matplotlib.pyplot as plt
或者,如果你喜欢更接近于Matlab的使用方式:
import pandas as pd
from matplotlib.pylab import *
选择绘图风格
plt.style.avaliable # 查看可用的风格
plt.style.use('ggplot') # 使用某一种风格
两种画图接口
plt.plot(x, y)
plt.shw()
fig, ax = plt.figure(), plt.axes()
ax.plot(x, y)
fig.show()
如何显示画好的图像(脚本、shell、notebook):
如果是在脚本中使用MPL,或者是在原版python console中使用MPL,需要显式调用show函数来显示图像;
如果是在ipython中使用MPL,可以调用show一次性显示图形,也可以调用%matplotlib
魔法命令,这样每一步绘图操作都会实时更新到窗口中;
如果是在jupyter notebook中使用MPL,可以使用%matplotlib inline
魔法命令,这样画出的图形能够实时更新但是不能交互,或者使用%matplotlib notebook
魔法命令,显示可以交互的图像(画3D图是建议使用)。
简易线条
在MPL中绘制线条使用plot函数,可选的参数有linestyle、marker、linewidth、markersize、color等,更多参数请参考文档;
设置标题使用plt.title;
设置图例使用plt.legend;
这是xy轴标签使用plt.xlabel, plt.ylabel;
设置xy轴范围使用plt.xlim, plt.ylim, 或者plt.axis;
简易散点图
包导入:
%matplotlib inline
from matplotlib.pylab import *
plt.style.use('classic')
可以直接使用plot画出散点图:
x = np.random.randn(100)
y = np.random.randn(100)
plt.plot(x, y, linestyle='', marker='o')
plt.grid()
plt.show()
还可以使用专用的散点图函数:scatter:
s = np.random.randn(100) * 100
c = np.random.randn(100)
plt.scatter(x=x, y=y, s=s, c=c, cmap=plt.cm.RdGy)
plt.colorbar()
plt.grid()
plt.show()
其中,s代表每一个点的大小,c代表每一个点的颜色(如果c是标量,则所有点的颜色相同);
因为scatter需要一个一个渲染数据点,所以速度会比较慢;
在图像中显示误差
包导入
%matplotlib inline
from matplotlib.pylab import *
plt.style.use('classic')
如果需要绘制离散的误差,可以使用errorbar函数
x = linspace(-5, 5, 20)
err = np.random.randn(20)
y = 1.5 * x + 2 + err
plt.errorbar(x=x, y=y, yerr=err, capsize=4, elinewidth=2, ecolor='gray', color='black', fmt='o')
plt.grid()
plt.show()
如果需要绘制连续的误差,可以使用fill_between函数
x = linspace(-5, 5, 100)
y = sin(x) + sin(2 * x) + sin(4 * x)
err1 = np.random.randn(100) / 10 + 0.5
err2 = np.random.randn(100) / 10 + 0.5
plt.plot(x, y, linestyle='-', color='black', linewidth=2)
plt.fill_between(x, y - err1, y + err2, color='gray', alpha=0.5)
plt.grid()
plt.show()
等高线
包导入
%matplotlib inline
from matplotlib.pylab import *
plt.style.use('classic')
数据准备
meshgird函数可以方便的从一维向量来生成二维网格:
x = linspace(-5, 5, 200)
y = linspace(-5, 5, 200)
x, y = np.meshgrid(x, y)
z = np.sin(x) ** 10 + np.cos(x * y + 10) + np.cos(y)
使用contour函数绘制等高线图
如果指定colors参数,则使用colors指定的单一的颜色绘制等高线,并使用虚线表示负值:
plt.contour(x, y, z, colors='black')
plt.show()
如果不指定colors参数,则使用不同的颜色代表不同的数值大小,可以使用cmap参数指定颜色映射方案:
plt.contour(x, y, z, cmap='viridis')
plt.colorbar()
plt.show()
使用contourf绘制填充颜色的等高线图
plt.contourf(x, y, z, cmap='viridis')
plt.colorbar()
plt.show()
如果需要让颜色过渡的更加自然,可以使用imshow函数
plt.imshow(z, extent=[-5, 5, -5, 5], origin='lower', cmap='viridis')
plt.colorbar()
plt.show()
由于imshow是为了显示图片设计的,我们只能用etent参数指定xy的范围,又因为图像的坐标原点一般选取为左上角,我们需要使用origin参数将原点设定为左下角:
将两类等高线结合起来,可以获得特殊的效果
c = plt.contour(x, y, z, colors='black')
plt.clabel(c, inline=True, fontsize=8)
plt.imshow(z, extent=[-5, 5, -5, 5], origin='lower', cmap='viridis')
plt.colorbar()
plt.show()
直方图
包导入
%matplotlib inline
from matplotlib.pylab import *
plt.style.use('classic')
使用hist函数绘制一维直方图
hist函数底层使用的numpy提供的histogram函数对输入数据进行处理,(MPL)hist函数的返回值就是(numpy)histogram函数的返回值;
hist函数有很多参数,有很多用法,详细信息建议参考文档,这里给出两个示例:
data1 = np.random.randn(1000)
data2 = np.random.randn(1000)
plt.hist([data1, data2], bins=30, color=['red', 'green'], histtype='barstacked', edgecolor='white', label=['data1', 'data2'])
plt.legend(loc='upper right')
plt.show()
plt.hist([data1 - 3, data2, data3 + 3], bins=30, histtype='stepfilled', alpha=0.3, label=['data1', 'data2', 'data3'])
plt.legend(loc='upper right')
plt.grid()
plt.show()
使用hist2d绘制二维直方图
x, y = np.random.multivariate_normal([0, 0], [[1, 1], [1, 2]], 10000).T
plt.hist2d(x, y, bins=40, cmap='Blues')
plt.colorbar(label='this is colorbar')
plt.show()
使用hexbin绘制二维直方图
x, y = np.random.multivariate_normal([0, 0], [[1, 1], [1, 2]], 10000).T
plt.hexbin(x, y, gridsize=30, cmap='Blues')
plt.colorbar()
plt.show()
配置图例
包导入
%matplotlib inline
import pandas as pd
from matplotlib.pylab import *
from matplotlib.legend import Legend
plt.style.use('classic')
修改默认的图例样式
x = linspace(-5, 5, 100)
plt.plot(x, sin(x), label='sin')
plt.plot(x, cos(x), label='cos')
plt.legend(loc='upper right', frameon=True, fancybox=True, framealpha=0.1, shadow=True)
只为指定的线条绘制图例
只有在绘制时添加了label参数的线条,legend函数才会为他们绘制图例,如果只需要为图中部分线条绘制图例,推荐使用这种做法。
为散点图添加图例
data = pd.read_csv('./PythonDataScience/notebooks/data/california_cities.csv')
pos_X, pos_Y = data['longd'], data['latd']
population, area = data['population_total'], data['area_total_km2']
plt.scatter(
x=pos_X, y=pos_Y,
s=area, c=np.log10(population),
cmap='viridis', linewidths=0.1, alpha=0.5,
)
plt.xlabel('longitude')
plt.ylabel('latitude')
plt.colorbar(label='log$_{10}$population')
plt.title('areaandpopulation')
for a in[100,300,500]:
plt.scatter(x=[],y=[],s=a,c='g',linewidth=0.5,alpha=0.3,label=str(a)+'km$^2$')
plt.legend(frameon=False,title='cityarea',labelspacing=1,scatterpoints=1)
plt.show()
在调用scatter时,如果使用空列表作为x,y参数,则表示是在为图例绘制散点;
添加多个图例
MPL默认不能支持添加多个图例,但是可以通过调用底层的add_artist方法实现效果:
x = linspace(-5, 5, 100)
fig, ax = plt.figure(), plt.axes()
lines = []
lines += ax.plot(x, sin(x))
lines += ax.plot(x, sin(x + 0.3 * pi))
lines += ax.plot(x, sin(x + 0.6 * pi))
lines += ax.plot(x, sin(x + 0.9 * pi))
leg1 = Legend(
ax,
lines[:2],
loc='upper right',
frameon=True,
framealpha=0.3,
labels=['line 1', 'line 2']
)
ax.add_artist(leg1)
leg2 = Legend(
ax,
lines[2:],
loc='upper left',
frameon=True,
framealpha=0.7,
labels=['line 3', 'line 4']
)
ax.add_artist(leg2)
plt.show()
配置颜色调
包导入
from matplotlib.pylab import *
plt.style.use('classic')
设置颜色映射值的上下限
有时,我们只想让一定范围内的数据与颜色进行映射,忽略超出范围的数据(比如噪声),这时候,可以使用clim函数设定映射范围,并可以在调用colorbar函数时使用extend参数,效果如下:
x = linspace(0, 10, 1000)
z = sin(x) * cos(x[:, np.newaxis])
noise = np.random.random(z.shape)
z[noise < 0.01] = np.random.normal(0, 3, np.count_nonzero(noise < 0.01)) # 加噪声
plt.subplot(1, 2, 1)
plt.imshow(z, cmap='RdGy')
plt.colorbar()
plt.subplot(1, 2, 2)
plt.imshow(z, cmap='RdGy')
plt.clim(-1, 1)
plt.colorbar(extend='both')
plt.show()
很明显,右边的效果要比左边的理想。
使用离散的颜色条
如果不想要渐变的颜色效果而想要离散的颜色效果,可以使用plt.cm.getcmap
函数选择颜色映射方案并指定想要的区间数量:
x = y = np.linspace(0, 10, 100)
x, y = np.meshgrid(x, y)
z = np.sin(x) * np.cos(y)
plt.contourf(x, y, z, cmap=plt.cm.get_cmap('RdGy', lut=8))
plt.colorbar()
plt.show()
多子图设置
包导入
%matplotlib inline
import pandas as pd
from matplotlib.pylab import *
plt.style.use('classic')
使用Figure对象的add_axes函数手动添加子图:
fig = plt.figure()
ax1 = fig.add_axes([0.1, 0.1, 0.1, 0.1])
ax2 = fig.add_axes([0.2, 0.2, 0.2, 0.2])
x = linspace(-5, 5, 100)
y1 = sin(x)
y2 = cos(x)
ax1.plot(x, y1)
ax2.plot(x, y2)
使用subplot函数创建简易的子图
有两种方式可以实现,子图的下标从1开始:
plt.subplots_adjust(wspace=0.5, hspace=0.3)
for i in range(1, 7):
plt.subplot(2, 3, i)
plt.plot(x, sin(x + i * pi / 6))
plt.grid()
fig = plt.figure()
fig.subplots_adjust(hspace=0.3, wspace=0.4)
for i in range(1, 7):
ax = fig.add_subplot(2, 3, i)
ax.plot(x, sin(x + i * pi / 6))
ax.grid()
使用subplots函数创建大型的网格图
这个函数创建的子图的下表从0开始:
fig, axs = plt.subplots(3, 4, sharex='col', sharey='row')
for i in range(3):
for j in range(4):
ax = axs[i, j]
ax.text(0.5, 0.5, f'{i, j}', fontsize=18, ha='center')
ax.grid()
使用GridSpec函数创建复杂的排列
mean = [0, 0]
cov = [[1, 1], [1, 2]]
X, Y = np.random.multivariate_normal(mean, cov, 3000).T
fig = plt.figure()
grid = plt.GridSpec(4, 4, hspace=0.2, wspace=0.2)
main_ax = fig.add_subplot(grid[:-1, 1:])
x_ax = fig.add_subplot(grid[-1, 1:])
y_ax = fig.add_subplot(grid[:-1, 0])
main_ax.plot(X, Y, linestyle='', color='black', marker='o', markersize=3, alpha=0.2)
_ = x_ax.hist(X, 40, histtype='stepfilled', orientation='vertical', color='grey')
x_ax.invert_yaxis()
_ = y_ax.hist(Y, 40, histtype='stepfilled', orientation='horizontal', color='grey')
y_ax.invert_xaxis()
文字与注释
包导入
from matplotlib.pylab import *
plt.style.use('classic')
使用text添加注释
text是添加注释的最简单的方式,只需要传入文字的内容大小对齐方式和位置即可,可选的定位模式有三种:
fig = plt.figure(figsize=(10, 10))
ax = plt.axes()
ax.axis([0, 10, 0, 10])
ax.text(1, 5, '. (1, 5)data', transform=ax.transData)
ax.text(0.5, 0.3, '. (0.5, 0.3)Axes', transform=ax.transAxes)
ax.text(0.2, 0.4, '. (0.2, 0.4)Figure', transForm=fig.transFigure)
使用annotate函数创建复杂的注释
详情参照annotate函数的文档;
示例如下:
plt.plot()
plt.axis([0, 10, 0, 10])
plt.annotate('example 1', xy=(2, 3), xytext=(5, 6), arrowprops=dict(facecolor='black', shrink=0.05))
plt.annotate('example 2', xy=(9, 9), xytext=(7, 2), arrowprops=dict(arrowstyle='->', connectionstyle='angle3,angleA=0,angleB=-90'))
自定义坐标轴与刻度
主要是自定义坐标的Loactor和Formatter
包导入
from matplotlib.pylab import *
plt.style.use('classic')
主要坐标与次要坐标
创建一副空的图,查看其主要坐标与次要坐标:
ax = plt.axes(xscale='log', yscale='log')
ax.axis([1, 1000, 1, 1000])
ax.grid()
print(ax.xaxis.get_major_locator())
print(ax.xaxis.get_minor_locator())
print(ax.xaxis.get_major_formatter())
print(ax.xaxis.get_minor_formatter())
使用NullLocator与NullFormatter隐藏坐标
ax = plt.axes()
ax.grid()
ax.plot(np.random.randn(50))
ax.yaxis.set_major_formatter(plt.NullFormatter())
ax.xaxis.set_major_locator(plt.NullLocator())
使用MaxNLocator设置坐标的最多刻度数
_, _ = plt.subplots(4, 4, sharex='col', sharey='row')
fig, ax = plt.subplots(4, 4, sharex='col', sharey='row')
for i in range(4):
for j in range(4):
ax[i, j].grid()
ax[i, j].xaxis.set_major_locator(plt.MaxNLocator(2))
ax[i, j].yaxis.set_major_locator(plt.MaxNLocator(2))
使用MultipleLocator与FuncFormatter自定义坐标的位置与文字
x = linspace(-5, 5, 100)
y1, y2 = sin(x), cos(x)
ax = plt.axes()
ax.plot(x, y1); ax.plot(x, y2)
ax.xaxis.set_major_locator(plt.MultipleLocator(pi / 2))
ax.xaxis.set_minor_locator(plt.MultipleLocator(pi / 4))
def format_func(value, tick_number):
n = int(round(2 * value / pi))
return f'{n/2}$\pi$'
ax.xaxis.set_major_formatter(plt.FuncFormatter(format_func))
常用的Locator与Formatter
Matplotlib的外观自定义
包导入
from matplotlib.pylab import *
plt.style.use('classic')
自定义外观
当前MPL的外观设定参数储存在plt.rc
中(runtime configuration),使用plt.style.available
可以查看可用的预定义样式,使用plt.rc
函数可以自定义某项外观属性,示例:
_ = plt.rcParams # 查看当前配置
from matplotlib import cycler
colors = cycler('color',
['#EE6666', '#3388BB', '#9988DD',
'#EECC55', '#88BB44', '#FFBBBB'])
plt.rc('axes', facecolor='#E6E6E6', edgecolor='none',
axisbelow=True, grid=True, prop_cycle=colors)
plt.rc('grid', color='w', linestyle='solid')
plt.rc('xtick', direction='out', color='gray')
plt.rc('ytick', direction='out', color='gray')
plt.rc('patch', edgecolor='#E6E6E6')
plt.rc('lines', linewidth=2)
plt.subplot(1, 2, 2)
plt.plot(np.random.randn(50), label='a')
plt.plot(np.random.randn(30), label='b')
plt.plot(np.random.randn(20), label='c')
plt.legend()
plt.subplot(1, 2, 1)
plt.hist(np.random.randn(50))
Matplotlib中预置的一些样式
fivethirtyeight:
with plt.style.context('fivethirtyeight'):
x = linspace(-5, 5, 100)
plt.subplot(1, 2, 2)
plt.plot(x, sin(x), label='a')
plt.plot(x, sin(x) + sin(2 * x) + 2, label='b')
plt.plot(x, sin(2) + sin(2 * x) + sin(4 * x) + 4, label='c')
plt.legend(loc='upper right')
plt.subplot(1, 2, 1)
plt.hist(np.random.randn(50))
plt.title('FiveThirtyEight')
plt.show()
ggplot:
with plt.style.context('ggplot'):
x = linspace(-5, 5, 100)
plt.subplot(1, 2, 2)
plt.plot(x, sin(x), label='a')
plt.plot(x, sin(x) + sin(2 * x) + 2, label='b')
plt.plot(x, sin(2) + sin(2 * x) + sin(4 * x) + 4, label='c')
plt.legend(loc='upper right')
plt.subplot(1, 2, 1)
plt.hist(np.random.randn(50))
plt.title('ggplot')
plt.show()
bmh:
with plt.style.context('bmh'):
x = linspace(-5, 5, 100)
plt.subplot(1, 2, 2)
plt.plot(x, sin(x), label='a')
plt.plot(x, sin(x) + sin(2 * x) + 2, label='b')
plt.plot(x, sin(2) + sin(2 * x) + sin(4 * x) + 4, label='c')
plt.legend(loc='upper right')
plt.subplot(1, 2, 1)
plt.hist(np.random.randn(50))
plt.title('bmh')
plt.show()
dark_background:
with plt.style.context('dark_background'):
x = linspace(-5, 5, 100)
plt.subplot(1, 2, 2)
plt.plot(x, sin(x), label='a')
plt.plot(x, sin(x) + sin(2 * x) + 2, label='b')
plt.plot(x, sin(2) + sin(2 * x) + sin(4 * x) + 4, label='c')
plt.legend(loc='upper right')
plt.subplot(1, 2, 1)
plt.hist(np.random.randn(50))
plt.title('dark_background')
plt.show()
grayscale:
with plt.style.context('grayscale'):
x = linspace(-5, 5, 100)
plt.subplot(1, 2, 2)
plt.plot(x, sin(x), label='a')
plt.plot(x, sin(x) + sin(2 * x) + 2, label='b')
plt.plot(x, sin(2) + sin(2 * x) + sin(4 * x) + 4, label='c')
plt.legend(loc='upper right')
plt.subplot(1, 2, 1)
plt.hist(np.random.randn(50))
plt.title('grayscale')
plt.show()
三维图表
包导入
from mpl_toolkits import mplot3d
from matplotlib.pylab import *
plt.style.use('ggplot')
显示一个空的三维坐标轴
fig, ax = plt.figure(), plt.axes(projection='3d')
三维的线与点
x = linspace(0, 10, 1000)
y = sin(x)
z = cos(x)
ax.plot3D(x, y, z)
x_data = 15 * np.random.random(100)
y_data = sin(x_data) + 0.1 * np.random.randn(100)
z_data = cos(x_data) + 0.1 * np.random.randn(100)
ax.scatter3D(x_data, y_data, z_data, c=z_data, cmap='viridis')
fig.colorbar(cm.ScalarMappable())
三维等高线图
x = linspace(-5, 5, 100)
X, Y = meshgrid(x, x)
Z = sin(sqrt(X ** 2 + Y ** 2))
ax.contour3D(X, Y, Z, 100, cmap='viridis')
ax.set_xlabel('x')
ax.set_ylabel('y')
ax.set_zlabel('z')
ax.view_init(60, 35) # 调整视角
fig.colorbar(cm.ScalarMappable())
网格图
x = linspace(-5, 5, 100)
X, Y = meshgrid(x, x)
Z = sin(sqrt(X ** 2 + Y ** 2))
ax.plot_wireframe(X, Y, Z, cmap='viridis')
ax.set_title('wireframe')
曲面图
x = linspace(-5, 5, 100)
X, Y = meshgrid(x, x)
Z = sin(sqrt(X ** 2 + Y ** 2))
ax.plot_surface(X, Y, Z, cmap='viridis')
ax.set_title('surface')
在三维坐标中使用极坐标(手动)
r = linspace(-5, 5, 30)
theta = np.linspace(0, 0.75 * pi, 40)
r, theta = meshgrid(r, theta)
X = r * sin(theta)
Y = r * cos(theta)
Z = sin(sqrt(X ** 2 + Y ** 2))
ax.plot_surface(X, Y, Z, cmap='viridis')
fig.colorbar(cm.ScalarMappable())
曲面三角剖分
如果要使用网格图或者曲面图,那么就要求数据点是均匀采样的,而现实中未必满足这样的条件,所以就需要三角面:
theta = 2 * pi * random(1000)
r = 6 * random(1000)
x = r * sin(theta)
y = r * cos(theta)
z = sin(sqrt(x ** 2 + y ** 2))
ax.plot_trisurf(x, y, z, cmap='viridis')
fig.colorbar(cm.ScalarMappable())
画一个莫比乌斯带
theta = linspace(0, 2 * pi, 30)
w = linspace(-0.25, 0.25, 8)
w, theta = meshgrid(w, theta)
phi = 0.5 * theta
r = 1 + w * cos(phi)
x = ravel(r * cos(theta))
y = ravel(r * sin(theta))
z = ravel(w * sin(phi))
from matplotlib.tri import Triangulation
tri = Triangulation(ravel(w), ravel(theta))
ax.plot_trisurf(x, y, z, triangles=tri.triangles, cmap='viridis')
ax.set_xlim(-1, 1)
ax.set_ylim(-1, 1)
ax.set_zlim(-1, 1)
fig.colorbar(cm.ScalarMappable())
fig.show()
使用basemap可视化地理数据
根据Basemap的Github主页的说明,this package is being deprecated in favour of cartopy,所以这里就不将怎么使用basemap了,推荐使用cartopy。
使用Seaborn可视化数据
Seaborn是对Matplotlib的封装,使得一些原本需要大量代码的初始化操作变得简单,并且准备了一种比MPL默认风格更加现代的绘图风格;同时Seaborn也准备了一些经典的数据集供使用,可以在GitHub上下载,或是使用seaborn.load
加载。
包导入
from matplotlib.pylab import *
import pandas as pd
import seaborn as sns
sns.set()
密度图、直方图、KDE
data = np.random.multivariate_normal([0, 0], [[5, 2], [2, 2]], size=2000)
data = pd.DataFrame(data, columns=['x', 'y'])
for col in 'xy':
plt.hist(data[col], density=True, alpha=0.5, label='x')
plt.title('plt.hist')
plt.legend(loc='upper right')
plt.show()
for col in 'xy':
sns.kdeplot(data[col], shade=True)
plt.title('sns.kdeplot')
plt.show()
for col in 'xy':
sns.distplot(data[col], label=col)
plt.legend(loc='upper right')
plt.title('sns.distplot')
plt.show()
sns.kdeplot(data) # 传入二维数据,就可以绘制二维KDE
plt.title('sns.kdeplot 2D')
sns.jointplot('x', 'y', data, kind='kde') # 联合分布 KDE
plt.title('sns.joinplot KDE')
sns.jointplot('x', 'y', data, kind='hex') # 联合分布 HEX
plt.title('sns.joinplot HEX')
使用矩阵图初始化鸢尾花数据
iris = pd.read_csv('./seaborn_dataset/iris.csv')
iris.head()
sns.pairplot(iris, hue='species', size=2.5)
分面频次直方图(以sns的小费数据为例)
tips = pd.read_csv('./seaborn_dataset/tips.csv')
print(tips.head())
tips['tip_pct'] = 100 * tips['tip'] / tips['total_bill']
grid = sns.FacetGrid(tips, row='sex', col='time', margin_titles=True)
grid.map(plt.hist, 'tip_pct', bins=linspace(0, 40, 15))
plt.show()
因子图(以sns的小费数据为例)
with sns.axes_style(style='ticks'):
g = sns.factorplot('day', 'total_bill', 'sex', data=tips, kind='box')
g.set_axis_labels('Day', 'Total Bill')
plt.show()
计算联合分布的同时计算线性回归
sns.jointplot('total_bill', 'tip', data=tips, kind='reg')
条形图
planets = pd.read_csv('./seaborn_dataset/planets.csv')
print(planets.head())
with sns.axes_style('white'):
g = sns.factorplot('year', data=planets, aspect=4, kind='count', hue='method', order=range(2001, 2015))
g.set_ylabels('number of planets discovered')