手写感知机(基于python)

本文介绍了如何使用Python从头实现感知机算法,通过生成线性可分的二分类数据集,运用梯度下降法更新权重和偏置,直至所有样本正确分类,证明了感知机在解决线性可分问题上的有效性。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

之前听说某大厂算法面试居然要求现场手搓感知机,令人震惊。今天我们就来使用python进行手搓感知机。

1. 本文假设你收悉python语言和numpy库,并且了解感知机的原理。
2.你可以将感知机简单地理解为:

对于输入样本,感知机是一个参数为w,b的超平面,通过训练(学习)更新w,b进而将样本进而达到需要的效果(比如将样本进行正确的二分类问题)。另外需要注意的是,感知机只能解决线性可分的问题,而不能解决线性不可分的问题,所以本文将以线性可分的二分类问题为例,进行讲解。

3.数据准备

我们的目的是生成线性可分的训练数据,代码如下:

import numpy as np
# 生成训练数据
x1=np.random.randint(0,5,(50,2))
y1=np.array([-1]*50).reshape(-1,1)
x1=np.concatenate((x1,y1),axis=1)
x2=np.random.randint(7,15,(50,2))
y2=np.array([1]*50).reshape(-1,1)
x2=np.concatenate((x2,y2),axis=1)
x3=np.concatenate([x1,x2],axis=0)

# 打乱顺序
index=np.random.permutation(100)
train_x=x3[index]

其中,训练数据train_x的前两列是特征,第三列是标签取值为±1.
将这些点可视化:
可视化出的训练数据
很明显,生成的训练数据是线性可分的。

4.感知机的学习

感知机的学习策略为:

  • 选取所有被错分类的点到直线的距离之和作为损失函数,即:
    L=1wΣxi∈Myi(wxi+b)L=\frac{1}{w}\Sigma_{x_i\in{M}}{y_i}(w{x_i}+b)L=w1ΣxiMyi(wxi+b)
    其中x∈Mx\in MxM的M表示被误分类的点的集合。
    某一状态时的w为常数,所以可简化为:
    L=Σxi∈Myi(wxi+b)L=\Sigma_{x_i\in{M}}{y_i}(w{x_i}+b)L=ΣxiMyi(wxi+b)
  • 采用梯度下降法,损失函数对w,b的梯度分别为:
    ▽wL(w,b)=Σxi∈Myixi\bigtriangledown{_w}L(w,b)=\Sigma_{x_i\in M}y_i x_iwL(w,b)=ΣxiMyixi
    ▽bL(w,b)=Σxi∈Mxi\bigtriangledown{_b}L(w,b)=\Sigma_{x_i\in M} x_ibL(w,b)=ΣxiMxi
  • 每当训练数据集中有被误分类的点时,对于点(xi,yi)(x_i ,y_i)(xi,yi),则对权重进行下面的更新:
    wt=wt−1−ηyixiw_t=w_{t-1}-\eta y_i x_iwt=wt1ηyixi
    bt=bt−1−ηyib_t=b_{t-1}-\eta y_ibt=bt1ηyi
  • 直到训练数据集中没有被误分类的点为止。
5.伪代码
初始化 w,b
while 训练数据集中有被误分类的点 do:
	if 某个点被误分类:
		更新权重
end
6. python语言实现
def anyerror(train_data,w,b):
    """检测是否存在被错误分类的样本"""
    for x in train_data:
        if x[2]*(x[:2].dot(w)+b)<=0:
            return True
    return False

def train_per(train_data,ita=ita):
    """训练感知机"""
    w=np.random.rand(2,1)
    b=np.random.rand(1)
    print(w,b)
    
    while anyerror(train_data,w,b):
        # 如果有错误分类的样本存在,就继续迭代,但是如果线性不可分的话,就会陷入死循环,即算法不收敛。
        for x in train_data:
            if x[2]*(x[:2].dot(w)+b)<=0:
                w+=ita*(x[:2]*x[2]).reshape(2,1)    # 注意这里,减去梯度,则负负得正变成了+
                b+=ita*x[2]
    return w,b

w1,b1=train_per(train_x)
print(w1,b1)

我们知道生成数据集时的中间变量x3的前50个数据为标签-1,后50个标签为1,下面来测试一下是否学习成功:

for x in x3[:50]:
    print(test(x[:2]))

输出:

Output exceeds the size limit. Open the full output data in a text editor
[-0.04393423]
[-0.11252145]
[-0.14321719]
[-0.1667172]
[-0.10292714]
[-0.1667172]
[-0.13842004]
[-0.07223139]
[-0.07462997]
[-0.14081861]
[-0.14321719]
[-0.16911578]
[-0.11252145]
[-0.17151436]
[-0.10772429]
[-0.07702855]
[-0.05112996]
[-0.17151436]
[-0.0818257]
[-0.10532571]
[-0.16911578]
[-0.14321719]
[-0.11252145]
[-0.04873138]
[-0.13842004]
...
[-0.10292714]
[-0.07223139]
[-0.13602146]
[-0.04873138]

可见,输出全部为负。
再来看一下后面50个标签为1的数据:

for x in x3[:50]:
    print(test(x[:2]))

输出:

Output exceeds the size limit. Open the full output data in a text editor
[0.27261752]
[0.28700899]
[0.28700899]
[0.21122603]
[0.08844306]
[0.19732033]
[0.28700899]
[0.16182743]
[0.06254447]
[0.18292887]
[0.1052331]
[0.16182743]
[0.19492176]
[0.21122603]
[0.12873311]
[0.10283453]
[0.24192177]
[0.1052331]
[0.10283453]
[0.18772602]
[0.15703028]
[0.06014589]
[0.12153738]
[0.28221183]
[0.16422601]
...
[0.24192177]
[0.14983454]
[0.07213878]
[0.21842177]

全部为正的。
由此可见,感知机的学习是成功的。
我们只需要定义下面这个函数,串联在感知机后面,即可输出为-1或者1:

def out(x):
	if x<=0:
		return -1
	else:
		return 1
手搓感知机不易,手搓感知机教程更不易,求个赞~

请添加图片描述
by——神采的二舅

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值