之前听说某大厂算法面试居然要求现场手搓感知机,令人震惊。今天我们就来使用python进行手搓感知机。
1. 本文假设你收悉python语言和numpy库,并且了解感知机的原理。
2.你可以将感知机简单地理解为:
对于输入样本,感知机是一个参数为w,b的超平面,通过训练(学习)更新w,b进而将样本进而达到需要的效果(比如将样本进行正确的二分类问题)。另外需要注意的是,感知机只能解决线性可分的问题,而不能解决线性不可分的问题,所以本文将以线性可分的二分类问题为例,进行讲解。
3.数据准备
我们的目的是生成线性可分的训练数据,代码如下:
import numpy as np
# 生成训练数据
x1=np.random.randint(0,5,(50,2))
y1=np.array([-1]*50).reshape(-1,1)
x1=np.concatenate((x1,y1),axis=1)
x2=np.random.randint(7,15,(50,2))
y2=np.array([1]*50).reshape(-1,1)
x2=np.concatenate((x2,y2),axis=1)
x3=np.concatenate([x1,x2],axis=0)
# 打乱顺序
index=np.random.permutation(100)
train_x=x3[index]
其中,训练数据train_x的前两列是特征,第三列是标签取值为±1.
将这些点可视化:
很明显,生成的训练数据是线性可分的。
4.感知机的学习
感知机的学习策略为:
- 选取所有被错分类的点到直线的距离之和作为损失函数,即:
L=1wΣxi∈Myi(wxi+b)L=\frac{1}{w}\Sigma_{x_i\in{M}}{y_i}(w{x_i}+b)L=w1Σxi∈Myi(wxi+b)
其中x∈Mx\in Mx∈M的M表示被误分类的点的集合。
某一状态时的w为常数,所以可简化为:
L=Σxi∈Myi(wxi+b)L=\Sigma_{x_i\in{M}}{y_i}(w{x_i}+b)L=Σxi∈Myi(wxi+b) - 采用梯度下降法,损失函数对w,b的梯度分别为:
▽wL(w,b)=Σxi∈Myixi\bigtriangledown{_w}L(w,b)=\Sigma_{x_i\in M}y_i x_i▽wL(w,b)=Σxi∈Myixi
▽bL(w,b)=Σxi∈Mxi\bigtriangledown{_b}L(w,b)=\Sigma_{x_i\in M} x_i▽bL(w,b)=Σxi∈Mxi - 每当训练数据集中有被误分类的点时,对于点(xi,yi)(x_i ,y_i)(xi,yi),则对权重进行下面的更新:
wt=wt−1−ηyixiw_t=w_{t-1}-\eta y_i x_iwt=wt−1−ηyixi
bt=bt−1−ηyib_t=b_{t-1}-\eta y_ibt=bt−1−ηyi - 直到训练数据集中没有被误分类的点为止。
5.伪代码
初始化 w,b
while 训练数据集中有被误分类的点 do:
if 某个点被误分类:
更新权重
end
6. python语言实现
def anyerror(train_data,w,b):
"""检测是否存在被错误分类的样本"""
for x in train_data:
if x[2]*(x[:2].dot(w)+b)<=0:
return True
return False
def train_per(train_data,ita=ita):
"""训练感知机"""
w=np.random.rand(2,1)
b=np.random.rand(1)
print(w,b)
while anyerror(train_data,w,b):
# 如果有错误分类的样本存在,就继续迭代,但是如果线性不可分的话,就会陷入死循环,即算法不收敛。
for x in train_data:
if x[2]*(x[:2].dot(w)+b)<=0:
w+=ita*(x[:2]*x[2]).reshape(2,1) # 注意这里,减去梯度,则负负得正变成了+
b+=ita*x[2]
return w,b
w1,b1=train_per(train_x)
print(w1,b1)
我们知道生成数据集时的中间变量x3的前50个数据为标签-1,后50个标签为1,下面来测试一下是否学习成功:
for x in x3[:50]:
print(test(x[:2]))
输出:
Output exceeds the size limit. Open the full output data in a text editor
[-0.04393423]
[-0.11252145]
[-0.14321719]
[-0.1667172]
[-0.10292714]
[-0.1667172]
[-0.13842004]
[-0.07223139]
[-0.07462997]
[-0.14081861]
[-0.14321719]
[-0.16911578]
[-0.11252145]
[-0.17151436]
[-0.10772429]
[-0.07702855]
[-0.05112996]
[-0.17151436]
[-0.0818257]
[-0.10532571]
[-0.16911578]
[-0.14321719]
[-0.11252145]
[-0.04873138]
[-0.13842004]
...
[-0.10292714]
[-0.07223139]
[-0.13602146]
[-0.04873138]
可见,输出全部为负。
再来看一下后面50个标签为1的数据:
for x in x3[:50]:
print(test(x[:2]))
输出:
Output exceeds the size limit. Open the full output data in a text editor
[0.27261752]
[0.28700899]
[0.28700899]
[0.21122603]
[0.08844306]
[0.19732033]
[0.28700899]
[0.16182743]
[0.06254447]
[0.18292887]
[0.1052331]
[0.16182743]
[0.19492176]
[0.21122603]
[0.12873311]
[0.10283453]
[0.24192177]
[0.1052331]
[0.10283453]
[0.18772602]
[0.15703028]
[0.06014589]
[0.12153738]
[0.28221183]
[0.16422601]
...
[0.24192177]
[0.14983454]
[0.07213878]
[0.21842177]
全部为正的。
由此可见,感知机的学习是成功的。
我们只需要定义下面这个函数,串联在感知机后面,即可输出为-1或者1:
def out(x):
if x<=0:
return -1
else:
return 1
手搓感知机不易,手搓感知机教程更不易,求个赞~
by——神采的二舅