使用Plotting API进行开发#
Scikit-learn定义了一个简单的API,用于创建机器学习的可视化。此API的主要功能是运行一次计算,并在事后灵活地调整可视化。本节面向希望开发或维护绘图工具的开发人员。有关用法,用户应参阅 User Guide .
绘图API概述#
该逻辑被封装到显示对象中,其中存储计算出的数据,并在 plot method. The display object's _ _init___'方法仅包含创建可视化所需的数据。的 `plot 方法接受仅与可视化有关的参数,例如matplotlib轴。的 plot 方法将matplotlib艺术家存储为属性,允许通过显示对象进行样式调整。的 Display 类应该定义一个或两个类方法: from_estimator 和 from_predictions .这些方法允许创建 Display 对象从估计器和一些数据或从真实值和预测值。在这些类方法之后,使用计算值创建显示对象,然后调用显示的plot方法。注意到 plot 方法定义与matplotlib相关的属性,例如线条艺术家。这允许在调用 plot 法
比如说 RocCurveDisplay 定义以下方法和属性::
class RocCurveDisplay:
def __init__(self, fpr, tpr, roc_auc, estimator_name):
...
self.fpr = fpr
self.tpr = tpr
self.roc_auc = roc_auc
self.estimator_name = estimator_name
@classmethod
def from_estimator(cls, estimator, X, y):
# get the predictions
y_pred = estimator.predict_proba(X)[:, 1]
return cls.from_predictions(y, y_pred, estimator.__class__.__name__)
@classmethod
def from_predictions(cls, y, y_pred, estimator_name):
# do ROC computation from y and y_pred
fpr, tpr, roc_auc = ...
viz = RocCurveDisplay(fpr, tpr, roc_auc, estimator_name)
return viz.plot()
def plot(self, ax=None, name=None, **kwargs):
...
self.line_ = ...
self.ax_ = ax
self.figure_ = ax.figure_
阅读更多在 sphx_glr_auto_examples_miscellaneous_plot_roc_curve_visualization_api.py 和 User Guide .
多轴绘图#
一些绘图工具,例如 from_estimator 和 PartialDependenceDisplay 支持多轴绘图。支持两种不同的场景:
1.如果传递轴列表, plot 将检查轴的数量是否与预期的轴的数量一致,然后在这些轴上绘制。2.如果传递了单个轴,该轴定义了放置多个轴的空间。在这种情况下,我们建议使用matplotlib的 ~matplotlib.gridspec.GridSpecFromSubplotSpec 分割空间::
import matplotlib.pyplot as plt
from matplotlib.gridspec import GridSpecFromSubplotSpec
fig, ax = plt.subplots()
gs = GridSpecFromSubplotSpec(2, 2, subplot_spec=ax.get_subplotspec())
ax_top_left = fig.add_subplot(gs[0, 0])
ax_top_right = fig.add_subplot(gs[0, 1])
ax_bottom = fig.add_subplot(gs[1, :])
默认情况下 ax 关键词 plot 是 None .在这种情况下,将创建单个轴,并使用gridspec api来创建要绘制的区域。
例如, from_estimator 它使用此API绘制多条线和轮廓。定义边界框的轴保存在 bounding_ax_ 属性创建的各个轴存储在 axes_ ndray,对应于网格上的轴位置。未使用的位置设置为 None .此外,matplotlib艺术家存储在 lines_ 和 contours_ 其中关键是网格上的位置。当传递轴列表时, axes_ , lines_ ,而且 contours_ 是与传递的轴列表对应的1d nd数组。