1818
1919from abc import ABC , abstractmethod
2020from datetime import datetime , timedelta
21- from typing import Any
21+ from typing import TYPE_CHECKING , Any
2222
23+ from airflow ._shared .timezones .timezone import parse_timezone
2324from airflow .partition_mappers .base import PartitionMapper
2425
26+ if TYPE_CHECKING :
27+ from pendulum import FixedTimezone , Timezone
28+
2529
2630class _BaseTemporalMapper (PartitionMapper , ABC ):
2731 """Base class for Temporal Partition Mappers."""
@@ -30,14 +34,24 @@ class _BaseTemporalMapper(PartitionMapper, ABC):
3034
3135 def __init__ (
3236 self ,
33- input_format : str = "%Y-%m-%dT%H:%M:%S" ,
37+ * ,
38+ timezone : str | Timezone | FixedTimezone ,
39+ input_format : str = "%Y-%m-%dT%H:%M:%S%z" ,
3440 output_format : str | None = None ,
3541 ):
3642 self .input_format = input_format
3743 self .output_format = output_format or self .default_output_format
44+ if isinstance (timezone , str ):
45+ timezone = parse_timezone (timezone )
46+ self ._timezone : Timezone | FixedTimezone = timezone
3847
3948 def to_downstream (self , key : str ) -> str :
4049 dt = datetime .strptime (key , self .input_format )
50+ if dt .tzinfo is None :
51+ dt = dt .replace (tzinfo = self ._timezone )
52+ else :
53+ dt = dt .astimezone (self ._timezone )
54+
4155 normalized = self .normalize (dt )
4256 return self .format (normalized )
4357
@@ -50,14 +64,18 @@ def format(self, dt: datetime) -> str:
5064 return dt .strftime (self .output_format )
5165
5266 def serialize (self ) -> dict [str , Any ]:
67+ from airflow .serialization .encoders import encode_timezone
68+
5369 return {
70+ "timezone" : encode_timezone (self ._timezone ),
5471 "input_format" : self .input_format ,
5572 "output_format" : self .output_format ,
5673 }
5774
5875 @classmethod
5976 def deserialize (cls , data : dict [str , Any ]) -> PartitionMapper :
6077 return cls (
78+ timezone = parse_timezone (data ["timezone" ]),
6179 input_format = data ["input_format" ],
6280 output_format = data ["output_format" ],
6381 )
@@ -66,7 +84,7 @@ def deserialize(cls, data: dict[str, Any]) -> PartitionMapper:
6684class HourlyMapper (_BaseTemporalMapper ):
6785 """Map a time-based partition key to hour."""
6886
69- default_output_format = "%Y-%m-%dT%H"
87+ default_output_format = "%Y-%m-%dT%H%z "
7088
7189 def normalize (self , dt : datetime ) -> datetime :
7290 return dt .replace (minute = 0 , second = 0 , microsecond = 0 )
@@ -75,7 +93,7 @@ def normalize(self, dt: datetime) -> datetime:
7593class DailyMapper (_BaseTemporalMapper ):
7694 """Map a time-based partition key to day."""
7795
78- default_output_format = "%Y-%m-%d"
96+ default_output_format = "%Y-%m-%d%z "
7997
8098 def normalize (self , dt : datetime ) -> datetime :
8199 return dt .replace (hour = 0 , minute = 0 , second = 0 , microsecond = 0 )
@@ -84,7 +102,7 @@ def normalize(self, dt: datetime) -> datetime:
84102class WeeklyMapper (_BaseTemporalMapper ):
85103 """Map a time-based partition key to week."""
86104
87- default_output_format = "%Y-%m-%d (W%V)"
105+ default_output_format = "%Y-%m-%d (W%V)%z "
88106
89107 def normalize (self , dt : datetime ) -> datetime :
90108 start = dt - timedelta (days = dt .weekday ())
@@ -94,7 +112,7 @@ def normalize(self, dt: datetime) -> datetime:
94112class MonthlyMapper (_BaseTemporalMapper ):
95113 """Map a time-based partition key to month."""
96114
97- default_output_format = "%Y-%m"
115+ default_output_format = "%Y-%m%z "
98116
99117 def normalize (self , dt : datetime ) -> datetime :
100118 return dt .replace (
@@ -109,7 +127,7 @@ def normalize(self, dt: datetime) -> datetime:
109127class QuarterlyMapper (_BaseTemporalMapper ):
110128 """Map a time-based partition key to quarter."""
111129
112- default_output_format = "%Y-Q{quarter}"
130+ default_output_format = "%Y-Q{quarter}%z "
113131
114132 def normalize (self , dt : datetime ) -> datetime :
115133 quarter = (dt .month - 1 ) // 3
@@ -131,7 +149,7 @@ def format(self, dt: datetime) -> str:
131149class YearlyMapper (_BaseTemporalMapper ):
132150 """Map a time-based partition key to year."""
133151
134- default_output_format = "%Y"
152+ default_output_format = "%Y%z "
135153
136154 def normalize (self , dt : datetime ) -> datetime :
137155 return dt .replace (
0 commit comments