Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions lead/features/wic.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ def __init__(self, spacedeltas, dates, parallel=False):
parallel=parallel)

def get_aggregates(self, date, delta):
enroll = self.inputs[0].get_result()
enroll = self.inputs[0].result
aggregates = [
Aggregate('medical_risk', 'any', fname=False),
Aggregate(['household_size', 'household_income'],
Expand Down Expand Up @@ -89,7 +89,7 @@ def __init__(self, spacedeltas, dates, parallel=False):
parallel=parallel)

def get_aggregates(self, date, delta):
births = self.inputs[0].get_result()
births = self.inputs[0].result
aggregates = [
Aggregate('length', 'max', fname=False),
Aggregate('weight', 'max', fname=False),
Expand Down Expand Up @@ -125,7 +125,7 @@ def __init__(self, spacedeltas, dates, parallel=False):
parallel=parallel)

def get_aggregates(self, date, delta):
prenatal = self.inputs[0].get_result()
prenatal = self.inputs[0].result

aggregates = [
Count(),
Expand Down
4 changes: 2 additions & 2 deletions lead/model/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,9 +94,9 @@ def run(self, acs, left, aux=None):
left = data.binarize(left, ['community_area_id', 'ward_id'], astype=self.dtype)

logging.info('Joining aggregations')
X = left.join([a.get_result() for a in self.aggregation_joins] + [acs])
X = left.join([a.result for a in self.aggregation_joins] + [acs])
# delete all aggregation inputs so that memory can be freed
for a in self.aggregation_joins: del a._result
for a in self.aggregation_joins: del a.result

if not self.address:
logging.info('Adding auxillary features')
Expand Down
10 changes: 8 additions & 2 deletions lead/model/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,19 +14,23 @@ class LeadTransform(Step):
performing feature selection and creating sample weights.
"""
def __init__(self, inputs, outcome_expr, aggregations,
wic_sample_weight=0, exclude=[], include=[]):
outcome_where_expr=None, wic_sample_weight=0,
exclude=[], include=[]):
"""
Args:
inputs: list containing a LeadCrossValidate step
outcome_expr: the query to perform on the auxillary information to produce an outcome variable
aggregations: defines which of the SpacetimeAggregations to include
and which to drop
and which to drop
outcome_where_expr: where to evaluate the outcome_expr,
defaults to None, which means everywhere
wic_sample_weight: optional different sample weight for wic kids
"""
Step.__init__(self,
inputs=inputs,
outcome_expr=outcome_expr,
aggregations=aggregations,
outcome_where_expr=outcome_where_expr,
wic_sample_weight=wic_sample_weight,
exclude=exclude, include=include)

Expand All @@ -40,6 +44,8 @@ def run(self, X, aux, train, test):

"""
y = aux.eval(self.outcome_expr)
if self.outcome_where_expr is not None:
y = y.where(aux.eval(self.outcome_where_expr))

logging.info('Selecting aggregations')
aggregations = self.get_input(LeadData).aggregations
Expand Down
3 changes: 2 additions & 1 deletion lead/model/workflows.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,8 @@ def bll6_models(estimators, cv_search={}, transform_search={}):
transformd = dict(
wic_sample_weight=[0],
aggregations=aggregations.args,
outcome_expr=['max_bll0 >= 6']
outcome_expr='max_bll0 >= 6',
outcome_where_expr='max_bll0 == max_bll0' # this means max_bll0.notnull()
)
transformd.update(transform_search)
return models(estimators, cvd, transformd)
Expand Down