十一在海边露营,风景很美~
最近尝试在 Java 里直接推理模型,避免在生产环境中使用 Python,一方面是出于性能和多线程等原因,另一方面还能减少线上环境的复杂度,当团队实施微服务架构时,团队的规模是考量微服务粒度的重要因素之一,因此对小团队来说,能在 Java 里直接调用模型将大大减少服务运维的复杂度,不用担心服务灰度、版本不匹配等问题。
什么是 ONNX?
ONNX(Open Neural Network Exchange)是一个开放的神经网络模型交换格式,旨在促进不同深度学习框架之间的互操作性,它由微软和 Facebook 在 2017 年提出,支持多种流行的深度学习框架,如 PyTorch、TensorFlow 和 Caffe2 等。
通过使用 ONNX,研究人员可以将模型从一个框架中导出,然后在另一个框架中导入并进行推理或训练,这种灵活性有助于在不同平台和工具之间迁移模型,提高开发效率和部署灵活性。
ONNX 还支持硬件加速器的优化,使得模型在不同硬件上运行时能够更高效地利用资源。
再看一眼 PhotoRoom 的分享:
在开始之前,需要准备两个东西:
- 一个 ONNX 模型
- 一个描述文件,描述了模型预期的输入格式,用于正确地将输入数据转换为张量,它可以是一个 JSON 文件:
{
"user_id": {
"name": "user_id",
"max": 56059,
"dtype": "int64"
},
"colors": {
"name": "colors",
"max": 16896,
"max_length": 5,
"dtype": "int64_list",
"separator": "|"
},
...
}
再准备一个测试的数据集,比如一个 csv 文件,每行是一次推理的参数。
然后就可以开始了。
创建一个推理类:
public class OnnxInference {
private static final String MODEL_PATH = "/path/to/onnx/file";
private static final String CSV_PATH = "/path/to/csv";
private static final String JSON_PATH = "/path/to/features.json";
private static final int BATCH_SIZE = 10;
private static JsonNode featureConfig;
private static void loadFeatureConfig() throws IOException {
ObjectMapper mapper = new ObjectMapper();
featureConfig = mapper.readTree(new File(JSON_PATH));
}
}
featureConfig 记录了模型期望的输入格式配置,预留了一个 BATCH_SIZE,用于批量推理,提高效率,假设每次推理 10 个。
创建一个读取 csv 文件并执行批量推理的函数:
private static void readCsvAndPredict(String csvPath, OrtSession session, OrtEnvironment env) throws IOException, CsvValidationException, OrtException {
try (CSVReader reader = new CSVReader(new FileReader(csvPath))) {
String[] headers = reader.readNext();
List<Map<String, String>> batch = new ArrayList<>();
String[] line;
while ((line = reader.readNext()) != null) {
Map<String, String> row = new HashMap<>();
for (int i = 0; i < headers.length; i++) {
row.put(headers[i], line[i]);
}
batch.add(row);
if (batch.size() == BATCH_SIZE) {
processBatch(batch, session, env);
batch.clear();
}
}
if (!batch.isEmpty()) {
processBatch(batch, session, env);
}
}
}
然后是最重要的推理函数:
private static void processBatch(List<Map<String, String>> batch, OrtSession session, OrtEnvironment env) throws OrtException {
Map<String, OnnxTensor> inputs = new HashMap<>();
for (Iterator<String> it = featureConfig.fieldNames(); it.hasNext(); ) {
String featureName = it.next();
JsonNode feature = featureConfig.get(featureName);
String dtype = feature.get("dtype").asText();
int maxLength = feature.has("max_length") ? feature.get("max_length").asInt() : 1;
if (dtype.startsWith("int64")) {
long[][] featureData = new long[batch.size()][maxLength];
fillLongFeatureData(batch, featureName, feature, dtype, maxLength, featureData);
OnnxTensor tensor = OnnxTensor.createTensor(env, featureData);
inputs.put(featureName, tensor);
} else if (dtype.startsWith("float32")) {
float[][] featureData = new float[batch.size()][maxLength];
fillFloatFeatureData(batch, featureName, feature, dtype, maxLength, featureData);
OnnxTensor tensor = OnnxTensor.createTensor(env, featureData);
inputs.put(featureName, tensor);
} else {
System.out.println("警告:未知的数据类型 " + dtype + " 用于特征 " + featureName);
}
}
try (OrtSession.Result result = session.run(inputs)) {
// 获取预测结果(假设是第三个输出)
OnnxValue outputValue = result.get(2);
System.out.println("+----------------------+--------------+");
System.out.println("| 用户ID | 预测结果 |");
System.out.println("+======================+==============+");
if (outputValue instanceof OnnxTensor) {
OnnxTensor outputTensor = (OnnxTensor) outputValue;
float[][] predictions = (float[][]) outputTensor.getValue();
for (int i = 0; i < batch.size(); i++) {
String userId = batch.get(i).get("user_id");
float prediction = predictions[i][0];
System.out.printf("| %-20s | %12.8f |%n", userId, prediction);
}
} else {
System.out.println("警告:输出不是预期的张量类型");
}
System.out.println("+----------------------+--------------+");
System.out.println();
}
}
先根据 dtype
把 inputs
准备好,由于 ONNX 还不支持不规则张量,所以要有填充过程:
private static void fillLongFeatureData(List<Map<String, String>> batch, String featureName, JsonNode feature, String dtype, int maxLength, long[][] featureData) {
for (int i = 0; i < batch.size(); i++) {
String value = batch.get(i).get(featureName);
if (value == null || value.isEmpty()) {
Arrays.fill(featureData[i], 0L);
continue;
}
if (dtype.endsWith("_list")) {
String separator = feature.get("separator").asText();
String[] sequence = value.split(Pattern.quote(separator));
for (int j = 0; j < Math.min(sequence.length, maxLength); j++) {
featureData[i][j] = Long.parseLong(sequence[j]);
}
} else {
featureData[i][0] = Long.parseLong(value);
}
}
}
private static void fillFloatFeatureData(List<Map<String, String>> batch, String featureName, JsonNode feature, String dtype, int maxLength, float[][] featureData) {
for (int i = 0; i < batch.size(); i++) {
String value = batch.get(i).get(featureName);
if (value == null || value.isEmpty()) {
System.out.println("警告:特征 " + featureName + " 在批次中缺失或为空");
Arrays.fill(featureData[i], 0.0f);
continue;
}
if (dtype.endsWith("_list")) {
String separator = feature.get("separator").asText();
String[] sequence = value.split(Pattern.quote(separator));
for (int j = 0; j < Math.min(sequence.length, maxLength); j++) {
featureData[i][j] = Float.parseFloat(sequence[j]);
}
} else {
featureData[i][0] = Float.parseFloat(value);
}
}
}
这样就能通过 OrtSession
推理出结果了,然后在对结果的处理上偷了个懒,正常情况下不会用索引,我只输出了第 3 个值,从规范性角度,我们要有一个同时包含了 inputs
和 outpus
的描述文件,然后把推理 SDK 做成一个通用的 SDK,接受描述文件和模型文件即可进行推理。
打印如下:
最后附上完整的代码:
import ai.onnxruntime.*;
import com.opencsv.CSVReader;
import com.opencsv.exceptions.CsvValidationException;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.fasterxml.jackson.databind.JsonNode;
import java.io.File;
import java.io.FileReader;
import java.io.IOException;
import java.nio.FloatBuffer;
import java.util.*;
import java.util.regex.Pattern;
public class OnnxInference {
private static final String MODEL_PATH = "/path/to/onnx/file";
private static final String CSV_PATH = "/path/to/csv";
private static final String JSON_PATH = "/path/to/features.json";
private static final int BATCH_SIZE = 10;
private static JsonNode featureConfig;
public static void main(String[] args) {
try {
loadFeatureConfig();
try (OrtEnvironment env = OrtEnvironment.getEnvironment();
OrtSession session = env.createSession(MODEL_PATH)) {
readCsvAndPredict(CSV_PATH, session, env);
} catch (OrtException | IOException | CsvValidationException e) {
e.printStackTrace();
}
} catch (IOException e) {
e.printStackTrace();
}
}
private static void loadFeatureConfig() throws IOException {
ObjectMapper mapper = new ObjectMapper();
featureConfig = mapper.readTree(new File(JSON_PATH));
}
private static void readCsvAndPredict(String csvPath, OrtSession session, OrtEnvironment env) throws IOException, CsvValidationException, OrtException {
try (CSVReader reader = new CSVReader(new FileReader(csvPath))) {
String[] headers = reader.readNext();
List<Map<String, String>> batch = new ArrayList<>();
String[] line;
while ((line = reader.readNext()) != null) {
Map<String, String> row = new HashMap<>();
for (int i = 0; i < headers.length; i++) {
row.put(headers[i], line[i]);
}
batch.add(row);
if (batch.size() == BATCH_SIZE) {
processBatch(batch, session, env);
batch.clear();
}
}
if (!batch.isEmpty()) {
processBatch(batch, session, env);
}
}
}
private static void processBatch(List<Map<String, String>> batch, OrtSession session, OrtEnvironment env) throws OrtException {
Map<String, OnnxTensor> inputs = new HashMap<>();
for (Iterator<String> it = featureConfig.fieldNames(); it.hasNext(); ) {
String featureName = it.next();
JsonNode feature = featureConfig.get(featureName);
String dtype = feature.get("dtype").asText();
int maxLength = feature.has("max_length") ? feature.get("max_length").asInt() : 1;
if (dtype.startsWith("int64")) {
long[][] featureData = new long[batch.size()][maxLength];
fillLongFeatureData(batch, featureName, feature, dtype, maxLength, featureData);
OnnxTensor tensor = OnnxTensor.createTensor(env, featureData);
inputs.put(featureName, tensor);
} else if (dtype.startsWith("float32")) {
float[][] featureData = new float[batch.size()][maxLength];
fillFloatFeatureData(batch, featureName, feature, dtype, maxLength, featureData);
OnnxTensor tensor = OnnxTensor.createTensor(env, featureData);
inputs.put(featureName, tensor);
} else {
System.out.println("警告:未知的数据类型 " + dtype + " 用于特征 " + featureName);
}
}
try (OrtSession.Result result = session.run(inputs)) {
// 获取预测结果(假设是第三个输出)
OnnxValue outputValue = result.get(2);
System.out.println("+----------------------+--------------+");
System.out.println("| 用户ID | 预测结果 |");
System.out.println("+======================+==============+");
if (outputValue instanceof OnnxTensor) {
OnnxTensor outputTensor = (OnnxTensor) outputValue;
float[][] predictions = (float[][]) outputTensor.getValue();
for (int i = 0; i < batch.size(); i++) {
String userId = batch.get(i).get("user_id");
float prediction = predictions[i][0];
System.out.printf("| %-20s | %12.8f |%n", userId, prediction);
}
} else {
System.out.println("警告:输出不是预期的张量类型");
}
System.out.println("+----------------------+--------------+");
System.out.println();
}
}
private static void fillLongFeatureData(List<Map<String, String>> batch, String featureName, JsonNode feature, String dtype, int maxLength, long[][] featureData) {
for (int i = 0; i < batch.size(); i++) {
String value = batch.get(i).get(featureName);
if (value == null || value.isEmpty()) {
Arrays.fill(featureData[i], 0L);
continue;
}
if (dtype.endsWith("_list")) {
String separator = feature.get("separator").asText();
String[] sequence = value.split(Pattern.quote(separator));
for (int j = 0; j < Math.min(sequence.length, maxLength); j++) {
featureData[i][j] = Long.parseLong(sequence[j]);
}
} else {
featureData[i][0] = Long.parseLong(value);
}
}
}
private static void fillFloatFeatureData(List<Map<String, String>> batch, String featureName, JsonNode feature, String dtype, int maxLength, float[][] featureData) {
for (int i = 0; i < batch.size(); i++) {
String value = batch.get(i).get(featureName);
if (value == null || value.isEmpty()) {
System.out.println("警告:特征 " + featureName + " 在批次中缺失或为空");
Arrays.fill(featureData[i], 0.0f);
continue;
}
if (dtype.endsWith("_list")) {
String separator = feature.get("separator").asText();
String[] sequence = value.split(Pattern.quote(separator));
for (int j = 0; j < Math.min(sequence.length, maxLength); j++) {
featureData[i][j] = Float.parseFloat(sequence[j]);
}
} else {
featureData[i][0] = Float.parseFloat(value);
}
}
}
}