144 lines
4.4 KiB
Python
144 lines
4.4 KiB
Python
# Copyright (c) ONNX Project Contributors
|
|
|
|
# SPDX-License-Identifier: Apache-2.0
|
|
from __future__ import annotations
|
|
|
|
import numpy as np
|
|
|
|
from onnx import subbyte
|
|
from onnx._custom_element_types import (
|
|
bfloat16,
|
|
float8e4m3fn,
|
|
float8e4m3fnuz,
|
|
float8e5m2,
|
|
float8e5m2fnuz,
|
|
int4,
|
|
uint4,
|
|
)
|
|
from onnx.helper import (
|
|
float32_to_bfloat16,
|
|
float32_to_float8e4m3,
|
|
float32_to_float8e5m2,
|
|
tensor_dtype_to_np_dtype,
|
|
)
|
|
from onnx.numpy_helper import (
|
|
bfloat16_to_float32,
|
|
float8e4m3_to_float32,
|
|
float8e5m2_to_float32,
|
|
)
|
|
from onnx.onnx_pb import TensorProto
|
|
from onnx.reference.op_run import OpRun
|
|
|
|
|
|
def cast_to(x, to, saturate): # noqa: PLR0911
|
|
if x.dtype == bfloat16 and x.dtype.descr[0][0] == "bfloat16":
|
|
if to == TensorProto.BFLOAT16:
|
|
return x
|
|
xr = x.ravel()
|
|
xf = np.empty(xr.shape[0], dtype=np.float32)
|
|
for i in range(xr.shape[0]):
|
|
el = bfloat16_to_float32(xr[i])
|
|
xf[i] = el
|
|
dtype = tensor_dtype_to_np_dtype(to)
|
|
return xf.astype(dtype).reshape(x.shape)
|
|
|
|
f8 = {
|
|
(float8e4m3fn, "e4m3fn", TensorProto.FLOAT8E4M3FN): float8e4m3_to_float32,
|
|
(
|
|
float8e4m3fnuz,
|
|
"e4m3fnuz",
|
|
TensorProto.FLOAT8E4M3FNUZ,
|
|
): lambda *args: float8e4m3_to_float32(*args, uz=True),
|
|
(float8e5m2, "e5m2", TensorProto.FLOAT8E5M2): float8e5m2_to_float32,
|
|
(
|
|
float8e5m2fnuz,
|
|
"e5m2fnuz",
|
|
TensorProto.FLOAT8E5M2FNUZ,
|
|
): lambda *args: float8e5m2_to_float32(*args, fn=True, uz=True),
|
|
}
|
|
|
|
for (dt, st, proto_type), cvt in f8.items():
|
|
if x.dtype == dt and x.dtype.descr[0][0] == st:
|
|
if to == proto_type:
|
|
return x
|
|
xr = x.ravel()
|
|
xf = np.empty(xr.shape[0], dtype=np.float32)
|
|
for i in range(xr.shape[0]):
|
|
el = cvt(xr[i])
|
|
xf[i] = el
|
|
dtype = tensor_dtype_to_np_dtype(to)
|
|
return xf.astype(dtype).reshape(x.shape)
|
|
|
|
if to == TensorProto.BFLOAT16:
|
|
xf = x.astype(np.float32).ravel()
|
|
y = np.empty(xf.shape, dtype=bfloat16).ravel()
|
|
for i in range(y.shape[0]):
|
|
el = float32_to_bfloat16(xf[i], truncate=True) # type: ignore[assignment]
|
|
y[i] = el
|
|
return y.reshape(x.shape)
|
|
|
|
i4 = [
|
|
(uint4, "uint4", TensorProto.UINT4, False),
|
|
(int4, "int4", TensorProto.INT4, True),
|
|
]
|
|
for np_type, np_desc, tensor_type, signed in i4:
|
|
if x.dtype == np_type and x.dtype.descr[0][0] == np_desc:
|
|
if to == tensor_type:
|
|
return x
|
|
to_type = tensor_dtype_to_np_dtype(to)
|
|
return x.astype(to_type)
|
|
|
|
if to == tensor_type:
|
|
xf = x.astype(np.float32).ravel()
|
|
y = np.empty(xf.shape, dtype=np_type).ravel()
|
|
for i in range(y.shape[0]):
|
|
el = subbyte.float32_to_4bit_unpacked(xf[i], signed=signed)
|
|
y[i] = el
|
|
# This operator preduces a tensor with the same shape for INT4.
|
|
return y.reshape(x.shape)
|
|
|
|
f8back = {
|
|
TensorProto.FLOAT8E4M3FN: (
|
|
float8e4m3fn,
|
|
lambda *args: float32_to_float8e4m3(*args, saturate=saturate),
|
|
),
|
|
TensorProto.FLOAT8E4M3FNUZ: (
|
|
float8e4m3fnuz,
|
|
lambda *args: float32_to_float8e4m3(*args, uz=True, saturate=saturate),
|
|
),
|
|
TensorProto.FLOAT8E5M2: (
|
|
float8e5m2,
|
|
lambda *args: float32_to_float8e5m2(*args, saturate=saturate),
|
|
),
|
|
TensorProto.FLOAT8E5M2FNUZ: (
|
|
float8e5m2fnuz,
|
|
lambda *args: float32_to_float8e5m2(
|
|
*args, fn=True, uz=True, saturate=saturate
|
|
),
|
|
),
|
|
}
|
|
for dt, (npdt, cvt) in f8back.items():
|
|
if to == dt:
|
|
xf = x.astype(np.float32).ravel()
|
|
y = np.empty(xf.shape, dtype=npdt).ravel()
|
|
for i in range(y.shape[0]):
|
|
el = cvt(xf[i]) # type: ignore[assignment]
|
|
y[i] = el
|
|
return y.reshape(x.shape)
|
|
|
|
if to == TensorProto.STRING:
|
|
return x.astype(np.str_)
|
|
|
|
dtype = tensor_dtype_to_np_dtype(to)
|
|
return x.astype(dtype)
|
|
|
|
|
|
class Cast_1(OpRun):
|
|
def _run(self, x, to=None): # type: ignore
|
|
return (cast_to(x, to, saturate=True),)
|
|
|
|
|
|
class Cast_19(OpRun):
|
|
def _run(self, x, to=None, saturate=None): # type: ignore
|
|
return (cast_to(x, to, saturate),)
|