I am done

This commit is contained in:
2024-10-30 22:14:35 +01:00
parent 720dc28c09
commit 40e2a747cf
36901 changed files with 5011519 additions and 0 deletions

View File

@ -0,0 +1,14 @@
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

View File

@ -0,0 +1,94 @@
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Keras integration for TensorBoard hparams.
Use `tensorboard.plugins.hparams.api` to access this module's contents.
"""
import tensorflow as tf
from tensorboard.plugins.hparams import api_pb2
from tensorboard.plugins.hparams import summary
from tensorboard.plugins.hparams import summary_v2
class Callback(tf.keras.callbacks.Callback):
"""Callback for logging hyperparameters to TensorBoard.
NOTE: This callback only works in TensorFlow eager mode.
"""
def __init__(self, writer, hparams, trial_id=None):
"""Create a callback for logging hyperparameters to TensorBoard.
As with the standard `tf.keras.callbacks.TensorBoard` class, each
callback object is valid for only one call to `model.fit`.
Args:
writer: The `SummaryWriter` object to which hparams should be
written, or a logdir (as a `str`) to be passed to
`tf.summary.create_file_writer` to create such a writer.
hparams: A `dict` mapping hyperparameters to the values used in
this session. Keys should be the names of `HParam` objects used
in an experiment, or the `HParam` objects themselves. Values
should be Python `bool`, `int`, `float`, or `string` values,
depending on the type of the hyperparameter.
trial_id: An optional `str` ID for the set of hyperparameter
values used in this trial. Defaults to a hash of the
hyperparameters.
Raises:
ValueError: If two entries in `hparams` share the same
hyperparameter name.
"""
# Defer creating the actual summary until we write it, so that the
# timestamp is correct. But create a "dry-run" first to fail fast in
# case the `hparams` are invalid.
self._hparams = dict(hparams)
self._trial_id = trial_id
summary_v2.hparams_pb(self._hparams, trial_id=self._trial_id)
if writer is None:
raise TypeError(
"writer must be a `SummaryWriter` or `str`, not None"
)
elif isinstance(writer, str):
self._writer = tf.compat.v2.summary.create_file_writer(writer)
else:
self._writer = writer
def _get_writer(self):
if self._writer is None:
raise RuntimeError(
"hparams Keras callback cannot be reused across training sessions"
)
if not tf.executing_eagerly():
raise RuntimeError(
"hparams Keras callback only supported in TensorFlow eager mode"
)
return self._writer
def on_train_begin(self, logs=None):
del logs # unused
with self._get_writer().as_default():
summary_v2.hparams(self._hparams, trial_id=self._trial_id)
def on_train_end(self, logs=None):
del logs # unused
with self._get_writer().as_default():
pb = summary.session_end_pb(api_pb2.STATUS_SUCCESS)
raw_pb = pb.SerializeToString()
tf.compat.v2.summary.experimental.write_raw_pb(raw_pb, step=0)
self._writer = None

View File

@ -0,0 +1,132 @@
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Public APIs for the HParams plugin.
This module supports a spectrum of use cases, depending on how much
structure you want. In the simplest case, you can simply collect your
hparams into a dict, and use a Keras callback to record them:
>>> from tensorboard.plugins.hparams import api as hp
>>> hparams = {
... "optimizer": "adam",
... "fc_dropout": 0.2,
... "neurons": 128,
... # ...
... }
>>>
>>> model = model_fn(hparams)
>>> callbacks = [
>>> tf.keras.callbacks.TensorBoard(logdir),
>>> hp.KerasCallback(logdir, hparams),
>>> ]
>>> model.fit(..., callbacks=callbacks)
The Keras callback requires that TensorFlow eager execution be enabled.
If not using Keras, use the `hparams` function to write the values
directly:
>>> # In eager mode:
>>> with tf.create_file_writer(logdir).as_default():
... hp.hparams(hparams)
>>>
>>> # In legacy graph mode:
>>> with tf.compat.v2.create_file_writer(logdir).as_default() as w:
... sess.run(w.init())
... sess.run(hp.hparams(hparams))
... sess.run(w.flush())
To control how hyperparameters and metrics appear in the TensorBoard UI,
you can define `HParam` and `Metric` objects, and write an experiment
summary to the top-level log directory:
>>> HP_OPTIMIZER = hp.HParam("optimizer")
>>> HP_FC_DROPOUT = hp.HParam(
... "fc_dropout",
... display_name="f.c. dropout",
... description="Dropout rate for fully connected subnet.",
... )
>>> HP_NEURONS = hp.HParam("neurons", description="Neurons per dense layer")
>>>
>>> with tf.summary.create_file_writer(base_logdir).as_default():
... hp.hparams_config(
... hparams=[
... HP_OPTIMIZER,
... HP_FC_DROPOUT,
... HP_NEURONS,
... ],
... metrics=[
... hp.Metric("xent", group="validation", display_name="cross-entropy"),
... hp.Metric("f1", group="validation", display_name="F₁ score"),
... hp.Metric("loss", group="train", display_name="training loss"),
... ],
... )
You can continue to pass a string-keyed dict to the Keras callback or
the `hparams` function, or you can use `HParam` objects as the keys. The
latter approach enables better static analysis: your favorite Python
linter can tell you if you misspell a hyperparameter name, your IDE can
help you find all the places where a hyperparameter is used, etc:
>>> hparams = {
... HP_OPTIMIZER: "adam",
... HP_FC_DROPOUT: 0.2,
... HP_NEURONS: 128,
... # ...
... }
>>>
>>> model = model_fn(hparams)
>>> callbacks = [
>>> tf.keras.callbacks.TensorBoard(logdir),
>>> hp.KerasCallback(logdir, hparams),
>>> ]
Finally, you can choose to annotate your hparam definitions with domain
information:
>>> HP_OPTIMIZER = hp.HParam("optimizer", hp.Discrete(["adam", "sgd"]))
>>> HP_FC_DROPOUT = hp.HParam("fc_dropout", hp.RealInterval(0.1, 0.4))
>>> HP_NEURONS = hp.HParam("neurons", hp.IntInterval(64, 256))
The TensorBoard HParams plugin does not provide tuners, but you can
integrate these domains into your preferred tuning framework if you so
desire. The domains will also be reflected in the TensorBoard UI.
See the `Experiment`, `HParam`, `Metric`, and `KerasCallback` classes
for API specifications. Consult the `hparams_demo.py` script in the
TensorBoard repository for an end-to-end MNIST example.
"""
from tensorboard.plugins.hparams import _keras
from tensorboard.plugins.hparams import summary_v2
Discrete = summary_v2.Discrete
Domain = summary_v2.Domain
HParam = summary_v2.HParam
IntInterval = summary_v2.IntInterval
Metric = summary_v2.Metric
RealInterval = summary_v2.RealInterval
hparams = summary_v2.hparams
hparams_pb = summary_v2.hparams_pb
hparams_config = summary_v2.hparams_config
hparams_config_pb = summary_v2.hparams_config_pb
KerasCallback = _keras.Callback
del _keras
del summary_v2

File diff suppressed because one or more lines are too long

View File

@ -0,0 +1,607 @@
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Wraps the base_plugin.TBContext to stores additional data shared across API
handlers for the HParams plugin backend."""
import collections
import os
from tensorboard.data import provider
from tensorboard.plugins.hparams import api_pb2
from tensorboard.plugins.hparams import json_format_compat
from tensorboard.plugins.hparams import metadata
from google.protobuf import json_format
from tensorboard.plugins.scalar import metadata as scalar_metadata
_DISCRETE_DOMAIN_TYPE_TO_DATA_TYPE = {
provider.HyperparameterDomainType.DISCRETE_BOOL: api_pb2.DATA_TYPE_BOOL,
provider.HyperparameterDomainType.DISCRETE_FLOAT: api_pb2.DATA_TYPE_FLOAT64,
provider.HyperparameterDomainType.DISCRETE_STRING: api_pb2.DATA_TYPE_STRING,
}
class Context:
"""Wraps the base_plugin.TBContext to stores additional data shared across
API handlers for the HParams plugin backend.
Before adding fields to this class, carefully consider whether the
field truly needs to be accessible to all API handlers or if it
can be passed separately to the handler constructor. We want to
avoid this class becoming a magic container of variables that have
no better place. See http://wiki.c2.com/?MagicContainer
"""
def __init__(self, tb_context):
"""Instantiates a context.
Args:
tb_context: base_plugin.TBContext. The "base" context we extend.
"""
self._tb_context = tb_context
def experiment_from_metadata(
self,
ctx,
experiment_id,
include_metrics,
hparams_run_to_tag_to_content,
data_provider_hparams,
hparams_limit=None,
):
"""Returns the experiment proto defining the experiment.
This method first attempts to find a metadata.EXPERIMENT_TAG tag and
retrieve the associated proto.
If no such tag is found, the method will attempt to build a minimal
experiment proto by scanning for all metadata.SESSION_START_INFO_TAG
tags (to compute the hparam_infos field of the experiment) and for all
scalar tags (to compute the metric_infos field of the experiment).
If no metadata.EXPERIMENT_TAG nor metadata.SESSION_START_INFO_TAG tags
are found, then will build an experiment proto using the results from
DataProvider.list_hyperparameters().
Args:
experiment_id: String, from `plugin_util.experiment_id`.
include_metrics: Whether to determine metrics_infos and include them
in the result.
hparams_run_to_tag_to_content: The output from an hparams_metadata()
call. A dict `d` such that `d[run][tag]` is a `bytes` value with the
summary metadata content for the keyed time series.
data_provider_hparams: The output from an hparams_from_data_provider()
call, corresponding to DataProvider.list_hyperparameters().
A provider.ListHyperpararametersResult.
hparams_limit: Optional number of hyperparameter metadata to include in the
result. If unset or zero, all metadata will be included.
Returns:
The experiment proto. If no data is found for an experiment proto to
be built, returns an entirely empty experiment.
"""
experiment = self._find_experiment_tag(
hparams_run_to_tag_to_content, include_metrics
)
if experiment:
_sort_and_reduce_to_hparams_limit(experiment, hparams_limit)
return experiment
experiment_from_runs = self._compute_experiment_from_runs(
ctx, experiment_id, include_metrics, hparams_run_to_tag_to_content
)
if experiment_from_runs:
_sort_and_reduce_to_hparams_limit(
experiment_from_runs, hparams_limit
)
return experiment_from_runs
experiment_from_data_provider_hparams = (
self._experiment_from_data_provider_hparams(
ctx, experiment_id, include_metrics, data_provider_hparams
)
)
return (
experiment_from_data_provider_hparams
if experiment_from_data_provider_hparams
else api_pb2.Experiment()
)
@property
def tb_context(self):
return self._tb_context
def _convert_plugin_metadata(self, data_provider_output):
return {
run: {
tag: time_series.plugin_content
for (tag, time_series) in tag_to_time_series.items()
}
for (run, tag_to_time_series) in data_provider_output.items()
}
def hparams_metadata(self, ctx, experiment_id, run_tag_filter=None):
"""Reads summary metadata for all hparams time series.
Args:
experiment_id: String, from `plugin_util.experiment_id`.
run_tag_filter: Optional `data.provider.RunTagFilter`, with
the semantics as in `list_tensors`.
Returns:
A dict `d` such that `d[run][tag]` is a `bytes` value with the
summary metadata content for the keyed time series.
"""
return self._convert_plugin_metadata(
self._tb_context.data_provider.list_tensors(
ctx,
experiment_id=experiment_id,
plugin_name=metadata.PLUGIN_NAME,
run_tag_filter=run_tag_filter,
)
)
def scalars_metadata(self, ctx, experiment_id):
"""Reads summary metadata for all scalar time series.
Args:
experiment_id: String, from `plugin_util.experiment_id`.
Returns:
A dict `d` such that `d[run][tag]` is a `bytes` value with the
summary metadata content for the keyed time series.
"""
return self._convert_plugin_metadata(
self._tb_context.data_provider.list_scalars(
ctx,
experiment_id=experiment_id,
plugin_name=scalar_metadata.PLUGIN_NAME,
)
)
def read_last_scalars(self, ctx, experiment_id, run_tag_filter):
"""Reads the most recent values from scalar time series.
Args:
experiment_id: String.
run_tag_filter: Required `data.provider.RunTagFilter`, with
the semantics as in `read_last_scalars`.
Returns:
A dict `d` such that `d[run][tag]` is a `provider.ScalarDatum`
value, with keys only for runs and tags that actually had
data, which may be a subset of what was requested.
"""
return self._tb_context.data_provider.read_last_scalars(
ctx,
experiment_id=experiment_id,
plugin_name=scalar_metadata.PLUGIN_NAME,
run_tag_filter=run_tag_filter,
)
def hparams_from_data_provider(self, ctx, experiment_id, limit):
"""Calls DataProvider.list_hyperparameters() and returns the result."""
return self._tb_context.data_provider.list_hyperparameters(
ctx, experiment_ids=[experiment_id], limit=limit
)
def session_groups_from_data_provider(
self, ctx, experiment_id, filters, sort, hparams_to_include
):
"""Calls DataProvider.read_hyperparameters() and returns the result."""
return self._tb_context.data_provider.read_hyperparameters(
ctx,
experiment_ids=[experiment_id],
filters=filters,
sort=sort,
hparams_to_include=hparams_to_include,
)
def _find_experiment_tag(
self, hparams_run_to_tag_to_content, include_metrics
):
"""Finds the experiment associated with the metadata.EXPERIMENT_TAG
tag.
Returns:
The experiment or None if no such experiment is found.
"""
# We expect only one run to have an `EXPERIMENT_TAG`; look
# through all of them and arbitrarily pick the first one.
for tags in hparams_run_to_tag_to_content.values():
maybe_content = tags.get(metadata.EXPERIMENT_TAG)
if maybe_content is not None:
experiment = metadata.parse_experiment_plugin_data(
maybe_content
)
if not include_metrics:
# metric_infos haven't technically been "calculated" in this
# case. They have been read directly from the Experiment
# proto.
# Delete them from the result so that they are not returned
# to the client.
experiment.ClearField("metric_infos")
return experiment
return None
def _compute_experiment_from_runs(
self, ctx, experiment_id, include_metrics, hparams_run_to_tag_to_content
):
"""Computes a minimal Experiment protocol buffer by scanning the runs.
Returns None if there are no hparam infos logged.
"""
hparam_infos = self._compute_hparam_infos(hparams_run_to_tag_to_content)
metric_infos = (
self._compute_metric_infos_from_runs(
ctx, experiment_id, hparams_run_to_tag_to_content
)
if hparam_infos and include_metrics
else []
)
if not hparam_infos and not metric_infos:
return None
return api_pb2.Experiment(
hparam_infos=hparam_infos, metric_infos=metric_infos
)
def _compute_hparam_infos(self, hparams_run_to_tag_to_content):
"""Computes a list of api_pb2.HParamInfo from the current run, tag
info.
Finds all the SessionStartInfo messages and collects the hparams values
appearing in each one. For each hparam attempts to deduce a type that fits
all its values. Finally, sets the 'domain' of the resulting HParamInfo
to be discrete if the type is string or boolean.
Returns:
A list of api_pb2.HParamInfo messages.
"""
# Construct a dict mapping an hparam name to its list of values.
hparams = collections.defaultdict(list)
for tag_to_content in hparams_run_to_tag_to_content.values():
if metadata.SESSION_START_INFO_TAG not in tag_to_content:
continue
start_info = metadata.parse_session_start_info_plugin_data(
tag_to_content[metadata.SESSION_START_INFO_TAG]
)
for name, value in start_info.hparams.items():
hparams[name].append(value)
# Try to construct an HParamInfo for each hparam from its name and list
# of values.
result = []
for name, values in hparams.items():
hparam_info = self._compute_hparam_info_from_values(name, values)
if hparam_info is not None:
result.append(hparam_info)
return result
def _compute_hparam_info_from_values(self, name, values):
"""Builds an HParamInfo message from the hparam name and list of
values.
Args:
name: string. The hparam name.
values: list of google.protobuf.Value messages. The list of values for the
hparam.
Returns:
An api_pb2.HParamInfo message.
"""
# Figure out the type from the values.
# Ignore values whose type is not listed in api_pb2.DataType
# If all values have the same type, then that is the type used.
# Otherwise, the returned type is DATA_TYPE_STRING.
result = api_pb2.HParamInfo(name=name, type=api_pb2.DATA_TYPE_UNSET)
for v in values:
v_type = _protobuf_value_type(v)
if not v_type:
continue
if result.type == api_pb2.DATA_TYPE_UNSET:
result.type = v_type
elif result.type != v_type:
result.type = api_pb2.DATA_TYPE_STRING
if result.type == api_pb2.DATA_TYPE_STRING:
# A string result.type does not change, so we can exit the loop.
break
# If we couldn't figure out a type, then we can't compute the hparam_info.
if result.type == api_pb2.DATA_TYPE_UNSET:
return None
if result.type == api_pb2.DATA_TYPE_STRING:
distinct_string_values = set(
_protobuf_value_to_string(v)
for v in values
if _can_be_converted_to_string(v)
)
result.domain_discrete.extend(distinct_string_values)
result.differs = len(distinct_string_values) > 1
if result.type == api_pb2.DATA_TYPE_BOOL:
distinct_bool_values = set(v.bool_value for v in values)
result.domain_discrete.extend(distinct_bool_values)
result.differs = len(distinct_bool_values) > 1
if result.type == api_pb2.DATA_TYPE_FLOAT64:
# Always uses interval domain type for numeric hparam values.
distinct_float_values = sorted([v.number_value for v in values])
if distinct_float_values:
result.domain_interval.min_value = distinct_float_values[0]
result.domain_interval.max_value = distinct_float_values[-1]
result.differs = len(set(distinct_float_values)) > 1
return result
def _experiment_from_data_provider_hparams(
self,
ctx,
experiment_id,
include_metrics,
data_provider_hparams,
):
"""Returns an experiment protobuffer based on data provider hparams.
Args:
experiment_id: String, from `plugin_util.experiment_id`.
include_metrics: Whether to determine metrics_infos and include them
in the result.
data_provider_hparams: The output from an hparams_from_data_provider()
call, corresponding to DataProvider.list_hyperparameters().
A provider.ListHyperparametersResult.
Returns:
The experiment proto. If there are no hyperparameters in the input,
returns None.
"""
if isinstance(data_provider_hparams, list):
# TODO: Support old return value of Collection[provider.Hyperparameters]
# until all internal implementations of DataProvider can be
# migrated to use new return value of provider.ListHyperparametersResult.
hyperparameters = data_provider_hparams
session_groups = []
else:
# Is instance of provider.ListHyperparametersResult
hyperparameters = data_provider_hparams.hyperparameters
session_groups = data_provider_hparams.session_groups
hparam_infos = [
self._convert_data_provider_hparam(dp_hparam)
for dp_hparam in hyperparameters
]
metric_infos = (
self.compute_metric_infos_from_data_provider_session_groups(
ctx, experiment_id, session_groups
)
if include_metrics
else []
)
return api_pb2.Experiment(
hparam_infos=hparam_infos, metric_infos=metric_infos
)
def _convert_data_provider_hparam(self, dp_hparam):
"""Builds an HParamInfo message from data provider Hyperparameter.
Args:
dp_hparam: The provider.Hyperparameter returned by the call to
provider.DataProvider.list_hyperparameters().
Returns:
An HParamInfo to include in the Experiment.
"""
hparam_info = api_pb2.HParamInfo(
name=dp_hparam.hyperparameter_name,
display_name=dp_hparam.hyperparameter_display_name,
differs=dp_hparam.differs,
)
if dp_hparam.domain_type == provider.HyperparameterDomainType.INTERVAL:
hparam_info.type = api_pb2.DATA_TYPE_FLOAT64
(dp_hparam_min, dp_hparam_max) = dp_hparam.domain
hparam_info.domain_interval.min_value = dp_hparam_min
hparam_info.domain_interval.max_value = dp_hparam_max
elif dp_hparam.domain_type in _DISCRETE_DOMAIN_TYPE_TO_DATA_TYPE.keys():
hparam_info.type = _DISCRETE_DOMAIN_TYPE_TO_DATA_TYPE.get(
dp_hparam.domain_type
)
hparam_info.domain_discrete.extend(dp_hparam.domain)
return hparam_info
def _compute_metric_infos_from_runs(
self, ctx, experiment_id, hparams_run_to_tag_to_content
):
session_runs = set(
run
for run, tags in hparams_run_to_tag_to_content.items()
if metadata.SESSION_START_INFO_TAG in tags
)
return (
api_pb2.MetricInfo(name=api_pb2.MetricName(group=group, tag=tag))
for tag, group in self._compute_metric_names(
ctx, experiment_id, session_runs
)
)
def compute_metric_infos_from_data_provider_session_groups(
self, ctx, experiment_id, session_groups
):
session_runs = set(
generate_data_provider_session_name(s)
for sg in session_groups
for s in sg.sessions
)
return [
api_pb2.MetricInfo(name=api_pb2.MetricName(group=group, tag=tag))
for tag, group in self._compute_metric_names(
ctx, experiment_id, session_runs
)
]
def _compute_metric_names(self, ctx, experiment_id, session_runs):
"""Computes the list of metric names from all the scalar (run, tag)
pairs.
The return value is a list of (tag, group) pairs representing the metric
names. The list is sorted in Python tuple-order (lexicographical).
For example, if the scalar (run, tag) pairs are:
("exp/session1", "loss")
("exp/session2", "loss")
("exp/session2/eval", "loss")
("exp/session2/validation", "accuracy")
("exp/no-session", "loss_2"),
and the runs corresponding to sessions are "exp/session1", "exp/session2",
this method will return [("loss", ""), ("loss", "/eval"), ("accuracy",
"/validation")]
More precisely, each scalar (run, tag) pair is converted to a (tag, group)
metric name, where group is the suffix of run formed by removing the
longest prefix which is a session run. If no session run is a prefix of
'run', the pair is skipped.
Returns:
A python list containing pairs. Each pair is a (tag, group) pair
representing a metric name used in some session.
"""
metric_names_set = set()
scalars_run_to_tag_to_content = self.scalars_metadata(
ctx, experiment_id
)
for run, tags in scalars_run_to_tag_to_content.items():
session = _find_longest_parent_path(session_runs, run)
if session is None:
continue
group = os.path.relpath(run, session)
# relpath() returns "." for the 'session' directory, we use an empty
# string, unless the run name actually ends with ".".
if group == "." and not run.endswith("."):
group = ""
metric_names_set.update((tag, group) for tag in tags)
metric_names_list = list(metric_names_set)
# Sort metrics for determinism.
metric_names_list.sort()
return metric_names_list
def generate_data_provider_session_name(session):
"""Generates a name from a HyperparameterSesssionRun.
If the HyperparameterSessionRun contains no experiment or run information
then the name is set to the original experiment_id.
"""
if not session.experiment_id and not session.run:
return ""
elif not session.experiment_id:
return session.run
elif not session.run:
return session.experiment_id
else:
return f"{session.experiment_id}/{session.run}"
def _find_longest_parent_path(path_set, path):
"""Finds the longest "parent-path" of 'path' in 'path_set'.
This function takes and returns "path-like" strings which are strings
made of strings separated by os.sep. No file access is performed here, so
these strings need not correspond to actual files in some file-system..
This function returns the longest ancestor path
For example, for path_set=["/foo/bar", "/foo", "/bar/foo"] and
path="/foo/bar/sub_dir", returns "/foo/bar".
Args:
path_set: set of path-like strings -- e.g. a list of strings separated by
os.sep. No actual disk-access is performed here, so these need not
correspond to actual files.
path: a path-like string.
Returns:
The element in path_set which is the longest parent directory of 'path'.
"""
# This could likely be more efficiently implemented with a trie
# data-structure, but we don't want to add an extra dependency for that.
while path not in path_set:
if not path:
return None
path = os.path.dirname(path)
return path
def _can_be_converted_to_string(value):
if not _protobuf_value_type(value):
return False
return json_format_compat.is_serializable_value(value)
def _protobuf_value_type(value):
"""Returns the type of the google.protobuf.Value message as an
api.DataType.
Returns None if the type of 'value' is not one of the types supported in
api_pb2.DataType.
Args:
value: google.protobuf.Value message.
"""
if value.HasField("number_value"):
return api_pb2.DATA_TYPE_FLOAT64
if value.HasField("string_value"):
return api_pb2.DATA_TYPE_STRING
if value.HasField("bool_value"):
return api_pb2.DATA_TYPE_BOOL
return None
def _protobuf_value_to_string(value):
"""Returns a string representation of given google.protobuf.Value message.
Args:
value: google.protobuf.Value message. Assumed to be of type 'number',
'string' or 'bool'.
"""
value_in_json = json_format.MessageToJson(value)
if value.HasField("string_value"):
# Remove the quotations.
return value_in_json[1:-1]
return value_in_json
def _sort_and_reduce_to_hparams_limit(experiment, hparams_limit=None):
"""Sorts and applies limit to the hparams in the given experiment proto.
Args:
experiment: An api_pb2.Experiment proto, which will be modified in place.
hparams_limit: Optional number of hyperparameter metadata to include in the
result. If unset or zero, no limit will be applied.
Returns:
None. `experiment` proto will be modified in place.
"""
if not hparams_limit:
# If limit is unset or zero, returns all hparams.
hparams_limit = len(experiment.hparam_infos)
# Prioritizes returning HParamInfo protos with `differed` values.
# Sorts by `differs` (True first), then by name.
limited_hparam_infos = sorted(
experiment.hparam_infos,
key=lambda hparam_info: (not hparam_info.differs, hparam_info.name),
)[:hparams_limit]
experiment.ClearField("hparam_infos")
experiment.hparam_infos.extend(limited_hparam_infos)

View File

@ -0,0 +1,159 @@
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Classes and functions for handling the DownloadData API call."""
import csv
import io
import math
from tensorboard.plugins.hparams import error
class OutputFormat:
"""An enum used to list the valid output formats for API calls."""
JSON = "json"
CSV = "csv"
LATEX = "latex"
class Handler:
"""Handles a DownloadData request."""
def __init__(
self,
context,
experiment,
session_groups,
response_format,
columns_visibility,
):
"""Constructor.
Args:
context: A backend_context.Context instance.
experiment: Experiment proto.
session_groups: ListSessionGroupsResponse proto.
response_format: A string in the OutputFormat enum.
columns_visibility: A list of boolean values to filter columns.
"""
self._context = context
self._experiment = experiment
self._session_groups = session_groups
self._response_format = response_format
self._columns_visibility = columns_visibility
def run(self):
"""Handles the request specified on construction.
Returns:
A response body.
A mime type (string) for the response.
"""
experiment = self._experiment
session_groups = self._session_groups
response_format = self._response_format
visibility = self._columns_visibility
header = []
for hparam_info in experiment.hparam_infos:
header.append(hparam_info.display_name or hparam_info.name)
for metric_info in experiment.metric_infos:
header.append(metric_info.display_name or metric_info.name.tag)
def _filter_columns(row):
return [value for value, visible in zip(row, visibility) if visible]
header = _filter_columns(header)
rows = []
def _get_value(value):
if value.HasField("number_value"):
return value.number_value
if value.HasField("string_value"):
return value.string_value
if value.HasField("bool_value"):
return value.bool_value
# hyperparameter values can be optional in a session group
return ""
def _get_metric_id(metric):
return metric.group + "." + metric.tag
for group in session_groups.session_groups:
row = []
for hparam_info in experiment.hparam_infos:
row.append(_get_value(group.hparams[hparam_info.name]))
metric_values = {}
for metric_value in group.metric_values:
metric_id = _get_metric_id(metric_value.name)
metric_values[metric_id] = metric_value.value
for metric_info in experiment.metric_infos:
metric_id = _get_metric_id(metric_info.name)
row.append(metric_values.get(metric_id))
rows.append(_filter_columns(row))
if response_format == OutputFormat.JSON:
mime_type = "application/json"
body = dict(header=header, rows=rows)
elif response_format == OutputFormat.LATEX:
def latex_format(value):
if value is None:
return "-"
elif isinstance(value, int):
return "$%d$" % value
elif isinstance(value, float):
if math.isnan(value):
return r"$\mathrm{NaN}$"
if value in (float("inf"), float("-inf")):
return r"$%s\infty$" % ("-" if value < 0 else "+")
scientific = "%.3g" % value
if "e" in scientific:
coefficient, exponent = scientific.split("e")
return "$%s\\cdot 10^{%d}$" % (
coefficient,
int(exponent),
)
return "$%s$" % scientific
return value.replace("_", "\\_").replace("%", "\\%")
mime_type = "application/x-latex"
top_part = "\\begin{table}[tbp]\n\\begin{tabular}{%s}\n" % (
"l" * len(header)
)
header_part = (
" & ".join(map(latex_format, header)) + " \\\\ \\hline\n"
)
middle_part = "".join(
" & ".join(map(latex_format, row)) + " \\\\\n" for row in rows
)
bottom_part = "\\hline\n\\end{tabular}\n\\end{table}\n"
body = top_part + header_part + middle_part + bottom_part
elif response_format == OutputFormat.CSV:
string_io = io.StringIO()
writer = csv.writer(string_io)
writer.writerow(header)
writer.writerows(rows)
body = string_io.getvalue()
mime_type = "text/csv"
else:
raise error.HParamsError(
"Invalid reponses format: %s" % response_format
)
return body, mime_type

View File

@ -0,0 +1,26 @@
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Defines an error (exception) class for the HParams plugin."""
class HParamsError(Exception):
"""Represents an error that is meaningful to the end-user.
Such an error should have a meaningful error message. Other errors,
(such as resulting from some internal invariants being violated)
should be represented by other exceptions.
"""
pass

View File

@ -0,0 +1,71 @@
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Classes and functions for handling the GetExperiment API call."""
from tensorboard.plugins.hparams import api_pb2
class Handler:
"""Handles a GetExperiment request."""
def __init__(
self, request_context, backend_context, experiment_id, request
):
"""Constructor.
Args:
request_context: A tensorboard.context.RequestContext.
backend_context: A backend_context.Context instance.
experiment_id: A string, as from `plugin_util.experiment_id`.
request: A request proto.
"""
self._request_context = request_context
self._backend_context = backend_context
self._experiment_id = experiment_id
self._include_metrics = (
# Metrics are included by default if include_metrics is not
# specified in the request.
not request.HasField("include_metrics")
or request.include_metrics
)
self._hparams_limit = (
request.hparams_limit
if isinstance(request, api_pb2.GetExperimentRequest)
else None
)
def run(self):
"""Handles the request specified on construction.
Returns:
An Experiment object.
"""
data_provider_hparams = (
self._backend_context.hparams_from_data_provider(
self._request_context,
self._experiment_id,
limit=self._hparams_limit,
)
)
return self._backend_context.experiment_from_metadata(
self._request_context,
self._experiment_id,
self._include_metrics,
self._backend_context.hparams_metadata(
self._request_context, self._experiment_id
),
data_provider_hparams,
self._hparams_limit,
)

View File

@ -0,0 +1,207 @@
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""The TensorBoard HParams plugin.
See `http_api.md` in this directory for specifications of the routes for
this plugin.
"""
import json
import werkzeug
from werkzeug import wrappers
from tensorboard import plugin_util
from tensorboard.plugins.hparams import api_pb2
from tensorboard.plugins.hparams import backend_context
from tensorboard.plugins.hparams import download_data
from tensorboard.plugins.hparams import error
from tensorboard.plugins.hparams import get_experiment
from tensorboard.plugins.hparams import list_metric_evals
from tensorboard.plugins.hparams import list_session_groups
from tensorboard.plugins.hparams import metadata
from google.protobuf import json_format
from tensorboard.backend import http_util
from tensorboard.plugins import base_plugin
from tensorboard.plugins.scalar import metadata as scalars_metadata
from tensorboard.util import tb_logging
logger = tb_logging.get_logger()
class HParamsPlugin(base_plugin.TBPlugin):
"""HParams Plugin for TensorBoard.
It supports both GETs and POSTs. See 'http_api.md' for more details.
"""
plugin_name = metadata.PLUGIN_NAME
def __init__(self, context):
"""Instantiates HParams plugin via TensorBoard core.
Args:
context: A base_plugin.TBContext instance.
"""
self._context = backend_context.Context(context)
def get_plugin_apps(self):
"""See base class."""
return {
"/download_data": self.download_data_route,
"/experiment": self.get_experiment_route,
"/session_groups": self.list_session_groups_route,
"/metric_evals": self.list_metric_evals_route,
}
def is_active(self):
return False # `list_plugins` as called by TB core suffices
def frontend_metadata(self):
return base_plugin.FrontendMetadata(element_name="tf-hparams-dashboard")
# ---- /download_data- -------------------------------------------------------
@wrappers.Request.application
def download_data_route(self, request):
ctx = plugin_util.context(request.environ)
experiment_id = plugin_util.experiment_id(request.environ)
try:
response_format = request.args.get("format")
columns_visibility = json.loads(
request.args.get("columnsVisibility")
)
request_proto = _parse_request_argument(
request, api_pb2.ListSessionGroupsRequest
)
session_groups = list_session_groups.Handler(
ctx, self._context, experiment_id, request_proto
).run()
experiment = get_experiment.Handler(
ctx, self._context, experiment_id, request_proto
).run()
body, mime_type = download_data.Handler(
self._context,
experiment,
session_groups,
response_format,
columns_visibility,
).run()
return http_util.Respond(request, body, mime_type)
except error.HParamsError as e:
logger.error("HParams error: %s" % e)
raise werkzeug.exceptions.BadRequest(description=str(e))
# ---- /experiment -----------------------------------------------------------
@wrappers.Request.application
def get_experiment_route(self, request):
ctx = plugin_util.context(request.environ)
experiment_id = plugin_util.experiment_id(request.environ)
try:
request_proto = _parse_request_argument(
request, api_pb2.GetExperimentRequest
)
response_proto = get_experiment.Handler(
ctx,
self._context,
experiment_id,
request_proto,
).run()
response = plugin_util.proto_to_json(response_proto)
return http_util.Respond(
request,
response,
"application/json",
)
except error.HParamsError as e:
logger.error("HParams error: %s" % e)
raise werkzeug.exceptions.BadRequest(description=str(e))
# ---- /session_groups -------------------------------------------------------
@wrappers.Request.application
def list_session_groups_route(self, request):
ctx = plugin_util.context(request.environ)
experiment_id = plugin_util.experiment_id(request.environ)
try:
request_proto = _parse_request_argument(
request, api_pb2.ListSessionGroupsRequest
)
response_proto = list_session_groups.Handler(
ctx,
self._context,
experiment_id,
request_proto,
).run()
response = plugin_util.proto_to_json(response_proto)
return http_util.Respond(
request,
response,
"application/json",
)
except error.HParamsError as e:
logger.error("HParams error: %s" % e)
raise werkzeug.exceptions.BadRequest(description=str(e))
# ---- /metric_evals ---------------------------------------------------------
@wrappers.Request.application
def list_metric_evals_route(self, request):
ctx = plugin_util.context(request.environ)
experiment_id = plugin_util.experiment_id(request.environ)
try:
request_proto = _parse_request_argument(
request, api_pb2.ListMetricEvalsRequest
)
scalars_plugin = self._get_scalars_plugin()
if not scalars_plugin:
raise werkzeug.exceptions.NotFound("Scalars plugin not loaded")
return http_util.Respond(
request,
list_metric_evals.Handler(
ctx, request_proto, scalars_plugin, experiment_id
).run(),
"application/json",
)
except error.HParamsError as e:
logger.error("HParams error: %s" % e)
raise werkzeug.exceptions.BadRequest(description=str(e))
def _get_scalars_plugin(self):
"""Tries to get the scalars plugin.
Returns:
The scalars plugin or None if it is not yet registered.
"""
return self._context.tb_context.plugin_name_to_instance.get(
scalars_metadata.PLUGIN_NAME
)
def _parse_request_argument(request, proto_class):
request_json = (
request.data
if request.method == "POST"
else request.args.get("request")
)
try:
return json_format.Parse(request_json, proto_class())
# if request_json is None, json_format.Parse will throw an AttributeError:
# 'NoneType' object has no attribute 'decode'.
except (AttributeError, json_format.ParseError) as e:
raise error.HParamsError(
"Expected a JSON-formatted request data of type: {}, but got {} ".format(
proto_class, request_json
)
) from e

View File

@ -0,0 +1,69 @@
# -*- coding: utf-8 -*-
# Generated by the protocol buffer compiler. DO NOT EDIT!
# source: tensorboard/plugins/hparams/hparams_util.proto
"""Generated protocol buffer code."""
from google.protobuf import descriptor as _descriptor
from google.protobuf import descriptor_pool as _descriptor_pool
from google.protobuf import message as _message
from google.protobuf import reflection as _reflection
from google.protobuf import symbol_database as _symbol_database
# @@protoc_insertion_point(imports)
_sym_db = _symbol_database.Default()
from google.protobuf import struct_pb2 as google_dot_protobuf_dot_struct__pb2
from tensorboard.plugins.hparams import api_pb2 as tensorboard_dot_plugins_dot_hparams_dot_api__pb2
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n.tensorboard/plugins/hparams/hparams_util.proto\x12\x13tensorboard.hparams\x1a\x1cgoogle/protobuf/struct.proto\x1a%tensorboard/plugins/hparams/api.proto\"H\n\x0fHParamInfosList\x12\x35\n\x0chparam_infos\x18\x01 \x03(\x0b\x32\x1f.tensorboard.hparams.HParamInfo\"H\n\x0fMetricInfosList\x12\x35\n\x0cmetric_infos\x18\x01 \x03(\x0b\x32\x1f.tensorboard.hparams.MetricInfo\"\x8d\x01\n\x07HParams\x12:\n\x07hparams\x18\x01 \x03(\x0b\x32).tensorboard.hparams.HParams.HparamsEntry\x1a\x46\n\x0cHparamsEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12%\n\x05value\x18\x02 \x01(\x0b\x32\x16.google.protobuf.Value:\x02\x38\x01\x62\x06proto3')
_HPARAMINFOSLIST = DESCRIPTOR.message_types_by_name['HParamInfosList']
_METRICINFOSLIST = DESCRIPTOR.message_types_by_name['MetricInfosList']
_HPARAMS = DESCRIPTOR.message_types_by_name['HParams']
_HPARAMS_HPARAMSENTRY = _HPARAMS.nested_types_by_name['HparamsEntry']
HParamInfosList = _reflection.GeneratedProtocolMessageType('HParamInfosList', (_message.Message,), {
'DESCRIPTOR' : _HPARAMINFOSLIST,
'__module__' : 'tensorboard.plugins.hparams.hparams_util_pb2'
# @@protoc_insertion_point(class_scope:tensorboard.hparams.HParamInfosList)
})
_sym_db.RegisterMessage(HParamInfosList)
MetricInfosList = _reflection.GeneratedProtocolMessageType('MetricInfosList', (_message.Message,), {
'DESCRIPTOR' : _METRICINFOSLIST,
'__module__' : 'tensorboard.plugins.hparams.hparams_util_pb2'
# @@protoc_insertion_point(class_scope:tensorboard.hparams.MetricInfosList)
})
_sym_db.RegisterMessage(MetricInfosList)
HParams = _reflection.GeneratedProtocolMessageType('HParams', (_message.Message,), {
'HparamsEntry' : _reflection.GeneratedProtocolMessageType('HparamsEntry', (_message.Message,), {
'DESCRIPTOR' : _HPARAMS_HPARAMSENTRY,
'__module__' : 'tensorboard.plugins.hparams.hparams_util_pb2'
# @@protoc_insertion_point(class_scope:tensorboard.hparams.HParams.HparamsEntry)
})
,
'DESCRIPTOR' : _HPARAMS,
'__module__' : 'tensorboard.plugins.hparams.hparams_util_pb2'
# @@protoc_insertion_point(class_scope:tensorboard.hparams.HParams)
})
_sym_db.RegisterMessage(HParams)
_sym_db.RegisterMessage(HParams.HparamsEntry)
if _descriptor._USE_C_DESCRIPTORS == False:
DESCRIPTOR._options = None
_HPARAMS_HPARAMSENTRY._options = None
_HPARAMS_HPARAMSENTRY._serialized_options = b'8\001'
_HPARAMINFOSLIST._serialized_start=140
_HPARAMINFOSLIST._serialized_end=212
_METRICINFOSLIST._serialized_start=214
_METRICINFOSLIST._serialized_end=286
_HPARAMS._serialized_start=289
_HPARAMS._serialized_end=430
_HPARAMS_HPARAMSENTRY._serialized_start=360
_HPARAMS_HPARAMSENTRY._serialized_end=430
# @@protoc_insertion_point(module_scope)

View File

@ -0,0 +1,38 @@
# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
import math
def is_serializable_value(value):
"""Returns whether a protobuf Value will be serializable by MessageToJson.
The json_format documentation states that "attempting to serialize NaN or
Infinity results in error."
https://protobuf.dev/reference/protobuf/google.protobuf/#value
Args:
value: A value of type protobuf.Value.
Returns:
True if the Value should be serializable without error by MessageToJson.
False, otherwise.
"""
if not value.HasField("number_value"):
return True
number_value = value.number_value
return not math.isnan(number_value) and not math.isinf(number_value)

View File

@ -0,0 +1,58 @@
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Classes and functions for handling the ListMetricEvals API call."""
from tensorboard.plugins.hparams import metrics
from tensorboard.plugins.scalar import scalars_plugin
class Handler:
"""Handles a ListMetricEvals request."""
def __init__(
self, request_context, request, scalars_plugin_instance, experiment
):
"""Constructor.
Args:
request_context: A tensorboard.context.RequestContext.
request: A ListSessionGroupsRequest protobuf.
scalars_plugin_instance: A scalars_plugin.ScalarsPlugin.
experiment: A experiment ID, as a possibly-empty `str`.
"""
self._request_context = request_context
self._request = request
self._scalars_plugin_instance = scalars_plugin_instance
self._experiment = experiment
def run(self):
"""Executes the request.
Returns:
An array of tuples representing the metric evaluations--each of the
form (<wall time in secs>, <training step>, <metric value>).
"""
run, tag = metrics.run_tag_from_session_and_metric(
self._request.session_name, self._request.metric_name
)
body, _ = self._scalars_plugin_instance.scalars_impl(
self._request_context,
tag,
run,
self._experiment,
scalars_plugin.OutputFormat.JSON,
)
return body

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,127 @@
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Constants used in the HParams plugin."""
from tensorboard.compat.proto import summary_pb2
from tensorboard.compat.proto import types_pb2
from tensorboard.plugins.hparams import error
from tensorboard.plugins.hparams import plugin_data_pb2
from tensorboard.util import tensor_util
PLUGIN_NAME = "hparams"
PLUGIN_DATA_VERSION = 0
# Tensor value for use in summaries that really only need to store
# metadata. A length-0 float vector is of minimal serialized length
# (6 bytes) among valid tensors. Cache this: computing it takes on the
# order of tens of microseconds.
NULL_TENSOR = tensor_util.make_tensor_proto(
[], dtype=types_pb2.DT_FLOAT, shape=(0,)
)
EXPERIMENT_TAG = "_hparams_/experiment"
SESSION_START_INFO_TAG = "_hparams_/session_start_info"
SESSION_END_INFO_TAG = "_hparams_/session_end_info"
def create_summary_metadata(hparams_plugin_data_pb):
"""Returns a summary metadata for the HParams plugin.
Returns a summary_pb2.SummaryMetadata holding a copy of the given
HParamsPluginData message in its plugin_data.content field.
Sets the version field of the hparams_plugin_data_pb copy to
PLUGIN_DATA_VERSION.
Args:
hparams_plugin_data_pb: the HParamsPluginData protobuffer to use.
"""
if not isinstance(
hparams_plugin_data_pb, plugin_data_pb2.HParamsPluginData
):
raise TypeError(
"Needed an instance of plugin_data_pb2.HParamsPluginData."
" Got: %s" % type(hparams_plugin_data_pb)
)
content = plugin_data_pb2.HParamsPluginData()
content.CopyFrom(hparams_plugin_data_pb)
content.version = PLUGIN_DATA_VERSION
return summary_pb2.SummaryMetadata(
plugin_data=summary_pb2.SummaryMetadata.PluginData(
plugin_name=PLUGIN_NAME, content=content.SerializeToString()
)
)
def parse_experiment_plugin_data(content):
"""Returns the experiment from HParam's
SummaryMetadata.plugin_data.content.
Raises HParamsError if the content doesn't have 'experiment' set or
this file is incompatible with the version of the metadata stored.
Args:
content: The SummaryMetadata.plugin_data.content to use.
"""
return _parse_plugin_data_as(content, "experiment")
def parse_session_start_info_plugin_data(content):
"""Returns session_start_info from the plugin_data.content.
Raises HParamsError if the content doesn't have 'session_start_info' set or
this file is incompatible with the version of the metadata stored.
Args:
content: The SummaryMetadata.plugin_data.content to use.
"""
return _parse_plugin_data_as(content, "session_start_info")
def parse_session_end_info_plugin_data(content):
"""Returns session_end_info from the plugin_data.content.
Raises HParamsError if the content doesn't have 'session_end_info' set or
this file is incompatible with the version of the metadata stored.
Args:
content: The SummaryMetadata.plugin_data.content to use.
"""
return _parse_plugin_data_as(content, "session_end_info")
def _parse_plugin_data_as(content, data_oneof_field):
"""Returns a data oneof's field from plugin_data.content.
Raises HParamsError if the content doesn't have 'data_oneof_field' set or
this file is incompatible with the version of the metadata stored.
Args:
content: The SummaryMetadata.plugin_data.content to use.
data_oneof_field: string. The name of the data oneof field to return.
"""
plugin_data = plugin_data_pb2.HParamsPluginData.FromString(content)
if plugin_data.version != PLUGIN_DATA_VERSION:
raise error.HParamsError(
"Only supports plugin_data version: %s; found: %s in: %s"
% (PLUGIN_DATA_VERSION, plugin_data.version, plugin_data)
)
if not plugin_data.HasField(data_oneof_field):
raise error.HParamsError(
"Expected plugin_data.%s to be set. Got: %s"
% (data_oneof_field, plugin_data)
)
return getattr(plugin_data, data_oneof_field)

View File

@ -0,0 +1,42 @@
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Functions for dealing with metrics."""
import os
from tensorboard.plugins.hparams import api_pb2
def run_tag_from_session_and_metric(session_name, metric_name):
"""Returns a (run,tag) tuple storing the evaluations of the specified
metric.
Args:
session_name: str.
metric_name: MetricName protobuffer.
Returns: (run, tag) tuple.
"""
assert isinstance(session_name, str)
assert isinstance(metric_name, api_pb2.MetricName)
# os.path.join() will append a final slash if the group is empty; it seems
# like multiplexer.Tensors won't recognize paths that end with a '/' so
# we remove the final '/' in that case.
run = os.path.join(session_name, metric_name.group)
if run.endswith("/"):
run = run[:-1]
tag = metric_name.tag
return run, tag

View File

@ -0,0 +1,69 @@
# -*- coding: utf-8 -*-
# Generated by the protocol buffer compiler. DO NOT EDIT!
# source: tensorboard/plugins/hparams/plugin_data.proto
"""Generated protocol buffer code."""
from google.protobuf import descriptor as _descriptor
from google.protobuf import descriptor_pool as _descriptor_pool
from google.protobuf import message as _message
from google.protobuf import reflection as _reflection
from google.protobuf import symbol_database as _symbol_database
# @@protoc_insertion_point(imports)
_sym_db = _symbol_database.Default()
from tensorboard.plugins.hparams import api_pb2 as tensorboard_dot_plugins_dot_hparams_dot_api__pb2
from google.protobuf import struct_pb2 as google_dot_protobuf_dot_struct__pb2
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n-tensorboard/plugins/hparams/plugin_data.proto\x12\x13tensorboard.hparams\x1a%tensorboard/plugins/hparams/api.proto\x1a\x1cgoogle/protobuf/struct.proto\"\xe9\x01\n\x11HParamsPluginData\x12\x0f\n\x07version\x18\x01 \x01(\x05\x12\x35\n\nexperiment\x18\x02 \x01(\x0b\x32\x1f.tensorboard.hparams.ExperimentH\x00\x12\x43\n\x12session_start_info\x18\x03 \x01(\x0b\x32%.tensorboard.hparams.SessionStartInfoH\x00\x12?\n\x10session_end_info\x18\x04 \x01(\x0b\x32#.tensorboard.hparams.SessionEndInfoH\x00\x42\x06\n\x04\x64\x61ta\"\xf4\x01\n\x10SessionStartInfo\x12\x43\n\x07hparams\x18\x01 \x03(\x0b\x32\x32.tensorboard.hparams.SessionStartInfo.HparamsEntry\x12\x11\n\tmodel_uri\x18\x02 \x01(\t\x12\x13\n\x0bmonitor_url\x18\x03 \x01(\t\x12\x12\n\ngroup_name\x18\x04 \x01(\t\x12\x17\n\x0fstart_time_secs\x18\x05 \x01(\x01\x1a\x46\n\x0cHparamsEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12%\n\x05value\x18\x02 \x01(\x0b\x32\x16.google.protobuf.Value:\x02\x38\x01\"T\n\x0eSessionEndInfo\x12+\n\x06status\x18\x01 \x01(\x0e\x32\x1b.tensorboard.hparams.Status\x12\x15\n\rend_time_secs\x18\x02 \x01(\x01\x62\x06proto3')
_HPARAMSPLUGINDATA = DESCRIPTOR.message_types_by_name['HParamsPluginData']
_SESSIONSTARTINFO = DESCRIPTOR.message_types_by_name['SessionStartInfo']
_SESSIONSTARTINFO_HPARAMSENTRY = _SESSIONSTARTINFO.nested_types_by_name['HparamsEntry']
_SESSIONENDINFO = DESCRIPTOR.message_types_by_name['SessionEndInfo']
HParamsPluginData = _reflection.GeneratedProtocolMessageType('HParamsPluginData', (_message.Message,), {
'DESCRIPTOR' : _HPARAMSPLUGINDATA,
'__module__' : 'tensorboard.plugins.hparams.plugin_data_pb2'
# @@protoc_insertion_point(class_scope:tensorboard.hparams.HParamsPluginData)
})
_sym_db.RegisterMessage(HParamsPluginData)
SessionStartInfo = _reflection.GeneratedProtocolMessageType('SessionStartInfo', (_message.Message,), {
'HparamsEntry' : _reflection.GeneratedProtocolMessageType('HparamsEntry', (_message.Message,), {
'DESCRIPTOR' : _SESSIONSTARTINFO_HPARAMSENTRY,
'__module__' : 'tensorboard.plugins.hparams.plugin_data_pb2'
# @@protoc_insertion_point(class_scope:tensorboard.hparams.SessionStartInfo.HparamsEntry)
})
,
'DESCRIPTOR' : _SESSIONSTARTINFO,
'__module__' : 'tensorboard.plugins.hparams.plugin_data_pb2'
# @@protoc_insertion_point(class_scope:tensorboard.hparams.SessionStartInfo)
})
_sym_db.RegisterMessage(SessionStartInfo)
_sym_db.RegisterMessage(SessionStartInfo.HparamsEntry)
SessionEndInfo = _reflection.GeneratedProtocolMessageType('SessionEndInfo', (_message.Message,), {
'DESCRIPTOR' : _SESSIONENDINFO,
'__module__' : 'tensorboard.plugins.hparams.plugin_data_pb2'
# @@protoc_insertion_point(class_scope:tensorboard.hparams.SessionEndInfo)
})
_sym_db.RegisterMessage(SessionEndInfo)
if _descriptor._USE_C_DESCRIPTORS == False:
DESCRIPTOR._options = None
_SESSIONSTARTINFO_HPARAMSENTRY._options = None
_SESSIONSTARTINFO_HPARAMSENTRY._serialized_options = b'8\001'
_HPARAMSPLUGINDATA._serialized_start=140
_HPARAMSPLUGINDATA._serialized_end=373
_SESSIONSTARTINFO._serialized_start=376
_SESSIONSTARTINFO._serialized_end=620
_SESSIONSTARTINFO_HPARAMSENTRY._serialized_start=550
_SESSIONSTARTINFO_HPARAMSENTRY._serialized_end=620
_SESSIONENDINFO._serialized_start=622
_SESSIONENDINFO._serialized_end=706
# @@protoc_insertion_point(module_scope)

View File

@ -0,0 +1,206 @@
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Summary creation methods for the HParams plugin.
Typical usage for exporting summaries in a hyperparameters-tuning experiment:
1. Create the experiment (once) by calling experiment_pb() and exporting
the resulting summary into a top-level (empty) run.
2. In each training session in the experiment, call session_start_pb() before
the session starts, exporting the resulting summary into a uniquely named
run for the session, say <session_name>.
3. Train the model in the session, exporting each metric as a scalar summary
in runs of the form <session_name>/<sub_dir>, where <sub_dir> can be empty a
(in which case the run is just the <session_name>) and depends on the
metric. The name of such a metric is a (group, tag) pair given by
(<sub_dir>, tag) where tag is the tag of the scalar summary.
When calling experiment_pb in step 1, you'll need to pass all the metric
names used in the experiemnt.
4. When the session completes, call session_end_pb() and export the resulting
summary into the same session run <session_name>.
"""
import time
import tensorflow as tf
from tensorboard.plugins.hparams import api_pb2
from tensorboard.plugins.hparams import metadata
from tensorboard.plugins.hparams import plugin_data_pb2
def experiment_pb(
hparam_infos, metric_infos, user="", description="", time_created_secs=None
):
"""Creates a summary that defines a hyperparameter-tuning experiment.
Args:
hparam_infos: Array of api_pb2.HParamInfo messages. Describes the
hyperparameters used in the experiment.
metric_infos: Array of api_pb2.MetricInfo messages. Describes the metrics
used in the experiment. See the documentation at the top of this file
for how to populate this.
user: String. An id for the user running the experiment
description: String. A description for the experiment. May contain markdown.
time_created_secs: float. The time the experiment is created in seconds
since the UNIX epoch. If None uses the current time.
Returns:
A summary protobuffer containing the experiment definition.
"""
if time_created_secs is None:
time_created_secs = time.time()
experiment = api_pb2.Experiment(
description=description,
user=user,
time_created_secs=time_created_secs,
hparam_infos=hparam_infos,
metric_infos=metric_infos,
)
return _summary(
metadata.EXPERIMENT_TAG,
plugin_data_pb2.HParamsPluginData(experiment=experiment),
)
def session_start_pb(
hparams, model_uri="", monitor_url="", group_name="", start_time_secs=None
):
"""Constructs a SessionStartInfo protobuffer.
Creates a summary that contains a training session metadata information.
One such summary per training session should be created. Each should have
a different run.
Args:
hparams: A dictionary with string keys. Describes the hyperparameter values
used in the session, mapping each hyperparameter name to its value.
Supported value types are `bool`, `int`, `float`, `str`, `list`,
`tuple`.
The type of value must correspond to the type of hyperparameter
(defined in the corresponding api_pb2.HParamInfo member of the
Experiment protobuf) as follows:
+-----------------+---------------------------------+
|Hyperparameter | Allowed (Python) value types |
|type | |
+-----------------+---------------------------------+
|DATA_TYPE_BOOL | bool |
|DATA_TYPE_FLOAT64| int, float |
|DATA_TYPE_STRING | str, tuple, list |
+-----------------+---------------------------------+
Tuple and list instances will be converted to their string
representation.
model_uri: See the comment for the field with the same name of
plugin_data_pb2.SessionStartInfo.
monitor_url: See the comment for the field with the same name of
plugin_data_pb2.SessionStartInfo.
group_name: See the comment for the field with the same name of
plugin_data_pb2.SessionStartInfo.
start_time_secs: float. The time to use as the session start time.
Represented as seconds since the UNIX epoch. If None uses
the current time.
Returns:
The summary protobuffer mentioned above.
"""
if start_time_secs is None:
start_time_secs = time.time()
session_start_info = plugin_data_pb2.SessionStartInfo(
model_uri=model_uri,
monitor_url=monitor_url,
group_name=group_name,
start_time_secs=start_time_secs,
)
for hp_name, hp_val in hparams.items():
# Boolean typed values need to be checked before integers since in Python
# isinstance(True/False, int) returns True.
if isinstance(hp_val, bool):
session_start_info.hparams[hp_name].bool_value = hp_val
elif isinstance(hp_val, (float, int)):
session_start_info.hparams[hp_name].number_value = hp_val
elif isinstance(hp_val, str):
session_start_info.hparams[hp_name].string_value = hp_val
elif isinstance(hp_val, (list, tuple)):
session_start_info.hparams[hp_name].string_value = str(hp_val)
else:
raise TypeError(
"hparams[%s]=%s has type: %s which is not supported"
% (hp_name, hp_val, type(hp_val))
)
return _summary(
metadata.SESSION_START_INFO_TAG,
plugin_data_pb2.HParamsPluginData(
session_start_info=session_start_info
),
)
def session_end_pb(status, end_time_secs=None):
"""Constructs a SessionEndInfo protobuffer.
Creates a summary that contains status information for a completed
training session. Should be exported after the training session is completed.
One such summary per training session should be created. Each should have
a different run.
Args:
status: A tensorboard.hparams.Status enumeration value denoting the
status of the session.
end_time_secs: float. The time to use as the session end time. Represented
as seconds since the unix epoch. If None uses the current time.
Returns:
The summary protobuffer mentioned above.
"""
if end_time_secs is None:
end_time_secs = time.time()
session_end_info = plugin_data_pb2.SessionEndInfo(
status=status, end_time_secs=end_time_secs
)
return _summary(
metadata.SESSION_END_INFO_TAG,
plugin_data_pb2.HParamsPluginData(session_end_info=session_end_info),
)
def _summary(tag, hparams_plugin_data):
"""Returns a summary holding the given HParamsPluginData message.
Helper function.
Args:
tag: string. The tag to use.
hparams_plugin_data: The HParamsPluginData message to use.
"""
summary = tf.compat.v1.Summary()
tb_metadata = metadata.create_summary_metadata(hparams_plugin_data)
raw_metadata = tb_metadata.SerializeToString()
tf_metadata = tf.compat.v1.SummaryMetadata.FromString(raw_metadata)
summary.value.add(
tag=tag,
metadata=tf_metadata,
tensor=_TF_NULL_TENSOR,
)
return summary
# Like `metadata.NULL_TENSOR`, but with the TensorFlow version of the
# proto. Slight kludge needed to expose the `TensorProto` type.
_TF_NULL_TENSOR = type(tf.make_tensor_proto(0)).FromString(
metadata.NULL_TENSOR.SerializeToString()
)

View File

@ -0,0 +1,598 @@
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Experimental public APIs for the HParams plugin.
These are porcelain on top of `api_pb2` (`api.proto`) and `summary.py`.
"""
import abc
import hashlib
import json
import random
import time
import numpy as np
from tensorboard.compat import tf2 as tf
from tensorboard.compat.proto import summary_pb2
from tensorboard.plugins.hparams import api_pb2
from tensorboard.plugins.hparams import metadata
from tensorboard.plugins.hparams import plugin_data_pb2
def hparams(hparams, trial_id=None, start_time_secs=None):
# NOTE: Keep docs in sync with `hparams_pb` below.
"""Write hyperparameter values for a single trial.
Args:
hparams: A `dict` mapping hyperparameters to the values used in this
trial. Keys should be the names of `HParam` objects used in an
experiment, or the `HParam` objects themselves. Values should be
Python `bool`, `int`, `float`, or `string` values, depending on
the type of the hyperparameter. The corresponding numpy types,
like `np.float32`, are also permitted.
trial_id: An optional `str` ID for the set of hyperparameter values
used in this trial. Defaults to a hash of the hyperparameters.
start_time_secs: The time that this trial started training, as
seconds since epoch. Defaults to the current time.
Returns:
A tensor whose value is `True` on success, or `False` if no summary
was written because no default summary writer was available.
"""
pb = hparams_pb(
hparams=hparams,
trial_id=trial_id,
start_time_secs=start_time_secs,
)
return _write_summary("hparams", pb)
def hparams_pb(hparams, trial_id=None, start_time_secs=None):
# NOTE: Keep docs in sync with `hparams` above.
"""Create a summary encoding hyperparameter values for a single trial.
Args:
hparams: A `dict` mapping hyperparameters to the values used in this
trial. Keys should be the names of `HParam` objects used in an
experiment, or the `HParam` objects themselves. Values should be
Python `bool`, `int`, `float`, or `string` values, depending on
the type of the hyperparameter.
trial_id: An optional `str` ID for the set of hyperparameter values
used in this trial. Defaults to a hash of the hyperparameters.
start_time_secs: The time that this trial started training, as
seconds since epoch. Defaults to the current time.
Returns:
A TensorBoard `summary_pb2.Summary` message.
"""
if start_time_secs is None:
start_time_secs = time.time()
hparams = _normalize_hparams(hparams)
group_name = _derive_session_group_name(trial_id, hparams)
session_start_info = plugin_data_pb2.SessionStartInfo(
group_name=group_name,
start_time_secs=start_time_secs,
)
for hp_name in sorted(hparams):
hp_value = hparams[hp_name]
if isinstance(hp_value, bool):
session_start_info.hparams[hp_name].bool_value = hp_value
elif isinstance(hp_value, (float, int)):
session_start_info.hparams[hp_name].number_value = hp_value
elif isinstance(hp_value, str):
session_start_info.hparams[hp_name].string_value = hp_value
else:
raise TypeError(
"hparams[%r] = %r, of unsupported type %r"
% (hp_name, hp_value, type(hp_value))
)
return _summary_pb(
metadata.SESSION_START_INFO_TAG,
plugin_data_pb2.HParamsPluginData(
session_start_info=session_start_info
),
)
def hparams_config(hparams, metrics, time_created_secs=None):
# NOTE: Keep docs in sync with `hparams_config_pb` below.
"""Write a top-level experiment configuration.
This configuration describes the hyperparameters and metrics that will
be tracked in the experiment, but does not record any actual values of
those hyperparameters and metrics. It can be created before any models
are actually trained.
Args:
hparams: A list of `HParam` values.
metrics: A list of `Metric` values.
time_created_secs: The time that this experiment was created, as
seconds since epoch. Defaults to the current time.
Returns:
A tensor whose value is `True` on success, or `False` if no summary
was written because no default summary writer was available.
"""
pb = hparams_config_pb(
hparams=hparams,
metrics=metrics,
time_created_secs=time_created_secs,
)
return _write_summary("hparams_config", pb)
def hparams_config_pb(hparams, metrics, time_created_secs=None):
# NOTE: Keep docs in sync with `hparams_config` above.
"""Create a top-level experiment configuration.
This configuration describes the hyperparameters and metrics that will
be tracked in the experiment, but does not record any actual values of
those hyperparameters and metrics. It can be created before any models
are actually trained.
Args:
hparams: A list of `HParam` values.
metrics: A list of `Metric` values.
time_created_secs: The time that this experiment was created, as
seconds since epoch. Defaults to the current time.
Returns:
A TensorBoard `summary_pb2.Summary` message.
"""
hparam_infos = []
for hparam in hparams:
info = api_pb2.HParamInfo(
name=hparam.name,
description=hparam.description,
display_name=hparam.display_name,
)
domain = hparam.domain
if domain is not None:
domain.update_hparam_info(info)
hparam_infos.append(info)
metric_infos = [metric.as_proto() for metric in metrics]
experiment = api_pb2.Experiment(
hparam_infos=hparam_infos,
metric_infos=metric_infos,
time_created_secs=time_created_secs,
)
return _summary_pb(
metadata.EXPERIMENT_TAG,
plugin_data_pb2.HParamsPluginData(experiment=experiment),
)
def _normalize_hparams(hparams):
"""Normalize a dict keyed by `HParam`s and/or raw strings.
Args:
hparams: A `dict` whose keys are `HParam` objects and/or strings
representing hyperparameter names, and whose values are
hyperparameter values. No two keys may have the same name.
Returns:
A `dict` whose keys are hyperparameter names (as strings) and whose
values are the corresponding hyperparameter values, after numpy
normalization (see `_normalize_numpy_value`).
Raises:
ValueError: If two entries in `hparams` share the same
hyperparameter name.
"""
result = {}
for k, v in hparams.items():
if isinstance(k, HParam):
k = k.name
if k in result:
raise ValueError("multiple values specified for hparam %r" % (k,))
result[k] = _normalize_numpy_value(v)
return result
def _normalize_numpy_value(value):
"""Convert a Python or Numpy scalar to a Python scalar.
For instance, `3.0`, `np.float32(3.0)`, and `np.float64(3.0)` all
map to `3.0`.
Args:
value: A Python scalar (`int`, `float`, `str`, or `bool`) or
rank-0 `numpy` equivalent (e.g., `np.int64`, `np.float32`).
Returns:
A Python scalar equivalent to `value`.
"""
if isinstance(value, np.generic):
return value.item()
else:
return value
def _derive_session_group_name(trial_id, hparams):
if trial_id is not None:
if not isinstance(trial_id, str):
raise TypeError(
"`trial_id` should be a `str`, but got: %r" % (trial_id,)
)
return trial_id
# Use `json.dumps` rather than `str` to ensure invariance under string
# type (incl. across Python versions) and dict iteration order.
jparams = json.dumps(hparams, sort_keys=True, separators=(",", ":"))
return hashlib.sha256(jparams.encode("utf-8")).hexdigest()
def _write_summary(name, pb):
"""Write a summary, returning the writing op.
Args:
name: As passed to `summary_scope`.
pb: A `summary_pb2.Summary` message.
Returns:
A tensor whose value is `True` on success, or `False` if no summary
was written because no default summary writer was available.
"""
raw_pb = pb.SerializeToString()
summary_scope = (
getattr(tf.summary.experimental, "summary_scope", None)
or tf.summary.summary_scope
)
with summary_scope(name):
return tf.summary.experimental.write_raw_pb(raw_pb, step=0)
def _summary_pb(tag, hparams_plugin_data):
"""Create a summary holding the given `HParamsPluginData` message.
Args:
tag: The `str` tag to use.
hparams_plugin_data: The `HParamsPluginData` message to use.
Returns:
A TensorBoard `summary_pb2.Summary` message.
"""
summary = summary_pb2.Summary()
summary_metadata = metadata.create_summary_metadata(hparams_plugin_data)
value = summary.value.add(
tag=tag, metadata=summary_metadata, tensor=metadata.NULL_TENSOR
)
return summary
class HParam:
"""A hyperparameter in an experiment.
This class describes a hyperparameter in the abstract. It ranges
over a domain of values, but is not bound to any particular value.
"""
def __init__(self, name, domain=None, display_name=None, description=None):
"""Create a hyperparameter object.
Args:
name: A string ID for this hyperparameter, which should be unique
within an experiment.
domain: An optional `Domain` object describing the values that
this hyperparameter can take on.
display_name: An optional human-readable display name (`str`).
description: An optional Markdown string describing this
hyperparameter.
Raises:
ValueError: If `domain` is not a `Domain`.
"""
self._name = name
self._domain = domain
self._display_name = display_name
self._description = description
if not isinstance(self._domain, (Domain, type(None))):
raise ValueError("not a domain: %r" % (self._domain,))
def __str__(self):
return "<HParam %r: %s>" % (self._name, self._domain)
def __repr__(self):
fields = [
("name", self._name),
("domain", self._domain),
("display_name", self._display_name),
("description", self._description),
]
fields_string = ", ".join("%s=%r" % (k, v) for (k, v) in fields)
return "HParam(%s)" % fields_string
@property
def name(self):
return self._name
@property
def domain(self):
return self._domain
@property
def display_name(self):
return self._display_name
@property
def description(self):
return self._description
class Domain(metaclass=abc.ABCMeta):
"""The domain of a hyperparameter.
Domains are restricted to values of the simple types `float`, `int`,
`str`, and `bool`.
"""
@abc.abstractproperty
def dtype(self):
"""Data type of this domain: `float`, `int`, `str`, or `bool`."""
pass
@abc.abstractmethod
def sample_uniform(self, rng=random):
"""Sample a value from this domain uniformly at random.
Args:
rng: A `random.Random` interface; defaults to the `random` module
itself.
Raises:
IndexError: If the domain is empty.
"""
pass
@abc.abstractmethod
def update_hparam_info(self, hparam_info):
"""Update an `HParamInfo` proto to include this domain.
This should update the `type` field on the proto and exactly one of
the `domain` variants on the proto.
Args:
hparam_info: An `api_pb2.HParamInfo` proto to modify.
"""
pass
class IntInterval(Domain):
"""A domain that takes on all integer values in a closed interval."""
def __init__(self, min_value=None, max_value=None):
"""Create an `IntInterval`.
Args:
min_value: The lower bound (inclusive) of the interval.
max_value: The upper bound (inclusive) of the interval.
Raises:
TypeError: If `min_value` or `max_value` is not an `int`.
ValueError: If `min_value > max_value`.
"""
if not isinstance(min_value, int):
raise TypeError("min_value must be an int: %r" % (min_value,))
if not isinstance(max_value, int):
raise TypeError("max_value must be an int: %r" % (max_value,))
if min_value > max_value:
raise ValueError("%r > %r" % (min_value, max_value))
self._min_value = min_value
self._max_value = max_value
def __str__(self):
return "[%s, %s]" % (self._min_value, self._max_value)
def __repr__(self):
return "IntInterval(%r, %r)" % (self._min_value, self._max_value)
@property
def dtype(self):
return int
@property
def min_value(self):
return self._min_value
@property
def max_value(self):
return self._max_value
def sample_uniform(self, rng=random):
return rng.randint(self._min_value, self._max_value)
def update_hparam_info(self, hparam_info):
hparam_info.type = (
api_pb2.DATA_TYPE_FLOAT64
) # TODO(#1998): Add int dtype.
hparam_info.domain_interval.min_value = self._min_value
hparam_info.domain_interval.max_value = self._max_value
class RealInterval(Domain):
"""A domain that takes on all real values in a closed interval."""
def __init__(self, min_value=None, max_value=None):
"""Create a `RealInterval`.
Args:
min_value: The lower bound (inclusive) of the interval.
max_value: The upper bound (inclusive) of the interval.
Raises:
TypeError: If `min_value` or `max_value` is not an `float`.
ValueError: If `min_value > max_value`.
"""
if not isinstance(min_value, float):
raise TypeError("min_value must be a float: %r" % (min_value,))
if not isinstance(max_value, float):
raise TypeError("max_value must be a float: %r" % (max_value,))
if min_value > max_value:
raise ValueError("%r > %r" % (min_value, max_value))
self._min_value = min_value
self._max_value = max_value
def __str__(self):
return "[%s, %s]" % (self._min_value, self._max_value)
def __repr__(self):
return "RealInterval(%r, %r)" % (self._min_value, self._max_value)
@property
def dtype(self):
return float
@property
def min_value(self):
return self._min_value
@property
def max_value(self):
return self._max_value
def sample_uniform(self, rng=random):
return rng.uniform(self._min_value, self._max_value)
def update_hparam_info(self, hparam_info):
hparam_info.type = api_pb2.DATA_TYPE_FLOAT64
hparam_info.domain_interval.min_value = self._min_value
hparam_info.domain_interval.max_value = self._max_value
class Discrete(Domain):
"""A domain that takes on a fixed set of values.
These values may be of any (single) domain type.
"""
def __init__(self, values, dtype=None):
"""Construct a discrete domain.
Args:
values: A iterable of the values in this domain.
dtype: The Python data type of values in this domain: one of
`int`, `float`, `bool`, or `str`. If `values` is non-empty,
`dtype` may be `None`, in which case it will be inferred as the
type of the first element of `values`.
Raises:
ValueError: If `values` is empty but no `dtype` is specified.
ValueError: If `dtype` or its inferred value is not `int`,
`float`, `bool`, or `str`.
TypeError: If an element of `values` is not an instance of
`dtype`.
"""
self._values = list(values)
if dtype is None:
if self._values:
dtype = type(self._values[0])
else:
raise ValueError("Empty domain with no dtype specified")
if dtype not in (int, float, bool, str):
raise ValueError("Unknown dtype: %r" % (dtype,))
self._dtype = dtype
for value in self._values:
if not isinstance(value, self._dtype):
raise TypeError(
"dtype mismatch: not isinstance(%r, %s)"
% (value, self._dtype.__name__)
)
self._values.sort()
def __str__(self):
return "{%s}" % (", ".join(repr(x) for x in self._values))
def __repr__(self):
return "Discrete(%r)" % (self._values,)
@property
def dtype(self):
return self._dtype
@property
def values(self):
return list(self._values)
def sample_uniform(self, rng=random):
return rng.choice(self._values)
def update_hparam_info(self, hparam_info):
hparam_info.type = {
int: api_pb2.DATA_TYPE_FLOAT64, # TODO(#1998): Add int dtype.
float: api_pb2.DATA_TYPE_FLOAT64,
bool: api_pb2.DATA_TYPE_BOOL,
str: api_pb2.DATA_TYPE_STRING,
}[self._dtype]
hparam_info.ClearField("domain_discrete")
hparam_info.domain_discrete.extend(self._values)
class Metric:
"""A metric in an experiment.
A metric is a real-valued function of a model. Each metric is
associated with a TensorBoard scalar summary, which logs the
metric's value as the model trains.
"""
TRAINING = api_pb2.DATASET_TRAINING
VALIDATION = api_pb2.DATASET_VALIDATION
def __init__(
self,
tag,
group=None,
display_name=None,
description=None,
dataset_type=None,
):
"""
Args:
tag: The tag name of the scalar summary that corresponds to this
metric (as a `str`).
group: An optional string listing the subdirectory under the
session's log directory containing summaries for this metric.
For instance, if summaries for training runs are written to
events files in `ROOT_LOGDIR/SESSION_ID/train`, then `group`
should be `"train"`. Defaults to the empty string: i.e.,
summaries are expected to be written to the session logdir.
display_name: An optional human-readable display name.
description: An optional Markdown string with a human-readable
description of this metric, to appear in TensorBoard.
dataset_type: Either `Metric.TRAINING` or `Metric.VALIDATION`, or
`None`.
"""
self._tag = tag
self._group = group
self._display_name = display_name
self._description = description
self._dataset_type = dataset_type
if self._dataset_type not in (None, Metric.TRAINING, Metric.VALIDATION):
raise ValueError("invalid dataset type: %r" % (self._dataset_type,))
def as_proto(self):
return api_pb2.MetricInfo(
name=api_pb2.MetricName(
group=self._group,
tag=self._tag,
),
display_name=self._display_name,
description=self._description,
dataset_type=self._dataset_type,
)