sklearn.datasets.make_classification函数用法简介
郝伟 2021/08/08

简介

sklearn.datasets.make_classification函数主要作用:随机生成n类分类问题的数据。最初创建关于边长为 2*class_sepn_informative 维超立方体的顶点正态分布 (std=1) 的点簇,并为每个类分配相同数量的簇。 它引入了这些特征之间的相互依赖性,并为数据添加了各种类型的进一步噪声。在没有改组的情况下,X 按以下顺序水平堆叠特征:主要的 n_informative 特征,然后是信息特征的 n_redundant 线性组合,然后是 n_repeated 重复项,随机抽取并替换信息和冗余特征。 其余特征充满随机噪声。 因此,无需改组,所有有用的特征都包含在列 X[:, :n_informative + n_redundant + n_repeated] 中。

接口说明

参数说明

函数的参数非常多,所有参数如下:

使用经验

虽然参数非常多,但是在使用的时候基本就以下几个

示例

这是一个显示多张效果图的示例代码,显示效果在代码后。

#作为参考,这里还加入了两个类似的函数:
#make_blobs:生成简化的变量。
#make_multilabel_classification:独立生成多个标签。

import matplotlib.pyplot as plt
from sklearn.datasets import make_classification
from sklearn.datasets import make_blobs
from sklearn.datasets import make_gaussian_quantiles

fontSize='small'
plt.figure(figsize=(8, 8))
plt.subplots_adjust(bottom=.05, top=.9, left=.05, right=.95)

# 图1:1个信息特征,每个类1个集群
plt.subplot(321)
plt.title("One informative feature, one cluster per class", fontsize=fontSize)
xs1, ys1 = make_classification(n_features=2, n_redundant=0, n_informative=1, n_clusters_per_class=1, shuffle=False)
plt.scatter(xs1[:, 0], xs1[:, 1], marker='o', c=ys1,  s=25, edgecolor='k')
 

# 图2:2个信息特征,每个类1个集群
plt.subplot(322)
plt.title("Two informative features, one cluster per class", fontsize=fontSize)
xs1, ys1 = make_classification(n_features=2, n_redundant=0, n_informative=2, n_clusters_per_class=1)
plt.scatter(xs1[:, 0], xs1[:, 1], marker='o', c=ys1, s=25, edgecolor='k')

# 图3:2个信息特征,每个类2个集群
plt.subplot(323)
plt.title("Two informative features, two clusters per class", fontsize=fontSize)
xs2, ys2 = make_classification(n_features=2, n_redundant=0, n_informative=2)
plt.scatter(xs2[:, 0], xs2[:, 1], marker='o', c=ys2, s=25, edgecolor='k')

# 图4:多分类,2个信息特征,1个集群
plt.subplot(324)
plt.title("Multi-class, two informative features, one cluster",fontsize=fontSize)
xs1, ys1 = make_classification(n_features=2, n_redundant=0, n_informative=2, n_clusters_per_class=1, n_classes=3)
plt.scatter(xs1[:, 0], xs1[:, 1], marker='o', c=ys1,  s=25, edgecolor='k')

# 图1:3个斑点
plt.subplot(325)
plt.title("Three blobs", fontsize=fontSize)
xs1, ys1 = make_blobs(n_features=2, centers=3)
plt.scatter(xs1[:, 0], xs1[:, 1], marker='o', c=ys1,  s=25, edgecolor='k')

# 图1:高斯划分三等分
plt.subplot(326)
plt.title("Gaussian divided into three quantiles", fontsize=fontSize)
xs1, ys1 = make_gaussian_quantiles(n_features=2, n_classes=3)
plt.scatter(xs1[:, 0], xs1[:, 1], marker='o', c=ys1,  s=25, edgecolor='k')

plt.show()
▶︎
all
running...
2021-08-11T09:22:37.234634 image/svg+xml Matplotlib v3.3.4, https://matplotlib.org/