Skip to content

Commit

Permalink
Update CoreML exports to support newer *.mlpackage outputs (#13222)
Browse files Browse the repository at this point in the history
* Implement and default mlpackage generation for CoreML model exports

Signed-off-by: Ryan Hirasaki <[email protected]>

* Provide command line argument to export as *.mlmodel instead of *.mlpackage for CoreML

Signed-off-by: Ryan Hirasaki <[email protected]>

* Remove macOS check for CoreML quantization

Requirements for macOS during quantization was removed from coremltools 6.0

Signed-off-by: Ryan Hirasaki <[email protected]>

* Undo removal of warning catching

Signed-off-by: Ryan Hirasaki <[email protected]>

* Change file extension references from mlmodel to mlpackage

Signed-off-by: Ryan Hirasaki <[email protected]>

* Auto-format by https://ultralytics.com/actions

---------

Signed-off-by: Ryan Hirasaki <[email protected]>
Co-authored-by: UltralyticsAssistant <[email protected]>
Co-authored-by: Glenn Jocher <[email protected]>
Co-authored-by: Ultralytics Assistant <[email protected]>
  • Loading branch information
4 people authored Jul 29, 2024
1 parent dcf1242 commit 6096750
Show file tree
Hide file tree
Showing 6 changed files with 55 additions and 20 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ VOC/
*.onnx
*.engine
*.mlmodel
*.mlpackage
*.torchscript
*.tflite
*.h5
Expand Down
2 changes: 1 addition & 1 deletion benchmarks.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
ONNX | `onnx` | yolov5s.onnx
OpenVINO | `openvino` | yolov5s_openvino_model/
TensorRT | `engine` | yolov5s.engine
CoreML | `coreml` | yolov5s.mlmodel
CoreML | `coreml` | yolov5s.mlpackage
TensorFlow SavedModel | `saved_model` | yolov5s_saved_model/
TensorFlow GraphDef | `pb` | yolov5s.pb
TensorFlow Lite | `tflite` | yolov5s.tflite
Expand Down
2 changes: 1 addition & 1 deletion detect.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
yolov5s.onnx # ONNX Runtime or OpenCV DNN with --dnn
yolov5s_openvino_model # OpenVINO
yolov5s.engine # TensorRT
yolov5s.mlmodel # CoreML (macOS-only)
yolov5s.mlpackage # CoreML (macOS-only)
yolov5s_saved_model # TensorFlow SavedModel
yolov5s.pb # TensorFlow GraphDef
yolov5s.tflite # TensorFlow Lite
Expand Down
66 changes: 50 additions & 16 deletions export.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,7 @@ def export_formats():
["ONNX", "onnx", ".onnx", True, True],
["OpenVINO", "openvino", "_openvino_model", True, False],
["TensorRT", "engine", ".engine", False, True],
["CoreML", "coreml", ".mlmodel", True, False],
["CoreML", "coreml", ".mlpackage", True, False],
["TensorFlow SavedModel", "saved_model", "_saved_model", True, True],
["TensorFlow GraphDef", "pb", ".pb", True, True],
["TensorFlow Lite", "tflite", ".tflite", True, False],
Expand Down Expand Up @@ -520,7 +520,7 @@ def export_paddle(model, im, file, metadata, prefix=colorstr("PaddlePaddle:")):


@try_export
def export_coreml(model, im, file, int8, half, nms, prefix=colorstr("CoreML:")):
def export_coreml(model, im, file, int8, half, nms, mlmodel, prefix=colorstr("CoreML:")):
"""
Export a YOLOv5 model to CoreML format with optional NMS, INT8, and FP16 support.
Expand All @@ -531,6 +531,7 @@ def export_coreml(model, im, file, int8, half, nms, prefix=colorstr("CoreML:")):
int8 (bool): Flag indicating whether to use INT8 quantization (default is False).
half (bool): Flag indicating whether to use FP16 quantization (default is False).
nms (bool): Flag indicating whether to include Non-Maximum Suppression (default is False).
mlmodel (bool): Flag indicating whether to export as older *.mlmodel format (default is False).
prefix (str): Prefix string for logging purposes (default is 'CoreML:').
Returns:
Expand All @@ -548,27 +549,46 @@ def export_coreml(model, im, file, int8, half, nms, prefix=colorstr("CoreML:")):
model = Model(cfg, ch=3, nc=80)
im = torch.randn(1, 3, 640, 640)
file = Path("yolov5s_coreml")
export_coreml(model, im, file, int8=False, half=False, nms=True)
export_coreml(model, im, file, int8=False, half=False, nms=True, mlmodel=False)
```
"""
check_requirements("coremltools")
import coremltools as ct

LOGGER.info(f"\n{prefix} starting export with coremltools {ct.__version__}...")
f = file.with_suffix(".mlmodel")
if mlmodel:
f = file.with_suffix(".mlmodel")
convert_to = "neuralnetwork"
precision = None
else:
f = file.with_suffix(".mlpackage")
convert_to = "mlprogram"
if half:
precision = ct.precision.FLOAT16
else:
precision = ct.precision.FLOAT32

if nms:
model = iOSModel(model, im)
ts = torch.jit.trace(model, im, strict=False) # TorchScript model
ct_model = ct.convert(ts, inputs=[ct.ImageType("image", shape=im.shape, scale=1 / 255, bias=[0, 0, 0])])
bits, mode = (8, "kmeans_lut") if int8 else (16, "linear") if half else (32, None)
ct_model = ct.convert(
ts,
inputs=[ct.ImageType("image", shape=im.shape, scale=1 / 255, bias=[0, 0, 0])],
convert_to=convert_to,
compute_precision=precision,
)
bits, mode = (8, "kmeans") if int8 else (16, "linear") if half else (32, None)
if bits < 32:
if MACOS: # quantization only supported on macOS
if mlmodel:
with warnings.catch_warnings():
warnings.filterwarnings("ignore", category=DeprecationWarning) # suppress numpy==1.20 float warning
warnings.filterwarnings(
"ignore", category=DeprecationWarning
) # suppress numpy==1.20 float warning, fixed in coremltools==7.0
ct_model = ct.models.neural_network.quantization_utils.quantize_weights(ct_model, bits, mode)
else:
print(f"{prefix} quantization only supported on macOS, skipping...")
elif bits == 8:
op_config = ct.optimize.coreml.OpPalettizerConfig(mode=mode, nbits=bits, weight_threshold=512)
config = ct.optimize.coreml.OptimizationConfig(global_config=op_config)
ct_model = ct.optimize.coreml.palettize_weights(ct_model, config)
ct_model.save(f)
return f, ct_model

Expand Down Expand Up @@ -1070,7 +1090,7 @@ def add_tflite_metadata(file, metadata, num_outputs):
tmp_file.unlink()


def pipeline_coreml(model, im, file, names, y, prefix=colorstr("CoreML Pipeline:")):
def pipeline_coreml(model, im, file, names, y, mlmodel, prefix=colorstr("CoreML Pipeline:")):
"""
Convert a PyTorch YOLOv5 model to CoreML format with Non-Maximum Suppression (NMS), handling different input/output
shapes, and saving the model.
Expand All @@ -1082,6 +1102,7 @@ def pipeline_coreml(model, im, file, names, y, prefix=colorstr("CoreML Pipeline:
file (Path): Path to save the converted CoreML model.
names (dict[int, str]): Dictionary mapping class indices to class names.
y (torch.Tensor): Output tensor from the PyTorch model's forward pass.
mlmodel (bool): Flag indicating whether to export as older *.mlmodel format (default is False).
prefix (str): Custom prefix for logging messages.
Returns:
Expand Down Expand Up @@ -1114,6 +1135,11 @@ def pipeline_coreml(model, im, file, names, y, prefix=colorstr("CoreML Pipeline:
import coremltools as ct
from PIL import Image

if mlmodel:
f = file.with_suffix(".mlmodel") # filename
else:
f = file.with_suffix(".mlpackage") # filename

print(f"{prefix} starting pipeline with coremltools {ct.__version__}...")
batch_size, ch, h, w = list(im.shape) # BCHW
t = time.time()
Expand Down Expand Up @@ -1156,7 +1182,12 @@ def pipeline_coreml(model, im, file, names, y, prefix=colorstr("CoreML Pipeline:
print(spec.description)

# Model from spec
model = ct.models.MLModel(spec)
weights_dir = None
if mlmodel:
weights_dir = None
else:
weights_dir = str(f / "Data/com.apple.CoreML/weights")
model = ct.models.MLModel(spec, weights_dir=weights_dir)

# 3. Create NMS protobuf
nms_spec = ct.proto.Model_pb2.Model()
Expand Down Expand Up @@ -1227,8 +1258,7 @@ def pipeline_coreml(model, im, file, names, y, prefix=colorstr("CoreML Pipeline:
)

# Save the model
f = file.with_suffix(".mlmodel") # filename
model = ct.models.MLModel(pipeline.spec)
model = ct.models.MLModel(pipeline.spec, weights_dir=weights_dir)
model.input_description["image"] = "Input image"
model.input_description["iouThreshold"] = f"(optional) IOU Threshold override (default: {nms.iouThreshold})"
model.input_description["confidenceThreshold"] = (
Expand Down Expand Up @@ -1256,6 +1286,7 @@ def run(
per_tensor=False, # TF per tensor quantization
dynamic=False, # ONNX/TF/TensorRT: dynamic axes
simplify=False, # ONNX: simplify model
mlmodel=False, # CoreML: Export in *.mlmodel format
opset=12, # ONNX: opset version
verbose=False, # TensorRT: verbose log
workspace=4, # TensorRT: workspace size (GB)
Expand Down Expand Up @@ -1293,6 +1324,7 @@ def run(
topk_all (int): Top-K boxes for all classes to keep for TensorFlow.js NMS. Default is 100.
iou_thres (float): IoU threshold for NMS. Default is 0.45.
conf_thres (float): Confidence threshold for NMS. Default is 0.25.
mlmodel (bool): Flag to use *.mlmodel for CoreML export. Default is False.
Returns:
None
Expand Down Expand Up @@ -1320,6 +1352,7 @@ def run(
simplify=False,
opset=12,
verbose=False,
mlmodel=False,
workspace=4,
nms=False,
agnostic_nms=False,
Expand Down Expand Up @@ -1383,9 +1416,9 @@ def run(
if xml: # OpenVINO
f[3], _ = export_openvino(file, metadata, half, int8, data)
if coreml: # CoreML
f[4], ct_model = export_coreml(model, im, file, int8, half, nms)
f[4], ct_model = export_coreml(model, im, file, int8, half, nms, mlmodel)
if nms:
pipeline_coreml(ct_model, im, file, model.names, y)
pipeline_coreml(ct_model, im, file, model.names, y, mlmodel)
if any((saved_model, pb, tflite, edgetpu, tfjs)): # TensorFlow formats
assert not tflite or not tfjs, "TFLite and TF.js models must be exported separately, please pass only one type."
assert not isinstance(model, ClassificationModel), "ClassificationModel export to TF formats not yet supported."
Expand Down Expand Up @@ -1473,6 +1506,7 @@ def parse_opt(known=False):
parser.add_argument("--per-tensor", action="store_true", help="TF per-tensor quantization")
parser.add_argument("--dynamic", action="store_true", help="ONNX/TF/TensorRT: dynamic axes")
parser.add_argument("--simplify", action="store_true", help="ONNX: simplify model")
parser.add_argument("--mlmodel", action="store_true", help="CoreML: Export in *.mlmodel format")
parser.add_argument("--opset", type=int, default=17, help="ONNX: opset version")
parser.add_argument("--verbose", action="store_true", help="TensorRT: verbose log")
parser.add_argument("--workspace", type=int, default=4, help="TensorRT: workspace size (GB)")
Expand Down
2 changes: 1 addition & 1 deletion models/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -444,7 +444,7 @@ def __init__(self, weights="yolov5s.pt", device=torch.device("cpu"), dnn=False,
# ONNX Runtime: *.onnx
# ONNX OpenCV DNN: *.onnx --dnn
# OpenVINO: *_openvino_model
# CoreML: *.mlmodel
# CoreML: *.mlpackage
# TensorRT: *.engine
# TensorFlow SavedModel: *_saved_model
# TensorFlow GraphDef: *.pb
Expand Down
2 changes: 1 addition & 1 deletion val.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
yolov5s.onnx # ONNX Runtime or OpenCV DNN with --dnn
yolov5s_openvino_model # OpenVINO
yolov5s.engine # TensorRT
yolov5s.mlmodel # CoreML (macOS-only)
yolov5s.mlpackage # CoreML (macOS-only)
yolov5s_saved_model # TensorFlow SavedModel
yolov5s.pb # TensorFlow GraphDef
yolov5s.tflite # TensorFlow Lite
Expand Down

0 comments on commit 6096750

Please sign in to comment.