Skip to content

Commit 3b6096d

Browse files
authored
Merge pull request #3 from giuvecchio/bugfix/tensor-conversion-channels
Bugfix/tensor conversion channels
2 parents ee7fe22 + 2deaebe commit 3b6096d

1 file changed

Lines changed: 13 additions & 1 deletion

File tree

pypbr/materials/base.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -157,6 +157,9 @@ def _to_tensor(
157157
tensor = tensor.unsqueeze(0) # Add channel dimension
158158
return tensor.to(self.device)
159159
else:
160+
if image.mode == "RGBA":
161+
# Convert RGBA to RGB
162+
image = image.convert("RGB")
160163
# For other modes, use torchvision transforms
161164
return TF.to_tensor(image).to(self.device)
162165
else:
@@ -314,7 +317,9 @@ def as_dict(self) -> Dict[str, torch.FloatTensor]:
314317
return {name: map_value for name, map_value in self._maps.items()}
315318

316319
def as_tensor(
317-
self, names: Optional[List[Union[str, Tuple[str, int]]]] = None
320+
self,
321+
names: Optional[List[Union[str, Tuple[str, int]]]] = None,
322+
normalize: Optional[bool] = False,
318323
) -> torch.FloatTensor:
319324
"""
320325
Get a subset of texture maps stacked in a tensor.
@@ -327,6 +332,7 @@ def as_tensor(
327332
- map name (str)
328333
- number of channels to include (int)
329334
- The list can contain a mix of strings and tuples.
335+
normalize (Optional[bool]): Wether to normalized in range [-1, 1].
330336
331337
Returns:
332338
torch.FloatTensor: A tensor containing the specified texture maps stacked along the channel dimension.
@@ -388,6 +394,9 @@ def as_tensor(
388394
)
389395
tensor = tensor[:channel_limit]
390396

397+
if normalize and name != "normal":
398+
tensor = (tensor - 0.5) / 0.5
399+
391400
tensors.append(tensor)
392401

393402
if not tensors:
@@ -794,6 +803,9 @@ def to_pil(
794803
for name, map_value in self._maps.items():
795804
if map_value is not None:
796805
if name == "normal":
806+
if map_value.shape[0] == 2:
807+
# Compute the Z-component
808+
map_value = self._compute_normal_map_z_component(map_value)
797809
# Scale the normal map from [-1, 1] to [0, 1] before converting to PIL
798810
map_value = (map_value + 1.0) * 0.5
799811

0 commit comments

Comments
 (0)