73 lines
2.2 KiB
Python
73 lines
2.2 KiB
Python
# Copyright (c) ONNX Project Contributors
|
|
#
|
|
# SPDX-License-Identifier: Apache-2.0
|
|
from __future__ import annotations
|
|
|
|
import numpy as np
|
|
|
|
INT4_MIN = -8
|
|
INT4_MAX = 7
|
|
UINT4_MIN = 0
|
|
UINT4_MAX = 15
|
|
|
|
|
|
def float32_to_4bit_unpacked(
|
|
x: np.ndarray | np.dtype | float, signed: bool
|
|
) -> np.ndarray:
|
|
"""Cast to 4bit via rounding and clipping (without packing).
|
|
|
|
Args:
|
|
x: element to be converted
|
|
signed: boolean, whether to convert to signed int4.
|
|
|
|
Returns:
|
|
An ndarray with a single int4 element (sign-extended to int8/uint8)
|
|
"""
|
|
dtype = np.int8 if signed else np.uint8
|
|
clip_low = INT4_MIN if signed else UINT4_MIN
|
|
clip_high = INT4_MAX if signed else UINT4_MAX
|
|
if not isinstance(x, np.ndarray):
|
|
x = np.asarray(x)
|
|
|
|
clipped = np.clip(x, clip_low, clip_high)
|
|
return np.rint(clipped).astype(dtype) # type: ignore[no-any-return]
|
|
|
|
|
|
def float32x2_to_4bitx2(
|
|
val_low: np.dtype, val_high: np.dtype, signed: bool
|
|
) -> np.ndarray:
|
|
"""Cast two elements to 4bit (via rounding and clipping) and pack
|
|
to a single byte
|
|
Args:
|
|
val_low: element to be packed in the 4 LSB
|
|
val_high: element to be packed in the 4 MSB
|
|
signed: boolean, whether to convert to signed int4.
|
|
|
|
Returns:
|
|
An ndarray with a single int8/uint8 element, containing both int4 elements
|
|
"""
|
|
i8_high = float32_to_4bit_unpacked(val_high, signed)
|
|
i8_low = float32_to_4bit_unpacked(val_low, signed)
|
|
return i8_high << 4 | i8_low & 0x0F # type: ignore[operator]
|
|
|
|
|
|
def unpack_single_4bitx2(
|
|
x: np.ndarray | np.dtype | float, signed: bool
|
|
) -> tuple[np.ndarray, np.ndarray]:
|
|
unpack_signed = lambda x: np.where((x >> 3) == 0, x, x | 0xF0) # noqa: E731
|
|
"""Unpack a single byte 4bitx2 to two 4 bit elements
|
|
Args:
|
|
x: Input data
|
|
signed: boolean, whether to interpret as signed int4.
|
|
Returns:
|
|
A tuple of ndarrays containing int4 elements (sign-extended to int8/uint8)
|
|
"""
|
|
if not isinstance(x, np.ndarray):
|
|
x = np.asarray(x)
|
|
x_low = x & 0x0F
|
|
x_high = x >> 4
|
|
x_low = unpack_signed(x_low) if signed else x_low
|
|
x_high = unpack_signed(x_high) if signed else x_high
|
|
dtype = np.int8 if signed else np.uint8
|
|
return (x_low.astype(dtype), x_high.astype(dtype))
|