【Python实现机器遗忘算法】复现2023年TNNLS期刊算法UNSIR
1 算法原理
Tarun A K, Chundawat V S, Mandal M, et al. Fast yet effective machine unlearning[J]. IEEE Transactions on Neural Networks and Learning Systems, 2023.
本文提出了一种名为 UNSIR(Unlearning with Single Pass Impair and Repair) 的机器遗忘框架,用于从深度神经网络中高效地卸载(遗忘)特定类别数据,同时保留模型对其他数据的性能。以下是算法的主要步骤:
1. 零隐私设置(Zero-Glance Privacy Setting)
- 假设:用户请求从已训练的模型中删除其数据(例如人脸图像),并且模型无法再访问这些数据,即使是为了权重调整。
- 目标:在不重新训练模型的情况下,使模型忘记特定类别的数据,同时保留对其他数据的性能。
2. 学习误差最大化噪声矩阵(Error-Maximizing Noise Matrix)
-
初始化:随机初始化噪声矩阵 N,其大小与模型输入相同。
-
优化目标:通过最大化模型对目标类别的损失函数来优化噪声矩阵 N。具体优化问题为:
argNminE(θ)=−L(f,y)+λ∥wnoise∥ argNminE(θ)=−L(f,y)+λ∥wnoise∥ argNminE(θ)=−L(f,y)+λ∥wnoise∥其中:
- L(f,y) 是针对要卸载的类别的分类损失函数。
- λ∥wnoise∥ 是正则化项,防止噪声值过大。
- 使用交叉熵损失函数 L 和 L2 归一化。
-
噪声矩阵的作用:生成的噪声矩阵 N 与要卸载的类别标签相关联,用于在后续步骤中破坏模型对这些类别的记忆。
3. 单次损伤与修复(Single Pass Impair and Repair)
- 损伤步骤(Impair Step):
- 操作:将噪声矩阵 N 与保留数据子集Dr结合,训练模型一个周期(epoch)。
- 目的:通过高学习率(例如 0.02)快速破坏模型对要卸载类别的权重。
- 结果:模型对要卸载类别的性能显著下降,同时对保留类别的性能也会受到一定影响。
- 修复步骤(Repair Step):
- 操作:仅使用保留数据子集 Dr再次训练模型一个周期(epoch),学习率较低(例如 0.01)。
- 目的:恢复模型对保留类别的性能,同时保持对要卸载类别的遗忘效果。
- 结果:最终模型在保留数据上保持较高的准确率,而在卸载数据上准确率接近于零。
2 Python代码实现
相关函数
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torchvision import datasets, transforms, models
from torch.utils.data import DataLoader, Subset,TensorDataset
from torch.amp import autocast, GradScaler
import numpy as np
import matplotlib.pyplot as plt
import os
import warnings
import random
from copy import deepcopy
random.seed(2024)
torch.manual_seed(2024)
np.random.seed(2024)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
warnings.filterwarnings("ignore")
MODEL_NAMES = "MLP"
# 设置设备
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# 定义三层全连接网络
class MLP(nn.Module):
def __init__(self):
super(MLP, self).__init__()
self.fc1 = nn.Linear(28 * 28, 256)
self.fc2 = nn.Linear(256, 128)
self.fc3 = nn.Linear(128, 10)
def forward(self, x):
x = x.view(-1, 28 * 28)
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
x = self.fc3(x)
return x
# 加载MNIST数据集
def load_MNIST_data(batch_size,forgotten_classes,ratio):
transform = transforms.Compose([transforms.ToTensor()])
train_data = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
test_data = datasets.MNIST(root='./data', train=False, download=True, transform=transform)
train_loader