Last updated 7th February, 2020.
Scikit-learn is a popular machine learning library and ONNX is a serialization format that is supported by OVHcloud ML Serving. This tutorial will cover how to export a Scikit-learn trained model into an ONNX file.
- A python environment with Scikit-learn installed
Convert a simple model into ONNX
ML Serving supports
scikit-learn models through the
ONNX serialization format.
Train Simple scikit-learn model
Let\'s take a simple example of a
scikit-learn model to illustrate:
# Train a model. from sklearn.datasets import load_iris from sklearn.model_selection import train_test_split from sklearn.ensemble import RandomForestClassifier iris = load_iris() X, y = iris.data, iris.target X_train, X_test, y_train, y_test = train_test_split(X, y) classifier = RandomForestClassifier() classifier.fit(X_train, y_train)
Install sklearn-onnx module
pip install skl2onnx
Define the inputs of your serialized model
numpy array (also called
tensor in ONNX) fed as an input to
the model, choose a name and declare its data-type and its shape.
# import needed data type from skl2onnx.common.data_types import FloatTensorType # input tensors of your model: list of ('<wanted name of tensor>', DataType('<shape>')) initial_type = [ ('float_input', FloatTensorType([None, 4])) ]
Launch the conversion and save it to a file
The trained model conversion is made with the
# Import export function from skl2onnx import convert_sklearn # Export the model onx = convert_sklearn(classifier, initial_types=initial_type) # Save it into wanted file with open("my_model.onnx", "wb") as f: f.write(onx.SerializeToString())
Your model is now serialized on you local file system in the
- For more information about how to serialize a
scikit-learnmodel to ONNX serialization format, refer to the official documentation. For example, you can find information about how to serialize a complex scikit-learn pipeline
- You can check the OVHcloud documentation on how to deploy custom models.
- You can check the supported compatibilities for ONNX models