5. Bridge TensorFlow* to run on Intel®
nGraph™ backends
https://github.com/NervanaSystems/ngraph-tf
https://github.com/NervanaSystems/ngraph-tf/tree/r0.5/
13. src/ngraph_rewrite_pass.cc
Status NGraphVariableCapturePass::Run(
const GraphOptimizationPassOptions& options) override;{
// For filename generation purposes, grab a fresh index. This is just an
// arbitrary integer to avoid filename collisions resulting from subsequent
// runs of this pass.
int idx = FreshIndex();
// Do variable capture then, if requested, dump the graphs.
TF_RETURN_IF_ERROR(CaptureVariables(options.graph->get()));
return Status::OK();
}
NGraphVariableCapturePass
27. src/ngraph_rewrite_pass.cc
// Pass that rewrites the graph for nGraph operation.
//
// The pass has several phases, each executed in sequence:
//
// 1. Marking [ngraph_mark_for_clustering.cc]
// 2. Cluster Assignment [ngraph_assign_clusters.cc]
// 3. Cluster Deassignment [ngraph_deassign_clusters.cc]
// 4. Cluster Encapsulation [ngraph_encapsulate_clusters.cc]
NGraphEncapsulatePass
28. src/ngraph_rewrite_pass.cc
class NGraphEncapsulationPass : public NGraphRewritePass {
public:
Status Run(const GraphOptimizationPassOptions& options) override {
// For filename generation purposes, grab a fresh index. This is just an
// arbitrary integer to avoid filename collisions resulting from subsequent
// runs of this pass.
int idx = FreshIndex();
// If requested, dump unmarked graphs.
if (DumpUnmarkedGraphs()) {
DumpGraphs(options, idx, "unmarked", "Unmarked Graph");
}
NGraphEncapsulatePass
29. src/ngraph_rewrite_pass.cc
// 1. Marking [ngraph_mark_for_clustering.cc]
// Mark for clustering then, if requested, dump the graphs.
TF_RETURN_IF_ERROR(MarkForClustering(options.graph->get()));
if (DumpMarkedGraphs()) {
DumpGraphs(options, idx, "marked", "Graph Marked for Clustering");
}
NGraphEncapsulatePass
30. src/ngraph_rewrite_pass.cc
// 2. Cluster Assignment [ngraph_assign_clusters.cc]
// Assign clusters then, if requested, dump the graphs.
TF_RETURN_IF_ERROR(AssignClusters(options.graph->get()));
if (DumpClusteredGraphs()) {
DumpGraphs(options, idx, "clustered", "Graph with Clusters Assigned");
}
NGraphEncapsulatePass
31. src/ngraph_rewrite_pass.cc
// 3. Cluster Deassignment [ngraph_deassign_clusters.cc]
// Deassign trivial clusters then, if requested, dump the graphs.
TF_RETURN_IF_ERROR(DeassignClusters(options.graph->get()));
if (DumpDeclusteredGraphs()) {
DumpGraphs(options, idx, "declustered",
"Graph with Trivial Clusters De-Assigned");
}
NGraphEncapsulatePass
32. src/ngraph_rewrite_pass.cc
// 4. Cluster Encapsulation [ngraph_encapsulate_clusters.cc]
// Encapsulate clusters then, if requested, dump the graphs.
TF_RETURN_IF_ERROR(EncapsulateClusters(options.graph->get()));
if (DumpEncapsulatedGraphs()) {
DumpGraphs(options, idx, "encapsulated",
"Graph with Clusters Encapsulated");
}
NGraphEncapsulatePass
33. src/ngraph_rewrite_pass.cc
// Rewrite for tracking then, if requested, dump the graphs.
TF_RETURN_IF_ERROR(RewriteForTracking(options.graph->get()));
if (DumpTrackedGraphs()) {
DumpGraphs(options, idx, "tracked",
"Graph with Variables Rewritten for Tracking");
}
return Status::OK();
}
NGraphEncapsulatePass
36. src/ngraph_encapsulate_clusters.cc
// Pass 1: Populate the cluster-index-to-device name map for each existing
// cluster.
if (it != device_name_map.end()) {
if (it->second != node->requested_device()) {
std::stringstream ss_err;
// ここでエラーメッセージを生成
return errors::Internal(ss_err.str());
}
} else {
device_name_map[cluster_idx] = node->requested_device();
}
}
EncapsulateClusters
37. src/ngraph_encapsulate_clusters.cc
// Pass 2: Find all nodes that are feeding into/out of each cluster, and
// add inputs for them to the corresponding FunctionDef(s).
std::map<int, int> retval_index_count;
std::map<int, int> arg_index_count;
for (auto edge : graph->edges()) {
if (edge->IsControlEdge()) {
continue;
}
Node* src = edge->src();
Node* dst = edge->dst();
EncapsulateClusters
38. src/ngraph_encapsulate_clusters.cc
if (!src->IsOp() || !dst->IsOp()) {
continue;
}
int dst_cluster_idx;
bool dst_clustered =
(GetNodeCluster(dst, &dst_cluster_idx) == Status::OK());
int src_cluster_idx;
bool src_clustered =
(GetNodeCluster(src, &src_cluster_idx) == Status::OK());
EncapsulateClusters
49. src/ngraph_encapsulate_clusters.cc
// Pass 4: Remap all non-clustered inputs that are reading from
// encapsulated edges, and all control edges that cross cluster
// boundaries.
// Copy the edge pointers, so as not to invalidate the iterator.
std::vector<Edge*> edges;
for (auto edge : graph->edges()) {
edges.push_back(edge);
}
EncapsulateClusters
50. src/ngraph_encapsulate_clusters.cc
for (auto edge : edges) {
int src_cluster_idx;
bool src_clustered =
(GetNodeCluster(edge->src(), &src_cluster_idx) == Status::OK());
int dst_cluster_idx;
bool dst_clustered =
(GetNodeCluster(edge->dst(), &dst_cluster_idx) == Status::OK());
if (src_cluster_idx == dst_cluster_idx) {
continue;
}
EncapsulateClusters
51. src/ngraph_encapsulate_clusters.cc
if (edge->IsControlEdge()) {
if (src_clustered && dst_clustered) {
graph->RemoveControlEdge(edge);
graph->AddControlEdge(cluster_node_map[src_cluster_idx],
cluster_node_map[dst_cluster_idx]);
} else if (src_clustered) {
Node* dst = edge->dst();
graph->RemoveControlEdge(edge);
graph->AddControlEdge(cluster_node_map[src_cluster_idx], dst);
EncapsulateClusters
52. src/ngraph_encapsulate_clusters.cc
} else if (dst_clustered) {
Node* src = edge->src();
graph->RemoveControlEdge(edge);
graph->AddControlEdge(src, cluster_node_map[dst_cluster_idx]);
}
} else {
// This is handled at a later stage (TODO(amprocte): explain)
if (dst_clustered) {
continue;
}
EncapsulateClusters
53. src/ngraph_encapsulate_clusters.cc
auto it = output_remap_map.find(
std::make_tuple(edge->src()->id(), edge->src_output()));
if (it == output_remap_map.end()) {
continue;
}
int cluster_idx, cluster_output;
std::tie(cluster_idx, cluster_output) = it->second;
graph->UpdateEdge(cluster_node_map[cluster_idx], cluster_output,
edge->dst(), edge->dst_input());
}
}
EncapsulateClusters
54. src/ngraph_encapsulate_clusters.cc
// Pass 5: Make copies of all clustered nodes inside the cluster graphs,
// rewiring the inputs in their NodeDefs as we go.
for (auto node : graph->op_nodes()) {
int cluster_idx;
if (GetNodeAttr(node->attrs(), "_ngraph_cluster", &cluster_idx) !=
Status::OK()) {
continue;
}
EncapsulateClusters
55. src/ngraph_encapsulate_clusters.cc
// Because the input names may have changed from the original node def,
// we will need to borrow some code from Graph::ToGraphDefSubRange
in
// tensorflow/core/graph/graph.cc that rewrites the node's input list.
// begin code copied and pasted (and modified) from graph.cc...
NodeDef original_def = node->def();
EncapsulateClusters
58. src/ngraph_encapsulate_clusters.cc
for (auto& input : *(node_def->mutable_input())) {
TensorId tensor_id = ParseTensorName(input);
auto it = input_rename_map.find(std::make_tuple(
cluster_idx, tensor_id.first.ToString(), tensor_id.second));
if (it != input_rename_map.end()) {
input = it->second;
}
}
}
EncapsulateClusters
59. src/ngraph_encapsulate_clusters.cc
// Pass 6: Remove clustered nodes from the graph.
for (auto node : graph->op_nodes()) {
int cluster_idx;
if (GetNodeAttr(node->attrs(), "_ngraph_cluster", &cluster_idx) !=
Status::OK()) {
continue;
}
graph->RemoveNode(node);
}
EncapsulateClusters
60. src/ngraph_encapsulate_clusters.cc
// Pass 7 (optional, only run if environment variable <= デバッグ用?
// NGRAPH_TF_VALIDATE_CLUSTER_GRAPHS is set):
// validate the graph def, and
// make sure we can construct a graph from it.
if (std::getenv("NGRAPH_TF_VALIDATE_CLUSTER_GRAPHS")) {
for (auto& kv : device_name_map) {
int cluster_idx = kv.first;
TF_RETURN_IF_ERROR(graph::ValidateGraphDef(
*NGraphClusterManager::GetClusterGraph(cluster_idx),
*OpRegistry::Global()));
EncapsulateClusters
80. src/ngraph_encapsulate_op.cc
OP_REQUIRES(
ctx, ng_element_type == expected_elem_type,
errors::Internal("Element type inferred by nGraph does not match "
"the element type expected by TensorFlow"));
void* last_dst_ptr = output_caches[i].first;
std::shared_ptr<ng::runtime::TensorView> last_tv =
output_caches[i].second;
NGraphEncapsulateOp
89. src/ngraph_builder.cc
for (const auto n : ordered) {
if (n->IsSink() || n->IsSource()) { // 入力 か 出力の場合
continue;
}
if (n->IsControlFlow()) { // 制御フルーはサポートしない
return errors::Unimplemented(
"Encountered a control flow op in the nGraph bridge: ",
n->DebugString());
}
Builder::TranslateGraph
90. src/ngraph_builder.cc
if (n->type_string() == "_Arg") { // パラメータ
tf_params.push_back(n);
} else if (n->type_string() == "_Retval") { // 戻り値
tf_ret_vals.push_back(n);
} else {
tf_ops.push_back(n); // Op
}
}
Builder::TranslateGraph
93. src/ngraph_builder.cc
// Op の処理
for (auto op : tf_ops) {
try {
// TensorFlow の Op を nGraph の Op にマッピング
TRANSLATE_OP_MAP.at(op->type_string())(op, ng_op_map);
} catch (const std::out_of_range&) {
return errors::InvalidArgument("Unsupported Op: ", op->name(), " (",
op->type_string(), ")");
}
}
Builder::TranslateGraph
94. src/ngraph_builder.cc
vector<shared_ptr<ng::Node>> ng_result_list(tf_ret_vals.size());
// 入力データ
for (auto n : tf_ret_vals) {
if (n->num_inputs() != 1) {
return errors::InvalidArgument("_Retval has ", n->num_inputs(),
" inputs, should have 1");
}
int index;
if (GetNodeAttr(n->attrs(), "index", &index) != Status::OK()) {
return errors::InvalidArgument("No index defined for _Retval");
}
Builder::TranslateGraph