Skip to main content
The OrtSession class wraps an ONNX model and provides methods for running inference.

Package

ai.onnxruntime.OrtSession

Class Declaration

public class OrtSession implements AutoCloseable

Creating Sessions

Sessions are created through OrtEnvironment, not directly constructed.

From File Path

OrtEnvironment env = OrtEnvironment.getEnvironment();
OrtSession session = env.createSession(
    "model.onnx",
    new OrtSession.SessionOptions()
);

From Byte Array

byte[] modelBytes = Files.readAllBytes(Paths.get("model.onnx"));
OrtSession session = env.createSession(
    modelBytes,
    new OrtSession.SessionOptions()
);

From ByteBuffer

ByteBuffer modelBuffer = ... // Direct ByteBuffer
OrtSession session = env.createSession(
    modelBuffer,
    new OrtSession.SessionOptions()
);

Properties

getNumInputs()

Returns the number of model inputs.
public long getNumInputs()
Example:
long numInputs = session.getNumInputs();
System.out.println("Model has " + numInputs + " inputs");

getNumOutputs()

Returns the number of model outputs.
public long getNumOutputs()

getInputNames()

Returns input names (ordered by input ID).
public Set<String> getInputNames()
Example:
Set<String> inputNames = session.getInputNames();
for (String name : inputNames) {
    System.out.println("Input: " + name);
}

getOutputNames()

Returns output names (ordered by output ID).
public Set<String> getOutputNames()

getInputInfo()

Returns detailed input information including types and shapes.
public Map<String, NodeInfo> getInputInfo() throws OrtException
Example:
Map<String, NodeInfo> inputInfo = session.getInputInfo();
for (Map.Entry<String, NodeInfo> entry : inputInfo.entrySet()) {
    NodeInfo info = entry.getValue();
    System.out.println("Input: " + entry.getKey());
    System.out.println("  Type: " + info.getType());
    System.out.println("  Shape: " + Arrays.toString(info.getShape()));
}

getOutputInfo()

Returns detailed output information.
public Map<String, NodeInfo> getOutputInfo() throws OrtException

Running Inference

run(Map)

Runs inference with all outputs.
public Result run(Map<String, ? extends OnnxTensorLike> inputs) 
    throws OrtException
Parameters:
  • inputs: Map of input name to tensor
Returns: Result containing all outputs Example:
float[] data = {1.0f, 2.0f, 3.0f, 4.0f};
OnnxTensor tensor = OnnxTensor.createTensor(env, 
    FloatBuffer.wrap(data), 
    new long[]{1, 4}
);

try (OrtSession.Result results = session.run(
        Map.of("input", tensor))) {
    
    OnnxValue output = results.get(0);
    float[][] outputData = (float[][]) output.getValue();
    System.out.println(Arrays.deepToString(outputData));
} finally {
    tensor.close();
}

run(Map, Set)

Runs inference with specific output names.
public Result run(
    Map<String, ? extends OnnxTensorLike> inputs,
    Set<String> requestedOutputs
) throws OrtException
Example:
Set<String> outputs = Set.of("output1", "output2");
try (OrtSession.Result results = session.run(inputs, outputs)) {
    for (Map.Entry<String, OnnxValue> entry : results) {
        System.out.println(entry.getKey() + ": " + entry.getValue());
    }
}

run(Map, RunOptions)

Runs inference with custom run options.
public Result run(
    Map<String, ? extends OnnxTensorLike> inputs,
    RunOptions runOptions
) throws OrtException
Example:
OrtSession.RunOptions runOptions = new OrtSession.RunOptions();
runOptions.setLogLevel(OrtLoggingLevel.ORT_LOGGING_LEVEL_VERBOSE);

try (OrtSession.Result results = session.run(inputs, runOptions)) {
    // Process results
}

run with Pre-allocated Outputs

Runs inference using pre-allocated output tensors.
public Result run(
    Map<String, ? extends OnnxTensorLike> inputs,
    Set<String> requestedOutputs,
    Map<String, ? extends OnnxValue> pinnedOutputs,
    RunOptions runOptions
) throws OrtException
Example:
// Pre-allocate output tensor
float[] outputBuffer = new float[1000];
OnnxTensor outputTensor = OnnxTensor.createTensor(env,
    FloatBuffer.wrap(outputBuffer),
    new long[]{1, 1000}
);

Map<String, OnnxTensor> pinnedOutputs = Map.of("output", outputTensor);

try (OrtSession.Result results = session.run(
        inputs,
        Set.of("output"),
        pinnedOutputs,
        null)) {
    
    // Output data is in outputBuffer
    System.out.println(Arrays.toString(outputBuffer));
}

Result Class

The Result class contains inference outputs.

Accessing Results

try (OrtSession.Result results = session.run(inputs)) {
    // By index
    OnnxValue firstOutput = results.get(0);
    
    // By name
    Optional<OnnxValue> namedOutput = results.get("output_name");
    
    // Iterate all outputs
    for (Map.Entry<String, OnnxValue> entry : results) {
        String name = entry.getKey();
        OnnxValue value = entry.getValue();
        // Process output
    }
}

Extracting Data

OnnxValue output = results.get(0);

// Get as array
Object value = output.getValue();

// Type-specific access
if (output instanceof OnnxTensor) {
    OnnxTensor tensor = (OnnxTensor) output;
    FloatBuffer buffer = tensor.getFloatBuffer();
    long[] shape = tensor.getInfo().getShape();
}

SessionOptions

Configuration options for creating sessions.

Creating SessionOptions

OrtSession.SessionOptions options = new OrtSession.SessionOptions();

Optimization Level

options.setOptimizationLevel(
    OrtSession.SessionOptions.OptLevel.ALL_OPT
);

// Available levels:
// - NO_OPT: No optimizations
// - BASIC_OPT: Basic optimizations
// - EXTENDED_OPT: Extended optimizations
// - ALL_OPT: All optimizations

Execution Mode

options.setExecutionMode(
    OrtSession.SessionOptions.ExecutionMode.PARALLEL
);

// Available modes:
// - SEQUENTIAL: Execute operators sequentially
// - PARALLEL: Execute operators in parallel

Thread Configuration

// Intra-op threads (parallelism within operators)
options.setIntraOpNumThreads(4);

// Inter-op threads (parallelism between operators)
options.setInterOpNumThreads(2);

Memory Configuration

// Enable CPU memory arena
options.setCPUArenaAllocator(true);

// Enable memory pattern optimization
options.setMemoryPatternOptimization(true);

Logging

options.setLoggingLevel(OrtLoggingLevel.ORT_LOGGING_LEVEL_WARNING);

Execution Providers

// Add CUDA provider
options.addCUDA(0); // Device ID

// Add CPU provider
options.addCPU(true); // Use arena allocator

// Add NNAPI provider (Android)
import ai.onnxruntime.providers.NNAPIFlags;
EnumSet<NNAPIFlags> flags = EnumSet.of(
    NNAPIFlags.USE_FP16,
    NNAPIFlags.CPU_DISABLED
);
options.addNnapi(flags);

// Add CoreML provider (iOS/macOS)
import ai.onnxruntime.providers.CoreMLFlags;
options.addCoreML(EnumSet.of(CoreMLFlags.ENABLE_ON_SUBGRAPH));

Custom Operators

options.registerCustomOpLibrary("libcustom_ops.so");

Complete Examples

Batch Processing

public class BatchProcessor {
    private OrtEnvironment env;
    private OrtSession session;
    
    public BatchProcessor(String modelPath) throws OrtException {
        env = OrtEnvironment.getEnvironment();
        
        OrtSession.SessionOptions opts = new OrtSession.SessionOptions();
        opts.setOptimizationLevel(OrtSession.SessionOptions.OptLevel.ALL_OPT);
        opts.setIntraOpNumThreads(4);
        
        session = env.createSession(modelPath, opts);
    }
    
    public List<float[]> processBatch(List<float[]> batch) 
            throws OrtException {
        
        List<float[]> results = new ArrayList<>();
        
        // Get input shape info
        String inputName = session.getInputNames().iterator().next();
        NodeInfo inputInfo = session.getInputInfo().get(inputName);
        long[] inputShape = inputInfo.getShape();
        
        for (float[] item : batch) {
            OnnxTensor tensor = OnnxTensor.createTensor(env,
                FloatBuffer.wrap(item),
                new long[]{1, item.length}
            );
            
            try (OrtSession.Result output = session.run(
                    Map.of(inputName, tensor))) {
                
                float[][] result = (float[][]) output.get(0).getValue();
                results.add(result[0]);
            } finally {
                tensor.close();
            }
        }
        
        return results;
    }
    
    public void close() {
        if (session != null) session.close();
    }
}

Multi-threaded Inference

import java.util.concurrent.*;

public class ParallelInference {
    private final OrtEnvironment env;
    private final OrtSession session;
    private final ExecutorService executor;
    
    public ParallelInference(String modelPath, int numThreads) 
            throws OrtException {
        env = OrtEnvironment.getEnvironment();
        session = env.createSession(modelPath, 
            new OrtSession.SessionOptions());
        executor = Executors.newFixedThreadPool(numThreads);
    }
    
    public Future<float[]> submitInference(float[] input) {
        return executor.submit(() -> {
            String inputName = session.getInputNames().iterator().next();
            
            OnnxTensor tensor = OnnxTensor.createTensor(env,
                FloatBuffer.wrap(input),
                new long[]{1, input.length}
            );
            
            try (OrtSession.Result results = session.run(
                    Map.of(inputName, tensor))) {
                
                float[][] output = (float[][]) results.get(0).getValue();
                return output[0];
            } finally {
                tensor.close();
            }
        });
    }
    
    public void shutdown() {
        executor.shutdown();
        session.close();
    }
}

// Usage
ParallelInference inference = new ParallelInference("model.onnx", 4);

List<Future<float[]>> futures = new ArrayList<>();
for (float[] input : inputs) {
    futures.add(inference.submitInference(input));
}

// Collect results
for (Future<float[]> future : futures) {
    float[] result = future.get();
    // Process result
}

inference.shutdown();

Model Metadata Inspection

public void inspectModel(String modelPath) throws OrtException {
    try (OrtEnvironment env = OrtEnvironment.getEnvironment();
         OrtSession session = env.createSession(modelPath,
             new OrtSession.SessionOptions())) {
        
        System.out.println("=== Model Information ===");
        System.out.println("Number of inputs: " + session.getNumInputs());
        System.out.println("Number of outputs: " + session.getNumOutputs());
        
        System.out.println("\n=== Inputs ===");
        Map<String, NodeInfo> inputInfo = session.getInputInfo();
        for (Map.Entry<String, NodeInfo> entry : inputInfo.entrySet()) {
            NodeInfo info = entry.getValue();
            System.out.println("Name: " + entry.getKey());
            System.out.println("  Type: " + info.getType());
            System.out.println("  Shape: " + Arrays.toString(info.getShape()));
        }
        
        System.out.println("\n=== Outputs ===");
        Map<String, NodeInfo> outputInfo = session.getOutputInfo();
        for (Map.Entry<String, NodeInfo> entry : outputInfo.entrySet()) {
            NodeInfo info = entry.getValue();
            System.out.println("Name: " + entry.getKey());
            System.out.println("  Type: " + info.getType());
            System.out.println("  Shape: " + Arrays.toString(info.getShape()));
        }
    }
}

Error Handling

try (OrtEnvironment env = OrtEnvironment.getEnvironment()) {
    OrtSession.SessionOptions opts = new OrtSession.SessionOptions();
    
    try (OrtSession session = env.createSession("model.onnx", opts)) {
        // Run inference
        try (OnnxTensor input = createInput();
             OrtSession.Result results = session.run(
                 Map.of("input", input))) {
            
            processResults(results);
        }
    } catch (OrtException e) {
        System.err.println("Inference failed: " + e.getMessage());
        e.printStackTrace();
    }
} catch (Exception e) {
    System.err.println("Initialization failed: " + e.getMessage());
}

Best Practices

  1. Always use try-with-resources: Ensures proper cleanup
  2. Reuse sessions: Create once, use many times
  3. Configure SessionOptions: Enable optimizations
  4. Close tensors: Free memory after use
  5. Thread-safe inference: Sessions support concurrent run() calls
  6. Handle exceptions: Catch OrtException for ONNX Runtime errors

See Also