I am done

This commit is contained in:
2024-10-30 22:14:35 +01:00
parent 720dc28c09
commit 40e2a747cf
36901 changed files with 5011519 additions and 0 deletions

View File

@ -0,0 +1,277 @@
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Provides data ingestion logic backed by local event processing."""
import os
import re
import threading
import time
from tensorboard.backend.event_processing import data_provider
from tensorboard.backend.event_processing import plugin_event_multiplexer
from tensorboard.backend.event_processing import tag_types
from tensorboard.compat import tf
from tensorboard.data import ingester
from tensorboard.plugins.audio import metadata as audio_metadata
from tensorboard.plugins.histogram import metadata as histogram_metadata
from tensorboard.plugins.image import metadata as image_metadata
from tensorboard.plugins.pr_curve import metadata as pr_curve_metadata
from tensorboard.plugins.scalar import metadata as scalar_metadata
from tensorboard.util import tb_logging
DEFAULT_SIZE_GUIDANCE = {
tag_types.TENSORS: 10,
}
# TODO(@wchargin): Replace with something that works for third-party plugins.
DEFAULT_TENSOR_SIZE_GUIDANCE = {
scalar_metadata.PLUGIN_NAME: 1000,
image_metadata.PLUGIN_NAME: 10,
audio_metadata.PLUGIN_NAME: 10,
histogram_metadata.PLUGIN_NAME: 500,
pr_curve_metadata.PLUGIN_NAME: 100,
}
logger = tb_logging.get_logger()
class LocalDataIngester(ingester.DataIngester):
"""Data ingestion implementation to use when running locally."""
def __init__(self, flags):
"""Initializes a `LocalDataIngester` from `flags`.
Args:
flags: An argparse.Namespace containing TensorBoard CLI flags.
Returns:
The new `LocalDataIngester`.
"""
tensor_size_guidance = dict(DEFAULT_TENSOR_SIZE_GUIDANCE)
tensor_size_guidance.update(flags.samples_per_plugin)
self._multiplexer = plugin_event_multiplexer.EventMultiplexer(
size_guidance=DEFAULT_SIZE_GUIDANCE,
tensor_size_guidance=tensor_size_guidance,
purge_orphaned_data=flags.purge_orphaned_data,
max_reload_threads=flags.max_reload_threads,
event_file_active_filter=_get_event_file_active_filter(flags),
detect_file_replacement=flags.detect_file_replacement,
)
self._data_provider = data_provider.MultiplexerDataProvider(
self._multiplexer, flags.logdir or flags.logdir_spec
)
self._reload_interval = flags.reload_interval
self._reload_task = flags.reload_task
if flags.logdir:
self._path_to_run = {os.path.expanduser(flags.logdir): None}
else:
self._path_to_run = _parse_event_files_spec(flags.logdir_spec)
# Conditionally import tensorflow_io.
if getattr(tf, "__version__", "stub") != "stub":
_check_filesystem_support(self._path_to_run.keys())
@property
def data_provider(self):
return self._data_provider
@property
def deprecated_multiplexer(self):
return self._multiplexer
def start(self):
"""Starts ingesting data based on the ingester flag configuration."""
def _reload():
while True:
start = time.time()
logger.info("TensorBoard reload process beginning")
for path, name in self._path_to_run.items():
self._multiplexer.AddRunsFromDirectory(path, name)
logger.info(
"TensorBoard reload process: Reload the whole Multiplexer"
)
self._multiplexer.Reload()
duration = time.time() - start
logger.info(
"TensorBoard done reloading. Load took %0.3f secs", duration
)
if self._reload_interval == 0:
# Only load the multiplexer once. Do not continuously reload.
break
time.sleep(self._reload_interval)
if self._reload_task == "process":
logger.info("Launching reload in a child process")
import multiprocessing
process = multiprocessing.Process(target=_reload, name="Reloader")
# Best-effort cleanup; on exit, the main TB parent process will attempt to
# kill all its daemonic children.
process.daemon = True
process.start()
elif self._reload_task in ("thread", "auto"):
logger.info("Launching reload in a daemon thread")
thread = threading.Thread(target=_reload, name="Reloader")
# Make this a daemon thread, which won't block TB from exiting.
thread.daemon = True
thread.start()
elif self._reload_task == "blocking":
if self._reload_interval != 0:
raise ValueError(
"blocking reload only allowed with load_interval=0"
)
_reload()
else:
raise ValueError("unrecognized reload_task: %s" % self._reload_task)
def _get_event_file_active_filter(flags):
"""Returns a predicate for whether an event file load timestamp is active.
Returns:
A predicate function accepting a single UNIX timestamp float argument, or
None if multi-file loading is not enabled.
"""
if not flags.reload_multifile:
return None
inactive_secs = flags.reload_multifile_inactive_secs
if inactive_secs == 0:
return None
if inactive_secs < 0:
return lambda timestamp: True
return lambda timestamp: timestamp + inactive_secs >= time.time()
def _parse_event_files_spec(logdir_spec):
"""Parses `logdir_spec` into a map from paths to run group names.
The `--logdir_spec` flag format is a comma-separated list of path
specifications. A path spec looks like 'group_name:/path/to/directory' or
'/path/to/directory'; in the latter case, the group is unnamed. Group names
cannot start with a forward slash: /foo:bar/baz will be interpreted as a spec
with no name and path '/foo:bar/baz'.
Globs are not supported.
Args:
logdir: A comma-separated list of run specifications.
Returns:
A dict mapping directory paths to names like {'/path/to/directory': 'name'}.
Groups without an explicit name are named after their path. If logdir is
None, returns an empty dict, which is helpful for testing things that don't
require any valid runs.
"""
files = {}
if logdir_spec is None:
return files
# Make sure keeping consistent with ParseURI in core/lib/io/path.cc
uri_pattern = re.compile("[a-zA-Z][0-9a-zA-Z.]*://.*")
for specification in logdir_spec.split(","):
# Check if the spec contains group. A spec start with xyz:// is regarded as
# URI path spec instead of group spec. If the spec looks like /foo:bar/baz,
# then we assume it's a path with a colon. If the spec looks like
# [a-zA-z]:\foo then we assume its a Windows path and not a single letter
# group
if (
uri_pattern.match(specification) is None
and ":" in specification
and specification[0] != "/"
and not os.path.splitdrive(specification)[0]
):
# We split at most once so run_name:/path:with/a/colon will work.
run_name, _, path = specification.partition(":")
else:
run_name = None
path = specification
if uri_pattern.match(path) is None:
path = os.path.realpath(os.path.expanduser(path))
files[path] = run_name
return files
def _get_filesystem_scheme(path):
"""Extracts filesystem scheme from a given path.
The filesystem scheme is usually separated by `://` from the local filesystem
path if given. For example, the scheme of `file://tmp/tf` is `file`.
Args:
path: A strings representing an input log directory.
Returns:
Filesystem scheme, None if the path doesn't contain one.
"""
if "://" not in path:
return None
return path.split("://")[0]
def _check_filesystem_support(paths):
"""Examines the list of filesystems user requested.
If TF I/O schemes are requested, try to import tensorflow_io module.
Args:
paths: A list of strings representing input log directories.
"""
get_registered_schemes = getattr(
tf.io.gfile, "get_registered_schemes", None
)
registered_schemes = (
None if get_registered_schemes is None else get_registered_schemes()
)
# Only need to check one path for each scheme.
scheme_to_path = {_get_filesystem_scheme(path): path for path in paths}
missing_scheme = None
for scheme, path in scheme_to_path.items():
if scheme is None:
continue
# Use `tf.io.gfile.exists.get_registered_schemes` if possible.
if registered_schemes is not None:
if scheme not in registered_schemes:
missing_scheme = scheme
break
else:
# Fall back to `tf.io.gfile.exists`.
try:
tf.io.gfile.exists(path)
except tf.errors.UnimplementedError:
missing_scheme = scheme
break
except tf.errors.OpError:
# Swallow other errors; we aren't concerned about them at this point.
pass
if missing_scheme:
try:
import tensorflow_io # noqa: F401
except ImportError as e:
supported_schemes_msg = (
" (supported schemes: {})".format(registered_schemes)
if registered_schemes
else ""
)
raise tf.errors.UnimplementedError(
None,
None,
(
"Error: Unsupported filename scheme '{}'{}. For additional"
+ " filesystem support, consider installing TensorFlow I/O"
+ " (https://www.tensorflow.org/io) via `pip install tensorflow-io`."
).format(missing_scheme, supported_schemes_msg),
) from e

View File

@ -0,0 +1,538 @@
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Bridge from event multiplexer storage to generic data APIs."""
import base64
import collections
import json
import random
from tensorboard import errors
from tensorboard.compat.proto import summary_pb2
from tensorboard.data import provider
from tensorboard.util import tb_logging
from tensorboard.util import tensor_util
logger = tb_logging.get_logger()
class MultiplexerDataProvider(provider.DataProvider):
def __init__(self, multiplexer, logdir):
"""Trivial initializer.
Args:
multiplexer: A `plugin_event_multiplexer.EventMultiplexer` (note:
not a boring old `event_multiplexer.EventMultiplexer`).
logdir: The log directory from which data is being read. Only used
cosmetically. Should be a `str`.
"""
self._multiplexer = multiplexer
self._logdir = logdir
def __str__(self):
return "MultiplexerDataProvider(logdir=%r)" % self._logdir
def _validate_context(self, ctx):
if type(ctx).__name__ != "RequestContext":
raise TypeError("ctx must be a RequestContext; got: %r" % (ctx,))
def _validate_experiment_id(self, experiment_id):
# This data provider doesn't consume the experiment ID at all, but
# as a courtesy to callers we require that it be a valid string, to
# help catch usage errors.
if not isinstance(experiment_id, str):
raise TypeError(
"experiment_id must be %r, but got %r: %r"
% (str, type(experiment_id), experiment_id)
)
def _validate_downsample(self, downsample):
if downsample is None:
raise TypeError("`downsample` required but not given")
if isinstance(downsample, int):
return # OK
raise TypeError(
"`downsample` must be an int, but got %r: %r"
% (type(downsample), downsample)
)
def _test_run_tag(self, run_tag_filter, run, tag):
runs = run_tag_filter.runs
if runs is not None and run not in runs:
return False
tags = run_tag_filter.tags
if tags is not None and tag not in tags:
return False
return True
def _get_first_event_timestamp(self, run_name):
try:
return self._multiplexer.FirstEventTimestamp(run_name)
except ValueError as e:
return None
def experiment_metadata(self, ctx=None, *, experiment_id):
self._validate_context(ctx)
self._validate_experiment_id(experiment_id)
return provider.ExperimentMetadata(data_location=self._logdir)
def list_plugins(self, ctx=None, *, experiment_id):
self._validate_context(ctx)
self._validate_experiment_id(experiment_id)
# Note: This result may include plugins that only have time
# series with `DATA_CLASS_UNKNOWN`, which will not actually be
# accessible via `list_*` or read_*`. This is inconsistent with
# the specification for `list_plugins`, but the bug should be
# mostly harmless.
return self._multiplexer.ActivePlugins()
def list_runs(self, ctx=None, *, experiment_id):
self._validate_context(ctx)
self._validate_experiment_id(experiment_id)
return [
provider.Run(
run_id=run, # use names as IDs
run_name=run,
start_time=self._get_first_event_timestamp(run),
)
for run in self._multiplexer.Runs()
]
def list_scalars(
self, ctx=None, *, experiment_id, plugin_name, run_tag_filter=None
):
self._validate_context(ctx)
self._validate_experiment_id(experiment_id)
index = self._index(
plugin_name, run_tag_filter, summary_pb2.DATA_CLASS_SCALAR
)
return self._list(provider.ScalarTimeSeries, index)
def read_scalars(
self,
ctx=None,
*,
experiment_id,
plugin_name,
downsample=None,
run_tag_filter=None,
):
self._validate_context(ctx)
self._validate_experiment_id(experiment_id)
self._validate_downsample(downsample)
index = self._index(
plugin_name, run_tag_filter, summary_pb2.DATA_CLASS_SCALAR
)
return self._read(_convert_scalar_event, index, downsample)
def read_last_scalars(
self,
ctx=None,
*,
experiment_id,
plugin_name,
run_tag_filter=None,
):
self._validate_context(ctx)
self._validate_experiment_id(experiment_id)
index = self._index(
plugin_name, run_tag_filter, summary_pb2.DATA_CLASS_SCALAR
)
run_tag_to_last_scalar_datum = collections.defaultdict(dict)
for run, tags_for_run in index.items():
for tag, metadata in tags_for_run.items():
events = self._multiplexer.Tensors(run, tag)
if events:
run_tag_to_last_scalar_datum[run][tag] = (
_convert_scalar_event(events[-1])
)
return run_tag_to_last_scalar_datum
def list_tensors(
self, ctx=None, *, experiment_id, plugin_name, run_tag_filter=None
):
self._validate_context(ctx)
self._validate_experiment_id(experiment_id)
index = self._index(
plugin_name, run_tag_filter, summary_pb2.DATA_CLASS_TENSOR
)
return self._list(provider.TensorTimeSeries, index)
def read_tensors(
self,
ctx=None,
*,
experiment_id,
plugin_name,
downsample=None,
run_tag_filter=None,
):
self._validate_context(ctx)
self._validate_experiment_id(experiment_id)
self._validate_downsample(downsample)
index = self._index(
plugin_name, run_tag_filter, summary_pb2.DATA_CLASS_TENSOR
)
return self._read(_convert_tensor_event, index, downsample)
def _index(self, plugin_name, run_tag_filter, data_class_filter):
"""List time series and metadata matching the given filters.
This is like `_list`, but doesn't traverse `Tensors(...)` to
compute metadata that's not always needed.
Args:
plugin_name: A string plugin name filter (required).
run_tag_filter: An `provider.RunTagFilter`, or `None`.
data_class_filter: A `summary_pb2.DataClass` filter (required).
Returns:
A nested dict `d` such that `d[run][tag]` is a
`SummaryMetadata` proto.
"""
if run_tag_filter is None:
run_tag_filter = provider.RunTagFilter(runs=None, tags=None)
runs = run_tag_filter.runs
tags = run_tag_filter.tags
# Optimization for a common case, reading a single time series.
if runs and len(runs) == 1 and tags and len(tags) == 1:
(run,) = runs
(tag,) = tags
try:
metadata = self._multiplexer.SummaryMetadata(run, tag)
except KeyError:
return {}
all_metadata = {run: {tag: metadata}}
else:
all_metadata = self._multiplexer.AllSummaryMetadata()
result = {}
for run, tag_to_metadata in all_metadata.items():
if runs is not None and run not in runs:
continue
result_for_run = {}
for tag, metadata in tag_to_metadata.items():
if tags is not None and tag not in tags:
continue
if metadata.data_class != data_class_filter:
continue
if metadata.plugin_data.plugin_name != plugin_name:
continue
result[run] = result_for_run
result_for_run[tag] = metadata
return result
def _list(self, construct_time_series, index):
"""Helper to list scalar or tensor time series.
Args:
construct_time_series: `ScalarTimeSeries` or `TensorTimeSeries`.
index: The result of `self._index(...)`.
Returns:
A list of objects of type given by `construct_time_series`,
suitable to be returned from `list_scalars` or `list_tensors`.
"""
result = {}
for run, tag_to_metadata in index.items():
result_for_run = {}
result[run] = result_for_run
for tag, summary_metadata in tag_to_metadata.items():
max_step = None
max_wall_time = None
for event in self._multiplexer.Tensors(run, tag):
if max_step is None or max_step < event.step:
max_step = event.step
if max_wall_time is None or max_wall_time < event.wall_time:
max_wall_time = event.wall_time
summary_metadata = self._multiplexer.SummaryMetadata(run, tag)
result_for_run[tag] = construct_time_series(
max_step=max_step,
max_wall_time=max_wall_time,
plugin_content=summary_metadata.plugin_data.content,
description=summary_metadata.summary_description,
display_name=summary_metadata.display_name,
)
return result
def _read(self, convert_event, index, downsample):
"""Helper to read scalar or tensor data from the multiplexer.
Args:
convert_event: Takes `plugin_event_accumulator.TensorEvent` to
either `provider.ScalarDatum` or `provider.TensorDatum`.
index: The result of `self._index(...)`.
downsample: Non-negative `int`; how many samples to return per
time series.
Returns:
A dict of dicts of values returned by `convert_event` calls,
suitable to be returned from `read_scalars` or `read_tensors`.
"""
result = {}
for run, tags_for_run in index.items():
result_for_run = {}
result[run] = result_for_run
for tag, metadata in tags_for_run.items():
events = self._multiplexer.Tensors(run, tag)
data = [convert_event(e) for e in events]
result_for_run[tag] = _downsample(data, downsample)
return result
def list_blob_sequences(
self, ctx=None, *, experiment_id, plugin_name, run_tag_filter=None
):
self._validate_context(ctx)
self._validate_experiment_id(experiment_id)
index = self._index(
plugin_name, run_tag_filter, summary_pb2.DATA_CLASS_BLOB_SEQUENCE
)
result = {}
for run, tag_to_metadata in index.items():
result_for_run = {}
result[run] = result_for_run
for tag, metadata in tag_to_metadata.items():
max_step = None
max_wall_time = None
max_length = None
for event in self._multiplexer.Tensors(run, tag):
if max_step is None or max_step < event.step:
max_step = event.step
if max_wall_time is None or max_wall_time < event.wall_time:
max_wall_time = event.wall_time
length = _tensor_size(event.tensor_proto)
if max_length is None or length > max_length:
max_length = length
result_for_run[tag] = provider.BlobSequenceTimeSeries(
max_step=max_step,
max_wall_time=max_wall_time,
max_length=max_length,
plugin_content=metadata.plugin_data.content,
description=metadata.summary_description,
display_name=metadata.display_name,
)
return result
def read_blob_sequences(
self,
ctx=None,
*,
experiment_id,
plugin_name,
downsample=None,
run_tag_filter=None,
):
self._validate_context(ctx)
self._validate_experiment_id(experiment_id)
self._validate_downsample(downsample)
index = self._index(
plugin_name, run_tag_filter, summary_pb2.DATA_CLASS_BLOB_SEQUENCE
)
result = {}
for run, tags in index.items():
result_for_run = {}
result[run] = result_for_run
for tag in tags:
events = self._multiplexer.Tensors(run, tag)
data_by_step = {}
for event in events:
if event.step in data_by_step:
continue
data_by_step[event.step] = _convert_blob_sequence_event(
experiment_id, plugin_name, run, tag, event
)
data = [datum for (step, datum) in sorted(data_by_step.items())]
result_for_run[tag] = _downsample(data, downsample)
return result
def read_blob(self, ctx=None, *, blob_key):
self._validate_context(ctx)
(
unused_experiment_id,
plugin_name,
run,
tag,
step,
index,
) = _decode_blob_key(blob_key)
summary_metadata = self._multiplexer.SummaryMetadata(run, tag)
if summary_metadata.data_class != summary_pb2.DATA_CLASS_BLOB_SEQUENCE:
raise errors.NotFoundError(blob_key)
tensor_events = self._multiplexer.Tensors(run, tag)
# In case of multiple events at this step, take first (arbitrary).
matching_step = next((e for e in tensor_events if e.step == step), None)
if not matching_step:
raise errors.NotFoundError("%s: no such step %r" % (blob_key, step))
tensor = tensor_util.make_ndarray(matching_step.tensor_proto)
return tensor[index]
# TODO(davidsoergel): deduplicate with other implementations
def _encode_blob_key(experiment_id, plugin_name, run, tag, step, index):
"""Generate a blob key: a short, URL-safe string identifying a blob.
A blob can be located using a set of integer and string fields; here we
serialize these to allow passing the data through a URL. Specifically, we
1) construct a tuple of the arguments in order; 2) represent that as an
ascii-encoded JSON string (without whitespace); and 3) take the URL-safe
base64 encoding of that, with no padding. For example:
1) Tuple: ("some_id", "graphs", "train", "graph_def", 2, 0)
2) JSON: ["some_id","graphs","train","graph_def",2,0]
3) base64: WyJzb21lX2lkIiwiZ3JhcGhzIiwidHJhaW4iLCJncmFwaF9kZWYiLDIsMF0K
Args:
experiment_id: a string ID identifying an experiment.
plugin_name: string
run: string
tag: string
step: int
index: int
Returns:
A URL-safe base64-encoded string representing the provided arguments.
"""
# Encodes the blob key as a URL-safe string, as required by the
# `BlobReference` API in `tensorboard/data/provider.py`, because these keys
# may be used to construct URLs for retrieving blobs.
stringified = json.dumps(
(experiment_id, plugin_name, run, tag, step, index),
separators=(",", ":"),
)
bytesified = stringified.encode("ascii")
encoded = base64.urlsafe_b64encode(bytesified)
return encoded.decode("ascii").rstrip("=")
# Any changes to this function need not be backward-compatible, even though
# the current encoding was used to generate URLs. The reason is that the
# generated URLs are not considered permalinks: they need to be valid only
# within the context of the session that created them (via the matching
# `_encode_blob_key` function above).
def _decode_blob_key(key):
"""Decode a blob key produced by `_encode_blob_key` into component fields.
Args:
key: a blob key, as generated by `_encode_blob_key`.
Returns:
A tuple of `(experiment_id, plugin_name, run, tag, step, index)`, with types
matching the arguments of `_encode_blob_key`.
"""
decoded = base64.urlsafe_b64decode(key + "==") # pad past a multiple of 4.
stringified = decoded.decode("ascii")
(experiment_id, plugin_name, run, tag, step, index) = json.loads(
stringified
)
return (experiment_id, plugin_name, run, tag, step, index)
def _convert_scalar_event(event):
"""Helper for `read_scalars`."""
return provider.ScalarDatum(
step=event.step,
wall_time=event.wall_time,
value=tensor_util.make_ndarray(event.tensor_proto).item(),
)
def _convert_tensor_event(event):
"""Helper for `read_tensors`."""
return provider.TensorDatum(
step=event.step,
wall_time=event.wall_time,
numpy=tensor_util.make_ndarray(event.tensor_proto),
)
def _convert_blob_sequence_event(experiment_id, plugin_name, run, tag, event):
"""Helper for `read_blob_sequences`."""
num_blobs = _tensor_size(event.tensor_proto)
values = tuple(
provider.BlobReference(
_encode_blob_key(
experiment_id,
plugin_name,
run,
tag,
event.step,
idx,
)
)
for idx in range(num_blobs)
)
return provider.BlobSequenceDatum(
wall_time=event.wall_time,
step=event.step,
values=values,
)
def _tensor_size(tensor_proto):
"""Compute the number of elements in a tensor.
This does not deserialize the full tensor contents.
Args:
tensor_proto: A `tensorboard.compat.proto.tensor_pb2.TensorProto`.
Returns:
A non-negative `int`.
"""
# This is the same logic that `tensor_util.make_ndarray` uses to
# compute the size, but without the actual buffer copies.
result = 1
for dim in tensor_proto.tensor_shape.dim:
result *= dim.size
return result
def _downsample(xs, k):
"""Downsample `xs` to at most `k` elements.
If `k` is larger than `xs`, then the contents of `xs` itself will be
returned. If `k` is smaller than `xs`, the last element of `xs` will
always be included (unless `k` is `0`) and the preceding elements
will be selected uniformly at random.
This differs from `random.sample` in that it returns a subsequence
(i.e., order is preserved) and that it permits `k > len(xs)`.
The random number generator will always be `random.Random(0)`, so
this function is deterministic (within a Python process).
Args:
xs: A sequence (`collections.abc.Sequence`).
k: A non-negative integer.
Returns:
A new list whose elements are a subsequence of `xs` of length
`min(k, len(xs))` and that is guaranteed to include the last
element of `xs`, uniformly selected among such subsequences.
"""
if k > len(xs):
return list(xs)
if k == 0:
return []
indices = random.Random(0).sample(range(len(xs) - 1), k - 1)
indices.sort()
indices += [len(xs) - 1]
return [xs[i] for i in indices]

View File

@ -0,0 +1,144 @@
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Implementation for a multi-file directory loader."""
from tensorboard.backend.event_processing import directory_watcher
from tensorboard.backend.event_processing import io_wrapper
from tensorboard.compat import tf
from tensorboard.util import tb_logging
logger = tb_logging.get_logger()
# Sentinel object for an inactive path.
_INACTIVE = object()
class DirectoryLoader:
"""Loader for an entire directory, maintaining multiple active file
loaders.
This class takes a directory, a factory for loaders, and optionally a
path filter and watches all the paths inside that directory for new data.
Each file loader created by the factory must read a path and produce an
iterator of (timestamp, value) pairs.
Unlike DirectoryWatcher, this class does not assume that only one file
receives new data at a time; there can be arbitrarily many active files.
However, any file whose maximum load timestamp fails an "active" predicate
will be marked as inactive and no longer checked for new data.
"""
def __init__(
self,
directory,
loader_factory,
path_filter=lambda x: True,
active_filter=lambda timestamp: True,
):
"""Constructs a new MultiFileDirectoryLoader.
Args:
directory: The directory to load files from.
loader_factory: A factory for creating loaders. The factory should take a
path and return an object that has a Load method returning an iterator
yielding (unix timestamp as float, value) pairs for any new data
path_filter: If specified, only paths matching this filter are loaded.
active_filter: If specified, any loader whose maximum load timestamp does
not pass this filter will be marked as inactive and no longer read.
Raises:
ValueError: If directory or loader_factory are None.
"""
if directory is None:
raise ValueError("A directory is required")
if loader_factory is None:
raise ValueError("A loader factory is required")
self._directory = directory
self._loader_factory = loader_factory
self._path_filter = path_filter
self._active_filter = active_filter
self._loaders = {}
self._max_timestamps = {}
def Load(self):
"""Loads new values from all active files.
Yields:
All values that have not been yielded yet.
Raises:
DirectoryDeletedError: If the directory has been permanently deleted
(as opposed to being temporarily unavailable).
"""
try:
all_paths = io_wrapper.ListDirectoryAbsolute(self._directory)
paths = sorted(p for p in all_paths if self._path_filter(p))
for path in paths:
for value in self._LoadPath(path):
yield value
except tf.errors.OpError as e:
if not tf.io.gfile.exists(self._directory):
raise directory_watcher.DirectoryDeletedError(
"Directory %s has been permanently deleted"
% self._directory
)
else:
logger.info("Ignoring error during file loading: %s" % e)
def _LoadPath(self, path):
"""Generator for values from a single path's loader.
Args:
path: the path to load from
Yields:
All values from this path's loader that have not been yielded yet.
"""
max_timestamp = self._max_timestamps.get(path, None)
if max_timestamp is _INACTIVE or self._MarkIfInactive(
path, max_timestamp
):
logger.debug("Skipping inactive path %s", path)
return
loader = self._loaders.get(path, None)
if loader is None:
try:
loader = self._loader_factory(path)
except tf.errors.NotFoundError:
# Happens if a file was removed after we listed the directory.
logger.debug("Skipping nonexistent path %s", path)
return
self._loaders[path] = loader
logger.info("Loading data from path %s", path)
for timestamp, value in loader.Load():
if max_timestamp is None or timestamp > max_timestamp:
max_timestamp = timestamp
yield value
if not self._MarkIfInactive(path, max_timestamp):
self._max_timestamps[path] = max_timestamp
def _MarkIfInactive(self, path, max_timestamp):
"""If max_timestamp is inactive, returns True and marks the path as
such."""
logger.debug("Checking active status of %s at %s", path, max_timestamp)
if max_timestamp is not None and not self._active_filter(max_timestamp):
self._max_timestamps[path] = _INACTIVE
del self._loaders[path]
return True
return False

View File

@ -0,0 +1,273 @@
# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Contains the implementation for the DirectoryWatcher class."""
import bisect
from tensorboard.backend.event_processing import io_wrapper
from tensorboard.compat import tf
from tensorboard.util import io_util
from tensorboard.util import tb_logging
logger = tb_logging.get_logger()
class DirectoryWatcher:
"""A DirectoryWatcher wraps a loader to load from a sequence of paths.
A loader reads a path and produces some kind of values as an iterator. A
DirectoryWatcher takes a directory, a factory for loaders, and optionally a
path filter and watches all the paths inside that directory.
This class is only valid under the assumption that only one path will be
written to by the data source at a time and that once the source stops writing
to a path, it will start writing to a new path that's lexicographically
greater and never come back. It uses some heuristics to check whether this is
true based on tracking changes to the files' sizes, but the check can have
false negatives. However, it should have no false positives.
"""
def __init__(self, directory, loader_factory, path_filter=lambda x: True):
"""Constructs a new DirectoryWatcher.
Args:
directory: The directory to load files from.
loader_factory: A factory for creating loaders. The factory should take a
path and return an object that has a Load method returning an
iterator that will yield all events that have not been yielded yet.
path_filter: If specified, only paths matching this filter are loaded.
Raises:
ValueError: If path_provider or loader_factory are None.
"""
if directory is None:
raise ValueError("A directory is required")
if loader_factory is None:
raise ValueError("A loader factory is required")
self._directory = directory
self._path = None
self._loader_factory = loader_factory
self._loader = None
self._path_filter = path_filter
self._ooo_writes_detected = False
# The file size for each file at the time it was finalized.
self._finalized_sizes = {}
def Load(self):
"""Loads new values.
The watcher will load from one path at a time; as soon as that path stops
yielding events, it will move on to the next path. We assume that old paths
are never modified after a newer path has been written. As a result, Load()
can be called multiple times in a row without losing events that have not
been yielded yet. In other words, we guarantee that every event will be
yielded exactly once.
Yields:
All values that have not been yielded yet.
Raises:
DirectoryDeletedError: If the directory has been permanently deleted
(as opposed to being temporarily unavailable).
"""
try:
for event in self._LoadInternal():
yield event
except tf.errors.OpError:
if not tf.io.gfile.exists(self._directory):
raise DirectoryDeletedError(
"Directory %s has been permanently deleted"
% self._directory
)
def _LoadInternal(self):
"""Internal implementation of Load().
The only difference between this and Load() is that the latter will throw
DirectoryDeletedError on I/O errors if it thinks that the directory has been
permanently deleted.
Yields:
All values that have not been yielded yet.
"""
# If the loader exists, check it for a value.
if not self._loader:
self._InitializeLoader()
# If it still doesn't exist, there is no data
if not self._loader:
return
while True:
# Yield all the new events in the path we're currently loading from.
for event in self._loader.Load():
yield event
next_path = self._GetNextPath()
if not next_path:
logger.info("No path found after %s", self._path)
# Current path is empty and there are no new paths, so we're done.
return
# There's a new path, so check to make sure there weren't any events
# written between when we finished reading the current path and when we
# checked for the new one. The sequence of events might look something
# like this:
#
# 1. Event #1 written to path #1.
# 2. We check for events and yield event #1 from path #1
# 3. We check for events and see that there are no more events in path #1.
# 4. Event #2 is written to path #1.
# 5. Event #3 is written to path #2.
# 6. We check for a new path and see that path #2 exists.
#
# Without this loop, we would miss event #2. We're also guaranteed by the
# loader contract that no more events will be written to path #1 after
# events start being written to path #2, so we don't have to worry about
# that.
for event in self._loader.Load():
yield event
logger.info(
"Directory watcher advancing from %s to %s",
self._path,
next_path,
)
# Advance to the next path and start over.
self._SetPath(next_path)
# The number of paths before the current one to check for out of order writes.
_OOO_WRITE_CHECK_COUNT = 20
def OutOfOrderWritesDetected(self):
"""Returns whether any out-of-order writes have been detected.
Out-of-order writes are only checked as part of the Load() iterator. Once an
out-of-order write is detected, this function will always return true.
Note that out-of-order write detection is not performed on GCS paths, so
this function will always return false.
Returns:
Whether any out-of-order write has ever been detected by this watcher.
"""
return self._ooo_writes_detected
def _InitializeLoader(self):
path = self._GetNextPath()
if path:
self._SetPath(path)
def _SetPath(self, path):
"""Sets the current path to watch for new events.
This also records the size of the old path, if any. If the size can't be
found, an error is logged.
Args:
path: The full path of the file to watch.
"""
old_path = self._path
if old_path and not io_util.IsCloudPath(old_path):
try:
# We're done with the path, so store its size.
size = tf.io.gfile.stat(old_path).length
logger.debug("Setting latest size of %s to %d", old_path, size)
self._finalized_sizes[old_path] = size
except tf.errors.OpError as e:
logger.error("Unable to get size of %s: %s", old_path, e)
self._path = path
self._loader = self._loader_factory(path)
def _GetNextPath(self):
"""Gets the next path to load from.
This function also does the checking for out-of-order writes as it iterates
through the paths.
Returns:
The next path to load events from, or None if there are no more paths.
"""
paths = sorted(
path
for path in io_wrapper.ListDirectoryAbsolute(self._directory)
if self._path_filter(path)
)
if not paths:
return None
if self._path is None:
return paths[0]
# Don't bother checking if the paths are GCS (which we can't check) or if
# we've already detected an OOO write.
if not io_util.IsCloudPath(paths[0]) and not self._ooo_writes_detected:
# Check the previous _OOO_WRITE_CHECK_COUNT paths for out of order writes.
current_path_index = bisect.bisect_left(paths, self._path)
ooo_check_start = max(
0, current_path_index - self._OOO_WRITE_CHECK_COUNT
)
for path in paths[ooo_check_start:current_path_index]:
if self._HasOOOWrite(path):
self._ooo_writes_detected = True
break
next_paths = list(
path for path in paths if self._path is None or path > self._path
)
if next_paths:
return min(next_paths)
else:
return None
def _HasOOOWrite(self, path):
"""Returns whether the path has had an out-of-order write."""
# Check the sizes of each path before the current one.
size = tf.io.gfile.stat(path).length
old_size = self._finalized_sizes.get(path, None)
if size != old_size:
if old_size is None:
logger.error(
"File %s created after file %s even though it's "
"lexicographically earlier",
path,
self._path,
)
else:
logger.error(
"File %s updated even though the current file is %s",
path,
self._path,
)
return True
else:
return False
class DirectoryDeletedError(Exception):
"""Thrown by Load() when the directory is *permanently* gone.
We distinguish this from temporary errors so that other code can
decide to drop all of our data only when a directory has been
intentionally deleted, as opposed to due to transient filesystem
errors.
"""
pass

View File

@ -0,0 +1,951 @@
# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Takes a generator of values, and accumulates them for a frontend."""
import collections
import dataclasses
import threading
from typing import Optional, Sequence, Tuple
from tensorboard.backend.event_processing import directory_watcher
from tensorboard.backend.event_processing import event_file_loader
from tensorboard.backend.event_processing import event_util
from tensorboard.backend.event_processing import io_wrapper
from tensorboard.backend.event_processing import plugin_asset_util
from tensorboard.backend.event_processing import reservoir
from tensorboard.backend.event_processing import tag_types
from tensorboard.compat.proto import config_pb2
from tensorboard.compat.proto import event_pb2
from tensorboard.compat.proto import graph_pb2
from tensorboard.compat.proto import meta_graph_pb2
from tensorboard.compat.proto import tensor_pb2
from tensorboard.plugins.distribution import compressor
from tensorboard.util import tb_logging
logger = tb_logging.get_logger()
@dataclasses.dataclass(frozen=True)
class ScalarEvent:
"""Contains information of a scalar event.
Attributes:
wall_time: Timestamp of the event in seconds.
step: Global step of the event.
value: A float or int value of the scalar.
"""
wall_time: float
step: int
value: float
@dataclasses.dataclass(frozen=True)
class CompressedHistogramEvent:
"""Contains information of a compressed histogram event.
Attributes:
wall_time: Timestamp of the event in seconds.
step: Global step of the event.
compressed_histogram_values: A sequence of tuples of basis points and
associated values in a compressed histogram.
"""
wall_time: float
step: int
compressed_histogram_values: Sequence[Tuple[float, float]]
@dataclasses.dataclass(frozen=True)
class HistogramValue:
"""Holds the information of the histogram values.
Attributes:
min: A float or int min value.
max: A float or int max value.
num: Total number of values.
sum: Sum of all values.
sum_squares: Sum of squares for all values.
bucket_limit: Upper values per bucket.
bucket: Numbers of values per bucket.
"""
min: float
max: float
num: int
sum: float
sum_squares: float
bucket_limit: Sequence[float]
bucket: Sequence[int]
@dataclasses.dataclass(frozen=True)
class HistogramEvent:
"""Contains information of a histogram event.
Attributes:
wall_time: Timestamp of the event in seconds.
step: Global step of the event.
histogram_value: Information of the histogram values.
"""
wall_time: float
step: int
histogram_value: HistogramValue
@dataclasses.dataclass(frozen=True)
class ImageEvent:
"""Contains information of an image event.
Attributes:
wall_time: Timestamp of the event in seconds.
step: Global step of the event.
encoded_image_string: Image content encoded in bytes.
width: Width of the image.
height: Height of the image.
"""
wall_time: float
step: int
encoded_image_string: bytes
width: int
height: int
@dataclasses.dataclass(frozen=True)
class AudioEvent:
"""Contains information of an audio event.
Attributes:
wall_time: Timestamp of the event in seconds.
step: Global step of the event.
encoded_audio_string: Audio content encoded in bytes.
content_type: A string describes the type of the audio content.
sample_rate: Sample rate of the audio in Hz. Must be positive.
length_frames: Length of the audio in frames (samples per channel).
"""
wall_time: float
step: int
encoded_audio_string: bytes
content_type: str
sample_rate: float
length_frames: int
@dataclasses.dataclass(frozen=True)
class TensorEvent:
"""A tensor event.
Attributes:
wall_time: Timestamp of the event in seconds.
step: Global step of the event.
tensor_proto: A `TensorProto`.
"""
wall_time: float
step: int
tensor_proto: tensor_pb2.TensorProto
## Different types of summary events handled by the event_accumulator
SUMMARY_TYPES = {
"simple_value": "_ProcessScalar",
"histo": "_ProcessHistogram",
"image": "_ProcessImage",
"audio": "_ProcessAudio",
"tensor": "_ProcessTensor",
}
# Legacy aliases
COMPRESSED_HISTOGRAMS = tag_types.COMPRESSED_HISTOGRAMS
HISTOGRAMS = tag_types.HISTOGRAMS
IMAGES = tag_types.IMAGES
AUDIO = tag_types.AUDIO
SCALARS = tag_types.SCALARS
TENSORS = tag_types.TENSORS
GRAPH = tag_types.GRAPH
META_GRAPH = tag_types.META_GRAPH
RUN_METADATA = tag_types.RUN_METADATA
## Normal CDF for std_devs: (-Inf, -1.5, -1, -0.5, 0, 0.5, 1, 1.5, Inf)
## naturally gives bands around median of width 1 std dev, 2 std dev, 3 std dev,
## and then the long tail.
NORMAL_HISTOGRAM_BPS = (0, 668, 1587, 3085, 5000, 6915, 8413, 9332, 10000)
DEFAULT_SIZE_GUIDANCE = {
COMPRESSED_HISTOGRAMS: 500,
IMAGES: 4,
AUDIO: 4,
SCALARS: 10000,
HISTOGRAMS: 1,
TENSORS: 10,
}
STORE_EVERYTHING_SIZE_GUIDANCE = {
COMPRESSED_HISTOGRAMS: 0,
IMAGES: 0,
AUDIO: 0,
SCALARS: 0,
HISTOGRAMS: 0,
TENSORS: 0,
}
class EventAccumulator:
"""An `EventAccumulator` takes an event generator, and accumulates the
values.
The `EventAccumulator` is intended to provide a convenient Python interface
for loading Event data written during a TensorFlow run. TensorFlow writes out
`Event` protobuf objects, which have a timestamp and step number, and often
contain a `Summary`. Summaries can have different kinds of data like an image,
a scalar value, or a histogram. The Summaries also have a tag, which we use to
organize logically related data. The `EventAccumulator` supports retrieving
the `Event` and `Summary` data by its tag.
Calling `Tags()` gets a map from `tagType` (e.g. `'images'`,
`'compressedHistograms'`, `'scalars'`, etc) to the associated tags for those
data types. Then, various functional endpoints (eg
`Accumulator.Scalars(tag)`) allow for the retrieval of all data
associated with that tag.
The `Reload()` method synchronously loads all of the data written so far.
Histograms, audio, and images are very large, so storing all of them is not
recommended.
Fields:
audios: A reservoir.Reservoir of audio summaries.
compressed_histograms: A reservoir.Reservoir of compressed
histogram summaries.
histograms: A reservoir.Reservoir of histogram summaries.
images: A reservoir.Reservoir of image summaries.
most_recent_step: Step of last Event proto added. This should only
be accessed from the thread that calls Reload. This is -1 if
nothing has been loaded yet.
most_recent_wall_time: Timestamp of last Event proto added. This is
a float containing seconds from the UNIX epoch, or -1 if
nothing has been loaded yet. This should only be accessed from
the thread that calls Reload.
path: A file path to a directory containing tf events files, or a single
tf events file. The accumulator will load events from this path.
scalars: A reservoir.Reservoir of scalar summaries.
tensors: A reservoir.Reservoir of tensor summaries.
@@Tensors
"""
def __init__(
self,
path,
size_guidance=None,
compression_bps=NORMAL_HISTOGRAM_BPS,
purge_orphaned_data=True,
):
"""Construct the `EventAccumulator`.
Args:
path: A file path to a directory containing tf events files, or a single
tf events file. The accumulator will load events from this path.
size_guidance: Information on how much data the EventAccumulator should
store in memory. The DEFAULT_SIZE_GUIDANCE tries not to store too much
so as to avoid OOMing the client. The size_guidance should be a map
from a `tagType` string to an integer representing the number of
items to keep per tag for items of that `tagType`. If the size is 0,
all events are stored.
compression_bps: Information on how the `EventAccumulator` should compress
histogram data for the `CompressedHistograms` tag (for details see
`ProcessCompressedHistogram`).
purge_orphaned_data: Whether to discard any events that were "orphaned" by
a TensorFlow restart.
"""
size_guidance = size_guidance or DEFAULT_SIZE_GUIDANCE
sizes = {}
for key in DEFAULT_SIZE_GUIDANCE:
if key in size_guidance:
sizes[key] = size_guidance[key]
else:
sizes[key] = DEFAULT_SIZE_GUIDANCE[key]
self._first_event_timestamp = None
self.scalars = reservoir.Reservoir(size=sizes[SCALARS])
self._graph = None
self._graph_from_metagraph = False
self._meta_graph = None
self._tagged_metadata = {}
self.summary_metadata = {}
self.histograms = reservoir.Reservoir(size=sizes[HISTOGRAMS])
self.compressed_histograms = reservoir.Reservoir(
size=sizes[COMPRESSED_HISTOGRAMS], always_keep_last=False
)
self.images = reservoir.Reservoir(size=sizes[IMAGES])
self.audios = reservoir.Reservoir(size=sizes[AUDIO])
self.tensors = reservoir.Reservoir(size=sizes[TENSORS])
# Keep a mapping from plugin name to a dict mapping from tag to plugin data
# content obtained from the SummaryMetadata (metadata field of Value) for
# that plugin (This is not the entire SummaryMetadata proto - only the
# content for that plugin). The SummaryWriter only keeps the content on the
# first event encountered per tag, so we must store that first instance of
# content for each tag.
self._plugin_to_tag_to_content = collections.defaultdict(dict)
self._generator_mutex = threading.Lock()
self.path = path
self._generator = _GeneratorFromPath(path)
self._compression_bps = compression_bps
self.purge_orphaned_data = purge_orphaned_data
self.most_recent_step = -1
self.most_recent_wall_time = -1
self.file_version = None
# Name of the source writer that writes the event.
self._source_writer = None
# The attributes that get built up by the accumulator
self.accumulated_attrs = (
"scalars",
"histograms",
"compressed_histograms",
"images",
"audios",
)
self._tensor_summaries = {}
def Reload(self):
"""Loads all events added since the last call to `Reload`.
If `Reload` was never called, loads all events in the file.
Returns:
The `EventAccumulator`.
"""
with self._generator_mutex:
for event in self._generator.Load():
self._ProcessEvent(event)
return self
def PluginAssets(self, plugin_name):
"""Return a list of all plugin assets for the given plugin.
Args:
plugin_name: The string name of a plugin to retrieve assets for.
Returns:
A list of string plugin asset names, or empty list if none are available.
If the plugin was not registered, an empty list is returned.
"""
return plugin_asset_util.ListAssets(self.path, plugin_name)
def RetrievePluginAsset(self, plugin_name, asset_name):
"""Return the contents of a given plugin asset.
Args:
plugin_name: The string name of a plugin.
asset_name: The string name of an asset.
Returns:
The string contents of the plugin asset.
Raises:
KeyError: If the asset is not available.
"""
return plugin_asset_util.RetrieveAsset(
self.path, plugin_name, asset_name
)
def FirstEventTimestamp(self):
"""Returns the timestamp in seconds of the first event.
If the first event has been loaded (either by this method or by `Reload`,
this returns immediately. Otherwise, it will load in the first event. Note
that this means that calling `Reload` will cause this to block until
`Reload` has finished.
Returns:
The timestamp in seconds of the first event that was loaded.
Raises:
ValueError: If no events have been loaded and there were no events found
on disk.
"""
if self._first_event_timestamp is not None:
return self._first_event_timestamp
with self._generator_mutex:
try:
event = next(self._generator.Load())
self._ProcessEvent(event)
return self._first_event_timestamp
except StopIteration:
raise ValueError("No event timestamp could be found")
def GetSourceWriter(self) -> Optional[str]:
"""Returns the name of the event writer."""
if self._source_writer is not None:
return self._source_writer
with self._generator_mutex:
try:
event = next(self._generator.Load())
self._ProcessEvent(event)
return self._source_writer
except StopIteration:
logger.info(
"End of file in %s, no source writer was found.", self.path
)
def PluginTagToContent(self, plugin_name):
"""Returns a dict mapping tags to content specific to that plugin.
Args:
plugin_name: The name of the plugin for which to fetch plugin-specific
content.
Raises:
KeyError: if the plugin name is not found.
Returns:
A dict mapping tag names to bytestrings of plugin-specific content-- by
convention, in the form of binary serialized protos.
"""
if plugin_name not in self._plugin_to_tag_to_content:
raise KeyError("Plugin %r could not be found." % plugin_name)
return self._plugin_to_tag_to_content[plugin_name]
def SummaryMetadata(self, tag):
"""Given a summary tag name, return the associated metadata object.
Args:
tag: The name of a tag, as a string.
Raises:
KeyError: If the tag is not found.
Returns:
A `SummaryMetadata` protobuf.
"""
return self.summary_metadata[tag]
def _ProcessEvent(self, event):
"""Called whenever an event is loaded."""
if self._first_event_timestamp is None:
self._first_event_timestamp = event.wall_time
if event.HasField("source_metadata"):
new_source_writer = event_util.GetSourceWriter(
event.source_metadata
)
if self._source_writer and self._source_writer != new_source_writer:
logger.info(
(
"Found new source writer for event.proto. "
"Old: {0}, New: {1}"
).format(self._source_writer, new_source_writer)
)
self._source_writer = new_source_writer
if event.HasField("file_version"):
new_file_version = event_util.ParseFileVersion(event.file_version)
if self.file_version and self.file_version != new_file_version:
## This should not happen.
logger.warning(
(
"Found new file_version for event.proto. This will "
"affect purging logic for TensorFlow restarts. "
"Old: {0} New: {1}"
).format(self.file_version, new_file_version)
)
self.file_version = new_file_version
self._MaybePurgeOrphanedData(event)
## Process the event.
# GraphDef and MetaGraphDef are handled in a special way:
# If no graph_def Event is available, but a meta_graph_def is, and it
# contains a graph_def, then use the meta_graph_def.graph_def as our graph.
# If a graph_def Event is available, always prefer it to the graph_def
# inside the meta_graph_def.
if event.HasField("graph_def"):
if self._graph is not None:
logger.warning(
(
"Found more than one graph event per run, or there was "
"a metagraph containing a graph_def, as well as one or "
"more graph events. Overwriting the graph with the "
"newest event."
)
)
self._graph = event.graph_def
self._graph_from_metagraph = False
elif event.HasField("meta_graph_def"):
if self._meta_graph is not None:
logger.warning(
(
"Found more than one metagraph event per run. "
"Overwriting the metagraph with the newest event."
)
)
self._meta_graph = event.meta_graph_def
if self._graph is None or self._graph_from_metagraph:
# We may have a graph_def in the metagraph. If so, and no
# graph_def is directly available, use this one instead.
meta_graph = meta_graph_pb2.MetaGraphDef()
meta_graph.ParseFromString(self._meta_graph)
if meta_graph.graph_def:
if self._graph is not None:
logger.warning(
(
"Found multiple metagraphs containing graph_defs,"
"but did not find any graph events. Overwriting the "
"graph with the newest metagraph version."
)
)
self._graph_from_metagraph = True
self._graph = meta_graph.graph_def.SerializeToString()
elif event.HasField("tagged_run_metadata"):
tag = event.tagged_run_metadata.tag
if tag in self._tagged_metadata:
logger.warning(
'Found more than one "run metadata" event with tag '
+ tag
+ ". Overwriting it with the newest event."
)
self._tagged_metadata[tag] = event.tagged_run_metadata.run_metadata
elif event.HasField("summary"):
for value in event.summary.value:
if value.HasField("metadata"):
tag = value.tag
# We only store the first instance of the metadata. This check
# is important: the `FileWriter` does strip metadata from all
# values except the first one per each tag, but a new
# `FileWriter` is created every time a training job stops and
# restarts. Hence, we must also ignore non-initial metadata in
# this logic.
if tag not in self.summary_metadata:
self.summary_metadata[tag] = value.metadata
plugin_data = value.metadata.plugin_data
if plugin_data.plugin_name:
self._plugin_to_tag_to_content[
plugin_data.plugin_name
][tag] = plugin_data.content
else:
logger.warning(
(
"This summary with tag %r is oddly not associated with a "
"plugin."
),
tag,
)
for summary_type, summary_func in SUMMARY_TYPES.items():
if value.HasField(summary_type):
datum = getattr(value, summary_type)
tag = value.tag
if summary_type == "tensor" and not tag:
# This tensor summary was created using the old method that used
# plugin assets. We must still continue to support it.
tag = value.node_name
getattr(self, summary_func)(
tag, event.wall_time, event.step, datum
)
def Tags(self):
"""Return all tags found in the value stream.
Returns:
A `{tagType: ['list', 'of', 'tags']}` dictionary.
"""
return {
IMAGES: self.images.Keys(),
AUDIO: self.audios.Keys(),
HISTOGRAMS: self.histograms.Keys(),
SCALARS: self.scalars.Keys(),
COMPRESSED_HISTOGRAMS: self.compressed_histograms.Keys(),
TENSORS: self.tensors.Keys(),
# Use a heuristic: if the metagraph is available, but
# graph is not, then we assume the metagraph contains the graph.
GRAPH: self._graph is not None,
META_GRAPH: self._meta_graph is not None,
RUN_METADATA: list(self._tagged_metadata.keys()),
}
def Scalars(self, tag):
"""Given a summary tag, return all associated `ScalarEvent`s.
Args:
tag: A string tag associated with the events.
Raises:
KeyError: If the tag is not found.
Returns:
An array of `ScalarEvent`s.
"""
return self.scalars.Items(tag)
def Graph(self):
"""Return the graph definition, if there is one.
If the graph is stored directly, return that. If no graph is stored
directly but a metagraph is stored containing a graph, return that.
Raises:
ValueError: If there is no graph for this run.
Returns:
The `graph_def` proto.
"""
graph = graph_pb2.GraphDef()
if self._graph is not None:
graph.ParseFromString(self._graph)
return graph
raise ValueError("There is no graph in this EventAccumulator")
def MetaGraph(self):
"""Return the metagraph definition, if there is one.
Raises:
ValueError: If there is no metagraph for this run.
Returns:
The `meta_graph_def` proto.
"""
if self._meta_graph is None:
raise ValueError("There is no metagraph in this EventAccumulator")
meta_graph = meta_graph_pb2.MetaGraphDef()
meta_graph.ParseFromString(self._meta_graph)
return meta_graph
def RunMetadata(self, tag):
"""Given a tag, return the associated session.run() metadata.
Args:
tag: A string tag associated with the event.
Raises:
ValueError: If the tag is not found.
Returns:
The metadata in form of `RunMetadata` proto.
"""
if tag not in self._tagged_metadata:
raise ValueError("There is no run metadata with this tag name")
run_metadata = config_pb2.RunMetadata()
run_metadata.ParseFromString(self._tagged_metadata[tag])
return run_metadata
def Histograms(self, tag):
"""Given a summary tag, return all associated histograms.
Args:
tag: A string tag associated with the events.
Raises:
KeyError: If the tag is not found.
Returns:
An array of `HistogramEvent`s.
"""
return self.histograms.Items(tag)
def CompressedHistograms(self, tag):
"""Given a summary tag, return all associated compressed histograms.
Args:
tag: A string tag associated with the events.
Raises:
KeyError: If the tag is not found.
Returns:
An array of `CompressedHistogramEvent`s.
"""
return self.compressed_histograms.Items(tag)
def Images(self, tag):
"""Given a summary tag, return all associated images.
Args:
tag: A string tag associated with the events.
Raises:
KeyError: If the tag is not found.
Returns:
An array of `ImageEvent`s.
"""
return self.images.Items(tag)
def Audio(self, tag):
"""Given a summary tag, return all associated audio.
Args:
tag: A string tag associated with the events.
Raises:
KeyError: If the tag is not found.
Returns:
An array of `AudioEvent`s.
"""
return self.audios.Items(tag)
def Tensors(self, tag):
"""Given a summary tag, return all associated tensors.
Args:
tag: A string tag associated with the events.
Raises:
KeyError: If the tag is not found.
Returns:
An array of `TensorEvent`s.
"""
return self.tensors.Items(tag)
def _MaybePurgeOrphanedData(self, event):
"""Maybe purge orphaned data due to a TensorFlow crash.
When TensorFlow crashes at step T+O and restarts at step T, any events
written after step T are now "orphaned" and will be at best misleading if
they are included in TensorBoard.
This logic attempts to determine if there is orphaned data, and purge it
if it is found.
Args:
event: The event to use as a reference, to determine if a purge is needed.
"""
if not self.purge_orphaned_data:
return
## Check if the event happened after a crash, and purge expired tags.
if self.file_version and self.file_version >= 2:
## If the file_version is recent enough, use the SessionLog enum
## to check for restarts.
self._CheckForRestartAndMaybePurge(event)
else:
## If there is no file version, default to old logic of checking for
## out of order steps.
self._CheckForOutOfOrderStepAndMaybePurge(event)
def _CheckForRestartAndMaybePurge(self, event):
"""Check and discard expired events using SessionLog.START.
Check for a SessionLog.START event and purge all previously seen events
with larger steps, because they are out of date. Because of supervisor
threading, it is possible that this logic will cause the first few event
messages to be discarded since supervisor threading does not guarantee
that the START message is deterministically written first.
This method is preferred over _CheckForOutOfOrderStepAndMaybePurge which
can inadvertently discard events due to supervisor threading.
Args:
event: The event to use as reference. If the event is a START event, all
previously seen events with a greater event.step will be purged.
"""
if (
event.HasField("session_log")
and event.session_log.status == event_pb2.SessionLog.START
):
self._Purge(event, by_tags=False)
def _CheckForOutOfOrderStepAndMaybePurge(self, event):
"""Check for out-of-order event.step and discard expired events for
tags.
Check if the event is out of order relative to the global most recent step.
If it is, purge outdated summaries for tags that the event contains.
Args:
event: The event to use as reference. If the event is out-of-order, all
events with the same tags, but with a greater event.step will be purged.
"""
if event.step < self.most_recent_step and event.HasField("summary"):
self._Purge(event, by_tags=True)
else:
self.most_recent_step = event.step
self.most_recent_wall_time = event.wall_time
def _ConvertHistogramProtoToPopo(self, histo):
"""Converts histogram proto to Python object."""
return HistogramValue(
min=histo.min,
max=histo.max,
num=histo.num,
sum=histo.sum,
sum_squares=histo.sum_squares,
bucket_limit=list(histo.bucket_limit),
bucket=list(histo.bucket),
)
def _ProcessHistogram(self, tag, wall_time, step, histo):
"""Processes a proto histogram by adding it to accumulated state."""
histo = self._ConvertHistogramProtoToPopo(histo)
histo_ev = HistogramEvent(wall_time, step, histo)
self.histograms.AddItem(tag, histo_ev)
self.compressed_histograms.AddItem(
tag, histo_ev, self._CompressHistogram
)
def _CompressHistogram(self, histo_ev):
"""Callback for _ProcessHistogram."""
return CompressedHistogramEvent(
histo_ev.wall_time,
histo_ev.step,
compressor.compress_histogram_proto(
histo_ev.histogram_value, self._compression_bps
),
)
def _ProcessImage(self, tag, wall_time, step, image):
"""Processes an image by adding it to accumulated state."""
event = ImageEvent(
wall_time=wall_time,
step=step,
encoded_image_string=image.encoded_image_string,
width=image.width,
height=image.height,
)
self.images.AddItem(tag, event)
def _ProcessAudio(self, tag, wall_time, step, audio):
"""Processes a audio by adding it to accumulated state."""
event = AudioEvent(
wall_time=wall_time,
step=step,
encoded_audio_string=audio.encoded_audio_string,
content_type=audio.content_type,
sample_rate=audio.sample_rate,
length_frames=audio.length_frames,
)
self.audios.AddItem(tag, event)
def _ProcessScalar(self, tag, wall_time, step, scalar):
"""Processes a simple value by adding it to accumulated state."""
sv = ScalarEvent(wall_time=wall_time, step=step, value=scalar)
self.scalars.AddItem(tag, sv)
def _ProcessTensor(self, tag, wall_time, step, tensor):
tv = TensorEvent(wall_time=wall_time, step=step, tensor_proto=tensor)
self.tensors.AddItem(tag, tv)
def _Purge(self, event, by_tags):
"""Purge all events that have occurred after the given event.step.
If by_tags is True, purge all events that occurred after the given
event.step, but only for the tags that the event has. Non-sequential
event.steps suggest that a TensorFlow restart occurred, and we discard
the out-of-order events to display a consistent view in TensorBoard.
Discarding by tags is the safer method, when we are unsure whether a restart
has occurred, given that threading in supervisor can cause events of
different tags to arrive with unsynchronized step values.
If by_tags is False, then purge all events with event.step greater than the
given event.step. This can be used when we are certain that a TensorFlow
restart has occurred and these events can be discarded.
Args:
event: The event to use as reference for the purge. All events with
the same tags, but with a greater event.step will be purged.
by_tags: Bool to dictate whether to discard all out-of-order events or
only those that are associated with the given reference event.
"""
## Keep data in reservoirs that has a step less than event.step
_NotExpired = lambda x: x.step < event.step
if by_tags:
def _ExpiredPerTag(value):
return [
getattr(self, x).FilterItems(_NotExpired, value.tag)
for x in self.accumulated_attrs
]
expired_per_tags = [
_ExpiredPerTag(value) for value in event.summary.value
]
expired_per_type = [sum(x) for x in zip(*expired_per_tags)]
else:
expired_per_type = [
getattr(self, x).FilterItems(_NotExpired)
for x in self.accumulated_attrs
]
if sum(expired_per_type) > 0:
purge_msg = _GetPurgeMessage(
self.most_recent_step,
self.most_recent_wall_time,
event.step,
event.wall_time,
*expired_per_type,
)
logger.warning(purge_msg)
def _GetPurgeMessage(
most_recent_step,
most_recent_wall_time,
event_step,
event_wall_time,
num_expired_scalars,
num_expired_histos,
num_expired_comp_histos,
num_expired_images,
num_expired_audio,
):
"""Return the string message associated with TensorBoard purges."""
return (
"Detected out of order event.step likely caused by "
"a TensorFlow restart. Purging expired events from Tensorboard"
" display between the previous step: {} (timestamp: {}) and "
"current step: {} (timestamp: {}). Removing {} scalars, {} "
"histograms, {} compressed histograms, {} images, "
"and {} audio."
).format(
most_recent_step,
most_recent_wall_time,
event_step,
event_wall_time,
num_expired_scalars,
num_expired_histos,
num_expired_comp_histos,
num_expired_images,
num_expired_audio,
)
def _GeneratorFromPath(path):
"""Create an event generator for file or directory at given path string."""
if not path:
raise ValueError("path must be a valid string")
if io_wrapper.IsSummaryEventsFile(path):
return event_file_loader.LegacyEventFileLoader(path)
else:
return directory_watcher.DirectoryWatcher(
path,
event_file_loader.LegacyEventFileLoader,
io_wrapper.IsSummaryEventsFile,
)

View File

@ -0,0 +1,465 @@
# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Logic for TensorBoard inspector to help humans investigate event files.
Example usages:
tensorboard --inspect --event_file myevents.out
tensorboard --inspect --event_file myevents.out --tag loss
tensorboard --inspect --logdir mylogdir
tensorboard --inspect --logdir mylogdir --tag loss
This script runs over a logdir and creates an InspectionUnit for every
subdirectory with event files. If running over an event file, it creates only
one InspectionUnit. One block of output is printed to console for each
InspectionUnit.
The primary content of an InspectionUnit is the dict field_to_obs that maps
fields (e.g. "scalar", "histogram", "session_log:start", etc.) to a list of
Observations for the field. Observations correspond one-to-one with Events in an
event file but contain less information because they only store what is
necessary to generate the final console output.
The final output is rendered to console by applying some aggregating function
to the lists of Observations. Different functions are applied depending on the
type of field. For instance, for "scalar" fields, the inspector shows aggregate
statistics. For other fields like "session_log:start", all observed steps are
printed in order to aid debugging.
[1] Query a logdir or an event file for its logged tags and summary statistics
using --logdir or --event_file.
[[event_file]] contains these tags:
histograms
binary/Sign/Activations
binary/nn_tanh/act/Activations
binary/nn_tanh/biases
binary/nn_tanh/biases:gradient
binary/nn_tanh/weights
binary/nn_tanh/weights:gradient
images
input_images/image/0
input_images/image/1
input_images/image/2
scalars
Learning Rate
Total Cost
Total Cost (raw)
Debug output aggregated over all tags:
graph
first_step 0
last_step 0
max_step 0
min_step 0
num_steps 1
outoforder_steps []
histograms
first_step 491
last_step 659823
max_step 659823
min_step 491
num_steps 993
outoforder_steps []
images -
scalars
first_step 0
last_step 659823
max_step 659823
min_step 0
num_steps 1985
outoforder_steps []
sessionlog:checkpoint
first_step 7129
last_step 657167
max_step 657167
min_step 7129
num_steps 99
outoforder_steps []
sessionlog:start
outoforder_steps []
steps [0L]
sessionlog:stop -
[2] Drill down into a particular tag using --tag.
Debug output for binary/Sign/Activations:
histograms
first_step 491
last_step 659823
max_step 659823
min_step 491
num_steps 993
outoforder_steps []
"""
import dataclasses
import itertools
import os
from typing import Any, Generator, Mapping
from tensorboard.backend.event_processing import event_accumulator
from tensorboard.backend.event_processing import event_file_loader
from tensorboard.backend.event_processing import io_wrapper
from tensorboard.compat import tf
from tensorboard.compat.proto import event_pb2
# Map of field names within summary.proto to the user-facing names that this
# script outputs.
SUMMARY_TYPE_TO_FIELD = {
"simple_value": "scalars",
"histo": "histograms",
"image": "images",
"audio": "audio",
}
for summary_type in event_accumulator.SUMMARY_TYPES:
if summary_type not in SUMMARY_TYPE_TO_FIELD:
SUMMARY_TYPE_TO_FIELD[summary_type] = summary_type
# Types of summaries that we may want to query for by tag.
TAG_FIELDS = list(SUMMARY_TYPE_TO_FIELD.values())
# Summaries that we want to see every instance of.
LONG_FIELDS = ["sessionlog:start", "sessionlog:stop"]
# Summaries that we only want an abridged digest of, since they would
# take too much screen real estate otherwise.
SHORT_FIELDS = ["graph", "sessionlog:checkpoint"] + TAG_FIELDS
# All summary types that we can inspect.
TRACKED_FIELDS = SHORT_FIELDS + LONG_FIELDS
PRINT_SEPARATOR = "=" * 70 + "\n"
@dataclasses.dataclass(frozen=True)
class Observation:
"""Contains the data within each Event file that the inspector cares about.
The inspector accumulates Observations as it processes events.
Attributes:
step: Global step of the event.
wall_time: Timestamp of the event in seconds.
tag: Tag name associated with the event.
"""
step: int
wall_time: float
tag: str
@dataclasses.dataclass(frozen=True)
class InspectionUnit:
"""Created for each organizational structure in the event files.
An InspectionUnit is visible in the final terminal output. For instance, one
InspectionUnit is created for each subdirectory in logdir. When asked to inspect
a single event file, there may only be one InspectionUnit.
Attributes:
name: Name of the organizational unit that will be printed to console.
generator: A generator that yields `Event` protos.
field_to_obs: A mapping from string fields to `Observations` that the inspector
creates.
"""
name: str
generator: Generator[event_pb2.Event, Any, Any]
field_to_obs: Mapping[str, Observation]
def get_field_to_observations_map(generator, query_for_tag=""):
"""Return a field to `Observations` dict for the event generator.
Args:
generator: A generator over event protos.
query_for_tag: A string that if specified, only create observations for
events with this tag name.
Returns:
A dict mapping keys in `TRACKED_FIELDS` to an `Observation` list.
"""
def increment(stat, event, tag=""):
assert stat in TRACKED_FIELDS
field_to_obs[stat].append(
dataclasses.asdict(
Observation(step=event.step, wall_time=event.wall_time, tag=tag)
)
)
field_to_obs = dict([(t, []) for t in TRACKED_FIELDS])
for event in generator:
## Process the event
if event.HasField("graph_def") and (not query_for_tag):
increment("graph", event)
if event.HasField("session_log") and (not query_for_tag):
status = event.session_log.status
if status == event_pb2.SessionLog.START:
increment("sessionlog:start", event)
elif status == event_pb2.SessionLog.STOP:
increment("sessionlog:stop", event)
elif status == event_pb2.SessionLog.CHECKPOINT:
increment("sessionlog:checkpoint", event)
elif event.HasField("summary"):
for value in event.summary.value:
if query_for_tag and value.tag != query_for_tag:
continue
for proto_name, display_name in SUMMARY_TYPE_TO_FIELD.items():
if value.HasField(proto_name):
increment(display_name, event, value.tag)
return field_to_obs
def get_unique_tags(field_to_obs):
"""Returns a dictionary of tags that a user could query over.
Args:
field_to_obs: Dict that maps string field to `Observation` list.
Returns:
A dict that maps keys in `TAG_FIELDS` to a list of string tags present in
the event files. If the dict does not have any observations of the type,
maps to an empty list so that we can render this to console.
"""
return {
field: sorted(set([x.get("tag", "") for x in observations]))
for field, observations in field_to_obs.items()
if field in TAG_FIELDS
}
def print_dict(d, show_missing=True):
"""Prints a shallow dict to console.
Args:
d: Dict to print.
show_missing: Whether to show keys with empty values.
"""
for k, v in sorted(d.items()):
if (not v) and show_missing:
# No instances of the key, so print missing symbol.
print("{} -".format(k))
elif isinstance(v, list):
# Value is a list, so print each item of the list.
print(k)
for item in v:
print(" {}".format(item))
elif isinstance(v, dict):
# Value is a dict, so print each (key, value) pair of the dict.
print(k)
for kk, vv in sorted(v.items()):
print(" {:<20} {}".format(kk, vv))
def get_dict_to_print(field_to_obs):
"""Transform the field-to-obs mapping into a printable dictionary.
Args:
field_to_obs: Dict that maps string field to `Observation` list.
Returns:
A dict with the keys and values to print to console.
"""
def compressed_steps(steps):
return {
"num_steps": len(set(steps)),
"min_step": min(steps),
"max_step": max(steps),
"last_step": steps[-1],
"first_step": steps[0],
"outoforder_steps": get_out_of_order(steps),
}
def full_steps(steps):
return {"steps": steps, "outoforder_steps": get_out_of_order(steps)}
output = {}
for field, observations in field_to_obs.items():
if not observations:
output[field] = None
continue
steps = [x["step"] for x in observations]
if field in SHORT_FIELDS:
output[field] = compressed_steps(steps)
if field in LONG_FIELDS:
output[field] = full_steps(steps)
return output
def get_out_of_order(list_of_numbers):
"""Returns elements that break the monotonically non-decreasing trend.
This is used to find instances of global step values that are "out-of-order",
which may trigger TensorBoard event discarding logic.
Args:
list_of_numbers: A list of numbers.
Returns:
A list of tuples in which each tuple are two elements are adjacent, but the
second element is lower than the first.
"""
# TODO: Consider changing this to only check for out-of-order
# steps within a particular tag.
result = []
# pylint: disable=consider-using-enumerate
for i in range(len(list_of_numbers)):
if i == 0:
continue
if list_of_numbers[i] < list_of_numbers[i - 1]:
result.append((list_of_numbers[i - 1], list_of_numbers[i]))
return result
def generators_from_logdir(logdir):
"""Returns a list of event generators for subdirectories with event files.
The number of generators returned should equal the number of directories
within logdir that contain event files. If only logdir contains event files,
returns a list of length one.
Args:
logdir: A log directory that contains event files.
Returns:
List of event generators for each subdirectory with event files.
"""
subdirs = io_wrapper.GetLogdirSubdirectories(logdir)
generators = [
itertools.chain(
*[
generator_from_event_file(os.path.join(subdir, f))
for f in tf.io.gfile.listdir(subdir)
if io_wrapper.IsTensorFlowEventsFile(os.path.join(subdir, f))
]
)
for subdir in subdirs
]
return generators
def generator_from_event_file(event_file):
"""Returns a generator that yields events from an event file."""
return event_file_loader.LegacyEventFileLoader(event_file).Load()
def get_inspection_units(logdir="", event_file="", tag=""):
"""Returns a list of InspectionUnit objects given either logdir or
event_file.
If logdir is given, the number of InspectionUnits should equal the
number of directories or subdirectories that contain event files.
If event_file is given, the number of InspectionUnits should be 1.
Args:
logdir: A log directory that contains event files.
event_file: Or, a particular event file path.
tag: An optional tag name to query for.
Returns:
A list of InspectionUnit objects.
"""
if logdir:
subdirs = io_wrapper.GetLogdirSubdirectories(logdir)
inspection_units = []
for subdir in subdirs:
generator = itertools.chain(
*[
generator_from_event_file(os.path.join(subdir, f))
for f in tf.io.gfile.listdir(subdir)
if io_wrapper.IsTensorFlowEventsFile(
os.path.join(subdir, f)
)
]
)
inspection_units.append(
InspectionUnit(
name=subdir,
generator=generator,
field_to_obs=get_field_to_observations_map(generator, tag),
)
)
if inspection_units:
print(
"Found event files in:\n{}\n".format(
"\n".join([u.name for u in inspection_units])
)
)
elif io_wrapper.IsTensorFlowEventsFile(logdir):
print(
"It seems that {} may be an event file instead of a logdir. If this "
"is the case, use --event_file instead of --logdir to pass "
"it in.".format(logdir)
)
else:
print("No event files found within logdir {}".format(logdir))
return inspection_units
elif event_file:
generator = generator_from_event_file(event_file)
return [
InspectionUnit(
name=event_file,
generator=generator,
field_to_obs=get_field_to_observations_map(generator, tag),
)
]
return []
def inspect(logdir="", event_file="", tag=""):
"""Main function for inspector that prints out a digest of event files.
Args:
logdir: A log directory that contains event files.
event_file: Or, a particular event file path.
tag: An optional tag name to query for.
Raises:
ValueError: If neither logdir and event_file are given, or both are given.
"""
print(
PRINT_SEPARATOR
+ "Processing event files... (this can take a few minutes)\n"
+ PRINT_SEPARATOR
)
inspection_units = get_inspection_units(logdir, event_file, tag)
for unit in inspection_units:
if tag:
print("Event statistics for tag {} in {}:".format(tag, unit.name))
else:
# If the user is not inspecting a particular tag, also print the list of
# all available tags that they can query.
print("These tags are in {}:".format(unit.name))
print_dict(get_unique_tags(unit.field_to_obs))
print(PRINT_SEPARATOR)
print("Event statistics for {}:".format(unit.name))
print_dict(get_dict_to_print(unit.field_to_obs), show_missing=(not tag))
print(PRINT_SEPARATOR)

View File

@ -0,0 +1,293 @@
# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Functionality for loading events from a record file."""
import contextlib
from tensorboard import data_compat
from tensorboard import dataclass_compat
from tensorboard.compat import tf
from tensorboard.compat.proto import event_pb2
from tensorboard.util import platform_util
from tensorboard.util import tb_logging
logger = tb_logging.get_logger()
@contextlib.contextmanager
def _nullcontext():
"""Pre-Python-3.7-compatible standin for contextlib.nullcontext."""
yield
# Might as well make this a singleton.
_NULLCONTEXT = _nullcontext()
def _silence_deprecation_warnings():
"""Context manager that best-effort silences TF deprecation warnings."""
try:
# Learn this one weird trick to make TF deprecation warnings go away.
from tensorflow.python.util import deprecation
return deprecation.silence()
except (ImportError, AttributeError):
return _NULLCONTEXT
def _make_tf_record_iterator(file_path):
"""Returns an iterator over TF records for the given tfrecord file."""
# If we don't have TF at all, use the stub implementation.
if tf.__version__ == "stub":
# TODO(#1711): Reshape stub implementation to fit tf_record_iterator API
# rather than needlessly emulating the old PyRecordReader_New API.
logger.debug("Opening a stub record reader pointing at %s", file_path)
return _PyRecordReaderIterator(
tf.pywrap_tensorflow.PyRecordReader_New, file_path
)
# If PyRecordReader exists, use it, otherwise use tf_record_iterator().
# Check old first, then new, since tf_record_iterator existed previously but
# only gained the semantics we need at the time PyRecordReader was removed.
#
# TODO(#1711): Eventually remove PyRecordReader fallback once we can drop
# support for TF 2.1 and prior, and find a non-deprecated replacement for
# tf.compat.v1.io.tf_record_iterator.
try:
from tensorflow.python import pywrap_tensorflow
py_record_reader_new = pywrap_tensorflow.PyRecordReader_New
except (ImportError, AttributeError):
py_record_reader_new = None
if py_record_reader_new:
logger.debug("Opening a PyRecordReader pointing at %s", file_path)
return _PyRecordReaderIterator(py_record_reader_new, file_path)
else:
logger.debug("Opening a tf_record_iterator pointing at %s", file_path)
# TODO(#1711): Find non-deprecated replacement for tf_record_iterator.
with _silence_deprecation_warnings():
return tf.compat.v1.io.tf_record_iterator(file_path)
class _PyRecordReaderIterator:
"""Python iterator for TF Records based on PyRecordReader."""
def __init__(self, py_record_reader_new, file_path):
"""Constructs a _PyRecordReaderIterator for the given file path.
Args:
py_record_reader_new: pywrap_tensorflow.PyRecordReader_New
file_path: file path of the tfrecord file to read
"""
with tf.compat.v1.errors.raise_exception_on_not_ok_status() as status:
self._reader = py_record_reader_new(
tf.compat.as_bytes(file_path), 0, tf.compat.as_bytes(""), status
)
if not self._reader:
raise IOError(
"Failed to open a record reader pointing to %s" % file_path
)
def __iter__(self):
return self
def __next__(self):
try:
self._reader.GetNext()
except tf.errors.OutOfRangeError as e:
raise StopIteration
return self._reader.record()
next = __next__ # for python2 compatibility
class RawEventFileLoader:
"""An iterator that yields Event protos as serialized bytestrings."""
def __init__(self, file_path, detect_file_replacement=False):
"""Constructs a RawEventFileLoader for the given file path.
Args:
file_path: the event file path to read from
detect_file_replacement: if True, when Load() is called, the loader
will make a stat() call to check the size of the file. If it sees
that the file has grown, it will reopen the file entirely (while
preserving the current offset) before attempting to read from it.
Otherwise, Load() will simply poll at EOF for new data.
"""
if file_path is None:
raise ValueError("A file path is required")
self._file_path = platform_util.readahead_file_path(file_path)
self._detect_file_replacement = detect_file_replacement
self._file_size = None
self._iterator = _make_tf_record_iterator(self._file_path)
if self._detect_file_replacement and not hasattr(
self._iterator, "reopen"
):
logger.warning(
"File replacement detection requested, but not enabled because "
"TF record iterator impl does not support reopening. This "
"functionality requires TensorFlow 2.9+"
)
self._detect_file_replacement = False
def Load(self):
"""Loads all new events from disk as raw serialized proto bytestrings.
Calling Load multiple times in a row will not 'drop' events as long as the
return value is not iterated over.
Yields:
All event proto bytestrings in the file that have not been yielded yet.
"""
logger.debug("Loading events from %s", self._file_path)
if self._detect_file_replacement:
has_increased = self.CheckForIncreasedFileSize()
# Only act on the file size information if we got a concrete result.
if has_increased is not None:
if has_increased:
logger.debug(
"Reopening %s since file size has changed",
self._file_path,
)
self._iterator.close()
self._iterator.reopen()
else:
logger.debug(
"Skipping attempt to poll %s since file size has not "
"changed (still %d)",
self._file_path,
self._file_size,
)
return
while True:
try:
yield next(self._iterator)
except StopIteration:
logger.debug("End of file in %s", self._file_path)
break
except tf.errors.DataLossError as e:
# We swallow partial read exceptions; if the record was truncated
# and a later update completes it, retrying can then resume from
# the same point in the file since the iterator holds the offset.
logger.debug("Truncated record in %s (%s)", self._file_path, e)
break
logger.debug("No more events in %s", self._file_path)
def CheckForIncreasedFileSize(self):
"""Stats the file to get its updated size, returning True if it grew.
If the stat call fails or reports a smaller size than was previously
seen, then any previously cached size is left unchanged.
Returns:
boolean or None: True if the file size increased; False if it was
the same or decreased; or None if neither case could be detected
(either because the previous size had not been recorded yet, or
because the stat call for the current size failed).
"""
previous_size = self._file_size
try:
self._file_size = tf.io.gfile.stat(self._file_path).length
except tf.errors.OpError as e:
logger.error("Failed to stat %s: %s", self._file_path, e)
return None
logger.debug(
"Stat on %s got size %d, previous size %s",
self._file_path,
self._file_size,
previous_size,
)
if previous_size is None:
return None
if self._file_size > previous_size:
return True
if self._file_size < previous_size:
logger.warning(
"File %s shrank from previous size %d to size %d",
self._file_path,
previous_size,
self._file_size,
)
# In case this was transient, preserve the previously cached size,
# to avoid reporting a spurious increase next time. If the file was
# actually truncated, we can't recover anyway, so just ignore it.
self._file_size = previous_size
return False
class LegacyEventFileLoader(RawEventFileLoader):
"""An iterator that yields parsed Event protos."""
def Load(self):
"""Loads all new events from disk.
Calling Load multiple times in a row will not 'drop' events as long as the
return value is not iterated over.
Yields:
All events in the file that have not been yielded yet.
"""
for record in super().Load():
yield event_pb2.Event.FromString(record)
class EventFileLoader(LegacyEventFileLoader):
"""An iterator that passes events through read-time compat layers.
Specifically, this includes `data_compat` and `dataclass_compat`.
"""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
# Track initial metadata for each tag, for `dataclass_compat`.
# This is meant to be tracked per run, not per event file, so
# there is a potential failure case when the second event file
# in a single run has no summary metadata. This only occurs when
# all of the following hold: (a) the events were written with
# the TensorFlow 1.x (not 2.x) writer, (b) the summaries were
# created by `tensorboard.summary.v1` ops and so do not undergo
# `data_compat` transformation, and (c) the file writer was
# reopened by calling `.reopen()` on it, which creates a new
# file but does not clear the tag cache. This is considered
# sufficiently improbable that we don't take extra mitigations.
self._initial_metadata = {} # from tag name to `SummaryMetadata`
def Load(self):
for event in super().Load():
event = data_compat.migrate_event(event)
events = dataclass_compat.migrate_event(
event, self._initial_metadata
)
for event in events:
yield event
class TimestampedEventFileLoader(EventFileLoader):
"""An iterator that yields (UNIX timestamp float, Event proto) pairs."""
def Load(self):
"""Loads all new events and their wall time values from disk.
Calling Load multiple times in a row will not 'drop' events as long as the
return value is not iterated over.
Yields:
Pairs of (UNIX timestamp float, Event proto) for all events in the file
that have not been yielded yet.
"""
for event in super().Load():
yield (event.wall_time, event)

View File

@ -0,0 +1,523 @@
# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Provides an interface for working with multiple event files."""
import os
import threading
from typing import Optional
from tensorboard.backend.event_processing import directory_watcher
from tensorboard.backend.event_processing import event_accumulator
from tensorboard.backend.event_processing import io_wrapper
from tensorboard.util import tb_logging
logger = tb_logging.get_logger()
class EventMultiplexer:
"""An `EventMultiplexer` manages access to multiple `EventAccumulator`s.
Each `EventAccumulator` is associated with a `run`, which is a self-contained
TensorFlow execution. The `EventMultiplexer` provides methods for extracting
information about events from multiple `run`s.
Example usage for loading specific runs from files:
```python
x = EventMultiplexer({'run1': 'path/to/run1', 'run2': 'path/to/run2'})
x.Reload()
```
Example usage for loading a directory where each subdirectory is a run
```python
(eg:) /parent/directory/path/
/parent/directory/path/run1/
/parent/directory/path/run1/events.out.tfevents.1001
/parent/directory/path/run1/events.out.tfevents.1002
/parent/directory/path/run2/
/parent/directory/path/run2/events.out.tfevents.9232
/parent/directory/path/run3/
/parent/directory/path/run3/events.out.tfevents.9232
x = EventMultiplexer().AddRunsFromDirectory('/parent/directory/path')
(which is equivalent to:)
x = EventMultiplexer({'run1': '/parent/directory/path/run1', 'run2':...}
```
If you would like to watch `/parent/directory/path`, wait for it to be created
(if necessary) and then periodically pick up new runs, use
`AutoloadingMultiplexer`
@@Tensors
"""
def __init__(
self, run_path_map=None, size_guidance=None, purge_orphaned_data=True
):
"""Constructor for the `EventMultiplexer`.
Args:
run_path_map: Dict `{run: path}` which specifies the
name of a run, and the path to find the associated events. If it is
None, then the EventMultiplexer initializes without any runs.
size_guidance: A dictionary mapping from `tagType` to the number of items
to store for each tag of that type. See
`event_accumulator.EventAccumulator` for details.
purge_orphaned_data: Whether to discard any events that were "orphaned" by
a TensorFlow restart.
"""
logger.info("Event Multiplexer initializing.")
self._accumulators_mutex = threading.Lock()
self._accumulators = {}
self._paths = {}
self._reload_called = False
self._size_guidance = (
size_guidance or event_accumulator.DEFAULT_SIZE_GUIDANCE
)
self.purge_orphaned_data = purge_orphaned_data
if run_path_map is not None:
logger.info(
"Event Multplexer doing initialization load for %s",
run_path_map,
)
for run, path in run_path_map.items():
self.AddRun(path, run)
logger.info("Event Multiplexer done initializing")
def AddRun(self, path, name=None):
"""Add a run to the multiplexer.
If the name is not specified, it is the same as the path.
If a run by that name exists, and we are already watching the right path,
do nothing. If we are watching a different path, replace the event
accumulator.
If `Reload` has been called, it will `Reload` the newly created
accumulators.
Args:
path: Path to the event files (or event directory) for given run.
name: Name of the run to add. If not provided, is set to path.
Returns:
The `EventMultiplexer`.
"""
name = name or path
accumulator = None
with self._accumulators_mutex:
if name not in self._accumulators or self._paths[name] != path:
if name in self._paths and self._paths[name] != path:
# TODO(@decentralion) - Make it impossible to overwrite an old path
# with a new path (just give the new path a distinct name)
logger.warning(
"Conflict for name %s: old path %s, new path %s",
name,
self._paths[name],
path,
)
logger.info("Constructing EventAccumulator for %s", path)
accumulator = event_accumulator.EventAccumulator(
path,
size_guidance=self._size_guidance,
purge_orphaned_data=self.purge_orphaned_data,
)
self._accumulators[name] = accumulator
self._paths[name] = path
if accumulator:
if self._reload_called:
accumulator.Reload()
return self
def AddRunsFromDirectory(self, path, name=None):
"""Load runs from a directory; recursively walks subdirectories.
If path doesn't exist, no-op. This ensures that it is safe to call
`AddRunsFromDirectory` multiple times, even before the directory is made.
If path is a directory, load event files in the directory (if any exist) and
recursively call AddRunsFromDirectory on any subdirectories. This mean you
can call AddRunsFromDirectory at the root of a tree of event logs and
TensorBoard will load them all.
If the `EventMultiplexer` is already loaded this will cause
the newly created accumulators to `Reload()`.
Args:
path: A string path to a directory to load runs from.
name: Optionally, what name to apply to the runs. If name is provided
and the directory contains run subdirectories, the name of each subrun
is the concatenation of the parent name and the subdirectory name. If
name is provided and the directory contains event files, then a run
is added called "name" and with the events from the path.
Raises:
ValueError: If the path exists and isn't a directory.
Returns:
The `EventMultiplexer`.
"""
logger.info("Starting AddRunsFromDirectory: %s", path)
for subdir in io_wrapper.GetLogdirSubdirectories(path):
logger.info("Adding events from directory %s", subdir)
rpath = os.path.relpath(subdir, path)
subname = os.path.join(name, rpath) if name else rpath
self.AddRun(subdir, name=subname)
logger.info("Done with AddRunsFromDirectory: %s", path)
return self
def Reload(self):
"""Call `Reload` on every `EventAccumulator`."""
logger.info("Beginning EventMultiplexer.Reload()")
self._reload_called = True
# Build a list so we're safe even if the list of accumulators is modified
# even while we're reloading.
with self._accumulators_mutex:
items = list(self._accumulators.items())
names_to_delete = set()
for name, accumulator in items:
try:
accumulator.Reload()
except (OSError, IOError) as e:
logger.error("Unable to reload accumulator '%s': %s", name, e)
except directory_watcher.DirectoryDeletedError:
names_to_delete.add(name)
with self._accumulators_mutex:
for name in names_to_delete:
logger.warning("Deleting accumulator '%s'", name)
del self._accumulators[name]
logger.info("Finished with EventMultiplexer.Reload()")
return self
def PluginAssets(self, plugin_name):
"""Get index of runs and assets for a given plugin.
Args:
plugin_name: Name of the plugin we are checking for.
Returns:
A dictionary that maps from run_name to a list of plugin
assets for that run.
"""
with self._accumulators_mutex:
# To avoid nested locks, we construct a copy of the run-accumulator map
items = list(self._accumulators.items())
return {run: accum.PluginAssets(plugin_name) for run, accum in items}
def RetrievePluginAsset(self, run, plugin_name, asset_name):
"""Return the contents for a specific plugin asset from a run.
Args:
run: The string name of the run.
plugin_name: The string name of a plugin.
asset_name: The string name of an asset.
Returns:
The string contents of the plugin asset.
Raises:
KeyError: If the asset is not available.
"""
accumulator = self.GetAccumulator(run)
return accumulator.RetrievePluginAsset(plugin_name, asset_name)
def FirstEventTimestamp(self, run):
"""Return the timestamp of the first event of the given run.
This may perform I/O if no events have been loaded yet for the run.
Args:
run: A string name of the run for which the timestamp is retrieved.
Returns:
The wall_time of the first event of the run, which will typically be
seconds since the epoch.
Raises:
KeyError: If the run is not found.
ValueError: If the run has no events loaded and there are no events on
disk to load.
"""
accumulator = self.GetAccumulator(run)
return accumulator.FirstEventTimestamp()
def GetSourceWriter(self, run) -> Optional[str]:
"""Returns the source writer name from the first event of the given run.
Assuming each run has only one source writer.
Args:
run: A string name of the run from which the event source information
is retrieved.
Returns:
Name of the writer that wrote the events in the run.
"""
accumulator = self.GetAccumulator(run)
return accumulator.GetSourceWriter()
def Scalars(self, run, tag):
"""Retrieve the scalar events associated with a run and tag.
Args:
run: A string name of the run for which values are retrieved.
tag: A string name of the tag for which values are retrieved.
Raises:
KeyError: If the run is not found, or the tag is not available for
the given run.
Returns:
An array of `event_accumulator.ScalarEvents`.
"""
accumulator = self.GetAccumulator(run)
return accumulator.Scalars(tag)
def Graph(self, run):
"""Retrieve the graph associated with the provided run.
Args:
run: A string name of a run to load the graph for.
Raises:
KeyError: If the run is not found.
ValueError: If the run does not have an associated graph.
Returns:
The `GraphDef` protobuf data structure.
"""
accumulator = self.GetAccumulator(run)
return accumulator.Graph()
def SerializedGraph(self, run):
"""Retrieve the serialized graph associated with the provided run.
Args:
run: A string name of a run to load the graph for.
Raises:
KeyError: If the run is not found.
ValueError: If the run does not have an associated graph.
Returns:
The serialized form of the `GraphDef` protobuf data structure.
"""
accumulator = self.GetAccumulator(run)
return accumulator.SerializedGraph()
def MetaGraph(self, run):
"""Retrieve the metagraph associated with the provided run.
Args:
run: A string name of a run to load the graph for.
Raises:
KeyError: If the run is not found.
ValueError: If the run does not have an associated graph.
Returns:
The `MetaGraphDef` protobuf data structure.
"""
accumulator = self.GetAccumulator(run)
return accumulator.MetaGraph()
def RunMetadata(self, run, tag):
"""Get the session.run() metadata associated with a TensorFlow run and
tag.
Args:
run: A string name of a TensorFlow run.
tag: A string name of the tag associated with a particular session.run().
Raises:
KeyError: If the run is not found, or the tag is not available for the
given run.
Returns:
The metadata in the form of `RunMetadata` protobuf data structure.
"""
accumulator = self.GetAccumulator(run)
return accumulator.RunMetadata(tag)
def Histograms(self, run, tag):
"""Retrieve the histogram events associated with a run and tag.
Args:
run: A string name of the run for which values are retrieved.
tag: A string name of the tag for which values are retrieved.
Raises:
KeyError: If the run is not found, or the tag is not available for
the given run.
Returns:
An array of `event_accumulator.HistogramEvents`.
"""
accumulator = self.GetAccumulator(run)
return accumulator.Histograms(tag)
def CompressedHistograms(self, run, tag):
"""Retrieve the compressed histogram events associated with a run and
tag.
Args:
run: A string name of the run for which values are retrieved.
tag: A string name of the tag for which values are retrieved.
Raises:
KeyError: If the run is not found, or the tag is not available for
the given run.
Returns:
An array of `event_accumulator.CompressedHistogramEvents`.
"""
accumulator = self.GetAccumulator(run)
return accumulator.CompressedHistograms(tag)
def Images(self, run, tag):
"""Retrieve the image events associated with a run and tag.
Args:
run: A string name of the run for which values are retrieved.
tag: A string name of the tag for which values are retrieved.
Raises:
KeyError: If the run is not found, or the tag is not available for
the given run.
Returns:
An array of `event_accumulator.ImageEvents`.
"""
accumulator = self.GetAccumulator(run)
return accumulator.Images(tag)
def Audio(self, run, tag):
"""Retrieve the audio events associated with a run and tag.
Args:
run: A string name of the run for which values are retrieved.
tag: A string name of the tag for which values are retrieved.
Raises:
KeyError: If the run is not found, or the tag is not available for
the given run.
Returns:
An array of `event_accumulator.AudioEvents`.
"""
accumulator = self.GetAccumulator(run)
return accumulator.Audio(tag)
def Tensors(self, run, tag):
"""Retrieve the tensor events associated with a run and tag.
Args:
run: A string name of the run for which values are retrieved.
tag: A string name of the tag for which values are retrieved.
Raises:
KeyError: If the run is not found, or the tag is not available for
the given run.
Returns:
An array of `event_accumulator.TensorEvent`s.
"""
accumulator = self.GetAccumulator(run)
return accumulator.Tensors(tag)
def PluginRunToTagToContent(self, plugin_name):
"""Returns a 2-layer dictionary of the form {run: {tag: content}}.
The `content` referred above is the content field of the PluginData proto
for the specified plugin within a Summary.Value proto.
Args:
plugin_name: The name of the plugin for which to fetch content.
Returns:
A dictionary of the form {run: {tag: content}}.
"""
mapping = {}
for run in self.Runs():
try:
tag_to_content = self.GetAccumulator(run).PluginTagToContent(
plugin_name
)
except KeyError:
# This run lacks content for the plugin. Try the next run.
continue
mapping[run] = tag_to_content
return mapping
def SummaryMetadata(self, run, tag):
"""Return the summary metadata for the given tag on the given run.
Args:
run: A string name of the run for which summary metadata is to be
retrieved.
tag: A string name of the tag whose summary metadata is to be
retrieved.
Raises:
KeyError: If the run is not found, or the tag is not available for
the given run.
Returns:
A `SummaryMetadata` protobuf.
"""
accumulator = self.GetAccumulator(run)
return accumulator.SummaryMetadata(tag)
def Runs(self):
"""Return all the run names in the `EventMultiplexer`.
Returns:
```
{runName: { images: [tag1, tag2, tag3],
scalarValues: [tagA, tagB, tagC],
histograms: [tagX, tagY, tagZ],
compressedHistograms: [tagX, tagY, tagZ],
graph: true, meta_graph: true}}
```
"""
with self._accumulators_mutex:
# To avoid nested locks, we construct a copy of the run-accumulator map
items = list(self._accumulators.items())
return {run_name: accumulator.Tags() for run_name, accumulator in items}
def RunPaths(self):
"""Returns a dict mapping run names to event file paths."""
return self._paths
def GetAccumulator(self, run):
"""Returns EventAccumulator for a given run.
Args:
run: String name of run.
Returns:
An EventAccumulator object.
Raises:
KeyError: If run does not exist.
"""
with self._accumulators_mutex:
return self._accumulators[run]

View File

@ -0,0 +1,68 @@
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Functionality for processing events."""
from typing import Optional
from tensorboard.compat.proto import event_pb2
from tensorboard.util import tb_logging
logger = tb_logging.get_logger()
# Maxmimum length for event writer name.
_MAX_WRITER_NAME_LEN = 128
def ParseFileVersion(file_version: str) -> float:
"""Convert the string file_version in event.proto into a float.
Args:
file_version: String file_version from event.proto
Returns:
Version number as a float.
"""
tokens = file_version.split("brain.Event:")
try:
return float(tokens[-1])
except ValueError:
## This should never happen according to the definition of file_version
## specified in event.proto.
logger.warning(
(
"Invalid event.proto file_version. Defaulting to use of "
"out-of-order event.step logic for purging expired events."
)
)
return -1
def GetSourceWriter(
source_metadata: event_pb2.SourceMetadata,
) -> Optional[str]:
"""Gets the source writer name from the source metadata proto."""
writer_name = source_metadata.writer
if not writer_name:
return None
# Checks the length of the writer name.
if len(writer_name) > _MAX_WRITER_NAME_LEN:
logger.error(
"Source writer name `%s` is too long, maximum allowed length is %d.",
writer_name,
_MAX_WRITER_NAME_LEN,
)
return None
return writer_name

View File

@ -0,0 +1,224 @@
# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""IO helper functions."""
import collections
import os
import re
from tensorboard.compat import tf
from tensorboard.util import io_util
from tensorboard.util import tb_logging
logger = tb_logging.get_logger()
_ESCAPE_GLOB_CHARACTERS_REGEX = re.compile("([*?[])")
def PathSeparator(path):
return "/" if io_util.IsCloudPath(path) else os.sep
def IsTensorFlowEventsFile(path):
"""Check the path name to see if it is probably a TF Events file.
Args:
path: A file path to check if it is an event file.
Raises:
ValueError: If the path is an empty string.
Returns:
If path is formatted like a TensorFlowEventsFile. Dummy files such as
those created with the '.profile-empty' suffixes and meant to hold
no `Summary` protos are treated as true TensorFlowEventsFiles. For
background, see: https://github.com/tensorflow/tensorboard/issues/2084.
"""
if not path:
raise ValueError("Path must be a nonempty string")
return "tfevents" in tf.compat.as_str_any(os.path.basename(path))
def IsSummaryEventsFile(path):
"""Check whether the path is probably a TF Events file containing Summary.
Args:
path: A file path to check if it is an event file containing `Summary`
protos.
Returns:
If path is formatted like a TensorFlowEventsFile. Dummy files such as
those created with the '.profile-empty' suffixes and meant to hold
no `Summary` protos are treated as `False`. For background, see:
https://github.com/tensorflow/tensorboard/issues/2084.
"""
return IsTensorFlowEventsFile(path) and not path.endswith(".profile-empty")
def ListDirectoryAbsolute(directory):
"""Yields all files in the given directory.
The paths are absolute.
"""
return (
os.path.join(directory, path) for path in tf.io.gfile.listdir(directory)
)
def _EscapeGlobCharacters(path):
"""Escapes the glob characters in a path.
Python 3 has a glob.escape method, but python 2 lacks it, so we manually
implement this method.
Args:
path: The absolute path to escape.
Returns:
The escaped path string.
"""
drive, path = os.path.splitdrive(path)
return "%s%s" % (drive, _ESCAPE_GLOB_CHARACTERS_REGEX.sub(r"[\1]", path))
def ListRecursivelyViaGlobbing(top):
"""Recursively lists all files within the directory.
This method does not list subdirectories (in addition to regular files), and
the file paths are all absolute. If the directory does not exist, this yields
nothing.
This method does so by glob-ing deeper and deeper directories, ie
foo/*, foo/*/*, foo/*/*/* and so on until all files are listed. All file
paths are absolute, and this method lists subdirectories too.
For certain file systems, globbing via this method may prove significantly
faster than recursively walking a directory. Specifically, TF file systems
that implement TensorFlow's FileSystem.GetMatchingPaths method could save
costly disk reads by using this method. However, for other file systems, this
method might prove slower because the file system performs a walk per call to
glob (in which case it might as well just perform 1 walk).
Args:
top: A path to a directory.
Yields:
A (dir_path, file_paths) tuple for each directory/subdirectory.
"""
current_glob_string = os.path.join(_EscapeGlobCharacters(top), "*")
level = 0
while True:
logger.info("GlobAndListFiles: Starting to glob level %d", level)
glob = tf.io.gfile.glob(current_glob_string)
logger.info(
"GlobAndListFiles: %d files glob-ed at level %d", len(glob), level
)
if not glob:
# This subdirectory level lacks files. Terminate.
return
# Map subdirectory to a list of files.
pairs = collections.defaultdict(list)
for file_path in glob:
pairs[os.path.dirname(file_path)].append(file_path)
for dir_name, file_paths in pairs.items():
yield (dir_name, tuple(file_paths))
if len(pairs) == 1:
# If at any point the glob returns files that are all in a single
# directory, replace the current globbing path with that directory as the
# literal prefix. This should improve efficiency in cases where a single
# subdir is significantly deeper than the rest of the sudirs.
current_glob_string = os.path.join(list(pairs.keys())[0], "*")
# Iterate to the next level of subdirectories.
current_glob_string = os.path.join(current_glob_string, "*")
level += 1
def ListRecursivelyViaWalking(top):
"""Walks a directory tree, yielding (dir_path, file_paths) tuples.
For each of `top` and its subdirectories, yields a tuple containing the path
to the directory and the path to each of the contained files. Note that
unlike os.Walk()/tf.io.gfile.walk()/ListRecursivelyViaGlobbing, this does not
list subdirectories. The file paths are all absolute. If the directory does
not exist, this yields nothing.
Walking may be incredibly slow on certain file systems.
Args:
top: A path to a directory.
Yields:
A (dir_path, file_paths) tuple for each directory/subdirectory.
"""
for dir_path, _, filenames in tf.io.gfile.walk(top, topdown=True):
yield (
dir_path,
(os.path.join(dir_path, filename) for filename in filenames),
)
def GetLogdirSubdirectories(path):
"""Obtains all subdirectories with events files.
The order of the subdirectories returned is unspecified. The internal logic
that determines order varies by scenario.
Args:
path: The path to a directory under which to find subdirectories.
Returns:
A tuple of absolute paths of all subdirectories each with at least 1 events
file directly within the subdirectory.
Raises:
ValueError: If the path passed to the method exists and is not a directory.
"""
if not tf.io.gfile.exists(path):
# No directory to traverse.
return ()
if not tf.io.gfile.isdir(path):
raise ValueError(
"GetLogdirSubdirectories: path exists and is not a "
"directory, %s" % path
)
if io_util.IsCloudPath(path):
# Glob-ing for files can be significantly faster than recursively
# walking through directories for some file systems.
logger.info(
"GetLogdirSubdirectories: Starting to list directories via glob-ing."
)
traversal_method = ListRecursivelyViaGlobbing
else:
# For other file systems, the glob-ing based method might be slower because
# each call to glob could involve performing a recursive walk.
logger.info(
"GetLogdirSubdirectories: Starting to list directories via walking."
)
traversal_method = ListRecursivelyViaWalking
return (
subdir
for (subdir, files) in traversal_method(path)
if any(IsTensorFlowEventsFile(f) for f in files)
)

View File

@ -0,0 +1,105 @@
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Load plugin assets from disk."""
import os.path
from tensorboard.compat import tf
_PLUGINS_DIR = "plugins"
def _IsDirectory(parent, item):
"""Helper that returns if parent/item is a directory."""
return tf.io.gfile.isdir(os.path.join(parent, item))
def PluginDirectory(logdir, plugin_name):
"""Returns the plugin directory for plugin_name."""
return os.path.join(logdir, _PLUGINS_DIR, plugin_name)
def ListPlugins(logdir):
"""List all the plugins that have registered assets in logdir.
If the plugins_dir does not exist, it returns an empty list. This maintains
compatibility with old directories that have no plugins written.
Args:
logdir: A directory that was created by a TensorFlow events writer.
Returns:
a list of plugin names, as strings
"""
plugins_dir = os.path.join(logdir, _PLUGINS_DIR)
try:
entries = tf.io.gfile.listdir(plugins_dir)
except tf.errors.NotFoundError:
return []
# Strip trailing slashes, which listdir() includes for some filesystems
# for subdirectories, after using them to bypass IsDirectory().
return [
x.rstrip("/")
for x in entries
if x.endswith("/") or _IsDirectory(plugins_dir, x)
]
def ListAssets(logdir, plugin_name):
"""List all the assets that are available for given plugin in a logdir.
Args:
logdir: A directory that was created by a TensorFlow summary.FileWriter.
plugin_name: A string name of a plugin to list assets for.
Returns:
A string list of available plugin assets. If the plugin subdirectory does
not exist (either because the logdir doesn't exist, or because the plugin
didn't register) an empty list is returned.
"""
plugin_dir = PluginDirectory(logdir, plugin_name)
try:
# Strip trailing slashes, which listdir() includes for some filesystems.
return [x.rstrip("/") for x in tf.io.gfile.listdir(plugin_dir)]
except tf.errors.NotFoundError:
return []
def RetrieveAsset(logdir, plugin_name, asset_name):
"""Retrieve a particular plugin asset from a logdir.
Args:
logdir: A directory that was created by a TensorFlow summary.FileWriter.
plugin_name: The plugin we want an asset from.
asset_name: The name of the requested asset.
Returns:
string contents of the plugin asset.
Raises:
KeyError: if the asset does not exist.
"""
asset_path = os.path.join(PluginDirectory(logdir, plugin_name), asset_name)
try:
with tf.io.gfile.GFile(asset_path, "r") as f:
return f.read()
except tf.errors.NotFoundError:
raise KeyError("Asset path %s not found" % asset_path)
except tf.errors.OpError as e:
raise KeyError(
"Couldn't read asset path: %s, OpError %s" % (asset_path, e)
)

View File

@ -0,0 +1,722 @@
# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Takes a generator of values, and accumulates them for a frontend."""
import collections
import dataclasses
import threading
from typing import Optional
from tensorboard.backend.event_processing import directory_loader
from tensorboard.backend.event_processing import directory_watcher
from tensorboard.backend.event_processing import event_file_loader
from tensorboard.backend.event_processing import event_util
from tensorboard.backend.event_processing import io_wrapper
from tensorboard.backend.event_processing import plugin_asset_util
from tensorboard.backend.event_processing import reservoir
from tensorboard.backend.event_processing import tag_types
from tensorboard.compat.proto import config_pb2
from tensorboard.compat.proto import event_pb2
from tensorboard.compat.proto import graph_pb2
from tensorboard.compat.proto import meta_graph_pb2
from tensorboard.compat.proto import tensor_pb2
from tensorboard.util import tb_logging
logger = tb_logging.get_logger()
# Legacy aliases
TENSORS = tag_types.TENSORS
GRAPH = tag_types.GRAPH
META_GRAPH = tag_types.META_GRAPH
RUN_METADATA = tag_types.RUN_METADATA
DEFAULT_SIZE_GUIDANCE = {
TENSORS: 500,
}
STORE_EVERYTHING_SIZE_GUIDANCE = {
TENSORS: 0,
}
_TENSOR_RESERVOIR_KEY = "." # arbitrary
@dataclasses.dataclass(frozen=True)
class TensorEvent:
"""A tensor event.
Attributes:
wall_time: Timestamp of the event in seconds.
step: Global step of the event.
tensor_proto: A `TensorProto`.
"""
wall_time: float
step: int
tensor_proto: tensor_pb2.TensorProto
class EventAccumulator:
"""An `EventAccumulator` takes an event generator, and accumulates the
values.
The `EventAccumulator` is intended to provide a convenient Python
interface for loading Event data written during a TensorFlow run.
TensorFlow writes out `Event` protobuf objects, which have a timestamp
and step number, and often contain a `Summary`. Summaries can have
different kinds of data stored as arbitrary tensors. The Summaries
also have a tag, which we use to organize logically related data. The
`EventAccumulator` supports retrieving the `Event` and `Summary` data
by its tag.
Calling `Tags()` gets a map from `tagType` (i.e., `tensors`) to the
associated tags for those data types. Then, the functional endpoint
(i.g., `Accumulator.Tensors(tag)`) allows for the retrieval of all
data associated with that tag.
The `Reload()` method synchronously loads all of the data written so far.
Fields:
most_recent_step: Step of last Event proto added. This should only
be accessed from the thread that calls Reload. This is -1 if
nothing has been loaded yet.
most_recent_wall_time: Timestamp of last Event proto added. This is
a float containing seconds from the UNIX epoch, or -1 if
nothing has been loaded yet. This should only be accessed from
the thread that calls Reload.
path: A file path to a directory containing tf events files, or a single
tf events file. The accumulator will load events from this path.
tensors_by_tag: A dictionary mapping each tag name to a
reservoir.Reservoir of tensor summaries. Each such reservoir will
only use a single key, given by `_TENSOR_RESERVOIR_KEY`.
@@Tensors
"""
def __init__(
self,
path,
size_guidance=None,
tensor_size_guidance=None,
purge_orphaned_data=True,
event_file_active_filter=None,
detect_file_replacement=None,
):
"""Construct the `EventAccumulator`.
Args:
path: A file path to a directory containing tf events files, or a single
tf events file. The accumulator will load events from this path.
size_guidance: Information on how much data the EventAccumulator should
store in memory. The DEFAULT_SIZE_GUIDANCE tries not to store too much
so as to avoid OOMing the client. The size_guidance should be a map
from a `tagType` string to an integer representing the number of
items to keep per tag for items of that `tagType`. If the size is 0,
all events are stored.
tensor_size_guidance: Like `size_guidance`, but allowing finer
granularity for tensor summaries. Should be a map from the
`plugin_name` field on the `PluginData` proto to an integer
representing the number of items to keep per tag. Plugins for
which there is no entry in this map will default to the value of
`size_guidance[event_accumulator.TENSORS]`. Defaults to `{}`.
purge_orphaned_data: Whether to discard any events that were "orphaned" by
a TensorFlow restart.
event_file_active_filter: Optional predicate for determining whether an
event file latest load timestamp should be considered active. If passed,
this will enable multifile directory loading.
detect_file_replacement: Optional boolean; if True, event file loading
will try to detect when a file has been replaced with a new version
that contains additional data, by monitoring the file size.
"""
size_guidance = dict(size_guidance or DEFAULT_SIZE_GUIDANCE)
sizes = {}
for key in DEFAULT_SIZE_GUIDANCE:
if key in size_guidance:
sizes[key] = size_guidance[key]
else:
sizes[key] = DEFAULT_SIZE_GUIDANCE[key]
self._size_guidance = size_guidance
self._tensor_size_guidance = dict(tensor_size_guidance or {})
self._first_event_timestamp = None
self._graph = None
self._graph_from_metagraph = False
self._meta_graph = None
self._tagged_metadata = {}
self.summary_metadata = {}
self.tensors_by_tag = {}
self._tensors_by_tag_lock = threading.Lock()
# Keep a mapping from plugin name to a dict mapping from tag to plugin data
# content obtained from the SummaryMetadata (metadata field of Value) for
# that plugin (This is not the entire SummaryMetadata proto - only the
# content for that plugin). The SummaryWriter only keeps the content on the
# first event encountered per tag, so we must store that first instance of
# content for each tag.
self._plugin_to_tag_to_content = collections.defaultdict(dict)
# Locks the dict `_plugin_to_tag_to_content` as well as the
# dicts `_plugin_to_tag_to_content[p]` for each `p`.
self._plugin_tag_lock = threading.Lock()
self.path = path
self._generator = _GeneratorFromPath(
path, event_file_active_filter, detect_file_replacement
)
self._generator_mutex = threading.Lock()
self.purge_orphaned_data = purge_orphaned_data
self._seen_session_start = False
self.most_recent_step = -1
self.most_recent_wall_time = -1
self.file_version = None
# Name of the source writer that writes the event.
self._source_writer = None
def Reload(self):
"""Loads all events added since the last call to `Reload`.
If `Reload` was never called, loads all events in the file.
Returns:
The `EventAccumulator`.
"""
with self._generator_mutex:
for event in self._generator.Load():
self._ProcessEvent(event)
return self
def PluginAssets(self, plugin_name):
"""Return a list of all plugin assets for the given plugin.
Args:
plugin_name: The string name of a plugin to retrieve assets for.
Returns:
A list of string plugin asset names, or empty list if none are available.
If the plugin was not registered, an empty list is returned.
"""
return plugin_asset_util.ListAssets(self.path, plugin_name)
def RetrievePluginAsset(self, plugin_name, asset_name):
"""Return the contents of a given plugin asset.
Args:
plugin_name: The string name of a plugin.
asset_name: The string name of an asset.
Returns:
The string contents of the plugin asset.
Raises:
KeyError: If the asset is not available.
"""
return plugin_asset_util.RetrieveAsset(
self.path, plugin_name, asset_name
)
def FirstEventTimestamp(self):
"""Returns the timestamp in seconds of the first event.
If the first event has been loaded (either by this method or by `Reload`,
this returns immediately. Otherwise, it will load in the first event. Note
that this means that calling `Reload` will cause this to block until
`Reload` has finished.
Returns:
The timestamp in seconds of the first event that was loaded.
Raises:
ValueError: If no events have been loaded and there were no events found
on disk.
"""
if self._first_event_timestamp is not None:
return self._first_event_timestamp
with self._generator_mutex:
try:
event = next(self._generator.Load())
self._ProcessEvent(event)
return self._first_event_timestamp
except StopIteration:
raise ValueError("No event timestamp could be found")
def GetSourceWriter(self) -> Optional[str]:
"""Returns the name of the event writer."""
if self._source_writer is not None:
return self._source_writer
with self._generator_mutex:
try:
event = next(self._generator.Load())
self._ProcessEvent(event)
return self._source_writer
except StopIteration:
logger.info(
"End of file in %s, no source writer was found.", self.path
)
def PluginTagToContent(self, plugin_name):
"""Returns a dict mapping tags to content specific to that plugin.
Args:
plugin_name: The name of the plugin for which to fetch plugin-specific
content.
Raises:
KeyError: if the plugin name is not found.
Returns:
A dict mapping tag names to bytestrings of plugin-specific content-- by
convention, in the form of binary serialized protos.
"""
with self._plugin_tag_lock:
if plugin_name not in self._plugin_to_tag_to_content:
raise KeyError("Plugin %r could not be found." % plugin_name)
# Return a snapshot to avoid concurrent mutation and iteration issues.
return dict(self._plugin_to_tag_to_content[plugin_name])
def ActivePlugins(self):
"""Return a set of plugins with summary data.
Returns:
The distinct union of `plugin_data.plugin_name` fields from
all the `SummaryMetadata` protos stored in this accumulator.
"""
with self._plugin_tag_lock:
return frozenset(self._plugin_to_tag_to_content)
def SummaryMetadata(self, tag):
"""Given a summary tag name, return the associated metadata object.
Args:
tag: The name of a tag, as a string.
Raises:
KeyError: If the tag is not found.
Returns:
A `SummaryMetadata` protobuf.
"""
return self.summary_metadata[tag]
def AllSummaryMetadata(self):
"""Return summary metadata for all tags.
Returns:
A dict `d` such that `d[tag]` is a `SummaryMetadata` proto for
the keyed tag.
"""
return dict(self.summary_metadata)
def _ProcessEvent(self, event):
"""Called whenever an event is loaded."""
if self._first_event_timestamp is None:
self._first_event_timestamp = event.wall_time
if event.HasField("source_metadata"):
new_source_writer = event_util.GetSourceWriter(
event.source_metadata
)
if self._source_writer and self._source_writer != new_source_writer:
logger.info(
(
"Found new source writer for event.proto. "
"Old: {0}, New: {1}"
).format(self._source_writer, new_source_writer)
)
self._source_writer = new_source_writer
if event.HasField("file_version"):
new_file_version = event_util.ParseFileVersion(event.file_version)
if self.file_version and self.file_version != new_file_version:
## This should not happen.
logger.warning(
(
"Found new file_version for event.proto. This will "
"affect purging logic for TensorFlow restarts. "
"Old: {0} New: {1}"
).format(self.file_version, new_file_version)
)
self.file_version = new_file_version
self._MaybePurgeOrphanedData(event)
## Process the event.
# GraphDef and MetaGraphDef are handled in a special way:
# If no graph_def Event is available, but a meta_graph_def is, and it
# contains a graph_def, then use the meta_graph_def.graph_def as our graph.
# If a graph_def Event is available, always prefer it to the graph_def
# inside the meta_graph_def.
if event.HasField("graph_def"):
if self._graph is not None:
logger.warning(
(
"Found more than one graph event per run, or there was "
"a metagraph containing a graph_def, as well as one or "
"more graph events. Overwriting the graph with the "
"newest event."
)
)
self._graph = event.graph_def
self._graph_from_metagraph = False
elif event.HasField("meta_graph_def"):
if self._meta_graph is not None:
logger.warning(
(
"Found more than one metagraph event per run. "
"Overwriting the metagraph with the newest event."
)
)
self._meta_graph = event.meta_graph_def
if self._graph is None or self._graph_from_metagraph:
# We may have a graph_def in the metagraph. If so, and no
# graph_def is directly available, use this one instead.
meta_graph = meta_graph_pb2.MetaGraphDef()
meta_graph.ParseFromString(self._meta_graph)
if meta_graph.graph_def:
if self._graph is not None:
logger.warning(
(
"Found multiple metagraphs containing graph_defs,"
"but did not find any graph events. Overwriting the "
"graph with the newest metagraph version."
)
)
self._graph_from_metagraph = True
self._graph = meta_graph.graph_def.SerializeToString()
elif event.HasField("tagged_run_metadata"):
tag = event.tagged_run_metadata.tag
if tag in self._tagged_metadata:
logger.warning(
'Found more than one "run metadata" event with tag '
+ tag
+ ". Overwriting it with the newest event."
)
self._tagged_metadata[tag] = event.tagged_run_metadata.run_metadata
elif event.HasField("summary"):
for value in event.summary.value:
if value.HasField("metadata"):
tag = value.tag
# We only store the first instance of the metadata. This check
# is important: the `FileWriter` does strip metadata from all
# values except the first one per each tag, but a new
# `FileWriter` is created every time a training job stops and
# restarts. Hence, we must also ignore non-initial metadata in
# this logic.
if tag not in self.summary_metadata:
self.summary_metadata[tag] = value.metadata
plugin_data = value.metadata.plugin_data
if plugin_data.plugin_name:
with self._plugin_tag_lock:
self._plugin_to_tag_to_content[
plugin_data.plugin_name
][tag] = plugin_data.content
else:
logger.warning(
(
"This summary with tag %r is oddly not associated with a "
"plugin."
),
tag,
)
if value.HasField("tensor"):
datum = value.tensor
tag = value.tag
if not tag:
# This tensor summary was created using the old method that used
# plugin assets. We must still continue to support it.
tag = value.node_name
self._ProcessTensor(tag, event.wall_time, event.step, datum)
def Tags(self):
"""Return all tags found in the value stream.
Returns:
A `{tagType: ['list', 'of', 'tags']}` dictionary.
"""
return {
TENSORS: list(self.tensors_by_tag.keys()),
# Use a heuristic: if the metagraph is available, but
# graph is not, then we assume the metagraph contains the graph.
GRAPH: self._graph is not None,
META_GRAPH: self._meta_graph is not None,
RUN_METADATA: list(self._tagged_metadata.keys()),
}
def Graph(self):
"""Return the graph definition, if there is one.
If the graph is stored directly, return that. If no graph is stored
directly but a metagraph is stored containing a graph, return that.
Raises:
ValueError: If there is no graph for this run.
Returns:
The `graph_def` proto.
"""
graph = graph_pb2.GraphDef()
if self._graph is not None:
graph.ParseFromString(self._graph)
return graph
raise ValueError("There is no graph in this EventAccumulator")
def SerializedGraph(self):
"""Return the graph definition in serialized form, if there is one."""
return self._graph
def MetaGraph(self):
"""Return the metagraph definition, if there is one.
Raises:
ValueError: If there is no metagraph for this run.
Returns:
The `meta_graph_def` proto.
"""
if self._meta_graph is None:
raise ValueError("There is no metagraph in this EventAccumulator")
meta_graph = meta_graph_pb2.MetaGraphDef()
meta_graph.ParseFromString(self._meta_graph)
return meta_graph
def RunMetadata(self, tag):
"""Given a tag, return the associated session.run() metadata.
Args:
tag: A string tag associated with the event.
Raises:
ValueError: If the tag is not found.
Returns:
The metadata in form of `RunMetadata` proto.
"""
if tag not in self._tagged_metadata:
raise ValueError("There is no run metadata with this tag name")
run_metadata = config_pb2.RunMetadata()
run_metadata.ParseFromString(self._tagged_metadata[tag])
return run_metadata
def Tensors(self, tag):
"""Given a summary tag, return all associated tensors.
Args:
tag: A string tag associated with the events.
Raises:
KeyError: If the tag is not found.
Returns:
An array of `TensorEvent`s.
"""
return self.tensors_by_tag[tag].Items(_TENSOR_RESERVOIR_KEY)
def _MaybePurgeOrphanedData(self, event):
"""Maybe purge orphaned data due to a TensorFlow crash.
When TensorFlow crashes at step T+O and restarts at step T, any events
written after step T are now "orphaned" and will be at best misleading if
they are included in TensorBoard.
This logic attempts to determine if there is orphaned data, and purge it
if it is found.
Args:
event: The event to use as a reference, to determine if a purge is needed.
"""
if not self.purge_orphaned_data:
return
## Check if the event happened after a crash, and purge expired tags.
if self.file_version and self.file_version >= 2:
## If the file_version is recent enough, use the SessionLog enum
## to check for restarts.
self._CheckForRestartAndMaybePurge(event)
else:
## If there is no file version, default to old logic of checking for
## out of order steps.
self._CheckForOutOfOrderStepAndMaybePurge(event)
# After checking, update the most recent summary step and wall time.
if event.HasField("summary"):
self.most_recent_step = event.step
self.most_recent_wall_time = event.wall_time
def _CheckForRestartAndMaybePurge(self, event):
"""Check and discard expired events using SessionLog.START.
The first SessionLog.START event in a run indicates the start of a
supervisor session. Subsequent SessionLog.START events indicate a
*restart*, which may need to preempt old events. This method checks
for a session restart event and purges all previously seen events whose
step is larger than or equal to this event's step.
Because of supervisor threading, it is possible that this logic will
cause the first few event messages to be discarded since supervisor
threading does not guarantee that the START message is deterministically
written first.
This method is preferred over _CheckForOutOfOrderStepAndMaybePurge which
can inadvertently discard events due to supervisor threading.
Args:
event: The event to use as reference. If the event is a START event, all
previously seen events with a greater event.step will be purged.
"""
if event.session_log.status != event_pb2.SessionLog.START:
return
if not self._seen_session_start:
# Initial start event: does not indicate a restart.
self._seen_session_start = True
return
self._Purge(event, by_tags=False)
def _CheckForOutOfOrderStepAndMaybePurge(self, event):
"""Check for out-of-order event.step and discard expired events for
tags.
Check if the event is out of order relative to the global most recent step.
If it is, purge outdated summaries for tags that the event contains.
Args:
event: The event to use as reference. If the event is out-of-order, all
events with the same tags, but with a greater event.step will be purged.
"""
if event.step < self.most_recent_step and event.HasField("summary"):
self._Purge(event, by_tags=True)
def _ProcessTensor(self, tag, wall_time, step, tensor):
tv = TensorEvent(wall_time=wall_time, step=step, tensor_proto=tensor)
with self._tensors_by_tag_lock:
if tag not in self.tensors_by_tag:
reservoir_size = self._GetTensorReservoirSize(tag)
self.tensors_by_tag[tag] = reservoir.Reservoir(reservoir_size)
self.tensors_by_tag[tag].AddItem(_TENSOR_RESERVOIR_KEY, tv)
def _GetTensorReservoirSize(self, tag):
default = self._size_guidance[TENSORS]
summary_metadata = self.summary_metadata.get(tag)
if summary_metadata is None:
return default
return self._tensor_size_guidance.get(
summary_metadata.plugin_data.plugin_name, default
)
def _Purge(self, event, by_tags):
"""Purge all events that have occurred after the given event.step.
If by_tags is True, purge all events that occurred after the given
event.step, but only for the tags that the event has. Non-sequential
event.steps suggest that a TensorFlow restart occurred, and we discard
the out-of-order events to display a consistent view in TensorBoard.
Discarding by tags is the safer method, when we are unsure whether a restart
has occurred, given that threading in supervisor can cause events of
different tags to arrive with unsynchronized step values.
If by_tags is False, then purge all events with event.step greater than the
given event.step. This can be used when we are certain that a TensorFlow
restart has occurred and these events can be discarded.
Args:
event: The event to use as reference for the purge. All events with
the same tags, but with a greater event.step will be purged.
by_tags: Bool to dictate whether to discard all out-of-order events or
only those that are associated with the given reference event.
"""
## Keep data in reservoirs that has a step less than event.step
_NotExpired = lambda x: x.step < event.step
num_expired = 0
if by_tags:
for value in event.summary.value:
if value.tag in self.tensors_by_tag:
tag_reservoir = self.tensors_by_tag[value.tag]
num_expired += tag_reservoir.FilterItems(
_NotExpired, _TENSOR_RESERVOIR_KEY
)
else:
for tag_reservoir in self.tensors_by_tag.values():
num_expired += tag_reservoir.FilterItems(
_NotExpired, _TENSOR_RESERVOIR_KEY
)
if num_expired > 0:
purge_msg = _GetPurgeMessage(
self.most_recent_step,
self.most_recent_wall_time,
event.step,
event.wall_time,
num_expired,
)
logger.warning(purge_msg)
def _GetPurgeMessage(
most_recent_step,
most_recent_wall_time,
event_step,
event_wall_time,
num_expired,
):
"""Return the string message associated with TensorBoard purges."""
return (
"Detected out of order event.step likely caused by a TensorFlow "
"restart. Purging {} expired tensor events from Tensorboard display "
"between the previous step: {} (timestamp: {}) and current step: {} "
"(timestamp: {})."
).format(
num_expired,
most_recent_step,
most_recent_wall_time,
event_step,
event_wall_time,
)
def _GeneratorFromPath(
path, event_file_active_filter=None, detect_file_replacement=None
):
"""Create an event generator for file or directory at given path string."""
if not path:
raise ValueError("path must be a valid string")
if io_wrapper.IsSummaryEventsFile(path):
return event_file_loader.EventFileLoader(path, detect_file_replacement)
elif event_file_active_filter:
loader_factory = (
lambda path: event_file_loader.TimestampedEventFileLoader(
path, detect_file_replacement
)
)
return directory_loader.DirectoryLoader(
path,
loader_factory,
path_filter=io_wrapper.IsSummaryEventsFile,
active_filter=event_file_active_filter,
)
else:
loader_factory = lambda path: event_file_loader.EventFileLoader(
path, detect_file_replacement
)
return directory_watcher.DirectoryWatcher(
path,
loader_factory,
io_wrapper.IsSummaryEventsFile,
)

View File

@ -0,0 +1,524 @@
# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Provides an interface for working with multiple event files."""
import os
import queue
import threading
from typing import Optional
from tensorboard.backend.event_processing import directory_watcher
from tensorboard.backend.event_processing import (
plugin_event_accumulator as event_accumulator,
)
from tensorboard.backend.event_processing import io_wrapper
from tensorboard.util import tb_logging
logger = tb_logging.get_logger()
class EventMultiplexer:
"""An `EventMultiplexer` manages access to multiple `EventAccumulator`s.
Each `EventAccumulator` is associated with a `run`, which is a self-contained
TensorFlow execution. The `EventMultiplexer` provides methods for extracting
information about events from multiple `run`s.
Example usage for loading specific runs from files:
```python
x = EventMultiplexer({'run1': 'path/to/run1', 'run2': 'path/to/run2'})
x.Reload()
```
Example usage for loading a directory where each subdirectory is a run
```python
(eg:) /parent/directory/path/
/parent/directory/path/run1/
/parent/directory/path/run1/events.out.tfevents.1001
/parent/directory/path/run1/events.out.tfevents.1002
/parent/directory/path/run2/
/parent/directory/path/run2/events.out.tfevents.9232
/parent/directory/path/run3/
/parent/directory/path/run3/events.out.tfevents.9232
x = EventMultiplexer().AddRunsFromDirectory('/parent/directory/path')
(which is equivalent to:)
x = EventMultiplexer({'run1': '/parent/directory/path/run1', 'run2':...}
```
If you would like to watch `/parent/directory/path`, wait for it to be created
(if necessary) and then periodically pick up new runs, use
`AutoloadingMultiplexer`
@@Tensors
"""
def __init__(
self,
run_path_map=None,
size_guidance=None,
tensor_size_guidance=None,
purge_orphaned_data=True,
max_reload_threads=None,
event_file_active_filter=None,
detect_file_replacement=None,
):
"""Constructor for the `EventMultiplexer`.
Args:
run_path_map: Dict `{run: path}` which specifies the
name of a run, and the path to find the associated events. If it is
None, then the EventMultiplexer initializes without any runs.
size_guidance: A dictionary mapping from `tagType` to the number of items
to store for each tag of that type. See
`event_accumulator.EventAccumulator` for details.
tensor_size_guidance: A dictionary mapping from `plugin_name` to
the number of items to store for each tag of that type. See
`event_accumulator.EventAccumulator` for details.
purge_orphaned_data: Whether to discard any events that were "orphaned" by
a TensorFlow restart.
max_reload_threads: The max number of threads that TensorBoard can use
to reload runs. Each thread reloads one run at a time. If not provided,
reloads runs serially (one after another).
event_file_active_filter: Optional predicate for determining whether an
event file latest load timestamp should be considered active. If passed,
this will enable multifile directory loading.
detect_file_replacement: Optional boolean; if True, event file loading
will try to detect when a file has been replaced with a new version
that contains additional data, by monitoring the file size.
"""
logger.info("Event Multiplexer initializing.")
self._accumulators_mutex = threading.Lock()
self._accumulators = {}
self._paths = {}
self._reload_called = False
self._size_guidance = (
size_guidance or event_accumulator.DEFAULT_SIZE_GUIDANCE
)
self._tensor_size_guidance = tensor_size_guidance
self.purge_orphaned_data = purge_orphaned_data
self._max_reload_threads = max_reload_threads or 1
self._event_file_active_filter = event_file_active_filter
self._detect_file_replacement = detect_file_replacement
if run_path_map is not None:
logger.info(
"Event Multplexer doing initialization load for %s",
run_path_map,
)
for run, path in run_path_map.items():
self.AddRun(path, run)
logger.info("Event Multiplexer done initializing")
def AddRun(self, path, name=None):
"""Add a run to the multiplexer.
If the name is not specified, it is the same as the path.
If a run by that name exists, and we are already watching the right path,
do nothing. If we are watching a different path, replace the event
accumulator.
If `Reload` has been called, it will `Reload` the newly created
accumulators.
Args:
path: Path to the event files (or event directory) for given run.
name: Name of the run to add. If not provided, is set to path.
Returns:
The `EventMultiplexer`.
"""
name = name or path
accumulator = None
with self._accumulators_mutex:
if name not in self._accumulators or self._paths[name] != path:
if name in self._paths and self._paths[name] != path:
# TODO(@decentralion) - Make it impossible to overwrite an old path
# with a new path (just give the new path a distinct name)
logger.warning(
"Conflict for name %s: old path %s, new path %s",
name,
self._paths[name],
path,
)
logger.info("Constructing EventAccumulator for %s", path)
accumulator = event_accumulator.EventAccumulator(
path,
size_guidance=self._size_guidance,
tensor_size_guidance=self._tensor_size_guidance,
purge_orphaned_data=self.purge_orphaned_data,
event_file_active_filter=self._event_file_active_filter,
detect_file_replacement=self._detect_file_replacement,
)
self._accumulators[name] = accumulator
self._paths[name] = path
if accumulator:
if self._reload_called:
accumulator.Reload()
return self
def AddRunsFromDirectory(self, path, name=None):
"""Load runs from a directory; recursively walks subdirectories.
If path doesn't exist, no-op. This ensures that it is safe to call
`AddRunsFromDirectory` multiple times, even before the directory is made.
If path is a directory, load event files in the directory (if any exist) and
recursively call AddRunsFromDirectory on any subdirectories. This mean you
can call AddRunsFromDirectory at the root of a tree of event logs and
TensorBoard will load them all.
If the `EventMultiplexer` is already loaded this will cause
the newly created accumulators to `Reload()`.
Args:
path: A string path to a directory to load runs from.
name: Optionally, what name to apply to the runs. If name is provided
and the directory contains run subdirectories, the name of each subrun
is the concatenation of the parent name and the subdirectory name. If
name is provided and the directory contains event files, then a run
is added called "name" and with the events from the path.
Raises:
ValueError: If the path exists and isn't a directory.
Returns:
The `EventMultiplexer`.
"""
path = os.path.expanduser(path)
logger.info("Starting AddRunsFromDirectory: %s", path)
for subdir in io_wrapper.GetLogdirSubdirectories(path):
logger.info("Adding run from directory %s", subdir)
rpath = os.path.relpath(subdir, path)
subname = os.path.join(name, rpath) if name else rpath
self.AddRun(subdir, name=subname)
logger.info("Done with AddRunsFromDirectory: %s", path)
return self
def Reload(self):
"""Call `Reload` on every `EventAccumulator`."""
logger.info("Beginning EventMultiplexer.Reload()")
self._reload_called = True
# Build a list so we're safe even if the list of accumulators is modified
# even while we're reloading.
with self._accumulators_mutex:
items = list(self._accumulators.items())
items_queue = queue.Queue()
for item in items:
items_queue.put(item)
# Methods of built-in python containers are thread-safe so long as the GIL
# for the thread exists, but we might as well be careful.
names_to_delete = set()
names_to_delete_mutex = threading.Lock()
def Worker():
"""Keeps reloading accumulators til none are left."""
while True:
try:
name, accumulator = items_queue.get(block=False)
except queue.Empty:
# No more runs to reload.
break
try:
accumulator.Reload()
except (OSError, IOError) as e:
logger.error("Unable to reload accumulator %r: %s", name, e)
except directory_watcher.DirectoryDeletedError:
with names_to_delete_mutex:
names_to_delete.add(name)
finally:
items_queue.task_done()
if self._max_reload_threads > 1:
num_threads = min(self._max_reload_threads, len(items))
logger.info("Starting %d threads to reload runs", num_threads)
for i in range(num_threads):
thread = threading.Thread(target=Worker, name="Reloader %d" % i)
thread.daemon = True
thread.start()
items_queue.join()
else:
logger.info(
"Reloading runs serially (one after another) on the main "
"thread."
)
Worker()
with self._accumulators_mutex:
for name in names_to_delete:
logger.warning("Deleting accumulator %r", name)
del self._accumulators[name]
logger.info("Finished with EventMultiplexer.Reload()")
return self
def PluginAssets(self, plugin_name):
"""Get index of runs and assets for a given plugin.
Args:
plugin_name: Name of the plugin we are checking for.
Returns:
A dictionary that maps from run_name to a list of plugin
assets for that run.
"""
with self._accumulators_mutex:
# To avoid nested locks, we construct a copy of the run-accumulator map
items = list(self._accumulators.items())
return {run: accum.PluginAssets(plugin_name) for run, accum in items}
def RetrievePluginAsset(self, run, plugin_name, asset_name):
"""Return the contents for a specific plugin asset from a run.
Args:
run: The string name of the run.
plugin_name: The string name of a plugin.
asset_name: The string name of an asset.
Returns:
The string contents of the plugin asset.
Raises:
KeyError: If the asset is not available.
"""
accumulator = self.GetAccumulator(run)
return accumulator.RetrievePluginAsset(plugin_name, asset_name)
def FirstEventTimestamp(self, run):
"""Return the timestamp of the first event of the given run.
This may perform I/O if no events have been loaded yet for the run.
Args:
run: A string name of the run for which the timestamp is retrieved.
Returns:
The wall_time of the first event of the run, which will typically be
seconds since the epoch.
Raises:
KeyError: If the run is not found.
ValueError: If the run has no events loaded and there are no events on
disk to load.
"""
accumulator = self.GetAccumulator(run)
return accumulator.FirstEventTimestamp()
def GetSourceWriter(self, run) -> Optional[str]:
"""Returns the source writer name from the first event of the given run.
Assuming each run has only one source writer.
Args:
run: A string name of the run from which the event source information
is retrieved.
Returns:
Name of the writer that wrote the events in the run.
"""
accumulator = self.GetAccumulator(run)
return accumulator.GetSourceWriter()
def Graph(self, run):
"""Retrieve the graph associated with the provided run.
Args:
run: A string name of a run to load the graph for.
Raises:
KeyError: If the run is not found.
ValueError: If the run does not have an associated graph.
Returns:
The `GraphDef` protobuf data structure.
"""
accumulator = self.GetAccumulator(run)
return accumulator.Graph()
def SerializedGraph(self, run):
"""Retrieve the serialized graph associated with the provided run.
Args:
run: A string name of a run to load the graph for.
Raises:
KeyError: If the run is not found.
ValueError: If the run does not have an associated graph.
Returns:
The serialized form of the `GraphDef` protobuf data structure.
"""
accumulator = self.GetAccumulator(run)
return accumulator.SerializedGraph()
def MetaGraph(self, run):
"""Retrieve the metagraph associated with the provided run.
Args:
run: A string name of a run to load the graph for.
Raises:
KeyError: If the run is not found.
ValueError: If the run does not have an associated graph.
Returns:
The `MetaGraphDef` protobuf data structure.
"""
accumulator = self.GetAccumulator(run)
return accumulator.MetaGraph()
def RunMetadata(self, run, tag):
"""Get the session.run() metadata associated with a TensorFlow run and
tag.
Args:
run: A string name of a TensorFlow run.
tag: A string name of the tag associated with a particular session.run().
Raises:
KeyError: If the run is not found, or the tag is not available for the
given run.
Returns:
The metadata in the form of `RunMetadata` protobuf data structure.
"""
accumulator = self.GetAccumulator(run)
return accumulator.RunMetadata(tag)
def Tensors(self, run, tag):
"""Retrieve the tensor events associated with a run and tag.
Args:
run: A string name of the run for which values are retrieved.
tag: A string name of the tag for which values are retrieved.
Raises:
KeyError: If the run is not found, or the tag is not available for
the given run.
Returns:
An array of `event_accumulator.TensorEvent`s.
"""
accumulator = self.GetAccumulator(run)
return accumulator.Tensors(tag)
def PluginRunToTagToContent(self, plugin_name):
"""Returns a 2-layer dictionary of the form {run: {tag: content}}.
The `content` referred above is the content field of the PluginData proto
for the specified plugin within a Summary.Value proto.
Args:
plugin_name: The name of the plugin for which to fetch content.
Returns:
A dictionary of the form {run: {tag: content}}.
"""
mapping = {}
for run in self.Runs():
try:
tag_to_content = self.GetAccumulator(run).PluginTagToContent(
plugin_name
)
except KeyError:
# This run lacks content for the plugin. Try the next run.
continue
mapping[run] = tag_to_content
return mapping
def ActivePlugins(self):
"""Return a set of plugins with summary data.
Returns:
The distinct union of `plugin_data.plugin_name` fields from
all the `SummaryMetadata` protos stored in any run known to
this multiplexer.
"""
with self._accumulators_mutex:
accumulators = list(self._accumulators.values())
return frozenset().union(*(a.ActivePlugins() for a in accumulators))
def SummaryMetadata(self, run, tag):
"""Return the summary metadata for the given tag on the given run.
Args:
run: A string name of the run for which summary metadata is to be
retrieved.
tag: A string name of the tag whose summary metadata is to be
retrieved.
Raises:
KeyError: If the run is not found, or the tag is not available for
the given run.
Returns:
A `SummaryMetadata` protobuf.
"""
accumulator = self.GetAccumulator(run)
return accumulator.SummaryMetadata(tag)
def AllSummaryMetadata(self):
"""Return summary metadata for all time series.
Returns:
A nested dict `d` such that `d[run][tag]` is a
`SummaryMetadata` proto for the keyed time series.
"""
with self._accumulators_mutex:
# To avoid nested locks, we construct a copy of the run-accumulator map
items = list(self._accumulators.items())
return {
run_name: accumulator.AllSummaryMetadata()
for run_name, accumulator in items
}
def Runs(self):
"""Return all the run names in the `EventMultiplexer`.
Returns:
```
{runName: { scalarValues: [tagA, tagB, tagC],
graph: true, meta_graph: true}}
```
"""
with self._accumulators_mutex:
# To avoid nested locks, we construct a copy of the run-accumulator map
items = list(self._accumulators.items())
return {run_name: accumulator.Tags() for run_name, accumulator in items}
def RunPaths(self):
"""Returns a dict mapping run names to event file paths."""
return self._paths
def GetAccumulator(self, run):
"""Returns EventAccumulator for a given run.
Args:
run: String name of run.
Returns:
An EventAccumulator object.
Raises:
KeyError: If run does not exist.
"""
with self._accumulators_mutex:
return self._accumulators[run]

View File

@ -0,0 +1,267 @@
# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""A key-value[] store that implements reservoir sampling on the values."""
import collections
import random
import threading
class Reservoir:
"""A map-to-arrays container, with deterministic Reservoir Sampling.
Items are added with an associated key. Items may be retrieved by key, and
a list of keys can also be retrieved. If size is not zero, then it dictates
the maximum number of items that will be stored with each key. Once there are
more items for a given key, they are replaced via reservoir sampling, such
that each item has an equal probability of being included in the sample.
Deterministic means that for any given seed and bucket size, the sequence of
values that are kept for any given tag will always be the same, and that this
is independent of any insertions on other tags. That is:
>>> separate_reservoir = reservoir.Reservoir(10)
>>> interleaved_reservoir = reservoir.Reservoir(10)
>>> for i in range(100):
>>> separate_reservoir.AddItem('key1', i)
>>> for i in range(100):
>>> separate_reservoir.AddItem('key2', i)
>>> for i in range(100):
>>> interleaved_reservoir.AddItem('key1', i)
>>> interleaved_reservoir.AddItem('key2', i)
separate_reservoir and interleaved_reservoir will be in identical states.
See: https://en.wikipedia.org/wiki/Reservoir_sampling
Adding items has amortized O(1) runtime.
Fields:
always_keep_last: Whether the latest seen sample is always at the
end of the reservoir. Defaults to True.
size: An integer of the maximum number of samples.
"""
def __init__(self, size, seed=0, always_keep_last=True):
"""Creates a new reservoir.
Args:
size: The number of values to keep in the reservoir for each tag. If 0,
all values will be kept.
seed: The seed of the random number generator to use when sampling.
Different values for |seed| will produce different samples from the same
input items.
always_keep_last: Whether to always keep the latest seen item in the
end of the reservoir. Defaults to True.
Raises:
ValueError: If size is negative or not an integer.
"""
if size < 0 or size != round(size):
raise ValueError("size must be nonnegative integer, was %s" % size)
self._buckets = collections.defaultdict(
lambda: _ReservoirBucket(
size, random.Random(seed), always_keep_last
)
)
# _mutex guards the keys - creating new keys, retrieving by key, etc
# the internal items are guarded by the ReservoirBuckets' internal mutexes
self._mutex = threading.Lock()
self.size = size
self.always_keep_last = always_keep_last
def Keys(self):
"""Return all the keys in the reservoir.
Returns:
['list', 'of', 'keys'] in the Reservoir.
"""
with self._mutex:
return list(self._buckets.keys())
def Items(self, key):
"""Return items associated with given key.
Args:
key: The key for which we are finding associated items.
Raises:
KeyError: If the key is not found in the reservoir.
Returns:
[list, of, items] associated with that key.
"""
with self._mutex:
if key not in self._buckets:
raise KeyError("Key %s was not found in Reservoir" % key)
bucket = self._buckets[key]
return bucket.Items()
def AddItem(self, key, item, f=lambda x: x):
"""Add a new item to the Reservoir with the given tag.
If the reservoir has not yet reached full size, the new item is guaranteed
to be added. If the reservoir is full, then behavior depends on the
always_keep_last boolean.
If always_keep_last was set to true, the new item is guaranteed to be added
to the reservoir, and either the previous last item will be replaced, or
(with low probability) an older item will be replaced.
If always_keep_last was set to false, then the new item will replace an
old item with low probability.
If f is provided, it will be applied to transform item (lazily, iff item is
going to be included in the reservoir).
Args:
key: The key to store the item under.
item: The item to add to the reservoir.
f: An optional function to transform the item prior to addition.
"""
with self._mutex:
bucket = self._buckets[key]
bucket.AddItem(item, f)
def FilterItems(self, filterFn, key=None):
"""Filter items within a Reservoir, using a filtering function.
Args:
filterFn: A function that returns True for the items to be kept.
key: An optional bucket key to filter. If not specified, will filter all
all buckets.
Returns:
The number of items removed.
"""
with self._mutex:
if key:
if key in self._buckets:
return self._buckets[key].FilterItems(filterFn)
else:
return 0
else:
return sum(
bucket.FilterItems(filterFn)
for bucket in self._buckets.values()
)
class _ReservoirBucket:
"""A container for items from a stream, that implements reservoir sampling.
It always stores the most recent item as its final item.
"""
def __init__(self, _max_size, _random=None, always_keep_last=True):
"""Create the _ReservoirBucket.
Args:
_max_size: The maximum size the reservoir bucket may grow to. If size is
zero, the bucket has unbounded size.
_random: The random number generator to use. If not specified, defaults to
random.Random(0).
always_keep_last: Whether the latest seen item should always be included
in the end of the bucket.
Raises:
ValueError: if the size is not a nonnegative integer.
"""
if _max_size < 0 or _max_size != round(_max_size):
raise ValueError(
"_max_size must be nonnegative int, was %s" % _max_size
)
self.items = []
# This mutex protects the internal items, ensuring that calls to Items and
# AddItem are thread-safe
self._mutex = threading.Lock()
self._max_size = _max_size
self._num_items_seen = 0
if _random is not None:
self._random = _random
else:
self._random = random.Random(0)
self.always_keep_last = always_keep_last
def AddItem(self, item, f=lambda x: x):
"""Add an item to the ReservoirBucket, replacing an old item if
necessary.
The new item is guaranteed to be added to the bucket, and to be the last
element in the bucket. If the bucket has reached capacity, then an old item
will be replaced. With probability (_max_size/_num_items_seen) a random item
in the bucket will be popped out and the new item will be appended
to the end. With probability (1 - _max_size/_num_items_seen)
the last item in the bucket will be replaced.
Since the O(n) replacements occur with O(1/_num_items_seen) likelihood,
the amortized runtime is O(1).
Args:
item: The item to add to the bucket.
f: A function to transform item before addition, if it will be kept in
the reservoir.
"""
with self._mutex:
if len(self.items) < self._max_size or self._max_size == 0:
self.items.append(f(item))
else:
r = self._random.randint(0, self._num_items_seen)
if r < self._max_size:
self.items.pop(r)
self.items.append(f(item))
elif self.always_keep_last:
self.items[-1] = f(item)
self._num_items_seen += 1
def FilterItems(self, filterFn):
"""Filter items in a ReservoirBucket, using a filtering function.
Filtering items from the reservoir bucket must update the
internal state variable self._num_items_seen, which is used for determining
the rate of replacement in reservoir sampling. Ideally, self._num_items_seen
would contain the exact number of items that have ever seen by the
ReservoirBucket and satisfy filterFn. However, the ReservoirBucket does not
have access to all items seen -- it only has access to the subset of items
that have survived sampling (self.items). Therefore, we estimate
self._num_items_seen by scaling it by the same ratio as the ratio of items
not removed from self.items.
Args:
filterFn: A function that returns True for items to be kept.
Returns:
The number of items removed from the bucket.
"""
with self._mutex:
size_before = len(self.items)
self.items = list(filter(filterFn, self.items))
size_diff = size_before - len(self.items)
# Estimate a correction the number of items seen
prop_remaining = (
len(self.items) / float(size_before) if size_before > 0 else 0
)
self._num_items_seen = int(
round(self._num_items_seen * prop_remaining)
)
return size_diff
def Items(self):
"""Get all the items in the bucket."""
with self._mutex:
return list(self.items)

View File

@ -0,0 +1,29 @@
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""String constants describing contents of an event accumulator."""
# Arbitrary strings chosen to pass the type information of the tag from
# the backend to the frontend.
TENSORS = "tensors"
GRAPH = "graph"
META_GRAPH = "meta_graph"
RUN_METADATA = "run_metadata"
# Legacy (pre-tensor) tag types.
COMPRESSED_HISTOGRAMS = "distributions"
HISTOGRAMS = "histograms"
IMAGES = "images"
AUDIO = "audio"
SCALARS = "scalars"