I am done

This commit is contained in:
2024-10-30 22:14:35 +01:00
parent 720dc28c09
commit 40e2a747cf
36901 changed files with 5011519 additions and 0 deletions

View File

@ -0,0 +1,146 @@
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""The TensorBoard Histograms plugin.
See `http_api.md` in this directory for specifications of the routes for
this plugin.
"""
from werkzeug import wrappers
from tensorboard import errors
from tensorboard import plugin_util
from tensorboard.backend import http_util
from tensorboard.data import provider
from tensorboard.plugins import base_plugin
from tensorboard.plugins.histogram import metadata
_DEFAULT_DOWNSAMPLING = 500 # histograms per time series
class HistogramsPlugin(base_plugin.TBPlugin):
"""Histograms Plugin for TensorBoard.
This supports both old-style summaries (created with TensorFlow ops
that output directly to the `histo` field of the proto) and new-
style summaries (as created by the
`tensorboard.plugins.histogram.summary` module).
"""
plugin_name = metadata.PLUGIN_NAME
# Use a round number + 1 since sampling includes both start and end steps,
# so N+1 samples corresponds to dividing the step sequence into N intervals.
SAMPLE_SIZE = 51
def __init__(self, context):
"""Instantiates HistogramsPlugin via TensorBoard core.
Args:
context: A base_plugin.TBContext instance.
"""
self._downsample_to = (context.sampling_hints or {}).get(
self.plugin_name, _DEFAULT_DOWNSAMPLING
)
self._data_provider = context.data_provider
self._version_checker = plugin_util._MetadataVersionChecker(
data_kind="histogram",
latest_known_version=0,
)
def get_plugin_apps(self):
return {
"/histograms": self.histograms_route,
"/tags": self.tags_route,
}
def is_active(self):
return False # `list_plugins` as called by TB core suffices
def index_impl(self, ctx, experiment):
"""Return {runName: {tagName: {displayName: ..., description:
...}}}."""
mapping = self._data_provider.list_tensors(
ctx,
experiment_id=experiment,
plugin_name=metadata.PLUGIN_NAME,
)
result = {run: {} for run in mapping}
for run, tag_to_content in mapping.items():
for tag, metadatum in tag_to_content.items():
description = plugin_util.markdown_to_safe_html(
metadatum.description
)
md = metadata.parse_plugin_metadata(metadatum.plugin_content)
if not self._version_checker.ok(md.version, run, tag):
continue
result[run][tag] = {
"displayName": metadatum.display_name,
"description": description,
}
return result
def frontend_metadata(self):
return base_plugin.FrontendMetadata(
element_name="tf-histogram-dashboard"
)
def histograms_impl(self, ctx, tag, run, experiment, downsample_to=None):
"""Result of the form `(body, mime_type)`.
At most `downsample_to` events will be returned. If this value is
`None`, then default downsampling will be performed.
Raises:
tensorboard.errors.PublicError: On invalid request.
"""
sample_count = (
downsample_to if downsample_to is not None else self._downsample_to
)
all_histograms = self._data_provider.read_tensors(
ctx,
experiment_id=experiment,
plugin_name=metadata.PLUGIN_NAME,
downsample=sample_count,
run_tag_filter=provider.RunTagFilter(runs=[run], tags=[tag]),
)
histograms = all_histograms.get(run, {}).get(tag, None)
if histograms is None:
raise errors.NotFoundError(
"No histogram tag %r for run %r" % (tag, run)
)
events = [(e.wall_time, e.step, e.numpy.tolist()) for e in histograms]
return (events, "application/json")
@wrappers.Request.application
def tags_route(self, request):
ctx = plugin_util.context(request.environ)
experiment = plugin_util.experiment_id(request.environ)
index = self.index_impl(ctx, experiment=experiment)
return http_util.Respond(request, index, "application/json")
@wrappers.Request.application
def histograms_route(self, request):
"""Given a tag and single run, return array of histogram values."""
ctx = plugin_util.context(request.environ)
experiment = plugin_util.experiment_id(request.environ)
tag = request.args.get("tag")
run = request.args.get("run")
(body, mime_type) = self.histograms_impl(
ctx, tag, run, experiment=experiment, downsample_to=self.SAMPLE_SIZE
)
return http_util.Respond(request, body, mime_type)

View File

@ -0,0 +1,64 @@
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Information about histogram summaries."""
from tensorboard.compat.proto import summary_pb2
from tensorboard.plugins.histogram import plugin_data_pb2
PLUGIN_NAME = "histograms"
# The most recent value for the `version` field of the
# `HistogramPluginData` proto.
PROTO_VERSION = 0
def create_summary_metadata(display_name, description):
"""Create a `summary_pb2.SummaryMetadata` proto for histogram plugin data.
Returns:
A `summary_pb2.SummaryMetadata` protobuf object.
"""
content = plugin_data_pb2.HistogramPluginData(version=PROTO_VERSION)
return summary_pb2.SummaryMetadata(
display_name=display_name,
summary_description=description,
plugin_data=summary_pb2.SummaryMetadata.PluginData(
plugin_name=PLUGIN_NAME, content=content.SerializeToString()
),
)
def parse_plugin_metadata(content):
"""Parse summary metadata to a Python object.
Arguments:
content: The `content` field of a `SummaryMetadata` proto
corresponding to the histogram plugin.
Returns:
A `HistogramPluginData` protobuf object.
"""
if not isinstance(content, bytes):
raise TypeError("Content type must be bytes")
if content == b"{}":
# Old-style JSON format. Equivalent to an all-default proto.
return plugin_data_pb2.HistogramPluginData()
else:
result = plugin_data_pb2.HistogramPluginData.FromString(content)
if result.version == 0:
return result
# No other versions known at this time, so no migrations to do.
return result

View File

@ -0,0 +1,34 @@
# -*- coding: utf-8 -*-
# Generated by the protocol buffer compiler. DO NOT EDIT!
# source: tensorboard/plugins/histogram/plugin_data.proto
"""Generated protocol buffer code."""
from google.protobuf import descriptor as _descriptor
from google.protobuf import descriptor_pool as _descriptor_pool
from google.protobuf import message as _message
from google.protobuf import reflection as _reflection
from google.protobuf import symbol_database as _symbol_database
# @@protoc_insertion_point(imports)
_sym_db = _symbol_database.Default()
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n/tensorboard/plugins/histogram/plugin_data.proto\x12\x0btensorboard\"&\n\x13HistogramPluginData\x12\x0f\n\x07version\x18\x01 \x01(\x05\x62\x06proto3')
_HISTOGRAMPLUGINDATA = DESCRIPTOR.message_types_by_name['HistogramPluginData']
HistogramPluginData = _reflection.GeneratedProtocolMessageType('HistogramPluginData', (_message.Message,), {
'DESCRIPTOR' : _HISTOGRAMPLUGINDATA,
'__module__' : 'tensorboard.plugins.histogram.plugin_data_pb2'
# @@protoc_insertion_point(class_scope:tensorboard.HistogramPluginData)
})
_sym_db.RegisterMessage(HistogramPluginData)
if _descriptor._USE_C_DESCRIPTORS == False:
DESCRIPTOR._options = None
_HISTOGRAMPLUGINDATA._serialized_start=64
_HISTOGRAMPLUGINDATA._serialized_end=102
# @@protoc_insertion_point(module_scope)

View File

@ -0,0 +1,236 @@
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Histogram summaries and TensorFlow operations to create them.
A histogram summary stores a list of buckets. Each bucket is encoded as
a triple `[left_edge, right_edge, count]`. Thus, a full histogram is
encoded as a tensor of dimension `[k, 3]`.
In general, the value of `k` (the number of buckets) will be a constant,
like 30. There are two edge cases: if there is no data, then there are
no buckets (the shape is `[0, 3]`); and if there is data but all points
have the same value, then there is one bucket whose left and right
endpoints are the same (the shape is `[1, 3]`).
NOTE: This module is in beta, and its API is subject to change, but the
data that it stores to disk will be supported forever.
"""
import numpy as np
from tensorboard.plugins.histogram import metadata
from tensorboard.plugins.histogram import summary_v2
# Export V3 versions.
histogram = summary_v2.histogram
histogram_pb = summary_v2.histogram_pb
def _buckets(data, bucket_count=None):
"""Create a TensorFlow op to group data into histogram buckets.
Arguments:
data: A `Tensor` of any shape. Must be castable to `float64`.
bucket_count: Optional positive `int` or scalar `int32` `Tensor`.
Returns:
A `Tensor` of shape `[k, 3]` and type `float64`. The `i`th row is
a triple `[left_edge, right_edge, count]` for a single bucket.
The value of `k` is either `bucket_count` or `1` or `0`.
"""
# TODO(nickfelt): remove on-demand imports once dep situation is fixed.
import tensorflow.compat.v1 as tf
if bucket_count is None:
bucket_count = summary_v2.DEFAULT_BUCKET_COUNT
with tf.name_scope(
"buckets", values=[data, bucket_count]
), tf.control_dependencies(
[tf.assert_scalar(bucket_count), tf.assert_type(bucket_count, tf.int32)]
):
data = tf.reshape(data, shape=[-1]) # flatten
data = tf.cast(data, tf.float64)
is_empty = tf.equal(tf.size(input=data), 0)
def when_empty():
return tf.constant([], shape=(0, 3), dtype=tf.float64)
def when_nonempty():
min_ = tf.reduce_min(input_tensor=data)
max_ = tf.reduce_max(input_tensor=data)
range_ = max_ - min_
is_singular = tf.equal(range_, 0)
def when_nonsingular():
bucket_width = range_ / tf.cast(bucket_count, tf.float64)
offsets = data - min_
bucket_indices = tf.cast(
tf.floor(offsets / bucket_width), dtype=tf.int32
)
clamped_indices = tf.minimum(bucket_indices, bucket_count - 1)
# Use float64 instead of float32 to avoid accumulating floating point error
# later in tf.reduce_sum when summing more than 2^24 individual `1.0` values.
# See https://github.com/tensorflow/tensorflow/issues/51419 for details.
one_hots = tf.one_hot(
clamped_indices, depth=bucket_count, dtype=tf.float64
)
bucket_counts = tf.cast(
tf.reduce_sum(input_tensor=one_hots, axis=0),
dtype=tf.float64,
)
edges = tf.linspace(min_, max_, bucket_count + 1)
left_edges = edges[:-1]
right_edges = edges[1:]
return tf.transpose(
a=tf.stack([left_edges, right_edges, bucket_counts])
)
def when_singular():
center = min_
bucket_starts = tf.stack([center - 0.5])
bucket_ends = tf.stack([center + 0.5])
bucket_counts = tf.stack(
[tf.cast(tf.size(input=data), tf.float64)]
)
return tf.transpose(
a=tf.stack([bucket_starts, bucket_ends, bucket_counts])
)
return tf.cond(is_singular, when_singular, when_nonsingular)
return tf.cond(is_empty, when_empty, when_nonempty)
def op(
name,
data,
bucket_count=None,
display_name=None,
description=None,
collections=None,
):
"""Create a legacy histogram summary op.
Arguments:
name: A unique name for the generated summary node.
data: A `Tensor` of any shape. Must be castable to `float64`.
bucket_count: Optional positive `int`. The output will have this
many buckets, except in two edge cases. If there is no data, then
there are no buckets. If there is data but all points have the
same value, then there is one bucket whose left and right
endpoints are the same.
display_name: Optional name for this summary in TensorBoard, as a
constant `str`. Defaults to `name`.
description: Optional long-form description for this summary, as a
constant `str`. Markdown is supported. Defaults to empty.
collections: Optional list of graph collections keys. The new
summary op is added to these collections. Defaults to
`[Graph Keys.SUMMARIES]`.
Returns:
A TensorFlow summary op.
"""
# TODO(nickfelt): remove on-demand imports once dep situation is fixed.
import tensorflow.compat.v1 as tf
if display_name is None:
display_name = name
summary_metadata = metadata.create_summary_metadata(
display_name=display_name, description=description
)
with tf.name_scope(name):
tensor = _buckets(data, bucket_count=bucket_count)
return tf.summary.tensor_summary(
name="histogram_summary",
tensor=tensor,
collections=collections,
summary_metadata=summary_metadata,
)
def pb(name, data, bucket_count=None, display_name=None, description=None):
"""Create a legacy histogram summary protobuf.
Arguments:
name: A unique name for the generated summary, including any desired
name scopes.
data: A `np.array` or array-like form of any shape. Must have type
castable to `float`.
bucket_count: Optional positive `int`. The output will have this
many buckets, except in two edge cases. If there is no data, then
there are no buckets. If there is data but all points have the
same value, then there is one bucket whose left and right
endpoints are the same.
display_name: Optional name for this summary in TensorBoard, as a
`str`. Defaults to `name`.
description: Optional long-form description for this summary, as a
`str`. Markdown is supported. Defaults to empty.
Returns:
A `tf.Summary` protobuf object.
"""
# TODO(nickfelt): remove on-demand imports once dep situation is fixed.
import tensorflow.compat.v1 as tf
if bucket_count is None:
bucket_count = summary_v2.DEFAULT_BUCKET_COUNT
data = np.array(data).flatten().astype(float)
if data.size == 0:
buckets = np.array([]).reshape((0, 3))
else:
min_ = np.min(data)
max_ = np.max(data)
range_ = max_ - min_
if range_ == 0:
center = min_
buckets = np.array([[center - 0.5, center + 0.5, float(data.size)]])
else:
bucket_width = range_ / bucket_count
offsets = data - min_
bucket_indices = np.floor(offsets / bucket_width).astype(int)
clamped_indices = np.minimum(bucket_indices, bucket_count - 1)
one_hots = np.array([clamped_indices]).transpose() == np.arange(
0, bucket_count
) # broadcast
assert one_hots.shape == (data.size, bucket_count), (
one_hots.shape,
(data.size, bucket_count),
)
bucket_counts = np.sum(one_hots, axis=0)
edges = np.linspace(min_, max_, bucket_count + 1)
left_edges = edges[:-1]
right_edges = edges[1:]
buckets = np.array(
[left_edges, right_edges, bucket_counts]
).transpose()
tensor = tf.make_tensor_proto(buckets, dtype=tf.float64)
if display_name is None:
display_name = name
summary_metadata = metadata.create_summary_metadata(
display_name=display_name, description=description
)
tf_summary_metadata = tf.SummaryMetadata.FromString(
summary_metadata.SerializeToString()
)
summary = tf.Summary()
summary.value.add(
tag="%s/histogram_summary" % name,
metadata=tf_summary_metadata,
tensor=tensor,
)
return summary

View File

@ -0,0 +1,293 @@
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Histogram summaries and TensorFlow operations to create them, V2 versions.
A histogram summary stores a list of buckets. Each bucket is encoded as a triple
`[left_edge, right_edge, count]`. Thus, a full histogram is encoded as a tensor
of dimension `[k, 3]`, where the first `k - 1` buckets are closed-open and the
last bucket is closed-closed.
In general, the shape of the output histogram is always constant (`[k, 3]`).
In the case of empty data, the output will be an all-zero histogram of shape
`[k, 3]`, where all edges and counts are zeros. If there is data but all points
have the same value, then all buckets' left and right edges are the same and only
the last bucket has nonzero count.
"""
import numpy as np
from tensorboard.compat import tf2 as tf
from tensorboard.compat.proto import summary_pb2
from tensorboard.plugins.histogram import metadata
from tensorboard.util import lazy_tensor_creator
from tensorboard.util import tensor_util
DEFAULT_BUCKET_COUNT = 30
def histogram_pb(tag, data, buckets=None, description=None):
"""Create a histogram summary protobuf.
Arguments:
tag: String tag for the summary.
data: A `np.array` or array-like form of any shape. Must have type
castable to `float`.
buckets: Optional positive `int`. The output shape will always be
[buckets, 3]. If there is no data, then an all-zero array of shape
[buckets, 3] will be returned. If there is data but all points have
the same value, then all buckets' left and right endpoints are the
same and only the last bucket has nonzero count. Defaults to 30 if
not specified.
description: Optional long-form description for this summary, as a
`str`. Markdown is supported. Defaults to empty.
Returns:
A `summary_pb2.Summary` protobuf object.
"""
bucket_count = DEFAULT_BUCKET_COUNT if buckets is None else buckets
data = np.array(data).flatten().astype(float)
if bucket_count == 0 or data.size == 0:
histogram_buckets = np.zeros((bucket_count, 3))
else:
min_ = np.min(data)
max_ = np.max(data)
range_ = max_ - min_
if range_ == 0:
left_edges = right_edges = np.array([min_] * bucket_count)
bucket_counts = np.array([0] * (bucket_count - 1) + [data.size])
histogram_buckets = np.array(
[left_edges, right_edges, bucket_counts]
).transpose()
else:
bucket_width = range_ / bucket_count
offsets = data - min_
bucket_indices = np.floor(offsets / bucket_width).astype(int)
clamped_indices = np.minimum(bucket_indices, bucket_count - 1)
one_hots = np.array([clamped_indices]).transpose() == np.arange(
0, bucket_count
) # broadcast
assert one_hots.shape == (data.size, bucket_count), (
one_hots.shape,
(data.size, bucket_count),
)
bucket_counts = np.sum(one_hots, axis=0)
edges = np.linspace(min_, max_, bucket_count + 1)
left_edges = edges[:-1]
right_edges = edges[1:]
histogram_buckets = np.array(
[left_edges, right_edges, bucket_counts]
).transpose()
tensor = tensor_util.make_tensor_proto(histogram_buckets, dtype=np.float64)
summary_metadata = metadata.create_summary_metadata(
display_name=None, description=description
)
summary = summary_pb2.Summary()
summary.value.add(tag=tag, metadata=summary_metadata, tensor=tensor)
return summary
# This is the TPU compatible V3 histogram implementation as of 2021-12-01.
def histogram(name, data, step=None, buckets=None, description=None):
"""Write a histogram summary.
See also `tf.summary.scalar`, `tf.summary.SummaryWriter`.
Writes a histogram to the current default summary writer, for later analysis
in TensorBoard's 'Histograms' and 'Distributions' dashboards (data written
using this API will appear in both places). Like `tf.summary.scalar` points,
each histogram is associated with a `step` and a `name`. All the histograms
with the same `name` constitute a time series of histograms.
The histogram is calculated over all the elements of the given `Tensor`
without regard to its shape or rank.
This example writes 2 histograms:
```python
w = tf.summary.create_file_writer('test/logs')
with w.as_default():
tf.summary.histogram("activations", tf.random.uniform([100, 50]), step=0)
tf.summary.histogram("initial_weights", tf.random.normal([1000]), step=0)
```
A common use case is to examine the changing activation patterns (or lack
thereof) at specific layers in a neural network, over time.
```python
w = tf.summary.create_file_writer('test/logs')
with w.as_default():
for step in range(100):
# Generate fake "activations".
activations = [
tf.random.normal([1000], mean=step, stddev=1),
tf.random.normal([1000], mean=step, stddev=10),
tf.random.normal([1000], mean=step, stddev=100),
]
tf.summary.histogram("layer1/activate", activations[0], step=step)
tf.summary.histogram("layer2/activate", activations[1], step=step)
tf.summary.histogram("layer3/activate", activations[2], step=step)
```
Arguments:
name: A name for this summary. The summary tag used for TensorBoard will
be this name prefixed by any active name scopes.
data: A `Tensor` of any shape. The histogram is computed over its elements,
which must be castable to `float64`.
step: Explicit `int64`-castable monotonic step value for this summary. If
omitted, this defaults to `tf.summary.experimental.get_step()`, which must
not be None.
buckets: Optional positive `int`. The output will have this
many buckets, except in two edge cases. If there is no data, then
there are no buckets. If there is data but all points have the
same value, then all buckets' left and right endpoints are the same
and only the last bucket has nonzero count. Defaults to 30 if not
specified.
description: Optional long-form description for this summary, as a
constant `str`. Markdown is supported. Defaults to empty.
Returns:
True on success, or false if no summary was emitted because no default
summary writer was available.
Raises:
ValueError: if a default writer exists, but no step was provided and
`tf.summary.experimental.get_step()` is None.
"""
# Avoid building unused gradient graphs for conds below. This works around
# an error building second-order gradient graphs when XlaDynamicUpdateSlice
# is used, and will generally speed up graph building slightly.
data = tf.stop_gradient(data)
summary_metadata = metadata.create_summary_metadata(
display_name=None, description=description
)
# TODO(https://github.com/tensorflow/tensorboard/issues/2109): remove fallback
summary_scope = (
getattr(tf.summary.experimental, "summary_scope", None)
or tf.summary.summary_scope
)
# TODO(ytjing): add special case handling.
with summary_scope(
name, "histogram_summary", values=[data, buckets, step]
) as (tag, _):
# Defer histogram bucketing logic by passing it as a callable to
# write(), wrapped in a LazyTensorCreator for backwards
# compatibility, so that we only do this work when summaries are
# actually written.
@lazy_tensor_creator.LazyTensorCreator
def lazy_tensor():
return _buckets(data, buckets)
return tf.summary.write(
tag=tag,
tensor=lazy_tensor,
step=step,
metadata=summary_metadata,
)
def _buckets(data, bucket_count=None):
"""Create a TensorFlow op to group data into histogram buckets.
Arguments:
data: A `Tensor` of any shape. Must be castable to `float64`.
bucket_count: Optional non-negative `int` or scalar `int32` `Tensor`,
defaults to 30.
Returns:
A `Tensor` of shape `[k, 3]` and type `float64`. The `i`th row is
a triple `[left_edge, right_edge, count]` for a single bucket.
The value of `k` is either `bucket_count` or `0` (when input data
is empty).
"""
if bucket_count is None:
bucket_count = DEFAULT_BUCKET_COUNT
with tf.name_scope("buckets"):
tf.debugging.assert_scalar(bucket_count)
tf.debugging.assert_type(bucket_count, tf.int32)
# Treat a negative bucket count as zero.
bucket_count = tf.math.maximum(0, bucket_count)
data = tf.reshape(data, shape=[-1]) # flatten
data = tf.cast(data, tf.float64)
data_size = tf.size(input=data)
is_empty = tf.logical_or(
tf.equal(data_size, 0), tf.less_equal(bucket_count, 0)
)
def when_empty():
"""When input data is empty or bucket_count is zero.
1. If bucket_count is specified as zero, an empty tensor of shape
(0, 3) will be returned.
2. If the input data is empty, a tensor of shape (bucket_count, 3)
of all zero values will be returned.
"""
return tf.zeros((bucket_count, 3), dtype=tf.float64)
def when_nonempty():
min_ = tf.reduce_min(input_tensor=data)
max_ = tf.reduce_max(input_tensor=data)
range_ = max_ - min_
has_single_value = tf.equal(range_, 0)
def when_multiple_values():
"""When input data contains multiple values."""
bucket_width = range_ / tf.cast(bucket_count, tf.float64)
offsets = data - min_
bucket_indices = tf.cast(
tf.floor(offsets / bucket_width), dtype=tf.int32
)
clamped_indices = tf.minimum(bucket_indices, bucket_count - 1)
# Use float64 instead of float32 to avoid accumulating floating point error
# later in tf.reduce_sum when summing more than 2^24 individual `1.0` values.
# See https://github.com/tensorflow/tensorflow/issues/51419 for details.
one_hots = tf.one_hot(
clamped_indices, depth=bucket_count, dtype=tf.float64
)
bucket_counts = tf.cast(
tf.reduce_sum(input_tensor=one_hots, axis=0),
dtype=tf.float64,
)
edges = tf.linspace(min_, max_, bucket_count + 1)
# Ensure edges[-1] == max_, which TF's linspace implementation does not
# do, leaving it subject to the whim of floating point rounding error.
edges = tf.concat([edges[:-1], [max_]], 0)
left_edges = edges[:-1]
right_edges = edges[1:]
return tf.transpose(
a=tf.stack([left_edges, right_edges, bucket_counts])
)
def when_single_value():
"""When input data contains a single unique value."""
# Left and right edges are the same for single value input.
edges = tf.fill([bucket_count], max_)
# Bucket counts are 0 except the last bucket (if bucket_count > 0),
# which is `data_size`. Ensure that the resulting counts vector has
# length `bucket_count` always, including the bucket_count==0 case.
zeroes = tf.fill([bucket_count], 0)
bucket_counts = tf.cast(
tf.concat([zeroes[:-1], [data_size]], 0)[:bucket_count],
dtype=tf.float64,
)
return tf.transpose(a=tf.stack([edges, edges, bucket_counts]))
return tf.cond(
has_single_value, when_single_value, when_multiple_values
)
return tf.cond(is_empty, when_empty, when_nonempty)