diff --git a/cpp/src/arrow/csv/column_builder_test.cc b/cpp/src/arrow/csv/column_builder_test.cc index cb178c1d2b3..04c9cfe2482 100644 --- a/cpp/src/arrow/csv/column_builder_test.cc +++ b/cpp/src/arrow/csv/column_builder_test.cc @@ -342,6 +342,25 @@ TEST_F(InferringColumnBuilderTest, SingleChunkInteger) { {ArrayFromJSON(int64(), "[null, 123, 456]")}); } +TEST_F(InferringColumnBuilderTest, SingleChunkDefaultColumnTypeDoesNotOverrideInference) { + auto options = ConvertOptions::Defaults(); + options.default_column_type = utf8(); + auto tg = TaskGroup::MakeSerial(); + + CheckInferred(tg, {{"0000404", "0000505", "0000606"}}, options, + {ArrayFromJSON(int64(), "[404, 505, 606]")}); +} + +TEST_F(InferringColumnBuilderTest, + MultipleChunkDefaultColumnTypeDoesNotOverrideInference) { + auto options = ConvertOptions::Defaults(); + options.default_column_type = utf8(); + auto tg = TaskGroup::MakeSerial(); + + CheckInferred(tg, {{"0000404"}, {"0000505", "0000606"}}, options, + {ArrayFromJSON(int64(), "[404]"), ArrayFromJSON(int64(), "[505, 606]")}); +} + TEST_F(InferringColumnBuilderTest, MultipleChunkInteger) { auto options = ConvertOptions::Defaults(); auto tg = TaskGroup::MakeSerial(); diff --git a/cpp/src/arrow/csv/options.cc b/cpp/src/arrow/csv/options.cc index 365b5646b66..52daa9c5fc6 100644 --- a/cpp/src/arrow/csv/options.cc +++ b/cpp/src/arrow/csv/options.cc @@ -43,6 +43,7 @@ ConvertOptions ConvertOptions::Defaults() { "NULL", "NaN", "n/a", "nan", "null"}; options.true_values = {"1", "True", "TRUE", "true"}; options.false_values = {"0", "False", "FALSE", "false"}; + options.default_column_type = nullptr; return options; } diff --git a/cpp/src/arrow/csv/options.h b/cpp/src/arrow/csv/options.h index 10e55bf838c..839550c3f0c 100644 --- a/cpp/src/arrow/csv/options.h +++ b/cpp/src/arrow/csv/options.h @@ -76,6 +76,8 @@ struct ARROW_EXPORT ConvertOptions { bool check_utf8 = true; /// Optional per-column types (disabling type inference on those columns) std::unordered_map> column_types; + /// Default type to use for columns not in `column_types` + std::shared_ptr default_column_type; /// Recognized spellings for null values std::vector null_values; /// Recognized spellings for boolean true values diff --git a/cpp/src/arrow/csv/reader.cc b/cpp/src/arrow/csv/reader.cc index 3c4e7e3da0c..b6412673ebf 100644 --- a/cpp/src/arrow/csv/reader.cc +++ b/cpp/src/arrow/csv/reader.cc @@ -674,8 +674,14 @@ class ReaderMixin { // Does the named column have a fixed type? auto it = convert_options_.column_types.find(col_name); if (it == convert_options_.column_types.end()) { - conversion_schema_.columns.push_back( - ConversionSchema::InferredColumn(std::move(col_name), col_index)); + // If not explicitly typed, respect default_column_type when provided + if (convert_options_.default_column_type != nullptr) { + conversion_schema_.columns.push_back(ConversionSchema::TypedColumn( + std::move(col_name), col_index, convert_options_.default_column_type)); + } else { + conversion_schema_.columns.push_back( + ConversionSchema::InferredColumn(std::move(col_name), col_index)); + } } else { conversion_schema_.columns.push_back( ConversionSchema::TypedColumn(std::move(col_name), col_index, it->second)); diff --git a/cpp/src/arrow/csv/reader_test.cc b/cpp/src/arrow/csv/reader_test.cc index 57cc7d8efa5..deb5c6dfbd5 100644 --- a/cpp/src/arrow/csv/reader_test.cc +++ b/cpp/src/arrow/csv/reader_test.cc @@ -488,5 +488,92 @@ TEST(CountRowsAsync, Errors) { internal::GetCpuThreadPool(), read_options, parse_options)); } +TEST(ReaderTests, DefaultColumnTypePartialDefault) { + auto table_buffer = std::make_shared( + "id,name,value,date\n" + "0000101,apple,0003.1400,2024-01-15\n" + "00102,banana,001.6180,2024-02-20\n" + "0003,cherry,02.71800,2024-03-25\n"); + + auto input = std::make_shared(table_buffer); + auto read_options = ReadOptions::Defaults(); + auto parse_options = ParseOptions::Defaults(); + auto convert_options = ConvertOptions::Defaults(); + convert_options.column_types["id"] = int64(); + convert_options.default_column_type = utf8(); + + ASSERT_OK_AND_ASSIGN(auto reader, + TableReader::Make(io::default_io_context(), input, read_options, + parse_options, convert_options)); + ASSERT_OK_AND_ASSIGN(auto table, reader->Read()); + + auto expected_schema = schema({field("id", int64()), field("name", utf8()), + field("value", utf8()), field("date", utf8())}); + AssertSchemaEqual(expected_schema, table->schema()); + + auto expected_table = TableFromJSON( + expected_schema, + {R"([{"id":101, "name":"apple", "value":"0003.1400", "date":"2024-01-15"}, + {"id":102, "name":"banana", "value":"001.6180", "date":"2024-02-20"}, + {"id":3, "name":"cherry", "value":"02.71800", "date":"2024-03-25"}])"}); + ASSERT_TRUE(table->Equals(*expected_table)); +} + +TEST(ReaderTests, DefaultColumnTypeForcesTypedColumns) { + auto table_buffer = std::make_shared( + "id,amount,code\n" + "0000404,000045.6700,001\n" + "0000505,000000.10,010\n"); + + auto input = std::make_shared(table_buffer); + auto read_options = ReadOptions::Defaults(); + auto parse_options = ParseOptions::Defaults(); + auto convert_options = ConvertOptions::Defaults(); + convert_options.default_column_type = utf8(); + + ASSERT_OK_AND_ASSIGN(auto reader, + TableReader::Make(io::default_io_context(), input, read_options, + parse_options, convert_options)); + ASSERT_OK_AND_ASSIGN(auto table, reader->Read()); + + auto expected_schema = + schema({field("id", utf8()), field("amount", utf8()), field("code", utf8())}); + AssertSchemaEqual(expected_schema, table->schema()); + + auto expected_table = TableFromJSON( + expected_schema, {R"([{"id":"0000404", "amount":"000045.6700", "code":"001"}, + {"id":"0000505", "amount":"000000.10", "code":"010"}])"}); + ASSERT_TRUE(table->Equals(*expected_table)); +} + +TEST(ReaderTests, DefaultColumnTypeAllStringsNoHeader) { + // Input without header; autogenerate column names and default all to strings + auto table_buffer = std::make_shared("AB|000388907|000045.6700\n"); + + auto input = std::make_shared(table_buffer); + auto read_options = ReadOptions::Defaults(); + read_options.autogenerate_column_names = true; // treat first row as data + auto parse_options = ParseOptions::Defaults(); + parse_options.delimiter = '|'; + auto convert_options = ConvertOptions::Defaults(); + convert_options.default_column_type = utf8(); + + ASSERT_OK_AND_ASSIGN(auto reader, + TableReader::Make(io::default_io_context(), input, read_options, + parse_options, convert_options)); + ASSERT_OK_AND_ASSIGN(auto table, reader->Read()); + + auto expected_schema = + schema({field("f0", utf8()), field("f1", utf8()), field("f2", utf8())}); + AssertSchemaEqual(expected_schema, table->schema()); + + auto expected_table = TableFromJSON(expected_schema, {R"([{ + "f0":"AB", + "f1":"000388907", + "f2":"000045.6700" + }])"}); + ASSERT_TRUE(table->Equals(*expected_table)); +} + } // namespace csv } // namespace arrow diff --git a/docs/source/python/csv.rst b/docs/source/python/csv.rst index 5eb68e9ccdc..27b740cdfd7 100644 --- a/docs/source/python/csv.rst +++ b/docs/source/python/csv.rst @@ -136,6 +136,7 @@ Available convert options are: ~ConvertOptions.check_utf8 ~ConvertOptions.column_types + ~ConvertOptions.default_column_type ~ConvertOptions.null_values ~ConvertOptions.true_values ~ConvertOptions.false_values diff --git a/python/pyarrow/_csv.pyx b/python/pyarrow/_csv.pyx index ed9d20beb6b..b878266ef8b 100644 --- a/python/pyarrow/_csv.pyx +++ b/python/pyarrow/_csv.pyx @@ -613,6 +613,9 @@ cdef class ConvertOptions(_Weakrefable): column_types : pyarrow.Schema or dict, optional Explicitly map column names to column types. Passing this argument disables type inference on the defined columns. + default_column_type : pyarrow.DataType, optional + Explicitly map columns not specified in column_types to a default type. + Passing this argument disables type inference on all columns. null_values : list, optional A sequence of strings that denote nulls in the data (defaults are appropriate in most cases). Note that by default, @@ -807,6 +810,40 @@ cdef class ConvertOptions(_Weakrefable): fast: bool ---- fast: [[true,true,false,false,null]] + + Set a default column type for all columns (disables type inference): + + >>> convert_options = csv.ConvertOptions(default_column_type=pa.string()) + >>> csv.read_csv(io.BytesIO(s.encode()), convert_options=convert_options) + pyarrow.Table + animals: string + n_legs: string + entry: string + fast: string + ---- + animals: [["Flamingo","Horse","Brittle stars","Centipede",""]] + n_legs: [["2","4","5","100","6"]] + entry: [["01/03/2022","02/03/2022","03/03/2022","04/03/2022","05/03/2022"]] + fast: [["Yes","Yes","No","No",""]] + + Combine default_column_type with column_types (specific column types override default): + + >>> convert_options = csv.ConvertOptions( + ... column_types={"n_legs": pa.int64(), "fast": pa.bool_()}, + ... default_column_type=pa.string(), + ... true_values=["Yes"], + ... false_values=["No"]) + >>> csv.read_csv(io.BytesIO(s.encode()), convert_options=convert_options) + pyarrow.Table + animals: string + n_legs: int64 + entry: string + fast: bool + ---- + animals: [["Flamingo","Horse","Brittle stars","Centipede",""]] + n_legs: [[2,4,5,100,6]] + entry: [["01/03/2022","02/03/2022","03/03/2022","04/03/2022","05/03/2022"]] + fast: [[true,true,false,false,null]] """ # Avoid mistakingly creating attributes @@ -816,7 +853,7 @@ cdef class ConvertOptions(_Weakrefable): self.options.reset( new CCSVConvertOptions(CCSVConvertOptions.Defaults())) - def __init__(self, *, check_utf8=None, column_types=None, null_values=None, + def __init__(self, *, check_utf8=None, column_types=None, default_column_type=None, null_values=None, true_values=None, false_values=None, decimal_point=None, strings_can_be_null=None, quoted_strings_can_be_null=None, include_columns=None, include_missing_columns=None, @@ -826,6 +863,8 @@ cdef class ConvertOptions(_Weakrefable): self.check_utf8 = check_utf8 if column_types is not None: self.column_types = column_types + if default_column_type is not None: + self.default_column_type = default_column_type if null_values is not None: self.null_values = null_values if true_values is not None: @@ -910,6 +949,27 @@ cdef class ConvertOptions(_Weakrefable): assert typ != NULL deref(self.options).column_types[tobytes(k)] = typ + @property + def default_column_type(self): + """ + Explicitly map columns not specified in column_types to a default type. + """ + if deref(self.options).default_column_type != NULL: + return pyarrow_wrap_data_type(deref(self.options).default_column_type) + else: + return None + + @default_column_type.setter + def default_column_type(self, value): + cdef: + shared_ptr[CDataType] typ + if value is not None: + typ = pyarrow_unwrap_data_type(ensure_type(value)) + assert typ != NULL + deref(self.options).default_column_type = typ + else: + deref(self.options).default_column_type.reset() + @property def null_values(self): """ @@ -1071,6 +1131,7 @@ cdef class ConvertOptions(_Weakrefable): return ( self.check_utf8 == other.check_utf8 and self.column_types == other.column_types and + self.default_column_type == other.default_column_type and self.null_values == other.null_values and self.true_values == other.true_values and self.false_values == other.false_values and @@ -1087,17 +1148,17 @@ cdef class ConvertOptions(_Weakrefable): ) def __getstate__(self): - return (self.check_utf8, self.column_types, self.null_values, - self.true_values, self.false_values, self.decimal_point, - self.timestamp_parsers, self.strings_can_be_null, - self.quoted_strings_can_be_null, self.auto_dict_encode, - self.auto_dict_max_cardinality, self.include_columns, - self.include_missing_columns) + return (self.check_utf8, self.column_types, self.default_column_type, + self.null_values, self.true_values, self.false_values, + self.decimal_point, self.timestamp_parsers, + self.strings_can_be_null, self.quoted_strings_can_be_null, + self.auto_dict_encode, self.auto_dict_max_cardinality, + self.include_columns, self.include_missing_columns) def __setstate__(self, state): - (self.check_utf8, self.column_types, self.null_values, - self.true_values, self.false_values, self.decimal_point, - self.timestamp_parsers, self.strings_can_be_null, + (self.check_utf8, self.column_types, self.default_column_type, + self.null_values, self.true_values, self.false_values, + self.decimal_point, self.timestamp_parsers, self.strings_can_be_null, self.quoted_strings_can_be_null, self.auto_dict_encode, self.auto_dict_max_cardinality, self.include_columns, self.include_missing_columns) = state diff --git a/python/pyarrow/includes/libarrow.pxd b/python/pyarrow/includes/libarrow.pxd index f294ee4d50b..fa479391211 100644 --- a/python/pyarrow/includes/libarrow.pxd +++ b/python/pyarrow/includes/libarrow.pxd @@ -2104,6 +2104,7 @@ cdef extern from "arrow/csv/api.h" namespace "arrow::csv" nogil: cdef cppclass CCSVConvertOptions" arrow::csv::ConvertOptions": c_bool check_utf8 unordered_map[c_string, shared_ptr[CDataType]] column_types + shared_ptr[CDataType] default_column_type vector[c_string] null_values vector[c_string] true_values vector[c_string] false_values diff --git a/python/pyarrow/tests/test_csv.py b/python/pyarrow/tests/test_csv.py index f510c6dbe23..a4840bcb9f2 100644 --- a/python/pyarrow/tests/test_csv.py +++ b/python/pyarrow/tests/test_csv.py @@ -297,7 +297,8 @@ def test_convert_options(pickle_module): include_columns=['def', 'abc'], include_missing_columns=False, auto_dict_encode=True, - timestamp_parsers=[ISO8601, '%y-%m']) + timestamp_parsers=[ISO8601, '%y-%m'], + default_column_type=pa.int16()) with pytest.raises(ValueError): opts.decimal_point = '..' @@ -325,6 +326,17 @@ def test_convert_options(pickle_module): with pytest.raises(TypeError): opts.column_types = 0 + assert opts.default_column_type is None + opts.default_column_type = pa.string() + assert opts.default_column_type == pa.string() + opts.default_column_type = 'int32' + assert opts.default_column_type == pa.int32() + opts.default_column_type = None + assert opts.default_column_type is None + + with pytest.raises(TypeError, match='DataType expected'): + opts.default_column_type = 123 + assert isinstance(opts.null_values, list) assert '' in opts.null_values assert 'N/A' in opts.null_values @@ -1331,6 +1343,57 @@ def test_column_types_with_column_names(self): 'y': ['b', 'd', 'f'], } + def test_default_column_type(self): + rows = b"a,b,c,d\n001,2.5,hello,true\n4,3.14,world,false\n" + + # Test with default_column_type only - all columns should use the specified type. + opts = ConvertOptions(default_column_type=pa.string()) + table = self.read_bytes(rows, convert_options=opts) + schema = pa.schema([('a', pa.string()), + ('b', pa.string()), + ('c', pa.string()), + ('d', pa.string())]) + assert table.schema == schema + assert table.to_pydict() == { + 'a': ["001", "4"], + 'b': ["2.5", "3.14"], + 'c': ["hello", "world"], + 'd': ["true", "false"], + } + + # Test with both column_types and default_column_type + # Columns specified in column_types should override default_column_type + opts = ConvertOptions( + column_types={'b': pa.float64(), 'd': pa.bool_()}, + default_column_type=pa.string() + ) + table = self.read_bytes(rows, convert_options=opts) + schema = pa.schema([('a', pa.string()), + ('b', pa.float64()), + ('c', pa.string()), + ('d', pa.bool_())]) + assert table.schema == schema + assert table.to_pydict() == { + 'a': ["001", "4"], + 'b': [2.5, 3.14], + 'c': ["hello", "world"], + 'd': [True, False], + } + + # Test that default_column_type disables type inference + opts_no_default = ConvertOptions(column_types={'b': pa.float64()}) + table_no_default = self.read_bytes(rows, convert_options=opts_no_default) + + opts_with_default = ConvertOptions( + column_types={'b': pa.float64()}, + default_column_type=pa.string() + ) + table_with_default = self.read_bytes(rows, convert_options=opts_with_default) + + # Column 'a' should be int64 without default, string with default + assert table_no_default.schema.field('a').type == pa.int64() + assert table_with_default.schema.field('a').type == pa.string() + def test_no_ending_newline(self): # No \n after last line rows = b"a,b,c\n1,2,3\n4,5,6"