@@ -22,7 +22,6 @@ limitations under the License.
2222#include < iterator>
2323#include < limits>
2424#include < memory>
25- #include < optional>
2625#include < string>
2726#include < utility>
2827#include < variant>
@@ -31,13 +30,19 @@ limitations under the License.
3130#include " absl/algorithm/container.h"
3231#include " absl/container/flat_hash_map.h"
3332#include " absl/container/flat_hash_set.h"
33+ #include " absl/functional/function_ref.h"
3434#include " absl/log/check.h"
3535#include " absl/log/log.h"
3636#include " absl/numeric/bits.h"
3737#include " absl/status/status.h"
38+ #include " absl/strings/str_cat.h"
3839#include " absl/strings/string_view.h"
3940#include " absl/time/clock.h"
4041#include " absl/time/time.h"
42+ #include " google/protobuf/descriptor.h"
43+ #include " google/protobuf/io/coded_stream.h"
44+ #include " google/protobuf/io/zero_copy_stream_impl_lite.h"
45+ #include " google/protobuf/wire_format_lite.h"
4146#include " xla/tsl/lib/io/iterator.h"
4247#include " xla/tsl/lib/io/table.h"
4348#include " xla/tsl/lib/io/table_builder.h"
@@ -53,6 +58,9 @@ limitations under the License.
5358#include " plugin/xprof/protobuf/trace_events.pb.h"
5459#include " plugin/xprof/protobuf/trace_events_raw.pb.h"
5560
61+ namespace proto2_io = ::google::protobuf::io;
62+ namespace proto2_internal = ::google::protobuf::internal;
63+
5664namespace tensorflow {
5765namespace profiler {
5866
@@ -94,6 +102,79 @@ inline void AppendEvents(TraceEventTrack&& src, TraceEventTrack* dst) {
94102 }
95103}
96104
105+ // Serializes 'event' to 'output' while skipping fields 'skip_field_1' and
106+ // 'skip_field_2'. Uses reflection for efficiency and type-agnostic field
107+ // iteration.
108+ absl::Status SerializeTraceEventSkipping (const TraceEvent* event,
109+ int skip_field_1, int skip_field_2,
110+ std::string& output) {
111+ output.clear ();
112+ proto2_io::StringOutputStream string_stream (&output);
113+ {
114+ proto2_io::CodedOutputStream coded_stream (&string_stream);
115+ const google::protobuf::Reflection* reflection = event->GetReflection ();
116+ std::vector<const google::protobuf::FieldDescriptor*> fields;
117+ reflection->ListFields (*event, &fields);
118+
119+ using proto2_internal::WireFormatLite;
120+ for (const google::protobuf::FieldDescriptor* field : fields) {
121+ int field_number = field->number ();
122+ if (field_number == skip_field_1 || field_number == skip_field_2) {
123+ continue ;
124+ }
125+
126+ switch (field->type ()) {
127+ case google::protobuf::FieldDescriptor::TYPE_UINT32:
128+ WireFormatLite::WriteUInt32 (field_number,
129+ reflection->GetUInt32 (*event, field),
130+ &coded_stream);
131+ break ;
132+ case google::protobuf::FieldDescriptor::TYPE_UINT64:
133+ WireFormatLite::WriteUInt64 (field_number,
134+ reflection->GetUInt64 (*event, field),
135+ &coded_stream);
136+ break ;
137+ case google::protobuf::FieldDescriptor::TYPE_INT64:
138+ WireFormatLite::WriteInt64 (
139+ field_number, reflection->GetInt64 (*event, field), &coded_stream);
140+ break ;
141+ case google::protobuf::FieldDescriptor::TYPE_FIXED64:
142+ WireFormatLite::WriteFixed64 (field_number,
143+ reflection->GetUInt64 (*event, field),
144+ &coded_stream);
145+ break ;
146+ case google::protobuf::FieldDescriptor::TYPE_ENUM:
147+ WireFormatLite::WriteEnum (field_number,
148+ reflection->GetEnumValue (*event, field),
149+ &coded_stream);
150+ break ;
151+ case google::protobuf::FieldDescriptor::TYPE_STRING:
152+ case google::protobuf::FieldDescriptor::TYPE_BYTES:
153+ WireFormatLite::WriteString (field_number,
154+ reflection->GetString (*event, field),
155+ &coded_stream);
156+ break ;
157+ default :
158+ return absl::UnimplementedError (
159+ absl::StrCat (" Unsupported field type: " , field->name ()));
160+ }
161+ }
162+ }
163+ return absl::OkStatus ();
164+ }
165+
166+ // Helper to wrap a serialization function into the required functor format.
167+ // Reuses the given 'buffer' for efficiency.
168+ // NOTE: The returned string_view is not thread-safe and must not outlive the
169+ // buffer.
170+ absl::StatusOr<absl::string_view> SerializeToView (
171+ const TraceEvent* event, std::string& buffer,
172+ absl::FunctionRef<absl::Status(const TraceEvent*, std::string&)>
173+ serialize_fn) {
174+ TF_RETURN_IF_ERROR (serialize_fn (event, buffer));
175+ return absl::string_view (buffer);
176+ }
177+
97178} // namespace
98179
99180TraceEvent::EventType GetTraceEventType (const TraceEvent& event) {
@@ -259,20 +340,30 @@ absl::Status DoStoreAsLevelDbTables(
259340 const Trace& trace, std::unique_ptr<tsl::WritableFile>& trace_events_file,
260341 std::unique_ptr<tsl::WritableFile>& trace_events_metadata_file,
261342 std::unique_ptr<tsl::WritableFile>& trace_events_prefix_trie_file) {
262- auto executor = std::make_unique<XprofThreadPoolExecutor>(
263- " StoreAsLevelDbTables" , /* num_threads=*/ 3 );
343+ std::unique_ptr<XprofThreadPoolExecutor> executor =
344+ std::make_unique<XprofThreadPoolExecutor>(" StoreAsLevelDbTables" ,
345+ /* num_threads=*/ 3 );
264346 absl::Status trace_events_status, trace_events_metadata_status;
265347 executor->Execute (
266348 [&trace_events_file, &trace, &events_by_level, &trace_events_status]() {
349+ std::string buffer;
267350 trace_events_status = DoStoreAsLevelDbTable (
268351 trace_events_file, trace, events_by_level,
269- GenerateTraceEventCopyForPersistingEventWithoutMetadata);
352+ [&buffer](const TraceEvent* event) {
353+ return SerializeToView (
354+ event, buffer,
355+ SerializeTraceEventForPersistingEventWithoutMetadata);
356+ });
270357 });
271358 executor->Execute ([&trace_events_metadata_file, &events_by_level, &trace,
272359 &trace_events_metadata_status]() {
360+ std::string buffer;
273361 trace_events_metadata_status = DoStoreAsLevelDbTable (
274362 trace_events_metadata_file, trace, events_by_level,
275- GenerateTraceEventCopyForPersistingOnlyMetadata);
363+ [&buffer](const TraceEvent* event) {
364+ return SerializeToView (event, buffer,
365+ SerializeTraceEventForPersistingOnlyMetadata);
366+ });
276367 });
277368 absl::Status trace_events_prefix_trie_status;
278369 executor->Execute ([&trace_events_prefix_trie_file, &events_by_level,
@@ -286,43 +377,42 @@ absl::Status DoStoreAsLevelDbTables(
286377 return trace_events_status;
287378}
288379
289- std::optional<TraceEvent> GenerateTraceEventCopyForPersistingFullEvent (
290- const TraceEvent* event) {
291- TraceEvent event_copy = *event;
292- // To reduce file size, clear the timestamp from the value. It is
293- // redundant info because the timestamp is part of the key.
294- event_copy.clear_timestamp_ps ();
295- return event_copy;
380+ absl::Status SerializeTraceEventForPersistingFullEvent (const TraceEvent* event,
381+ std::string& output) {
382+ return SerializeTraceEventSkipping (event, TraceEvent::kTimestampPsFieldNumber ,
383+ /* skip_field_2=*/ -1 , output);
296384}
297385
298- std::optional<TraceEvent>
299- GenerateTraceEventCopyForPersistingEventWithoutMetadata (
300- const TraceEvent* event) {
301- TraceEvent event_copy = *event;
302- // To reduce file size, clear the timestamp from the value. It is
303- // redundant info because the timestamp is part of the key.
304- event_copy.clear_timestamp_ps ();
386+ absl::Status SerializeTraceEventForPersistingEventWithoutMetadata (
387+ const TraceEvent* event, std::string& output) {
388+ int skip_2 = -1 ;
305389 // To reduce file size, clear the raw data from the value. It is
306390 // redundant info because the raw data is stored in the metadata file.
307391 // However, we still need to keep the raw data for counter events as they
308392 // are a special case and we need to return the args for the same during the
309393 // initial read.
310394 if (GetTraceEventType (*event) != TraceEvent::EVENT_TYPE_COUNTER) {
311- event_copy. clear_raw_data () ;
395+ skip_2 = TraceEvent:: kRawDataFieldNumber ;
312396 }
313- return event_copy;
397+ return SerializeTraceEventSkipping (event, TraceEvent::kTimestampPsFieldNumber ,
398+ /* skip_field_2=*/ skip_2, output);
314399}
315400
316- std::optional<TraceEvent> GenerateTraceEventCopyForPersistingOnlyMetadata (
317- const TraceEvent* event) {
401+ absl::Status SerializeTraceEventForPersistingOnlyMetadata (
402+ const TraceEvent* event, std::string& output ) {
318403 if (GetTraceEventType (*event) == TraceEvent::EVENT_TYPE_COUNTER) {
319404 // Counter events are stored in the trace events file itself and do not
320405 // require a metadata copy.
321- return std::nullopt ;
406+ return absl::NotFoundError (" No metadata found for counter event" );
407+ }
408+ // To avoid redundant deep copies of the whole TraceEvent, we only copy
409+ // the raw_data field into a small individual TraceEvent.
410+ TraceEvent metadata_event;
411+ metadata_event.set_raw_data (event->raw_data ());
412+ if (metadata_event.SerializeToString (&output)) {
413+ return absl::OkStatus ();
322414 }
323- TraceEvent event_copy;
324- event_copy.set_raw_data (event->raw_data ());
325- return event_copy;
415+ return absl::InternalError (" Failed to serialize trace metadata" );
326416}
327417
328418// Store the contents of this container in an sstable file. The format is as
@@ -340,8 +430,7 @@ std::optional<TraceEvent> GenerateTraceEventCopyForPersistingOnlyMetadata(
340430absl::Status DoStoreAsLevelDbTable (
341431 std::unique_ptr<tsl::WritableFile>& file, const Trace& trace,
342432 const std::vector<std::vector<const TraceEvent*>>& events_by_level,
343- std::function<std::optional<TraceEvent>(const TraceEvent*)>
344- generate_event_copy_fn) {
433+ SerializeEventFn serialize_event_fn) {
345434 absl::Time start_time = absl::Now ();
346435 tsl::table::Options options;
347436 options.block_size = 20 * 1024 * 1024 ;
@@ -361,9 +450,12 @@ absl::Status DoStoreAsLevelDbTable(
361450 uint64_t timestamp = event->timestamp_ps ();
362451 std::string key = LevelDbTableKey (zoom_level, timestamp, event->serial ());
363452 if (!key.empty ()) {
364- auto event_copy = generate_event_copy_fn (event);
365- if (event_copy.has_value ()) {
366- builder.Add (key, event_copy->SerializeAsString ());
453+ absl::StatusOr<absl::string_view> status_or_view =
454+ serialize_event_fn (event);
455+ if (status_or_view.ok ()) {
456+ builder.Add (key, status_or_view.value ());
457+ } else if (!absl::IsNotFound (status_or_view.status ())) {
458+ return status_or_view.status ();
367459 }
368460 } else {
369461 ++num_of_events_dropped;
@@ -477,6 +569,17 @@ void TraceEventsContainerBase<EventFactory, RawData, Hash>::Merge(
477569 other.trace_ .Clear ();
478570}
479571
572+ absl::Status StoreAsLevelDbTableImpl (
573+ std::unique_ptr<tsl::WritableFile> file, const Trace& trace,
574+ const std::vector<std::vector<const TraceEvent*>>& events_by_level) {
575+ std::string buffer;
576+ return DoStoreAsLevelDbTable (
577+ file, trace, events_by_level, [&buffer](const TraceEvent* event) {
578+ return SerializeToView (event, buffer,
579+ SerializeTraceEventForPersistingFullEvent);
580+ });
581+ }
582+
480583// Explicit instantiations for the common case.
481584template class TraceEventsContainerBase <EventFactory, RawData>;
482585
0 commit comments