Java Spring Boot 使用DJL 部署python训练的PyTorch模型(MNIST)

本文介绍了如何在Java Spring Boot应用中使用DJL库部署和调用Python训练的PyTorch模型,特别是MNIST手写数字识别。讲解了Translator、Criteria和NDArray的概念及用法,包括模型输入输出的转换、数据类型的变更、Tensor运算、切片、赋值和翻转等操作,并展示了MNIST实战的结果。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

Java 使用 DJL 训练模型:https://blog.csdn.net/xundh/category_11361043.html?spm=1001.2014.3001.5515

DJL官网:https://docs.djl.ai/index.html

Python 训练Pytorch模型

本项目采用PyTorch==1.10.0版本训练。

pytorch                   1.10.0          py3.9_cuda11.3_cudnn8_0    pytorch
pytorch-mutex             1.0                        cuda    pytorch
requests                  2.28.1                   pypi_0    pypi
scipy                     1.9.3                    pypi_0    pypi
setuptools                65.6.3             pyhd8ed1ab_0    https://mirrors.tuna.tsinghua.edu.cn/anaconda/cloud/conda-forge
six                       1.16.0                   pypi_0    pypi
tbb                       2021.7.0             h91493d7_1    conda-forge
tk                        8.6.12               h8ffe710_0    https://mirrors.tuna.tsinghua.edu.cn/anaconda/cloud/conda-forge
torchaudio                0.10.0               py39_cu113    pytorch
torchsummary              1.5.1                    pypi_0    pypi
torchvision               0.11.0               py39_cu113    pytorch

修改模型保存方法:

model.eval()  # 模型验证模式
example = torch.rand(1, 1, 28, 28).to(device)  # 模型输入层
traced_script_module = torch.jit.trace(model, example)  # trace
traced_script_module.save('models/{}_model.pt'.format(val_ac))  # 保存模型

Java Spring Boot使用DJL调用模型

pom.xml

<!-- djl 依赖 -->
<dependency>
<groupId>ai.djl.pytorch</groupId>
<artifactId>pytorch-engine</artifactId>
<version>0.19.0</version>
<scope>runtime</scope>
</dependency>

<dependency>
<groupId>ai.djl.pytorch</groupId>
<artifactId>pytorch-jni</artifactId>
<version>1.12.1-0.19.0</version>
<scope>runtime</scope>
</dependency>

<dependency>
<groupId>ai.djl</groupId>
<artifactId>api</artifactId>
<version>0.19.0</version>
</dependency>

<dependency>
<groupId>ai.djl.opencv</groupId>
<artifactId>opencv</artifactId>
<version>0.19.0</version>
</dependency>

Translator

模型的输入类和输出类是可以自定义的,但Pytorch模型不可能接收你自己定义的类对象啊,它只会接受Tensor类型,所以我们就需要使用Translator接口来定义如何将我们的自定义输入输出类转换为Tensor类型。

private Translator
<think>好的,我现在需要解决用户在Java使用DJLPyTorch模型时遇到的ModelNotFoundException问题,特别是“no matching model input output type found”的错误。首先,我得回忆一下DJL的基本工作原理和模型载流程。 用户提到ModelNotFoundException,这通常意味着DJL在指定路径下找不到模型文件。首先应该检查模型文件的存放位置是否正确。DJL默认会在classpath的model目录下查找,所以用户需要确认模型文件是否放在项目的src/main/resources/model目录下,或者通过系统属性指定的其他路径。例如,是否设置了-Dai.djl.repository.zoo.location=path/to/model。 接下来,模型名称和扩展名是否正确?DJL支持的PyTorch模型扩展名通常是.pt或者.pth。用户需要确认文件名是否匹配,比如模型文件是否命名为model.pt,并且在代码中引用的名称是否一致。例如,如果模型文件是my_model.pt,代码中应该使用Criteria.builder().optModelName("my_model")。 然后,模型输入输出类型不匹配的问题。DJL需要明确指定输入输出的数据类型,比如Tensor.class。用户可能在模型时没有正确设置这些类型,导致DJL无法找到匹配的模型。需要检查Criteria中的setTypes方法是否正确调用了,例如.setTypes(Tensor.class, Tensor.class)。 另外,模型格式是否正确?有时候用户可能保存的是整个PyTorch模型(包含结构和参数),而DJL需要的是仅参数的state_dict。这时候需要用torch.save(model.state_dict(), "model.pt")来保存,否则载时会出错。 还有模型元数据的问题,DJL可能需要metadata.json文件来描述模型的输入输出。如果用户没有提供这个文件,或者文件内容不正确,也会导致错误。需要检查是否在模型目录下存在正确的metadata.json,并且其中的输入输出类型与代码中的设置一致。 框架版本兼容性也是一个可能的问题。用户使用DJL版本是否与PyTorch版本兼容?例如,DJL 0.20.0可能只支持PyTorch 1.12.1,如果用户安装了PyTorch 2.0,可能会有兼容性问题。需要确认两者的版本是否匹配。 最后,用户可能需要启用日志来查看更详细的错误信息。通过设置系统属性-Dorg.slf4j.simpleLogger.defaultLogLevel=debug,可以输出DJL的调试日志,帮助定位问题所在。 总结可能的解决步骤:检查模型路径、文件名和扩展名;确认输入输出类型设置正确;检查模型保存方式是否正确;验证metadata.json是否存在且正确;确认DJLPyTorch版本兼容;启用调试日志进一步排查。 现在需要将这些步骤整理成清晰的解决方案,并确保引用正确的文档或资源,比如DJL官方文档中的模型训练和保存部分,以及如何处理模型输入输出类型的问题。同时,根据用户提供的引用,可能还需要考虑向量存储相关的配置,但当前问题更直接相关的是模型载的配置,所以重点放在模型路径、格式和元数据上。</think>### Java DJLPyTorch模型时ModelNotFoundException解决方案 #### 1. 检查模型文件路径和命名规范 - **模型存放路径**:DJL默认从`classpath:model/`目录模型,需确保模型文件位于`src/main/resources/model`目录下。若自定义路径,需通过系统参数指定: ```java -Dai.djl.repository.zoo.location=path/to/model_directory ``` - **文件名要求**:PyTorch模型文件扩展名需为`.pt`或`.pth`,且文件名需与代码中引用的名称一致。例如模型文件为`resnet18.pt`,代码中应配置: ```java Criteria<Image, Classifications> criteria = Criteria.builder() .setTypes(Image.class, Classifications.class) .optModelName("resnet18") // 与文件名一致(不含扩展名) .build(); ``` #### 2. 验证输入输出类型匹配 DJL需显式定义输入输出类型,**类型不匹配是常见错误原因**: ```java // 错误示例:未指定Tensor类型 Criteria.builder().setTypes(Object.class, Object.class); // 正确示例:明确指定输入输出为Tensor Criteria<Tensor, Tensor> criteria = Criteria.builder() .setTypes(Tensor.class, Tensor.class) .optModelPath(Paths.get("model/path")) .optEngine("PyTorch") .build(); ``` #### 3. 检查模型保存格式 - **仅保存参数**:PyTorch模型需以`state_dict`格式保存: ```python torch.save(model.state_dict(), "model.pt") ``` - **完整模型问题**:若保存完整模型(含结构),需额外提供元数据描述输入输出格式[^2]。 #### 4. 添metadata.json文件 在模型目录中添`metadata.json`,示例内容: ```json { "engine": "PyTorch", "input": {"type": "Tensor", "shape": [1, 3, 224, 224]}, "output": {"type": "Tensor", "shape": [1, 1000]} } ``` #### 5. 版本兼容性验证 | DJL版本 | 支持PyTorch版本 | |---------|----------------| | 0.20.0 | 1.12.1 | | 0.22.1 | 2.0.1 | 通过`pom.xml`确保版本匹配: ```xml <dependency> <groupId>ai.djl.pytorch</groupId> <artifactId>pytorch-engine</artifactId> <version>0.22.1</version> </dependency> ``` #### 6. 调试日志分析 启动时添JVM参数: ```bash -Dorg.slf4j.simpleLogger.defaultLogLevel=debug ``` 日志会显示模型载过程: ``` DEBUG ai.djl.repository.zoo - Searching model in: file:/project/model/ DEBUG ai.djl.repository.zoo - Found metadata.json in: /project/model/ ``` #### 完整代码示例 ```java try (ZooModel<Tensor, Tensor> model = ModelZoo.loadModel(criteria)) { try (Predictor<Tensor, Tensor> predictor = model.newPredictor()) { Tensor input = Tensor.newInstance(new float[3*224*224], new Shape(1,3,224,224)); Tensor output = predictor.predict(input); } } ```
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值