Skip to content

Commit 0ade7e2

Browse files
authored
Merge pull request #103 from skoudoro/fix-local-header
BF: Improve TRX loading when local file headers have extra bytes
2 parents f436d76 + a521859 commit 0ade7e2

2 files changed

Lines changed: 130 additions & 4 deletions

File tree

trx/tests/test_memmap.py

Lines changed: 108 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
# -*- coding: utf-8 -*-
22

3+
import json
34
import os
5+
import struct
46
import tempfile
57
import zipfile
68

@@ -476,6 +478,112 @@ def test__ensure_little_endian_big_endian_input():
476478
assert result[0] == 0x12345678
477479

478480

481+
def test_load_zip_with_local_header_extra_field():
482+
"""Test loading ZIP where local header has extra field not in central dir.
483+
484+
Regression test for a bug where zip_info.FileHeader() was used to calculate
485+
data offset. The ZIP spec allows local headers to have different extra
486+
fields than central directory entries. The fix reads the actual local
487+
file header to get the correct offset.
488+
"""
489+
positions = np.array([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]], dtype=np.float32)
490+
offsets = np.array([0, 2], dtype=np.uint64)
491+
header = {
492+
"DIMENSIONS": [10, 10, 10],
493+
"VOXEL_TO_RASMM": np.eye(4).tolist(),
494+
"NB_VERTICES": 2,
495+
"NB_STREAMLINES": 1,
496+
}
497+
498+
with tempfile.TemporaryDirectory() as tmp_dir:
499+
trx_path = os.path.join(tmp_dir, "test.trx")
500+
501+
# Build ZIP with extra bytes in local headers but not central directory
502+
with open(trx_path, "wb") as f:
503+
local_info = []
504+
extra = b"\x00\x00\x04\x00TEST" # 8-byte extra field
505+
506+
for name, data in [
507+
("header.json", json.dumps(header).encode()),
508+
("positions.3.float32", positions.tobytes()),
509+
("offsets.uint64", offsets.tobytes()),
510+
]:
511+
offset = f.tell()
512+
fname = name.encode()
513+
crc = zipfile.crc32(data)
514+
# Local header WITH extra field
515+
f.write(
516+
struct.pack(
517+
"<4sHHHHHIIIHH",
518+
b"PK\x03\x04",
519+
20,
520+
0,
521+
0,
522+
0,
523+
0,
524+
crc,
525+
len(data),
526+
len(data),
527+
len(fname),
528+
len(extra),
529+
)
530+
)
531+
f.write(fname)
532+
f.write(extra)
533+
f.write(data)
534+
local_info.append((name, offset, crc, len(data)))
535+
536+
cd_start = f.tell()
537+
for name, offset, crc, size in local_info:
538+
fname = name.encode()
539+
# Central directory WITHOUT extra field (mismatch!)
540+
f.write(
541+
struct.pack(
542+
"<4sHHHHHHIIIHHHHHII",
543+
b"PK\x01\x02",
544+
20,
545+
20,
546+
0,
547+
0,
548+
0,
549+
0,
550+
crc,
551+
size,
552+
size,
553+
len(fname),
554+
0,
555+
0,
556+
0,
557+
0,
558+
0,
559+
offset,
560+
)
561+
)
562+
f.write(fname)
563+
564+
# End of central directory
565+
f.write(
566+
struct.pack(
567+
"<4sHHHHIIH",
568+
b"PK\x05\x06",
569+
0,
570+
0,
571+
3,
572+
3,
573+
f.tell() - cd_start,
574+
cd_start,
575+
0,
576+
)
577+
)
578+
579+
trx = tmm.load_from_zip(trx_path)
580+
np.testing.assert_array_almost_equal(trx.streamlines._data, positions)
581+
assert trx.header["NB_VERTICES"] == 2
582+
assert trx.header["NB_STREAMLINES"] == 1
583+
584+
trx.close()
585+
586+
479587
def test_endianness_roundtrip():
480588
"""Test that data survives write/read cycle with correct endianness."""
481589
with get_trx_tmp_dir() as dirname:

trx/trx_file_memmap.py

Lines changed: 22 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
import logging
77
import os
88
import shutil
9+
import struct
910
from typing import Any, List, Optional, Tuple, Type, Union
1011
import zipfile
1112

@@ -407,13 +408,30 @@ def load_from_zip(filename: str) -> Type["TrxFile"]:
407408
if ext == ".bit":
408409
ext = ".bool"
409410

410-
mem_adress = zip_info.header_offset + len(zip_info.FileHeader())
411+
# Read actual local file header to get correct data offset.
412+
# We can't use zip_info.FileHeader() because ZIP spec allows local
413+
# headers to differ from central directory entries.
414+
# See: https://pkware.cachefly.net/webdocs/casestudies/APPNOTE.TXT
415+
_ZIP_LOCAL_HEADER_SIZE = 30
416+
_ZIP_LOCAL_HEADER_SIGNATURE = b"PK\x03\x04"
417+
418+
zf.fp.seek(zip_info.header_offset)
419+
local_header = zf.fp.read(_ZIP_LOCAL_HEADER_SIZE)
420+
if len(local_header) < _ZIP_LOCAL_HEADER_SIZE:
421+
raise ValueError(f"Truncated local file header for {elem_filename}")
422+
if local_header[:4] != _ZIP_LOCAL_HEADER_SIGNATURE:
423+
raise ValueError(
424+
f"Invalid local file header signature for {elem_filename}"
425+
)
426+
fname_len, extra_len = struct.unpack("<HH", local_header[26:30])
427+
428+
mem_adress = (
429+
zip_info.header_offset + _ZIP_LOCAL_HEADER_SIZE + fname_len + extra_len
430+
)
431+
411432
dtype_size = np.dtype(ext[1:]).itemsize
412433
size = zip_info.file_size / dtype_size
413434

414-
if len(zip_info.extra):
415-
mem_adress -= len(zip_info.extra)
416-
417435
if size.is_integer():
418436
files_pointer_size[elem_filename] = mem_adress, int(size)
419437
else:

0 commit comments

Comments
 (0)