321 lines
13 KiB
Python
321 lines
13 KiB
Python
# Copyright (c) ONNX Project Contributors
|
|
|
|
# SPDX-License-Identifier: Apache-2.0
|
|
from __future__ import annotations
|
|
|
|
import numpy as np
|
|
|
|
from onnx.reference.op_run import OpRun
|
|
|
|
|
|
def _conv_implementation( # type: ignore
|
|
X, W, B, auto_pad, dilations, group, kernel_shape, pads, strides
|
|
):
|
|
if dilations is None:
|
|
dilations = [1 for s in X.shape[2:]]
|
|
if kernel_shape is None:
|
|
kernel_shape = W.shape[2:]
|
|
if pads is None:
|
|
pads = [0 for s in X.shape[2:]] * 2
|
|
if strides is None:
|
|
strides = [1 for s in X.shape[2:]]
|
|
|
|
if X.shape[1] != W.shape[1] * group or W.shape[0] % group != 0:
|
|
raise ValueError(
|
|
f"Shape inconsistencies, X.shape={X.shape}, W.shape={W.shape}, group={group}, "
|
|
f"W should be {(W.shape[0], X.shape[1] // group, np.prod(W.shape[1:]) // X.shape[1] * group)}."
|
|
)
|
|
if group > 1:
|
|
res = []
|
|
td = 0
|
|
mg = W.shape[0] // group
|
|
dw = W.shape[1]
|
|
|
|
for b in range(X.shape[0]):
|
|
for g in range(group):
|
|
gx = X[b : b + 1, g * dw : (g + 1) * dw]
|
|
gw = W[g * mg : (g + 1) * mg]
|
|
try:
|
|
cv = _conv_implementation(
|
|
gx,
|
|
gw,
|
|
None,
|
|
auto_pad,
|
|
dilations,
|
|
1,
|
|
kernel_shape,
|
|
pads,
|
|
strides,
|
|
)
|
|
except (ValueError, RuntimeError) as e:
|
|
raise ValueError(
|
|
f"Shape inconsistencies, X.shape={X.shape}, W.shape={W.shape}, group={g}/{group}, "
|
|
f"gx.shape={gx.shape}, gw.shape={gw.shape}, auto_pad={auto_pad}, "
|
|
f"dilations={dilations}, kernel_shape={kernel_shape}, pads={pads}, "
|
|
f"strides={strides}."
|
|
) from e
|
|
if b == 0:
|
|
td += cv.shape[1]
|
|
res.append((b, cv))
|
|
|
|
new_shape = [X.shape[0], *list(res[0][1].shape[1:])]
|
|
new_shape[1] = td
|
|
final = np.zeros(tuple(new_shape), dtype=res[0][1].dtype)
|
|
p = 0
|
|
for b, cv in res:
|
|
final[b : b + 1, p : p + cv.shape[1]] = cv
|
|
p += cv.shape[1]
|
|
if p >= final.shape[1]:
|
|
p = 0
|
|
if B is not None:
|
|
new_shape = [1 for s in final.shape]
|
|
new_shape[1] = B.shape[0]
|
|
b = B.reshape(tuple(new_shape))
|
|
final += b
|
|
return final
|
|
|
|
if dilations[0] != 1 or min(dilations) != max(dilations):
|
|
# Let's compute the dilated kernel.
|
|
nd = len(dilations)
|
|
new_kernel_shape = []
|
|
new_shape = list(W.shape[:-nd])
|
|
for i, d in enumerate(dilations):
|
|
di = len(W.shape) - nd + i
|
|
new_shape.append(W.shape[di] + (W.shape[di] - 1) * (d - 1))
|
|
new_kernel_shape.append(kernel_shape[i] + (kernel_shape[i] - 1) * (d - 1))
|
|
new_w = np.zeros(tuple(new_shape), dtype=W.dtype)
|
|
indices = [slice(0, new_w.shape[0]), slice(0, new_w.shape[1])]
|
|
for i, d in enumerate(dilations):
|
|
di = len(W.shape) - nd + i
|
|
indices.append(slice(0, new_w.shape[di], d))
|
|
new_w[tuple(indices)] = W
|
|
W = new_w
|
|
kernel_shape = new_kernel_shape
|
|
|
|
if auto_pad in {"SAME_LOWER", "SAME_UPPER", "VALID"}:
|
|
head = []
|
|
tail = []
|
|
for i in range(len(X.shape) - 2):
|
|
d = X.shape[i]
|
|
target_size = (d + strides[i] - 1) // strides[i]
|
|
pad_needed = (target_size - 1) * strides[i] + kernel_shape[i] - d
|
|
if auto_pad == "SAME_LOWER":
|
|
pad_head = (pad_needed + 1) // 2
|
|
else:
|
|
pad_head = pad_needed // 2
|
|
pad_tail = pad_needed - pad_head
|
|
head.append(pad_head)
|
|
tail.append(pad_tail)
|
|
pads = head + tail
|
|
|
|
if len(X.shape) == 3:
|
|
sN, sC, sH = X.shape
|
|
# M, C_group, kH, kW = W.shape
|
|
(kh,) = kernel_shape
|
|
(sth,) = strides
|
|
|
|
h_out = int(((sH - kh + pads[0] + pads[1]) / sth) + 1)
|
|
|
|
h0 = pads[0]
|
|
oh = -1 * (kh % 2)
|
|
bh = -h0
|
|
eh = h_out * sth
|
|
res = np.zeros((X.shape[0], W.shape[0], h_out)) # type: ignore[assignment]
|
|
if B is not None:
|
|
res[:, :, :] += B.reshape((1, -1, 1)) # type: ignore
|
|
|
|
for n in range(sN):
|
|
for nw in range(W.shape[0]):
|
|
for c in range(sC):
|
|
w = W[nw : nw + 1, c : c + 1]
|
|
for io in range(bh, eh, sth):
|
|
hr = (io - bh) // sth
|
|
if hr >= h_out:
|
|
continue
|
|
i = io + kh % 2
|
|
ih1, ih2 = max(0, i + oh), min(i + oh + kh, sH)
|
|
img = X[n : n + 1, c : c + 1, ih1:ih2]
|
|
if img.shape != w.shape:
|
|
jh1, jh2 = max(-oh - i, 0), min(kh, kh + sH - (i + oh + kh))
|
|
w_ = w[:1, :1, jh1:jh2]
|
|
if img.shape != w_.shape:
|
|
raise RuntimeError(
|
|
f"Unexpected shape {img.shape} != {w_.shape}, oh={oh}, "
|
|
f"i={i}, kh={kh}, sH={sH}, sth={sth}."
|
|
)
|
|
s = np.dot(img.reshape((1, -1)), w_.reshape((-1, 1)))[
|
|
0, 0
|
|
] # (img * w_).sum()
|
|
else:
|
|
s = np.dot(img.reshape((1, -1)), w.reshape((-1, 1)))[
|
|
0, 0
|
|
] # (img * w).sum()
|
|
res[n, nw, hr] += s # type: ignore
|
|
|
|
return res
|
|
|
|
if len(X.shape) == 4:
|
|
sN, sC, sH, sW = X.shape
|
|
# M, C_group, kH, kW = W.shape
|
|
kh, kw = kernel_shape
|
|
sth, stw = strides
|
|
|
|
h_out = int(((sH - kh + pads[0] + pads[2]) / sth) + 1)
|
|
w_out = int(((sW - kw + pads[1] + pads[3]) / stw) + 1)
|
|
|
|
h0, w0 = pads[0], pads[1]
|
|
oh, ow = -1 * (kh % 2), -1 * (kw % 2)
|
|
bh, bw = -h0, -w0
|
|
eh, ew = h_out * sth, w_out * stw
|
|
res = np.zeros((X.shape[0], W.shape[0], h_out, w_out)) # type: ignore[assignment]
|
|
if B is not None:
|
|
res[:, :, :, :] = B.reshape((1, -1, 1, 1)) # type: ignore
|
|
|
|
for n in range(sN):
|
|
for nw in range(W.shape[0]):
|
|
for c in range(sC):
|
|
w = W[nw : nw + 1, c : c + 1]
|
|
for io in range(bh, eh, sth):
|
|
hr = (io - bh) // sth
|
|
if hr >= h_out:
|
|
continue
|
|
i = io + kh % 2
|
|
ih1, ih2 = max(0, i + oh), min(i + oh + kh, sH)
|
|
for jo in range(bw, ew, stw):
|
|
wr = (jo - bw) // stw
|
|
if wr >= w_out:
|
|
continue
|
|
j = jo + kw % 2
|
|
iw1, iw2 = max(0, j + ow), min(j + ow + kw, sW)
|
|
img = X[n : n + 1, c : c + 1, ih1:ih2, iw1:iw2]
|
|
if img.shape != w.shape:
|
|
jh1, jh2 = (
|
|
max(-oh - i, 0),
|
|
min(kh, kh + sH - (i + oh + kh)),
|
|
)
|
|
jw1, jw2 = (
|
|
max(-ow - j, 0),
|
|
min(kw, kw + sW - (j + ow + kw)),
|
|
)
|
|
w_ = w[:1, :1, jh1:jh2, jw1:jw2]
|
|
if img.shape != w_.shape:
|
|
raise RuntimeError(
|
|
f"Unexpected shape {img.shape} != {w_.shape}, oh={oh}, ow={ow}, "
|
|
f"i={i}, j={j}, kh={kh}, kw={kw}, sH={sH}, sW={sW}, sth={sth}, stw={stw}."
|
|
)
|
|
s = np.dot(img.reshape((1, -1)), w_.reshape((-1, 1)))[
|
|
0, 0
|
|
] # (img * w_).sum()
|
|
else:
|
|
s = np.dot(img.reshape((1, -1)), w.reshape((-1, 1)))[
|
|
0, 0
|
|
] # (img * w).sum()
|
|
res[n, nw, hr, wr] += s # type: ignore
|
|
|
|
return res
|
|
|
|
if len(X.shape) == 5:
|
|
sN, sC, sH, sW, sZ = X.shape
|
|
kh, kw, kz = kernel_shape
|
|
sth, stw, stz = strides
|
|
|
|
h_out = int(((sH - kh + pads[0] + pads[3]) / sth) + 1)
|
|
w_out = int(((sW - kw + pads[1] + pads[4]) / stw) + 1)
|
|
z_out = int(((sZ - kz + pads[2] + pads[5]) / stz) + 1)
|
|
|
|
h0, w0, z0 = pads[0], pads[1], pads[2]
|
|
oh, ow, oz = -1 * (kh % 2), -1 * (kw % 2), -1 * (kz % 2)
|
|
bh, bw, bz = -h0, -w0, -z0
|
|
eh, ew, ez = h_out * sth, w_out * stw, z_out * stz
|
|
res = np.zeros((X.shape[0], W.shape[0], h_out, w_out, z_out)) # type: ignore[assignment]
|
|
if B is not None:
|
|
res[:, :, :, :, :] = B.reshape((1, -1, 1, 1, 1)) # type: ignore
|
|
|
|
for n in range(sN):
|
|
for nw in range(W.shape[0]):
|
|
for c in range(sC):
|
|
w = W[nw : nw + 1, c : c + 1]
|
|
for io in range(bh, eh, sth):
|
|
hr = (io - bh) // sth
|
|
if hr >= h_out:
|
|
continue
|
|
i = io + kh % 2
|
|
ih1, ih2 = max(0, i + oh), min(i + oh + kh, sH)
|
|
for jo in range(bw, ew, stw):
|
|
wr = (jo - bw) // stw
|
|
if wr >= w_out:
|
|
continue
|
|
j = jo + kw % 2
|
|
iw1, iw2 = max(0, j + ow), min(j + ow + kw, sW)
|
|
for zo in range(bz, ez, stz):
|
|
zr = (zo - bz) // stz
|
|
if zr >= z_out:
|
|
continue
|
|
z = zo + kz % 2
|
|
iz1, iz2 = max(0, z + oz), min(z + oz + kz, sZ)
|
|
img = X[n : n + 1, c : c + 1, ih1:ih2, iw1:iw2, iz1:iz2]
|
|
if img.shape != w.shape:
|
|
jh1, jh2 = (
|
|
max(-oh - i, 0),
|
|
min(kh, kh + sH - (i + oh + kh)),
|
|
)
|
|
jw1, jw2 = (
|
|
max(-ow - j, 0),
|
|
min(kw, kw + sW - (j + ow + kw)),
|
|
)
|
|
jz1, jz2 = (
|
|
max(-oz - z, 0),
|
|
min(kz, kz + sZ - (z + oz + kz)),
|
|
)
|
|
w_ = w[:1, :1, jh1:jh2, jw1:jw2, jz1:jz2]
|
|
if img.shape != w_.shape:
|
|
raise RuntimeError(
|
|
f"Unexpected shape {img.shape} != {w_.shape}, oh={oh}, ow={ow}, oz={oz}, "
|
|
f"i={i}, j={j}, z={z}, kh={kh}, kw={kw}, kz={kz}, "
|
|
f"sH={sH}, sW={sW}, sZ={sZ}, sth={sth}, stw={stw}, stz={stz}."
|
|
)
|
|
s = np.dot(
|
|
img.reshape((1, -1)), w_.reshape((-1, 1))
|
|
)[
|
|
0, 0
|
|
] # (img * w_).sum()
|
|
else:
|
|
s = np.dot(
|
|
img.reshape((1, -1)), w.reshape((-1, 1))
|
|
)[
|
|
0, 0
|
|
] # (img * w).sum()
|
|
res[n, nw, hr, wr, zr] += s # type: ignore
|
|
|
|
return res
|
|
|
|
raise RuntimeError(
|
|
f"The convolution for X.shape={X.shape}, W.shape={W.shape}, "
|
|
f"kernel_shape={kernel_shape} is not implemented yet."
|
|
)
|
|
|
|
|
|
class Conv(OpRun):
|
|
def _run( # type: ignore
|
|
self,
|
|
X,
|
|
W,
|
|
B=None,
|
|
auto_pad=None,
|
|
dilations=None,
|
|
group=None,
|
|
kernel_shape=None,
|
|
pads=None,
|
|
strides=None,
|
|
):
|
|
if len(X.shape) < 3:
|
|
raise ValueError(
|
|
f"X must have at least 3 dimensions but its shape is {X.shape}."
|
|
)
|
|
return (
|
|
# _conv_implementation(
|
|
_conv_implementation(
|
|
X, W, B, auto_pad, dilations, group, kernel_shape, pads, strides
|
|
).astype(X.dtype),
|
|
)
|