Skip to content

Commit 137de92

Browse files
improve table.parents and table.children with the option to include foreign key information
1 parent ae767b7 commit 137de92

4 files changed

Lines changed: 67 additions & 59 deletions

File tree

datajoint/autopopulate.py

Lines changed: 11 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -32,25 +32,19 @@ def key_source(self):
3232
The default value is the join of the parent relations.
3333
Users may override to change the granularity or the scope of populate() calls.
3434
"""
35-
def parent_gen(self):
36-
if self.target.full_table_name not in self.connection.dependencies:
37-
self.connection.dependencies.load()
38-
for parent_name, fk_props in self.target.parents(primary=True).items():
39-
if not parent_name.isdigit(): # simple foreign key
40-
yield FreeTable(self.connection, parent_name).proj()
41-
else:
42-
grandparent = list(self.connection.dependencies.in_edges(parent_name))[0][0]
43-
yield FreeTable(self.connection, grandparent).proj(**{
44-
attr: ref for attr, ref in fk_props['attr_map'].items() if ref != attr})
35+
def _rename_attributes(table, props):
36+
return (table.proj(
37+
**{attr: ref for attr, ref in props['attr_map'].items() if attr != ref})
38+
if props['aliased'] else table)
4539

4640
if self._key_source is None:
47-
parents = parent_gen(self)
48-
try:
49-
self._key_source = next(parents)
50-
except StopIteration:
51-
raise DataJointError('A relation must have primary dependencies for auto-populate to work') from None
52-
for q in parents:
53-
self._key_source *= q
41+
parents = self.target.parents(primary=True, as_objects=True, foreign_key_info=True)
42+
if not parents:
43+
raise DataJointError(
44+
'A relation must have primary dependencies for auto-populate to work') from None
45+
self._key_source = _rename_attributes(*parents[0])
46+
for q in parents[1:]:
47+
self._key_source *= _rename_attributes(*q)
5448
return self._key_source
5549

5650
def make(self, key):

datajoint/table.py

Lines changed: 54 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
from .declare import declare, alter
1212
from .expression import QueryExpression
1313
from . import blob
14-
from .utils import user_choice
14+
from .utils import user_choice, OrderedDict
1515
from .heading import Heading
1616
from .errors import DuplicateError, AccessError, DataJointError, UnknownAttributeError
1717
from .version import __version__ as version
@@ -40,8 +40,7 @@ class Table(QueryExpression):
4040
@property
4141
def heading(self):
4242
"""
43-
Returns the table heading. If the table is not declared, attempts to declare it and return heading.
44-
:return: table heading
43+
:return: table heading. If the table is not declared, attempts to declare it first.
4544
"""
4645
if self._heading is None:
4746
self._heading = Heading() # instance-level heading
@@ -53,7 +52,8 @@ def heading(self):
5352
def declare(self, context=None):
5453
"""
5554
Declare the table in the schema based on self.definition.
56-
:param context: the context for foreign key resolution. If None, foreign keys are not allowed.
55+
:param context: the context for foreign key resolution. If None, foreign keys are
56+
not allowed.
5757
"""
5858
if self.connection.in_transaction:
5959
raise DataJointError('Cannot declare new tables inside a transaction, '
@@ -116,38 +116,59 @@ def get_select_fields(self, select_fields=None):
116116
"""
117117
return '*' if select_fields is None else self.heading.project(select_fields).as_sql
118118

119-
def parents(self, primary=None, as_objects=False):
119+
def parents(self, primary=None, as_objects=False, foreign_key_info=False):
120120
"""
121121
:param primary: if None, then all parents are returned. If True, then only foreign keys composed of
122-
primary key attributes are considered. If False, the only foreign keys including at least one non-primary
123-
attribute are considered.
124-
:param as_objects: if False (default), the output is a dict describing the foreign keys. If True, return table objects.
125-
:return: dict of tables referenced with self's foreign keys or list of table objects if as_objects=True
126-
"""
127-
parents = self.connection.dependencies.parents(self.full_table_name, primary)
122+
primary key attributes are considered. If False, return foreign keys including at least one
123+
secondary attribute.
124+
:param as_objects: if False, return table names. If True, return table objects.
125+
:param foreign_key_info: if True, each element in result also includes foreign key info.
126+
:return: list of parents as table names or table objects
127+
with (optional) foreign key information.
128+
"""
129+
get_edge = self.connection.dependencies.parents
130+
nodes = [next(iter(get_edge(name).items())) if name.isdigit() else (name, props)
131+
for name, props in get_edge(self.full_table_name, primary).items()]
128132
if as_objects:
129-
parents = [FreeTable(self.connection, c) for c in parents]
130-
return parents
133+
nodes = [(FreeTable(self.connection, name), props) for name, props in nodes]
134+
if not foreign_key_info:
135+
nodes = [name for name, props in nodes]
136+
return nodes
131137

132-
def children(self, primary=None, as_objects=False):
138+
def children(self, primary=None, as_objects=False, foreign_key_info=False):
133139
"""
134140
:param primary: if None, then all children are returned. If True, then only foreign keys composed of
135-
primary key attributes are considered. If False, the only foreign keys including at least one non-primary
136-
attribute are considered.
137-
:param as_objects: if False (default), the output is a dict describing the foreign keys. If True, return table objects.
138-
:return: dict of tables with foreign keys referencing self or list of table objects if as_objects=True
139-
"""
140-
nodes = dict((next(iter(self.connection.dependencies.children(k).items())) if k.isdigit() else (k, v))
141-
for k, v in self.connection.dependencies.children(self.full_table_name, primary).items())
141+
primary key attributes are considered. If False, return foreign keys including at least one
142+
secondary attribute.
143+
:param as_objects: if False, return table names. If True, return table objects.
144+
:param foreign_key_info: if True, each element in result also includes foreign key info.
145+
:return: list of children as table names or table objects
146+
with (optional) foreign key information.
147+
"""
148+
get_edge = self.connection.dependencies.children
149+
nodes = [next(iter(get_edge(name).items())) if name.isdigit() else (name, props)
150+
for name, props in get_edge(self.full_table_name, primary).items()]
142151
if as_objects:
143-
nodes = [FreeTable(self.connection, c) for c in nodes]
152+
nodes = [(FreeTable(self.connection, name), props) for name, props in nodes]
153+
if not foreign_key_info:
154+
nodes = [name for name, props in nodes]
144155
return nodes
145156

146157
def descendants(self, as_objects=False):
147-
nodes = [node for node in self.connection.dependencies.descendants(self.full_table_name) if not node.isdigit()]
148-
if as_objects:
149-
nodes = [FreeTable(self.connection, c) for c in nodes]
150-
return nodes
158+
"""
159+
:param as_objects: False - a list of table names; True - a list of table objects.
160+
:return: list of tables descendants in topological order.
161+
"""
162+
return [FreeTable(self.connection, node) if as_objects else node
163+
for node in self.connection.dependencies.descendants(self.full_table_name) if not node.isdigit()]
164+
165+
def ancestors(self, as_objects=False):
166+
"""
167+
:param as_objects: False - a list of table names; True - a list of table objects.
168+
:return: list of tables ancestors in topological order.
169+
"""
170+
return [FreeTable(self.connection, node) if as_objects else node
171+
for node in self.connection.dependencies.ancestors(self.full_table_name) if not node.isdigit()]
151172

152173
def parts(self, as_objects=False):
153174
"""
@@ -156,13 +177,7 @@ def parts(self, as_objects=False):
156177
"""
157178
nodes = [node for node in self.connection.dependencies.nodes
158179
if not node.isdigit() and node.startswith(self.full_table_name[:-1] + '__')]
159-
if as_objects:
160-
nodes = [FreeTable(self.connection, c) for c in nodes]
161-
return nodes
162-
163-
def ancestors(self, as_objects=False):
164-
return [FreeTable(self.connection, node) if as_objects else node
165-
for node in self.connection.dependencies.ancestors(self.full_table_name) if not node.isdigit()]
180+
return [FreeTable(self.connection, c) for c in nodes] if as_objects else nodes
166181

167182
@property
168183
def is_declared(self):
@@ -525,7 +540,7 @@ def describe(self, context=None, printout=True):
525540
del frame
526541
if self.full_table_name not in self.connection.dependencies:
527542
self.connection.dependencies.load()
528-
parents = self.parents()
543+
parents = self.parents(foreign_key_info=True)
529544
in_key = True
530545
definition = ('# ' + self.heading.table_info['comment'] + '\n'
531546
if self.heading.table_info['comment'] else '')
@@ -538,11 +553,10 @@ def describe(self, context=None, printout=True):
538553
in_key = False
539554
attributes_thus_far.add(attr.name)
540555
do_include = True
541-
for parent_name, fk_props in list(parents.items()): # need list() to force a copy
556+
for parent_name, fk_props in parents:
542557
if attr.name in fk_props['attr_map']:
543558
do_include = False
544559
if attributes_thus_far.issuperset(fk_props['attr_map']):
545-
parents.pop(parent_name)
546560
# foreign key properties
547561
try:
548562
index_props = indexes.pop(tuple(fk_props['attr_map']))
@@ -552,19 +566,19 @@ def describe(self, context=None, printout=True):
552566
index_props = [k for k, v in index_props.items() if v]
553567
index_props = ' [{}]'.format(', '.join(index_props)) if index_props else ''
554568

555-
if not parent_name.isdigit():
569+
if not fk_props['aliased']:
556570
# simple foreign key
557571
definition += '->{props} {class_name}\n'.format(
558572
props=index_props,
559573
class_name=lookup_class_name(parent_name, context) or parent_name)
560574
else:
561575
# projected foreign key
562-
parent_name = list(self.connection.dependencies.in_edges(parent_name))[0][0]
563-
lst = [(attr, ref) for attr, ref in fk_props['attr_map'].items() if ref != attr]
564576
definition += '->{props} {class_name}.proj({proj_list})\n'.format(
565577
props=index_props,
566578
class_name=lookup_class_name(parent_name, context) or parent_name,
567-
proj_list=','.join('{}="{}"'.format(a, b) for a, b in lst))
579+
proj_list=','.join(
580+
'{}="{}"'.format(attr, ref)
581+
for attr, ref in fk_props['attr_map'].items() if ref != attr))
568582
attributes_declared.update(fk_props['attr_map'])
569583
if do_include:
570584
attributes_declared.add(attr.name)

tests/test_fetch.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -201,7 +201,7 @@ def test_fetch1_step3(self):
201201
self.lang.fetch1('name')
202202

203203
def test_decimal(self):
204-
"""Tests that decimal fields are correctly fetched and used in restrictions, see issue #334"""
204+
""" Tests that decimal fields are correctly fetched and used in restrictions, see issue #334"""
205205
rel = schema.DecimalPrimaryKey()
206206
rel.insert1([decimal.Decimal('3.1415926')])
207207
keys = rel.fetch()

tests/test_foreign_keys.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from nose.tools import assert_equal, assert_false, assert_true, raises
1+
from nose.tools import assert_equal, assert_false, assert_true
22
from datajoint.declare import declare
33

44
from . import schema_advanced

0 commit comments

Comments
 (0)