CodingTour
在 Java 里推理 ONNX 模型

十一在海边露营,风景很美~

最近尝试在 Java 里直接推理模型,避免在生产环境中使用 Python,一方面是出于性能和多线程等原因,另一方面还能减少线上环境的复杂度,当团队实施微服务架构时,团队的规模是考量微服务粒度的重要因素之一,因此对小团队来说,能在 Java 里直接调用模型将大大减少服务运维的复杂度,不用担心服务灰度、版本不匹配等问题。

什么是 ONNX?

ONNX(Open Neural Network Exchange)是一个开放的神经网络模型交换格式,旨在促进不同深度学习框架之间的互操作性,它由微软和 Facebook 在 2017 年提出,支持多种流行的深度学习框架,如 PyTorch、TensorFlow 和 Caffe2 等。

通过使用 ONNX,研究人员可以将模型从一个框架中导出,然后在另一个框架中导入并进行推理或训练,这种灵活性有助于在不同平台和工具之间迁移模型,提高开发效率和部署灵活性。

ONNX 还支持硬件加速器的优化,使得模型在不同硬件上运行时能够更高效地利用资源。

再看一眼 PhotoRoom 的分享:

在开始之前,需要准备两个东西:

  1. 一个 ONNX 模型
  2. 一个描述文件,描述了模型预期的输入格式,用于正确地将输入数据转换为张量,它可以是一个 JSON 文件:
{
  "user_id": {
    "name": "user_id",
    "max": 56059,
    "dtype": "int64"    
  },
	"colors": {
    "name": "colors",
    "max": 16896,
    "max_length": 5,
    "dtype": "int64_list",
    "separator": "|"
  },
  ...
}

再准备一个测试的数据集,比如一个 csv 文件,每行是一次推理的参数。

然后就可以开始了。

创建一个推理类:

public class OnnxInference {
    private static final String MODEL_PATH = "/path/to/onnx/file";
    private static final String CSV_PATH = "/path/to/csv";
    private static final String JSON_PATH = "/path/to/features.json";
  
    private static final int BATCH_SIZE = 10;
    private static JsonNode featureConfig;
  
      private static void loadFeatureConfig() throws IOException {
        ObjectMapper mapper = new ObjectMapper();
        featureConfig = mapper.readTree(new File(JSON_PATH));
    }

}

featureConfig 记录了模型期望的输入格式配置,预留了一个 BATCH_SIZE,用于批量推理,提高效率,假设每次推理 10 个。

创建一个读取 csv 文件并执行批量推理的函数:

private static void readCsvAndPredict(String csvPath, OrtSession session, OrtEnvironment env) throws IOException, CsvValidationException, OrtException {
  try (CSVReader reader = new CSVReader(new FileReader(csvPath))) {
    String[] headers = reader.readNext();
    List<Map<String, String>> batch = new ArrayList<>();
    String[] line;
    while ((line = reader.readNext()) != null) {
      Map<String, String> row = new HashMap<>();
      for (int i = 0; i < headers.length; i++) {
        row.put(headers[i], line[i]);
      }
      batch.add(row);
      if (batch.size() == BATCH_SIZE) {
        processBatch(batch, session, env);
        batch.clear();
      }
    }
    if (!batch.isEmpty()) {
      processBatch(batch, session, env);
    }
  }
}

然后是最重要的推理函数:

private static void processBatch(List<Map<String, String>> batch, OrtSession session, OrtEnvironment env) throws OrtException {
  Map<String, OnnxTensor> inputs = new HashMap<>();

  for (Iterator<String> it = featureConfig.fieldNames(); it.hasNext(); ) {
    String featureName = it.next();
    JsonNode feature = featureConfig.get(featureName);
    String dtype = feature.get("dtype").asText();
    int maxLength = feature.has("max_length") ? feature.get("max_length").asInt() : 1;

    if (dtype.startsWith("int64")) {
      long[][] featureData = new long[batch.size()][maxLength];
      fillLongFeatureData(batch, featureName, feature, dtype, maxLength, featureData);
      OnnxTensor tensor = OnnxTensor.createTensor(env, featureData);
      inputs.put(featureName, tensor);
    } else if (dtype.startsWith("float32")) {
      float[][] featureData = new float[batch.size()][maxLength];
      fillFloatFeatureData(batch, featureName, feature, dtype, maxLength, featureData);
      OnnxTensor tensor = OnnxTensor.createTensor(env, featureData);
      inputs.put(featureName, tensor);
    } else {
      System.out.println("警告:未知的数据类型 " + dtype + " 用于特征 " + featureName);
    }
  }

  try (OrtSession.Result result = session.run(inputs)) {
    // 获取预测结果(假设是第三个输出)
    OnnxValue outputValue = result.get(2);

    System.out.println("+----------------------+--------------+");
    System.out.println("| 用户ID               | 预测结果     |");
    System.out.println("+======================+==============+");

    if (outputValue instanceof OnnxTensor) {
      OnnxTensor outputTensor = (OnnxTensor) outputValue;
      float[][] predictions = (float[][]) outputTensor.getValue();

      for (int i = 0; i < batch.size(); i++) {
        String userId = batch.get(i).get("user_id");
        float prediction = predictions[i][0];
        System.out.printf("| %-20s | %12.8f |%n", userId, prediction);
      }
    } else {
      System.out.println("警告:输出不是预期的张量类型");
    }

    System.out.println("+----------------------+--------------+");
    System.out.println();
  }
}

先根据 dtypeinputs 准备好,由于 ONNX 还不支持不规则张量,所以要有填充过程:

private static void fillLongFeatureData(List<Map<String, String>> batch, String featureName, JsonNode feature, String dtype, int maxLength, long[][] featureData) {
  for (int i = 0; i < batch.size(); i++) {
    String value = batch.get(i).get(featureName);
    if (value == null || value.isEmpty()) {
      Arrays.fill(featureData[i], 0L);
      continue;
    }

    if (dtype.endsWith("_list")) {
      String separator = feature.get("separator").asText();
      String[] sequence = value.split(Pattern.quote(separator));
      for (int j = 0; j < Math.min(sequence.length, maxLength); j++) {
        featureData[i][j] = Long.parseLong(sequence[j]);
      }
    } else {
      featureData[i][0] = Long.parseLong(value);
    }
  }
}

private static void fillFloatFeatureData(List<Map<String, String>> batch, String featureName, JsonNode feature, String dtype, int maxLength, float[][] featureData) {
  for (int i = 0; i < batch.size(); i++) {
    String value = batch.get(i).get(featureName);
    if (value == null || value.isEmpty()) {
      System.out.println("警告:特征 " + featureName + " 在批次中缺失或为空");
      Arrays.fill(featureData[i], 0.0f);
      continue;
    }

    if (dtype.endsWith("_list")) {
      String separator = feature.get("separator").asText();
      String[] sequence = value.split(Pattern.quote(separator));
      for (int j = 0; j < Math.min(sequence.length, maxLength); j++) {
        featureData[i][j] = Float.parseFloat(sequence[j]);
      }
    } else {
      featureData[i][0] = Float.parseFloat(value);
    }
  }
}

这样就能通过 OrtSession 推理出结果了,然后在对结果的处理上偷了个懒,正常情况下不会用索引,我只输出了第 3 个值,从规范性角度,我们要有一个同时包含了 inputsoutpus 的描述文件,然后把推理 SDK 做成一个通用的 SDK,接受描述文件和模型文件即可进行推理。

打印如下:

最后附上完整的代码:

import ai.onnxruntime.*;
import com.opencsv.CSVReader;
import com.opencsv.exceptions.CsvValidationException;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.fasterxml.jackson.databind.JsonNode;

import java.io.File;
import java.io.FileReader;
import java.io.IOException;
import java.nio.FloatBuffer;
import java.util.*;

import java.util.regex.Pattern;

public class OnnxInference {
    private static final String MODEL_PATH = "/path/to/onnx/file";
    private static final String CSV_PATH = "/path/to/csv";
    private static final String JSON_PATH = "/path/to/features.json";

    private static final int BATCH_SIZE = 10;
    private static JsonNode featureConfig;

    public static void main(String[] args) {
        try {
            loadFeatureConfig();
            try (OrtEnvironment env = OrtEnvironment.getEnvironment();
                 OrtSession session = env.createSession(MODEL_PATH)) {

                readCsvAndPredict(CSV_PATH, session, env);

            } catch (OrtException | IOException | CsvValidationException e) {
                e.printStackTrace();
            }
        } catch (IOException e) {
            e.printStackTrace();
        }
    }

    private static void loadFeatureConfig() throws IOException {
        ObjectMapper mapper = new ObjectMapper();
        featureConfig = mapper.readTree(new File(JSON_PATH));
    }

    private static void readCsvAndPredict(String csvPath, OrtSession session, OrtEnvironment env) throws IOException, CsvValidationException, OrtException {
        try (CSVReader reader = new CSVReader(new FileReader(csvPath))) {
            String[] headers = reader.readNext();
            List<Map<String, String>> batch = new ArrayList<>();
            String[] line;
            while ((line = reader.readNext()) != null) {
                Map<String, String> row = new HashMap<>();
                for (int i = 0; i < headers.length; i++) {
                    row.put(headers[i], line[i]);
                }
                batch.add(row);
                if (batch.size() == BATCH_SIZE) {
                    processBatch(batch, session, env);
                    batch.clear();
                }
            }
            if (!batch.isEmpty()) {
                processBatch(batch, session, env);
            }
        }
    }

    private static void processBatch(List<Map<String, String>> batch, OrtSession session, OrtEnvironment env) throws OrtException {
        Map<String, OnnxTensor> inputs = new HashMap<>();

        for (Iterator<String> it = featureConfig.fieldNames(); it.hasNext(); ) {
            String featureName = it.next();
            JsonNode feature = featureConfig.get(featureName);
            String dtype = feature.get("dtype").asText();
            int maxLength = feature.has("max_length") ? feature.get("max_length").asInt() : 1;

            if (dtype.startsWith("int64")) {
                long[][] featureData = new long[batch.size()][maxLength];
                fillLongFeatureData(batch, featureName, feature, dtype, maxLength, featureData);
                OnnxTensor tensor = OnnxTensor.createTensor(env, featureData);
                inputs.put(featureName, tensor);
            } else if (dtype.startsWith("float32")) {
                float[][] featureData = new float[batch.size()][maxLength];
                fillFloatFeatureData(batch, featureName, feature, dtype, maxLength, featureData);
                OnnxTensor tensor = OnnxTensor.createTensor(env, featureData);
                inputs.put(featureName, tensor);
            } else {
                System.out.println("警告:未知的数据类型 " + dtype + " 用于特征 " + featureName);
            }
        }

        try (OrtSession.Result result = session.run(inputs)) {
            // 获取预测结果(假设是第三个输出)
            OnnxValue outputValue = result.get(2);
            
            System.out.println("+----------------------+--------------+");
            System.out.println("| 用户ID               | 预测结果     |");
            System.out.println("+======================+==============+");
            
            if (outputValue instanceof OnnxTensor) {
                OnnxTensor outputTensor = (OnnxTensor) outputValue;
                float[][] predictions = (float[][]) outputTensor.getValue();
                
                for (int i = 0; i < batch.size(); i++) {
                    String userId = batch.get(i).get("user_id");
                    float prediction = predictions[i][0];
                    System.out.printf("| %-20s | %12.8f |%n", userId, prediction);
                }
            } else {
                System.out.println("警告:输出不是预期的张量类型");
            }
            
            System.out.println("+----------------------+--------------+");
            System.out.println();
        }
    }

    private static void fillLongFeatureData(List<Map<String, String>> batch, String featureName, JsonNode feature, String dtype, int maxLength, long[][] featureData) {
        for (int i = 0; i < batch.size(); i++) {
            String value = batch.get(i).get(featureName);
            if (value == null || value.isEmpty()) {
                Arrays.fill(featureData[i], 0L);
                continue;
            }

            if (dtype.endsWith("_list")) {
                String separator = feature.get("separator").asText();
                String[] sequence = value.split(Pattern.quote(separator));
                for (int j = 0; j < Math.min(sequence.length, maxLength); j++) {
                    featureData[i][j] = Long.parseLong(sequence[j]);
                }
            } else {
                featureData[i][0] = Long.parseLong(value);
            }
        }
    }

    private static void fillFloatFeatureData(List<Map<String, String>> batch, String featureName, JsonNode feature, String dtype, int maxLength, float[][] featureData) {
        for (int i = 0; i < batch.size(); i++) {
            String value = batch.get(i).get(featureName);
            if (value == null || value.isEmpty()) {
                System.out.println("警告:特征 " + featureName + " 在批次中缺失或为空");
                Arrays.fill(featureData[i], 0.0f);
                continue;
            }

            if (dtype.endsWith("_list")) {
                String separator = feature.get("separator").asText();
                String[] sequence = value.split(Pattern.quote(separator));
                for (int j = 0; j < Math.min(sequence.length, maxLength); j++) {
                    featureData[i][j] = Float.parseFloat(sequence[j]);
                }
            } else {
                featureData[i][0] = Float.parseFloat(value);
            }
        }
    }
}