PyTorch 与 Spring AI 集成实战


一、前言

大多数深度学习模型仍由 Python 和 PyTorch 驱动,但越来越多的企业希望将这些模型嵌入到 Java 微服务中运行。

Spring AI 提供了灵活的方式,结合 RESTful 接口、容器部署、Tool Calling 和 Agent 架构,使 Java 与 PyTorch 模型之间的协作不再是梦。

本篇将带你完成:

  • PyTorch 模型部署为服务(REST API)
  • Spring AI 调用 PyTorch 模型进行问答、分类或推理
  • 实战示例:中文情感分析模型接入

二、发布PyTorch 模型

Java 无法直接运行 PyTorch 模型,但可以通过以下两种方式调用:

REST API 部署方式

也是本篇推荐使用 FastAPI 或 Flask 将模型包装为 HTTP 接口,第三节将重点介绍。

ONNX 转换方式

ONNX转换适用于通用模型,将模型转换为 ONNX 格式,用 JNI/ONNX Runtime 调用。详见《标准化模型格式ONNX介绍:打通AI模型从训练到部署的环节》

本篇我们将采用第一种REST API 部署方式:用 Python + FastAPI 部署 PyTorch 模型,由 Java 远程调用。


三、构建 PyTorch 服务端

模型保存

# train.py
import torch
model = MyModel()
... # 训练代码
torch.save(model.state_dict(), "sentiment_model.pt")

FastAPI 服务

# app.py
from fastapi import FastAPI, Request
import torch
import torch.nn.functional as F

app = FastAPI()
model = MyModel()
model.load_state_dict(torch.load("sentiment_model.pt"))
model.eval()

@app.post("/predict")
async def predict(request: Request):
    data = await request.json()
    text = data["text"]
    # TODO: text preprocessing & tokenizing
    with torch.no_grad():
        output = model(text)
        pred = F.softmax(output, dim=1).tolist()
    return {"result": pred}

四、Spring AI 调用 PyTorch 模型

使用 RestTemplate 访问

@RestController
public class InferenceController {

  @Autowired RestTemplate restTemplate;

  @PostMapping("/ai/sentiment")
  public String classify(@RequestBody String text) {
    HttpHeaders headers = new HttpHeaders();
    headers.setContentType(MediaType.APPLICATION_JSON);

    Map<String, String> body = Map.of("text", text);
    HttpEntity<Map<String, String>> req = new HttpEntity<>(body, headers);

    String url = "http://localhost:8000/predict";
    ResponseEntity<String> resp = restTemplate.postForEntity(url, req, String.class);
    return resp.getBody();
  }
}

在 Tool Calling 中集成

@AiFunction(name = "sentiment")
public String analyzeSentiment(@AiParam("text") String text) {
    return classify(text);
}

注册为 Spring AI 工具:

List<ToolSpecification> tools = FunctionCallingTools.fromBeans(appContext);
chatClient = new FunctionCallingChatClient(chatClient, tools);

现在,模型就可以被 LLM 调用啦!


五、实战演练:智能客服识别情绪

用户提问

“你们的服务真的太烂了,我再也不会买了!”

PromptTemplate 示例:

String prompt = "请判断以下内容的用户情绪类别(积极、消极、中性):{{text}}";

LLM 返回:消极

Spring AI 调用 PyTorch 接口进行二次验证:

String pyResult = analyzeSentiment(userText);

可用于模型投票融合、异常拦截等场景。


六、总结

通过本文,我们完成了从 PyTorch 模型训练、FastAPI 部署,到 Spring AI 调用推理的完整闭环。

Spring AI 可以将自研模型作为 Tool,嵌入智能 Agent 流程中,与大语言模型协同。

七、参考

《Java驱动AI革命:Spring AI八篇进阶指南——从架构基础到企业级智能系统实战》

在这里插入图片描述

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

勤奋的知更鸟

你的鼓励将是我创作的最大动力!

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值