diff --git a/splunklib/searchcommands/internals.py b/splunklib/searchcommands/internals.py index abceac30f..5f20c3faf 100644 --- a/splunklib/searchcommands/internals.py +++ b/splunklib/searchcommands/internals.py @@ -554,7 +554,7 @@ def write_record(self, record): def write_records(self, records): self._ensure_validity() - records = list(records) + records = [] if records is NotImplemented else list(records) write_record = self._write_record for record in records: write_record(record) diff --git a/splunklib/searchcommands/reporting_command.py b/splunklib/searchcommands/reporting_command.py index 5df3dc7e7..e455a159a 100644 --- a/splunklib/searchcommands/reporting_command.py +++ b/splunklib/searchcommands/reporting_command.py @@ -77,21 +77,26 @@ def map(self, records): """ return NotImplemented - def prepare(self): - - phase = self.phase + def _has_custom_method(self, method_name): + method = getattr(self.__class__, method_name, None) + base_method = getattr(ReportingCommand, method_name, None) + return callable(method) and (method is not base_method) - if phase == 'map': - # noinspection PyUnresolvedReferences - self._configuration = self.map.ConfigurationSettings(self) + def prepare(self): + if self.phase == 'map': + if self._has_custom_method('map'): + phase_method = getattr(self.__class__, 'map') + self._configuration = phase_method.ConfigurationSettings(self) + else: + self._configuration = self.ConfigurationSettings(self) return - if phase == 'reduce': + if self.phase == 'reduce': streaming_preop = chain((self.name, 'phase="map"', str(self._options)), self.fieldnames) self._configuration.streaming_preop = ' '.join(streaming_preop) return - raise RuntimeError(f'Unrecognized reporting command phase: {json_encode_string(str(phase))}') + raise RuntimeError(f'Unrecognized reporting command phase: {json_encode_string(str(self.phase))}') def reduce(self, records): """ Override this method to produce a reporting data structure.