Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 41 additions & 4 deletions src/linkml_map/loaders/data_loaders.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,16 +83,25 @@ def __init__(
self,
source: str | Path,
skip_empty_rows: bool = True,
schema_path: str | Path | None = None,
target_class: str | None = None,
) -> None:
"""Initialize TSV loader."""
super().__init__(source)
self.skip_empty_rows = skip_empty_rows
self.schema_path = schema_path
self.target_class = target_class

def iter_instances(self) -> Iterator[dict[str, Any]]:
"""Iterate over rows from the TSV file."""
from linkml.validator.loaders import TsvLoader

loader = TsvLoader(str(self.source), skip_empty_rows=self.skip_empty_rows)
loader = TsvLoader(
str(self.source),
skip_empty_rows=self.skip_empty_rows,
schema_path=self.schema_path,
target_class=self.target_class,
)
yield from loader.iter_instances()


Expand All @@ -103,29 +112,43 @@ def __init__(
self,
source: str | Path,
skip_empty_rows: bool = True,
schema_path: str | Path | None = None,
target_class: str | None = None,
) -> None:
"""Initialize CSV loader."""
super().__init__(source)
self.skip_empty_rows = skip_empty_rows
self.schema_path = schema_path
self.target_class = target_class

def iter_instances(self) -> Iterator[dict[str, Any]]:
"""Iterate over rows from the CSV file."""
from linkml.validator.loaders import CsvLoader

loader = CsvLoader(str(self.source), skip_empty_rows=self.skip_empty_rows)
loader = CsvLoader(
str(self.source),
skip_empty_rows=self.skip_empty_rows,
schema_path=self.schema_path,
target_class=self.target_class,
)
yield from loader.iter_instances()


def get_file_loader(
path: str | Path,
file_format: FileFormat | None = None,
*,
schema_path: str | Path | None = None,
target_class: str | None = None,
**kwargs: Any,
) -> BaseFileLoader:
"""
Get the appropriate file loader for a given path.

:param path: Path to the file
:param file_format: Explicit file format (auto-detected from extension if not provided)
:param schema_path: Path to the LinkML schema (enables schema-aware type coercion for TSV/CSV)
:param target_class: Target class name within the schema
:param kwargs: Additional arguments passed to the loader
:return: Appropriate file loader instance
"""
Expand All @@ -144,6 +167,10 @@ def get_file_loader(
msg = f"No loader available for format: {file_format}"
raise ValueError(msg)

if file_format in (FileFormat.TSV, FileFormat.CSV):
kwargs["schema_path"] = schema_path
kwargs["target_class"] = target_class

return loader_class(path, **kwargs)


Expand Down Expand Up @@ -175,13 +202,17 @@ def __init__(
base_path: str | Path,
default_format: FileFormat | None = None,
skip_empty_rows: bool = True,
schema_path: str | Path | None = None,
target_class: str | None = None,
) -> None:
"""
Initialize the data loader.

:param base_path: Base directory containing data files, or a single file path
:param default_format: Default format to use when extension is ambiguous
:param skip_empty_rows: Skip empty rows in tabular files (default: True)
:param schema_path: Path to the LinkML schema (enables schema-aware type coercion for TSV/CSV)
:param target_class: Target class name within the schema
:raises FileNotFoundError: If the path does not exist
"""
self.base_path = Path(base_path)
Expand All @@ -190,6 +221,8 @@ def __init__(
raise FileNotFoundError(msg)
self.default_format = default_format
self.skip_empty_rows = skip_empty_rows
self.schema_path = schema_path
self.target_class = target_class

@property
def is_single_file(self) -> bool:
Expand Down Expand Up @@ -282,7 +315,9 @@ def __getitem__(self, identifier: str) -> Iterator[dict[str, Any]]:
if file_format in (FileFormat.TSV, FileFormat.CSV):
loader_kwargs["skip_empty_rows"] = self.skip_empty_rows

loader = get_file_loader(file_path, **loader_kwargs)
loader = get_file_loader(
file_path, schema_path=self.schema_path, target_class=self.target_class, **loader_kwargs
)
return loader.iter_instances()

def __iter__(self) -> Iterator[dict[str, Any]]:
Expand All @@ -296,7 +331,9 @@ def __iter__(self) -> Iterator[dict[str, Any]]:
if file_format in (FileFormat.TSV, FileFormat.CSV):
loader_kwargs["skip_empty_rows"] = self.skip_empty_rows

loader = get_file_loader(self.base_path, **loader_kwargs)
loader = get_file_loader(
self.base_path, schema_path=self.schema_path, target_class=self.target_class, **loader_kwargs
)
yield from loader.iter_instances()

def get_available_identifiers(self) -> list[str]:
Expand Down
154 changes: 154 additions & 0 deletions tests/test_loaders/test_data_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,34 @@
import yaml

from linkml_map.loaders import DataLoader, FileFormat, load_data_file
from linkml_map.loaders.data_loaders import CsvFileLoader, TsvFileLoader, get_file_loader

SCHEMA_WITH_ENUM = {
"id": "https://example.org/test",
"name": "test",
"prefixes": {"linkml": "https://w3id.org/linkml/"},
"imports": ["linkml:types"],
"default_range": "string",
"classes": {
"Record": {
"attributes": {
"id": {"range": "integer", "identifier": True},
"zipcode": {"range": "string"},
"score": {"range": "score_enum"},
"weight": {"range": "float"},
}
}
},
"enums": {
"score_enum": {
"permissible_values": {
"1": {"description": "Low"},
"2": {"description": "Medium"},
"3": {"description": "High"},
}
}
},
}


@pytest.fixture
Expand Down Expand Up @@ -313,3 +341,129 @@ def test_skip_empty_rows_false(self, tmp_path: Path) -> None:
assert rows[1]["id"] == "P:002"
# Empty string values are not included in the dict by linkml's loader
assert "name" not in rows[1] or rows[1].get("name") == ""


# --- Schema-aware loading tests ---
# These verify that schema_path/target_class flow through to the underlying
# linkml loader so that string-ranged and enum-ranged columns are not
# coerced to int/float.


@pytest.fixture()
def schema_file(tmp_path: Path) -> Path:
"""Write the test schema to a YAML file and return its path."""
path = tmp_path / "schema.yaml"
path.write_text(yaml.dump(SCHEMA_WITH_ENUM))
return path


@pytest.fixture()
def schema_aware_tsv(tmp_path: Path) -> Path:
"""TSV with numeric-looking values in string and enum columns."""
path = tmp_path / "Record.tsv"
path.write_text("id\tzipcode\tscore\tweight\n1\t90210\t2\t3.5\n")
return path


@pytest.fixture()
def schema_aware_csv(tmp_path: Path) -> Path:
"""CSV with numeric-looking values in string and enum columns."""
path = tmp_path / "Record.csv"
path.write_text("id,zipcode,score,weight\n1,90210,2,3.5\n")
return path


def _assert_schema_aware_row(row: dict) -> None:
"""Shared assertions for schema-aware loading: string/enum columns stay strings."""
assert row["id"] == 1
assert isinstance(row["id"], int)
assert row["zipcode"] == "90210"
assert isinstance(row["zipcode"], str)
assert row["score"] == "2"
assert isinstance(row["score"], str)
assert row["weight"] == 3.5
assert isinstance(row["weight"], float)


class TestSchemaAwareTsvFileLoader:
"""TsvFileLoader preserves string/enum columns when given a schema."""

def test_with_schema(self, schema_aware_tsv: Path, schema_file: Path) -> None:
loader = TsvFileLoader(schema_aware_tsv, schema_path=schema_file, target_class="Record")
row = next(loader.iter_instances())
_assert_schema_aware_row(row)

def test_without_schema_coerces(self, schema_aware_tsv: Path) -> None:
loader = TsvFileLoader(schema_aware_tsv)
row = next(loader.iter_instances())
assert isinstance(row["zipcode"], int)
assert isinstance(row["score"], int)


class TestSchemaAwareCsvFileLoader:
"""CsvFileLoader preserves string/enum columns when given a schema."""

def test_with_schema(self, schema_aware_csv: Path, schema_file: Path) -> None:
loader = CsvFileLoader(schema_aware_csv, schema_path=schema_file, target_class="Record")
row = next(loader.iter_instances())
_assert_schema_aware_row(row)

def test_without_schema_coerces(self, schema_aware_csv: Path) -> None:
loader = CsvFileLoader(schema_aware_csv)
row = next(loader.iter_instances())
assert isinstance(row["zipcode"], int)
assert isinstance(row["score"], int)


class TestSchemaAwareGetFileLoader:
"""get_file_loader forwards schema params to TSV/CSV loaders."""

@pytest.mark.parametrize("fixture_name", ["schema_aware_tsv", "schema_aware_csv"])
def test_with_schema(self, fixture_name: str, schema_file: Path, request: pytest.FixtureRequest) -> None:
data_file = request.getfixturevalue(fixture_name)
loader = get_file_loader(data_file, schema_path=schema_file, target_class="Record")
row = next(loader.iter_instances())
_assert_schema_aware_row(row)

def test_ignored_for_yaml(self, tmp_path: Path, schema_file: Path) -> None:
"""schema_path/target_class are accepted but ignored for non-tabular formats."""
yaml_path = tmp_path / "data.yaml"
yaml_path.write_text(yaml.dump({"id": 1, "zipcode": "90210"}))
loader = get_file_loader(yaml_path, schema_path=schema_file, target_class="Record")
row = next(loader.iter_instances())
assert row["id"] == 1


class TestSchemaAwareDataLoader:
"""DataLoader forwards schema params through to underlying loaders."""

def test_single_file_with_schema(
self, schema_aware_tsv: Path, schema_file: Path
) -> None:
loader = DataLoader(schema_aware_tsv, schema_path=schema_file, target_class="Record")
row = next(iter(loader))
_assert_schema_aware_row(row)

def test_directory_with_schema(self, tmp_path: Path, schema_file: Path) -> None:
tsv_path = tmp_path / "Record.tsv"
tsv_path.write_text("id\tzipcode\tscore\tweight\n1\t90210\t2\t3.5\n")
loader = DataLoader(tmp_path, schema_path=schema_file, target_class="Record")
row = next(loader["Record"])
_assert_schema_aware_row(row)

def test_directory_without_schema_coerces(self, tmp_path: Path) -> None:
tsv_path = tmp_path / "Record.tsv"
tsv_path.write_text("id\tzipcode\tscore\tweight\n1\t90210\t2\t3.5\n")
loader = DataLoader(tmp_path)
row = next(loader["Record"])
assert isinstance(row["zipcode"], int)

def test_iter_sources_with_schema(
self, schema_aware_tsv: Path, schema_file: Path
) -> None:
loader = DataLoader(schema_aware_tsv, schema_path=schema_file, target_class="Record")
sources = list(loader.iter_sources())
assert len(sources) == 1
_, rows = sources[0]
row = next(rows)
_assert_schema_aware_row(row)