小样本学习之原型网络

本文探讨了在小样本分类问题中如何通过原型网络利用度量空间来减少过拟合。核心内容包括样本投影方法、同类样本聚类原理,以及如何计算未标注样本与类别的距离。讨论了距离度量的合理性及原型网络在解决样本稀疏问题上的应用。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

参考链接:https://blog.csdn.net/Snoopy_S/article/details/88420013

在小样本分类问题中,最需要解决的一个问题是数据的过拟合,由于数据过少,一般的分类算法会表现出过拟合的现象,从而导致分类结果与实际结果有较大的误差。为了减少因数据量过少而导致的过拟合的影响,可以使用基于度量的元学习方法,而原型网络便是。在此方法中,需要将样本投影到一个度量空间,且在这个空间中同类样本距离较近,异类样本的距离较远,如图:

在这个投影空间中,存在有三个类别的样本,且相同类别的样本间距离较近。为了给一个未标注样本x进行标注,则将样本x投影至这个空间并计算x距哪个类别较近,则认为x属于哪个类别。

那么,现在有几个问题:

  1. 怎么将这些样本投影至一个空间且让同类样本间距离较近?
  2. 怎么说明一个类别所在的位置?从而能够让未标记的样本计算与类别的距离

原型网络的原理较为简单,但是有一点小问题就是,对于两个或多个样本的相似度,用距离较近来度量是否合理。

 

 

### 小样本学习中的原型网络 #### 1. 原型网络简介 原型网络是一种用于小样本学习的有效方法,其核心思想是通过计算类别的原型向量来完成分类任务。具体而言,在训练阶段,算法会从支持集(support set)中提取每种类别的特征表示,并将其平均化得到该类别的原型向量[^4]。 对于测试阶段的新样本,模型会计算它与各个类别原型之间的距离(通常采用欧氏距离),并将新样本分配给最近的原型所属的类别。 #### 2. 数学描述 假设我们有 $ N $ 类数据,每一类的数据集合为 $ S_i = \{x_{i,j}\}_{j=1}^{k} $ ,其中 $ k $ 表示每个类别的样例数。那么第 $ i $ 类的原型可以定义为: $$ c_i = \frac{1}{|S_i|} \sum_{x_j \in S_i} f(x_j) $$ 这里 $ f(\cdot) $ 是一个嵌入函数,通常是神经网络结构的一部分。最终,对于一个新的输入样本 $ x_q $,它的预测标签可以通过以下方式获得: $$ \hat{y}_q = \arg\min_i d(f(x_q), c_i) $$ 这里的 $ d(\cdot, \cdot) $ 可以选择不同的度量标准,比如欧式距离或者余弦相似度。 #### 3. 改进方向——原型补全网络 (ProtoComNet) 尽管传统的原型网络已经取得了不错的成果,但在某些情况下可能仍然存在不足之处。例如当某一类别的样本数量非常少时,仅依靠这些有限样本计算出来的原型可能会不够准确。因此,《Prototype Completion with Primitive Knowledge for Few-Shot Learning》提出了原型补全网络的概念[^3]。 此方法引入了一个额外的知识源——基础知识模块,用来辅助构建更鲁棒的类别表征。通过这种方式,即使面对极端稀疏的情况也能有效提升性能表现。 ```python import torch from torch import nn class ProtoNet(nn.Module): def __init__(self, input_dim, hidden_dim, output_dim): super(ProtoNet, self).__init__() self.encoder = nn.Sequential( nn.Linear(input_dim, hidden_dim), nn.ReLU(), nn.Linear(hidden_dim, output_dim) ) def forward(self, support_set, query_set): # Extract features from the encoder network. proto_features = self.encoder(support_set).mean(dim=0) # Compute prototype by averaging embeddings of support samples. # Calculate distances between each query sample and prototypes. queries_embedded = self.encoder(query_set) dists = torch.cdist(proto_features.unsqueeze(0), queries_embedded) return -dists.squeeze() # Negative distance as logits. # Example usage: input_size = 784 # For MNIST-like data sets. hidden_layer_size = 512 output_embedding_size = 64 model = ProtoNet(input_size, hidden_layer_size, output_embedding_size) ``` 上述代码片段展示了一个简单的基于PyTorch框架实现的原型网络架构实例。 --- ###
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值