Source code for omlt.io.onnx

import json
from pathlib import Path

import onnx

from omlt.io.input_bounds import load_input_bounds, write_input_bounds
from omlt.io.onnx_parser import NetworkParser


[docs]def write_onnx_model_with_bounds(filename, onnx_model=None, input_bounds=None): """ Write the ONNX model to the given file. If `input_bounds` is not None, write it alongside the ONNX model. Parameters ---------- filename : str the path where the ONNX model is written onnx_model : onnx model or None the onnx model input_bounds : None or dict-like or list bounds on the input variables """ if onnx_model is not None: with open(filename, "wb") as f: f.write(onnx_model.SerializeToString()) if input_bounds is not None: write_input_bounds(f"{filename}.bounds.json", input_bounds)
[docs]def load_onnx_neural_network_with_bounds(filename): """ Load a NetworkDefinition with input bounds from an onnx object. Parameters ---------- filename : str the path where the ONNX model and input bounds file are written Returns ------- NetworkDefinition """ onnx_model = onnx.load(filename) input_bounds_filename = Path(f"{filename}.bounds.json") input_bounds = None if input_bounds_filename.exists: input_bounds = load_input_bounds(input_bounds_filename) return load_onnx_neural_network(onnx_model, input_bounds=input_bounds)
[docs]def load_onnx_neural_network(onnx, scaling_object=None, input_bounds=None): """ Load a NetworkDefinition from an onnx object. Parameters ---------- onnx : onnx model scaling_object : instance of object supporting ScalingInterface input_bounds : list of tuples Returns ------- NetworkDefinition """ parser = NetworkParser() return parser.parse_network(onnx.graph, scaling_object, input_bounds)