Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 22 additions & 1 deletion xprof/convert/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -213,7 +213,6 @@ cc_library(
"@org_xprof//xprof/utils:hlo_proto_to_module",
"@tsl//tsl/profiler/protobuf:xplane_proto_cc",
"@xla//xla/hlo/ir:hlo",
"@xla//xla/tsl/platform:errors",
"@xla//xla/tsl/platform:statusor",
],
alwayslink = 1,
Expand Down Expand Up @@ -1637,6 +1636,7 @@ cc_library(
":file_utils",
":repository",
":tool_options",
"@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/strings",
"@com_google_protobuf//:protobuf",
Expand All @@ -1650,6 +1650,27 @@ cc_library(
],
)

cc_test(
name = "xplane_to_hlo_test",
srcs = ["xplane_to_hlo_test.cc"],
deps = [
":file_utils",
":repository",
":tool_options",
":xplane_to_hlo",
"@com_google_absl//absl/log:check",
"@com_google_absl//absl/status",
"@com_google_absl//absl/status:status_matchers",
"@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/strings",
"@com_google_googletest//:gtest",
"@com_google_googletest//:gtest_main",
"@tsl//tsl/platform:path",
"@xla//xla/service:hlo_proto_cc",
"@xla//xla/tsl/platform:env",
],
)

cc_library(
name = "op_profile_builder",
srcs = ["op_profile_builder.cc"],
Expand Down
33 changes: 26 additions & 7 deletions xprof/convert/graph_viewer_processor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@ limitations under the License.
#include "xprof/convert/graph_viewer_processor.h"

#include <memory>
#include <optional>
#include <string>

#include "absl/log/log.h"
Expand Down Expand Up @@ -47,11 +46,11 @@ limitations under the License.

namespace xprof {

namespace {

using ::tensorflow::profiler::ConvertHloProtoToGraph;
using ::tensorflow::profiler::ConvertHloProtoToStringView;
using ::tensorflow::profiler::GetAdjacentNodes;
using ::tensorflow::profiler::GetHloProtoByModuleName;
using ::tensorflow::profiler::GetParam;
using ::tensorflow::profiler::GraphViewerParams;
using ::tensorflow::profiler::kAdjacentNodes;
using ::tensorflow::profiler::kCustomCallGraphTypeName;
Expand All @@ -60,6 +59,22 @@ using ::tensorflow::profiler::ParseGraphViewerParams;
using ::tensorflow::profiler::SessionSnapshot;
using ::tensorflow::profiler::ToolOptions;

absl::StatusOr<xla::HloProto> GetHloProto(
const SessionSnapshot& session_snapshot, const ToolOptions& options) {
absl::StatusOr<xla::HloProto> hlo_proto =
GetHloProtoByOptions(session_snapshot, options);
if (hlo_proto.ok()) return hlo_proto;

// Fallback: If module not found/provided, try searching by node name.
absl::StatusOr<GraphViewerParams> params = ParseGraphViewerParams(options);
if (params.ok() && !params->node_name.empty()) {
hlo_proto = GetHloProtoByNodeName(session_snapshot, params->node_name);
}
return hlo_proto;
}

} // namespace

absl::StatusOr<std::string> ConvertHloProtoToGraphViewer(
const xla::HloProto& hlo_proto, const ToolOptions& options) {
TF_ASSIGN_OR_RETURN(GraphViewerParams params,
Expand Down Expand Up @@ -89,16 +104,20 @@ absl::StatusOr<std::string> ConvertHloProtoToGraphViewer(

absl::Status GraphViewerProcessor::ProcessSession(
const SessionSnapshot& session_snapshot, const ToolOptions& options) {
TF_ASSIGN_OR_RETURN(xla::HloProto hlo_proto,
GetHloProtoByOptions(session_snapshot, options));
absl::StatusOr<xla::HloProto> hlo_proto =
GetHloProto(session_snapshot, options);

if (!hlo_proto.ok()) {
return hlo_proto.status();
}

LOG(INFO) << "Processing graph viewer for hlo module: "
<< hlo_proto.hlo_module().name();
<< hlo_proto->hlo_module().name();

std::string graph_viewer_json;

TF_ASSIGN_OR_RETURN(graph_viewer_json,
ConvertHloProtoToGraphViewer(hlo_proto, options));
ConvertHloProtoToGraphViewer(*hlo_proto, options));

SetOutput(graph_viewer_json, "application/json");
return absl::OkStatus();
Expand Down
64 changes: 46 additions & 18 deletions xprof/convert/xplane_to_hlo.cc
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ limitations under the License.
#include <string>
#include <vector>

#include "absl/status/status.h"
#include "absl/status/statusor.h"
#include "absl/strings/match.h"
#include "absl/strings/str_cat.h"
Expand Down Expand Up @@ -73,7 +74,8 @@ absl::StatusOr<bool> GetHloProtoFromMultiXSpaceAndSaveToFile(

// Save HLO protos to session run directory.
for (const absl::string_view module_name : module_list) {
auto hlo_proto_or = hlo_proto_map.GetHloProtoByModuleName(module_name);
absl::StatusOr<const xla::HloProto*> hlo_proto_or =
hlo_proto_map.GetHloProtoByModuleName(module_name);
if (!hlo_proto_or.ok()) {
return tsl::errors::Internal(hlo_proto_or.status().message());
}
Expand All @@ -91,8 +93,7 @@ absl::StatusOr<bool> GetHloProtoFromMultiXSpaceAndSaveToFile(
} // namespace

absl::StatusOr<xla::HloProto> GetHloProtoByModuleName(
const SessionSnapshot& session_snapshot,
const absl::string_view module_name) {
const SessionSnapshot& session_snapshot, absl::string_view module_name) {
std::string file_name =
ProfilerJoinPath(session_snapshot.GetSessionRunDir(),
absl::StrCat(module_name, kHloProtoSuffix));
Expand All @@ -101,33 +102,60 @@ absl::StatusOr<xla::HloProto> GetHloProtoByModuleName(
return hlo_proto;
}

absl::StatusOr<xla::HloProto> GetHloProtoByNodeName(
const SessionSnapshot& session_snapshot, absl::string_view node_name) {
std::vector<std::string> files;
TF_RETURN_IF_ERROR(tsl::Env::Default()->GetChildren(
std::string(session_snapshot.GetSessionRunDir()), &files));

for (absl::string_view module_name : files) {
if (!absl::ConsumeSuffix(&module_name, kHloProtoSuffix)) {
continue;
}
absl::StatusOr<xla::HloProto> hlo_proto_or =
GetHloProtoByModuleName(session_snapshot, module_name);
if (!hlo_proto_or.ok()) continue;
const xla::HloProto& hlo_proto = *hlo_proto_or;
if (!hlo_proto.has_hlo_module()) continue;
for (const xla::HloComputationProto& computation :
hlo_proto.hlo_module().computations()) {
for (const xla::HloInstructionProto& instruction :
computation.instructions()) {
if (instruction.name() == node_name) {
return hlo_proto_or;
}
}
}
}
return absl::NotFoundError(
absl::StrCat("HLO proto file containing node name ", node_name,
" not found in ", session_snapshot.GetSessionRunDir()));
}

absl::StatusOr<xla::HloProto> GetHloProtoByProgramId(
const SessionSnapshot& session_snapshot,
const absl::string_view program_id_str) {
const SessionSnapshot& session_snapshot, absl::string_view program_id_str) {
std::vector<std::string> files;
TF_RETURN_IF_ERROR(tsl::Env::Default()->GetChildren(
std::string(session_snapshot.GetSessionRunDir()), &files));

std::string target_module_name = "";

for (const std::string& file : files) {
if (absl::EndsWith(file, kHloProtoSuffix)) {
absl::string_view module_name = file;
if (!absl::ConsumeSuffix(&module_name, kHloProtoSuffix)) {
continue; // Should not happen based on the EndsWith check
}
absl::string_view module_name = file;
if (!absl::ConsumeSuffix(&module_name, kHloProtoSuffix)) {
continue;
}

// Fuzzy search: Check if the module name contains the program_id string.
if (absl::StrContains(module_name, program_id_str)) {
// Assuming the first match is the desired one.
target_module_name = std::string(module_name);
break;
}
// Fuzzy search: Check if the module name contains the program_id string.
if (absl::StrContains(module_name, program_id_str)) {
// Assuming the first match is the desired one.
target_module_name = std::string(module_name);
break;
}
}

if (target_module_name.empty()) {
return tsl::errors::NotFound(
return absl::NotFoundError(
absl::StrCat("HLO proto file containing program ID ", program_id_str,
" not found in ", session_snapshot.GetSessionRunDir()));
}
Expand All @@ -148,7 +176,7 @@ absl::StatusOr<xla::HloProto> GetHloProtoByOptions(
} else if (program_id.has_value() && !program_id->empty()) {
return GetHloProtoByProgramId(session_snapshot, *program_id);
} else {
return tsl::errors::InvalidArgument("Can not load hlo proto from options.");
return absl::InvalidArgumentError("Can not load hlo proto from options.");
}
}

Expand Down
4 changes: 4 additions & 0 deletions xprof/convert/xplane_to_hlo.h
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,10 @@ absl::StatusOr<xla::HloProto> GetHloProtoByModuleName(
absl::StatusOr<xla::HloProto> GetHloProtoByProgramId(
const SessionSnapshot& session_snapshot, absl::string_view program_id);

// Get HLO proto by searching for node name in all HLO modules.
absl::StatusOr<xla::HloProto> GetHloProtoByNodeName(
const SessionSnapshot& session_snapshot, absl::string_view node_name);

absl::StatusOr<xla::HloProto> GetHloProtoByOptions(
const SessionSnapshot& session_snapshot, const ToolOptions& options);

Expand Down
Loading
Loading