1212# limitations under the License.
1313
1414from __future__ import annotations
15- from typing import Optional , TYPE_CHECKING
15+ from typing import Optional , TYPE_CHECKING , cast
1616import math
1717from sys import version_info
1818
2828if TYPE_CHECKING :
2929 # prevent circular dependenacy by skipping import at runtime
3030 from .project_config import ProjectConfig
31- from .entities import Experiment , Variation
31+ from .entities import Experiment , Variation , Holdout
3232 from .helpers .types import TrafficAllocation
3333
3434
@@ -104,8 +104,8 @@ def find_bucket(
104104
105105 def bucket (
106106 self , project_config : ProjectConfig ,
107- experiment : Experiment , user_id : str , bucketing_id : str
108- ) -> tuple [Optional [ Variation ] , list [str ]]:
107+ experiment : Experiment | Holdout , user_id : str , bucketing_id : str
108+ ) -> tuple [Variation | None , list [str ]]:
109109 """ For a given experiment and bucketing ID determines variation to be shown to user.
110110
111111 Args:
@@ -125,14 +125,9 @@ def bucket(
125125 project_config .logger .debug (message )
126126 return None , []
127127
128- if isinstance (experiment , dict ):
129- # This is a holdout dictionary
130- experiment_key = experiment .get ('key' , '' )
131- experiment_id = experiment .get ('id' , '' )
132- else :
133- # This is an Experiment object
134- experiment_key = experiment .key
135- experiment_id = experiment .id
128+ # Handle both Experiment and Holdout entities
129+ experiment_key = experiment .key
130+ experiment_id = experiment .id
136131
137132 if not experiment_key or not experiment_key .strip ():
138133 message = 'Invalid entity key provided for bucketing. Returning nil.'
@@ -141,14 +136,9 @@ def bucket(
141136
142137 variation_id , decide_reasons = self .bucket_to_entity_id (project_config , experiment , user_id , bucketing_id )
143138 if variation_id :
144- if isinstance (experiment , dict ):
145- # For holdouts, find the variation in the holdout's variations array
146- variations = experiment .get ('variations' , [])
147- variation = next ((v for v in variations if v .get ('id' ) == variation_id ), None )
148- else :
149- # For experiments, use the existing method
150- variation = project_config .get_variation_from_id_by_experiment_id (experiment_id , variation_id )
151- return variation , decide_reasons
139+ variation = project_config .get_variation_from_id_by_experiment_id (experiment_id , variation_id )
140+ # Cast is safe here because experiments always use Variation entities, not VariationDict
141+ return cast ('Optional[Variation]' , variation ), decide_reasons
152142
153143 # No variation found - log message for empty traffic range
154144 message = 'Bucketed into an empty traffic range. Returning nil.'
@@ -158,7 +148,7 @@ def bucket(
158148
159149 def bucket_to_entity_id (
160150 self , project_config : ProjectConfig ,
161- experiment : Experiment , user_id : str , bucketing_id : str
151+ experiment : Experiment | Holdout , user_id : str , bucketing_id : str
162152 ) -> tuple [Optional [str ], list [str ]]:
163153 """
164154 For a given experiment and bucketing ID determines variation ID to be shown to user.
@@ -176,58 +166,52 @@ def bucket_to_entity_id(
176166 if not experiment :
177167 return None , decide_reasons
178168
179- # Handle both Experiment objects and holdout dictionaries
180- if isinstance (experiment , dict ):
181- # This is a holdout dictionary - holdouts don't have groups
182- experiment_key = experiment .get ('key' , '' )
183- experiment_id = experiment .get ('id' , '' )
184- traffic_allocations = experiment .get ('trafficAllocation' , [])
185- has_cmab = False
186- group_policy = None
187- else :
188- # This is an Experiment object
189- experiment_key = experiment .key
190- experiment_id = experiment .id
191- traffic_allocations = experiment .trafficAllocation
192- has_cmab = bool (experiment .cmab )
169+ # Handle both Experiment and Holdout entities
170+ # Both entities have key, id, and trafficAllocation attributes
171+ from . import entities
172+
173+ experiment_key = experiment .key
174+ experiment_id = experiment .id
175+ traffic_allocations = experiment .trafficAllocation
176+
177+ # Determine if experiment is in a mutually exclusive group
178+ # Holdouts don't have groupId or groupPolicy - use isinstance for type narrowing
179+ if isinstance (experiment , entities .Experiment ):
193180 group_policy = getattr (experiment , 'groupPolicy' , None )
181+ if group_policy and group_policy in GROUP_POLICIES :
182+ group = project_config .get_group (experiment .groupId )
194183
195- # Determine if experiment is in a mutually exclusive group.
196- # This will not affect evaluation of rollout rules or holdouts.
197- if group_policy and group_policy in GROUP_POLICIES :
198- group = project_config .get_group (experiment .groupId )
184+ if not group :
185+ return None , decide_reasons
199186
200- if not group :
201- return None , decide_reasons
187+ user_experiment_id = self .find_bucket (
188+ project_config , bucketing_id , experiment .groupId , group .trafficAllocation ,
189+ )
202190
203- user_experiment_id = self .find_bucket (
204- project_config , bucketing_id , experiment .groupId , group .trafficAllocation ,
205- )
191+ if not user_experiment_id :
192+ message = f'User "{ user_id } " is in no experiment.'
193+ project_config .logger .info (message )
194+ decide_reasons .append (message )
195+ return None , decide_reasons
206196
207- if not user_experiment_id :
208- message = f'User "{ user_id } " is in no experiment.'
209- project_config .logger .info (message )
210- decide_reasons .append (message )
211- return None , decide_reasons
197+ if user_experiment_id != experiment_id :
198+ message = f'User "{ user_id } " is not in experiment " { experiment_key } " of group { experiment . groupId } .'
199+ project_config .logger .info (message )
200+ decide_reasons .append (message )
201+ return None , decide_reasons
212202
213- if user_experiment_id != experiment_id :
214- message = f'User "{ user_id } " is not in experiment "{ experiment_key } " of group { experiment .groupId } .'
203+ message = f'User "{ user_id } " is in experiment { experiment_key } of group { experiment .groupId } .'
215204 project_config .logger .info (message )
216205 decide_reasons .append (message )
217- return None , decide_reasons
218-
219- message = f'User "{ user_id } " is in experiment { experiment_key } of group { experiment .groupId } .'
220- project_config .logger .info (message )
221- decide_reasons .append (message )
222-
223- if has_cmab :
224- if experiment .cmab :
225- traffic_allocations = [
226- {
227- "entityId" : "$" ,
228- "endOfRange" : experiment .cmab ['trafficAllocation' ]
229- }
230- ]
206+
207+ # Holdouts don't have cmab - use isinstance for type narrowing
208+ if isinstance (experiment , entities .Experiment ) and experiment .cmab :
209+ traffic_allocations = [
210+ {
211+ "entityId" : "$" ,
212+ "endOfRange" : experiment .cmab ['trafficAllocation' ]
213+ }
214+ ]
231215
232216 # Bucket user if not in white-list and in group (if any)
233217 variation_id = self .find_bucket (project_config , bucketing_id ,
0 commit comments