博客
关于我
强烈建议你试试无所不能的chatGPT,快点击我
scikit-learn Quick Start
阅读量:4494 次
发布时间:2019-06-08

本文共 3419 字,大约阅读时间需要 11 分钟。

 

  Time:2017/02/24  21:50  at UTSZ

  Environment: pyCharm, python2.7

 

  一般来讲,学习是指利用一些已知的样例数据来预测未知数据的属性。

1. 我们可以将学习问题分为如下的类别:

  1. 监督学习:样例数据有自己的标签。
    1. 分类:要预测的值是离散的
    2. 回归:要预测的值是连续的
  2. 无监督学习:样例数据没有自己的标签。

 

2. 机器学习的常用步骤(python):

  1. 导入要用的包(sklearn)
  2. 加载训练数据
  3. 确定要用的训练算法(原例是用的svm)
  4. 训练,算法内部得到一些参数的值
  5. 预测
from sklearn import datasetsfrom sklearn import svm# 加载数据集digits = datasets.load_digits()# 分类算法clf = svm.SVC(gamma=0.001, C=100.)# 训练clf.fit(digits.data[:-1], digits.target[:-1])# 预测result = clf.predict(digits.data[-1:])# 输出结果print result

  

3. 训练模型的保存和重新加载

在实际的编码中,我们经常要将之前训练过的模型重复使用,为了不必每次训练,节省时间和空间,我们需要将当前训练过的模型保存起来,便于下次的使用。

这里用到了 pickle。

from sklearn import datasetsfrom sklearn import svmimport pickle# 加载数据iris = datasets.load_iris()X, y = iris.data, iris.target# 确定算法clf = svm.SVC()# 训练clf.fit(X, y)# 保存训练后的模型s = pickle.dumps(clf)# 加载之前保存的模型clf2 = pickle.loads(s)# 预测result = clf2.predict(X[0:1])# 输出结果print resultprint y[0]

 注:pickle会有一些安全性和维护性的问题。

 对于大数据而言,使用joblib.dump and joblib.load更好,它会将模型保存到磁盘中,而不仅仅只是保存成一个字符串。

from sklearn.externals import joblib# 保存模型joblib.dump(clf, 'filename.pkl')# 重新加载模型joblib.load('filename.pkl') #其他和pickle一样

 

4. 变量类型问题

  如果没有特别声明,输入将被转换成 float64:

import numpy as npfrom sklearn import random_projectionrng = np.random.RandomState(0)X = rng.rand(10, 2000)# 定义X是float32的数组X = np.array(X, dtype='float32')# 输出查看X的类型print X.dtype# 利用fit_transform(X)将X的类型改变为float64transformer = random_projection.GaussianRandomProjection()X_new = transformer.fit_transform(X)# 输出查看X_new的类型print X_new.dtype

  结果如下:

  

  另外的一个例子:

1 from sklearn import datasets 2 from sklearn.svm import SVC 3  4 iris = datasets.load_iris() 5 clf = SVC() 6 clf.fit(iris.data, iris.target) 7 print list(clf.predict(iris.data[:3])) 8  9 clf.fit(iris.data, iris.target_names[iris.target])10 print list(clf.predict(iris.data[:3]))

results:

因此,第一个预测中,由于训练的目标值target是一个整型数组,所以最后的预测值也是一个整型数组。

        第二个预测中,由于训练的目标值target_names是一个字符串数组,所以最后的预测值也是一个字符串数组。

 

5. 再训练和更新参数

  

import numpy as npfrom sklearn.svm import SVCrng = np.random.RandomState(0)X = rng.rand(100, 10)y = rng.binomial(1, 0.5, 100)X_test = rng.rand(5, 10)clf = SVC()# equal to the two sentences# clf2 = clf.set_params(kernel='linear')# clf2.fit(X, y)clf.set_params(kernel='linear').fit(X, y)print clf.predict(X_test)clf.set_params(kernel='rbf').fit(X, y)print clf.predict(X_test)

  results:

  因此,我们可以看到在算法中选用不同的核函数,会产生不同的预测结果。在实际的应用中,需要不断的调整参数的值。

 

6. 多类标 vs 多类标拟合

  当我们使用多类标分类器(multiclass classifiers)的时候,训练和预测很大程度上依赖于数据的格式。

from sklearn.svm import SVCfrom sklearn.multiclass import OneVsRestClassifierfrom sklearn.preprocessing import LabelBinarizerX = [[1, 2], [2, 4], [4, 5], [3, 2], [3, 1]]y = [0, 0, 1, 1, 2]clf = OneVsRestClassifier(estimator=SVC(random_state=0))print clf.fit(X, y).predict(X)y = LabelBinarizer().fit_transform(y)print clf.fit(X, y).predict(X)

  results:

因此,对于输入是两维来说,输出也是两维的。相互对应的多类标预测。

注:第四个和第五个样例的预测值都是0,这表示它们未能匹配到任意一个类标。

同理,对多于两个类标的情况。

from sklearn.svm import SVCfrom sklearn.multiclass import OneVsRestClassifierfrom sklearn.preprocessing import MultiLabelBinarizerX = [[1, 2], [2, 4], [4, 5], [3, 2], [3, 1]]y = [[0, 1], [0, 2], [1, 3], [0, 2, 3], [2, 4]]clf = OneVsRestClassifier(estimator=SVC(random_state=0))y = MultiLabelBinarizer().fit_transform(y)print clf.fit(X, y).predict(X)

  results:

  在上面的例子中,分类器对每个样例赋予多个类标。MultiLabelBinarizer 用来二值化多类标的训练数据,最终的预测结果也是一个多类标的二位数组。

 

参考资料:http://scikit-learn.org/stable/tutorial/basic/tutorial.html

转载于:https://www.cnblogs.com/aszhaoweiguo/p/6439931.html

你可能感兴趣的文章
h5播放音乐
查看>>
Python 的內建模块
查看>>
每天一个linux命令(54):ping命令
查看>>
Centos下搭建FTP服务器基础笔记
查看>>
jpa-入门.缓存配置ehcache.xml
查看>>
krpano 常用标签
查看>>
21069207《Linux内核原理与分析》第四周作业
查看>>
Linux系统中“动态库”和“静态库”那点事儿
查看>>
《linux备份与恢复之一》.tar.bz2与.tar.gz格式的文本压缩率比较
查看>>
005_nginx414_nginx 414 Request-URI Too Large
查看>>
Spring源码情操陶冶-ContextLoader
查看>>
Spring源码情操陶冶-PathMatchingResourcePatternResolver路径资源匹配溶解器
查看>>
C++数据结构大作业之大数加法、乘法、幂运算
查看>>
C++编程对缓冲区的理解
查看>>
windows下 安装 rabbitMQ 及操作常用命令
查看>>
Linux中 bash_profile和.bashrc的区别(启动文件)
查看>>
Tomcat出现java.lang.Exception: Socket bind failed
查看>>
AngularJS
查看>>
DBCP、C3P0、Proxool 、 BoneCP开源连接池的比较
查看>>
[.NET WebAPI系列01] WebAPI 简单例子
查看>>