篇1:SVM原理及多分类python代码实例讲解(鸢尾花数据)
SVM原理
支持向量机(Support Vector Machine,SVM),主要用于小样本下的二分类、多分类以及回归分析,是一种有监督学习的算法。基本思想是寻找一个超平面来对样本进行分割,把样本中的正例和反例用超平面分开,其原则是使正例和反例之间的间隔最大。
SVM学习的基本想法是求解能够正确划分训练数据集并且几何间隔最大的分离超平面。如下图所示,wx+b=0即为分离超平面,对于线性可分的数据集来说,这样的超平面有无穷多个(即感知机),但是几何间隔最大的分离超平面却是唯一的。
SVM实现分类代码
1.数据集介绍——鸢尾花数据集
下载方式:通过UCI Machine Learning Repository下载或者直接使用代码
from sklearn.datasets import load_iris
数据展示与介绍(iris.data)
Iris.data中有5个属性,包括4个预测属性(萼片长度、萼片宽度、花瓣长度、花瓣宽度)和1个类别属性(Iris-setosa、Iris-versicolor、Iris-virginica三种类别)。首先,需要将第五列类别信息转换为数字,再选择输入数据和标签。
2.多分类python代码(二分类可看做只有两类的多分类)
from sklearn import svm #引入svm包
import numpy as np
import matplotlib.pyplot as plt
import matplotlib
from sklearn.model_selection import train_test_split
#定义字典,将字符与数字对应起来
def Iris_label(s):
it={b'Iris-setosa':0, b'Iris-versicolor':1, b'Iris-virginica':2}
return it[s]
#读取数据,利用np.loadtxt()读取text中的数据
path='iris.data' #将下载的原始数据放到项目文件夹,即可不用写路径
data= np.loadtxt(path, dtype=float, delimiter=',', converters={4:Iris_label}) #分隔符为‘,'
#确定输入和输出
x,y=np.split(data,(4,),axis=1) #将data按前4列返回给x作为输入,最后1列给y作为标签值
x=x[:,0:2] #取x的前2列作为svm的输入,为了便于可视化展示
#划分数据集和标签:利用sklearn中的train_test_split对原始数据集进行划分,