文章目录
1 搭建环境
1.1 下载 libtorch c++ 版本
由于我当前采用虚拟机形式,使用cpu处理器,没有使用GPU处理器。
wget https://download.pytorch.org/libtorch/cpu/libtorch-cxx11-abi-shared-with-deps-2.0.0%2Bcpu.zip
unzip libtorch-cxx11-abi-shared-with-deps-2.0.0+cpu.zip
1.2 下载数字数据集
wget https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz
wget https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz
wget https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz
wget https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz
gunzip t10k-images-idx3-ubyte.gz
gunzip t10k-labels-idx1-ubyte.gz
gunzip train-images-idx3-ubyte.gz
gunzip train-labels-idx1-ubyte.gz
1.3 目录结构
xhome@ubuntu:~/project_ai$ ls
data image libtorch mnist_demo
2 源码目录
xhome@ubuntu:~/project_ai/mnist_demo$ ls
build CMakeLists.txt src
xhome@ubuntu:~/project_ai/mnist_demo/src$ ls
dataset.h main.cpp model.cpp model.h predict.cpp
2.1 CMakeLists.txt文件
xhome@ubuntu:~/project_ai/mnist_demo$ cat CMakeLists.txt
cmake_minimum_required(VERSION 3.0)
project(mnist_demo)
set(CMAKE_CXX_STANDARD 14)
set(CMAKE_CXX_STANDARD_REQUIRED ON)
# 直接在这里设置 LibTorch 路径(修改为你的实际路径)
#set(CMAKE_PREFIX_PATH "${CMAKE_CURRENT_SOURCE_DIR}/libtorch")
# 或者使用绝对路径
set(CMAKE_PREFIX_PATH "/home/xhome/project_ai/libtorch")
find_package(Torch REQUIRED)
find_package(OpenCV REQUIRED)
# 添加源文件
add_executable(mnist_demo
src/main.cpp
src/model.cpp
)
# 添加预测程序
add_executable(mnist_predict
src/predict.cpp
src/model.cpp
)
# 包含头文件目录
target_include_directories(mnist_demo PRIVATE ${
CMAKE_CURRENT_SOURCE_DIR}/src)
# 链接 LibTorch
target_link_libraries(mnist_demo "${TORCH_LIBRARIES}" ${
OpenCV_LIBS} pthread)
target_link_libraries(mnist_predict ${
TORCH_LIBRARIES} ${
OpenCV_LIBS} pthread)
# 设置编译选项
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${TORCH_CXX_FLAGS}")
2.2 训练集主程序 main.cpp
xhome@ubuntu:~/project_ai/mnist_demo/src$ cat main.cpp
#include <torch/torch.h>
#include "model.h"
#include "dataset.h"
#include <iostream>
#include <iomanip>
#include <chrono>
// 训练一个epoch
template<typename DataLoader>
void train(
Net& model,
DataLoader& data_loader,
torch::optim::Optimizer& optimizer,
size_t epoch,
size_t dataset_size) {
model.train();
size_t batch_idx = 0;
size_t processed = 0;
float last_loss = 0.0f; // 添加这行来存储最后一个loss值
for (auto& batch : data_loader) {
auto data = batch.data;
auto target = batch.target;
optimizer.zero_grad();
auto output = model.forward(data);
auto loss = torch::nll_loss(output, target);
loss.backward();
optimizer.step();
processed += batch.data.size(0);
last_loss = loss.template item<float>(); // 保存最后一个loss值
if (batch_idx++ % 100 == 0) {
std::cout << "\rTrain Epoch: " << epoch
<< " [" << processed << "/"
<< dataset_size
<< " (" << std::fixed << std::setprecision(1)
<< (100.0 * processed / dataset_size)
<< "%)] Loss: " << std::fixed << std::setprecision(4)
<< last_loss << std::flush;
}
}
// 在epoch结束时显示最终进度
std::cout << "\rTrain Epoch: " << epoch
<< " [" << dataset_size << "/"
<< dataset_size
<< " (100.0%)] Loss: " << std::fixed << std::setprecision(4)
<< last_loss << std::endl;
}
// 测试模型
template<typename DataLoader>
void test(
Net& model,
DataLoader& data_loader,
size_t dataset_size) {
model.eval();
double test_loss = 0;
int32_t correct = 0;
torch::NoGradGuard no_grad;
for (const auto& batch : data_loader) {
auto data = batch.data;
auto target = batch.target;
auto output = model.forward(data);
test_loss += torch::nll_loss(
output,
target,
/*weight=*/{
},
torch::Reduction::Sum
).template item<float>();
auto pred = output.argmax(1);
correct += pred.eq(target).sum().template item<int64_t>();
}
test_loss /= dataset_size;
double accuracy = 100.0 * correct / dataset_size;
std::cout << "Test set: Average loss: " << std::fixed << std::setprecision(4)
<< test_loss << ", Accuracy: " << correct << "/"
<< dataset_size << " (" << std::fixed << std::setprecision(1)
<< accuracy << "%)\n";
}
int main() {
try {
// 设置随机种子
torch::manual_seed(1);
// 使用CPU
torch::Device device(torch::kCPU);
std::cout << "Training on CPU." << std::endl;
// 创建模型
auto model = std::make_shared<Net>();
model->to(device);
// 创建数据集
const std::string data_path = "data/MNIST/raw";
std::cout << "Loading training dataset..." << std::endl;
auto train_dataset = MNISTDataset(data_path, MNISTDataset::Mode::kTrain)
.map(torch::data::transforms::Stack<>());
std::cout << "Loading test dataset..." << std::endl;
auto test_dataset = MNISTDataset(data_path, MNISTDataset::Mode