I am done

This commit is contained in:
2024-10-30 22:14:35 +01:00
parent 720dc28c09
commit 40e2a747cf
36901 changed files with 5011519 additions and 0 deletions

View File

@ -0,0 +1,129 @@
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Utilities for graph plugin."""
from tensorboard.compat.proto import graph_pb2
def _prefixed_op_name(prefix, op_name):
return "%s/%s" % (prefix, op_name)
def _prefixed_func_name(prefix, func_name):
"""Returns function name prefixed with `prefix`.
For function libraries, which are often created out of autographed Python
function, are factored out in the graph vis. They are grouped under a
function name which often has a shape of
`__inference_[py_func_name]_[numeric_suffix]`.
While it does not have some unique information about which graph it is from,
creating another wrapping structure with graph prefix and "/" is less than
ideal so we join the prefix and func_name using underscore.
TODO(stephanwlee): add business logic to strip "__inference_" for more user
friendlier name
"""
return "%s_%s" % (prefix, func_name)
def _add_with_prepended_names(prefix, graph_to_add, destination_graph):
for node in graph_to_add.node:
new_node = destination_graph.node.add()
new_node.CopyFrom(node)
new_node.name = _prefixed_op_name(prefix, node.name)
new_node.input[:] = [
_prefixed_op_name(prefix, input_name) for input_name in node.input
]
# Remap tf.function method name in the PartitionedCall. 'f' is short for
# function.
if new_node.op == "PartitionedCall" and new_node.attr["f"]:
new_node.attr["f"].func.name = _prefixed_func_name(
prefix,
new_node.attr["f"].func.name,
)
for func in graph_to_add.library.function:
new_func = destination_graph.library.function.add()
new_func.CopyFrom(func)
new_func.signature.name = _prefixed_func_name(
prefix, new_func.signature.name
)
for gradient in graph_to_add.library.gradient:
new_gradient = destination_graph.library.gradient.add()
new_gradient.CopyFrom(gradient)
new_gradient.function_name = _prefixed_func_name(
prefix,
new_gradient.function_name,
)
new_gradient.gradient_func = _prefixed_func_name(
prefix,
new_gradient.gradient_func,
)
def merge_graph_defs(graph_defs):
"""Merges GraphDefs by adding unique prefix, `graph_{ind}`, to names.
All GraphDefs are expected to be of TensorBoard's.
When collecting graphs using the `tf.summary.trace` API, node names are not
guranteed to be unique. When non-unique names are not considered, it can
lead to graph visualization showing them as one which creates inaccurate
depiction of the flow of the graph (e.g., if there are A -> B -> C and D ->
B -> E, you may see {A, D} -> B -> E). To prevent such graph, we checked
for uniquenss while merging but it resulted in
https://github.com/tensorflow/tensorboard/issues/1929.
To remedy these issues, we simply "apply name scope" on each graph by
prefixing it with unique name (with a chance of collision) to create
unconnected group of graphs.
In case there is only one graph def passed, it returns the original
graph_def. In case no graph defs are passed, it returns an empty GraphDef.
Args:
graph_defs: TensorBoard GraphDefs to merge.
Returns:
TensorBoard GraphDef that merges all graph_defs with unique prefixes.
Raises:
ValueError in case GraphDef versions mismatch.
"""
if len(graph_defs) == 1:
return graph_defs[0]
elif len(graph_defs) == 0:
return graph_pb2.GraphDef()
dst_graph_def = graph_pb2.GraphDef()
if graph_defs[0].versions.producer:
dst_graph_def.versions.CopyFrom(graph_defs[0].versions)
for index, graph_def in enumerate(graph_defs):
if dst_graph_def.versions.producer != graph_def.versions.producer:
raise ValueError("Cannot combine GraphDefs of different versions.")
_add_with_prepended_names(
"graph_%d" % (index + 1),
graph_def,
dst_graph_def,
)
return dst_graph_def

View File

@ -0,0 +1,337 @@
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""The TensorBoard Graphs plugin."""
import json
from werkzeug import wrappers
from tensorboard import errors
from tensorboard import plugin_util
from tensorboard.backend import http_util
from tensorboard.backend import process_graph
from tensorboard.compat.proto import config_pb2
from tensorboard.compat.proto import graph_pb2
from tensorboard.data import provider
from tensorboard.plugins import base_plugin
from tensorboard.plugins.graph import graph_util
from tensorboard.plugins.graph import keras_util
from tensorboard.plugins.graph import metadata
from tensorboard.util import tb_logging
logger = tb_logging.get_logger()
class GraphsPlugin(base_plugin.TBPlugin):
"""Graphs Plugin for TensorBoard."""
plugin_name = metadata.PLUGIN_NAME
def __init__(self, context):
"""Instantiates GraphsPlugin via TensorBoard core.
Args:
context: A base_plugin.TBContext instance.
"""
self._data_provider = context.data_provider
def get_plugin_apps(self):
return {
"/graph": self.graph_route,
"/info": self.info_route,
"/run_metadata": self.run_metadata_route,
}
def is_active(self):
"""The graphs plugin is active iff any run has a graph or metadata."""
return False # `list_plugins` as called by TB core suffices
def data_plugin_names(self):
return (
metadata.PLUGIN_NAME,
metadata.PLUGIN_NAME_RUN_METADATA,
metadata.PLUGIN_NAME_RUN_METADATA_WITH_GRAPH,
metadata.PLUGIN_NAME_KERAS_MODEL,
metadata.PLUGIN_NAME_TAGGED_RUN_METADATA,
)
def frontend_metadata(self):
return base_plugin.FrontendMetadata(
element_name="tf-graph-dashboard",
# TODO(@chihuahua): Reconcile this setting with Health Pills.
disable_reload=True,
)
def info_impl(self, ctx, experiment=None):
"""Returns a dict of all runs and their data availabilities."""
result = {}
def add_row_item(run, tag=None):
run_item = result.setdefault(
run,
{
"run": run,
"tags": {},
# A run-wide GraphDef of ops.
"run_graph": False,
},
)
tag_item = None
if tag:
tag_item = run_item.get("tags").setdefault(
tag,
{
"tag": tag,
"conceptual_graph": False,
# A tagged GraphDef of ops.
"op_graph": False,
"profile": False,
},
)
return (run_item, tag_item)
mapping = self._data_provider.list_blob_sequences(
ctx,
experiment_id=experiment,
plugin_name=metadata.PLUGIN_NAME_RUN_METADATA_WITH_GRAPH,
)
for run_name, tags in mapping.items():
for tag, tag_data in tags.items():
# The Summary op is defined in TensorFlow and does not use a stringified proto
# as a content of plugin data. It contains single string that denotes a version.
# https://github.com/tensorflow/tensorflow/blob/11f4ecb54708865ec757ca64e4805957b05d7570/tensorflow/python/ops/summary_ops_v2.py#L789-L790
if tag_data.plugin_content != b"1":
logger.warning(
"Ignoring unrecognizable version of RunMetadata."
)
continue
(_, tag_item) = add_row_item(run_name, tag)
tag_item["op_graph"] = True
# Tensors associated with plugin name metadata.PLUGIN_NAME_RUN_METADATA
# contain both op graph and profile information.
mapping = self._data_provider.list_blob_sequences(
ctx,
experiment_id=experiment,
plugin_name=metadata.PLUGIN_NAME_RUN_METADATA,
)
for run_name, tags in mapping.items():
for tag, tag_data in tags.items():
if tag_data.plugin_content != b"1":
logger.warning(
"Ignoring unrecognizable version of RunMetadata."
)
continue
(_, tag_item) = add_row_item(run_name, tag)
tag_item["profile"] = True
tag_item["op_graph"] = True
# Tensors associated with plugin name metadata.PLUGIN_NAME_KERAS_MODEL
# contain serialized Keras model in JSON format.
mapping = self._data_provider.list_blob_sequences(
ctx,
experiment_id=experiment,
plugin_name=metadata.PLUGIN_NAME_KERAS_MODEL,
)
for run_name, tags in mapping.items():
for tag, tag_data in tags.items():
if tag_data.plugin_content != b"1":
logger.warning(
"Ignoring unrecognizable version of RunMetadata."
)
continue
(_, tag_item) = add_row_item(run_name, tag)
tag_item["conceptual_graph"] = True
mapping = self._data_provider.list_blob_sequences(
ctx,
experiment_id=experiment,
plugin_name=metadata.PLUGIN_NAME,
)
for run_name, tags in mapping.items():
if metadata.RUN_GRAPH_NAME in tags:
(run_item, _) = add_row_item(run_name, None)
run_item["run_graph"] = True
# Top level `Event.tagged_run_metadata` represents profile data only.
mapping = self._data_provider.list_blob_sequences(
ctx,
experiment_id=experiment,
plugin_name=metadata.PLUGIN_NAME_TAGGED_RUN_METADATA,
)
for run_name, tags in mapping.items():
for tag in tags:
(_, tag_item) = add_row_item(run_name, tag)
tag_item["profile"] = True
return result
def _read_blob(self, ctx, experiment, plugin_names, run, tag):
for plugin_name in plugin_names:
blob_sequences = self._data_provider.read_blob_sequences(
ctx,
experiment_id=experiment,
plugin_name=plugin_name,
run_tag_filter=provider.RunTagFilter(runs=[run], tags=[tag]),
downsample=1,
)
blob_sequence_data = blob_sequences.get(run, {}).get(tag, ())
try:
blob_ref = blob_sequence_data[0].values[0]
except IndexError:
continue
return self._data_provider.read_blob(
ctx, blob_key=blob_ref.blob_key
)
raise errors.NotFoundError()
def graph_impl(
self,
ctx,
run,
tag,
is_conceptual,
experiment=None,
limit_attr_size=None,
large_attrs_key=None,
):
"""Result of the form `(body, mime_type)`; may raise `NotFound`."""
if is_conceptual:
keras_model_config = json.loads(
self._read_blob(
ctx,
experiment,
[metadata.PLUGIN_NAME_KERAS_MODEL],
run,
tag,
)
)
graph = keras_util.keras_model_to_graph_def(keras_model_config)
elif tag is None:
graph_raw = self._read_blob(
ctx,
experiment,
[metadata.PLUGIN_NAME],
run,
metadata.RUN_GRAPH_NAME,
)
graph = graph_pb2.GraphDef.FromString(graph_raw)
else:
# Op graph: could be either of two plugins. (Cf. `info_impl`.)
plugins = [
metadata.PLUGIN_NAME_RUN_METADATA,
metadata.PLUGIN_NAME_RUN_METADATA_WITH_GRAPH,
]
raw_run_metadata = self._read_blob(
ctx, experiment, plugins, run, tag
)
run_metadata = config_pb2.RunMetadata.FromString(raw_run_metadata)
graph = graph_util.merge_graph_defs(
[
func_graph.pre_optimization_graph
for func_graph in run_metadata.function_graphs
]
)
# This next line might raise a ValueError if the limit parameters
# are invalid (size is negative, size present but key absent, etc.).
process_graph.prepare_graph_for_ui(
graph, limit_attr_size, large_attrs_key
)
return (str(graph), "text/x-protobuf") # pbtxt
def run_metadata_impl(self, ctx, experiment, run, tag):
"""Result of the form `(body, mime_type)`; may raise `NotFound`."""
# Profile graph: could be either of two plugins. (Cf. `info_impl`.)
plugins = [
metadata.PLUGIN_NAME_TAGGED_RUN_METADATA,
metadata.PLUGIN_NAME_RUN_METADATA,
]
raw_run_metadata = self._read_blob(ctx, experiment, plugins, run, tag)
run_metadata = config_pb2.RunMetadata.FromString(raw_run_metadata)
return (str(run_metadata), "text/x-protobuf") # pbtxt
@wrappers.Request.application
def info_route(self, request):
ctx = plugin_util.context(request.environ)
experiment = plugin_util.experiment_id(request.environ)
info = self.info_impl(ctx, experiment)
return http_util.Respond(request, info, "application/json")
@wrappers.Request.application
def graph_route(self, request):
"""Given a single run, return the graph definition in protobuf
format."""
ctx = plugin_util.context(request.environ)
experiment = plugin_util.experiment_id(request.environ)
run = request.args.get("run")
tag = request.args.get("tag")
conceptual_arg = request.args.get("conceptual", False)
is_conceptual = True if conceptual_arg == "true" else False
if run is None:
return http_util.Respond(
request, 'query parameter "run" is required', "text/plain", 400
)
limit_attr_size = request.args.get("limit_attr_size", None)
if limit_attr_size is not None:
try:
limit_attr_size = int(limit_attr_size)
except ValueError:
return http_util.Respond(
request,
"query parameter `limit_attr_size` must be an integer",
"text/plain",
400,
)
large_attrs_key = request.args.get("large_attrs_key", None)
try:
result = self.graph_impl(
ctx,
run,
tag,
is_conceptual,
experiment,
limit_attr_size,
large_attrs_key,
)
except ValueError as e:
return http_util.Respond(request, e.message, "text/plain", code=400)
(body, mime_type) = result
return http_util.Respond(request, body, mime_type)
@wrappers.Request.application
def run_metadata_route(self, request):
"""Given a tag and a run, return the session.run() metadata."""
ctx = plugin_util.context(request.environ)
experiment = plugin_util.experiment_id(request.environ)
tag = request.args.get("tag")
run = request.args.get("run")
if tag is None:
return http_util.Respond(
request, 'query parameter "tag" is required', "text/plain", 400
)
if run is None:
return http_util.Respond(
request, 'query parameter "run" is required', "text/plain", 400
)
(body, mime_type) = self.run_metadata_impl(ctx, experiment, run, tag)
return http_util.Respond(request, body, mime_type)

View File

@ -0,0 +1,328 @@
# -*- coding: utf-8 -*-
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Utilities for handling Keras model in graph plugin.
Two canonical types of Keras model are Functional and Sequential.
A model can be serialized as JSON and deserialized to reconstruct a model.
This utility helps with dealing with the serialized Keras model.
They have distinct structures to the configurations in shapes below:
Functional:
config
name: Name of the model. If not specified, it is 'model' with
an optional suffix if there are more than one instance.
input_layers: Keras.layers.Inputs in the model.
output_layers: Layer names that are outputs of the model.
layers: list of layer configurations.
layer: [*]
inbound_nodes: inputs to this layer.
Sequential:
config
name: Name of the model. If not specified, it is 'sequential' with
an optional suffix if there are more than one instance.
layers: list of layer configurations.
layer: [*]
[*]: Note that a model can be a layer.
Please refer to https://github.com/tensorflow/tfjs-layers/blob/master/src/keras_format/model_serialization.ts
for more complete definition.
"""
from tensorboard.compat.proto.graph_pb2 import GraphDef
from tensorboard.compat.tensorflow_stub import dtypes
from tensorboard.util import tb_logging
logger = tb_logging.get_logger()
def _walk_layers(keras_layer):
"""Walks the nested keras layer configuration in preorder.
Args:
keras_layer: Keras configuration from model.to_json.
Yields:
A tuple of (name_scope, layer_config).
name_scope: a string representing a scope name, similar to that of tf.name_scope.
layer_config: a dict representing a Keras layer configuration.
"""
yield ("", keras_layer)
if keras_layer.get("config").get("layers"):
name_scope = keras_layer.get("config").get("name")
for layer in keras_layer.get("config").get("layers"):
for sub_name_scope, sublayer in _walk_layers(layer):
sub_name_scope = (
"%s/%s" % (name_scope, sub_name_scope)
if sub_name_scope
else name_scope
)
yield (sub_name_scope, sublayer)
def _scoped_name(name_scope, node_name):
"""Returns scoped name for a node as a string in the form '<scope>/<node
name>'.
Args:
name_scope: a string representing a scope name, similar to that of tf.name_scope.
node_name: a string representing the current node name.
Returns
A string representing a scoped name.
"""
if name_scope:
return "%s/%s" % (name_scope, node_name)
return node_name
def _is_model(layer):
"""Returns True if layer is a model.
Args:
layer: a dict representing a Keras model configuration.
Returns:
bool: True if layer is a model.
"""
return layer.get("config").get("layers") is not None
def _norm_to_list_of_layers(maybe_layers):
"""Normalizes to a list of layers.
Args:
maybe_layers: A list of data[1] or a list of list of data.
Returns:
List of list of data.
[1]: A Functional model has fields 'inbound_nodes' and 'output_layers' which can
look like below:
- ['in_layer_name', 0, 0]
- [['in_layer_is_model', 1, 0], ['in_layer_is_model', 1, 1]]
The data inside the list seems to describe [name, size, index].
"""
return (
maybe_layers if isinstance(maybe_layers[0], (list,)) else [maybe_layers]
)
def _get_inbound_nodes(layer):
"""Returns a list of [name, size, index] for all inbound nodes of the given layer."""
inbound_nodes = []
if layer.get("inbound_nodes") is not None:
for maybe_inbound_node in layer.get("inbound_nodes", []):
if not isinstance(maybe_inbound_node, dict):
# Note that the inbound node parsing is not backward compatible with
# Keras 2. If given a Keras 2 model, the input nodes will be missing
# in the final graph.
continue
for inbound_node_args in maybe_inbound_node.get("args", []):
# Sometimes this field is a list when there are multiple inbound nodes
# for the given layer.
if not isinstance(inbound_node_args, list):
inbound_node_args = [inbound_node_args]
for arg in inbound_node_args:
history = arg.get("config", {}).get("keras_history", [])
if len(history) < 3:
continue
inbound_nodes.append(history[:3])
return inbound_nodes
def _update_dicts(
name_scope,
model_layer,
input_to_in_layer,
model_name_to_output,
prev_node_name,
):
"""Updates input_to_in_layer, model_name_to_output, and prev_node_name
based on the model_layer.
Args:
name_scope: a string representing a scope name, similar to that of tf.name_scope.
model_layer: a dict representing a Keras model configuration.
input_to_in_layer: a dict mapping Keras.layers.Input to inbound layer.
model_name_to_output: a dict mapping Keras Model name to output layer of the model.
prev_node_name: a string representing a previous, in sequential model layout,
node name.
Returns:
A tuple of (input_to_in_layer, model_name_to_output, prev_node_name).
input_to_in_layer: a dict mapping Keras.layers.Input to inbound layer.
model_name_to_output: a dict mapping Keras Model name to output layer of the model.
prev_node_name: a string representing a previous, in sequential model layout,
node name.
"""
layer_config = model_layer.get("config")
if not layer_config.get("layers"):
raise ValueError("layer is not a model.")
node_name = _scoped_name(name_scope, layer_config.get("name"))
input_layers = layer_config.get("input_layers")
output_layers = layer_config.get("output_layers")
inbound_nodes = _get_inbound_nodes(model_layer)
is_functional_model = bool(input_layers and output_layers)
# In case of [1] and the parent model is functional, current layer
# will have the 'inbound_nodes' property.
is_parent_functional_model = bool(inbound_nodes)
if is_parent_functional_model and is_functional_model:
for input_layer, inbound_node in zip(input_layers, inbound_nodes):
input_layer_name = _scoped_name(node_name, input_layer)
inbound_node_name = _scoped_name(name_scope, inbound_node[0])
input_to_in_layer[input_layer_name] = inbound_node_name
elif is_parent_functional_model and not is_functional_model:
# Sequential model can take only one input. Make sure inbound to the
# model is linked to the first layer in the Sequential model.
prev_node_name = _scoped_name(name_scope, inbound_nodes[0][0])
elif (
not is_parent_functional_model
and prev_node_name
and is_functional_model
):
assert len(input_layers) == 1, (
"Cannot have multi-input Functional model when parent model "
"is not Functional. Number of input layers: %d" % len(input_layer)
)
input_layer = input_layers[0]
input_layer_name = _scoped_name(node_name, input_layer)
input_to_in_layer[input_layer_name] = prev_node_name
if is_functional_model and output_layers:
layers = _norm_to_list_of_layers(output_layers)
layer_names = [_scoped_name(node_name, layer[0]) for layer in layers]
model_name_to_output[node_name] = layer_names
else:
last_layer = layer_config.get("layers")[-1]
last_layer_name = last_layer.get("config").get("name")
output_node = _scoped_name(node_name, last_layer_name)
model_name_to_output[node_name] = [output_node]
return (input_to_in_layer, model_name_to_output, prev_node_name)
def keras_model_to_graph_def(keras_layer):
"""Returns a GraphDef representation of the Keras model in a dict form.
Note that it only supports models that implemented to_json().
Args:
keras_layer: A dict from Keras model.to_json().
Returns:
A GraphDef representation of the layers in the model.
"""
input_to_layer = {}
model_name_to_output = {}
g = GraphDef()
# Sequential model layers do not have a field "inbound_nodes" but
# instead are defined implicitly via order of layers.
prev_node_name = None
for name_scope, layer in _walk_layers(keras_layer):
if _is_model(layer):
(
input_to_layer,
model_name_to_output,
prev_node_name,
) = _update_dicts(
name_scope,
layer,
input_to_layer,
model_name_to_output,
prev_node_name,
)
continue
layer_config = layer.get("config")
node_name = _scoped_name(name_scope, layer_config.get("name"))
node_def = g.node.add()
node_def.name = node_name
if layer.get("class_name") is not None:
keras_cls_name = layer.get("class_name").encode("ascii")
node_def.attr["keras_class"].s = keras_cls_name
dtype_or_policy = layer_config.get("dtype")
dtype = None
has_unsupported_value = False
# If this is a dict, try and extract the dtype string from
# `config.name`. Keras will export like this for non-input layers and
# some other cases (e.g. tf/keras/mixed_precision/Policy, as described
# in issue #5548).
if isinstance(dtype_or_policy, dict) and "config" in dtype_or_policy:
dtype = dtype_or_policy.get("config").get("name")
elif dtype_or_policy is not None:
dtype = dtype_or_policy
if dtype is not None:
try:
tf_dtype = dtypes.as_dtype(dtype)
node_def.attr["dtype"].type = tf_dtype.as_datatype_enum
except TypeError:
has_unsupported_value = True
elif dtype_or_policy is not None:
has_unsupported_value = True
if has_unsupported_value:
# There's at least one known case when this happens, which is when
# mixed precision dtype policies are used, as described in issue
# #5548. (See https://keras.io/api/mixed_precision/).
# There might be a better way to handle this, but here we are.
logger.warning(
"Unsupported dtype value in graph model config (json):\n%s",
dtype_or_policy,
)
if layer.get("inbound_nodes") is not None:
for name, size, index in _get_inbound_nodes(layer):
inbound_name = _scoped_name(name_scope, name)
# An input to a layer can be output from a model. In that case, the name
# of inbound_nodes to a layer is a name of a model. Remap the name of the
# model to output layer of the model. Also, since there can be multiple
# outputs in a model, make sure we pick the right output_layer from the model.
inbound_node_names = model_name_to_output.get(
inbound_name, [inbound_name]
)
# There can be multiple inbound_nodes that reference the
# same upstream layer. This causes issues when looking for
# a particular index in that layer, since the indices
# captured in `inbound_nodes` doesn't necessarily match the
# number of entries in the `inbound_node_names` list. To
# avoid IndexErrors, we just use the last element in the
# `inbound_node_names` in this situation.
# Note that this is a quick hack to avoid IndexErrors in
# this situation, and might not be an appropriate solution
# to this problem in general.
input_name = (
inbound_node_names[index]
if index < len(inbound_node_names)
else inbound_node_names[-1]
)
node_def.input.append(input_name)
elif prev_node_name is not None:
node_def.input.append(prev_node_name)
if node_name in input_to_layer:
node_def.input.append(input_to_layer.get(node_name))
prev_node_name = node_def.name
return g

View File

@ -0,0 +1,43 @@
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Information on the graph plugin."""
# This name is used as the plugin prefix route and to identify this plugin
# generally, and is also the `plugin_name` for run graphs after data-compat
# transformations.
PLUGIN_NAME = "graphs"
# The Summary API is implemented in TensorFlow because it uses TensorFlow internal APIs.
# As a result, this SummaryMetadata is a bit unconventional and uses non-public
# hardcoded name as the plugin name. Please refer to link below for the summary ops.
# https://github.com/tensorflow/tensorflow/blob/11f4ecb54708865ec757ca64e4805957b05d7570/tensorflow/python/ops/summary_ops_v2.py#L757
PLUGIN_NAME_RUN_METADATA = "graph_run_metadata"
# https://github.com/tensorflow/tensorflow/blob/11f4ecb54708865ec757ca64e4805957b05d7570/tensorflow/python/ops/summary_ops_v2.py#L788
PLUGIN_NAME_RUN_METADATA_WITH_GRAPH = "graph_run_metadata_graph"
# https://github.com/tensorflow/tensorflow/blob/565952cc2f17fdfd995e25171cf07be0f6f06180/tensorflow/python/ops/summary_ops_v2.py#L825
PLUGIN_NAME_KERAS_MODEL = "graph_keras_model"
# Plugin name used for `Event.tagged_run_metadata`. This doesn't fall into one
# of the above cases because (despite the name) `PLUGIN_NAME_RUN_METADATA` is
# _required_ to have both profile and op graphs, whereas tagged run metadata
# need only have profile data.
PLUGIN_NAME_TAGGED_RUN_METADATA = "graph_tagged_run_metadata"
# In the context of the data provider interface, tag name given to a
# graph read from the `graph_def` field of an `Event` proto, which is
# not attached to a summary and thus does not have a proper tag name of
# its own. Run level graphs always represent `GraphDef`s (graphs of
# TensorFlow ops), never conceptual graphs, profile graphs, etc. This is
# the only tag name used by the `"graphs"` plugin.
RUN_GRAPH_NAME = "__run_graph__"