I am done

This commit is contained in:
2024-10-30 22:14:35 +01:00
parent 720dc28c09
commit 40e2a747cf
36901 changed files with 5011519 additions and 0 deletions

View File

@ -0,0 +1,13 @@
import tensorboard
from torch._vendor.packaging.version import Version
if not hasattr(tensorboard, "__version__") or Version(
tensorboard.__version__
) < Version("1.15"):
raise ImportError("TensorBoard logging requires TensorBoard version 1.15 or above")
del Version
del tensorboard
from .writer import FileWriter, SummaryWriter # noqa: F401
from tensorboard.summary.writer.record_writer import RecordWriter # noqa: F401

View File

@ -0,0 +1,34 @@
# mypy: allow-untyped-defs
"""This module converts objects into numpy array."""
import numpy as np
import torch
def make_np(x):
"""
Convert an object into numpy array.
Args:
x: An instance of torch tensor
Returns:
numpy.array: Numpy array
"""
if isinstance(x, np.ndarray):
return x
if np.isscalar(x):
return np.array([x])
if isinstance(x, torch.Tensor):
return _prepare_pytorch(x)
raise NotImplementedError(
f"Got {type(x)}, but numpy array or torch tensor are expected."
)
def _prepare_pytorch(x):
if x.dtype == torch.bfloat16:
x = x.to(torch.float16)
x = x.detach().cpu().numpy()
return x

View File

@ -0,0 +1,86 @@
# mypy: allow-untyped-defs
import math
import numpy as np
from ._convert_np import make_np
from ._utils import make_grid
from tensorboard.compat import tf
from tensorboard.plugins.projector.projector_config_pb2 import EmbeddingInfo
_HAS_GFILE_JOIN = hasattr(tf.io.gfile, "join")
def _gfile_join(a, b):
# The join API is different between tensorboard's TF stub and TF:
# https://github.com/tensorflow/tensorboard/issues/6080
# We need to try both because `tf` may point to either the stub or the real TF.
if _HAS_GFILE_JOIN:
return tf.io.gfile.join(a, b)
else:
fs = tf.io.gfile.get_filesystem(a)
return fs.join(a, b)
def make_tsv(metadata, save_path, metadata_header=None):
if not metadata_header:
metadata = [str(x) for x in metadata]
else:
assert len(metadata_header) == len(
metadata[0]
), "len of header must be equal to the number of columns in metadata"
metadata = ["\t".join(str(e) for e in l) for l in [metadata_header] + metadata]
metadata_bytes = tf.compat.as_bytes("\n".join(metadata) + "\n")
with tf.io.gfile.GFile(_gfile_join(save_path, "metadata.tsv"), "wb") as f:
f.write(metadata_bytes)
# https://github.com/tensorflow/tensorboard/issues/44 image label will be squared
def make_sprite(label_img, save_path):
from PIL import Image
from io import BytesIO
# this ensures the sprite image has correct dimension as described in
# https://www.tensorflow.org/get_started/embedding_viz
nrow = int(math.ceil((label_img.size(0)) ** 0.5))
arranged_img_CHW = make_grid(make_np(label_img), ncols=nrow)
# augment images so that #images equals nrow*nrow
arranged_augment_square_HWC = np.zeros(
(arranged_img_CHW.shape[2], arranged_img_CHW.shape[2], 3)
)
arranged_img_HWC = arranged_img_CHW.transpose(1, 2, 0) # chw -> hwc
arranged_augment_square_HWC[: arranged_img_HWC.shape[0], :, :] = arranged_img_HWC
im = Image.fromarray(np.uint8((arranged_augment_square_HWC * 255).clip(0, 255)))
with BytesIO() as buf:
im.save(buf, format="PNG")
im_bytes = buf.getvalue()
with tf.io.gfile.GFile(_gfile_join(save_path, "sprite.png"), "wb") as f:
f.write(im_bytes)
def get_embedding_info(metadata, label_img, subdir, global_step, tag):
info = EmbeddingInfo()
info.tensor_name = f"{tag}:{str(global_step).zfill(5)}"
info.tensor_path = _gfile_join(subdir, "tensors.tsv")
if metadata is not None:
info.metadata_path = _gfile_join(subdir, "metadata.tsv")
if label_img is not None:
info.sprite.image_path = _gfile_join(subdir, "sprite.png")
info.sprite.single_image_dim.extend([label_img.size(3), label_img.size(2)])
return info
def write_pbtxt(save_path, contents):
config_path = _gfile_join(save_path, "projector_config.pbtxt")
with tf.io.gfile.GFile(config_path, "wb") as f:
f.write(tf.compat.as_bytes(contents))
def make_mat(matlist, save_path):
with tf.io.gfile.GFile(_gfile_join(save_path, "tensors.tsv"), "wb") as f:
for x in matlist:
x = [str(i.item()) for i in x]
f.write(tf.compat.as_bytes("\t".join(x) + "\n"))

View File

@ -0,0 +1,63 @@
# mypy: allow-untyped-defs
from tensorboard.compat.proto.graph_pb2 import GraphDef
from tensorboard.compat.proto.node_def_pb2 import NodeDef
from tensorboard.compat.proto.versions_pb2 import VersionDef
from tensorboard.compat.proto.attr_value_pb2 import AttrValue
from tensorboard.compat.proto.tensor_shape_pb2 import TensorShapeProto
def load_onnx_graph(fname):
import onnx
m = onnx.load(fname) # type: ignore[attr-defined]
g = m.graph
return parse(g)
def parse(graph):
nodes = []
import itertools
nodes_proto = list(itertools.chain(graph.input, graph.output))
for node in nodes_proto:
print(node.name)
shapeproto = TensorShapeProto(
dim=[
TensorShapeProto.Dim(size=d.dim_value)
for d in node.type.tensor_type.shape.dim
]
)
nodes.append(
NodeDef(
name=node.name.encode(encoding="utf_8"),
op="Variable",
input=[],
attr={
"dtype": AttrValue(type=node.type.tensor_type.elem_type),
"shape": AttrValue(shape=shapeproto),
},
)
)
for node in graph.node:
_attr = []
for s in node.attribute:
_attr.append(" = ".join([str(f[1]) for f in s.ListFields()]))
attr = ", ".join(_attr).encode(encoding="utf_8")
print(node.output[0])
nodes.append(
NodeDef(
name=node.output[0].encode(encoding="utf_8"),
op=node.op_type,
input=node.input,
attr={"parameters": AttrValue(s=attr)},
)
)
# two pass token replacement, appends opname to object id
mapping = {}
for node in nodes:
mapping[node.name] = node.op + "_" + node.name
return GraphDef(node=nodes, versions=VersionDef(producer=22))

View File

@ -0,0 +1,54 @@
# mypy: allow-untyped-defs
from typing import Optional
from tensorboard.compat.proto.node_def_pb2 import NodeDef
from tensorboard.compat.proto.attr_value_pb2 import AttrValue
from tensorboard.compat.proto.tensor_shape_pb2 import TensorShapeProto
def attr_value_proto(dtype, shape, s):
"""Create a dict of objects matching a NodeDef's attr field.
Follows https://github.com/tensorflow/tensorboard/blob/master/tensorboard/compat/proto/attr_value.proto
specifically designed for a NodeDef. The values have been reverse engineered from
standard TensorBoard logged data.
"""
attr = {}
if s is not None:
attr["attr"] = AttrValue(s=s.encode(encoding="utf_8"))
if shape is not None:
shapeproto = tensor_shape_proto(shape)
attr["_output_shapes"] = AttrValue(list=AttrValue.ListValue(shape=[shapeproto]))
return attr
def tensor_shape_proto(outputsize):
"""Create an object matching a tensor_shape field.
Follows https://github.com/tensorflow/tensorboard/blob/master/tensorboard/compat/proto/tensor_shape.proto .
"""
return TensorShapeProto(dim=[TensorShapeProto.Dim(size=d) for d in outputsize])
def node_proto(
name,
op="UnSpecified",
input=None,
dtype=None,
shape: Optional[tuple] = None,
outputsize=None,
attributes="",
):
"""Create an object matching a NodeDef.
Follows https://github.com/tensorflow/tensorboard/blob/master/tensorboard/compat/proto/node_def.proto .
"""
if input is None:
input = []
if not isinstance(input, list):
input = [input]
return NodeDef(
name=name.encode(encoding="utf_8"),
op=op,
input=input,
attr=attr_value_proto(dtype, outputsize, attributes),
)

View File

@ -0,0 +1,381 @@
# mypy: allow-untyped-defs
from collections import OrderedDict
import contextlib
from typing import Dict, Any
from tensorboard.compat.proto.config_pb2 import RunMetadata
from tensorboard.compat.proto.graph_pb2 import GraphDef
from tensorboard.compat.proto.step_stats_pb2 import StepStats, DeviceStepStats
from tensorboard.compat.proto.versions_pb2 import VersionDef
import torch
from ._proto_graph import node_proto
methods_OP = [
"attributeNames",
"hasMultipleOutputs",
"hasUses",
"inputs",
"kind",
"outputs",
"outputsSize",
"scopeName",
]
# Some additional methods to explure for methods_IO are
#
# 'unique' (type int)
# 'type' (type <Tensor<class 'torch._C.Type'>>)
#
# But the below are sufficient for now.
methods_IO = ["node", "offset", "debugName"]
GETATTR_KIND = "prim::GetAttr"
CLASSTYPE_KIND = "ClassType"
class NodeBase:
def __init__(
self,
debugName=None,
inputs=None,
scope=None,
tensor_size=None,
op_type="UnSpecified",
attributes="",
):
# TODO; Specify a __slots__ for this class or potentially
# used namedtuple instead
self.debugName = debugName
self.inputs = inputs
self.tensor_size = tensor_size
self.kind = op_type
self.attributes = attributes
self.scope = scope
def __repr__(self):
repr = []
repr.append(str(type(self)))
for m in dir(self):
if "__" not in m:
repr.append(
m + ": " + str(getattr(self, m)) + str(type(getattr(self, m)))
)
return "\n".join(repr) + "\n\n"
class NodePy(NodeBase):
def __init__(self, node_cpp, valid_methods):
super().__init__(node_cpp)
valid_methods = valid_methods[:]
self.inputs = []
for m in valid_methods:
if m == "inputs" or m == "outputs":
list_of_node = list(getattr(node_cpp, m)())
io_unique_names = []
io_tensor_sizes = []
for n in list_of_node:
io_unique_names.append(n.debugName())
if n.isCompleteTensor():
io_tensor_sizes.append(n.type().sizes())
else:
io_tensor_sizes.append(None)
setattr(self, m, io_unique_names)
setattr(self, m + "tensor_size", io_tensor_sizes)
else:
setattr(self, m, getattr(node_cpp, m)())
class NodePyIO(NodePy):
def __init__(self, node_cpp, input_or_output=None):
super().__init__(node_cpp, methods_IO)
try:
tensor_size = node_cpp.type().sizes()
except RuntimeError:
tensor_size = [
1,
] # fail when constant model is used.
self.tensor_size = tensor_size
# Kind attribute string is purely descriptive and will be shown
# in detailed information for the node in TensorBoard's graph plugin.
#
# NodePyOP nodes get this from their kind() method.
self.kind = "Parameter"
if input_or_output:
self.input_or_output = input_or_output
self.kind = "IO Node"
class NodePyOP(NodePy):
def __init__(self, node_cpp):
super().__init__(node_cpp, methods_OP)
# Replace single quote which causes strange behavior in TensorBoard
# TODO: See if we can remove this in the future
self.attributes = str(
{k: _node_get(node_cpp, k) for k in node_cpp.attributeNames()}
).replace("'", " ")
self.kind = node_cpp.kind()
class GraphPy:
"""Helper class to convert torch.nn.Module to GraphDef proto and visualization with TensorBoard.
GraphDef generation operates in two passes:
In the first pass, all nodes are read and saved to two lists.
One list is for input/output nodes (nodes_io), which only have inbound
or outbound connections, but not both. Another list is for internal
operator nodes (nodes_op). The first pass also saves all scope name
appeared in the nodes in scope_name_appeared list for later processing.
In the second pass, scope names are fully applied to all nodes.
debugNameToScopedName is a mapping from a node's ID to its fully qualified
scope name. e.g. Net1/Linear[0]/1. Unfortunately torch.jit doesn't have
totally correct scope output, so this is nontrivial. The function
populate_namespace_from_OP_to_IO and find_common_root are used to
assign scope name to a node based on the connection between nodes
in a heuristic kind of way. Bookkeeping is done with shallowest_scope_name
and scope_name_appeared.
"""
def __init__(self):
self.nodes_op = []
self.nodes_io = OrderedDict()
self.unique_name_to_scoped_name = {}
self.shallowest_scope_name = "default"
self.scope_name_appeared = []
def append(self, x):
if isinstance(x, NodePyIO):
self.nodes_io[x.debugName] = x
if isinstance(x, NodePyOP):
self.nodes_op.append(x)
def printall(self):
print("all nodes")
for node in self.nodes_op:
print(node)
for key in self.nodes_io:
print(self.nodes_io[key])
def find_common_root(self):
for fullscope in self.scope_name_appeared:
if fullscope:
self.shallowest_scope_name = fullscope.split("/")[0]
def populate_namespace_from_OP_to_IO(self):
for node in self.nodes_op:
for node_output, outputSize in zip(node.outputs, node.outputstensor_size):
self.scope_name_appeared.append(node.scopeName)
self.nodes_io[node_output] = NodeBase(
node_output,
node.inputs,
node.scopeName,
outputSize,
op_type=node.kind,
attributes=node.attributes,
)
self.find_common_root()
for node in self.nodes_op:
for input_node_id in node.inputs:
self.unique_name_to_scoped_name[input_node_id] = (
node.scopeName + "/" + input_node_id
)
for key, node in self.nodes_io.items():
if type(node) == NodeBase:
self.unique_name_to_scoped_name[key] = node.scope + "/" + node.debugName
if hasattr(node, "input_or_output"):
self.unique_name_to_scoped_name[key] = (
node.input_or_output + "/" + node.debugName
)
if hasattr(node, "scope") and node.scope is not None:
self.unique_name_to_scoped_name[key] = node.scope + "/" + node.debugName
if node.scope == "" and self.shallowest_scope_name:
self.unique_name_to_scoped_name[node.debugName] = (
self.shallowest_scope_name + "/" + node.debugName
)
# replace name
for key, node in self.nodes_io.items():
self.nodes_io[key].inputs = [
self.unique_name_to_scoped_name[node_input_id]
for node_input_id in node.inputs
]
if node.debugName in self.unique_name_to_scoped_name:
self.nodes_io[key].debugName = self.unique_name_to_scoped_name[
node.debugName
]
def to_proto(self):
"""Convert graph representation of GraphPy object to TensorBoard required format."""
# TODO: compute correct memory usage and CPU time once
# PyTorch supports it
nodes = []
for v in self.nodes_io.values():
nodes.append(
node_proto(
v.debugName,
input=v.inputs,
outputsize=v.tensor_size,
op=v.kind,
attributes=v.attributes,
)
)
return nodes
def parse(graph, trace, args=None, omit_useless_nodes=True):
"""Parse an optimized PyTorch model graph and produces a list of nodes and node stats.
Useful for eventual conversion to TensorBoard protobuf format.
Args:
graph (PyTorch module): The model graph to be parsed.
trace (PyTorch JIT TracedModule): The model trace to be parsed.
args (tuple): input tensor[s] for the model.
omit_useless_nodes (boolean): Whether to remove nodes from the graph.
"""
n_inputs = len(args)
scope = {}
nodes_py = GraphPy()
for node in graph.inputs():
if omit_useless_nodes:
if (
len(node.uses()) == 0
): # number of user of the node (= number of outputs/ fanout)
continue
if node.type().kind() != CLASSTYPE_KIND:
nodes_py.append(NodePyIO(node, "input"))
attr_to_scope: Dict[Any, str] = {}
for node in graph.nodes():
if node.kind() == GETATTR_KIND:
attr_name = node.s("name")
attr_key = node.output().debugName()
parent = node.input().node()
if (
parent.kind() == GETATTR_KIND
): # If the parent node is not the top-level "self" node
parent_attr_name = parent.s("name")
parent_attr_key = parent.output().debugName()
parent_scope = attr_to_scope[parent_attr_key]
attr_scope = parent_scope.split("/")[-1]
attr_to_scope[attr_key] = f"{parent_scope}/{attr_scope}.{attr_name}"
else:
attr_to_scope[attr_key] = f"__module.{attr_name}"
# We don't need classtype nodes; scope will provide this information
if node.output().type().kind() != CLASSTYPE_KIND:
node_py = NodePyOP(node)
node_py.scopeName = attr_to_scope[attr_key] # type: ignore[attr-defined]
nodes_py.append(node_py)
else:
nodes_py.append(NodePyOP(node))
for i, node in enumerate(graph.outputs()): # Create sink nodes for output ops
node_pyio = NodePyIO(node, "output")
node_pyio.debugName = f"output.{i + 1}"
node_pyio.inputs = [node.debugName()]
nodes_py.append(node_pyio)
def parse_traced_name(module):
if isinstance(module, torch.jit.TracedModule):
module_name = module._name
else:
module_name = getattr(module, "original_name", "Module")
return module_name
alias_to_name = {}
base_name = parse_traced_name(trace)
for name, module in trace.named_modules(prefix="__module"):
mod_name = parse_traced_name(module)
attr_name = name.split(".")[-1]
alias_to_name[name] = f"{mod_name}[{attr_name}]"
for node in nodes_py.nodes_op:
module_aliases = node.scopeName.split("/")
replacements = [
alias_to_name[alias] if alias in alias_to_name else alias.split(".")[-1]
for alias in module_aliases
]
node.scopeName = base_name
if any(replacements):
node.scopeName += "/" + "/".join(replacements)
nodes_py.populate_namespace_from_OP_to_IO()
return nodes_py.to_proto()
def graph(model, args, verbose=False, use_strict_trace=True):
"""
Process a PyTorch model and produces a `GraphDef` proto that can be logged to TensorBoard.
Args:
model (PyTorch module): The model to be parsed.
args (tuple): input tensor[s] for the model.
verbose (bool): Whether to print out verbose information while
processing.
use_strict_trace (bool): Whether to pass keyword argument `strict` to
`torch.jit.trace`. Pass False when you want the tracer to
record your mutable container types (list, dict)
"""
with _set_model_to_eval(model):
try:
trace = torch.jit.trace(model, args, strict=use_strict_trace)
graph = trace.graph
torch._C._jit_pass_inline(graph)
except RuntimeError as e:
print(e)
print("Error occurs, No graph saved")
raise e
if verbose:
print(graph)
list_of_nodes = parse(graph, trace, args)
# We are hardcoding that this was run on CPU even though it might have actually
# run on GPU. Note this is what is shown in TensorBoard and has no bearing
# on actual execution.
# TODO: See if we can extract GPU vs CPU information from the PyTorch model
# and pass it correctly to TensorBoard.
#
# Definition of StepStats and DeviceStepStats can be found at
# https://github.com/tensorflow/tensorboard/blob/master/tensorboard/plugins/graph/tf_graph_common/test/graph-test.ts
# and
# https://github.com/tensorflow/tensorboard/blob/master/tensorboard/compat/proto/step_stats.proto
stepstats = RunMetadata(
step_stats=StepStats(dev_stats=[DeviceStepStats(device="/device:CPU:0")])
)
return GraphDef(node=list_of_nodes, versions=VersionDef(producer=22)), stepstats
# The producer version has been reverse engineered from standard
# TensorBoard logged data.
@contextlib.contextmanager
def _set_model_to_eval(model):
"""Context manager to temporarily set the training mode of ``model`` to eval."""
if not isinstance(model, torch.jit.ScriptFunction):
originally_training = model.training
model.train(False)
try:
yield
finally:
model.train(originally_training)
else:
# Do nothing for ScriptFunction
try:
yield
finally:
pass
def _node_get(node: torch._C.Node, key: str):
"""Get attributes of a node which is polymorphic over return type."""
sel = node.kindOf(key)
return getattr(node, sel)(key)

View File

@ -0,0 +1,126 @@
# mypy: allow-untyped-defs
import numpy as np
# Functions for converting
def figure_to_image(figures, close=True):
"""Render matplotlib figure to numpy format.
Note that this requires the ``matplotlib`` package.
Args:
figures (matplotlib.pyplot.figure or list of figures): figure or a list of figures
close (bool): Flag to automatically close the figure
Returns:
numpy.array: image in [CHW] order
"""
import matplotlib.pyplot as plt
import matplotlib.backends.backend_agg as plt_backend_agg
def render_to_rgb(figure):
canvas = plt_backend_agg.FigureCanvasAgg(figure)
canvas.draw()
data: np.ndarray = np.frombuffer(canvas.buffer_rgba(), dtype=np.uint8)
w, h = figure.canvas.get_width_height()
image_hwc = data.reshape([h, w, 4])[:, :, 0:3]
image_chw = np.moveaxis(image_hwc, source=2, destination=0)
if close:
plt.close(figure)
return image_chw
if isinstance(figures, list):
images = [render_to_rgb(figure) for figure in figures]
return np.stack(images)
else:
image = render_to_rgb(figures)
return image
def _prepare_video(V):
"""
Convert a 5D tensor into 4D tensor.
Convesrion is done from [batchsize, time(frame), channel(color), height, width] (5D tensor)
to [time(frame), new_width, new_height, channel] (4D tensor).
A batch of images are spreaded to a grid, which forms a frame.
e.g. Video with batchsize 16 will have a 4x4 grid.
"""
b, t, c, h, w = V.shape
if V.dtype == np.uint8:
V = np.float32(V) / 255.0
def is_power2(num):
return num != 0 and ((num & (num - 1)) == 0)
# pad to nearest power of 2, all at once
if not is_power2(V.shape[0]):
len_addition = int(2 ** V.shape[0].bit_length() - V.shape[0])
V = np.concatenate((V, np.zeros(shape=(len_addition, t, c, h, w))), axis=0)
n_rows = 2 ** ((b.bit_length() - 1) // 2)
n_cols = V.shape[0] // n_rows
V = np.reshape(V, newshape=(n_rows, n_cols, t, c, h, w))
V = np.transpose(V, axes=(2, 0, 4, 1, 5, 3))
V = np.reshape(V, newshape=(t, n_rows * h, n_cols * w, c))
return V
def make_grid(I, ncols=8):
# I: N1HW or N3HW
assert isinstance(I, np.ndarray), "plugin error, should pass numpy array here"
if I.shape[1] == 1:
I = np.concatenate([I, I, I], 1)
assert I.ndim == 4 and I.shape[1] == 3
nimg = I.shape[0]
H = I.shape[2]
W = I.shape[3]
ncols = min(nimg, ncols)
nrows = int(np.ceil(float(nimg) / ncols))
canvas = np.zeros((3, H * nrows, W * ncols), dtype=I.dtype)
i = 0
for y in range(nrows):
for x in range(ncols):
if i >= nimg:
break
canvas[:, y * H : (y + 1) * H, x * W : (x + 1) * W] = I[i]
i = i + 1
return canvas
# if modality == 'IMG':
# if x.dtype == np.uint8:
# x = x.astype(np.float32) / 255.0
def convert_to_HWC(tensor, input_format): # tensor: numpy array
assert len(set(input_format)) == len(
input_format
), f"You can not use the same dimension shordhand twice. input_format: {input_format}"
assert len(tensor.shape) == len(
input_format
), f"size of input tensor and input format are different. \
tensor shape: {tensor.shape}, input_format: {input_format}"
input_format = input_format.upper()
if len(input_format) == 4:
index = [input_format.find(c) for c in "NCHW"]
tensor_NCHW = tensor.transpose(index)
tensor_CHW = make_grid(tensor_NCHW)
return tensor_CHW.transpose(1, 2, 0)
if len(input_format) == 3:
index = [input_format.find(c) for c in "HWC"]
tensor_HWC = tensor.transpose(index)
if tensor_HWC.shape[2] == 1:
tensor_HWC = np.concatenate([tensor_HWC, tensor_HWC, tensor_HWC], 2)
return tensor_HWC
if len(input_format) == 2:
index = [input_format.find(c) for c in "HW"]
tensor = tensor.transpose(index)
tensor = np.stack([tensor, tensor, tensor], 2)
return tensor

View File

@ -0,0 +1,982 @@
# mypy: allow-untyped-defs
import json
import logging
import os
import struct
from typing import Any, List, Optional
import torch
import numpy as np
from google.protobuf import struct_pb2
from tensorboard.compat.proto.summary_pb2 import (
HistogramProto,
Summary,
SummaryMetadata,
)
from tensorboard.compat.proto.tensor_pb2 import TensorProto
from tensorboard.compat.proto.tensor_shape_pb2 import TensorShapeProto
from tensorboard.plugins.custom_scalar import layout_pb2
from tensorboard.plugins.pr_curve.plugin_data_pb2 import PrCurvePluginData
from tensorboard.plugins.text.plugin_data_pb2 import TextPluginData
from ._convert_np import make_np
from ._utils import _prepare_video, convert_to_HWC
__all__ = [
"half_to_int",
"int_to_half",
"hparams",
"scalar",
"histogram_raw",
"histogram",
"make_histogram",
"image",
"image_boxes",
"draw_boxes",
"make_image",
"video",
"make_video",
"audio",
"custom_scalars",
"text",
"tensor_proto",
"pr_curve_raw",
"pr_curve",
"compute_curve",
"mesh",
]
logger = logging.getLogger(__name__)
def half_to_int(f: float) -> int:
"""Casts a half-precision float value into an integer.
Converts a half precision floating point value, such as `torch.half` or
`torch.bfloat16`, into an integer value which can be written into the
half_val field of a TensorProto for storage.
To undo the effects of this conversion, use int_to_half().
"""
buf = struct.pack("f", f)
return struct.unpack("i", buf)[0]
def int_to_half(i: int) -> float:
"""Casts an integer value to a half-precision float.
Converts an integer value obtained from half_to_int back into a floating
point value.
"""
buf = struct.pack("i", i)
return struct.unpack("f", buf)[0]
def _tensor_to_half_val(t: torch.Tensor) -> List[int]:
return [half_to_int(x) for x in t.flatten().tolist()]
def _tensor_to_complex_val(t: torch.Tensor) -> List[float]:
return torch.view_as_real(t).flatten().tolist()
def _tensor_to_list(t: torch.Tensor) -> List[Any]:
return t.flatten().tolist()
# type maps: torch.Tensor type -> (protobuf type, protobuf val field)
_TENSOR_TYPE_MAP = {
torch.half: ("DT_HALF", "half_val", _tensor_to_half_val),
torch.float16: ("DT_HALF", "half_val", _tensor_to_half_val),
torch.bfloat16: ("DT_BFLOAT16", "half_val", _tensor_to_half_val),
torch.float32: ("DT_FLOAT", "float_val", _tensor_to_list),
torch.float: ("DT_FLOAT", "float_val", _tensor_to_list),
torch.float64: ("DT_DOUBLE", "double_val", _tensor_to_list),
torch.double: ("DT_DOUBLE", "double_val", _tensor_to_list),
torch.int8: ("DT_INT8", "int_val", _tensor_to_list),
torch.uint8: ("DT_UINT8", "int_val", _tensor_to_list),
torch.qint8: ("DT_UINT8", "int_val", _tensor_to_list),
torch.int16: ("DT_INT16", "int_val", _tensor_to_list),
torch.short: ("DT_INT16", "int_val", _tensor_to_list),
torch.int: ("DT_INT32", "int_val", _tensor_to_list),
torch.int32: ("DT_INT32", "int_val", _tensor_to_list),
torch.qint32: ("DT_INT32", "int_val", _tensor_to_list),
torch.int64: ("DT_INT64", "int64_val", _tensor_to_list),
torch.complex32: ("DT_COMPLEX32", "scomplex_val", _tensor_to_complex_val),
torch.chalf: ("DT_COMPLEX32", "scomplex_val", _tensor_to_complex_val),
torch.complex64: ("DT_COMPLEX64", "scomplex_val", _tensor_to_complex_val),
torch.cfloat: ("DT_COMPLEX64", "scomplex_val", _tensor_to_complex_val),
torch.bool: ("DT_BOOL", "bool_val", _tensor_to_list),
torch.complex128: ("DT_COMPLEX128", "dcomplex_val", _tensor_to_complex_val),
torch.cdouble: ("DT_COMPLEX128", "dcomplex_val", _tensor_to_complex_val),
torch.uint8: ("DT_UINT8", "uint32_val", _tensor_to_list),
torch.quint8: ("DT_UINT8", "uint32_val", _tensor_to_list),
torch.quint4x2: ("DT_UINT8", "uint32_val", _tensor_to_list),
}
def _calc_scale_factor(tensor):
converted = tensor.numpy() if not isinstance(tensor, np.ndarray) else tensor
return 1 if converted.dtype == np.uint8 else 255
def _draw_single_box(
image,
xmin,
ymin,
xmax,
ymax,
display_str,
color="black",
color_text="black",
thickness=2,
):
from PIL import ImageDraw, ImageFont
font = ImageFont.load_default()
draw = ImageDraw.Draw(image)
(left, right, top, bottom) = (xmin, xmax, ymin, ymax)
draw.line(
[(left, top), (left, bottom), (right, bottom), (right, top), (left, top)],
width=thickness,
fill=color,
)
if display_str:
text_bottom = bottom
# Reverse list and print from bottom to top.
_left, _top, _right, _bottom = font.getbbox(display_str)
text_width, text_height = _right - _left, _bottom - _top
margin = np.ceil(0.05 * text_height)
draw.rectangle(
[
(left, text_bottom - text_height - 2 * margin),
(left + text_width, text_bottom),
],
fill=color,
)
draw.text(
(left + margin, text_bottom - text_height - margin),
display_str,
fill=color_text,
font=font,
)
return image
def hparams(hparam_dict=None, metric_dict=None, hparam_domain_discrete=None):
"""Output three `Summary` protocol buffers needed by hparams plugin.
`Experiment` keeps the metadata of an experiment, such as the name of the
hyperparameters and the name of the metrics.
`SessionStartInfo` keeps key-value pairs of the hyperparameters
`SessionEndInfo` describes status of the experiment e.g. STATUS_SUCCESS
Args:
hparam_dict: A dictionary that contains names of the hyperparameters
and their values.
metric_dict: A dictionary that contains names of the metrics
and their values.
hparam_domain_discrete: (Optional[Dict[str, List[Any]]]) A dictionary that
contains names of the hyperparameters and all discrete values they can hold
Returns:
The `Summary` protobufs for Experiment, SessionStartInfo and
SessionEndInfo
"""
import torch
from tensorboard.plugins.hparams.api_pb2 import (
DataType,
Experiment,
HParamInfo,
MetricInfo,
MetricName,
Status,
)
from tensorboard.plugins.hparams.metadata import (
EXPERIMENT_TAG,
PLUGIN_DATA_VERSION,
PLUGIN_NAME,
SESSION_END_INFO_TAG,
SESSION_START_INFO_TAG,
)
from tensorboard.plugins.hparams.plugin_data_pb2 import (
HParamsPluginData,
SessionEndInfo,
SessionStartInfo,
)
# TODO: expose other parameters in the future.
# hp = HParamInfo(name='lr',display_name='learning rate',
# type=DataType.DATA_TYPE_FLOAT64, domain_interval=Interval(min_value=10,
# max_value=100))
# mt = MetricInfo(name=MetricName(tag='accuracy'), display_name='accuracy',
# description='', dataset_type=DatasetType.DATASET_VALIDATION)
# exp = Experiment(name='123', description='456', time_created_secs=100.0,
# hparam_infos=[hp], metric_infos=[mt], user='tw')
if not isinstance(hparam_dict, dict):
logger.warning("parameter: hparam_dict should be a dictionary, nothing logged.")
raise TypeError(
"parameter: hparam_dict should be a dictionary, nothing logged."
)
if not isinstance(metric_dict, dict):
logger.warning("parameter: metric_dict should be a dictionary, nothing logged.")
raise TypeError(
"parameter: metric_dict should be a dictionary, nothing logged."
)
hparam_domain_discrete = hparam_domain_discrete or {}
if not isinstance(hparam_domain_discrete, dict):
raise TypeError(
"parameter: hparam_domain_discrete should be a dictionary, nothing logged."
)
for k, v in hparam_domain_discrete.items():
if (
k not in hparam_dict
or not isinstance(v, list)
or not all(isinstance(d, type(hparam_dict[k])) for d in v)
):
raise TypeError(
f"parameter: hparam_domain_discrete[{k}] should be a list of same type as hparam_dict[{k}]."
)
hps = []
ssi = SessionStartInfo()
for k, v in hparam_dict.items():
if v is None:
continue
if isinstance(v, (int, float)):
ssi.hparams[k].number_value = v
if k in hparam_domain_discrete:
domain_discrete: Optional[struct_pb2.ListValue] = struct_pb2.ListValue(
values=[
struct_pb2.Value(number_value=d)
for d in hparam_domain_discrete[k]
]
)
else:
domain_discrete = None
hps.append(
HParamInfo(
name=k,
type=DataType.Value("DATA_TYPE_FLOAT64"),
domain_discrete=domain_discrete,
)
)
continue
if isinstance(v, str):
ssi.hparams[k].string_value = v
if k in hparam_domain_discrete:
domain_discrete = struct_pb2.ListValue(
values=[
struct_pb2.Value(string_value=d)
for d in hparam_domain_discrete[k]
]
)
else:
domain_discrete = None
hps.append(
HParamInfo(
name=k,
type=DataType.Value("DATA_TYPE_STRING"),
domain_discrete=domain_discrete,
)
)
continue
if isinstance(v, bool):
ssi.hparams[k].bool_value = v
if k in hparam_domain_discrete:
domain_discrete = struct_pb2.ListValue(
values=[
struct_pb2.Value(bool_value=d)
for d in hparam_domain_discrete[k]
]
)
else:
domain_discrete = None
hps.append(
HParamInfo(
name=k,
type=DataType.Value("DATA_TYPE_BOOL"),
domain_discrete=domain_discrete,
)
)
continue
if isinstance(v, torch.Tensor):
v = make_np(v)[0]
ssi.hparams[k].number_value = v
hps.append(HParamInfo(name=k, type=DataType.Value("DATA_TYPE_FLOAT64")))
continue
raise ValueError(
"value should be one of int, float, str, bool, or torch.Tensor"
)
content = HParamsPluginData(session_start_info=ssi, version=PLUGIN_DATA_VERSION)
smd = SummaryMetadata(
plugin_data=SummaryMetadata.PluginData(
plugin_name=PLUGIN_NAME, content=content.SerializeToString()
)
)
ssi = Summary(value=[Summary.Value(tag=SESSION_START_INFO_TAG, metadata=smd)])
mts = [MetricInfo(name=MetricName(tag=k)) for k in metric_dict.keys()]
exp = Experiment(hparam_infos=hps, metric_infos=mts)
content = HParamsPluginData(experiment=exp, version=PLUGIN_DATA_VERSION)
smd = SummaryMetadata(
plugin_data=SummaryMetadata.PluginData(
plugin_name=PLUGIN_NAME, content=content.SerializeToString()
)
)
exp = Summary(value=[Summary.Value(tag=EXPERIMENT_TAG, metadata=smd)])
sei = SessionEndInfo(status=Status.Value("STATUS_SUCCESS"))
content = HParamsPluginData(session_end_info=sei, version=PLUGIN_DATA_VERSION)
smd = SummaryMetadata(
plugin_data=SummaryMetadata.PluginData(
plugin_name=PLUGIN_NAME, content=content.SerializeToString()
)
)
sei = Summary(value=[Summary.Value(tag=SESSION_END_INFO_TAG, metadata=smd)])
return exp, ssi, sei
def scalar(name, tensor, collections=None, new_style=False, double_precision=False):
"""Output a `Summary` protocol buffer containing a single scalar value.
The generated Summary has a Tensor.proto containing the input Tensor.
Args:
name: A name for the generated node. Will also serve as the series name in
TensorBoard.
tensor: A real numeric Tensor containing a single value.
collections: Optional list of graph collections keys. The new summary op is
added to these collections. Defaults to `[GraphKeys.SUMMARIES]`.
new_style: Whether to use new style (tensor field) or old style (simple_value
field). New style could lead to faster data loading.
Returns:
A scalar `Tensor` of type `string`. Which contains a `Summary` protobuf.
Raises:
ValueError: If tensor has the wrong shape or type.
"""
tensor = make_np(tensor).squeeze()
assert (
tensor.ndim == 0
), f"Tensor should contain one element (0 dimensions). Was given size: {tensor.size} and {tensor.ndim} dimensions."
# python float is double precision in numpy
scalar = float(tensor)
if new_style:
tensor_proto = TensorProto(float_val=[scalar], dtype="DT_FLOAT")
if double_precision:
tensor_proto = TensorProto(double_val=[scalar], dtype="DT_DOUBLE")
plugin_data = SummaryMetadata.PluginData(plugin_name="scalars")
smd = SummaryMetadata(plugin_data=plugin_data)
return Summary(
value=[
Summary.Value(
tag=name,
tensor=tensor_proto,
metadata=smd,
)
]
)
else:
return Summary(value=[Summary.Value(tag=name, simple_value=scalar)])
def tensor_proto(tag, tensor):
"""Outputs a `Summary` protocol buffer containing the full tensor.
The generated Summary has a Tensor.proto containing the input Tensor.
Args:
name: A name for the generated node. Will also serve as the series name in
TensorBoard.
tensor: Tensor to be converted to protobuf
Returns:
A tensor protobuf in a `Summary` protobuf.
Raises:
ValueError: If tensor is too big to be converted to protobuf, or
tensor data type is not supported
"""
if tensor.numel() * tensor.itemsize >= (1 << 31):
raise ValueError(
"tensor is bigger than protocol buffer's hard limit of 2GB in size"
)
if tensor.dtype in _TENSOR_TYPE_MAP:
dtype, field_name, conversion_fn = _TENSOR_TYPE_MAP[tensor.dtype]
tensor_proto = TensorProto(
**{
"dtype": dtype,
"tensor_shape": TensorShapeProto(
dim=[TensorShapeProto.Dim(size=x) for x in tensor.shape]
),
field_name: conversion_fn(tensor),
},
)
else:
raise ValueError(f"{tag} has unsupported tensor dtype {tensor.dtype}")
plugin_data = SummaryMetadata.PluginData(plugin_name="tensor")
smd = SummaryMetadata(plugin_data=plugin_data)
return Summary(value=[Summary.Value(tag=tag, metadata=smd, tensor=tensor_proto)])
def histogram_raw(name, min, max, num, sum, sum_squares, bucket_limits, bucket_counts):
# pylint: disable=line-too-long
"""Output a `Summary` protocol buffer with a histogram.
The generated
[`Summary`](https://www.tensorflow.org/code/tensorflow/core/framework/summary.proto)
has one summary value containing a histogram for `values`.
Args:
name: A name for the generated node. Will also serve as a series name in
TensorBoard.
min: A float or int min value
max: A float or int max value
num: Int number of values
sum: Float or int sum of all values
sum_squares: Float or int sum of squares for all values
bucket_limits: A numeric `Tensor` with upper value per bucket
bucket_counts: A numeric `Tensor` with number of values per bucket
Returns:
A scalar `Tensor` of type `string`. The serialized `Summary` protocol
buffer.
"""
hist = HistogramProto(
min=min,
max=max,
num=num,
sum=sum,
sum_squares=sum_squares,
bucket_limit=bucket_limits,
bucket=bucket_counts,
)
return Summary(value=[Summary.Value(tag=name, histo=hist)])
def histogram(name, values, bins, max_bins=None):
# pylint: disable=line-too-long
"""Output a `Summary` protocol buffer with a histogram.
The generated
[`Summary`](https://www.tensorflow.org/code/tensorflow/core/framework/summary.proto)
has one summary value containing a histogram for `values`.
This op reports an `InvalidArgument` error if any value is not finite.
Args:
name: A name for the generated node. Will also serve as a series name in
TensorBoard.
values: A real numeric `Tensor`. Any shape. Values to use to
build the histogram.
Returns:
A scalar `Tensor` of type `string`. The serialized `Summary` protocol
buffer.
"""
values = make_np(values)
hist = make_histogram(values.astype(float), bins, max_bins)
return Summary(value=[Summary.Value(tag=name, histo=hist)])
def make_histogram(values, bins, max_bins=None):
"""Convert values into a histogram proto using logic from histogram.cc."""
if values.size == 0:
raise ValueError("The input has no element.")
values = values.reshape(-1)
counts, limits = np.histogram(values, bins=bins)
num_bins = len(counts)
if max_bins is not None and num_bins > max_bins:
subsampling = num_bins // max_bins
subsampling_remainder = num_bins % subsampling
if subsampling_remainder != 0:
counts = np.pad(
counts,
pad_width=[[0, subsampling - subsampling_remainder]],
mode="constant",
constant_values=0,
)
counts = counts.reshape(-1, subsampling).sum(axis=-1)
new_limits = np.empty((counts.size + 1,), limits.dtype)
new_limits[:-1] = limits[:-1:subsampling]
new_limits[-1] = limits[-1]
limits = new_limits
# Find the first and the last bin defining the support of the histogram:
cum_counts = np.cumsum(np.greater(counts, 0))
start, end = np.searchsorted(cum_counts, [0, cum_counts[-1] - 1], side="right")
start = int(start)
end = int(end) + 1
del cum_counts
# TensorBoard only includes the right bin limits. To still have the leftmost limit
# included, we include an empty bin left.
# If start == 0, we need to add an empty one left, otherwise we can just include the bin left to the
# first nonzero-count bin:
counts = (
counts[start - 1 : end] if start > 0 else np.concatenate([[0], counts[:end]])
)
limits = limits[start : end + 1]
if counts.size == 0 or limits.size == 0:
raise ValueError("The histogram is empty, please file a bug report.")
sum_sq = values.dot(values)
return HistogramProto(
min=values.min(),
max=values.max(),
num=len(values),
sum=values.sum(),
sum_squares=sum_sq,
bucket_limit=limits.tolist(),
bucket=counts.tolist(),
)
def image(tag, tensor, rescale=1, dataformats="NCHW"):
"""Output a `Summary` protocol buffer with images.
The summary has up to `max_images` summary values containing images. The
images are built from `tensor` which must be 3-D with shape `[height, width,
channels]` and where `channels` can be:
* 1: `tensor` is interpreted as Grayscale.
* 3: `tensor` is interpreted as RGB.
* 4: `tensor` is interpreted as RGBA.
The `name` in the outputted Summary.Value protobufs is generated based on the
name, with a suffix depending on the max_outputs setting:
* If `max_outputs` is 1, the summary value tag is '*name*/image'.
* If `max_outputs` is greater than 1, the summary value tags are
generated sequentially as '*name*/image/0', '*name*/image/1', etc.
Args:
tag: A name for the generated node. Will also serve as a series name in
TensorBoard.
tensor: A 3-D `uint8` or `float32` `Tensor` of shape `[height, width,
channels]` where `channels` is 1, 3, or 4.
'tensor' can either have values in [0, 1] (float32) or [0, 255] (uint8).
The image() function will scale the image values to [0, 255] by applying
a scale factor of either 1 (uint8) or 255 (float32). Out-of-range values
will be clipped.
Returns:
A scalar `Tensor` of type `string`. The serialized `Summary` protocol
buffer.
"""
tensor = make_np(tensor)
tensor = convert_to_HWC(tensor, dataformats)
# Do not assume that user passes in values in [0, 255], use data type to detect
scale_factor = _calc_scale_factor(tensor)
tensor = tensor.astype(np.float32)
tensor = (tensor * scale_factor).clip(0, 255).astype(np.uint8)
image = make_image(tensor, rescale=rescale)
return Summary(value=[Summary.Value(tag=tag, image=image)])
def image_boxes(
tag, tensor_image, tensor_boxes, rescale=1, dataformats="CHW", labels=None
):
"""Output a `Summary` protocol buffer with images."""
tensor_image = make_np(tensor_image)
tensor_image = convert_to_HWC(tensor_image, dataformats)
tensor_boxes = make_np(tensor_boxes)
tensor_image = tensor_image.astype(np.float32) * _calc_scale_factor(tensor_image)
image = make_image(
tensor_image.clip(0, 255).astype(np.uint8),
rescale=rescale,
rois=tensor_boxes,
labels=labels,
)
return Summary(value=[Summary.Value(tag=tag, image=image)])
def draw_boxes(disp_image, boxes, labels=None):
# xyxy format
num_boxes = boxes.shape[0]
list_gt = range(num_boxes)
for i in list_gt:
disp_image = _draw_single_box(
disp_image,
boxes[i, 0],
boxes[i, 1],
boxes[i, 2],
boxes[i, 3],
display_str=None if labels is None else labels[i],
color="Red",
)
return disp_image
def make_image(tensor, rescale=1, rois=None, labels=None):
"""Convert a numpy representation of an image to Image protobuf."""
from PIL import Image
height, width, channel = tensor.shape
scaled_height = int(height * rescale)
scaled_width = int(width * rescale)
image = Image.fromarray(tensor)
if rois is not None:
image = draw_boxes(image, rois, labels=labels)
ANTIALIAS = Image.Resampling.LANCZOS
image = image.resize((scaled_width, scaled_height), ANTIALIAS)
import io
output = io.BytesIO()
image.save(output, format="PNG")
image_string = output.getvalue()
output.close()
return Summary.Image(
height=height,
width=width,
colorspace=channel,
encoded_image_string=image_string,
)
def video(tag, tensor, fps=4):
tensor = make_np(tensor)
tensor = _prepare_video(tensor)
# If user passes in uint8, then we don't need to rescale by 255
scale_factor = _calc_scale_factor(tensor)
tensor = tensor.astype(np.float32)
tensor = (tensor * scale_factor).clip(0, 255).astype(np.uint8)
video = make_video(tensor, fps)
return Summary(value=[Summary.Value(tag=tag, image=video)])
def make_video(tensor, fps):
try:
import moviepy # noqa: F401
except ImportError:
print("add_video needs package moviepy")
return
try:
from moviepy import editor as mpy
except ImportError:
print(
"moviepy is installed, but can't import moviepy.editor.",
"Some packages could be missing [imageio, requests]",
)
return
import tempfile
t, h, w, c = tensor.shape
# encode sequence of images into gif string
clip = mpy.ImageSequenceClip(list(tensor), fps=fps)
filename = tempfile.NamedTemporaryFile(suffix=".gif", delete=False).name
try: # newer version of moviepy use logger instead of progress_bar argument.
clip.write_gif(filename, verbose=False, logger=None)
except TypeError:
try: # older version of moviepy does not support progress_bar argument.
clip.write_gif(filename, verbose=False, progress_bar=False)
except TypeError:
clip.write_gif(filename, verbose=False)
with open(filename, "rb") as f:
tensor_string = f.read()
try:
os.remove(filename)
except OSError:
logger.warning("The temporary file used by moviepy cannot be deleted.")
return Summary.Image(
height=h, width=w, colorspace=c, encoded_image_string=tensor_string
)
def audio(tag, tensor, sample_rate=44100):
array = make_np(tensor)
array = array.squeeze()
if abs(array).max() > 1:
print("warning: audio amplitude out of range, auto clipped.")
array = array.clip(-1, 1)
assert array.ndim == 1, "input tensor should be 1 dimensional."
array = (array * np.iinfo(np.int16).max).astype("<i2")
import io
import wave
fio = io.BytesIO()
with wave.open(fio, "wb") as wave_write:
wave_write.setnchannels(1)
wave_write.setsampwidth(2)
wave_write.setframerate(sample_rate)
wave_write.writeframes(array.data)
audio_string = fio.getvalue()
fio.close()
audio = Summary.Audio(
sample_rate=sample_rate,
num_channels=1,
length_frames=array.shape[-1],
encoded_audio_string=audio_string,
content_type="audio/wav",
)
return Summary(value=[Summary.Value(tag=tag, audio=audio)])
def custom_scalars(layout):
categories = []
for k, v in layout.items():
charts = []
for chart_name, chart_meatadata in v.items():
tags = chart_meatadata[1]
if chart_meatadata[0] == "Margin":
assert len(tags) == 3
mgcc = layout_pb2.MarginChartContent(
series=[
layout_pb2.MarginChartContent.Series(
value=tags[0], lower=tags[1], upper=tags[2]
)
]
)
chart = layout_pb2.Chart(title=chart_name, margin=mgcc)
else:
mlcc = layout_pb2.MultilineChartContent(tag=tags)
chart = layout_pb2.Chart(title=chart_name, multiline=mlcc)
charts.append(chart)
categories.append(layout_pb2.Category(title=k, chart=charts))
layout = layout_pb2.Layout(category=categories)
plugin_data = SummaryMetadata.PluginData(plugin_name="custom_scalars")
smd = SummaryMetadata(plugin_data=plugin_data)
tensor = TensorProto(
dtype="DT_STRING",
string_val=[layout.SerializeToString()],
tensor_shape=TensorShapeProto(),
)
return Summary(
value=[
Summary.Value(tag="custom_scalars__config__", tensor=tensor, metadata=smd)
]
)
def text(tag, text):
plugin_data = SummaryMetadata.PluginData(
plugin_name="text", content=TextPluginData(version=0).SerializeToString()
)
smd = SummaryMetadata(plugin_data=plugin_data)
tensor = TensorProto(
dtype="DT_STRING",
string_val=[text.encode(encoding="utf_8")],
tensor_shape=TensorShapeProto(dim=[TensorShapeProto.Dim(size=1)]),
)
return Summary(
value=[Summary.Value(tag=tag + "/text_summary", metadata=smd, tensor=tensor)]
)
def pr_curve_raw(
tag, tp, fp, tn, fn, precision, recall, num_thresholds=127, weights=None
):
if num_thresholds > 127: # weird, value > 127 breaks protobuf
num_thresholds = 127
data = np.stack((tp, fp, tn, fn, precision, recall))
pr_curve_plugin_data = PrCurvePluginData(
version=0, num_thresholds=num_thresholds
).SerializeToString()
plugin_data = SummaryMetadata.PluginData(
plugin_name="pr_curves", content=pr_curve_plugin_data
)
smd = SummaryMetadata(plugin_data=plugin_data)
tensor = TensorProto(
dtype="DT_FLOAT",
float_val=data.reshape(-1).tolist(),
tensor_shape=TensorShapeProto(
dim=[
TensorShapeProto.Dim(size=data.shape[0]),
TensorShapeProto.Dim(size=data.shape[1]),
]
),
)
return Summary(value=[Summary.Value(tag=tag, metadata=smd, tensor=tensor)])
def pr_curve(tag, labels, predictions, num_thresholds=127, weights=None):
# weird, value > 127 breaks protobuf
num_thresholds = min(num_thresholds, 127)
data = compute_curve(
labels, predictions, num_thresholds=num_thresholds, weights=weights
)
pr_curve_plugin_data = PrCurvePluginData(
version=0, num_thresholds=num_thresholds
).SerializeToString()
plugin_data = SummaryMetadata.PluginData(
plugin_name="pr_curves", content=pr_curve_plugin_data
)
smd = SummaryMetadata(plugin_data=plugin_data)
tensor = TensorProto(
dtype="DT_FLOAT",
float_val=data.reshape(-1).tolist(),
tensor_shape=TensorShapeProto(
dim=[
TensorShapeProto.Dim(size=data.shape[0]),
TensorShapeProto.Dim(size=data.shape[1]),
]
),
)
return Summary(value=[Summary.Value(tag=tag, metadata=smd, tensor=tensor)])
# https://github.com/tensorflow/tensorboard/blob/master/tensorboard/plugins/pr_curve/summary.py
def compute_curve(labels, predictions, num_thresholds=None, weights=None):
_MINIMUM_COUNT = 1e-7
if weights is None:
weights = 1.0
# Compute bins of true positives and false positives.
bucket_indices = np.int32(np.floor(predictions * (num_thresholds - 1)))
float_labels = labels.astype(np.float64)
histogram_range = (0, num_thresholds - 1)
tp_buckets, _ = np.histogram(
bucket_indices,
bins=num_thresholds,
range=histogram_range,
weights=float_labels * weights,
)
fp_buckets, _ = np.histogram(
bucket_indices,
bins=num_thresholds,
range=histogram_range,
weights=(1.0 - float_labels) * weights,
)
# Obtain the reverse cumulative sum.
tp = np.cumsum(tp_buckets[::-1])[::-1]
fp = np.cumsum(fp_buckets[::-1])[::-1]
tn = fp[0] - fp
fn = tp[0] - tp
precision = tp / np.maximum(_MINIMUM_COUNT, tp + fp)
recall = tp / np.maximum(_MINIMUM_COUNT, tp + fn)
return np.stack((tp, fp, tn, fn, precision, recall))
def _get_tensor_summary(
name, display_name, description, tensor, content_type, components, json_config
):
"""Create a tensor summary with summary metadata.
Args:
name: Uniquely identifiable name of the summary op. Could be replaced by
combination of name and type to make it unique even outside of this
summary.
display_name: Will be used as the display name in TensorBoard.
Defaults to `name`.
description: A longform readable description of the summary data. Markdown
is supported.
tensor: Tensor to display in summary.
content_type: Type of content inside the Tensor.
components: Bitmask representing present parts (vertices, colors, etc.) that
belong to the summary.
json_config: A string, JSON-serialized dictionary of ThreeJS classes
configuration.
Returns:
Tensor summary with metadata.
"""
import torch
from tensorboard.plugins.mesh import metadata
tensor = torch.as_tensor(tensor)
tensor_metadata = metadata.create_summary_metadata(
name,
display_name,
content_type,
components,
tensor.shape,
description,
json_config=json_config,
)
tensor = TensorProto(
dtype="DT_FLOAT",
float_val=tensor.reshape(-1).tolist(),
tensor_shape=TensorShapeProto(
dim=[
TensorShapeProto.Dim(size=tensor.shape[0]),
TensorShapeProto.Dim(size=tensor.shape[1]),
TensorShapeProto.Dim(size=tensor.shape[2]),
]
),
)
tensor_summary = Summary.Value(
tag=metadata.get_instance_name(name, content_type),
tensor=tensor,
metadata=tensor_metadata,
)
return tensor_summary
def _get_json_config(config_dict):
"""Parse and returns JSON string from python dictionary."""
json_config = "{}"
if config_dict is not None:
json_config = json.dumps(config_dict, sort_keys=True)
return json_config
# https://github.com/tensorflow/tensorboard/blob/master/tensorboard/plugins/mesh/summary.py
def mesh(
tag, vertices, colors, faces, config_dict, display_name=None, description=None
):
"""Output a merged `Summary` protocol buffer with a mesh/point cloud.
Args:
tag: A name for this summary operation.
vertices: Tensor of shape `[dim_1, ..., dim_n, 3]` representing the 3D
coordinates of vertices.
faces: Tensor of shape `[dim_1, ..., dim_n, 3]` containing indices of
vertices within each triangle.
colors: Tensor of shape `[dim_1, ..., dim_n, 3]` containing colors for each
vertex.
display_name: If set, will be used as the display name in TensorBoard.
Defaults to `name`.
description: A longform readable description of the summary data. Markdown
is supported.
config_dict: Dictionary with ThreeJS classes names and configuration.
Returns:
Merged summary for mesh/point cloud representation.
"""
from tensorboard.plugins.mesh import metadata
from tensorboard.plugins.mesh.plugin_data_pb2 import MeshPluginData
json_config = _get_json_config(config_dict)
summaries = []
tensors = [
(vertices, MeshPluginData.VERTEX),
(faces, MeshPluginData.FACE),
(colors, MeshPluginData.COLOR),
]
tensors = [tensor for tensor in tensors if tensor[0] is not None]
components = metadata.get_components_bitmask(
[content_type for (tensor, content_type) in tensors]
)
for tensor, content_type in tensors:
summaries.append(
_get_tensor_summary(
tag,
display_name,
description,
tensor,
content_type,
components,
json_config,
)
)
return Summary(value=summaries)

File diff suppressed because it is too large Load Diff