Skip to content

Commit a436cd5

Browse files
Sai Ganesh Muthuramancopybara-github
authored andcommitted
Optimize TraceEvent serialization for LevelDB storage.
PiperOrigin-RevId: 881293680
1 parent c425d4e commit a436cd5

4 files changed

Lines changed: 316 additions & 53 deletions

File tree

xprof/convert/trace_viewer/BUILD

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -150,14 +150,15 @@ cc_library(
150150
"@com_google_absl//absl/container:flat_hash_map",
151151
"@com_google_absl//absl/container:flat_hash_set",
152152
"@com_google_absl//absl/functional:bind_front",
153+
"@com_google_absl//absl/functional:function_ref",
153154
"@com_google_absl//absl/log",
154155
"@com_google_absl//absl/log:check",
155156
"@com_google_absl//absl/numeric:bits",
156157
"@com_google_absl//absl/status",
157158
"@com_google_absl//absl/status:statusor",
158159
"@com_google_absl//absl/strings",
159160
"@com_google_absl//absl/time",
160-
"@com_google_absl//absl/types:optional",
161+
"@com_google_protobuf//:protobuf",
161162
"@org_xprof//plugin/xprof/protobuf:task_proto_cc",
162163
"@org_xprof//plugin/xprof/protobuf:trace_events_proto_cc",
163164
"@org_xprof//plugin/xprof/protobuf:trace_events_raw_proto_cc",
@@ -255,3 +256,15 @@ cc_test(
255256
"@xla//xla/tsl/platform:env",
256257
],
257258
)
259+
260+
cc_test(
261+
name = "trace_events_test",
262+
srcs = ["trace_events_test.cc"],
263+
deps = [
264+
":trace_events",
265+
"@com_google_absl//absl/status",
266+
"@com_google_googletest//:gtest_main",
267+
"@org_xprof//plugin/xprof/protobuf:trace_events_proto_cc",
268+
"@org_xprof//plugin/xprof/protobuf:trace_events_raw_proto_cc",
269+
],
270+
)

xprof/convert/trace_viewer/trace_events.cc

Lines changed: 135 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,6 @@ limitations under the License.
2222
#include <iterator>
2323
#include <limits>
2424
#include <memory>
25-
#include <optional>
2625
#include <string>
2726
#include <utility>
2827
#include <variant>
@@ -31,13 +30,19 @@ limitations under the License.
3130
#include "absl/algorithm/container.h"
3231
#include "absl/container/flat_hash_map.h"
3332
#include "absl/container/flat_hash_set.h"
33+
#include "absl/functional/function_ref.h"
3434
#include "absl/log/check.h"
3535
#include "absl/log/log.h"
3636
#include "absl/numeric/bits.h"
3737
#include "absl/status/status.h"
38+
#include "absl/strings/str_cat.h"
3839
#include "absl/strings/string_view.h"
3940
#include "absl/time/clock.h"
4041
#include "absl/time/time.h"
42+
#include "google/protobuf/descriptor.h"
43+
#include "google/protobuf/io/coded_stream.h"
44+
#include "google/protobuf/io/zero_copy_stream_impl_lite.h"
45+
#include "google/protobuf/wire_format_lite.h"
4146
#include "xla/tsl/lib/io/iterator.h"
4247
#include "xla/tsl/lib/io/table.h"
4348
#include "xla/tsl/lib/io/table_builder.h"
@@ -53,6 +58,9 @@ limitations under the License.
5358
#include "plugin/xprof/protobuf/trace_events.pb.h"
5459
#include "plugin/xprof/protobuf/trace_events_raw.pb.h"
5560

61+
namespace proto2_io = ::google::protobuf::io;
62+
namespace proto2_internal = ::google::protobuf::internal;
63+
5664
namespace tensorflow {
5765
namespace profiler {
5866

@@ -94,6 +102,79 @@ inline void AppendEvents(TraceEventTrack&& src, TraceEventTrack* dst) {
94102
}
95103
}
96104

105+
// Serializes 'event' to 'output' while skipping fields 'skip_field_1' and
106+
// 'skip_field_2'. Uses reflection for efficiency and type-agnostic field
107+
// iteration.
108+
absl::Status SerializeTraceEventSkipping(const TraceEvent* event,
109+
int skip_field_1, int skip_field_2,
110+
std::string& output) {
111+
output.clear();
112+
proto2_io::StringOutputStream string_stream(&output);
113+
{
114+
proto2_io::CodedOutputStream coded_stream(&string_stream);
115+
const google::protobuf::Reflection* reflection = event->GetReflection();
116+
std::vector<const google::protobuf::FieldDescriptor*> fields;
117+
reflection->ListFields(*event, &fields);
118+
119+
using proto2_internal::WireFormatLite;
120+
for (const google::protobuf::FieldDescriptor* field : fields) {
121+
int field_number = field->number();
122+
if (field_number == skip_field_1 || field_number == skip_field_2) {
123+
continue;
124+
}
125+
126+
switch (field->type()) {
127+
case google::protobuf::FieldDescriptor::TYPE_UINT32:
128+
WireFormatLite::WriteUInt32(field_number,
129+
reflection->GetUInt32(*event, field),
130+
&coded_stream);
131+
break;
132+
case google::protobuf::FieldDescriptor::TYPE_UINT64:
133+
WireFormatLite::WriteUInt64(field_number,
134+
reflection->GetUInt64(*event, field),
135+
&coded_stream);
136+
break;
137+
case google::protobuf::FieldDescriptor::TYPE_INT64:
138+
WireFormatLite::WriteInt64(
139+
field_number, reflection->GetInt64(*event, field), &coded_stream);
140+
break;
141+
case google::protobuf::FieldDescriptor::TYPE_FIXED64:
142+
WireFormatLite::WriteFixed64(field_number,
143+
reflection->GetUInt64(*event, field),
144+
&coded_stream);
145+
break;
146+
case google::protobuf::FieldDescriptor::TYPE_ENUM:
147+
WireFormatLite::WriteEnum(field_number,
148+
reflection->GetEnumValue(*event, field),
149+
&coded_stream);
150+
break;
151+
case google::protobuf::FieldDescriptor::TYPE_STRING:
152+
case google::protobuf::FieldDescriptor::TYPE_BYTES:
153+
WireFormatLite::WriteString(field_number,
154+
reflection->GetString(*event, field),
155+
&coded_stream);
156+
break;
157+
default:
158+
return absl::UnimplementedError(
159+
absl::StrCat("Unsupported field type: ", field->name()));
160+
}
161+
}
162+
}
163+
return absl::OkStatus();
164+
}
165+
166+
// Helper to wrap a serialization function into the required functor format.
167+
// Reuses the given 'buffer' for efficiency.
168+
// NOTE: The returned string_view is not thread-safe and must not outlive the
169+
// buffer.
170+
absl::StatusOr<absl::string_view> SerializeToView(
171+
const TraceEvent* event, std::string& buffer,
172+
absl::FunctionRef<absl::Status(const TraceEvent*, std::string&)>
173+
serialize_fn) {
174+
TF_RETURN_IF_ERROR(serialize_fn(event, buffer));
175+
return absl::string_view(buffer);
176+
}
177+
97178
} // namespace
98179

99180
TraceEvent::EventType GetTraceEventType(const TraceEvent& event) {
@@ -259,20 +340,30 @@ absl::Status DoStoreAsLevelDbTables(
259340
const Trace& trace, std::unique_ptr<tsl::WritableFile>& trace_events_file,
260341
std::unique_ptr<tsl::WritableFile>& trace_events_metadata_file,
261342
std::unique_ptr<tsl::WritableFile>& trace_events_prefix_trie_file) {
262-
auto executor = std::make_unique<XprofThreadPoolExecutor>(
263-
"StoreAsLevelDbTables", /*num_threads=*/3);
343+
std::unique_ptr<XprofThreadPoolExecutor> executor =
344+
std::make_unique<XprofThreadPoolExecutor>("StoreAsLevelDbTables",
345+
/*num_threads=*/3);
264346
absl::Status trace_events_status, trace_events_metadata_status;
265347
executor->Execute(
266348
[&trace_events_file, &trace, &events_by_level, &trace_events_status]() {
349+
std::string buffer;
267350
trace_events_status = DoStoreAsLevelDbTable(
268351
trace_events_file, trace, events_by_level,
269-
GenerateTraceEventCopyForPersistingEventWithoutMetadata);
352+
[&buffer](const TraceEvent* event) {
353+
return SerializeToView(
354+
event, buffer,
355+
SerializeTraceEventForPersistingEventWithoutMetadata);
356+
});
270357
});
271358
executor->Execute([&trace_events_metadata_file, &events_by_level, &trace,
272359
&trace_events_metadata_status]() {
360+
std::string buffer;
273361
trace_events_metadata_status = DoStoreAsLevelDbTable(
274362
trace_events_metadata_file, trace, events_by_level,
275-
GenerateTraceEventCopyForPersistingOnlyMetadata);
363+
[&buffer](const TraceEvent* event) {
364+
return SerializeToView(event, buffer,
365+
SerializeTraceEventForPersistingOnlyMetadata);
366+
});
276367
});
277368
absl::Status trace_events_prefix_trie_status;
278369
executor->Execute([&trace_events_prefix_trie_file, &events_by_level,
@@ -286,43 +377,42 @@ absl::Status DoStoreAsLevelDbTables(
286377
return trace_events_status;
287378
}
288379

289-
std::optional<TraceEvent> GenerateTraceEventCopyForPersistingFullEvent(
290-
const TraceEvent* event) {
291-
TraceEvent event_copy = *event;
292-
// To reduce file size, clear the timestamp from the value. It is
293-
// redundant info because the timestamp is part of the key.
294-
event_copy.clear_timestamp_ps();
295-
return event_copy;
380+
absl::Status SerializeTraceEventForPersistingFullEvent(const TraceEvent* event,
381+
std::string& output) {
382+
return SerializeTraceEventSkipping(event, TraceEvent::kTimestampPsFieldNumber,
383+
/*skip_field_2=*/-1, output);
296384
}
297385

298-
std::optional<TraceEvent>
299-
GenerateTraceEventCopyForPersistingEventWithoutMetadata(
300-
const TraceEvent* event) {
301-
TraceEvent event_copy = *event;
302-
// To reduce file size, clear the timestamp from the value. It is
303-
// redundant info because the timestamp is part of the key.
304-
event_copy.clear_timestamp_ps();
386+
absl::Status SerializeTraceEventForPersistingEventWithoutMetadata(
387+
const TraceEvent* event, std::string& output) {
388+
int skip_2 = -1;
305389
// To reduce file size, clear the raw data from the value. It is
306390
// redundant info because the raw data is stored in the metadata file.
307391
// However, we still need to keep the raw data for counter events as they
308392
// are a special case and we need to return the args for the same during the
309393
// initial read.
310394
if (GetTraceEventType(*event) != TraceEvent::EVENT_TYPE_COUNTER) {
311-
event_copy.clear_raw_data();
395+
skip_2 = TraceEvent::kRawDataFieldNumber;
312396
}
313-
return event_copy;
397+
return SerializeTraceEventSkipping(event, TraceEvent::kTimestampPsFieldNumber,
398+
/*skip_field_2=*/skip_2, output);
314399
}
315400

316-
std::optional<TraceEvent> GenerateTraceEventCopyForPersistingOnlyMetadata(
317-
const TraceEvent* event) {
401+
absl::Status SerializeTraceEventForPersistingOnlyMetadata(
402+
const TraceEvent* event, std::string& output) {
318403
if (GetTraceEventType(*event) == TraceEvent::EVENT_TYPE_COUNTER) {
319404
// Counter events are stored in the trace events file itself and do not
320405
// require a metadata copy.
321-
return std::nullopt;
406+
return absl::NotFoundError("No metadata found for counter event");
407+
}
408+
// To avoid redundant deep copies of the whole TraceEvent, we only copy
409+
// the raw_data field into a small individual TraceEvent.
410+
TraceEvent metadata_event;
411+
metadata_event.set_raw_data(event->raw_data());
412+
if (metadata_event.SerializeToString(&output)) {
413+
return absl::OkStatus();
322414
}
323-
TraceEvent event_copy;
324-
event_copy.set_raw_data(event->raw_data());
325-
return event_copy;
415+
return absl::InternalError("Failed to serialize trace metadata");
326416
}
327417

328418
// Store the contents of this container in an sstable file. The format is as
@@ -340,8 +430,7 @@ std::optional<TraceEvent> GenerateTraceEventCopyForPersistingOnlyMetadata(
340430
absl::Status DoStoreAsLevelDbTable(
341431
std::unique_ptr<tsl::WritableFile>& file, const Trace& trace,
342432
const std::vector<std::vector<const TraceEvent*>>& events_by_level,
343-
std::function<std::optional<TraceEvent>(const TraceEvent*)>
344-
generate_event_copy_fn) {
433+
SerializeEventFn serialize_event_fn) {
345434
absl::Time start_time = absl::Now();
346435
tsl::table::Options options;
347436
options.block_size = 20 * 1024 * 1024;
@@ -361,9 +450,12 @@ absl::Status DoStoreAsLevelDbTable(
361450
uint64_t timestamp = event->timestamp_ps();
362451
std::string key = LevelDbTableKey(zoom_level, timestamp, event->serial());
363452
if (!key.empty()) {
364-
auto event_copy = generate_event_copy_fn(event);
365-
if (event_copy.has_value()) {
366-
builder.Add(key, event_copy->SerializeAsString());
453+
absl::StatusOr<absl::string_view> status_or_view =
454+
serialize_event_fn(event);
455+
if (status_or_view.ok()) {
456+
builder.Add(key, status_or_view.value());
457+
} else if (!absl::IsNotFound(status_or_view.status())) {
458+
return status_or_view.status();
367459
}
368460
} else {
369461
++num_of_events_dropped;
@@ -477,6 +569,17 @@ void TraceEventsContainerBase<EventFactory, RawData, Hash>::Merge(
477569
other.trace_.Clear();
478570
}
479571

572+
absl::Status StoreAsLevelDbTableImpl(
573+
std::unique_ptr<tsl::WritableFile> file, const Trace& trace,
574+
const std::vector<std::vector<const TraceEvent*>>& events_by_level) {
575+
std::string buffer;
576+
return DoStoreAsLevelDbTable(
577+
file, trace, events_by_level, [&buffer](const TraceEvent* event) {
578+
return SerializeToView(event, buffer,
579+
SerializeTraceEventForPersistingFullEvent);
580+
});
581+
}
582+
480583
// Explicit instantiations for the common case.
481584
template class TraceEventsContainerBase<EventFactory, RawData>;
482585

xprof/convert/trace_viewer/trace_events.h

Lines changed: 22 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -33,14 +33,14 @@ limitations under the License.
3333
#include "absl/container/flat_hash_map.h"
3434
#include "absl/container/flat_hash_set.h"
3535
#include "absl/functional/bind_front.h"
36+
#include "absl/functional/function_ref.h"
3637
#include "absl/log/check.h"
3738
#include "absl/log/log.h"
3839
#include "absl/status/status.h"
3940
#include "absl/status/statusor.h"
4041
#include "absl/strings/string_view.h"
4142
#include "absl/time/clock.h"
4243
#include "absl/time/time.h"
43-
#include "absl/types/optional.h"
4444
#include "xla/tsl/lib/io/iterator.h"
4545
#include "xla/tsl/lib/io/table.h"
4646
#include "xla/tsl/lib/io/table_builder.h"
@@ -87,34 +87,38 @@ static constexpr int kSearchParallelizationThreshold = 100;
8787
std::vector<TraceEvent*> MergeEventTracks(
8888
const std::vector<const TraceEventTrack*>& event_tracks);
8989

90+
using SerializeEventFn =
91+
absl::FunctionRef<absl::StatusOr<absl::string_view>(const TraceEvent*)>;
92+
93+
absl::Status StoreAsLevelDbTableImpl(
94+
std::unique_ptr<tsl::WritableFile> file, const Trace& trace,
95+
const std::vector<std::vector<const TraceEvent*>>& events_by_level);
96+
9097
absl::Status DoStoreAsLevelDbTable(
9198
std::unique_ptr<tsl::WritableFile>& file, const Trace& trace,
9299
const std::vector<std::vector<const TraceEvent*>>& events_by_level,
93-
std::function<std::optional<TraceEvent>(const TraceEvent*)>
94-
generate_event_copy_fn);
100+
SerializeEventFn serialize_event_fn);
95101

96102
absl::Status DoStoreAsLevelDbTables(
97103
const std::vector<std::vector<const TraceEvent*>>& events_by_level,
98104
const Trace& trace, std::unique_ptr<tsl::WritableFile>& trace_events_file,
99105
std::unique_ptr<tsl::WritableFile>& trace_events_metadata_file,
100106
std::unique_ptr<tsl::WritableFile>& trace_events_prefix_trie_file);
101107

102-
// Generates a copy of the event to be persisted in the trace events file.
103-
// This is the copy of the passed event without the timestamp_ps field.
104-
std::optional<TraceEvent> GenerateTraceEventCopyForPersistingFullEvent(
105-
const TraceEvent* event);
108+
// Serializes the event to be persisted in the trace events file.
109+
// This is the passed event without the timestamp_ps field.
110+
absl::Status SerializeTraceEventForPersistingFullEvent(const TraceEvent* event,
111+
std::string& output);
106112

107-
// Generates a copy of the event to be persisted in the trace events file.
108-
// This is the copy of the passed event without the raw_data and timestamp_ps
109-
// fields.
110-
std::optional<TraceEvent>
111-
GenerateTraceEventCopyForPersistingEventWithoutMetadata(
112-
const TraceEvent* event);
113+
// Serializes the event to be persisted in the trace events file.
114+
// This is the passed event without the raw_data and timestamp_ps fields.
115+
absl::Status SerializeTraceEventForPersistingEventWithoutMetadata(
116+
const TraceEvent* event, std::string& output);
113117

114-
// It generates a copy of the event to be persisted in the trace events metadata
115-
// file. This only has the raw_data field set.
116-
std::optional<TraceEvent> GenerateTraceEventCopyForPersistingOnlyMetadata(
117-
const TraceEvent* event);
118+
// Serializes the event to be persisted in the trace events metadata file.
119+
// This only has the raw_data field set.
120+
absl::Status SerializeTraceEventForPersistingOnlyMetadata(
121+
const TraceEvent* event, std::string& output);
118122

119123
// Opens the level db table from the given filename. The table is owned by the
120124
// caller.
@@ -774,9 +778,7 @@ class TraceEventsContainerBase {
774778
std::unique_ptr<tsl::WritableFile> file) const {
775779
Trace trace = trace_;
776780
trace.set_num_events(NumEvents());
777-
auto events_by_level = EventsByLevel();
778-
return DoStoreAsLevelDbTable(file, trace, events_by_level,
779-
GenerateTraceEventCopyForPersistingFullEvent);
781+
return StoreAsLevelDbTableImpl(std::move(file), trace, EventsByLevel());
780782
}
781783

782784
// Stores the contents of this container in three level-db sstable files. The

0 commit comments

Comments
 (0)