I am done

This commit is contained in:
2024-10-30 22:14:35 +01:00
parent 720dc28c09
commit 40e2a747cf
36901 changed files with 5011519 additions and 0 deletions

View File

@ -0,0 +1,231 @@
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""The TensorBoard Audio plugin."""
import urllib.parse
from werkzeug import wrappers
from tensorboard import errors
from tensorboard import plugin_util
from tensorboard.backend import http_util
from tensorboard.data import provider
from tensorboard.plugins import base_plugin
from tensorboard.plugins.audio import metadata
_DEFAULT_MIME_TYPE = "application/octet-stream"
_DEFAULT_DOWNSAMPLING = 10 # audio clips per time series
_MIME_TYPES = {
metadata.Encoding.Value("WAV"): "audio/wav",
}
_ALLOWED_MIME_TYPES = frozenset(
list(_MIME_TYPES.values()) + [_DEFAULT_MIME_TYPE]
)
class AudioPlugin(base_plugin.TBPlugin):
"""Audio Plugin for TensorBoard."""
plugin_name = metadata.PLUGIN_NAME
def __init__(self, context):
"""Instantiates AudioPlugin via TensorBoard core.
Args:
context: A base_plugin.TBContext instance.
"""
self._data_provider = context.data_provider
self._downsample_to = (context.sampling_hints or {}).get(
self.plugin_name, _DEFAULT_DOWNSAMPLING
)
self._version_checker = plugin_util._MetadataVersionChecker(
data_kind="audio",
latest_known_version=0,
)
def get_plugin_apps(self):
return {
"/audio": self._serve_audio_metadata,
"/individualAudio": self._serve_individual_audio,
"/tags": self._serve_tags,
}
def is_active(self):
return False # `list_plugins` as called by TB core suffices
def frontend_metadata(self):
return base_plugin.FrontendMetadata(element_name="tf-audio-dashboard")
def _index_impl(self, ctx, experiment):
"""Return information about the tags in each run.
Result is a dictionary of the form
{
"runName1": {
"tagName1": {
"displayName": "The first tag",
"description": "<p>Long ago there was just one tag...</p>",
"samples": 3
},
"tagName2": ...,
...
},
"runName2": ...,
...
}
For each tag, `samples` is the greatest number of audio clips that
appear at any particular step. (It's not related to "samples of a
waveform.") For example, if for tag `minibatch_input` there are
five audio clips at step 0 and ten audio clips at step 1, then the
dictionary for `"minibatch_input"` will contain `"samples": 10`.
"""
mapping = self._data_provider.list_blob_sequences(
ctx,
experiment_id=experiment,
plugin_name=metadata.PLUGIN_NAME,
)
result = {run: {} for run in mapping}
for run, tag_to_time_series in mapping.items():
for tag, time_series in tag_to_time_series.items():
md = metadata.parse_plugin_metadata(time_series.plugin_content)
if not self._version_checker.ok(md.version, run, tag):
continue
description = plugin_util.markdown_to_safe_html(
time_series.description
)
result[run][tag] = {
"displayName": time_series.display_name,
"description": description,
"samples": time_series.max_length,
}
return result
@wrappers.Request.application
def _serve_audio_metadata(self, request):
"""Given a tag and list of runs, serve a list of metadata for audio.
Note that the actual audio data are not sent; instead, we respond
with URLs to the audio. The frontend should treat these URLs as
opaque and should not try to parse information about them or
generate them itself, as the format may change.
Args:
request: A werkzeug.wrappers.Request object.
Returns:
A werkzeug.Response application.
"""
ctx = plugin_util.context(request.environ)
experiment = plugin_util.experiment_id(request.environ)
tag = request.args.get("tag")
run = request.args.get("run")
sample = int(request.args.get("sample", 0))
response = self._audio_response_for_run(
ctx, experiment, run, tag, sample
)
return http_util.Respond(request, response, "application/json")
def _audio_response_for_run(self, ctx, experiment, run, tag, sample):
"""Builds a JSON-serializable object with information about audio.
Args:
run: The name of the run.
tag: The name of the tag the audio entries all belong to.
sample: The zero-indexed sample of the audio sample for which to
retrieve information. For instance, setting `sample` to `2` will
fetch information about only the third audio clip of each batch,
and steps with fewer than three audio clips will be omitted from
the results.
Returns:
A list of dictionaries containing the wall time, step, label,
content type, and query string for each audio entry.
"""
all_audio = self._data_provider.read_blob_sequences(
ctx,
experiment_id=experiment,
plugin_name=metadata.PLUGIN_NAME,
downsample=self._downsample_to,
run_tag_filter=provider.RunTagFilter(runs=[run], tags=[tag]),
)
audio = all_audio.get(run, {}).get(tag, None)
if audio is None:
raise errors.NotFoundError(
"No audio data for run=%r, tag=%r" % (run, tag)
)
content_type = self._get_mime_type(ctx, experiment, run, tag)
response = []
for datum in audio:
if len(datum.values) < sample:
continue
query = urllib.parse.urlencode(
{
"blob_key": datum.values[sample].blob_key,
"content_type": content_type,
}
)
response.append(
{
"wall_time": datum.wall_time,
"label": "",
"step": datum.step,
"contentType": content_type,
"query": query,
}
)
return response
def _get_mime_type(self, ctx, experiment, run, tag):
# TODO(@wchargin): Move this call from `/audio` (called many
# times) to `/tags` (called few times) to reduce data provider
# calls.
mapping = self._data_provider.list_blob_sequences(
ctx,
experiment_id=experiment,
plugin_name=metadata.PLUGIN_NAME,
)
time_series = mapping.get(run, {}).get(tag, None)
if time_series is None:
raise errors.NotFoundError(
"No audio data for run=%r, tag=%r" % (run, tag)
)
parsed = metadata.parse_plugin_metadata(time_series.plugin_content)
return _MIME_TYPES.get(parsed.encoding, _DEFAULT_MIME_TYPE)
@wrappers.Request.application
def _serve_individual_audio(self, request):
"""Serve encoded audio data."""
ctx = plugin_util.context(request.environ)
experiment = plugin_util.experiment_id(request.environ)
mime_type = request.args["content_type"]
if mime_type not in _ALLOWED_MIME_TYPES:
raise errors.InvalidArgumentError(
"Illegal mime type %r" % mime_type
)
blob_key = request.args["blob_key"]
data = self._data_provider.read_blob(ctx, blob_key=blob_key)
return http_util.Respond(request, data, mime_type)
@wrappers.Request.application
def _serve_tags(self, request):
ctx = plugin_util.context(request.environ)
experiment = plugin_util.experiment_id(request.environ)
index = self._index_impl(ctx, experiment)
return http_util.Respond(request, index, "application/json")

View File

@ -0,0 +1,70 @@
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Internal information about the audio plugin."""
from tensorboard.compat.proto import summary_pb2
from tensorboard.plugins.audio import plugin_data_pb2
PLUGIN_NAME = "audio"
# The most recent value for the `version` field of the `AudioPluginData`
# proto.
PROTO_VERSION = 0
# Expose the `Encoding` enum constants.
Encoding = plugin_data_pb2.AudioPluginData.Encoding
def create_summary_metadata(
display_name, description, encoding, *, converted_to_tensor=None
):
"""Create a `SummaryMetadata` proto for audio plugin data.
Returns:
A `SummaryMetadata` protobuf object.
"""
content = plugin_data_pb2.AudioPluginData(
version=PROTO_VERSION,
encoding=encoding,
converted_to_tensor=converted_to_tensor,
)
metadata = summary_pb2.SummaryMetadata(
display_name=display_name,
summary_description=description,
plugin_data=summary_pb2.SummaryMetadata.PluginData(
plugin_name=PLUGIN_NAME, content=content.SerializeToString()
),
)
return metadata
def parse_plugin_metadata(content):
"""Parse summary metadata to a Python object.
Arguments:
content: The `content` field of a `SummaryMetadata` proto
corresponding to the audio plugin.
Returns:
An `AudioPluginData` protobuf object.
"""
if not isinstance(content, bytes):
raise TypeError("Content type must be bytes")
result = plugin_data_pb2.AudioPluginData.FromString(content)
if result.version == 0:
return result
# No other versions known at this time, so no migrations to do.
return result

View File

@ -0,0 +1,37 @@
# -*- coding: utf-8 -*-
# Generated by the protocol buffer compiler. DO NOT EDIT!
# source: tensorboard/plugins/audio/plugin_data.proto
"""Generated protocol buffer code."""
from google.protobuf import descriptor as _descriptor
from google.protobuf import descriptor_pool as _descriptor_pool
from google.protobuf import message as _message
from google.protobuf import reflection as _reflection
from google.protobuf import symbol_database as _symbol_database
# @@protoc_insertion_point(imports)
_sym_db = _symbol_database.Default()
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n+tensorboard/plugins/audio/plugin_data.proto\x12\x0btensorboard\"\x9a\x01\n\x0f\x41udioPluginData\x12\x0f\n\x07version\x18\x01 \x01(\x05\x12\x37\n\x08\x65ncoding\x18\x02 \x01(\x0e\x32%.tensorboard.AudioPluginData.Encoding\x12\x1b\n\x13\x63onverted_to_tensor\x18\x03 \x01(\x08\" \n\x08\x45ncoding\x12\x0b\n\x07UNKNOWN\x10\x00\x12\x07\n\x03WAV\x10\x0b\x62\x06proto3')
_AUDIOPLUGINDATA = DESCRIPTOR.message_types_by_name['AudioPluginData']
_AUDIOPLUGINDATA_ENCODING = _AUDIOPLUGINDATA.enum_types_by_name['Encoding']
AudioPluginData = _reflection.GeneratedProtocolMessageType('AudioPluginData', (_message.Message,), {
'DESCRIPTOR' : _AUDIOPLUGINDATA,
'__module__' : 'tensorboard.plugins.audio.plugin_data_pb2'
# @@protoc_insertion_point(class_scope:tensorboard.AudioPluginData)
})
_sym_db.RegisterMessage(AudioPluginData)
if _descriptor._USE_C_DESCRIPTORS == False:
DESCRIPTOR._options = None
_AUDIOPLUGINDATA._serialized_start=61
_AUDIOPLUGINDATA._serialized_end=215
_AUDIOPLUGINDATA_ENCODING._serialized_start=183
_AUDIOPLUGINDATA_ENCODING._serialized_end=215
# @@protoc_insertion_point(module_scope)

View File

@ -0,0 +1,232 @@
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Audio summaries and TensorFlow operations to create them.
An audio summary stores a rank-2 string tensor of shape `[k, 2]`, where
`k` is the number of audio clips recorded in the summary. Each row of
the tensor is a pair `[encoded_audio, label]`, where `encoded_audio` is
a binary string whose encoding is specified in the summary metadata, and
`label` is a UTF-8 encoded Markdown string describing the audio clip.
NOTE: This module is in beta, and its API is subject to change, but the
data that it stores to disk will be supported forever.
"""
import functools
import warnings
import numpy as np
from tensorboard.util import encoder as encoder_util
from tensorboard.plugins.audio import metadata
from tensorboard.plugins.audio import summary_v2
# Export V2 versions.
audio = summary_v2.audio
_LABELS_WARNING = (
"Labels on audio summaries are deprecated and will be removed. "
"See <https://github.com/tensorflow/tensorboard/issues/3513>."
)
def op(
name,
audio,
sample_rate,
labels=None,
max_outputs=3,
encoding=None,
display_name=None,
description=None,
collections=None,
):
"""Create a legacy audio summary op for use in a TensorFlow graph.
Arguments:
name: A unique name for the generated summary node.
audio: A `Tensor` representing audio data with shape `[k, t, c]`,
where `k` is the number of audio clips, `t` is the number of
frames, and `c` is the number of channels. Elements should be
floating-point values in `[-1.0, 1.0]`. Any of the dimensions may
be statically unknown (i.e., `None`).
sample_rate: An `int` or rank-0 `int32` `Tensor` that represents the
sample rate, in Hz. Must be positive.
labels: Deprecated. Do not set.
max_outputs: Optional `int` or rank-0 integer `Tensor`. At most this
many audio clips will be emitted at each step. When more than
`max_outputs` many clips are provided, the first `max_outputs`
many clips will be used and the rest silently discarded.
encoding: A constant `str` (not string tensor) indicating the
desired encoding. You can choose any format you like, as long as
it's "wav". Please see the "API compatibility note" below.
display_name: Optional name for this summary in TensorBoard, as a
constant `str`. Defaults to `name`.
description: Optional long-form description for this summary, as a
constant `str`. Markdown is supported. Defaults to empty.
collections: Optional list of graph collections keys. The new
summary op is added to these collections. Defaults to
`[Graph Keys.SUMMARIES]`.
Returns:
A TensorFlow summary op.
API compatibility note: The default value of the `encoding`
argument is _not_ guaranteed to remain unchanged across TensorBoard
versions. In the future, we will by default encode as FLAC instead of
as WAV. If the specific format is important to you, please provide a
file format explicitly.
"""
if labels is not None:
warnings.warn(_LABELS_WARNING)
# TODO(nickfelt): remove on-demand imports once dep situation is fixed.
import tensorflow.compat.v1 as tf
if display_name is None:
display_name = name
if encoding is None:
encoding = "wav"
if encoding == "wav":
encoding = metadata.Encoding.Value("WAV")
encoder = functools.partial(
tf.audio.encode_wav, sample_rate=sample_rate
)
else:
raise ValueError("Unknown encoding: %r" % encoding)
with tf.name_scope(name), tf.control_dependencies(
[tf.assert_rank(audio, 3)]
):
limited_audio = audio[:max_outputs]
encoded_audio = tf.map_fn(
encoder, limited_audio, dtype=tf.string, name="encode_each_audio"
)
if labels is None:
limited_labels = tf.tile([""], tf.shape(input=limited_audio)[:1])
else:
limited_labels = labels[:max_outputs]
tensor = tf.transpose(a=tf.stack([encoded_audio, limited_labels]))
summary_metadata = metadata.create_summary_metadata(
display_name=display_name,
description=description,
encoding=encoding,
)
return tf.summary.tensor_summary(
name="audio_summary",
tensor=tensor,
collections=collections,
summary_metadata=summary_metadata,
)
def pb(
name,
audio,
sample_rate,
labels=None,
max_outputs=3,
encoding=None,
display_name=None,
description=None,
):
"""Create a legacy audio summary protobuf.
This behaves as if you were to create an `op` with the same arguments
(wrapped with constant tensors where appropriate) and then execute
that summary op in a TensorFlow session.
Arguments:
name: A unique name for the generated summary node.
audio: An `np.array` representing audio data with shape `[k, t, c]`,
where `k` is the number of audio clips, `t` is the number of
frames, and `c` is the number of channels. Elements should be
floating-point values in `[-1.0, 1.0]`.
sample_rate: An `int` that represents the sample rate, in Hz.
Must be positive.
labels: Deprecated. Do not set.
max_outputs: Optional `int`. At most this many audio clips will be
emitted. When more than `max_outputs` many clips are provided, the
first `max_outputs` many clips will be used and the rest silently
discarded.
encoding: A constant `str` indicating the desired encoding. You
can choose any format you like, as long as it's "wav". Please see
the "API compatibility note" below.
display_name: Optional name for this summary in TensorBoard, as a
`str`. Defaults to `name`.
description: Optional long-form description for this summary, as a
`str`. Markdown is supported. Defaults to empty.
Returns:
A `tf.Summary` protobuf object.
API compatibility note: The default value of the `encoding`
argument is _not_ guaranteed to remain unchanged across TensorBoard
versions. In the future, we will by default encode as FLAC instead of
as WAV. If the specific format is important to you, please provide a
file format explicitly.
"""
if labels is not None:
warnings.warn(_LABELS_WARNING)
# TODO(nickfelt): remove on-demand imports once dep situation is fixed.
import tensorflow.compat.v1 as tf
audio = np.array(audio)
if audio.ndim != 3:
raise ValueError("Shape %r must have rank 3" % (audio.shape,))
if encoding is None:
encoding = "wav"
if encoding == "wav":
encoding = metadata.Encoding.Value("WAV")
encoder = functools.partial(
encoder_util.encode_wav, samples_per_second=sample_rate
)
else:
raise ValueError("Unknown encoding: %r" % encoding)
limited_audio = audio[:max_outputs]
if labels is None:
limited_labels = [b""] * len(limited_audio)
else:
limited_labels = [
tf.compat.as_bytes(label) for label in labels[:max_outputs]
]
encoded_audio = [encoder(a) for a in limited_audio]
content = np.array([encoded_audio, limited_labels]).transpose()
tensor = tf.make_tensor_proto(content, dtype=tf.string)
if display_name is None:
display_name = name
summary_metadata = metadata.create_summary_metadata(
display_name=display_name, description=description, encoding=encoding
)
tf_summary_metadata = tf.SummaryMetadata.FromString(
summary_metadata.SerializeToString()
)
summary = tf.Summary()
summary.value.add(
tag="%s/audio_summary" % name,
metadata=tf_summary_metadata,
tensor=tensor,
)
return summary

View File

@ -0,0 +1,125 @@
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Audio summaries and TensorFlow operations to create them, V2 versions.
An audio summary stores a rank-2 string tensor of shape `[k, 2]`, where
`k` is the number of audio clips recorded in the summary. Each row of
the tensor is a pair `[encoded_audio, label]`, where `encoded_audio` is
a binary string whose encoding is specified in the summary metadata, and
`label` is a UTF-8 encoded Markdown string describing the audio clip.
"""
import functools
from tensorboard.compat import tf2 as tf
from tensorboard.plugins.audio import metadata
from tensorboard.util import lazy_tensor_creator
def audio(
name,
data,
sample_rate,
step=None,
max_outputs=3,
encoding=None,
description=None,
):
"""Write an audio summary.
Arguments:
name: A name for this summary. The summary tag used for TensorBoard will
be this name prefixed by any active name scopes.
data: A `Tensor` representing audio data with shape `[k, t, c]`,
where `k` is the number of audio clips, `t` is the number of
frames, and `c` is the number of channels. Elements should be
floating-point values in `[-1.0, 1.0]`. Any of the dimensions may
be statically unknown (i.e., `None`).
sample_rate: An `int` or rank-0 `int32` `Tensor` that represents the
sample rate, in Hz. Must be positive.
step: Explicit `int64`-castable monotonic step value for this summary. If
omitted, this defaults to `tf.summary.experimental.get_step()`, which must
not be None.
max_outputs: Optional `int` or rank-0 integer `Tensor`. At most this
many audio clips will be emitted at each step. When more than
`max_outputs` many clips are provided, the first `max_outputs`
many clips will be used and the rest silently discarded.
encoding: Optional constant `str` for the desired encoding. Only "wav"
is currently supported, but this is not guaranteed to remain the
default, so if you want "wav" in particular, set this explicitly.
description: Optional long-form description for this summary, as a
constant `str`. Markdown is supported. Defaults to empty.
Returns:
True on success, or false if no summary was emitted because no default
summary writer was available.
Raises:
ValueError: if a default writer exists, but no step was provided and
`tf.summary.experimental.get_step()` is None.
"""
audio_ops = getattr(tf, "audio", None)
if audio_ops is None:
# Fallback for older versions of TF without tf.audio.
from tensorflow.python.ops import gen_audio_ops as audio_ops
if encoding is None:
encoding = "wav"
if encoding != "wav":
raise ValueError("Unknown encoding: %r" % encoding)
summary_metadata = metadata.create_summary_metadata(
display_name=None,
description=description,
encoding=metadata.Encoding.Value("WAV"),
)
inputs = [data, sample_rate, max_outputs, step]
# TODO(https://github.com/tensorflow/tensorboard/issues/2109): remove fallback
summary_scope = (
getattr(tf.summary.experimental, "summary_scope", None)
or tf.summary.summary_scope
)
with summary_scope(name, "audio_summary", values=inputs) as (tag, _):
# Defer audio encoding preprocessing by passing it as a callable to write(),
# wrapped in a LazyTensorCreator for backwards compatibility, so that we
# only do this work when summaries are actually written.
@lazy_tensor_creator.LazyTensorCreator
def lazy_tensor():
tf.debugging.assert_rank(data, 3)
tf.debugging.assert_non_negative(max_outputs)
limited_audio = data[:max_outputs]
encode_fn = functools.partial(
audio_ops.encode_wav, sample_rate=sample_rate
)
encoded_audio = tf.map_fn(
encode_fn,
limited_audio,
dtype=tf.string,
name="encode_each_audio",
)
# Workaround for map_fn returning float dtype for an empty elems input.
encoded_audio = tf.cond(
tf.shape(input=encoded_audio)[0] > 0,
lambda: encoded_audio,
lambda: tf.constant([], tf.string),
)
limited_labels = tf.tile([""], tf.shape(input=limited_audio)[:1])
return tf.transpose(a=tf.stack([encoded_audio, limited_labels]))
# To ensure that audio encoding logic is only executed when summaries
# are written, we pass callable to `tensor` parameter.
return tf.summary.write(
tag=tag, tensor=lazy_tensor, step=step, metadata=summary_metadata
)

View File

@ -0,0 +1,368 @@
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""TensorBoard Plugin abstract base class.
Every plugin in TensorBoard must extend and implement the abstract
methods of this base class.
"""
from abc import ABCMeta
from abc import abstractmethod
class TBPlugin(metaclass=ABCMeta):
"""TensorBoard plugin interface.
Every plugin must extend from this class.
Subclasses should have a trivial constructor that takes a TBContext
argument. Any operation that might throw an exception should either be
done lazily or made safe with a TBLoader subclass, so the plugin won't
negatively impact the rest of TensorBoard.
Fields:
plugin_name: The plugin_name will also be a prefix in the http
handlers, e.g. `data/plugins/$PLUGIN_NAME/$HANDLER` The plugin
name must be unique for each registered plugin, or a ValueError
will be thrown when the application is constructed. The plugin
name must only contain characters among [A-Za-z0-9_.-], and must
be nonempty, or a ValueError will similarly be thrown.
"""
plugin_name = None
def __init__(self, context):
"""Initializes this plugin.
The default implementation does nothing. Subclasses are encouraged
to override this and save any necessary fields from the `context`.
Args:
context: A `base_plugin.TBContext` object.
"""
pass
@abstractmethod
def get_plugin_apps(self):
"""Returns a set of WSGI applications that the plugin implements.
Each application gets registered with the tensorboard app and is served
under a prefix path that includes the name of the plugin.
Returns:
A dict mapping route paths to WSGI applications. Each route path
should include a leading slash.
"""
raise NotImplementedError()
@abstractmethod
def is_active(self):
"""Determines whether this plugin is active.
A plugin may not be active for instance if it lacks relevant data. If a
plugin is inactive, the frontend may avoid issuing requests to its routes.
Returns:
A boolean value. Whether this plugin is active.
"""
raise NotImplementedError()
def frontend_metadata(self):
"""Defines how the plugin will be displayed on the frontend.
The base implementation returns a default value. Subclasses
should override this and specify either an `es_module_path` or
(for legacy plugins) an `element_name`, and are encouraged to
set any other relevant attributes.
"""
return FrontendMetadata()
def data_plugin_names(self):
"""Experimental. Lists plugins whose summary data this plugin reads.
Returns:
A collection of strings representing plugin names (as read
from `SummaryMetadata.plugin_data.plugin_name`) from which
this plugin may read data. Defaults to `(self.plugin_name,)`.
"""
return (self.plugin_name,)
class FrontendMetadata:
"""Metadata required to render a plugin on the frontend.
Each argument to the constructor is publicly accessible under a
field of the same name. See constructor docs for further details.
"""
def __init__(
self,
*,
disable_reload=None,
element_name=None,
es_module_path=None,
remove_dom=None,
tab_name=None,
is_ng_component=None,
):
"""Creates a `FrontendMetadata` value.
The argument list is sorted and may be extended in the future;
therefore, callers must pass only named arguments to this
constructor.
Args:
disable_reload: Whether to disable the reload button and
auto-reload timer. A `bool`; defaults to `False`.
element_name: For legacy plugins, name of the custom element
defining the plugin frontend: e.g., `"tf-scalar-dashboard"`.
A `str` or `None` (for iframed plugins). Mutually exclusive
with `es_module_path`.
es_module_path: ES module to use as an entry point to this plugin.
A `str` that is a key in the result of `get_plugin_apps()`, or
`None` for legacy plugins bundled with TensorBoard as part of
`webfiles.zip`. Mutually exclusive with legacy `element_name`
remove_dom: Whether to remove the plugin DOM when switching to a
different plugin, to trigger the Polymer 'detached' event.
A `bool`; defaults to `False`.
tab_name: Name to show in the menu item for this dashboard within
the navigation bar. May differ from the plugin name: for
instance, the tab name should not use underscores to separate
words. Should be a `str` or `None` (the default; indicates to
use the plugin name as the tab name).
is_ng_component: Set to `True` only for built-in Angular plugins.
In this case, the `plugin_name` property of the Plugin, which is
mapped to the `id` property in JavaScript's `UiPluginMetadata` type,
is used to select the Angular component. A `True` value is mutually
exclusive with `element_name` and `es_module_path`.
"""
self._disable_reload = (
False if disable_reload is None else disable_reload
)
self._element_name = element_name
self._es_module_path = es_module_path
self._remove_dom = False if remove_dom is None else remove_dom
self._tab_name = tab_name
self._is_ng_component = (
False if is_ng_component is None else is_ng_component
)
@property
def disable_reload(self):
return self._disable_reload
@property
def element_name(self):
return self._element_name
@property
def is_ng_component(self):
return self._is_ng_component
@property
def es_module_path(self):
return self._es_module_path
@property
def remove_dom(self):
return self._remove_dom
@property
def tab_name(self):
return self._tab_name
def __eq__(self, other):
if not isinstance(other, FrontendMetadata):
return False
if self._disable_reload != other._disable_reload:
return False
if self._disable_reload != other._disable_reload:
return False
if self._element_name != other._element_name:
return False
if self._es_module_path != other._es_module_path:
return False
if self._remove_dom != other._remove_dom:
return False
if self._tab_name != other._tab_name:
return False
return True
def __hash__(self):
return hash(
(
self._disable_reload,
self._element_name,
self._es_module_path,
self._remove_dom,
self._tab_name,
self._is_ng_component,
)
)
def __repr__(self):
return "FrontendMetadata(%s)" % ", ".join(
(
"disable_reload=%r" % self._disable_reload,
"element_name=%r" % self._element_name,
"es_module_path=%r" % self._es_module_path,
"remove_dom=%r" % self._remove_dom,
"tab_name=%r" % self._tab_name,
"is_ng_component=%r" % self._is_ng_component,
)
)
class TBContext:
"""Magic container of information passed from TensorBoard core to plugins.
A TBContext instance is passed to the constructor of a TBPlugin class. Plugins
are strongly encouraged to assume that any of these fields can be None. In
cases when a field is considered mandatory by a plugin, it can either crash
with ValueError, or silently choose to disable itself by returning False from
its is_active method.
All fields in this object are thread safe.
"""
def __init__(
self,
*,
assets_zip_provider=None,
data_provider=None,
flags=None,
logdir=None,
multiplexer=None,
plugin_name_to_instance=None,
sampling_hints=None,
window_title=None,
):
"""Instantiates magic container.
The argument list is sorted and may be extended in the future; therefore,
callers must pass only named arguments to this constructor.
Args:
assets_zip_provider: A function that returns a newly opened file handle
for a zip file containing all static assets. The file names inside the
zip file are considered absolute paths on the web server. The file
handle this function returns must be closed. It is assumed that you
will pass this file handle to zipfile.ZipFile. This zip file should
also have been created by the tensorboard_zip_file build rule.
data_provider: Instance of `tensorboard.data.provider.DataProvider`. May
be `None` if `flags.generic_data` is set to `"false"`.
flags: An object of the runtime flags provided to TensorBoard to their
values.
logdir: The string logging directory TensorBoard was started with.
multiplexer: An EventMultiplexer with underlying TB data. Plugins should
copy this data over to the database when the db fields are set.
plugin_name_to_instance: A mapping between plugin name to instance.
Plugins may use this property to access other plugins. The context
object is passed to plugins during their construction, so a given
plugin may be absent from this mapping until it is registered. Plugin
logic should handle cases in which a plugin is absent from this
mapping, lest a KeyError is raised.
sampling_hints: Map from plugin name to `int` or `NoneType`, where
the value represents the user-specified downsampling limit as
given to the `--samples_per_plugin` flag, or `None` if none was
explicitly given for this plugin.
window_title: A string specifying the window title.
"""
self.assets_zip_provider = assets_zip_provider
self.data_provider = data_provider
self.flags = flags
self.logdir = logdir
self.multiplexer = multiplexer
self.plugin_name_to_instance = plugin_name_to_instance
self.sampling_hints = sampling_hints
self.window_title = window_title
class TBLoader:
"""TBPlugin factory base class.
Plugins can override this class to customize how a plugin is loaded at
startup. This might entail adding command-line arguments, checking if
optional dependencies are installed, and potentially also specializing
the plugin class at runtime.
When plugins use optional dependencies, the loader needs to be
specified in its own module. That way it's guaranteed to be
importable, even if the `TBPlugin` itself can't be imported.
Subclasses must have trivial constructors.
"""
def define_flags(self, parser):
"""Adds plugin-specific CLI flags to parser.
The default behavior is to do nothing.
When overriding this method, it's recommended that plugins call the
`parser.add_argument_group(plugin_name)` method for readability. No
flags should be specified that would cause `parse_args([])` to fail.
Args:
parser: The argument parsing object, which may be mutated.
"""
pass
def fix_flags(self, flags):
"""Allows flag values to be corrected or validated after parsing.
Args:
flags: The parsed argparse.Namespace object.
Raises:
base_plugin.FlagsError: If a flag is invalid or a required
flag is not passed.
"""
pass
def load(self, context):
"""Loads a TBPlugin instance during the setup phase.
Args:
context: The TBContext instance.
Returns:
A plugin instance or None if it could not be loaded. Loaders that return
None are skipped.
:type context: TBContext
:rtype: TBPlugin | None
"""
return None
class BasicLoader(TBLoader):
"""Simple TBLoader that's sufficient for most plugins."""
def __init__(self, plugin_class):
"""Creates simple plugin instance maker.
:param plugin_class: :class:`TBPlugin`
"""
self.plugin_class = plugin_class
def load(self, context):
return self.plugin_class(context)
class FlagsError(ValueError):
"""Raised when a command line flag is not specified or is invalid."""
pass

View File

@ -0,0 +1,755 @@
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""TensorBoard core plugin package."""
import argparse
import functools
import gzip
import io
import mimetypes
import posixpath
import zipfile
from werkzeug import utils
from werkzeug import wrappers
from tensorboard import plugin_util
from tensorboard.backend import http_util
from tensorboard.plugins import base_plugin
from tensorboard.util import grpc_util
from tensorboard.util import tb_logging
from tensorboard import version
logger = tb_logging.get_logger()
# If no port is specified, try to bind to this port. See help for --port
# for more details.
DEFAULT_PORT = 6006
# Valid javascript mimetypes that we have seen configured, in practice.
# Historically (~2020-2022) we saw "application/javascript" exclusively but with
# RFC 9239 (https://www.rfc-editor.org/rfc/rfc9239) we saw some systems quickly
# transition to 'text/javascript'.
JS_MIMETYPES = ["text/javascript", "application/javascript"]
JS_CACHE_EXPIRATION_IN_SECS = 86400
class CorePlugin(base_plugin.TBPlugin):
"""Core plugin for TensorBoard.
This plugin serves runs, configuration data, and static assets. This
plugin should always be present in a TensorBoard WSGI application.
"""
plugin_name = "core"
def __init__(self, context, include_debug_info=None):
"""Instantiates CorePlugin.
Args:
context: A base_plugin.TBContext instance.
include_debug_info: If true, `/data/environment` will include some
basic information like the TensorBoard server version. Disabled by
default to prevent surprising information leaks in custom builds of
TensorBoard.
"""
self._flags = context.flags
logdir_spec = context.flags.logdir_spec if context.flags else ""
self._logdir = context.logdir or logdir_spec
self._window_title = context.window_title
self._path_prefix = context.flags.path_prefix if context.flags else None
self._assets_zip_provider = context.assets_zip_provider
self._data_provider = context.data_provider
self._include_debug_info = bool(include_debug_info)
def is_active(self):
return True
def get_plugin_apps(self):
apps = {
"/___rPc_sWiTcH___": self._send_404_without_logging,
"/audio": self._redirect_to_index,
"/data/environment": self._serve_environment,
"/data/logdir": self._serve_logdir,
"/data/runs": self._serve_runs,
"/data/experiments": self._serve_experiments,
"/data/experiment_runs": self._serve_experiment_runs,
"/data/notifications": self._serve_notifications,
"/data/window_properties": self._serve_window_properties,
"/events": self._redirect_to_index,
"/favicon.ico": self._send_404_without_logging,
"/graphs": self._redirect_to_index,
"/histograms": self._redirect_to_index,
"/images": self._redirect_to_index,
}
apps.update(self.get_resource_apps())
return apps
def get_resource_apps(self):
apps = {}
if not self._assets_zip_provider:
return apps
with self._assets_zip_provider() as fp:
with zipfile.ZipFile(fp) as zip_:
for path in zip_.namelist():
content = zip_.read(path)
# Opt out of gzipping index.html
if path == "index.html":
apps["/" + path] = functools.partial(
self._serve_index, content
)
continue
gzipped_asset_bytes = _gzip(content)
wsgi_app = functools.partial(
self._serve_asset, path, gzipped_asset_bytes
)
apps["/" + path] = wsgi_app
apps["/"] = apps["/index.html"]
return apps
@wrappers.Request.application
def _send_404_without_logging(self, request):
return http_util.Respond(request, "Not found", "text/plain", code=404)
@wrappers.Request.application
def _redirect_to_index(self, unused_request):
return utils.redirect("/")
@wrappers.Request.application
def _serve_asset(self, path, gzipped_asset_bytes, request):
"""Serves a pre-gzipped static asset from the zip file."""
mimetype = mimetypes.guess_type(path)[0] or "application/octet-stream"
# Cache JS resources while keep others do not cache.
expires = (
JS_CACHE_EXPIRATION_IN_SECS
if request.args.get("_file_hash") and mimetype in JS_MIMETYPES
else 0
)
return http_util.Respond(
request,
gzipped_asset_bytes,
mimetype,
content_encoding="gzip",
expires=expires,
)
@wrappers.Request.application
def _serve_index(self, index_asset_bytes, request):
"""Serves index.html content.
Note that we opt out of gzipping index.html to write preamble before the
resource content. This inflates the resource size from 2x kiB to 1xx
kiB, but we require an ability to flush preamble with the HTML content.
"""
relpath = (
posixpath.relpath(self._path_prefix, request.script_root)
if self._path_prefix
else "."
)
meta_header = (
'<!doctype html><meta name="tb-relative-root" content="%s/">'
% relpath
)
content = meta_header.encode("utf-8") + index_asset_bytes
# By passing content_encoding, disallow gzipping. Bloats the content
# from ~25 kiB to ~120 kiB but reduces CPU usage and avoid 3ms worth of
# gzipping.
return http_util.Respond(
request, content, "text/html", content_encoding="identity"
)
@wrappers.Request.application
def _serve_environment(self, request):
"""Serve a JSON object describing the TensorBoard parameters."""
ctx = plugin_util.context(request.environ)
experiment = plugin_util.experiment_id(request.environ)
md = self._data_provider.experiment_metadata(
ctx, experiment_id=experiment
)
environment = {
"version": version.VERSION,
"data_location": md.data_location,
"window_title": self._window_title,
"experiment_name": md.experiment_name,
"experiment_description": md.experiment_description,
"creation_time": md.creation_time,
}
if self._include_debug_info:
environment["debug"] = {
"data_provider": str(self._data_provider),
"flags": self._render_flags(),
}
return http_util.Respond(
request,
environment,
"application/json",
)
def _render_flags(self):
"""Return a JSON-and-human-friendly version of `self._flags`.
Like `json.loads(json.dumps(self._flags, default=str))` but
without the wasteful serialization overhead.
"""
if self._flags is None:
return None
def go(x):
if isinstance(x, (type(None), str, int, float)):
return x
if isinstance(x, (list, tuple)):
return [go(v) for v in x]
if isinstance(x, dict):
return {str(k): go(v) for (k, v) in x.items()}
return str(x)
return go(vars(self._flags))
@wrappers.Request.application
def _serve_logdir(self, request):
"""Respond with a JSON object containing this TensorBoard's logdir."""
# TODO(chihuahua): Remove this method once the frontend instead uses the
# /data/environment route (and no deps throughout Google use the
# /data/logdir route).
return http_util.Respond(
request, {"logdir": self._logdir}, "application/json"
)
@wrappers.Request.application
def _serve_window_properties(self, request):
"""Serve a JSON object containing this TensorBoard's window
properties."""
# TODO(chihuahua): Remove this method once the frontend instead uses the
# /data/environment route.
return http_util.Respond(
request, {"window_title": self._window_title}, "application/json"
)
@wrappers.Request.application
def _serve_runs(self, request):
"""Serve a JSON array of run names, ordered by run started time.
Sort order is by started time (aka first event time) with empty
times sorted last, and then ties are broken by sorting on the
run name.
"""
ctx = plugin_util.context(request.environ)
experiment = plugin_util.experiment_id(request.environ)
runs = sorted(
self._data_provider.list_runs(ctx, experiment_id=experiment),
key=lambda run: (
run.start_time if run.start_time is not None else float("inf"),
run.run_name,
),
)
run_names = [run.run_name for run in runs]
return http_util.Respond(request, run_names, "application/json")
@wrappers.Request.application
def _serve_experiments(self, request):
"""Serve a JSON array of experiments.
Experiments are ordered by experiment started time (aka first
event time) with empty times sorted last, and then ties are
broken by sorting on the experiment name.
"""
results = self.list_experiments_impl()
return http_util.Respond(request, results, "application/json")
def list_experiments_impl(self):
return []
@wrappers.Request.application
def _serve_experiment_runs(self, request):
"""Serve a JSON runs of an experiment, specified with query param
`experiment`, with their nested data, tag, populated.
Runs returned are ordered by started time (aka first event time)
with empty times sorted last, and then ties are broken by
sorting on the run name. Tags are sorted by its name,
displayName, and lastly, inserted time.
"""
results = []
return http_util.Respond(request, results, "application/json")
@wrappers.Request.application
def _serve_notifications(self, request):
"""Serve JSON payload of notifications to show in the UI."""
response = utils.redirect("../notifications_note.json")
# Disable Werkzeug's automatic Location header correction routine, which
# absolutizes relative paths "to be RFC conformant" [1], but this is
# based on an outdated HTTP/1.1 RFC; the current one allows them:
# https://tools.ietf.org/html/rfc7231#section-7.1.2
response.autocorrect_location_header = False
return response
class CorePluginLoader(base_plugin.TBLoader):
"""CorePlugin factory."""
def __init__(self, include_debug_info=None):
self._include_debug_info = include_debug_info
def define_flags(self, parser):
"""Adds standard TensorBoard CLI flags to parser."""
parser.add_argument(
"--logdir",
metavar="PATH",
type=str,
default="",
help="""\
Directory where TensorBoard will look to find TensorFlow event files
that it can display. TensorBoard will recursively walk the directory
structure rooted at logdir, looking for .*tfevents.* files.
A leading tilde will be expanded with the semantics of Python's
os.expanduser function.
""",
)
parser.add_argument(
"--logdir_spec",
metavar="PATH_SPEC",
type=str,
default="",
help="""\
Like `--logdir`, but with special interpretation for commas and colons:
commas separate multiple runs, where a colon specifies a new name for a
run. For example:
`tensorboard --logdir_spec=name1:/path/to/logs/1,name2:/path/to/logs/2`.
This flag is discouraged and can usually be avoided. TensorBoard walks
log directories recursively; for finer-grained control, prefer using a
symlink tree. Some features may not work when using `--logdir_spec`
instead of `--logdir`.
""",
)
parser.add_argument(
"--host",
metavar="ADDR",
type=str,
default=None, # like localhost, but prints a note about `--bind_all`
help="""\
What host to listen to (default: localhost). To serve to the entire local
network on both IPv4 and IPv6, see `--bind_all`, with which this option is
mutually exclusive.
""",
)
parser.add_argument(
"--bind_all",
action="store_true",
help="""\
Serve on all public interfaces. This will expose your TensorBoard instance to
the network on both IPv4 and IPv6 (where available). Mutually exclusive with
`--host`.
""",
)
parser.add_argument(
"--port",
metavar="PORT",
type=lambda s: (None if s == "default" else int(s)),
default="default",
help="""\
Port to serve TensorBoard on. Pass 0 to request an unused port selected
by the operating system, or pass "default" to try to bind to the default
port (%s) but search for a nearby free port if the default port is
unavailable. (default: "default").\
"""
% DEFAULT_PORT,
)
parser.add_argument(
"--reuse_port",
metavar="BOOL",
# Custom str-to-bool converter since regular bool() doesn't work.
type=lambda v: {"true": True, "false": False}.get(v.lower(), v),
choices=[True, False],
default=False,
help="""\
Enables the SO_REUSEPORT option on the socket opened by TensorBoard's HTTP
server, for platforms that support it. This is useful in cases when a parent
process has obtained the port already and wants to delegate access to the
port to TensorBoard as a subprocess.(default: %(default)s).\
""",
)
parser.add_argument(
"--load_fast",
type=str,
default="auto",
choices=["false", "auto", "true"],
help="""\
Use alternate mechanism to load data. Typically 100x faster or more, but only
available on some platforms and invocations. Defaults to "auto" to use this new
mode only if available, otherwise falling back to the legacy loading path. Set
to "true" to suppress the advisory note and hard-fail if the fast codepath is
not available. Set to "false" to always fall back. Feedback/issues:
https://github.com/tensorflow/tensorboard/issues/4784
(default: %(default)s)
""",
)
parser.add_argument(
"--extra_data_server_flags",
type=str,
default="",
help="""\
Experimental. With `--load_fast`, pass these additional command-line flags to
the data server. Subject to POSIX word splitting per `shlex.split`. Meant for
debugging; not officially supported.
""",
)
parser.add_argument(
"--grpc_creds_type",
type=grpc_util.ChannelCredsType,
default=grpc_util.ChannelCredsType.LOCAL,
choices=grpc_util.ChannelCredsType.choices(),
help="""\
Experimental. The type of credentials to use to connect to the data server.
(default: %(default)s)
""",
)
parser.add_argument(
"--grpc_data_provider",
metavar="PORT",
type=str,
default="",
help="""\
Experimental. Address of a gRPC server exposing a data provider. Set to empty
string to disable. (default: %(default)s)
""",
)
parser.add_argument(
"--purge_orphaned_data",
metavar="BOOL",
# Custom str-to-bool converter since regular bool() doesn't work.
type=lambda v: {"true": True, "false": False}.get(v.lower(), v),
choices=[True, False],
default=True,
help="""\
Whether to purge data that may have been orphaned due to TensorBoard
restarts. Setting --purge_orphaned_data=False can be used to debug data
disappearance. (default: %(default)s)\
""",
)
parser.add_argument(
"--db",
metavar="URI",
type=str,
default="",
help="""\
[experimental] sets SQL database URI and enables DB backend mode, which is
read-only unless --db_import is also passed.\
""",
)
parser.add_argument(
"--db_import",
action="store_true",
help="""\
[experimental] enables DB read-and-import mode, which in combination with
--logdir imports event files into a DB backend on the fly. The backing DB is
temporary unless --db is also passed to specify a DB path to use.\
""",
)
parser.add_argument(
"--inspect",
action="store_true",
help="""\
Prints digests of event files to command line.
This is useful when no data is shown on TensorBoard, or the data shown
looks weird.
Must specify one of `logdir` or `event_file` flag.
Example usage:
`tensorboard --inspect --logdir mylogdir --tag loss`
See tensorboard/backend/event_processing/event_file_inspector.py for more info.\
""",
)
# This flag has a "_tb" suffix to avoid conflicting with an internal flag
# named --version. Note that due to argparse auto-expansion of unambiguous
# flag prefixes, you can still invoke this as `tensorboard --version`.
parser.add_argument(
"--version_tb",
action="store_true",
help="Prints the version of Tensorboard",
)
parser.add_argument(
"--tag",
metavar="TAG",
type=str,
default="",
help="tag to query for; used with --inspect",
)
parser.add_argument(
"--event_file",
metavar="PATH",
type=str,
default="",
help="""\
The particular event file to query for. Only used if --inspect is
present and --logdir is not specified.\
""",
)
parser.add_argument(
"--path_prefix",
metavar="PATH",
type=str,
default="",
help="""\
An optional, relative prefix to the path, e.g. "/path/to/tensorboard".
resulting in the new base url being located at
localhost:6006/path/to/tensorboard under default settings. A leading
slash is required when specifying the path_prefix. A trailing slash is
optional and has no effect. The path_prefix can be leveraged for path
based routing of an ELB when the website base_url is not available e.g.
"example.site.com/path/to/tensorboard/".\
""",
)
parser.add_argument(
"--window_title",
metavar="TEXT",
type=str,
default="",
help="changes title of browser window",
)
parser.add_argument(
"--max_reload_threads",
metavar="COUNT",
type=int,
default=1,
help="""\
The max number of threads that TensorBoard can use to reload runs. Not
relevant for db read-only mode. Each thread reloads one run at a time.
(default: %(default)s)\
""",
)
parser.add_argument(
"--reload_interval",
metavar="SECONDS",
type=_nonnegative_float,
default=5.0,
help="""\
How often the backend should load more data, in seconds. Set to 0 to
load just once at startup. Must be non-negative. (default: %(default)s)\
""",
)
parser.add_argument(
"--reload_task",
metavar="TYPE",
type=str,
default="auto",
choices=["auto", "thread", "process", "blocking"],
help="""\
[experimental] The mechanism to use for the background data reload task.
The default "auto" option will conditionally use threads for legacy reloading
and a child process for DB import reloading. The "process" option is only
useful with DB import mode. The "blocking" option will block startup until
reload finishes, and requires --load_interval=0. (default: %(default)s)\
""",
)
parser.add_argument(
"--reload_multifile",
metavar="BOOL",
# Custom str-to-bool converter since regular bool() doesn't work.
type=lambda v: {"true": True, "false": False}.get(v.lower(), v),
choices=[True, False],
default=None,
help="""\
[experimental] If true, this enables experimental support for continuously
polling multiple event files in each run directory for newly appended data
(rather than only polling the last event file). Event files will only be
polled as long as their most recently read data is newer than the threshold
defined by --reload_multifile_inactive_secs, to limit resource usage. Beware
of running out of memory if the logdir contains many active event files.
(default: false)\
""",
)
parser.add_argument(
"--reload_multifile_inactive_secs",
metavar="SECONDS",
type=int,
default=86400,
help="""\
[experimental] Configures the age threshold in seconds at which an event file
that has no event wall time more recent than that will be considered an
inactive file and no longer polled (to limit resource usage). If set to -1,
no maximum age will be enforced, but beware of running out of memory and
heavier filesystem read traffic. If set to 0, this reverts to the older
last-file-only polling strategy (akin to --reload_multifile=false).
(default: %(default)s - intended to ensure an event file remains active if
it receives new data at least once per 24 hour period)\
""",
)
parser.add_argument(
"--generic_data",
metavar="TYPE",
type=str,
default="auto",
choices=["false", "auto", "true"],
help="""\
[experimental] Hints whether plugins should read from generic data
provider infrastructure. For plugins that support only the legacy
multiplexer APIs or only the generic data APIs, this option has no
effect. The "auto" option enables this only for plugins that are
considered to have stable support for generic data providers. (default:
%(default)s)\
""",
)
parser.add_argument(
"--samples_per_plugin",
type=_parse_samples_per_plugin,
default="",
help="""\
An optional comma separated list of plugin_name=num_samples pairs to
explicitly specify how many samples to keep per tag for that plugin. For
unspecified plugins, TensorBoard randomly downsamples logged summaries
to reasonable values to prevent out-of-memory errors for long running
jobs. This flag allows fine control over that downsampling. Note that if a
plugin is not specified in this list, a plugin-specific default number of
samples will be enforced. (for example, 10 for images, 500 for histograms,
and 1000 for scalars). Most users should not need to set this flag.\
""",
)
parser.add_argument(
"--detect_file_replacement",
metavar="BOOL",
# Custom str-to-bool converter since regular bool() doesn't work.
type=lambda v: {"true": True, "false": False}.get(v.lower(), v),
choices=[True, False],
default=None,
help="""\
[experimental] If true, this enables experimental support for detecting when
event files are replaced with new versions that contain additional data. This is
not needed in the normal case where new data is either appended to an existing
file or written to a brand new file, but it arises, for example, when using
rsync without the --inplace option, in which new versions of the original file
are first written to a temporary file, then swapped into the final location.
This option is currently incompatible with --load_fast=true, and if passed will
disable fast-loading mode. (default: false)\
""",
)
def fix_flags(self, flags):
"""Fixes standard TensorBoard CLI flags to parser."""
FlagsError = base_plugin.FlagsError
if flags.version_tb:
pass
elif flags.inspect:
if flags.logdir_spec:
raise FlagsError(
"--logdir_spec is not supported with --inspect."
)
if flags.logdir and flags.event_file:
raise FlagsError(
"Must specify either --logdir or --event_file, but not both."
)
if not (flags.logdir or flags.event_file):
raise FlagsError(
"Must specify either --logdir or --event_file."
)
elif flags.logdir and flags.logdir_spec:
raise FlagsError("May not specify both --logdir and --logdir_spec")
elif (
not flags.db
and not flags.logdir
and not flags.logdir_spec
and not flags.grpc_data_provider
):
raise FlagsError(
"A logdir or db must be specified. "
"For example `tensorboard --logdir mylogdir` "
"or `tensorboard --db sqlite:~/.tensorboard.db`. "
"Run `tensorboard --helpfull` for details and examples."
)
elif flags.host is not None and flags.bind_all:
raise FlagsError("Must not specify both --host and --bind_all.")
elif (
flags.load_fast == "true" and flags.detect_file_replacement is True
):
raise FlagsError(
"Must not specify both --load_fast=true and"
"--detect_file_replacement=true"
)
flags.path_prefix = flags.path_prefix.rstrip("/")
if flags.path_prefix and not flags.path_prefix.startswith("/"):
raise FlagsError(
"Path prefix must start with slash, but got: %r."
% flags.path_prefix
)
def load(self, context):
"""Creates CorePlugin instance."""
return CorePlugin(context, include_debug_info=self._include_debug_info)
def _gzip(bytestring):
out = io.BytesIO()
# Set mtime to zero for deterministic results across TensorBoard launches.
with gzip.GzipFile(fileobj=out, mode="wb", compresslevel=3, mtime=0) as f:
f.write(bytestring)
return out.getvalue()
def _parse_samples_per_plugin(value):
"""Parses `value` as a string-to-int dict in the form `foo=12,bar=34`."""
result = {}
for token in value.split(","):
if token:
k, v = token.strip().split("=")
result[k] = int(v)
return result
def _nonnegative_float(v):
try:
v = float(v)
except ValueError:
raise argparse.ArgumentTypeError("invalid float: %r" % v)
if not (v >= 0): # no NaNs, please
raise argparse.ArgumentTypeError("must be non-negative: %r" % v)
return v

View File

@ -0,0 +1,321 @@
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""The TensorBoard Custom Scalars plugin.
This plugin lets the user create scalars plots with custom run-tag combinations
by specifying regular expressions.
See `http_api.md` in this directory for specifications of the routes for
this plugin.
"""
import re
from werkzeug import wrappers
from tensorboard import plugin_util
from tensorboard.backend import http_util
from tensorboard.compat import tf
from tensorboard.data import provider
from tensorboard.plugins import base_plugin
from tensorboard.plugins.custom_scalar import layout_pb2
from tensorboard.plugins.custom_scalar import metadata
from tensorboard.plugins.scalar import metadata as scalars_metadata
from tensorboard.plugins.scalar import scalars_plugin
# The name of the property in the response for whether the regex is valid.
_REGEX_VALID_PROPERTY = "regex_valid"
# The name of the property in the response for the payload (tag to ScalarEvents
# mapping).
_TAG_TO_EVENTS_PROPERTY = "tag_to_events"
# The number of seconds to wait in between checks for the config file specifying
# layout.
_CONFIG_FILE_CHECK_THROTTLE = 60
class CustomScalarsPlugin(base_plugin.TBPlugin):
"""CustomScalars Plugin for TensorBoard."""
plugin_name = metadata.PLUGIN_NAME
def __init__(self, context):
"""Instantiates ScalarsPlugin via TensorBoard core.
Args:
context: A base_plugin.TBContext instance.
"""
self._logdir = context.logdir
self._data_provider = context.data_provider
self._plugin_name_to_instance = context.plugin_name_to_instance
def _get_scalars_plugin(self):
"""Tries to get the scalars plugin.
Returns:
The scalars plugin. Or None if it is not yet registered.
"""
if scalars_metadata.PLUGIN_NAME in self._plugin_name_to_instance:
# The plugin is registered.
return self._plugin_name_to_instance[scalars_metadata.PLUGIN_NAME]
# The plugin is not yet registered.
return None
def get_plugin_apps(self):
return {
"/download_data": self.download_data_route,
"/layout": self.layout_route,
"/scalars": self.scalars_route,
}
def is_active(self):
"""Plugin is active if there is a custom layout for the dashboard."""
return False # `list_plugins` as called by TB core suffices
def frontend_metadata(self):
return base_plugin.FrontendMetadata(
element_name="tf-custom-scalar-dashboard",
tab_name="Custom Scalars",
)
@wrappers.Request.application
def download_data_route(self, request):
ctx = plugin_util.context(request.environ)
run = request.args.get("run")
tag = request.args.get("tag")
experiment = plugin_util.experiment_id(request.environ)
response_format = request.args.get("format")
try:
body, mime_type = self.download_data_impl(
ctx, run, tag, experiment, response_format
)
except ValueError as e:
return http_util.Respond(
request=request,
content=str(e),
content_type="text/plain",
code=400,
)
return http_util.Respond(request, body, mime_type)
def download_data_impl(self, ctx, run, tag, experiment, response_format):
"""Provides a response for downloading scalars data for a data series.
Args:
ctx: A tensorboard.context.RequestContext value.
run: The run.
tag: The specific tag.
experiment: An experiment ID, as a possibly-empty `str`.
response_format: A string. One of the values of the OutputFormat enum
of the scalar plugin.
Raises:
ValueError: If the scalars plugin is not registered.
Returns:
2 entities:
- A JSON object response body.
- A mime type (string) for the response.
"""
scalars_plugin_instance = self._get_scalars_plugin()
if not scalars_plugin_instance:
raise ValueError(
(
"Failed to respond to request for /download_data. "
"The scalars plugin is oddly not registered."
)
)
body, mime_type = scalars_plugin_instance.scalars_impl(
ctx, tag, run, experiment, response_format
)
return body, mime_type
@wrappers.Request.application
def scalars_route(self, request):
"""Given a tag regex and single run, return ScalarEvents.
This route takes 2 GET params:
run: A run string to find tags for.
tag: A string that is a regex used to find matching tags.
The response is a JSON object:
{
// Whether the regular expression is valid. Also false if empty.
regexValid: boolean,
// An object mapping tag name to a list of ScalarEvents.
payload: Object<string, ScalarEvent[]>,
}
"""
ctx = plugin_util.context(request.environ)
tag_regex_string = request.args.get("tag")
run = request.args.get("run")
experiment = plugin_util.experiment_id(request.environ)
mime_type = "application/json"
try:
body = self.scalars_impl(ctx, run, tag_regex_string, experiment)
except ValueError as e:
return http_util.Respond(
request=request,
content=str(e),
content_type="text/plain",
code=400,
)
# Produce the response.
return http_util.Respond(request, body, mime_type)
def scalars_impl(self, ctx, run, tag_regex_string, experiment):
"""Given a tag regex and single run, return ScalarEvents.
Args:
ctx: A tensorboard.context.RequestContext value.
run: A run string.
tag_regex_string: A regular expression that captures portions of tags.
Raises:
ValueError: if the scalars plugin is not registered.
Returns:
A dictionary that is the JSON-able response.
"""
if not tag_regex_string:
# The user provided no regex.
return {
_REGEX_VALID_PROPERTY: False,
_TAG_TO_EVENTS_PROPERTY: {},
}
# Construct the regex.
try:
regex = re.compile(tag_regex_string)
except re.error:
return {
_REGEX_VALID_PROPERTY: False,
_TAG_TO_EVENTS_PROPERTY: {},
}
# Fetch the tags for the run. Filter for tags that match the regex.
run_to_data = self._data_provider.list_scalars(
ctx,
experiment_id=experiment,
plugin_name=scalars_metadata.PLUGIN_NAME,
run_tag_filter=provider.RunTagFilter(runs=[run]),
)
tag_to_data = None
try:
tag_to_data = run_to_data[run]
except KeyError:
# The run could not be found. Perhaps a configuration specified a run that
# TensorBoard has not read from disk yet.
payload = {}
if tag_to_data:
scalars_plugin_instance = self._get_scalars_plugin()
if not scalars_plugin_instance:
raise ValueError(
(
"Failed to respond to request for /scalars. "
"The scalars plugin is oddly not registered."
)
)
form = scalars_plugin.OutputFormat.JSON
payload = {
tag: scalars_plugin_instance.scalars_impl(
ctx, tag, run, experiment, form
)[0]
for tag in tag_to_data.keys()
if regex.match(tag)
}
return {
_REGEX_VALID_PROPERTY: True,
_TAG_TO_EVENTS_PROPERTY: payload,
}
@wrappers.Request.application
def layout_route(self, request):
"""Fetches the custom layout specified by the config file in the logdir.
If more than 1 run contains a layout, this method merges the layouts by
merging charts within individual categories. If 2 categories with the same
name are found, the charts within are merged. The merging is based on the
order of the runs to which the layouts are written.
The response is a JSON object mirroring properties of the Layout proto if a
layout for any run is found.
The response is an empty object if no layout could be found.
"""
ctx = plugin_util.context(request.environ)
experiment = plugin_util.experiment_id(request.environ)
body = self.layout_impl(ctx, experiment)
return http_util.Respond(request, body, "application/json")
def layout_impl(self, ctx, experiment):
# Keep a mapping between and category so we do not create duplicate
# categories.
title_to_category = {}
merged_layout = None
data = self._data_provider.read_tensors(
ctx,
experiment_id=experiment,
plugin_name=metadata.PLUGIN_NAME,
run_tag_filter=provider.RunTagFilter(
tags=[metadata.CONFIG_SUMMARY_TAG]
),
downsample=1,
)
for run in sorted(data):
points = data[run][metadata.CONFIG_SUMMARY_TAG]
content = points[0].numpy.item()
layout_proto = layout_pb2.Layout()
layout_proto.ParseFromString(tf.compat.as_bytes(content))
if merged_layout:
# Append the categories within this layout to the merged layout.
for category in layout_proto.category:
if category.title in title_to_category:
# A category with this name has been seen before. Do not create a
# new one. Merge their charts, skipping any duplicates.
title_to_category[category.title].chart.extend(
[
c
for c in category.chart
if c
not in title_to_category[category.title].chart
]
)
else:
# This category has not been seen before.
merged_layout.category.add().MergeFrom(category)
title_to_category[category.title] = category
else:
# This is the first layout encountered.
merged_layout = layout_proto
for category in layout_proto.category:
title_to_category[category.title] = category
if merged_layout:
return plugin_util.proto_to_json(merged_layout)
else:
# No layout was found.
return {}

View File

@ -0,0 +1,85 @@
# -*- coding: utf-8 -*-
# Generated by the protocol buffer compiler. DO NOT EDIT!
# source: tensorboard/plugins/custom_scalar/layout.proto
"""Generated protocol buffer code."""
from google.protobuf import descriptor as _descriptor
from google.protobuf import descriptor_pool as _descriptor_pool
from google.protobuf import message as _message
from google.protobuf import reflection as _reflection
from google.protobuf import symbol_database as _symbol_database
# @@protoc_insertion_point(imports)
_sym_db = _symbol_database.Default()
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n.tensorboard/plugins/custom_scalar/layout.proto\x12\x0btensorboard\"\x8d\x01\n\x05\x43hart\x12\r\n\x05title\x18\x01 \x01(\t\x12\x37\n\tmultiline\x18\x02 \x01(\x0b\x32\".tensorboard.MultilineChartContentH\x00\x12\x31\n\x06margin\x18\x03 \x01(\x0b\x32\x1f.tensorboard.MarginChartContentH\x00\x42\t\n\x07\x63ontent\"$\n\x15MultilineChartContent\x12\x0b\n\x03tag\x18\x01 \x03(\t\"\x83\x01\n\x12MarginChartContent\x12\x36\n\x06series\x18\x01 \x03(\x0b\x32&.tensorboard.MarginChartContent.Series\x1a\x35\n\x06Series\x12\r\n\x05value\x18\x01 \x01(\t\x12\r\n\x05lower\x18\x02 \x01(\t\x12\r\n\x05upper\x18\x03 \x01(\t\"L\n\x08\x43\x61tegory\x12\r\n\x05title\x18\x01 \x01(\t\x12!\n\x05\x63hart\x18\x02 \x03(\x0b\x32\x12.tensorboard.Chart\x12\x0e\n\x06\x63losed\x18\x03 \x01(\x08\"B\n\x06Layout\x12\x0f\n\x07version\x18\x01 \x01(\x05\x12\'\n\x08\x63\x61tegory\x18\x02 \x03(\x0b\x32\x15.tensorboard.Categoryb\x06proto3')
_CHART = DESCRIPTOR.message_types_by_name['Chart']
_MULTILINECHARTCONTENT = DESCRIPTOR.message_types_by_name['MultilineChartContent']
_MARGINCHARTCONTENT = DESCRIPTOR.message_types_by_name['MarginChartContent']
_MARGINCHARTCONTENT_SERIES = _MARGINCHARTCONTENT.nested_types_by_name['Series']
_CATEGORY = DESCRIPTOR.message_types_by_name['Category']
_LAYOUT = DESCRIPTOR.message_types_by_name['Layout']
Chart = _reflection.GeneratedProtocolMessageType('Chart', (_message.Message,), {
'DESCRIPTOR' : _CHART,
'__module__' : 'tensorboard.plugins.custom_scalar.layout_pb2'
# @@protoc_insertion_point(class_scope:tensorboard.Chart)
})
_sym_db.RegisterMessage(Chart)
MultilineChartContent = _reflection.GeneratedProtocolMessageType('MultilineChartContent', (_message.Message,), {
'DESCRIPTOR' : _MULTILINECHARTCONTENT,
'__module__' : 'tensorboard.plugins.custom_scalar.layout_pb2'
# @@protoc_insertion_point(class_scope:tensorboard.MultilineChartContent)
})
_sym_db.RegisterMessage(MultilineChartContent)
MarginChartContent = _reflection.GeneratedProtocolMessageType('MarginChartContent', (_message.Message,), {
'Series' : _reflection.GeneratedProtocolMessageType('Series', (_message.Message,), {
'DESCRIPTOR' : _MARGINCHARTCONTENT_SERIES,
'__module__' : 'tensorboard.plugins.custom_scalar.layout_pb2'
# @@protoc_insertion_point(class_scope:tensorboard.MarginChartContent.Series)
})
,
'DESCRIPTOR' : _MARGINCHARTCONTENT,
'__module__' : 'tensorboard.plugins.custom_scalar.layout_pb2'
# @@protoc_insertion_point(class_scope:tensorboard.MarginChartContent)
})
_sym_db.RegisterMessage(MarginChartContent)
_sym_db.RegisterMessage(MarginChartContent.Series)
Category = _reflection.GeneratedProtocolMessageType('Category', (_message.Message,), {
'DESCRIPTOR' : _CATEGORY,
'__module__' : 'tensorboard.plugins.custom_scalar.layout_pb2'
# @@protoc_insertion_point(class_scope:tensorboard.Category)
})
_sym_db.RegisterMessage(Category)
Layout = _reflection.GeneratedProtocolMessageType('Layout', (_message.Message,), {
'DESCRIPTOR' : _LAYOUT,
'__module__' : 'tensorboard.plugins.custom_scalar.layout_pb2'
# @@protoc_insertion_point(class_scope:tensorboard.Layout)
})
_sym_db.RegisterMessage(Layout)
if _descriptor._USE_C_DESCRIPTORS == False:
DESCRIPTOR._options = None
_CHART._serialized_start=64
_CHART._serialized_end=205
_MULTILINECHARTCONTENT._serialized_start=207
_MULTILINECHARTCONTENT._serialized_end=243
_MARGINCHARTCONTENT._serialized_start=246
_MARGINCHARTCONTENT._serialized_end=377
_MARGINCHARTCONTENT_SERIES._serialized_start=324
_MARGINCHARTCONTENT_SERIES._serialized_end=377
_CATEGORY._serialized_start=379
_CATEGORY._serialized_end=455
_LAYOUT._serialized_start=457
_LAYOUT._serialized_end=523
# @@protoc_insertion_point(module_scope)

View File

@ -0,0 +1,36 @@
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Information on the custom scalars plugin."""
from tensorboard.compat.proto import summary_pb2
# A special tag named used for the summary that stores the layout.
CONFIG_SUMMARY_TAG = "custom_scalars__config__"
PLUGIN_NAME = "custom_scalars"
def create_summary_metadata():
"""Create a `SummaryMetadata` proto for custom scalar plugin data.
Returns:
A `summary_pb2.SummaryMetadata` protobuf object.
"""
return summary_pb2.SummaryMetadata(
plugin_data=summary_pb2.SummaryMetadata.PluginData(
plugin_name=PLUGIN_NAME
)
)

View File

@ -0,0 +1,80 @@
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Contains summaries related to laying out the custom scalars dashboard."""
from tensorboard.plugins.custom_scalar import layout_pb2
from tensorboard.plugins.custom_scalar import metadata
def op(scalars_layout, collections=None):
"""Creates a summary that contains a layout.
When users navigate to the custom scalars dashboard, they will see a layout
based on the proto provided to this function.
Args:
scalars_layout: The scalars_layout_pb2.Layout proto that specifies the
layout.
collections: Optional list of graph collections keys. The new
summary op is added to these collections. Defaults to
`[Graph Keys.SUMMARIES]`.
Returns:
A tensor summary op that writes the layout to disk.
"""
# TODO(nickfelt): remove on-demand imports once dep situation is fixed.
import tensorflow.compat.v1 as tf
assert isinstance(scalars_layout, layout_pb2.Layout)
summary_metadata = metadata.create_summary_metadata()
return tf.summary.tensor_summary(
name=metadata.CONFIG_SUMMARY_TAG,
tensor=tf.constant(scalars_layout.SerializeToString(), dtype=tf.string),
collections=collections,
summary_metadata=summary_metadata,
)
def pb(scalars_layout):
"""Creates a summary that contains a layout.
When users navigate to the custom scalars dashboard, they will see a layout
based on the proto provided to this function.
Args:
scalars_layout: The scalars_layout_pb2.Layout proto that specifies the
layout.
Returns:
A summary proto containing the layout.
"""
# TODO(nickfelt): remove on-demand imports once dep situation is fixed.
import tensorflow.compat.v1 as tf
assert isinstance(scalars_layout, layout_pb2.Layout)
tensor = tf.make_tensor_proto(
scalars_layout.SerializeToString(), dtype=tf.string
)
tf_summary_metadata = tf.SummaryMetadata.FromString(
metadata.create_summary_metadata().SerializeToString()
)
summary = tf.Summary()
summary.value.add(
tag=metadata.CONFIG_SUMMARY_TAG,
metadata=tf_summary_metadata,
tensor=tensor,
)
return summary

View File

@ -0,0 +1,632 @@
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""A wrapper around DebugDataReader used for retrieving tfdbg v2 data."""
import threading
from tensorboard import errors
# Dummy run name for the debugger.
# Currently, the `DebuggerV2ExperimentMultiplexer` class is tied to a single
# logdir, which holds at most one DebugEvent file set in the tfdbg v2 (tfdbg2
# for short) format.
# TODO(cais): When tfdbg2 allows there to be multiple DebugEvent file sets in
# the same logdir, replace this magic string with actual run names.
DEFAULT_DEBUGGER_RUN_NAME = "__default_debugger_run__"
# Default number of alerts per monitor type.
# Limiting the number of alerts is based on the consideration that usually
# only the first few alerting events are the most critical and the subsequent
# ones are either repetitions of the earlier ones or caused by the earlier ones.
DEFAULT_PER_TYPE_ALERT_LIMIT = 1000
# Default interval between successive calls to `DebugDataReader.update()``.
DEFAULT_RELOAD_INTERVAL_SEC = 30
def run_repeatedly_in_background(target, interval_sec):
"""Run a target task repeatedly in the background.
In the context of this module, `target` is the `update()` method of the
underlying reader for tfdbg2-format data.
This method is mocked by unit tests for deterministic behaviors during
testing.
Args:
target: The target task to run in the background, a callable with no args.
interval_sec: Time interval between repeats, in seconds.
Returns:
- A `threading.Event` object that can be used to interrupt an ongoing
waiting interval between successive runs of `target`. To interrupt the
interval, call the `set()` method of the object.
- The `threading.Thread` object on which `target` is run repeatedly.
"""
event = threading.Event()
def _run_repeatedly():
while True:
target()
event.wait(interval_sec)
event.clear()
# Use `daemon=True` to make sure the thread doesn't block program exit.
thread = threading.Thread(target=_run_repeatedly, daemon=True)
thread.start()
return event, thread
def _alert_to_json(alert):
# TODO(cais): Replace this with Alert.to_json() when supported by the
# backend.
from tensorflow.python.debug.lib import debug_events_monitors
if isinstance(alert, debug_events_monitors.InfNanAlert):
return {
"alert_type": "InfNanAlert",
"op_type": alert.op_type,
"output_slot": alert.output_slot,
# TODO(cais): Once supported by backend, add 'op_name' key
# for intra-graph execution events.
"size": alert.size,
"num_neg_inf": alert.num_neg_inf,
"num_pos_inf": alert.num_pos_inf,
"num_nan": alert.num_nan,
"execution_index": alert.execution_index,
"graph_execution_trace_index": alert.graph_execution_trace_index,
}
else:
raise TypeError("Unrecognized alert subtype: %s" % type(alert))
def parse_tensor_name(tensor_name):
"""Helper function that extracts op name and slot from tensor name."""
output_slot = 0
if ":" in tensor_name:
op_name, output_slot = tensor_name.split(":")
output_slot = int(output_slot)
else:
op_name = tensor_name
return op_name, output_slot
class DebuggerV2EventMultiplexer:
"""A class used for accessing tfdbg v2 DebugEvent data on local filesystem.
This class is a short-term hack, mirroring the EventMultiplexer for the main
TensorBoard plugins (e.g., scalar, histogram and graphs.) As such, it only
implements the methods relevant to the Debugger V2 pluggin.
TODO(cais): Integrate it with EventMultiplexer and use the integrated class
from MultiplexerDataProvider for a single path of accessing debugger and
non-debugger data.
"""
def __init__(self, logdir):
"""Constructor for the `DebugEventMultiplexer`.
Args:
logdir: Path to the directory to load the tfdbg v2 data from.
"""
self._logdir = logdir
self._reader = None
self._reader_lock = threading.Lock()
self._reload_needed_event = None
# Create the reader for the tfdbg2 data in the lodir as soon as
# the backend of the debugger-v2 plugin is created, so it doesn't need
# to wait for the first request from the FE to start loading data.
self._tryCreateReader()
def _tryCreateReader(self):
"""Try creating reader for tfdbg2 data in the logdir.
If the reader has already been created, a new one will not be created and
this function is a no-op.
If a reader has not been created, create it and start periodic calls to
`update()` on a separate thread.
"""
if self._reader:
return
with self._reader_lock:
if not self._reader:
try:
# TODO(cais): Avoid conditional imports and instead use
# plugin loader to gate the loading of this entire plugin.
from tensorflow.python.debug.lib import debug_events_reader
from tensorflow.python.debug.lib import (
debug_events_monitors,
)
except ImportError:
# This ensures graceful behavior when tensorflow install is
# unavailable or when the installed tensorflow version does not
# contain the required modules.
return
try:
self._reader = debug_events_reader.DebugDataReader(
self._logdir
)
except AttributeError:
# Gracefully fail for users without the required API changes to
# debug_events_reader.DebugDataReader introduced in
# TF 2.1.0.dev20200103. This should be safe to remove when
# TF 2.2 is released.
return
except ValueError:
# When no DebugEvent file set is found in the logdir, a
# `ValueError` is thrown.
return
self._monitors = [
debug_events_monitors.InfNanMonitor(
self._reader, limit=DEFAULT_PER_TYPE_ALERT_LIMIT
)
]
self._reload_needed_event, _ = run_repeatedly_in_background(
self._reader.update, DEFAULT_RELOAD_INTERVAL_SEC
)
def _reloadReader(self):
"""If a reader exists and has started period updating, unblock the update.
The updates are performed periodically with a sleep interval between
successive calls to the reader's update() method. Calling this method
interrupts the sleep immediately if one is ongoing.
"""
if self._reload_needed_event:
self._reload_needed_event.set()
def FirstEventTimestamp(self, run):
"""Return the timestamp of the first DebugEvent of the given run.
This may perform I/O if no events have been loaded yet for the run.
Args:
run: A string name of the run for which the timestamp is retrieved.
This currently must be hardcoded as `DEFAULT_DEBUGGER_RUN_NAME`,
as each logdir contains at most one DebugEvent file set (i.e., a
run of a tfdbg2-instrumented TensorFlow program.)
Returns:
The wall_time of the first event of the run, which will be in seconds
since the epoch as a `float`.
"""
if self._reader is None:
raise ValueError("No tfdbg2 runs exists.")
if run != DEFAULT_DEBUGGER_RUN_NAME:
raise ValueError(
"Expected run name to be %s, but got %s"
% (DEFAULT_DEBUGGER_RUN_NAME, run)
)
return self._reader.starting_wall_time()
def PluginRunToTagToContent(self, plugin_name):
raise NotImplementedError(
"DebugDataMultiplexer.PluginRunToTagToContent() has not been "
"implemented yet."
)
def Runs(self):
"""Return all the tfdbg2 run names in the logdir watched by this instance.
The `Run()` method of this class is specialized for the tfdbg2-format
DebugEvent files.
As a side effect, this method unblocks the underlying reader's period
reloading if a reader exists. This lets the reader update at a higher
frequency than the default one with 30-second sleeping period between
reloading when data is being queried actively from this instance.
Note that this `Runs()` method is used by all other public data-access
methods of this class (e.g., `ExecutionData()`, `GraphExecutionData()`).
Hence calls to those methods will lead to accelerated data reloading of
the reader.
Returns:
If tfdbg2-format data exists in the `logdir` of this object, returns:
```
{runName: { "debugger-v2": [tag1, tag2, tag3] } }
```
where `runName` is the hard-coded string `DEFAULT_DEBUGGER_RUN_NAME`
string. This is related to the fact that tfdbg2 currently contains
at most one DebugEvent file set per directory.
If no tfdbg2-format data exists in the `logdir`, an empty `dict`.
"""
# Call `_tryCreateReader()` here to cover the possibility of tfdbg2
# data start being written to the logdir after the tensorboard backend
# starts.
self._tryCreateReader()
if self._reader:
# If a _reader exists, unblock its reloading (on a separate thread)
# immediately.
self._reloadReader()
return {
DEFAULT_DEBUGGER_RUN_NAME: {
# TODO(cais): Add the semantically meaningful tag names such as
# 'execution_digests_book', 'alerts_book'
"debugger-v2": []
}
}
else:
return {}
def _checkBeginEndIndices(self, begin, end, total_count):
if begin < 0:
raise errors.InvalidArgumentError(
"Invalid begin index (%d)" % begin
)
if end > total_count:
raise errors.InvalidArgumentError(
"end index (%d) out of bounds (%d)" % (end, total_count)
)
if end >= 0 and end < begin:
raise errors.InvalidArgumentError(
"end index (%d) is unexpectedly less than begin index (%d)"
% (end, begin)
)
if end < 0: # This means all digests.
end = total_count
return end
def Alerts(self, run, begin, end, alert_type_filter=None):
"""Get alerts from the debugged TensorFlow program.
Args:
run: The tfdbg2 run to get Alerts from.
begin: Beginning alert index.
end: Ending alert index.
alert_type_filter: Optional filter string for alert type, used to
restrict retrieved alerts data to a single type. If used,
`begin` and `end` refer to the beginning and ending indices within
the filtered alert type.
"""
from tensorflow.python.debug.lib import debug_events_monitors
runs = self.Runs()
if run not in runs:
# TODO(cais): This should generate a 400 response instead.
return None
alerts = []
alerts_breakdown = dict()
alerts_by_type = dict()
for monitor in self._monitors:
monitor_alerts = monitor.alerts()
if not monitor_alerts:
continue
alerts.extend(monitor_alerts)
# TODO(cais): Replace this with Alert.to_json() when
# monitor.alert_type() is available.
if isinstance(monitor, debug_events_monitors.InfNanMonitor):
alert_type = "InfNanAlert"
else:
alert_type = "__MiscellaneousAlert__"
alerts_breakdown[alert_type] = len(monitor_alerts)
alerts_by_type[alert_type] = monitor_alerts
num_alerts = len(alerts)
if alert_type_filter is not None:
if alert_type_filter not in alerts_breakdown:
raise errors.InvalidArgumentError(
"Filtering of alerts failed: alert type %s does not exist"
% alert_type_filter
)
alerts = alerts_by_type[alert_type_filter]
end = self._checkBeginEndIndices(begin, end, len(alerts))
return {
"begin": begin,
"end": end,
"alert_type": alert_type_filter,
"num_alerts": num_alerts,
"alerts_breakdown": alerts_breakdown,
"per_type_alert_limit": DEFAULT_PER_TYPE_ALERT_LIMIT,
"alerts": [_alert_to_json(alert) for alert in alerts[begin:end]],
}
def ExecutionDigests(self, run, begin, end):
"""Get ExecutionDigests.
Args:
run: The tfdbg2 run to get `ExecutionDigest`s from.
begin: Beginning execution index.
end: Ending execution index.
Returns:
A JSON-serializable object containing the `ExecutionDigest`s and
related meta-information
"""
runs = self.Runs()
if run not in runs:
return None
# TODO(cais): For scalability, use begin and end kwargs when available in
# `DebugDataReader.execution()`.`
execution_digests = self._reader.executions(digest=True)
end = self._checkBeginEndIndices(begin, end, len(execution_digests))
return {
"begin": begin,
"end": end,
"num_digests": len(execution_digests),
"execution_digests": [
digest.to_json() for digest in execution_digests[begin:end]
],
}
def ExecutionData(self, run, begin, end):
"""Get Execution data objects (Detailed, non-digest form).
Args:
run: The tfdbg2 run to get `ExecutionDigest`s from.
begin: Beginning execution index.
end: Ending execution index.
Returns:
A JSON-serializable object containing the `ExecutionDigest`s and
related meta-information
"""
runs = self.Runs()
if run not in runs:
return None
execution_digests = self._reader.executions(digest=True)
end = self._checkBeginEndIndices(begin, end, len(execution_digests))
execution_digests = execution_digests[begin:end]
executions = self._reader.executions(digest=False, begin=begin, end=end)
return {
"begin": begin,
"end": end,
"executions": [execution.to_json() for execution in executions],
}
def GraphExecutionDigests(self, run, begin, end, trace_id=None):
"""Get `GraphExecutionTraceDigest`s.
Args:
run: The tfdbg2 run to get `GraphExecutionTraceDigest`s from.
begin: Beginning graph-execution index.
end: Ending graph-execution index.
Returns:
A JSON-serializable object containing the `ExecutionDigest`s and
related meta-information
"""
runs = self.Runs()
if run not in runs:
return None
# TODO(cais): Implement support for trace_id once the joining of eager
# execution and intra-graph execution is supported by DebugDataReader.
if trace_id is not None:
raise NotImplementedError(
"trace_id support for GraphExecutionTraceDigest is "
"not implemented yet."
)
graph_exec_digests = self._reader.graph_execution_traces(digest=True)
end = self._checkBeginEndIndices(begin, end, len(graph_exec_digests))
return {
"begin": begin,
"end": end,
"num_digests": len(graph_exec_digests),
"graph_execution_digests": [
digest.to_json() for digest in graph_exec_digests[begin:end]
],
}
def GraphExecutionData(self, run, begin, end, trace_id=None):
"""Get `GraphExecutionTrace`s.
Args:
run: The tfdbg2 run to get `GraphExecutionTrace`s from.
begin: Beginning graph-execution index.
end: Ending graph-execution index.
Returns:
A JSON-serializable object containing the `ExecutionDigest`s and
related meta-information
"""
runs = self.Runs()
if run not in runs:
return None
# TODO(cais): Implement support for trace_id once the joining of eager
# execution and intra-graph execution is supported by DebugDataReader.
if trace_id is not None:
raise NotImplementedError(
"trace_id support for GraphExecutionTraceData is "
"not implemented yet."
)
digests = self._reader.graph_execution_traces(digest=True)
end = self._checkBeginEndIndices(begin, end, len(digests))
graph_executions = self._reader.graph_execution_traces(
digest=False, begin=begin, end=end
)
return {
"begin": begin,
"end": end,
"graph_executions": [
graph_exec.to_json() for graph_exec in graph_executions
],
}
def GraphInfo(self, run, graph_id):
"""Get the information regarding a TensorFlow graph.
Args:
run: Name of the run.
graph_id: Debugger-generated ID of the graph in question.
This information is available in the return values
of `GraphOpInfo`, `GraphExecution`, etc.
Returns:
A JSON-serializable object containing the information regarding
the TensorFlow graph.
Raises:
NotFoundError if the graph_id is not known to the debugger.
"""
runs = self.Runs()
if run not in runs:
return None
try:
graph = self._reader.graph_by_id(graph_id)
except KeyError:
raise errors.NotFoundError(
'There is no graph with ID "%s"' % graph_id
)
return graph.to_json()
def GraphOpInfo(self, run, graph_id, op_name):
"""Get the information regarding a graph op's creation.
Args:
run: Name of the run.
graph_id: Debugger-generated ID of the graph that contains
the op in question. This ID is available from other methods
of this class, e.g., the return value of `GraphExecutionDigests()`.
op_name: Name of the op.
Returns:
A JSON-serializable object containing the information regarding
the op's creation and its immediate inputs and consumers.
Raises:
NotFoundError if the graph_id or op_name does not exist.
"""
runs = self.Runs()
if run not in runs:
return None
try:
graph = self._reader.graph_by_id(graph_id)
except KeyError:
raise errors.NotFoundError(
'There is no graph with ID "%s"' % graph_id
)
try:
op_creation_digest = graph.get_op_creation_digest(op_name)
except KeyError:
raise errors.NotFoundError(
'There is no op named "%s" in graph with ID "%s"'
% (op_name, graph_id)
)
data_object = self._opCreationDigestToDataObject(
op_creation_digest, graph
)
# Populate data about immediate inputs.
for input_spec in data_object["inputs"]:
try:
input_op_digest = graph.get_op_creation_digest(
input_spec["op_name"]
)
except KeyError:
input_op_digest = None
if input_op_digest:
input_spec["data"] = self._opCreationDigestToDataObject(
input_op_digest, graph
)
# Populate data about immediate consuming ops.
for slot_consumer_specs in data_object["consumers"]:
for consumer_spec in slot_consumer_specs:
try:
digest = graph.get_op_creation_digest(
consumer_spec["op_name"]
)
except KeyError:
digest = None
if digest:
consumer_spec["data"] = self._opCreationDigestToDataObject(
digest, graph
)
return data_object
def _opCreationDigestToDataObject(self, op_creation_digest, graph):
if op_creation_digest is None:
return None
json_object = op_creation_digest.to_json()
del json_object["graph_id"]
json_object["graph_ids"] = self._getGraphStackIds(
op_creation_digest.graph_id
)
# TODO(cais): "num_outputs" should be populated in to_json() instead.
json_object["num_outputs"] = op_creation_digest.num_outputs
del json_object["input_names"]
json_object["inputs"] = []
for input_tensor_name in op_creation_digest.input_names or []:
input_op_name, output_slot = parse_tensor_name(input_tensor_name)
json_object["inputs"].append(
{"op_name": input_op_name, "output_slot": output_slot}
)
json_object["consumers"] = []
for _ in range(json_object["num_outputs"]):
json_object["consumers"].append([])
for src_slot, consumer_op_name, dst_slot in graph.get_op_consumers(
json_object["op_name"]
):
json_object["consumers"][src_slot].append(
{"op_name": consumer_op_name, "input_slot": dst_slot}
)
return json_object
def _getGraphStackIds(self, graph_id):
"""Retrieve the IDs of all outer graphs of a graph.
Args:
graph_id: Id of the graph being queried with respect to its outer
graphs context.
Returns:
A list of graph_ids, ordered from outermost to innermost, including
the input `graph_id` argument as the last item.
"""
graph_ids = [graph_id]
graph = self._reader.graph_by_id(graph_id)
while graph.outer_graph_id:
graph_ids.insert(0, graph.outer_graph_id)
graph = self._reader.graph_by_id(graph.outer_graph_id)
return graph_ids
def SourceFileList(self, run):
runs = self.Runs()
if run not in runs:
return None
return self._reader.source_file_list()
def SourceLines(self, run, index):
runs = self.Runs()
if run not in runs:
return None
try:
host_name, file_path = self._reader.source_file_list()[index]
except IndexError:
raise errors.NotFoundError(
"There is no source-code file at index %d" % index
)
return {
"host_name": host_name,
"file_path": file_path,
"lines": self._reader.source_lines(host_name, file_path),
}
def StackFrames(self, run, stack_frame_ids):
runs = self.Runs()
if run not in runs:
return None
stack_frames = []
for stack_frame_id in stack_frame_ids:
if stack_frame_id not in self._reader._stack_frame_by_id:
raise errors.NotFoundError(
"Cannot find stack frame with ID %s" % stack_frame_id
)
# TODO(cais): Use public method (`stack_frame_by_id()`) when
# available.
# pylint: disable=protected-access
stack_frames.append(self._reader._stack_frame_by_id[stack_frame_id])
# pylint: enable=protected-access
return {"stack_frames": stack_frames}

View File

@ -0,0 +1,635 @@
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""An implementation of DataProvider that serves tfdbg v2 data.
This implementation is:
1. Based on reading data from a DebugEvent file set on the local filesystem.
2. Implements only the relevant methods for the debugger v2 plugin, including
- list_runs()
- read_blob_sequences()
- read_blob()
This class is a short-term hack. To be used in production, it awaits integration
with a more complete implementation of DataProvider such as
MultiplexerDataProvider.
"""
import json
from tensorboard.data import provider
from tensorboard.plugins.debugger_v2 import debug_data_multiplexer
PLUGIN_NAME = "debugger-v2"
ALERTS_BLOB_TAG_PREFIX = "alerts"
EXECUTION_DIGESTS_BLOB_TAG_PREFIX = "execution_digests"
EXECUTION_DATA_BLOB_TAG_PREFIX = "execution_data"
GRAPH_EXECUTION_DIGESTS_BLOB_TAG_PREFIX = "graphexec_digests"
GRAPH_EXECUTION_DATA_BLOB_TAG_PREFIX = "graphexec_data"
GRAPH_INFO_BLOB_TAG_PREFIX = "graph_info"
GRAPH_OP_INFO_BLOB_TAG_PREFIX = "graph_op_info"
SOURCE_FILE_LIST_BLOB_TAG = "source_file_list"
SOURCE_FILE_BLOB_TAG_PREFIX = "source_file"
STACK_FRAMES_BLOB_TAG_PREFIX = "stack_frames"
def alerts_run_tag_filter(run, begin, end, alert_type=None):
"""Create a RunTagFilter for Alerts.
Args:
run: tfdbg2 run name.
begin: Beginning index of alerts.
end: Ending index of alerts.
alert_type: Optional alert type, used to restrict retrieval of alerts
data to a single type of alerts.
Returns:
`RunTagFilter` for the run and range of Alerts.
"""
tag = "%s_%d_%d" % (ALERTS_BLOB_TAG_PREFIX, begin, end)
if alert_type is not None:
tag += "_%s" % alert_type
return provider.RunTagFilter(runs=[run], tags=[tag])
def _parse_alerts_blob_key(blob_key):
"""Parse the BLOB key for Alerts.
Args:
blob_key: The BLOB key to parse. By contract, it should have the format:
- `${ALERTS_BLOB_TAG_PREFIX}_${begin}_${end}.${run_id}` when there is no
alert type filter.
- `${ALERTS_BLOB_TAG_PREFIX}_${begin}_${end}_${alert_filter}.${run_id}`
when there is an alert type filter.
Returns:
- run ID
- begin index
- end index
- alert_type: alert type string used to filter retrieved alert data.
`None` if no filtering is used.
"""
key_body, run = blob_key.split(".", 1)
key_body = key_body[len(ALERTS_BLOB_TAG_PREFIX) :]
key_items = key_body.split("_", 3)
begin = int(key_items[1])
end = int(key_items[2])
alert_type = None
if len(key_items) > 3:
alert_type = key_items[3]
return run, begin, end, alert_type
def execution_digest_run_tag_filter(run, begin, end):
"""Create a RunTagFilter for ExecutionDigests.
This differs from `execution_data_run_tag_filter()` in that it is for
the small-size digest objects for execution debug events, instead of the
full-size data objects.
Args:
run: tfdbg2 run name.
begin: Beginning index of ExecutionDigests.
end: Ending index of ExecutionDigests.
Returns:
`RunTagFilter` for the run and range of ExecutionDigests.
"""
return provider.RunTagFilter(
runs=[run],
tags=["%s_%d_%d" % (EXECUTION_DIGESTS_BLOB_TAG_PREFIX, begin, end)],
)
def _parse_execution_digest_blob_key(blob_key):
"""Parse the BLOB key for ExecutionDigests.
This differs from `_parse_execution_data_blob_key()` in that it is for
the small-size digest objects for execution debug events, instead of the
full-size data objects.
Args:
blob_key: The BLOB key to parse. By contract, it should have the format:
`${EXECUTION_DIGESTS_BLOB_TAG_PREFIX}_${begin}_${end}.${run_id}`
Returns:
- run ID
- begin index
- end index
"""
key_body, run = blob_key.split(".", 1)
key_body = key_body[len(EXECUTION_DIGESTS_BLOB_TAG_PREFIX) :]
begin = int(key_body.split("_")[1])
end = int(key_body.split("_")[2])
return run, begin, end
def execution_data_run_tag_filter(run, begin, end):
"""Create a RunTagFilter for Execution data objects.
This differs from `execution_digest_run_tag_filter()` in that it is
for the detailed data objects for execution, instead of the digests.
Args:
run: tfdbg2 run name.
begin: Beginning index of Execution.
end: Ending index of Execution.
Returns:
`RunTagFilter` for the run and range of ExecutionDigests.
"""
return provider.RunTagFilter(
runs=[run],
tags=["%s_%d_%d" % (EXECUTION_DATA_BLOB_TAG_PREFIX, begin, end)],
)
def _parse_execution_data_blob_key(blob_key):
"""Parse the BLOB key for Execution data objects.
This differs from `_parse_execution_digest_blob_key()` in that it is
for the deatiled data objects for execution, instead of the digests.
Args:
blob_key: The BLOB key to parse. By contract, it should have the format:
`${EXECUTION_DATA_BLOB_TAG_PREFIX}_${begin}_${end}.${run_id}`
Returns:
- run ID
- begin index
- end index
"""
key_body, run = blob_key.split(".", 1)
key_body = key_body[len(EXECUTION_DATA_BLOB_TAG_PREFIX) :]
begin = int(key_body.split("_")[1])
end = int(key_body.split("_")[2])
return run, begin, end
def graph_execution_digest_run_tag_filter(run, begin, end, trace_id=None):
"""Create a RunTagFilter for GraphExecutionTraceDigests.
This differs from `graph_execution_data_run_tag_filter()` in that it is for
the small-size digest objects for intra-graph execution debug events, instead
of the full-size data objects.
Args:
run: tfdbg2 run name.
begin: Beginning index of GraphExecutionTraceDigests.
end: Ending index of GraphExecutionTraceDigests.
Returns:
`RunTagFilter` for the run and range of GraphExecutionTraceDigests.
"""
# TODO(cais): Implement support for trace_id once joining of eager
# execution and intra-graph execution is supported by DebugDataReader.
if trace_id is not None:
raise NotImplementedError(
"trace_id support for graph_execution_digest_run_tag_filter() is "
"not implemented yet."
)
return provider.RunTagFilter(
runs=[run],
tags=[
"%s_%d_%d" % (GRAPH_EXECUTION_DIGESTS_BLOB_TAG_PREFIX, begin, end)
],
)
def _parse_graph_execution_digest_blob_key(blob_key):
"""Parse the BLOB key for GraphExecutionTraceDigests.
This differs from `_parse_graph_execution_data_blob_key()` in that it is for
the small-size digest objects for intra-graph execution debug events,
instead of the full-size data objects.
Args:
blob_key: The BLOB key to parse. By contract, it should have the format:
`${GRAPH_EXECUTION_DIGESTS_BLOB_TAG_PREFIX}_${begin}_${end}.${run_id}`
Returns:
- run ID
- begin index
- end index
"""
# TODO(cais): Support parsing trace_id when it is supported.
key_body, run = blob_key.split(".", 1)
key_body = key_body[len(GRAPH_EXECUTION_DIGESTS_BLOB_TAG_PREFIX) :]
begin = int(key_body.split("_")[1])
end = int(key_body.split("_")[2])
return run, begin, end
def graph_execution_data_run_tag_filter(run, begin, end, trace_id=None):
"""Create a RunTagFilter for GraphExecutionTrace.
This method differs from `graph_execution_digest_run_tag_filter()` in that
it is for full-sized data objects for intra-graph execution events.
Args:
run: tfdbg2 run name.
begin: Beginning index of GraphExecutionTrace.
end: Ending index of GraphExecutionTrace.
Returns:
`RunTagFilter` for the run and range of GraphExecutionTrace.
"""
# TODO(cais): Implement support for trace_id once joining of eager
# execution and intra-graph execution is supported by DebugDataReader.
if trace_id is not None:
raise NotImplementedError(
"trace_id support for graph_execution_data_run_tag_filter() is "
"not implemented yet."
)
return provider.RunTagFilter(
runs=[run],
tags=["%s_%d_%d" % (GRAPH_EXECUTION_DATA_BLOB_TAG_PREFIX, begin, end)],
)
def _parse_graph_execution_data_blob_key(blob_key):
"""Parse the BLOB key for GraphExecutionTrace.
This method differs from `_parse_graph_execution_digest_blob_key()` in that
it is for full-sized data objects for intra-graph execution events.
Args:
blob_key: The BLOB key to parse. By contract, it should have the format:
`${GRAPH_EXECUTION_DATA_BLOB_TAG_PREFIX}_${begin}_${end}.${run_id}`
Returns:
- run ID
- begin index
- end index
"""
# TODO(cais): Support parsing trace_id when it is supported.
key_body, run = blob_key.split(".", 1)
key_body = key_body[len(GRAPH_EXECUTION_DATA_BLOB_TAG_PREFIX) :]
begin = int(key_body.split("_")[1])
end = int(key_body.split("_")[2])
return run, begin, end
def graph_op_info_run_tag_filter(run, graph_id, op_name):
"""Create a RunTagFilter for graph op info.
Args:
run: tfdbg2 run name.
graph_id: Debugger-generated ID of the graph. This is assumed to
be the ID of the graph that immediately encloses the op in question.
op_name: Name of the op in question. (e.g., "Dense_1/MatMul")
Returns:
`RunTagFilter` for the run and range of graph op info.
"""
if not graph_id:
raise ValueError("graph_id must not be None or empty.")
return provider.RunTagFilter(
runs=[run],
tags=["%s_%s_%s" % (GRAPH_OP_INFO_BLOB_TAG_PREFIX, graph_id, op_name)],
)
def _parse_graph_op_info_blob_key(blob_key):
"""Parse the BLOB key for graph op info.
Args:
blob_key: The BLOB key to parse. By contract, it should have the format:
`${GRAPH_OP_INFO_BLOB_TAG_PREFIX}_${graph_id}_${op_name}.${run_name}`,
wherein
- `graph_id` is a UUID
- op_name conforms to the TensorFlow spec:
`^[A-Za-z0-9.][A-Za-z0-9_.\\/>-]*$`
- `run_name` is assumed to contain no dots (`'.'`s).
Returns:
- run name
- graph_id
- op name
"""
# NOTE: the op_name itself may include dots, this is why we use `rindex()`
# instead of `split()`.
last_dot_index = blob_key.rindex(".")
run = blob_key[last_dot_index + 1 :]
key_body = blob_key[:last_dot_index]
key_body = key_body[len(GRAPH_OP_INFO_BLOB_TAG_PREFIX) :]
_, graph_id, op_name = key_body.split("_", 2)
return run, graph_id, op_name
def graph_info_run_tag_filter(run, graph_id):
"""Create a RunTagFilter for graph info.
Args:
run: tfdbg2 run name.
graph_id: Debugger-generated ID of the graph in question.
Returns:
`RunTagFilter` for the run and range of graph info.
"""
if not graph_id:
raise ValueError("graph_id must not be None or empty.")
return provider.RunTagFilter(
runs=[run],
tags=["%s_%s" % (GRAPH_INFO_BLOB_TAG_PREFIX, graph_id)],
)
def _parse_graph_info_blob_key(blob_key):
"""Parse the BLOB key for graph info.
Args:
blob_key: The BLOB key to parse. By contract, it should have the format:
`${GRAPH_INFO_BLOB_TAG_PREFIX}_${graph_id}.${run_name}`,
Returns:
- run name
- graph_id
"""
key_body, run = blob_key.split(".")
graph_id = key_body[len(GRAPH_INFO_BLOB_TAG_PREFIX) + 1 :]
return run, graph_id
def source_file_list_run_tag_filter(run):
"""Create a RunTagFilter for listing source files.
Args:
run: tfdbg2 run name.
Returns:
`RunTagFilter` for listing the source files in the tfdbg2 run.
"""
return provider.RunTagFilter(runs=[run], tags=[SOURCE_FILE_LIST_BLOB_TAG])
def _parse_source_file_list_blob_key(blob_key):
"""Parse the BLOB key for source file list.
Args:
blob_key: The BLOB key to parse. By contract, it should have the format:
`${SOURCE_FILE_LIST_BLOB_TAG}.${run_id}`
Returns:
- run ID
"""
return blob_key[blob_key.index(".") + 1 :]
def source_file_run_tag_filter(run, index):
"""Create a RunTagFilter for listing source files.
Args:
run: tfdbg2 run name.
index: The index for the source file of which the content is to be
accessed.
Returns:
`RunTagFilter` for accessing the content of the source file.
"""
return provider.RunTagFilter(
runs=[run],
tags=["%s_%d" % (SOURCE_FILE_BLOB_TAG_PREFIX, index)],
)
def _parse_source_file_blob_key(blob_key):
"""Parse the BLOB key for accessing the content of a source file.
Args:
blob_key: The BLOB key to parse. By contract, it should have the format:
`${SOURCE_FILE_BLOB_TAG_PREFIX}_${index}.${run_id}`
Returns:
- run ID, as a str.
- File index, as an int.
"""
key_body, run = blob_key.split(".", 1)
index = int(key_body[len(SOURCE_FILE_BLOB_TAG_PREFIX) + 1 :])
return run, index
def stack_frames_run_tag_filter(run, stack_frame_ids):
"""Create a RunTagFilter for querying stack frames.
Args:
run: tfdbg2 run name.
stack_frame_ids: The stack_frame_ids being requested.
Returns:
`RunTagFilter` for accessing the content of the source file.
"""
return provider.RunTagFilter(
runs=[run],
# The stack-frame IDS are UUIDs, which do not contain underscores.
# Hence it's safe to concatenate them with underscores.
tags=[STACK_FRAMES_BLOB_TAG_PREFIX + "_" + "_".join(stack_frame_ids)],
)
def _parse_stack_frames_blob_key(blob_key):
"""Parse the BLOB key for source file list.
Args:
blob_key: The BLOB key to parse. By contract, it should have the format:
`${STACK_FRAMES_BLOB_TAG_PREFIX}_` +
`${stack_frame_id_0}_..._${stack_frame_id_N}.${run_id}`
Returns:
- run ID
- The stack frame IDs as a tuple of strings.
"""
key_body, run = blob_key.split(".", 1)
key_body = key_body[len(STACK_FRAMES_BLOB_TAG_PREFIX) + 1 :]
stack_frame_ids = key_body.split("_")
return run, stack_frame_ids
class LocalDebuggerV2DataProvider(provider.DataProvider):
"""A DataProvider implementation for tfdbg v2 data on local filesystem.
In this implementation, `experiment_id` is assumed to be the path to the
logdir that contains the DebugEvent file set.
"""
def __init__(self, logdir):
"""Constructor of LocalDebuggerV2DataProvider.
Args:
logdir: Path to the directory from which the tfdbg v2 data will be
loaded.
"""
super().__init__()
self._multiplexer = debug_data_multiplexer.DebuggerV2EventMultiplexer(
logdir
)
def list_runs(self, ctx=None, *, experiment_id):
"""List runs available.
Args:
experiment_id: currently unused, because the backing
DebuggerV2EventMultiplexer does not accommodate multiple experiments.
Returns:
Run names as a list of str.
"""
return [
provider.Run(
run_id=run, # use names as IDs
run_name=run,
start_time=self._get_first_event_timestamp(run),
)
for run in self._multiplexer.Runs()
]
def _get_first_event_timestamp(self, run_name):
try:
return self._multiplexer.FirstEventTimestamp(run_name)
except ValueError as e:
return None
def list_scalars(
self, ctx=None, *, experiment_id, plugin_name, run_tag_filter=None
):
del experiment_id, plugin_name, run_tag_filter # Unused.
raise TypeError("Debugger V2 DataProvider doesn't support scalars.")
def read_scalars(
self,
ctx=None,
*,
experiment_id,
plugin_name,
downsample=None,
run_tag_filter=None,
):
del experiment_id, plugin_name, downsample, run_tag_filter
raise TypeError("Debugger V2 DataProvider doesn't support scalars.")
def read_last_scalars(
self,
ctx=None,
*,
experiment_id,
plugin_name,
run_tag_filter=None,
):
del experiment_id, plugin_name, run_tag_filter
raise TypeError("Debugger V2 DataProvider doesn't support scalars.")
def list_blob_sequences(
self, ctx=None, *, experiment_id, plugin_name, run_tag_filter=None
):
del experiment_id, plugin_name, run_tag_filter # Unused currently.
# TODO(cais): Implement this.
raise NotImplementedError()
def read_blob_sequences(
self,
ctx=None,
*,
experiment_id,
plugin_name,
downsample=None,
run_tag_filter=None,
):
del experiment_id, downsample # Unused.
if plugin_name != PLUGIN_NAME:
raise ValueError("Unsupported plugin_name: %s" % plugin_name)
if run_tag_filter.runs is None:
raise ValueError(
"run_tag_filter.runs is expected to be specified, but is not."
)
if run_tag_filter.tags is None:
raise ValueError(
"run_tag_filter.tags is expected to be specified, but is not."
)
output = dict()
existing_runs = self._multiplexer.Runs()
for run in run_tag_filter.runs:
if run not in existing_runs:
continue
output[run] = dict()
for tag in run_tag_filter.tags:
if tag.startswith(
(
ALERTS_BLOB_TAG_PREFIX,
EXECUTION_DIGESTS_BLOB_TAG_PREFIX,
EXECUTION_DATA_BLOB_TAG_PREFIX,
GRAPH_EXECUTION_DIGESTS_BLOB_TAG_PREFIX,
GRAPH_EXECUTION_DATA_BLOB_TAG_PREFIX,
GRAPH_INFO_BLOB_TAG_PREFIX,
GRAPH_OP_INFO_BLOB_TAG_PREFIX,
SOURCE_FILE_BLOB_TAG_PREFIX,
STACK_FRAMES_BLOB_TAG_PREFIX,
)
) or tag in (SOURCE_FILE_LIST_BLOB_TAG,):
output[run][tag] = [
provider.BlobReference(blob_key="%s.%s" % (tag, run))
]
return output
def read_blob(self, ctx=None, *, blob_key):
if blob_key.startswith(ALERTS_BLOB_TAG_PREFIX):
run, begin, end, alert_type = _parse_alerts_blob_key(blob_key)
return json.dumps(
self._multiplexer.Alerts(
run, begin, end, alert_type_filter=alert_type
)
)
elif blob_key.startswith(EXECUTION_DIGESTS_BLOB_TAG_PREFIX):
run, begin, end = _parse_execution_digest_blob_key(blob_key)
return json.dumps(
self._multiplexer.ExecutionDigests(run, begin, end)
)
elif blob_key.startswith(EXECUTION_DATA_BLOB_TAG_PREFIX):
run, begin, end = _parse_execution_data_blob_key(blob_key)
return json.dumps(self._multiplexer.ExecutionData(run, begin, end))
elif blob_key.startswith(GRAPH_EXECUTION_DIGESTS_BLOB_TAG_PREFIX):
run, begin, end = _parse_graph_execution_digest_blob_key(blob_key)
return json.dumps(
self._multiplexer.GraphExecutionDigests(run, begin, end)
)
elif blob_key.startswith(GRAPH_EXECUTION_DATA_BLOB_TAG_PREFIX):
run, begin, end = _parse_graph_execution_data_blob_key(blob_key)
return json.dumps(
self._multiplexer.GraphExecutionData(run, begin, end)
)
elif blob_key.startswith(GRAPH_INFO_BLOB_TAG_PREFIX):
run, graph_id = _parse_graph_info_blob_key(blob_key)
return json.dumps(self._multiplexer.GraphInfo(run, graph_id))
elif blob_key.startswith(GRAPH_OP_INFO_BLOB_TAG_PREFIX):
run, graph_id, op_name = _parse_graph_op_info_blob_key(blob_key)
return json.dumps(
self._multiplexer.GraphOpInfo(run, graph_id, op_name)
)
elif blob_key.startswith(SOURCE_FILE_LIST_BLOB_TAG):
run = _parse_source_file_list_blob_key(blob_key)
return json.dumps(self._multiplexer.SourceFileList(run))
elif blob_key.startswith(SOURCE_FILE_BLOB_TAG_PREFIX):
run, index = _parse_source_file_blob_key(blob_key)
return json.dumps(self._multiplexer.SourceLines(run, index))
elif blob_key.startswith(STACK_FRAMES_BLOB_TAG_PREFIX):
run, stack_frame_ids = _parse_stack_frames_blob_key(blob_key)
return json.dumps(
self._multiplexer.StackFrames(run, stack_frame_ids)
)
else:
raise ValueError("Unrecognized blob_key: %s" % blob_key)

View File

@ -0,0 +1,505 @@
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""The TensorBoard Debugger V2 plugin."""
import threading
from werkzeug import wrappers
from tensorboard import errors
from tensorboard import plugin_util
from tensorboard.plugins import base_plugin
from tensorboard.plugins.debugger_v2 import debug_data_provider
from tensorboard.backend import http_util
def _error_response(request, error_message):
return http_util.Respond(
request,
{"error": error_message},
"application/json",
code=400,
)
def _missing_run_error_response(request):
return _error_response(request, "run parameter is not provided")
class DebuggerV2Plugin(base_plugin.TBPlugin):
"""Debugger V2 Plugin for TensorBoard."""
plugin_name = debug_data_provider.PLUGIN_NAME
def __init__(self, context):
"""Instantiates Debugger V2 Plugin via TensorBoard core.
Args:
context: A base_plugin.TBContext instance.
"""
super().__init__(context)
self._logdir = context.logdir
self._underlying_data_provider = None
# Held while initializing `_underlying_data_provider` for the first
# time, to make sure that we only construct one.
self._data_provider_init_lock = threading.Lock()
@property
def _data_provider(self):
if self._underlying_data_provider is not None:
return self._underlying_data_provider
with self._data_provider_init_lock:
if self._underlying_data_provider is not None:
return self._underlying_data_provider
# TODO(cais): Implement factory for DataProvider that takes into account
# the settings.
dp = debug_data_provider.LocalDebuggerV2DataProvider(self._logdir)
self._underlying_data_provider = dp
return dp
def get_plugin_apps(self):
# TODO(cais): Add routes as they are implemented.
return {
"/runs": self.serve_runs,
"/alerts": self.serve_alerts,
"/execution/digests": self.serve_execution_digests,
"/execution/data": self.serve_execution_data,
"/graph_execution/digests": self.serve_graph_execution_digests,
"/graph_execution/data": self.serve_graph_execution_data,
"/graphs/graph_info": self.serve_graph_info,
"/graphs/op_info": self.serve_graph_op_info,
"/source_files/list": self.serve_source_files_list,
"/source_files/file": self.serve_source_file,
"/stack_frames/stack_frames": self.serve_stack_frames,
}
def is_active(self):
"""The Debugger V2 plugin must be manually selected."""
return False
def frontend_metadata(self):
return base_plugin.FrontendMetadata(
is_ng_component=True, tab_name="Debugger V2", disable_reload=False
)
@wrappers.Request.application
def serve_runs(self, request):
experiment = plugin_util.experiment_id(request.environ)
runs = self._data_provider.list_runs(experiment_id=experiment)
run_listing = dict()
for run in runs:
run_listing[run.run_id] = {"start_time": run.start_time}
return http_util.Respond(request, run_listing, "application/json")
@wrappers.Request.application
def serve_alerts(self, request):
experiment = plugin_util.experiment_id(request.environ)
run = request.args.get("run")
if run is None:
return _missing_run_error_response(request)
begin = int(request.args.get("begin", "0"))
end = int(request.args.get("end", "-1"))
alert_type = request.args.get("alert_type", None)
run_tag_filter = debug_data_provider.alerts_run_tag_filter(
run, begin, end, alert_type=alert_type
)
blob_sequences = self._data_provider.read_blob_sequences(
experiment_id=experiment,
plugin_name=self.plugin_name,
run_tag_filter=run_tag_filter,
)
tag = next(iter(run_tag_filter.tags))
try:
return http_util.Respond(
request,
self._data_provider.read_blob(
blob_key=blob_sequences[run][tag][0].blob_key
),
"application/json",
)
except errors.InvalidArgumentError as e:
return _error_response(request, str(e))
@wrappers.Request.application
def serve_execution_digests(self, request):
experiment = plugin_util.experiment_id(request.environ)
run = request.args.get("run")
if run is None:
return _missing_run_error_response(request)
begin = int(request.args.get("begin", "0"))
end = int(request.args.get("end", "-1"))
run_tag_filter = debug_data_provider.execution_digest_run_tag_filter(
run, begin, end
)
blob_sequences = self._data_provider.read_blob_sequences(
experiment_id=experiment,
plugin_name=self.plugin_name,
run_tag_filter=run_tag_filter,
)
tag = next(iter(run_tag_filter.tags))
try:
return http_util.Respond(
request,
self._data_provider.read_blob(
blob_key=blob_sequences[run][tag][0].blob_key
),
"application/json",
)
except errors.InvalidArgumentError as e:
return _error_response(request, str(e))
@wrappers.Request.application
def serve_execution_data(self, request):
experiment = plugin_util.experiment_id(request.environ)
run = request.args.get("run")
if run is None:
return _missing_run_error_response(request)
begin = int(request.args.get("begin", "0"))
end = int(request.args.get("end", "-1"))
run_tag_filter = debug_data_provider.execution_data_run_tag_filter(
run, begin, end
)
blob_sequences = self._data_provider.read_blob_sequences(
experiment_id=experiment,
plugin_name=self.plugin_name,
run_tag_filter=run_tag_filter,
)
tag = next(iter(run_tag_filter.tags))
try:
return http_util.Respond(
request,
self._data_provider.read_blob(
blob_key=blob_sequences[run][tag][0].blob_key
),
"application/json",
)
except errors.InvalidArgumentError as e:
return _error_response(request, str(e))
@wrappers.Request.application
def serve_graph_execution_digests(self, request):
"""Serve digests of intra-graph execution events.
As the names imply, this route differs from `serve_execution_digests()`
in that it is for intra-graph execution, while `serve_execution_digests()`
is for top-level (eager) execution.
"""
experiment = plugin_util.experiment_id(request.environ)
run = request.args.get("run")
if run is None:
return _missing_run_error_response(request)
begin = int(request.args.get("begin", "0"))
end = int(request.args.get("end", "-1"))
run_tag_filter = (
debug_data_provider.graph_execution_digest_run_tag_filter(
run, begin, end
)
)
blob_sequences = self._data_provider.read_blob_sequences(
experiment_id=experiment,
plugin_name=self.plugin_name,
run_tag_filter=run_tag_filter,
)
tag = next(iter(run_tag_filter.tags))
try:
return http_util.Respond(
request,
self._data_provider.read_blob(
blob_key=blob_sequences[run][tag][0].blob_key
),
"application/json",
)
except errors.InvalidArgumentError as e:
return _error_response(request, str(e))
@wrappers.Request.application
def serve_graph_execution_data(self, request):
"""Serve detailed data objects of intra-graph execution events.
As the names imply, this route differs from `serve_execution_data()`
in that it is for intra-graph execution, while `serve_execution_data()`
is for top-level (eager) execution.
Unlike `serve_graph_execution_digests()`, this method serves the
full-sized data objects for intra-graph execution events.
"""
experiment = plugin_util.experiment_id(request.environ)
run = request.args.get("run")
if run is None:
return _missing_run_error_response(request)
begin = int(request.args.get("begin", "0"))
end = int(request.args.get("end", "-1"))
run_tag_filter = (
debug_data_provider.graph_execution_data_run_tag_filter(
run, begin, end
)
)
blob_sequences = self._data_provider.read_blob_sequences(
experiment_id=experiment,
plugin_name=self.plugin_name,
run_tag_filter=run_tag_filter,
)
tag = next(iter(run_tag_filter.tags))
try:
return http_util.Respond(
request,
self._data_provider.read_blob(
blob_key=blob_sequences[run][tag][0].blob_key
),
"application/json",
)
except errors.InvalidArgumentError as e:
return _error_response(request, str(e))
@wrappers.Request.application
def serve_graph_info(self, request):
"""Serve basic information about a TensorFlow graph.
The request specifies the debugger-generated ID of the graph being
queried.
The response contains a JSON object with the following fields:
- graph_id: The debugger-generated ID (echoing the request).
- name: The name of the graph (if any). For TensorFlow 2.x
Function Graphs (FuncGraphs), this is typically the name of
the underlying Python function, optionally prefixed with
TensorFlow-generated prefixed such as "__inference_".
Some graphs (e.g., certain outermost graphs) may have no names,
in which case this field is `null`.
- outer_graph_id: Outer graph ID (if any). For an outermost graph
without an outer graph context, this field is `null`.
- inner_graph_ids: Debugger-generated IDs of all the graphs
nested inside this graph. For a graph without any graphs nested
inside, this field is an empty array.
"""
experiment = plugin_util.experiment_id(request.environ)
run = request.args.get("run")
if run is None:
return _missing_run_error_response(request)
graph_id = request.args.get("graph_id")
run_tag_filter = debug_data_provider.graph_info_run_tag_filter(
run, graph_id
)
blob_sequences = self._data_provider.read_blob_sequences(
experiment_id=experiment,
plugin_name=self.plugin_name,
run_tag_filter=run_tag_filter,
)
tag = next(iter(run_tag_filter.tags))
try:
return http_util.Respond(
request,
self._data_provider.read_blob(
blob_key=blob_sequences[run][tag][0].blob_key
),
"application/json",
)
except errors.NotFoundError as e:
return _error_response(request, str(e))
@wrappers.Request.application
def serve_graph_op_info(self, request):
"""Serve information for ops in graphs.
The request specifies the op name and the ID of the graph that
contains the op.
The response contains a JSON object with the following fields:
- op_type
- op_name
- graph_ids: Stack of graph IDs that the op is located in, from
outermost to innermost. The length of this array is always >= 1.
The length is 1 if and only if the graph is an outermost graph.
- num_outputs: Number of output tensors.
- output_tensor_ids: The debugger-generated number IDs for the
symbolic output tensors of the op (an array of numbers).
- host_name: Name of the host on which the op is created.
- stack_trace: Stack frames of the op's creation.
- inputs: Specifications of all inputs to this op.
Currently only immediate (one level of) inputs are provided.
This is an array of length N_in, where N_in is the number of
data inputs received by the op. Each element of the array is an
object with the following fields:
- op_name: Name of the op that provides the input tensor.
- output_slot: 0-based output slot index from which the input
tensor emits.
- data: A recursive data structure of this same schema.
This field is not populated (undefined) at the leaf nodes
of this recursive data structure.
In the rare case wherein the data for an input cannot be
retrieved properly (e.g., special internal op types), this
field will be unpopulated.
This is an empty list for an op with no inputs.
- consumers: Specifications for all the downstream consuming ops of
this. Currently only immediate (one level of) consumers are provided.
This is an array of length N_out, where N_out is the number of
symbolic tensors output by this op.
Each element of the array is an array of which the length equals
the number of downstream ops that consume the corresponding symbolic
tensor (only data edges are tracked).
Each element of the array is an object with the following fields:
- op_name: Name of the op that receives the output tensor as an
input.
- input_slot: 0-based input slot index at which the downstream
op receives this output tensor.
- data: A recursive data structure of this very schema.
This field is not populated (undefined) at the leaf nodes
of this recursive data structure.
In the rare case wherein the data for a consumer op cannot be
retrieved properly (e.g., special internal op types), this
field will be unpopulated.
If this op has no output tensors, this is an empty array.
If one of the output tensors of this op has no consumers, the
corresponding element is an empty array.
"""
experiment = plugin_util.experiment_id(request.environ)
run = request.args.get("run")
if run is None:
return _missing_run_error_response(request)
graph_id = request.args.get("graph_id")
op_name = request.args.get("op_name")
run_tag_filter = debug_data_provider.graph_op_info_run_tag_filter(
run, graph_id, op_name
)
blob_sequences = self._data_provider.read_blob_sequences(
experiment_id=experiment,
plugin_name=self.plugin_name,
run_tag_filter=run_tag_filter,
)
tag = next(iter(run_tag_filter.tags))
try:
return http_util.Respond(
request,
self._data_provider.read_blob(
blob_key=blob_sequences[run][tag][0].blob_key
),
"application/json",
)
except errors.NotFoundError as e:
return _error_response(request, str(e))
@wrappers.Request.application
def serve_source_files_list(self, request):
"""Serves a list of all source files involved in the debugged program."""
experiment = plugin_util.experiment_id(request.environ)
run = request.args.get("run")
if run is None:
return _missing_run_error_response(request)
run_tag_filter = debug_data_provider.source_file_list_run_tag_filter(
run
)
blob_sequences = self._data_provider.read_blob_sequences(
experiment_id=experiment,
plugin_name=self.plugin_name,
run_tag_filter=run_tag_filter,
)
tag = next(iter(run_tag_filter.tags))
return http_util.Respond(
request,
self._data_provider.read_blob(
blob_key=blob_sequences[run][tag][0].blob_key
),
"application/json",
)
@wrappers.Request.application
def serve_source_file(self, request):
"""Serves the content of a given source file.
The source file is referred to by the index in the list of all source
files involved in the execution of the debugged program, which is
available via the `serve_source_files_list()` serving route.
Args:
request: HTTP request.
Returns:
Response to the request.
"""
experiment = plugin_util.experiment_id(request.environ)
run = request.args.get("run")
if run is None:
return _missing_run_error_response(request)
index = request.args.get("index")
# TOOD(cais): When the need arises, support serving a subset of a
# source file's lines.
if index is None:
return _error_response(
request, "index is not provided for source file content"
)
index = int(index)
run_tag_filter = debug_data_provider.source_file_run_tag_filter(
run, index
)
blob_sequences = self._data_provider.read_blob_sequences(
experiment_id=experiment,
plugin_name=self.plugin_name,
run_tag_filter=run_tag_filter,
)
tag = next(iter(run_tag_filter.tags))
try:
return http_util.Respond(
request,
self._data_provider.read_blob(
blob_key=blob_sequences[run][tag][0].blob_key
),
"application/json",
)
except errors.NotFoundError as e:
return _error_response(request, str(e))
@wrappers.Request.application
def serve_stack_frames(self, request):
"""Serves the content of stack frames.
The source frames being requested are referred to be UUIDs for each of
them, separated by commas.
Args:
request: HTTP request.
Returns:
Response to the request.
"""
experiment = plugin_util.experiment_id(request.environ)
run = request.args.get("run")
if run is None:
return _missing_run_error_response(request)
stack_frame_ids = request.args.get("stack_frame_ids")
if stack_frame_ids is None:
return _error_response(request, "Missing stack_frame_ids parameter")
if not stack_frame_ids:
return _error_response(request, "Empty stack_frame_ids parameter")
stack_frame_ids = stack_frame_ids.split(",")
run_tag_filter = debug_data_provider.stack_frames_run_tag_filter(
run, stack_frame_ids
)
blob_sequences = self._data_provider.read_blob_sequences(
experiment_id=experiment,
plugin_name=self.plugin_name,
run_tag_filter=run_tag_filter,
)
tag = next(iter(run_tag_filter.tags))
try:
return http_util.Respond(
request,
self._data_provider.read_blob(
blob_key=blob_sequences[run][tag][0].blob_key
),
"application/json",
)
except errors.NotFoundError as e:
return _error_response(request, str(e))

View File

@ -0,0 +1,159 @@
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Package for histogram compression."""
import dataclasses
import numpy as np
from typing import Tuple
# Normal CDF for std_devs: (-Inf, -1.5, -1, -0.5, 0, 0.5, 1, 1.5, Inf)
# naturally gives bands around median of width 1 std dev, 2 std dev, 3 std dev,
# and then the long tail.
NORMAL_HISTOGRAM_BPS = (0, 668, 1587, 3085, 5000, 6915, 8413, 9332, 10000)
@dataclasses.dataclass(frozen=True)
class CompressedHistogramValue:
"""Represents a value in a compressed histogram.
Attributes:
basis_point: Compression point represented in basis point, 1/100th of a
percent.
value: Cumulative weight at the basis point.
"""
basis_point: float
value: float
def as_tuple(self) -> Tuple[float, float]:
"""Returns the basis point and the value as a tuple."""
return (self.basis_point, self.value)
# TODO(@jart): Unfork these methods.
def compress_histogram_proto(histo, bps=NORMAL_HISTOGRAM_BPS):
"""Creates fixed size histogram by adding compression to accumulated state.
This routine transforms a histogram at a particular step by interpolating its
variable number of buckets to represent their cumulative weight at a constant
number of compression points. This significantly reduces the size of the
histogram and makes it suitable for a two-dimensional area plot where the
output of this routine constitutes the ranges for a single x coordinate.
Args:
histo: A HistogramProto object.
bps: Compression points represented in basis points, 1/100ths of a percent.
Defaults to normal distribution.
Returns:
List of values for each basis point.
"""
# See also: Histogram::Percentile() in core/lib/histogram/histogram.cc
if not histo.num:
return [CompressedHistogramValue(b, 0.0).as_tuple() for b in bps]
bucket = np.array(histo.bucket)
bucket_limit = list(histo.bucket_limit)
weights = (bucket * bps[-1] / (bucket.sum() or 1.0)).cumsum()
values = []
j = 0
while j < len(bps):
i = np.searchsorted(weights, bps[j], side="right")
while i < len(weights):
cumsum = weights[i]
cumsum_prev = weights[i - 1] if i > 0 else 0.0
if cumsum == cumsum_prev: # prevent lerp divide by zero
i += 1
continue
if not i or not cumsum_prev:
lhs = histo.min
else:
lhs = max(bucket_limit[i - 1], histo.min)
rhs = min(bucket_limit[i], histo.max)
weight = _lerp(bps[j], cumsum_prev, cumsum, lhs, rhs)
values.append(CompressedHistogramValue(bps[j], weight).as_tuple())
j += 1
break
else:
break
while j < len(bps):
values.append(CompressedHistogramValue(bps[j], histo.max).as_tuple())
j += 1
return values
def compress_histogram(buckets, bps=NORMAL_HISTOGRAM_BPS):
"""Creates fixed size histogram by adding compression to accumulated state.
This routine transforms a histogram at a particular step by linearly
interpolating its variable number of buckets to represent their cumulative
weight at a constant number of compression points. This significantly reduces
the size of the histogram and makes it suitable for a two-dimensional area
plot where the output of this routine constitutes the ranges for a single x
coordinate.
Args:
buckets: A list of buckets, each of which is a 3-tuple of the form
`(min, max, count)`.
bps: Compression points represented in basis points, 1/100ths of a percent.
Defaults to normal distribution.
Returns:
List of values for each basis point.
"""
# See also: Histogram::Percentile() in core/lib/histogram/histogram.cc
buckets = np.array(buckets)
if not buckets.size:
return [CompressedHistogramValue(b, 0.0).as_tuple() for b in bps]
(minmin, maxmax) = (buckets[0][0], buckets[-1][1])
counts = buckets[:, 2]
right_edges = list(buckets[:, 1])
weights = (counts * bps[-1] / (counts.sum() or 1.0)).cumsum()
result = []
bp_index = 0
while bp_index < len(bps):
i = np.searchsorted(weights, bps[bp_index], side="right")
while i < len(weights):
cumsum = weights[i]
cumsum_prev = weights[i - 1] if i > 0 else 0.0
if cumsum == cumsum_prev: # prevent division-by-zero in `_lerp`
i += 1
continue
if not i or not cumsum_prev:
lhs = minmin
else:
lhs = max(right_edges[i - 1], minmin)
rhs = min(right_edges[i], maxmax)
weight = _lerp(bps[bp_index], cumsum_prev, cumsum, lhs, rhs)
result.append(
CompressedHistogramValue(bps[bp_index], weight).as_tuple()
)
bp_index += 1
break
else:
break
while bp_index < len(bps):
result.append(
CompressedHistogramValue(bps[bp_index], maxmax).as_tuple()
)
bp_index += 1
return result
def _lerp(x, x0, x1, y0, y1):
"""Affinely map from [x0, x1] onto [y0, y1]."""
return y0 + (x - x0) * float(y1 - y0) / (x1 - x0)

View File

@ -0,0 +1,117 @@
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""The TensorBoard Distributions (a.k.a. compressed histograms) plugin.
See `http_api.md` in this directory for specifications of the routes for
this plugin.
"""
from werkzeug import wrappers
from tensorboard import plugin_util
from tensorboard.backend import http_util
from tensorboard.plugins import base_plugin
from tensorboard.plugins.distribution import compressor
from tensorboard.plugins.distribution import metadata
from tensorboard.plugins.histogram import histograms_plugin
class DistributionsPlugin(base_plugin.TBPlugin):
"""Distributions Plugin for TensorBoard.
This supports both old-style summaries (created with TensorFlow ops
that output directly to the `histo` field of the proto) and new-
style summaries (as created by the
`tensorboard.plugins.histogram.summary` module).
"""
plugin_name = metadata.PLUGIN_NAME
# Use a round number + 1 since sampling includes both start and end steps,
# so N+1 samples corresponds to dividing the step sequence into N intervals.
SAMPLE_SIZE = 501
def __init__(self, context):
"""Instantiates DistributionsPlugin via TensorBoard core.
Args:
context: A base_plugin.TBContext instance.
"""
self._histograms_plugin = histograms_plugin.HistogramsPlugin(context)
def get_plugin_apps(self):
return {
"/distributions": self.distributions_route,
"/tags": self.tags_route,
}
def is_active(self):
"""This plugin is active iff any run has at least one histogram tag.
(The distributions plugin uses the same data source as the
histogram plugin.)
"""
return self._histograms_plugin.is_active()
def data_plugin_names(self):
return (self._histograms_plugin.plugin_name,)
def frontend_metadata(self):
return base_plugin.FrontendMetadata(
element_name="tf-distribution-dashboard",
)
def distributions_impl(self, ctx, tag, run, experiment):
"""Result of the form `(body, mime_type)`.
Raises:
tensorboard.errors.PublicError: On invalid request.
"""
(histograms, mime_type) = self._histograms_plugin.histograms_impl(
ctx, tag, run, experiment=experiment, downsample_to=self.SAMPLE_SIZE
)
return (
[self._compress(histogram) for histogram in histograms],
mime_type,
)
def _compress(self, histogram):
(wall_time, step, buckets) = histogram
converted_buckets = compressor.compress_histogram(buckets)
return [wall_time, step, converted_buckets]
def index_impl(self, ctx, experiment):
return self._histograms_plugin.index_impl(ctx, experiment=experiment)
@wrappers.Request.application
def tags_route(self, request):
ctx = plugin_util.context(request.environ)
experiment = plugin_util.experiment_id(request.environ)
index = self.index_impl(ctx, experiment=experiment)
return http_util.Respond(request, index, "application/json")
@wrappers.Request.application
def distributions_route(self, request):
"""Given a tag and single run, return an array of compressed
histograms."""
ctx = plugin_util.context(request.environ)
experiment = plugin_util.experiment_id(request.environ)
tag = request.args.get("tag")
run = request.args.get("run")
(body, mime_type) = self.distributions_impl(
ctx, tag, run, experiment=experiment
)
return http_util.Respond(request, body, mime_type)

View File

@ -0,0 +1,20 @@
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Information on the distributions plugin."""
# This name is used as the plugin prefix route and to identify this plugin
# generally.
PLUGIN_NAME = "distributions"

View File

@ -0,0 +1,129 @@
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Utilities for graph plugin."""
from tensorboard.compat.proto import graph_pb2
def _prefixed_op_name(prefix, op_name):
return "%s/%s" % (prefix, op_name)
def _prefixed_func_name(prefix, func_name):
"""Returns function name prefixed with `prefix`.
For function libraries, which are often created out of autographed Python
function, are factored out in the graph vis. They are grouped under a
function name which often has a shape of
`__inference_[py_func_name]_[numeric_suffix]`.
While it does not have some unique information about which graph it is from,
creating another wrapping structure with graph prefix and "/" is less than
ideal so we join the prefix and func_name using underscore.
TODO(stephanwlee): add business logic to strip "__inference_" for more user
friendlier name
"""
return "%s_%s" % (prefix, func_name)
def _add_with_prepended_names(prefix, graph_to_add, destination_graph):
for node in graph_to_add.node:
new_node = destination_graph.node.add()
new_node.CopyFrom(node)
new_node.name = _prefixed_op_name(prefix, node.name)
new_node.input[:] = [
_prefixed_op_name(prefix, input_name) for input_name in node.input
]
# Remap tf.function method name in the PartitionedCall. 'f' is short for
# function.
if new_node.op == "PartitionedCall" and new_node.attr["f"]:
new_node.attr["f"].func.name = _prefixed_func_name(
prefix,
new_node.attr["f"].func.name,
)
for func in graph_to_add.library.function:
new_func = destination_graph.library.function.add()
new_func.CopyFrom(func)
new_func.signature.name = _prefixed_func_name(
prefix, new_func.signature.name
)
for gradient in graph_to_add.library.gradient:
new_gradient = destination_graph.library.gradient.add()
new_gradient.CopyFrom(gradient)
new_gradient.function_name = _prefixed_func_name(
prefix,
new_gradient.function_name,
)
new_gradient.gradient_func = _prefixed_func_name(
prefix,
new_gradient.gradient_func,
)
def merge_graph_defs(graph_defs):
"""Merges GraphDefs by adding unique prefix, `graph_{ind}`, to names.
All GraphDefs are expected to be of TensorBoard's.
When collecting graphs using the `tf.summary.trace` API, node names are not
guranteed to be unique. When non-unique names are not considered, it can
lead to graph visualization showing them as one which creates inaccurate
depiction of the flow of the graph (e.g., if there are A -> B -> C and D ->
B -> E, you may see {A, D} -> B -> E). To prevent such graph, we checked
for uniquenss while merging but it resulted in
https://github.com/tensorflow/tensorboard/issues/1929.
To remedy these issues, we simply "apply name scope" on each graph by
prefixing it with unique name (with a chance of collision) to create
unconnected group of graphs.
In case there is only one graph def passed, it returns the original
graph_def. In case no graph defs are passed, it returns an empty GraphDef.
Args:
graph_defs: TensorBoard GraphDefs to merge.
Returns:
TensorBoard GraphDef that merges all graph_defs with unique prefixes.
Raises:
ValueError in case GraphDef versions mismatch.
"""
if len(graph_defs) == 1:
return graph_defs[0]
elif len(graph_defs) == 0:
return graph_pb2.GraphDef()
dst_graph_def = graph_pb2.GraphDef()
if graph_defs[0].versions.producer:
dst_graph_def.versions.CopyFrom(graph_defs[0].versions)
for index, graph_def in enumerate(graph_defs):
if dst_graph_def.versions.producer != graph_def.versions.producer:
raise ValueError("Cannot combine GraphDefs of different versions.")
_add_with_prepended_names(
"graph_%d" % (index + 1),
graph_def,
dst_graph_def,
)
return dst_graph_def

View File

@ -0,0 +1,337 @@
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""The TensorBoard Graphs plugin."""
import json
from werkzeug import wrappers
from tensorboard import errors
from tensorboard import plugin_util
from tensorboard.backend import http_util
from tensorboard.backend import process_graph
from tensorboard.compat.proto import config_pb2
from tensorboard.compat.proto import graph_pb2
from tensorboard.data import provider
from tensorboard.plugins import base_plugin
from tensorboard.plugins.graph import graph_util
from tensorboard.plugins.graph import keras_util
from tensorboard.plugins.graph import metadata
from tensorboard.util import tb_logging
logger = tb_logging.get_logger()
class GraphsPlugin(base_plugin.TBPlugin):
"""Graphs Plugin for TensorBoard."""
plugin_name = metadata.PLUGIN_NAME
def __init__(self, context):
"""Instantiates GraphsPlugin via TensorBoard core.
Args:
context: A base_plugin.TBContext instance.
"""
self._data_provider = context.data_provider
def get_plugin_apps(self):
return {
"/graph": self.graph_route,
"/info": self.info_route,
"/run_metadata": self.run_metadata_route,
}
def is_active(self):
"""The graphs plugin is active iff any run has a graph or metadata."""
return False # `list_plugins` as called by TB core suffices
def data_plugin_names(self):
return (
metadata.PLUGIN_NAME,
metadata.PLUGIN_NAME_RUN_METADATA,
metadata.PLUGIN_NAME_RUN_METADATA_WITH_GRAPH,
metadata.PLUGIN_NAME_KERAS_MODEL,
metadata.PLUGIN_NAME_TAGGED_RUN_METADATA,
)
def frontend_metadata(self):
return base_plugin.FrontendMetadata(
element_name="tf-graph-dashboard",
# TODO(@chihuahua): Reconcile this setting with Health Pills.
disable_reload=True,
)
def info_impl(self, ctx, experiment=None):
"""Returns a dict of all runs and their data availabilities."""
result = {}
def add_row_item(run, tag=None):
run_item = result.setdefault(
run,
{
"run": run,
"tags": {},
# A run-wide GraphDef of ops.
"run_graph": False,
},
)
tag_item = None
if tag:
tag_item = run_item.get("tags").setdefault(
tag,
{
"tag": tag,
"conceptual_graph": False,
# A tagged GraphDef of ops.
"op_graph": False,
"profile": False,
},
)
return (run_item, tag_item)
mapping = self._data_provider.list_blob_sequences(
ctx,
experiment_id=experiment,
plugin_name=metadata.PLUGIN_NAME_RUN_METADATA_WITH_GRAPH,
)
for run_name, tags in mapping.items():
for tag, tag_data in tags.items():
# The Summary op is defined in TensorFlow and does not use a stringified proto
# as a content of plugin data. It contains single string that denotes a version.
# https://github.com/tensorflow/tensorflow/blob/11f4ecb54708865ec757ca64e4805957b05d7570/tensorflow/python/ops/summary_ops_v2.py#L789-L790
if tag_data.plugin_content != b"1":
logger.warning(
"Ignoring unrecognizable version of RunMetadata."
)
continue
(_, tag_item) = add_row_item(run_name, tag)
tag_item["op_graph"] = True
# Tensors associated with plugin name metadata.PLUGIN_NAME_RUN_METADATA
# contain both op graph and profile information.
mapping = self._data_provider.list_blob_sequences(
ctx,
experiment_id=experiment,
plugin_name=metadata.PLUGIN_NAME_RUN_METADATA,
)
for run_name, tags in mapping.items():
for tag, tag_data in tags.items():
if tag_data.plugin_content != b"1":
logger.warning(
"Ignoring unrecognizable version of RunMetadata."
)
continue
(_, tag_item) = add_row_item(run_name, tag)
tag_item["profile"] = True
tag_item["op_graph"] = True
# Tensors associated with plugin name metadata.PLUGIN_NAME_KERAS_MODEL
# contain serialized Keras model in JSON format.
mapping = self._data_provider.list_blob_sequences(
ctx,
experiment_id=experiment,
plugin_name=metadata.PLUGIN_NAME_KERAS_MODEL,
)
for run_name, tags in mapping.items():
for tag, tag_data in tags.items():
if tag_data.plugin_content != b"1":
logger.warning(
"Ignoring unrecognizable version of RunMetadata."
)
continue
(_, tag_item) = add_row_item(run_name, tag)
tag_item["conceptual_graph"] = True
mapping = self._data_provider.list_blob_sequences(
ctx,
experiment_id=experiment,
plugin_name=metadata.PLUGIN_NAME,
)
for run_name, tags in mapping.items():
if metadata.RUN_GRAPH_NAME in tags:
(run_item, _) = add_row_item(run_name, None)
run_item["run_graph"] = True
# Top level `Event.tagged_run_metadata` represents profile data only.
mapping = self._data_provider.list_blob_sequences(
ctx,
experiment_id=experiment,
plugin_name=metadata.PLUGIN_NAME_TAGGED_RUN_METADATA,
)
for run_name, tags in mapping.items():
for tag in tags:
(_, tag_item) = add_row_item(run_name, tag)
tag_item["profile"] = True
return result
def _read_blob(self, ctx, experiment, plugin_names, run, tag):
for plugin_name in plugin_names:
blob_sequences = self._data_provider.read_blob_sequences(
ctx,
experiment_id=experiment,
plugin_name=plugin_name,
run_tag_filter=provider.RunTagFilter(runs=[run], tags=[tag]),
downsample=1,
)
blob_sequence_data = blob_sequences.get(run, {}).get(tag, ())
try:
blob_ref = blob_sequence_data[0].values[0]
except IndexError:
continue
return self._data_provider.read_blob(
ctx, blob_key=blob_ref.blob_key
)
raise errors.NotFoundError()
def graph_impl(
self,
ctx,
run,
tag,
is_conceptual,
experiment=None,
limit_attr_size=None,
large_attrs_key=None,
):
"""Result of the form `(body, mime_type)`; may raise `NotFound`."""
if is_conceptual:
keras_model_config = json.loads(
self._read_blob(
ctx,
experiment,
[metadata.PLUGIN_NAME_KERAS_MODEL],
run,
tag,
)
)
graph = keras_util.keras_model_to_graph_def(keras_model_config)
elif tag is None:
graph_raw = self._read_blob(
ctx,
experiment,
[metadata.PLUGIN_NAME],
run,
metadata.RUN_GRAPH_NAME,
)
graph = graph_pb2.GraphDef.FromString(graph_raw)
else:
# Op graph: could be either of two plugins. (Cf. `info_impl`.)
plugins = [
metadata.PLUGIN_NAME_RUN_METADATA,
metadata.PLUGIN_NAME_RUN_METADATA_WITH_GRAPH,
]
raw_run_metadata = self._read_blob(
ctx, experiment, plugins, run, tag
)
run_metadata = config_pb2.RunMetadata.FromString(raw_run_metadata)
graph = graph_util.merge_graph_defs(
[
func_graph.pre_optimization_graph
for func_graph in run_metadata.function_graphs
]
)
# This next line might raise a ValueError if the limit parameters
# are invalid (size is negative, size present but key absent, etc.).
process_graph.prepare_graph_for_ui(
graph, limit_attr_size, large_attrs_key
)
return (str(graph), "text/x-protobuf") # pbtxt
def run_metadata_impl(self, ctx, experiment, run, tag):
"""Result of the form `(body, mime_type)`; may raise `NotFound`."""
# Profile graph: could be either of two plugins. (Cf. `info_impl`.)
plugins = [
metadata.PLUGIN_NAME_TAGGED_RUN_METADATA,
metadata.PLUGIN_NAME_RUN_METADATA,
]
raw_run_metadata = self._read_blob(ctx, experiment, plugins, run, tag)
run_metadata = config_pb2.RunMetadata.FromString(raw_run_metadata)
return (str(run_metadata), "text/x-protobuf") # pbtxt
@wrappers.Request.application
def info_route(self, request):
ctx = plugin_util.context(request.environ)
experiment = plugin_util.experiment_id(request.environ)
info = self.info_impl(ctx, experiment)
return http_util.Respond(request, info, "application/json")
@wrappers.Request.application
def graph_route(self, request):
"""Given a single run, return the graph definition in protobuf
format."""
ctx = plugin_util.context(request.environ)
experiment = plugin_util.experiment_id(request.environ)
run = request.args.get("run")
tag = request.args.get("tag")
conceptual_arg = request.args.get("conceptual", False)
is_conceptual = True if conceptual_arg == "true" else False
if run is None:
return http_util.Respond(
request, 'query parameter "run" is required', "text/plain", 400
)
limit_attr_size = request.args.get("limit_attr_size", None)
if limit_attr_size is not None:
try:
limit_attr_size = int(limit_attr_size)
except ValueError:
return http_util.Respond(
request,
"query parameter `limit_attr_size` must be an integer",
"text/plain",
400,
)
large_attrs_key = request.args.get("large_attrs_key", None)
try:
result = self.graph_impl(
ctx,
run,
tag,
is_conceptual,
experiment,
limit_attr_size,
large_attrs_key,
)
except ValueError as e:
return http_util.Respond(request, e.message, "text/plain", code=400)
(body, mime_type) = result
return http_util.Respond(request, body, mime_type)
@wrappers.Request.application
def run_metadata_route(self, request):
"""Given a tag and a run, return the session.run() metadata."""
ctx = plugin_util.context(request.environ)
experiment = plugin_util.experiment_id(request.environ)
tag = request.args.get("tag")
run = request.args.get("run")
if tag is None:
return http_util.Respond(
request, 'query parameter "tag" is required', "text/plain", 400
)
if run is None:
return http_util.Respond(
request, 'query parameter "run" is required', "text/plain", 400
)
(body, mime_type) = self.run_metadata_impl(ctx, experiment, run, tag)
return http_util.Respond(request, body, mime_type)

View File

@ -0,0 +1,328 @@
# -*- coding: utf-8 -*-
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Utilities for handling Keras model in graph plugin.
Two canonical types of Keras model are Functional and Sequential.
A model can be serialized as JSON and deserialized to reconstruct a model.
This utility helps with dealing with the serialized Keras model.
They have distinct structures to the configurations in shapes below:
Functional:
config
name: Name of the model. If not specified, it is 'model' with
an optional suffix if there are more than one instance.
input_layers: Keras.layers.Inputs in the model.
output_layers: Layer names that are outputs of the model.
layers: list of layer configurations.
layer: [*]
inbound_nodes: inputs to this layer.
Sequential:
config
name: Name of the model. If not specified, it is 'sequential' with
an optional suffix if there are more than one instance.
layers: list of layer configurations.
layer: [*]
[*]: Note that a model can be a layer.
Please refer to https://github.com/tensorflow/tfjs-layers/blob/master/src/keras_format/model_serialization.ts
for more complete definition.
"""
from tensorboard.compat.proto.graph_pb2 import GraphDef
from tensorboard.compat.tensorflow_stub import dtypes
from tensorboard.util import tb_logging
logger = tb_logging.get_logger()
def _walk_layers(keras_layer):
"""Walks the nested keras layer configuration in preorder.
Args:
keras_layer: Keras configuration from model.to_json.
Yields:
A tuple of (name_scope, layer_config).
name_scope: a string representing a scope name, similar to that of tf.name_scope.
layer_config: a dict representing a Keras layer configuration.
"""
yield ("", keras_layer)
if keras_layer.get("config").get("layers"):
name_scope = keras_layer.get("config").get("name")
for layer in keras_layer.get("config").get("layers"):
for sub_name_scope, sublayer in _walk_layers(layer):
sub_name_scope = (
"%s/%s" % (name_scope, sub_name_scope)
if sub_name_scope
else name_scope
)
yield (sub_name_scope, sublayer)
def _scoped_name(name_scope, node_name):
"""Returns scoped name for a node as a string in the form '<scope>/<node
name>'.
Args:
name_scope: a string representing a scope name, similar to that of tf.name_scope.
node_name: a string representing the current node name.
Returns
A string representing a scoped name.
"""
if name_scope:
return "%s/%s" % (name_scope, node_name)
return node_name
def _is_model(layer):
"""Returns True if layer is a model.
Args:
layer: a dict representing a Keras model configuration.
Returns:
bool: True if layer is a model.
"""
return layer.get("config").get("layers") is not None
def _norm_to_list_of_layers(maybe_layers):
"""Normalizes to a list of layers.
Args:
maybe_layers: A list of data[1] or a list of list of data.
Returns:
List of list of data.
[1]: A Functional model has fields 'inbound_nodes' and 'output_layers' which can
look like below:
- ['in_layer_name', 0, 0]
- [['in_layer_is_model', 1, 0], ['in_layer_is_model', 1, 1]]
The data inside the list seems to describe [name, size, index].
"""
return (
maybe_layers if isinstance(maybe_layers[0], (list,)) else [maybe_layers]
)
def _get_inbound_nodes(layer):
"""Returns a list of [name, size, index] for all inbound nodes of the given layer."""
inbound_nodes = []
if layer.get("inbound_nodes") is not None:
for maybe_inbound_node in layer.get("inbound_nodes", []):
if not isinstance(maybe_inbound_node, dict):
# Note that the inbound node parsing is not backward compatible with
# Keras 2. If given a Keras 2 model, the input nodes will be missing
# in the final graph.
continue
for inbound_node_args in maybe_inbound_node.get("args", []):
# Sometimes this field is a list when there are multiple inbound nodes
# for the given layer.
if not isinstance(inbound_node_args, list):
inbound_node_args = [inbound_node_args]
for arg in inbound_node_args:
history = arg.get("config", {}).get("keras_history", [])
if len(history) < 3:
continue
inbound_nodes.append(history[:3])
return inbound_nodes
def _update_dicts(
name_scope,
model_layer,
input_to_in_layer,
model_name_to_output,
prev_node_name,
):
"""Updates input_to_in_layer, model_name_to_output, and prev_node_name
based on the model_layer.
Args:
name_scope: a string representing a scope name, similar to that of tf.name_scope.
model_layer: a dict representing a Keras model configuration.
input_to_in_layer: a dict mapping Keras.layers.Input to inbound layer.
model_name_to_output: a dict mapping Keras Model name to output layer of the model.
prev_node_name: a string representing a previous, in sequential model layout,
node name.
Returns:
A tuple of (input_to_in_layer, model_name_to_output, prev_node_name).
input_to_in_layer: a dict mapping Keras.layers.Input to inbound layer.
model_name_to_output: a dict mapping Keras Model name to output layer of the model.
prev_node_name: a string representing a previous, in sequential model layout,
node name.
"""
layer_config = model_layer.get("config")
if not layer_config.get("layers"):
raise ValueError("layer is not a model.")
node_name = _scoped_name(name_scope, layer_config.get("name"))
input_layers = layer_config.get("input_layers")
output_layers = layer_config.get("output_layers")
inbound_nodes = _get_inbound_nodes(model_layer)
is_functional_model = bool(input_layers and output_layers)
# In case of [1] and the parent model is functional, current layer
# will have the 'inbound_nodes' property.
is_parent_functional_model = bool(inbound_nodes)
if is_parent_functional_model and is_functional_model:
for input_layer, inbound_node in zip(input_layers, inbound_nodes):
input_layer_name = _scoped_name(node_name, input_layer)
inbound_node_name = _scoped_name(name_scope, inbound_node[0])
input_to_in_layer[input_layer_name] = inbound_node_name
elif is_parent_functional_model and not is_functional_model:
# Sequential model can take only one input. Make sure inbound to the
# model is linked to the first layer in the Sequential model.
prev_node_name = _scoped_name(name_scope, inbound_nodes[0][0])
elif (
not is_parent_functional_model
and prev_node_name
and is_functional_model
):
assert len(input_layers) == 1, (
"Cannot have multi-input Functional model when parent model "
"is not Functional. Number of input layers: %d" % len(input_layer)
)
input_layer = input_layers[0]
input_layer_name = _scoped_name(node_name, input_layer)
input_to_in_layer[input_layer_name] = prev_node_name
if is_functional_model and output_layers:
layers = _norm_to_list_of_layers(output_layers)
layer_names = [_scoped_name(node_name, layer[0]) for layer in layers]
model_name_to_output[node_name] = layer_names
else:
last_layer = layer_config.get("layers")[-1]
last_layer_name = last_layer.get("config").get("name")
output_node = _scoped_name(node_name, last_layer_name)
model_name_to_output[node_name] = [output_node]
return (input_to_in_layer, model_name_to_output, prev_node_name)
def keras_model_to_graph_def(keras_layer):
"""Returns a GraphDef representation of the Keras model in a dict form.
Note that it only supports models that implemented to_json().
Args:
keras_layer: A dict from Keras model.to_json().
Returns:
A GraphDef representation of the layers in the model.
"""
input_to_layer = {}
model_name_to_output = {}
g = GraphDef()
# Sequential model layers do not have a field "inbound_nodes" but
# instead are defined implicitly via order of layers.
prev_node_name = None
for name_scope, layer in _walk_layers(keras_layer):
if _is_model(layer):
(
input_to_layer,
model_name_to_output,
prev_node_name,
) = _update_dicts(
name_scope,
layer,
input_to_layer,
model_name_to_output,
prev_node_name,
)
continue
layer_config = layer.get("config")
node_name = _scoped_name(name_scope, layer_config.get("name"))
node_def = g.node.add()
node_def.name = node_name
if layer.get("class_name") is not None:
keras_cls_name = layer.get("class_name").encode("ascii")
node_def.attr["keras_class"].s = keras_cls_name
dtype_or_policy = layer_config.get("dtype")
dtype = None
has_unsupported_value = False
# If this is a dict, try and extract the dtype string from
# `config.name`. Keras will export like this for non-input layers and
# some other cases (e.g. tf/keras/mixed_precision/Policy, as described
# in issue #5548).
if isinstance(dtype_or_policy, dict) and "config" in dtype_or_policy:
dtype = dtype_or_policy.get("config").get("name")
elif dtype_or_policy is not None:
dtype = dtype_or_policy
if dtype is not None:
try:
tf_dtype = dtypes.as_dtype(dtype)
node_def.attr["dtype"].type = tf_dtype.as_datatype_enum
except TypeError:
has_unsupported_value = True
elif dtype_or_policy is not None:
has_unsupported_value = True
if has_unsupported_value:
# There's at least one known case when this happens, which is when
# mixed precision dtype policies are used, as described in issue
# #5548. (See https://keras.io/api/mixed_precision/).
# There might be a better way to handle this, but here we are.
logger.warning(
"Unsupported dtype value in graph model config (json):\n%s",
dtype_or_policy,
)
if layer.get("inbound_nodes") is not None:
for name, size, index in _get_inbound_nodes(layer):
inbound_name = _scoped_name(name_scope, name)
# An input to a layer can be output from a model. In that case, the name
# of inbound_nodes to a layer is a name of a model. Remap the name of the
# model to output layer of the model. Also, since there can be multiple
# outputs in a model, make sure we pick the right output_layer from the model.
inbound_node_names = model_name_to_output.get(
inbound_name, [inbound_name]
)
# There can be multiple inbound_nodes that reference the
# same upstream layer. This causes issues when looking for
# a particular index in that layer, since the indices
# captured in `inbound_nodes` doesn't necessarily match the
# number of entries in the `inbound_node_names` list. To
# avoid IndexErrors, we just use the last element in the
# `inbound_node_names` in this situation.
# Note that this is a quick hack to avoid IndexErrors in
# this situation, and might not be an appropriate solution
# to this problem in general.
input_name = (
inbound_node_names[index]
if index < len(inbound_node_names)
else inbound_node_names[-1]
)
node_def.input.append(input_name)
elif prev_node_name is not None:
node_def.input.append(prev_node_name)
if node_name in input_to_layer:
node_def.input.append(input_to_layer.get(node_name))
prev_node_name = node_def.name
return g

View File

@ -0,0 +1,43 @@
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Information on the graph plugin."""
# This name is used as the plugin prefix route and to identify this plugin
# generally, and is also the `plugin_name` for run graphs after data-compat
# transformations.
PLUGIN_NAME = "graphs"
# The Summary API is implemented in TensorFlow because it uses TensorFlow internal APIs.
# As a result, this SummaryMetadata is a bit unconventional and uses non-public
# hardcoded name as the plugin name. Please refer to link below for the summary ops.
# https://github.com/tensorflow/tensorflow/blob/11f4ecb54708865ec757ca64e4805957b05d7570/tensorflow/python/ops/summary_ops_v2.py#L757
PLUGIN_NAME_RUN_METADATA = "graph_run_metadata"
# https://github.com/tensorflow/tensorflow/blob/11f4ecb54708865ec757ca64e4805957b05d7570/tensorflow/python/ops/summary_ops_v2.py#L788
PLUGIN_NAME_RUN_METADATA_WITH_GRAPH = "graph_run_metadata_graph"
# https://github.com/tensorflow/tensorflow/blob/565952cc2f17fdfd995e25171cf07be0f6f06180/tensorflow/python/ops/summary_ops_v2.py#L825
PLUGIN_NAME_KERAS_MODEL = "graph_keras_model"
# Plugin name used for `Event.tagged_run_metadata`. This doesn't fall into one
# of the above cases because (despite the name) `PLUGIN_NAME_RUN_METADATA` is
# _required_ to have both profile and op graphs, whereas tagged run metadata
# need only have profile data.
PLUGIN_NAME_TAGGED_RUN_METADATA = "graph_tagged_run_metadata"
# In the context of the data provider interface, tag name given to a
# graph read from the `graph_def` field of an `Event` proto, which is
# not attached to a summary and thus does not have a proper tag name of
# its own. Run level graphs always represent `GraphDef`s (graphs of
# TensorFlow ops), never conceptual graphs, profile graphs, etc. This is
# the only tag name used by the `"graphs"` plugin.
RUN_GRAPH_NAME = "__run_graph__"

View File

@ -0,0 +1,146 @@
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""The TensorBoard Histograms plugin.
See `http_api.md` in this directory for specifications of the routes for
this plugin.
"""
from werkzeug import wrappers
from tensorboard import errors
from tensorboard import plugin_util
from tensorboard.backend import http_util
from tensorboard.data import provider
from tensorboard.plugins import base_plugin
from tensorboard.plugins.histogram import metadata
_DEFAULT_DOWNSAMPLING = 500 # histograms per time series
class HistogramsPlugin(base_plugin.TBPlugin):
"""Histograms Plugin for TensorBoard.
This supports both old-style summaries (created with TensorFlow ops
that output directly to the `histo` field of the proto) and new-
style summaries (as created by the
`tensorboard.plugins.histogram.summary` module).
"""
plugin_name = metadata.PLUGIN_NAME
# Use a round number + 1 since sampling includes both start and end steps,
# so N+1 samples corresponds to dividing the step sequence into N intervals.
SAMPLE_SIZE = 51
def __init__(self, context):
"""Instantiates HistogramsPlugin via TensorBoard core.
Args:
context: A base_plugin.TBContext instance.
"""
self._downsample_to = (context.sampling_hints or {}).get(
self.plugin_name, _DEFAULT_DOWNSAMPLING
)
self._data_provider = context.data_provider
self._version_checker = plugin_util._MetadataVersionChecker(
data_kind="histogram",
latest_known_version=0,
)
def get_plugin_apps(self):
return {
"/histograms": self.histograms_route,
"/tags": self.tags_route,
}
def is_active(self):
return False # `list_plugins` as called by TB core suffices
def index_impl(self, ctx, experiment):
"""Return {runName: {tagName: {displayName: ..., description:
...}}}."""
mapping = self._data_provider.list_tensors(
ctx,
experiment_id=experiment,
plugin_name=metadata.PLUGIN_NAME,
)
result = {run: {} for run in mapping}
for run, tag_to_content in mapping.items():
for tag, metadatum in tag_to_content.items():
description = plugin_util.markdown_to_safe_html(
metadatum.description
)
md = metadata.parse_plugin_metadata(metadatum.plugin_content)
if not self._version_checker.ok(md.version, run, tag):
continue
result[run][tag] = {
"displayName": metadatum.display_name,
"description": description,
}
return result
def frontend_metadata(self):
return base_plugin.FrontendMetadata(
element_name="tf-histogram-dashboard"
)
def histograms_impl(self, ctx, tag, run, experiment, downsample_to=None):
"""Result of the form `(body, mime_type)`.
At most `downsample_to` events will be returned. If this value is
`None`, then default downsampling will be performed.
Raises:
tensorboard.errors.PublicError: On invalid request.
"""
sample_count = (
downsample_to if downsample_to is not None else self._downsample_to
)
all_histograms = self._data_provider.read_tensors(
ctx,
experiment_id=experiment,
plugin_name=metadata.PLUGIN_NAME,
downsample=sample_count,
run_tag_filter=provider.RunTagFilter(runs=[run], tags=[tag]),
)
histograms = all_histograms.get(run, {}).get(tag, None)
if histograms is None:
raise errors.NotFoundError(
"No histogram tag %r for run %r" % (tag, run)
)
events = [(e.wall_time, e.step, e.numpy.tolist()) for e in histograms]
return (events, "application/json")
@wrappers.Request.application
def tags_route(self, request):
ctx = plugin_util.context(request.environ)
experiment = plugin_util.experiment_id(request.environ)
index = self.index_impl(ctx, experiment=experiment)
return http_util.Respond(request, index, "application/json")
@wrappers.Request.application
def histograms_route(self, request):
"""Given a tag and single run, return array of histogram values."""
ctx = plugin_util.context(request.environ)
experiment = plugin_util.experiment_id(request.environ)
tag = request.args.get("tag")
run = request.args.get("run")
(body, mime_type) = self.histograms_impl(
ctx, tag, run, experiment=experiment, downsample_to=self.SAMPLE_SIZE
)
return http_util.Respond(request, body, mime_type)

View File

@ -0,0 +1,64 @@
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Information about histogram summaries."""
from tensorboard.compat.proto import summary_pb2
from tensorboard.plugins.histogram import plugin_data_pb2
PLUGIN_NAME = "histograms"
# The most recent value for the `version` field of the
# `HistogramPluginData` proto.
PROTO_VERSION = 0
def create_summary_metadata(display_name, description):
"""Create a `summary_pb2.SummaryMetadata` proto for histogram plugin data.
Returns:
A `summary_pb2.SummaryMetadata` protobuf object.
"""
content = plugin_data_pb2.HistogramPluginData(version=PROTO_VERSION)
return summary_pb2.SummaryMetadata(
display_name=display_name,
summary_description=description,
plugin_data=summary_pb2.SummaryMetadata.PluginData(
plugin_name=PLUGIN_NAME, content=content.SerializeToString()
),
)
def parse_plugin_metadata(content):
"""Parse summary metadata to a Python object.
Arguments:
content: The `content` field of a `SummaryMetadata` proto
corresponding to the histogram plugin.
Returns:
A `HistogramPluginData` protobuf object.
"""
if not isinstance(content, bytes):
raise TypeError("Content type must be bytes")
if content == b"{}":
# Old-style JSON format. Equivalent to an all-default proto.
return plugin_data_pb2.HistogramPluginData()
else:
result = plugin_data_pb2.HistogramPluginData.FromString(content)
if result.version == 0:
return result
# No other versions known at this time, so no migrations to do.
return result

View File

@ -0,0 +1,34 @@
# -*- coding: utf-8 -*-
# Generated by the protocol buffer compiler. DO NOT EDIT!
# source: tensorboard/plugins/histogram/plugin_data.proto
"""Generated protocol buffer code."""
from google.protobuf import descriptor as _descriptor
from google.protobuf import descriptor_pool as _descriptor_pool
from google.protobuf import message as _message
from google.protobuf import reflection as _reflection
from google.protobuf import symbol_database as _symbol_database
# @@protoc_insertion_point(imports)
_sym_db = _symbol_database.Default()
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n/tensorboard/plugins/histogram/plugin_data.proto\x12\x0btensorboard\"&\n\x13HistogramPluginData\x12\x0f\n\x07version\x18\x01 \x01(\x05\x62\x06proto3')
_HISTOGRAMPLUGINDATA = DESCRIPTOR.message_types_by_name['HistogramPluginData']
HistogramPluginData = _reflection.GeneratedProtocolMessageType('HistogramPluginData', (_message.Message,), {
'DESCRIPTOR' : _HISTOGRAMPLUGINDATA,
'__module__' : 'tensorboard.plugins.histogram.plugin_data_pb2'
# @@protoc_insertion_point(class_scope:tensorboard.HistogramPluginData)
})
_sym_db.RegisterMessage(HistogramPluginData)
if _descriptor._USE_C_DESCRIPTORS == False:
DESCRIPTOR._options = None
_HISTOGRAMPLUGINDATA._serialized_start=64
_HISTOGRAMPLUGINDATA._serialized_end=102
# @@protoc_insertion_point(module_scope)

View File

@ -0,0 +1,236 @@
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Histogram summaries and TensorFlow operations to create them.
A histogram summary stores a list of buckets. Each bucket is encoded as
a triple `[left_edge, right_edge, count]`. Thus, a full histogram is
encoded as a tensor of dimension `[k, 3]`.
In general, the value of `k` (the number of buckets) will be a constant,
like 30. There are two edge cases: if there is no data, then there are
no buckets (the shape is `[0, 3]`); and if there is data but all points
have the same value, then there is one bucket whose left and right
endpoints are the same (the shape is `[1, 3]`).
NOTE: This module is in beta, and its API is subject to change, but the
data that it stores to disk will be supported forever.
"""
import numpy as np
from tensorboard.plugins.histogram import metadata
from tensorboard.plugins.histogram import summary_v2
# Export V3 versions.
histogram = summary_v2.histogram
histogram_pb = summary_v2.histogram_pb
def _buckets(data, bucket_count=None):
"""Create a TensorFlow op to group data into histogram buckets.
Arguments:
data: A `Tensor` of any shape. Must be castable to `float64`.
bucket_count: Optional positive `int` or scalar `int32` `Tensor`.
Returns:
A `Tensor` of shape `[k, 3]` and type `float64`. The `i`th row is
a triple `[left_edge, right_edge, count]` for a single bucket.
The value of `k` is either `bucket_count` or `1` or `0`.
"""
# TODO(nickfelt): remove on-demand imports once dep situation is fixed.
import tensorflow.compat.v1 as tf
if bucket_count is None:
bucket_count = summary_v2.DEFAULT_BUCKET_COUNT
with tf.name_scope(
"buckets", values=[data, bucket_count]
), tf.control_dependencies(
[tf.assert_scalar(bucket_count), tf.assert_type(bucket_count, tf.int32)]
):
data = tf.reshape(data, shape=[-1]) # flatten
data = tf.cast(data, tf.float64)
is_empty = tf.equal(tf.size(input=data), 0)
def when_empty():
return tf.constant([], shape=(0, 3), dtype=tf.float64)
def when_nonempty():
min_ = tf.reduce_min(input_tensor=data)
max_ = tf.reduce_max(input_tensor=data)
range_ = max_ - min_
is_singular = tf.equal(range_, 0)
def when_nonsingular():
bucket_width = range_ / tf.cast(bucket_count, tf.float64)
offsets = data - min_
bucket_indices = tf.cast(
tf.floor(offsets / bucket_width), dtype=tf.int32
)
clamped_indices = tf.minimum(bucket_indices, bucket_count - 1)
# Use float64 instead of float32 to avoid accumulating floating point error
# later in tf.reduce_sum when summing more than 2^24 individual `1.0` values.
# See https://github.com/tensorflow/tensorflow/issues/51419 for details.
one_hots = tf.one_hot(
clamped_indices, depth=bucket_count, dtype=tf.float64
)
bucket_counts = tf.cast(
tf.reduce_sum(input_tensor=one_hots, axis=0),
dtype=tf.float64,
)
edges = tf.linspace(min_, max_, bucket_count + 1)
left_edges = edges[:-1]
right_edges = edges[1:]
return tf.transpose(
a=tf.stack([left_edges, right_edges, bucket_counts])
)
def when_singular():
center = min_
bucket_starts = tf.stack([center - 0.5])
bucket_ends = tf.stack([center + 0.5])
bucket_counts = tf.stack(
[tf.cast(tf.size(input=data), tf.float64)]
)
return tf.transpose(
a=tf.stack([bucket_starts, bucket_ends, bucket_counts])
)
return tf.cond(is_singular, when_singular, when_nonsingular)
return tf.cond(is_empty, when_empty, when_nonempty)
def op(
name,
data,
bucket_count=None,
display_name=None,
description=None,
collections=None,
):
"""Create a legacy histogram summary op.
Arguments:
name: A unique name for the generated summary node.
data: A `Tensor` of any shape. Must be castable to `float64`.
bucket_count: Optional positive `int`. The output will have this
many buckets, except in two edge cases. If there is no data, then
there are no buckets. If there is data but all points have the
same value, then there is one bucket whose left and right
endpoints are the same.
display_name: Optional name for this summary in TensorBoard, as a
constant `str`. Defaults to `name`.
description: Optional long-form description for this summary, as a
constant `str`. Markdown is supported. Defaults to empty.
collections: Optional list of graph collections keys. The new
summary op is added to these collections. Defaults to
`[Graph Keys.SUMMARIES]`.
Returns:
A TensorFlow summary op.
"""
# TODO(nickfelt): remove on-demand imports once dep situation is fixed.
import tensorflow.compat.v1 as tf
if display_name is None:
display_name = name
summary_metadata = metadata.create_summary_metadata(
display_name=display_name, description=description
)
with tf.name_scope(name):
tensor = _buckets(data, bucket_count=bucket_count)
return tf.summary.tensor_summary(
name="histogram_summary",
tensor=tensor,
collections=collections,
summary_metadata=summary_metadata,
)
def pb(name, data, bucket_count=None, display_name=None, description=None):
"""Create a legacy histogram summary protobuf.
Arguments:
name: A unique name for the generated summary, including any desired
name scopes.
data: A `np.array` or array-like form of any shape. Must have type
castable to `float`.
bucket_count: Optional positive `int`. The output will have this
many buckets, except in two edge cases. If there is no data, then
there are no buckets. If there is data but all points have the
same value, then there is one bucket whose left and right
endpoints are the same.
display_name: Optional name for this summary in TensorBoard, as a
`str`. Defaults to `name`.
description: Optional long-form description for this summary, as a
`str`. Markdown is supported. Defaults to empty.
Returns:
A `tf.Summary` protobuf object.
"""
# TODO(nickfelt): remove on-demand imports once dep situation is fixed.
import tensorflow.compat.v1 as tf
if bucket_count is None:
bucket_count = summary_v2.DEFAULT_BUCKET_COUNT
data = np.array(data).flatten().astype(float)
if data.size == 0:
buckets = np.array([]).reshape((0, 3))
else:
min_ = np.min(data)
max_ = np.max(data)
range_ = max_ - min_
if range_ == 0:
center = min_
buckets = np.array([[center - 0.5, center + 0.5, float(data.size)]])
else:
bucket_width = range_ / bucket_count
offsets = data - min_
bucket_indices = np.floor(offsets / bucket_width).astype(int)
clamped_indices = np.minimum(bucket_indices, bucket_count - 1)
one_hots = np.array([clamped_indices]).transpose() == np.arange(
0, bucket_count
) # broadcast
assert one_hots.shape == (data.size, bucket_count), (
one_hots.shape,
(data.size, bucket_count),
)
bucket_counts = np.sum(one_hots, axis=0)
edges = np.linspace(min_, max_, bucket_count + 1)
left_edges = edges[:-1]
right_edges = edges[1:]
buckets = np.array(
[left_edges, right_edges, bucket_counts]
).transpose()
tensor = tf.make_tensor_proto(buckets, dtype=tf.float64)
if display_name is None:
display_name = name
summary_metadata = metadata.create_summary_metadata(
display_name=display_name, description=description
)
tf_summary_metadata = tf.SummaryMetadata.FromString(
summary_metadata.SerializeToString()
)
summary = tf.Summary()
summary.value.add(
tag="%s/histogram_summary" % name,
metadata=tf_summary_metadata,
tensor=tensor,
)
return summary

View File

@ -0,0 +1,293 @@
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Histogram summaries and TensorFlow operations to create them, V2 versions.
A histogram summary stores a list of buckets. Each bucket is encoded as a triple
`[left_edge, right_edge, count]`. Thus, a full histogram is encoded as a tensor
of dimension `[k, 3]`, where the first `k - 1` buckets are closed-open and the
last bucket is closed-closed.
In general, the shape of the output histogram is always constant (`[k, 3]`).
In the case of empty data, the output will be an all-zero histogram of shape
`[k, 3]`, where all edges and counts are zeros. If there is data but all points
have the same value, then all buckets' left and right edges are the same and only
the last bucket has nonzero count.
"""
import numpy as np
from tensorboard.compat import tf2 as tf
from tensorboard.compat.proto import summary_pb2
from tensorboard.plugins.histogram import metadata
from tensorboard.util import lazy_tensor_creator
from tensorboard.util import tensor_util
DEFAULT_BUCKET_COUNT = 30
def histogram_pb(tag, data, buckets=None, description=None):
"""Create a histogram summary protobuf.
Arguments:
tag: String tag for the summary.
data: A `np.array` or array-like form of any shape. Must have type
castable to `float`.
buckets: Optional positive `int`. The output shape will always be
[buckets, 3]. If there is no data, then an all-zero array of shape
[buckets, 3] will be returned. If there is data but all points have
the same value, then all buckets' left and right endpoints are the
same and only the last bucket has nonzero count. Defaults to 30 if
not specified.
description: Optional long-form description for this summary, as a
`str`. Markdown is supported. Defaults to empty.
Returns:
A `summary_pb2.Summary` protobuf object.
"""
bucket_count = DEFAULT_BUCKET_COUNT if buckets is None else buckets
data = np.array(data).flatten().astype(float)
if bucket_count == 0 or data.size == 0:
histogram_buckets = np.zeros((bucket_count, 3))
else:
min_ = np.min(data)
max_ = np.max(data)
range_ = max_ - min_
if range_ == 0:
left_edges = right_edges = np.array([min_] * bucket_count)
bucket_counts = np.array([0] * (bucket_count - 1) + [data.size])
histogram_buckets = np.array(
[left_edges, right_edges, bucket_counts]
).transpose()
else:
bucket_width = range_ / bucket_count
offsets = data - min_
bucket_indices = np.floor(offsets / bucket_width).astype(int)
clamped_indices = np.minimum(bucket_indices, bucket_count - 1)
one_hots = np.array([clamped_indices]).transpose() == np.arange(
0, bucket_count
) # broadcast
assert one_hots.shape == (data.size, bucket_count), (
one_hots.shape,
(data.size, bucket_count),
)
bucket_counts = np.sum(one_hots, axis=0)
edges = np.linspace(min_, max_, bucket_count + 1)
left_edges = edges[:-1]
right_edges = edges[1:]
histogram_buckets = np.array(
[left_edges, right_edges, bucket_counts]
).transpose()
tensor = tensor_util.make_tensor_proto(histogram_buckets, dtype=np.float64)
summary_metadata = metadata.create_summary_metadata(
display_name=None, description=description
)
summary = summary_pb2.Summary()
summary.value.add(tag=tag, metadata=summary_metadata, tensor=tensor)
return summary
# This is the TPU compatible V3 histogram implementation as of 2021-12-01.
def histogram(name, data, step=None, buckets=None, description=None):
"""Write a histogram summary.
See also `tf.summary.scalar`, `tf.summary.SummaryWriter`.
Writes a histogram to the current default summary writer, for later analysis
in TensorBoard's 'Histograms' and 'Distributions' dashboards (data written
using this API will appear in both places). Like `tf.summary.scalar` points,
each histogram is associated with a `step` and a `name`. All the histograms
with the same `name` constitute a time series of histograms.
The histogram is calculated over all the elements of the given `Tensor`
without regard to its shape or rank.
This example writes 2 histograms:
```python
w = tf.summary.create_file_writer('test/logs')
with w.as_default():
tf.summary.histogram("activations", tf.random.uniform([100, 50]), step=0)
tf.summary.histogram("initial_weights", tf.random.normal([1000]), step=0)
```
A common use case is to examine the changing activation patterns (or lack
thereof) at specific layers in a neural network, over time.
```python
w = tf.summary.create_file_writer('test/logs')
with w.as_default():
for step in range(100):
# Generate fake "activations".
activations = [
tf.random.normal([1000], mean=step, stddev=1),
tf.random.normal([1000], mean=step, stddev=10),
tf.random.normal([1000], mean=step, stddev=100),
]
tf.summary.histogram("layer1/activate", activations[0], step=step)
tf.summary.histogram("layer2/activate", activations[1], step=step)
tf.summary.histogram("layer3/activate", activations[2], step=step)
```
Arguments:
name: A name for this summary. The summary tag used for TensorBoard will
be this name prefixed by any active name scopes.
data: A `Tensor` of any shape. The histogram is computed over its elements,
which must be castable to `float64`.
step: Explicit `int64`-castable monotonic step value for this summary. If
omitted, this defaults to `tf.summary.experimental.get_step()`, which must
not be None.
buckets: Optional positive `int`. The output will have this
many buckets, except in two edge cases. If there is no data, then
there are no buckets. If there is data but all points have the
same value, then all buckets' left and right endpoints are the same
and only the last bucket has nonzero count. Defaults to 30 if not
specified.
description: Optional long-form description for this summary, as a
constant `str`. Markdown is supported. Defaults to empty.
Returns:
True on success, or false if no summary was emitted because no default
summary writer was available.
Raises:
ValueError: if a default writer exists, but no step was provided and
`tf.summary.experimental.get_step()` is None.
"""
# Avoid building unused gradient graphs for conds below. This works around
# an error building second-order gradient graphs when XlaDynamicUpdateSlice
# is used, and will generally speed up graph building slightly.
data = tf.stop_gradient(data)
summary_metadata = metadata.create_summary_metadata(
display_name=None, description=description
)
# TODO(https://github.com/tensorflow/tensorboard/issues/2109): remove fallback
summary_scope = (
getattr(tf.summary.experimental, "summary_scope", None)
or tf.summary.summary_scope
)
# TODO(ytjing): add special case handling.
with summary_scope(
name, "histogram_summary", values=[data, buckets, step]
) as (tag, _):
# Defer histogram bucketing logic by passing it as a callable to
# write(), wrapped in a LazyTensorCreator for backwards
# compatibility, so that we only do this work when summaries are
# actually written.
@lazy_tensor_creator.LazyTensorCreator
def lazy_tensor():
return _buckets(data, buckets)
return tf.summary.write(
tag=tag,
tensor=lazy_tensor,
step=step,
metadata=summary_metadata,
)
def _buckets(data, bucket_count=None):
"""Create a TensorFlow op to group data into histogram buckets.
Arguments:
data: A `Tensor` of any shape. Must be castable to `float64`.
bucket_count: Optional non-negative `int` or scalar `int32` `Tensor`,
defaults to 30.
Returns:
A `Tensor` of shape `[k, 3]` and type `float64`. The `i`th row is
a triple `[left_edge, right_edge, count]` for a single bucket.
The value of `k` is either `bucket_count` or `0` (when input data
is empty).
"""
if bucket_count is None:
bucket_count = DEFAULT_BUCKET_COUNT
with tf.name_scope("buckets"):
tf.debugging.assert_scalar(bucket_count)
tf.debugging.assert_type(bucket_count, tf.int32)
# Treat a negative bucket count as zero.
bucket_count = tf.math.maximum(0, bucket_count)
data = tf.reshape(data, shape=[-1]) # flatten
data = tf.cast(data, tf.float64)
data_size = tf.size(input=data)
is_empty = tf.logical_or(
tf.equal(data_size, 0), tf.less_equal(bucket_count, 0)
)
def when_empty():
"""When input data is empty or bucket_count is zero.
1. If bucket_count is specified as zero, an empty tensor of shape
(0, 3) will be returned.
2. If the input data is empty, a tensor of shape (bucket_count, 3)
of all zero values will be returned.
"""
return tf.zeros((bucket_count, 3), dtype=tf.float64)
def when_nonempty():
min_ = tf.reduce_min(input_tensor=data)
max_ = tf.reduce_max(input_tensor=data)
range_ = max_ - min_
has_single_value = tf.equal(range_, 0)
def when_multiple_values():
"""When input data contains multiple values."""
bucket_width = range_ / tf.cast(bucket_count, tf.float64)
offsets = data - min_
bucket_indices = tf.cast(
tf.floor(offsets / bucket_width), dtype=tf.int32
)
clamped_indices = tf.minimum(bucket_indices, bucket_count - 1)
# Use float64 instead of float32 to avoid accumulating floating point error
# later in tf.reduce_sum when summing more than 2^24 individual `1.0` values.
# See https://github.com/tensorflow/tensorflow/issues/51419 for details.
one_hots = tf.one_hot(
clamped_indices, depth=bucket_count, dtype=tf.float64
)
bucket_counts = tf.cast(
tf.reduce_sum(input_tensor=one_hots, axis=0),
dtype=tf.float64,
)
edges = tf.linspace(min_, max_, bucket_count + 1)
# Ensure edges[-1] == max_, which TF's linspace implementation does not
# do, leaving it subject to the whim of floating point rounding error.
edges = tf.concat([edges[:-1], [max_]], 0)
left_edges = edges[:-1]
right_edges = edges[1:]
return tf.transpose(
a=tf.stack([left_edges, right_edges, bucket_counts])
)
def when_single_value():
"""When input data contains a single unique value."""
# Left and right edges are the same for single value input.
edges = tf.fill([bucket_count], max_)
# Bucket counts are 0 except the last bucket (if bucket_count > 0),
# which is `data_size`. Ensure that the resulting counts vector has
# length `bucket_count` always, including the bucket_count==0 case.
zeroes = tf.fill([bucket_count], 0)
bucket_counts = tf.cast(
tf.concat([zeroes[:-1], [data_size]], 0)[:bucket_count],
dtype=tf.float64,
)
return tf.transpose(a=tf.stack([edges, edges, bucket_counts]))
return tf.cond(
has_single_value, when_single_value, when_multiple_values
)
return tf.cond(is_empty, when_empty, when_nonempty)

View File

@ -0,0 +1,14 @@
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

View File

@ -0,0 +1,94 @@
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Keras integration for TensorBoard hparams.
Use `tensorboard.plugins.hparams.api` to access this module's contents.
"""
import tensorflow as tf
from tensorboard.plugins.hparams import api_pb2
from tensorboard.plugins.hparams import summary
from tensorboard.plugins.hparams import summary_v2
class Callback(tf.keras.callbacks.Callback):
"""Callback for logging hyperparameters to TensorBoard.
NOTE: This callback only works in TensorFlow eager mode.
"""
def __init__(self, writer, hparams, trial_id=None):
"""Create a callback for logging hyperparameters to TensorBoard.
As with the standard `tf.keras.callbacks.TensorBoard` class, each
callback object is valid for only one call to `model.fit`.
Args:
writer: The `SummaryWriter` object to which hparams should be
written, or a logdir (as a `str`) to be passed to
`tf.summary.create_file_writer` to create such a writer.
hparams: A `dict` mapping hyperparameters to the values used in
this session. Keys should be the names of `HParam` objects used
in an experiment, or the `HParam` objects themselves. Values
should be Python `bool`, `int`, `float`, or `string` values,
depending on the type of the hyperparameter.
trial_id: An optional `str` ID for the set of hyperparameter
values used in this trial. Defaults to a hash of the
hyperparameters.
Raises:
ValueError: If two entries in `hparams` share the same
hyperparameter name.
"""
# Defer creating the actual summary until we write it, so that the
# timestamp is correct. But create a "dry-run" first to fail fast in
# case the `hparams` are invalid.
self._hparams = dict(hparams)
self._trial_id = trial_id
summary_v2.hparams_pb(self._hparams, trial_id=self._trial_id)
if writer is None:
raise TypeError(
"writer must be a `SummaryWriter` or `str`, not None"
)
elif isinstance(writer, str):
self._writer = tf.compat.v2.summary.create_file_writer(writer)
else:
self._writer = writer
def _get_writer(self):
if self._writer is None:
raise RuntimeError(
"hparams Keras callback cannot be reused across training sessions"
)
if not tf.executing_eagerly():
raise RuntimeError(
"hparams Keras callback only supported in TensorFlow eager mode"
)
return self._writer
def on_train_begin(self, logs=None):
del logs # unused
with self._get_writer().as_default():
summary_v2.hparams(self._hparams, trial_id=self._trial_id)
def on_train_end(self, logs=None):
del logs # unused
with self._get_writer().as_default():
pb = summary.session_end_pb(api_pb2.STATUS_SUCCESS)
raw_pb = pb.SerializeToString()
tf.compat.v2.summary.experimental.write_raw_pb(raw_pb, step=0)
self._writer = None

View File

@ -0,0 +1,132 @@
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Public APIs for the HParams plugin.
This module supports a spectrum of use cases, depending on how much
structure you want. In the simplest case, you can simply collect your
hparams into a dict, and use a Keras callback to record them:
>>> from tensorboard.plugins.hparams import api as hp
>>> hparams = {
... "optimizer": "adam",
... "fc_dropout": 0.2,
... "neurons": 128,
... # ...
... }
>>>
>>> model = model_fn(hparams)
>>> callbacks = [
>>> tf.keras.callbacks.TensorBoard(logdir),
>>> hp.KerasCallback(logdir, hparams),
>>> ]
>>> model.fit(..., callbacks=callbacks)
The Keras callback requires that TensorFlow eager execution be enabled.
If not using Keras, use the `hparams` function to write the values
directly:
>>> # In eager mode:
>>> with tf.create_file_writer(logdir).as_default():
... hp.hparams(hparams)
>>>
>>> # In legacy graph mode:
>>> with tf.compat.v2.create_file_writer(logdir).as_default() as w:
... sess.run(w.init())
... sess.run(hp.hparams(hparams))
... sess.run(w.flush())
To control how hyperparameters and metrics appear in the TensorBoard UI,
you can define `HParam` and `Metric` objects, and write an experiment
summary to the top-level log directory:
>>> HP_OPTIMIZER = hp.HParam("optimizer")
>>> HP_FC_DROPOUT = hp.HParam(
... "fc_dropout",
... display_name="f.c. dropout",
... description="Dropout rate for fully connected subnet.",
... )
>>> HP_NEURONS = hp.HParam("neurons", description="Neurons per dense layer")
>>>
>>> with tf.summary.create_file_writer(base_logdir).as_default():
... hp.hparams_config(
... hparams=[
... HP_OPTIMIZER,
... HP_FC_DROPOUT,
... HP_NEURONS,
... ],
... metrics=[
... hp.Metric("xent", group="validation", display_name="cross-entropy"),
... hp.Metric("f1", group="validation", display_name="F&#x2081; score"),
... hp.Metric("loss", group="train", display_name="training loss"),
... ],
... )
You can continue to pass a string-keyed dict to the Keras callback or
the `hparams` function, or you can use `HParam` objects as the keys. The
latter approach enables better static analysis: your favorite Python
linter can tell you if you misspell a hyperparameter name, your IDE can
help you find all the places where a hyperparameter is used, etc:
>>> hparams = {
... HP_OPTIMIZER: "adam",
... HP_FC_DROPOUT: 0.2,
... HP_NEURONS: 128,
... # ...
... }
>>>
>>> model = model_fn(hparams)
>>> callbacks = [
>>> tf.keras.callbacks.TensorBoard(logdir),
>>> hp.KerasCallback(logdir, hparams),
>>> ]
Finally, you can choose to annotate your hparam definitions with domain
information:
>>> HP_OPTIMIZER = hp.HParam("optimizer", hp.Discrete(["adam", "sgd"]))
>>> HP_FC_DROPOUT = hp.HParam("fc_dropout", hp.RealInterval(0.1, 0.4))
>>> HP_NEURONS = hp.HParam("neurons", hp.IntInterval(64, 256))
The TensorBoard HParams plugin does not provide tuners, but you can
integrate these domains into your preferred tuning framework if you so
desire. The domains will also be reflected in the TensorBoard UI.
See the `Experiment`, `HParam`, `Metric`, and `KerasCallback` classes
for API specifications. Consult the `hparams_demo.py` script in the
TensorBoard repository for an end-to-end MNIST example.
"""
from tensorboard.plugins.hparams import _keras
from tensorboard.plugins.hparams import summary_v2
Discrete = summary_v2.Discrete
Domain = summary_v2.Domain
HParam = summary_v2.HParam
IntInterval = summary_v2.IntInterval
Metric = summary_v2.Metric
RealInterval = summary_v2.RealInterval
hparams = summary_v2.hparams
hparams_pb = summary_v2.hparams_pb
hparams_config = summary_v2.hparams_config
hparams_config_pb = summary_v2.hparams_config_pb
KerasCallback = _keras.Callback
del _keras
del summary_v2

File diff suppressed because one or more lines are too long

View File

@ -0,0 +1,607 @@
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Wraps the base_plugin.TBContext to stores additional data shared across API
handlers for the HParams plugin backend."""
import collections
import os
from tensorboard.data import provider
from tensorboard.plugins.hparams import api_pb2
from tensorboard.plugins.hparams import json_format_compat
from tensorboard.plugins.hparams import metadata
from google.protobuf import json_format
from tensorboard.plugins.scalar import metadata as scalar_metadata
_DISCRETE_DOMAIN_TYPE_TO_DATA_TYPE = {
provider.HyperparameterDomainType.DISCRETE_BOOL: api_pb2.DATA_TYPE_BOOL,
provider.HyperparameterDomainType.DISCRETE_FLOAT: api_pb2.DATA_TYPE_FLOAT64,
provider.HyperparameterDomainType.DISCRETE_STRING: api_pb2.DATA_TYPE_STRING,
}
class Context:
"""Wraps the base_plugin.TBContext to stores additional data shared across
API handlers for the HParams plugin backend.
Before adding fields to this class, carefully consider whether the
field truly needs to be accessible to all API handlers or if it
can be passed separately to the handler constructor. We want to
avoid this class becoming a magic container of variables that have
no better place. See http://wiki.c2.com/?MagicContainer
"""
def __init__(self, tb_context):
"""Instantiates a context.
Args:
tb_context: base_plugin.TBContext. The "base" context we extend.
"""
self._tb_context = tb_context
def experiment_from_metadata(
self,
ctx,
experiment_id,
include_metrics,
hparams_run_to_tag_to_content,
data_provider_hparams,
hparams_limit=None,
):
"""Returns the experiment proto defining the experiment.
This method first attempts to find a metadata.EXPERIMENT_TAG tag and
retrieve the associated proto.
If no such tag is found, the method will attempt to build a minimal
experiment proto by scanning for all metadata.SESSION_START_INFO_TAG
tags (to compute the hparam_infos field of the experiment) and for all
scalar tags (to compute the metric_infos field of the experiment).
If no metadata.EXPERIMENT_TAG nor metadata.SESSION_START_INFO_TAG tags
are found, then will build an experiment proto using the results from
DataProvider.list_hyperparameters().
Args:
experiment_id: String, from `plugin_util.experiment_id`.
include_metrics: Whether to determine metrics_infos and include them
in the result.
hparams_run_to_tag_to_content: The output from an hparams_metadata()
call. A dict `d` such that `d[run][tag]` is a `bytes` value with the
summary metadata content for the keyed time series.
data_provider_hparams: The output from an hparams_from_data_provider()
call, corresponding to DataProvider.list_hyperparameters().
A provider.ListHyperpararametersResult.
hparams_limit: Optional number of hyperparameter metadata to include in the
result. If unset or zero, all metadata will be included.
Returns:
The experiment proto. If no data is found for an experiment proto to
be built, returns an entirely empty experiment.
"""
experiment = self._find_experiment_tag(
hparams_run_to_tag_to_content, include_metrics
)
if experiment:
_sort_and_reduce_to_hparams_limit(experiment, hparams_limit)
return experiment
experiment_from_runs = self._compute_experiment_from_runs(
ctx, experiment_id, include_metrics, hparams_run_to_tag_to_content
)
if experiment_from_runs:
_sort_and_reduce_to_hparams_limit(
experiment_from_runs, hparams_limit
)
return experiment_from_runs
experiment_from_data_provider_hparams = (
self._experiment_from_data_provider_hparams(
ctx, experiment_id, include_metrics, data_provider_hparams
)
)
return (
experiment_from_data_provider_hparams
if experiment_from_data_provider_hparams
else api_pb2.Experiment()
)
@property
def tb_context(self):
return self._tb_context
def _convert_plugin_metadata(self, data_provider_output):
return {
run: {
tag: time_series.plugin_content
for (tag, time_series) in tag_to_time_series.items()
}
for (run, tag_to_time_series) in data_provider_output.items()
}
def hparams_metadata(self, ctx, experiment_id, run_tag_filter=None):
"""Reads summary metadata for all hparams time series.
Args:
experiment_id: String, from `plugin_util.experiment_id`.
run_tag_filter: Optional `data.provider.RunTagFilter`, with
the semantics as in `list_tensors`.
Returns:
A dict `d` such that `d[run][tag]` is a `bytes` value with the
summary metadata content for the keyed time series.
"""
return self._convert_plugin_metadata(
self._tb_context.data_provider.list_tensors(
ctx,
experiment_id=experiment_id,
plugin_name=metadata.PLUGIN_NAME,
run_tag_filter=run_tag_filter,
)
)
def scalars_metadata(self, ctx, experiment_id):
"""Reads summary metadata for all scalar time series.
Args:
experiment_id: String, from `plugin_util.experiment_id`.
Returns:
A dict `d` such that `d[run][tag]` is a `bytes` value with the
summary metadata content for the keyed time series.
"""
return self._convert_plugin_metadata(
self._tb_context.data_provider.list_scalars(
ctx,
experiment_id=experiment_id,
plugin_name=scalar_metadata.PLUGIN_NAME,
)
)
def read_last_scalars(self, ctx, experiment_id, run_tag_filter):
"""Reads the most recent values from scalar time series.
Args:
experiment_id: String.
run_tag_filter: Required `data.provider.RunTagFilter`, with
the semantics as in `read_last_scalars`.
Returns:
A dict `d` such that `d[run][tag]` is a `provider.ScalarDatum`
value, with keys only for runs and tags that actually had
data, which may be a subset of what was requested.
"""
return self._tb_context.data_provider.read_last_scalars(
ctx,
experiment_id=experiment_id,
plugin_name=scalar_metadata.PLUGIN_NAME,
run_tag_filter=run_tag_filter,
)
def hparams_from_data_provider(self, ctx, experiment_id, limit):
"""Calls DataProvider.list_hyperparameters() and returns the result."""
return self._tb_context.data_provider.list_hyperparameters(
ctx, experiment_ids=[experiment_id], limit=limit
)
def session_groups_from_data_provider(
self, ctx, experiment_id, filters, sort, hparams_to_include
):
"""Calls DataProvider.read_hyperparameters() and returns the result."""
return self._tb_context.data_provider.read_hyperparameters(
ctx,
experiment_ids=[experiment_id],
filters=filters,
sort=sort,
hparams_to_include=hparams_to_include,
)
def _find_experiment_tag(
self, hparams_run_to_tag_to_content, include_metrics
):
"""Finds the experiment associated with the metadata.EXPERIMENT_TAG
tag.
Returns:
The experiment or None if no such experiment is found.
"""
# We expect only one run to have an `EXPERIMENT_TAG`; look
# through all of them and arbitrarily pick the first one.
for tags in hparams_run_to_tag_to_content.values():
maybe_content = tags.get(metadata.EXPERIMENT_TAG)
if maybe_content is not None:
experiment = metadata.parse_experiment_plugin_data(
maybe_content
)
if not include_metrics:
# metric_infos haven't technically been "calculated" in this
# case. They have been read directly from the Experiment
# proto.
# Delete them from the result so that they are not returned
# to the client.
experiment.ClearField("metric_infos")
return experiment
return None
def _compute_experiment_from_runs(
self, ctx, experiment_id, include_metrics, hparams_run_to_tag_to_content
):
"""Computes a minimal Experiment protocol buffer by scanning the runs.
Returns None if there are no hparam infos logged.
"""
hparam_infos = self._compute_hparam_infos(hparams_run_to_tag_to_content)
metric_infos = (
self._compute_metric_infos_from_runs(
ctx, experiment_id, hparams_run_to_tag_to_content
)
if hparam_infos and include_metrics
else []
)
if not hparam_infos and not metric_infos:
return None
return api_pb2.Experiment(
hparam_infos=hparam_infos, metric_infos=metric_infos
)
def _compute_hparam_infos(self, hparams_run_to_tag_to_content):
"""Computes a list of api_pb2.HParamInfo from the current run, tag
info.
Finds all the SessionStartInfo messages and collects the hparams values
appearing in each one. For each hparam attempts to deduce a type that fits
all its values. Finally, sets the 'domain' of the resulting HParamInfo
to be discrete if the type is string or boolean.
Returns:
A list of api_pb2.HParamInfo messages.
"""
# Construct a dict mapping an hparam name to its list of values.
hparams = collections.defaultdict(list)
for tag_to_content in hparams_run_to_tag_to_content.values():
if metadata.SESSION_START_INFO_TAG not in tag_to_content:
continue
start_info = metadata.parse_session_start_info_plugin_data(
tag_to_content[metadata.SESSION_START_INFO_TAG]
)
for name, value in start_info.hparams.items():
hparams[name].append(value)
# Try to construct an HParamInfo for each hparam from its name and list
# of values.
result = []
for name, values in hparams.items():
hparam_info = self._compute_hparam_info_from_values(name, values)
if hparam_info is not None:
result.append(hparam_info)
return result
def _compute_hparam_info_from_values(self, name, values):
"""Builds an HParamInfo message from the hparam name and list of
values.
Args:
name: string. The hparam name.
values: list of google.protobuf.Value messages. The list of values for the
hparam.
Returns:
An api_pb2.HParamInfo message.
"""
# Figure out the type from the values.
# Ignore values whose type is not listed in api_pb2.DataType
# If all values have the same type, then that is the type used.
# Otherwise, the returned type is DATA_TYPE_STRING.
result = api_pb2.HParamInfo(name=name, type=api_pb2.DATA_TYPE_UNSET)
for v in values:
v_type = _protobuf_value_type(v)
if not v_type:
continue
if result.type == api_pb2.DATA_TYPE_UNSET:
result.type = v_type
elif result.type != v_type:
result.type = api_pb2.DATA_TYPE_STRING
if result.type == api_pb2.DATA_TYPE_STRING:
# A string result.type does not change, so we can exit the loop.
break
# If we couldn't figure out a type, then we can't compute the hparam_info.
if result.type == api_pb2.DATA_TYPE_UNSET:
return None
if result.type == api_pb2.DATA_TYPE_STRING:
distinct_string_values = set(
_protobuf_value_to_string(v)
for v in values
if _can_be_converted_to_string(v)
)
result.domain_discrete.extend(distinct_string_values)
result.differs = len(distinct_string_values) > 1
if result.type == api_pb2.DATA_TYPE_BOOL:
distinct_bool_values = set(v.bool_value for v in values)
result.domain_discrete.extend(distinct_bool_values)
result.differs = len(distinct_bool_values) > 1
if result.type == api_pb2.DATA_TYPE_FLOAT64:
# Always uses interval domain type for numeric hparam values.
distinct_float_values = sorted([v.number_value for v in values])
if distinct_float_values:
result.domain_interval.min_value = distinct_float_values[0]
result.domain_interval.max_value = distinct_float_values[-1]
result.differs = len(set(distinct_float_values)) > 1
return result
def _experiment_from_data_provider_hparams(
self,
ctx,
experiment_id,
include_metrics,
data_provider_hparams,
):
"""Returns an experiment protobuffer based on data provider hparams.
Args:
experiment_id: String, from `plugin_util.experiment_id`.
include_metrics: Whether to determine metrics_infos and include them
in the result.
data_provider_hparams: The output from an hparams_from_data_provider()
call, corresponding to DataProvider.list_hyperparameters().
A provider.ListHyperparametersResult.
Returns:
The experiment proto. If there are no hyperparameters in the input,
returns None.
"""
if isinstance(data_provider_hparams, list):
# TODO: Support old return value of Collection[provider.Hyperparameters]
# until all internal implementations of DataProvider can be
# migrated to use new return value of provider.ListHyperparametersResult.
hyperparameters = data_provider_hparams
session_groups = []
else:
# Is instance of provider.ListHyperparametersResult
hyperparameters = data_provider_hparams.hyperparameters
session_groups = data_provider_hparams.session_groups
hparam_infos = [
self._convert_data_provider_hparam(dp_hparam)
for dp_hparam in hyperparameters
]
metric_infos = (
self.compute_metric_infos_from_data_provider_session_groups(
ctx, experiment_id, session_groups
)
if include_metrics
else []
)
return api_pb2.Experiment(
hparam_infos=hparam_infos, metric_infos=metric_infos
)
def _convert_data_provider_hparam(self, dp_hparam):
"""Builds an HParamInfo message from data provider Hyperparameter.
Args:
dp_hparam: The provider.Hyperparameter returned by the call to
provider.DataProvider.list_hyperparameters().
Returns:
An HParamInfo to include in the Experiment.
"""
hparam_info = api_pb2.HParamInfo(
name=dp_hparam.hyperparameter_name,
display_name=dp_hparam.hyperparameter_display_name,
differs=dp_hparam.differs,
)
if dp_hparam.domain_type == provider.HyperparameterDomainType.INTERVAL:
hparam_info.type = api_pb2.DATA_TYPE_FLOAT64
(dp_hparam_min, dp_hparam_max) = dp_hparam.domain
hparam_info.domain_interval.min_value = dp_hparam_min
hparam_info.domain_interval.max_value = dp_hparam_max
elif dp_hparam.domain_type in _DISCRETE_DOMAIN_TYPE_TO_DATA_TYPE.keys():
hparam_info.type = _DISCRETE_DOMAIN_TYPE_TO_DATA_TYPE.get(
dp_hparam.domain_type
)
hparam_info.domain_discrete.extend(dp_hparam.domain)
return hparam_info
def _compute_metric_infos_from_runs(
self, ctx, experiment_id, hparams_run_to_tag_to_content
):
session_runs = set(
run
for run, tags in hparams_run_to_tag_to_content.items()
if metadata.SESSION_START_INFO_TAG in tags
)
return (
api_pb2.MetricInfo(name=api_pb2.MetricName(group=group, tag=tag))
for tag, group in self._compute_metric_names(
ctx, experiment_id, session_runs
)
)
def compute_metric_infos_from_data_provider_session_groups(
self, ctx, experiment_id, session_groups
):
session_runs = set(
generate_data_provider_session_name(s)
for sg in session_groups
for s in sg.sessions
)
return [
api_pb2.MetricInfo(name=api_pb2.MetricName(group=group, tag=tag))
for tag, group in self._compute_metric_names(
ctx, experiment_id, session_runs
)
]
def _compute_metric_names(self, ctx, experiment_id, session_runs):
"""Computes the list of metric names from all the scalar (run, tag)
pairs.
The return value is a list of (tag, group) pairs representing the metric
names. The list is sorted in Python tuple-order (lexicographical).
For example, if the scalar (run, tag) pairs are:
("exp/session1", "loss")
("exp/session2", "loss")
("exp/session2/eval", "loss")
("exp/session2/validation", "accuracy")
("exp/no-session", "loss_2"),
and the runs corresponding to sessions are "exp/session1", "exp/session2",
this method will return [("loss", ""), ("loss", "/eval"), ("accuracy",
"/validation")]
More precisely, each scalar (run, tag) pair is converted to a (tag, group)
metric name, where group is the suffix of run formed by removing the
longest prefix which is a session run. If no session run is a prefix of
'run', the pair is skipped.
Returns:
A python list containing pairs. Each pair is a (tag, group) pair
representing a metric name used in some session.
"""
metric_names_set = set()
scalars_run_to_tag_to_content = self.scalars_metadata(
ctx, experiment_id
)
for run, tags in scalars_run_to_tag_to_content.items():
session = _find_longest_parent_path(session_runs, run)
if session is None:
continue
group = os.path.relpath(run, session)
# relpath() returns "." for the 'session' directory, we use an empty
# string, unless the run name actually ends with ".".
if group == "." and not run.endswith("."):
group = ""
metric_names_set.update((tag, group) for tag in tags)
metric_names_list = list(metric_names_set)
# Sort metrics for determinism.
metric_names_list.sort()
return metric_names_list
def generate_data_provider_session_name(session):
"""Generates a name from a HyperparameterSesssionRun.
If the HyperparameterSessionRun contains no experiment or run information
then the name is set to the original experiment_id.
"""
if not session.experiment_id and not session.run:
return ""
elif not session.experiment_id:
return session.run
elif not session.run:
return session.experiment_id
else:
return f"{session.experiment_id}/{session.run}"
def _find_longest_parent_path(path_set, path):
"""Finds the longest "parent-path" of 'path' in 'path_set'.
This function takes and returns "path-like" strings which are strings
made of strings separated by os.sep. No file access is performed here, so
these strings need not correspond to actual files in some file-system..
This function returns the longest ancestor path
For example, for path_set=["/foo/bar", "/foo", "/bar/foo"] and
path="/foo/bar/sub_dir", returns "/foo/bar".
Args:
path_set: set of path-like strings -- e.g. a list of strings separated by
os.sep. No actual disk-access is performed here, so these need not
correspond to actual files.
path: a path-like string.
Returns:
The element in path_set which is the longest parent directory of 'path'.
"""
# This could likely be more efficiently implemented with a trie
# data-structure, but we don't want to add an extra dependency for that.
while path not in path_set:
if not path:
return None
path = os.path.dirname(path)
return path
def _can_be_converted_to_string(value):
if not _protobuf_value_type(value):
return False
return json_format_compat.is_serializable_value(value)
def _protobuf_value_type(value):
"""Returns the type of the google.protobuf.Value message as an
api.DataType.
Returns None if the type of 'value' is not one of the types supported in
api_pb2.DataType.
Args:
value: google.protobuf.Value message.
"""
if value.HasField("number_value"):
return api_pb2.DATA_TYPE_FLOAT64
if value.HasField("string_value"):
return api_pb2.DATA_TYPE_STRING
if value.HasField("bool_value"):
return api_pb2.DATA_TYPE_BOOL
return None
def _protobuf_value_to_string(value):
"""Returns a string representation of given google.protobuf.Value message.
Args:
value: google.protobuf.Value message. Assumed to be of type 'number',
'string' or 'bool'.
"""
value_in_json = json_format.MessageToJson(value)
if value.HasField("string_value"):
# Remove the quotations.
return value_in_json[1:-1]
return value_in_json
def _sort_and_reduce_to_hparams_limit(experiment, hparams_limit=None):
"""Sorts and applies limit to the hparams in the given experiment proto.
Args:
experiment: An api_pb2.Experiment proto, which will be modified in place.
hparams_limit: Optional number of hyperparameter metadata to include in the
result. If unset or zero, no limit will be applied.
Returns:
None. `experiment` proto will be modified in place.
"""
if not hparams_limit:
# If limit is unset or zero, returns all hparams.
hparams_limit = len(experiment.hparam_infos)
# Prioritizes returning HParamInfo protos with `differed` values.
# Sorts by `differs` (True first), then by name.
limited_hparam_infos = sorted(
experiment.hparam_infos,
key=lambda hparam_info: (not hparam_info.differs, hparam_info.name),
)[:hparams_limit]
experiment.ClearField("hparam_infos")
experiment.hparam_infos.extend(limited_hparam_infos)

View File

@ -0,0 +1,159 @@
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Classes and functions for handling the DownloadData API call."""
import csv
import io
import math
from tensorboard.plugins.hparams import error
class OutputFormat:
"""An enum used to list the valid output formats for API calls."""
JSON = "json"
CSV = "csv"
LATEX = "latex"
class Handler:
"""Handles a DownloadData request."""
def __init__(
self,
context,
experiment,
session_groups,
response_format,
columns_visibility,
):
"""Constructor.
Args:
context: A backend_context.Context instance.
experiment: Experiment proto.
session_groups: ListSessionGroupsResponse proto.
response_format: A string in the OutputFormat enum.
columns_visibility: A list of boolean values to filter columns.
"""
self._context = context
self._experiment = experiment
self._session_groups = session_groups
self._response_format = response_format
self._columns_visibility = columns_visibility
def run(self):
"""Handles the request specified on construction.
Returns:
A response body.
A mime type (string) for the response.
"""
experiment = self._experiment
session_groups = self._session_groups
response_format = self._response_format
visibility = self._columns_visibility
header = []
for hparam_info in experiment.hparam_infos:
header.append(hparam_info.display_name or hparam_info.name)
for metric_info in experiment.metric_infos:
header.append(metric_info.display_name or metric_info.name.tag)
def _filter_columns(row):
return [value for value, visible in zip(row, visibility) if visible]
header = _filter_columns(header)
rows = []
def _get_value(value):
if value.HasField("number_value"):
return value.number_value
if value.HasField("string_value"):
return value.string_value
if value.HasField("bool_value"):
return value.bool_value
# hyperparameter values can be optional in a session group
return ""
def _get_metric_id(metric):
return metric.group + "." + metric.tag
for group in session_groups.session_groups:
row = []
for hparam_info in experiment.hparam_infos:
row.append(_get_value(group.hparams[hparam_info.name]))
metric_values = {}
for metric_value in group.metric_values:
metric_id = _get_metric_id(metric_value.name)
metric_values[metric_id] = metric_value.value
for metric_info in experiment.metric_infos:
metric_id = _get_metric_id(metric_info.name)
row.append(metric_values.get(metric_id))
rows.append(_filter_columns(row))
if response_format == OutputFormat.JSON:
mime_type = "application/json"
body = dict(header=header, rows=rows)
elif response_format == OutputFormat.LATEX:
def latex_format(value):
if value is None:
return "-"
elif isinstance(value, int):
return "$%d$" % value
elif isinstance(value, float):
if math.isnan(value):
return r"$\mathrm{NaN}$"
if value in (float("inf"), float("-inf")):
return r"$%s\infty$" % ("-" if value < 0 else "+")
scientific = "%.3g" % value
if "e" in scientific:
coefficient, exponent = scientific.split("e")
return "$%s\\cdot 10^{%d}$" % (
coefficient,
int(exponent),
)
return "$%s$" % scientific
return value.replace("_", "\\_").replace("%", "\\%")
mime_type = "application/x-latex"
top_part = "\\begin{table}[tbp]\n\\begin{tabular}{%s}\n" % (
"l" * len(header)
)
header_part = (
" & ".join(map(latex_format, header)) + " \\\\ \\hline\n"
)
middle_part = "".join(
" & ".join(map(latex_format, row)) + " \\\\\n" for row in rows
)
bottom_part = "\\hline\n\\end{tabular}\n\\end{table}\n"
body = top_part + header_part + middle_part + bottom_part
elif response_format == OutputFormat.CSV:
string_io = io.StringIO()
writer = csv.writer(string_io)
writer.writerow(header)
writer.writerows(rows)
body = string_io.getvalue()
mime_type = "text/csv"
else:
raise error.HParamsError(
"Invalid reponses format: %s" % response_format
)
return body, mime_type

View File

@ -0,0 +1,26 @@
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Defines an error (exception) class for the HParams plugin."""
class HParamsError(Exception):
"""Represents an error that is meaningful to the end-user.
Such an error should have a meaningful error message. Other errors,
(such as resulting from some internal invariants being violated)
should be represented by other exceptions.
"""
pass

View File

@ -0,0 +1,71 @@
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Classes and functions for handling the GetExperiment API call."""
from tensorboard.plugins.hparams import api_pb2
class Handler:
"""Handles a GetExperiment request."""
def __init__(
self, request_context, backend_context, experiment_id, request
):
"""Constructor.
Args:
request_context: A tensorboard.context.RequestContext.
backend_context: A backend_context.Context instance.
experiment_id: A string, as from `plugin_util.experiment_id`.
request: A request proto.
"""
self._request_context = request_context
self._backend_context = backend_context
self._experiment_id = experiment_id
self._include_metrics = (
# Metrics are included by default if include_metrics is not
# specified in the request.
not request.HasField("include_metrics")
or request.include_metrics
)
self._hparams_limit = (
request.hparams_limit
if isinstance(request, api_pb2.GetExperimentRequest)
else None
)
def run(self):
"""Handles the request specified on construction.
Returns:
An Experiment object.
"""
data_provider_hparams = (
self._backend_context.hparams_from_data_provider(
self._request_context,
self._experiment_id,
limit=self._hparams_limit,
)
)
return self._backend_context.experiment_from_metadata(
self._request_context,
self._experiment_id,
self._include_metrics,
self._backend_context.hparams_metadata(
self._request_context, self._experiment_id
),
data_provider_hparams,
self._hparams_limit,
)

View File

@ -0,0 +1,207 @@
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""The TensorBoard HParams plugin.
See `http_api.md` in this directory for specifications of the routes for
this plugin.
"""
import json
import werkzeug
from werkzeug import wrappers
from tensorboard import plugin_util
from tensorboard.plugins.hparams import api_pb2
from tensorboard.plugins.hparams import backend_context
from tensorboard.plugins.hparams import download_data
from tensorboard.plugins.hparams import error
from tensorboard.plugins.hparams import get_experiment
from tensorboard.plugins.hparams import list_metric_evals
from tensorboard.plugins.hparams import list_session_groups
from tensorboard.plugins.hparams import metadata
from google.protobuf import json_format
from tensorboard.backend import http_util
from tensorboard.plugins import base_plugin
from tensorboard.plugins.scalar import metadata as scalars_metadata
from tensorboard.util import tb_logging
logger = tb_logging.get_logger()
class HParamsPlugin(base_plugin.TBPlugin):
"""HParams Plugin for TensorBoard.
It supports both GETs and POSTs. See 'http_api.md' for more details.
"""
plugin_name = metadata.PLUGIN_NAME
def __init__(self, context):
"""Instantiates HParams plugin via TensorBoard core.
Args:
context: A base_plugin.TBContext instance.
"""
self._context = backend_context.Context(context)
def get_plugin_apps(self):
"""See base class."""
return {
"/download_data": self.download_data_route,
"/experiment": self.get_experiment_route,
"/session_groups": self.list_session_groups_route,
"/metric_evals": self.list_metric_evals_route,
}
def is_active(self):
return False # `list_plugins` as called by TB core suffices
def frontend_metadata(self):
return base_plugin.FrontendMetadata(element_name="tf-hparams-dashboard")
# ---- /download_data- -------------------------------------------------------
@wrappers.Request.application
def download_data_route(self, request):
ctx = plugin_util.context(request.environ)
experiment_id = plugin_util.experiment_id(request.environ)
try:
response_format = request.args.get("format")
columns_visibility = json.loads(
request.args.get("columnsVisibility")
)
request_proto = _parse_request_argument(
request, api_pb2.ListSessionGroupsRequest
)
session_groups = list_session_groups.Handler(
ctx, self._context, experiment_id, request_proto
).run()
experiment = get_experiment.Handler(
ctx, self._context, experiment_id, request_proto
).run()
body, mime_type = download_data.Handler(
self._context,
experiment,
session_groups,
response_format,
columns_visibility,
).run()
return http_util.Respond(request, body, mime_type)
except error.HParamsError as e:
logger.error("HParams error: %s" % e)
raise werkzeug.exceptions.BadRequest(description=str(e))
# ---- /experiment -----------------------------------------------------------
@wrappers.Request.application
def get_experiment_route(self, request):
ctx = plugin_util.context(request.environ)
experiment_id = plugin_util.experiment_id(request.environ)
try:
request_proto = _parse_request_argument(
request, api_pb2.GetExperimentRequest
)
response_proto = get_experiment.Handler(
ctx,
self._context,
experiment_id,
request_proto,
).run()
response = plugin_util.proto_to_json(response_proto)
return http_util.Respond(
request,
response,
"application/json",
)
except error.HParamsError as e:
logger.error("HParams error: %s" % e)
raise werkzeug.exceptions.BadRequest(description=str(e))
# ---- /session_groups -------------------------------------------------------
@wrappers.Request.application
def list_session_groups_route(self, request):
ctx = plugin_util.context(request.environ)
experiment_id = plugin_util.experiment_id(request.environ)
try:
request_proto = _parse_request_argument(
request, api_pb2.ListSessionGroupsRequest
)
response_proto = list_session_groups.Handler(
ctx,
self._context,
experiment_id,
request_proto,
).run()
response = plugin_util.proto_to_json(response_proto)
return http_util.Respond(
request,
response,
"application/json",
)
except error.HParamsError as e:
logger.error("HParams error: %s" % e)
raise werkzeug.exceptions.BadRequest(description=str(e))
# ---- /metric_evals ---------------------------------------------------------
@wrappers.Request.application
def list_metric_evals_route(self, request):
ctx = plugin_util.context(request.environ)
experiment_id = plugin_util.experiment_id(request.environ)
try:
request_proto = _parse_request_argument(
request, api_pb2.ListMetricEvalsRequest
)
scalars_plugin = self._get_scalars_plugin()
if not scalars_plugin:
raise werkzeug.exceptions.NotFound("Scalars plugin not loaded")
return http_util.Respond(
request,
list_metric_evals.Handler(
ctx, request_proto, scalars_plugin, experiment_id
).run(),
"application/json",
)
except error.HParamsError as e:
logger.error("HParams error: %s" % e)
raise werkzeug.exceptions.BadRequest(description=str(e))
def _get_scalars_plugin(self):
"""Tries to get the scalars plugin.
Returns:
The scalars plugin or None if it is not yet registered.
"""
return self._context.tb_context.plugin_name_to_instance.get(
scalars_metadata.PLUGIN_NAME
)
def _parse_request_argument(request, proto_class):
request_json = (
request.data
if request.method == "POST"
else request.args.get("request")
)
try:
return json_format.Parse(request_json, proto_class())
# if request_json is None, json_format.Parse will throw an AttributeError:
# 'NoneType' object has no attribute 'decode'.
except (AttributeError, json_format.ParseError) as e:
raise error.HParamsError(
"Expected a JSON-formatted request data of type: {}, but got {} ".format(
proto_class, request_json
)
) from e

View File

@ -0,0 +1,69 @@
# -*- coding: utf-8 -*-
# Generated by the protocol buffer compiler. DO NOT EDIT!
# source: tensorboard/plugins/hparams/hparams_util.proto
"""Generated protocol buffer code."""
from google.protobuf import descriptor as _descriptor
from google.protobuf import descriptor_pool as _descriptor_pool
from google.protobuf import message as _message
from google.protobuf import reflection as _reflection
from google.protobuf import symbol_database as _symbol_database
# @@protoc_insertion_point(imports)
_sym_db = _symbol_database.Default()
from google.protobuf import struct_pb2 as google_dot_protobuf_dot_struct__pb2
from tensorboard.plugins.hparams import api_pb2 as tensorboard_dot_plugins_dot_hparams_dot_api__pb2
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n.tensorboard/plugins/hparams/hparams_util.proto\x12\x13tensorboard.hparams\x1a\x1cgoogle/protobuf/struct.proto\x1a%tensorboard/plugins/hparams/api.proto\"H\n\x0fHParamInfosList\x12\x35\n\x0chparam_infos\x18\x01 \x03(\x0b\x32\x1f.tensorboard.hparams.HParamInfo\"H\n\x0fMetricInfosList\x12\x35\n\x0cmetric_infos\x18\x01 \x03(\x0b\x32\x1f.tensorboard.hparams.MetricInfo\"\x8d\x01\n\x07HParams\x12:\n\x07hparams\x18\x01 \x03(\x0b\x32).tensorboard.hparams.HParams.HparamsEntry\x1a\x46\n\x0cHparamsEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12%\n\x05value\x18\x02 \x01(\x0b\x32\x16.google.protobuf.Value:\x02\x38\x01\x62\x06proto3')
_HPARAMINFOSLIST = DESCRIPTOR.message_types_by_name['HParamInfosList']
_METRICINFOSLIST = DESCRIPTOR.message_types_by_name['MetricInfosList']
_HPARAMS = DESCRIPTOR.message_types_by_name['HParams']
_HPARAMS_HPARAMSENTRY = _HPARAMS.nested_types_by_name['HparamsEntry']
HParamInfosList = _reflection.GeneratedProtocolMessageType('HParamInfosList', (_message.Message,), {
'DESCRIPTOR' : _HPARAMINFOSLIST,
'__module__' : 'tensorboard.plugins.hparams.hparams_util_pb2'
# @@protoc_insertion_point(class_scope:tensorboard.hparams.HParamInfosList)
})
_sym_db.RegisterMessage(HParamInfosList)
MetricInfosList = _reflection.GeneratedProtocolMessageType('MetricInfosList', (_message.Message,), {
'DESCRIPTOR' : _METRICINFOSLIST,
'__module__' : 'tensorboard.plugins.hparams.hparams_util_pb2'
# @@protoc_insertion_point(class_scope:tensorboard.hparams.MetricInfosList)
})
_sym_db.RegisterMessage(MetricInfosList)
HParams = _reflection.GeneratedProtocolMessageType('HParams', (_message.Message,), {
'HparamsEntry' : _reflection.GeneratedProtocolMessageType('HparamsEntry', (_message.Message,), {
'DESCRIPTOR' : _HPARAMS_HPARAMSENTRY,
'__module__' : 'tensorboard.plugins.hparams.hparams_util_pb2'
# @@protoc_insertion_point(class_scope:tensorboard.hparams.HParams.HparamsEntry)
})
,
'DESCRIPTOR' : _HPARAMS,
'__module__' : 'tensorboard.plugins.hparams.hparams_util_pb2'
# @@protoc_insertion_point(class_scope:tensorboard.hparams.HParams)
})
_sym_db.RegisterMessage(HParams)
_sym_db.RegisterMessage(HParams.HparamsEntry)
if _descriptor._USE_C_DESCRIPTORS == False:
DESCRIPTOR._options = None
_HPARAMS_HPARAMSENTRY._options = None
_HPARAMS_HPARAMSENTRY._serialized_options = b'8\001'
_HPARAMINFOSLIST._serialized_start=140
_HPARAMINFOSLIST._serialized_end=212
_METRICINFOSLIST._serialized_start=214
_METRICINFOSLIST._serialized_end=286
_HPARAMS._serialized_start=289
_HPARAMS._serialized_end=430
_HPARAMS_HPARAMSENTRY._serialized_start=360
_HPARAMS_HPARAMSENTRY._serialized_end=430
# @@protoc_insertion_point(module_scope)

View File

@ -0,0 +1,38 @@
# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
import math
def is_serializable_value(value):
"""Returns whether a protobuf Value will be serializable by MessageToJson.
The json_format documentation states that "attempting to serialize NaN or
Infinity results in error."
https://protobuf.dev/reference/protobuf/google.protobuf/#value
Args:
value: A value of type protobuf.Value.
Returns:
True if the Value should be serializable without error by MessageToJson.
False, otherwise.
"""
if not value.HasField("number_value"):
return True
number_value = value.number_value
return not math.isnan(number_value) and not math.isinf(number_value)

View File

@ -0,0 +1,58 @@
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Classes and functions for handling the ListMetricEvals API call."""
from tensorboard.plugins.hparams import metrics
from tensorboard.plugins.scalar import scalars_plugin
class Handler:
"""Handles a ListMetricEvals request."""
def __init__(
self, request_context, request, scalars_plugin_instance, experiment
):
"""Constructor.
Args:
request_context: A tensorboard.context.RequestContext.
request: A ListSessionGroupsRequest protobuf.
scalars_plugin_instance: A scalars_plugin.ScalarsPlugin.
experiment: A experiment ID, as a possibly-empty `str`.
"""
self._request_context = request_context
self._request = request
self._scalars_plugin_instance = scalars_plugin_instance
self._experiment = experiment
def run(self):
"""Executes the request.
Returns:
An array of tuples representing the metric evaluations--each of the
form (<wall time in secs>, <training step>, <metric value>).
"""
run, tag = metrics.run_tag_from_session_and_metric(
self._request.session_name, self._request.metric_name
)
body, _ = self._scalars_plugin_instance.scalars_impl(
self._request_context,
tag,
run,
self._experiment,
scalars_plugin.OutputFormat.JSON,
)
return body

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,127 @@
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Constants used in the HParams plugin."""
from tensorboard.compat.proto import summary_pb2
from tensorboard.compat.proto import types_pb2
from tensorboard.plugins.hparams import error
from tensorboard.plugins.hparams import plugin_data_pb2
from tensorboard.util import tensor_util
PLUGIN_NAME = "hparams"
PLUGIN_DATA_VERSION = 0
# Tensor value for use in summaries that really only need to store
# metadata. A length-0 float vector is of minimal serialized length
# (6 bytes) among valid tensors. Cache this: computing it takes on the
# order of tens of microseconds.
NULL_TENSOR = tensor_util.make_tensor_proto(
[], dtype=types_pb2.DT_FLOAT, shape=(0,)
)
EXPERIMENT_TAG = "_hparams_/experiment"
SESSION_START_INFO_TAG = "_hparams_/session_start_info"
SESSION_END_INFO_TAG = "_hparams_/session_end_info"
def create_summary_metadata(hparams_plugin_data_pb):
"""Returns a summary metadata for the HParams plugin.
Returns a summary_pb2.SummaryMetadata holding a copy of the given
HParamsPluginData message in its plugin_data.content field.
Sets the version field of the hparams_plugin_data_pb copy to
PLUGIN_DATA_VERSION.
Args:
hparams_plugin_data_pb: the HParamsPluginData protobuffer to use.
"""
if not isinstance(
hparams_plugin_data_pb, plugin_data_pb2.HParamsPluginData
):
raise TypeError(
"Needed an instance of plugin_data_pb2.HParamsPluginData."
" Got: %s" % type(hparams_plugin_data_pb)
)
content = plugin_data_pb2.HParamsPluginData()
content.CopyFrom(hparams_plugin_data_pb)
content.version = PLUGIN_DATA_VERSION
return summary_pb2.SummaryMetadata(
plugin_data=summary_pb2.SummaryMetadata.PluginData(
plugin_name=PLUGIN_NAME, content=content.SerializeToString()
)
)
def parse_experiment_plugin_data(content):
"""Returns the experiment from HParam's
SummaryMetadata.plugin_data.content.
Raises HParamsError if the content doesn't have 'experiment' set or
this file is incompatible with the version of the metadata stored.
Args:
content: The SummaryMetadata.plugin_data.content to use.
"""
return _parse_plugin_data_as(content, "experiment")
def parse_session_start_info_plugin_data(content):
"""Returns session_start_info from the plugin_data.content.
Raises HParamsError if the content doesn't have 'session_start_info' set or
this file is incompatible with the version of the metadata stored.
Args:
content: The SummaryMetadata.plugin_data.content to use.
"""
return _parse_plugin_data_as(content, "session_start_info")
def parse_session_end_info_plugin_data(content):
"""Returns session_end_info from the plugin_data.content.
Raises HParamsError if the content doesn't have 'session_end_info' set or
this file is incompatible with the version of the metadata stored.
Args:
content: The SummaryMetadata.plugin_data.content to use.
"""
return _parse_plugin_data_as(content, "session_end_info")
def _parse_plugin_data_as(content, data_oneof_field):
"""Returns a data oneof's field from plugin_data.content.
Raises HParamsError if the content doesn't have 'data_oneof_field' set or
this file is incompatible with the version of the metadata stored.
Args:
content: The SummaryMetadata.plugin_data.content to use.
data_oneof_field: string. The name of the data oneof field to return.
"""
plugin_data = plugin_data_pb2.HParamsPluginData.FromString(content)
if plugin_data.version != PLUGIN_DATA_VERSION:
raise error.HParamsError(
"Only supports plugin_data version: %s; found: %s in: %s"
% (PLUGIN_DATA_VERSION, plugin_data.version, plugin_data)
)
if not plugin_data.HasField(data_oneof_field):
raise error.HParamsError(
"Expected plugin_data.%s to be set. Got: %s"
% (data_oneof_field, plugin_data)
)
return getattr(plugin_data, data_oneof_field)

Some files were not shown because too many files have changed in this diff Show More