Skip to content

Commit c9848fc

Browse files
authored
SUBMIT-756 Invitation creation refactoring - avoid account project creation upfront (#778)
1 parent 07dce7a commit c9848fc

7 files changed

Lines changed: 102 additions & 18 deletions

File tree

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,64 @@
1+
"""unique constraint account project
2+
3+
Revision ID: 2656a9b67883
4+
Revises: e9ede8ab9185
5+
Create Date: 2026-02-02 09:38:20.712941
6+
7+
"""
8+
from alembic import op
9+
import sqlalchemy as sa
10+
from sqlalchemy.dialects import postgresql
11+
12+
# revision identifiers, used by Alembic.
13+
revision = '2656a9b67883'
14+
down_revision = 'e9ede8ab9185'
15+
branch_labels = None
16+
depends_on = None
17+
18+
19+
def upgrade():
20+
# ### commands auto generated by Alembic - please adjust! ###
21+
# Rename existing submissiontypestatus to submissiontype (used by type column)
22+
op.execute("ALTER TYPE submissiontypestatus RENAME TO submissiontype")
23+
24+
# Create new submissiontypestatus enum for status column
25+
op.execute("CREATE TYPE submissiontypestatus AS ENUM ('SUBMITTED', 'REJECTED', 'APPROVED', 'PENDING', 'PENDING_REPLACEMENT')")
26+
27+
with op.batch_alter_table('submissions', schema=None) as batch_op:
28+
# Change status column from submissionstatus to submissiontypestatus
29+
batch_op.alter_column('status',
30+
existing_type=postgresql.ENUM('SUBMITTED', 'REJECTED', 'APPROVED', 'PENDING', 'PENDING_REPLACEMENT', name='submissionstatus'),
31+
type_=sa.Enum('SUBMITTED', 'REJECTED', 'APPROVED', 'PENDING', 'PENDING_REPLACEMENT', name='submissiontypestatus'),
32+
existing_nullable=True,
33+
postgresql_using='status::text::submissiontypestatus')
34+
35+
with op.batch_alter_table('account_projects', schema=None) as batch_op:
36+
batch_op.create_unique_constraint('uq_account_project', ['account_id', 'project_id'])
37+
38+
# Drop old unused enum type
39+
op.execute("DROP TYPE IF EXISTS submissionstatus")
40+
41+
# ### end Alembic commands ###
42+
43+
44+
def downgrade():
45+
# ### commands auto generated by Alembic - please adjust! ###
46+
# Recreate the old enum type
47+
op.execute("CREATE TYPE submissionstatus AS ENUM ('SUBMITTED', 'REJECTED', 'APPROVED', 'PENDING', 'PENDING_REPLACEMENT')")
48+
49+
with op.batch_alter_table('submissions', schema=None) as batch_op:
50+
# Revert status column from submissiontypestatus back to submissionstatus
51+
batch_op.alter_column('status',
52+
existing_type=sa.Enum('SUBMITTED', 'REJECTED', 'APPROVED', 'PENDING', name='submissiontypestatus'),
53+
type_=postgresql.ENUM('SUBMITTED', 'REJECTED', 'APPROVED', 'PENDING', 'PENDING_REPLACEMENT', name='submissionstatus'),
54+
existing_nullable=True,
55+
postgresql_using='status::text::submissionstatus')
56+
57+
with op.batch_alter_table('account_projects', schema=None) as batch_op:
58+
batch_op.drop_constraint('uq_account_project', type_='unique')
59+
60+
# Drop the enum type created in upgrade and rename submissiontype back
61+
op.execute("DROP TYPE IF EXISTS submissiontypestatus")
62+
op.execute("ALTER TYPE submissiontype RENAME TO submissiontypestatus")
63+
64+
# ### end Alembic commands ###

submit-api/src/submit_api/models/account_project.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
"""
55
from __future__ import annotations
66

7-
from sqlalchemy import Column, ForeignKey
7+
from sqlalchemy import Column, ForeignKey, UniqueConstraint
88

99
from .base_model import BaseModel
1010
from .db import db
@@ -14,6 +14,9 @@ class AccountProject(BaseModel):
1414
"""Definition of the Account Project entity."""
1515

1616
__tablename__ = 'account_projects'
17+
__table_args__ = (
18+
UniqueConstraint('account_id', 'project_id', name='uq_account_project'),
19+
)
1720

1821
id = Column(db.Integer, primary_key=True, autoincrement=True)
1922
account_id = Column(db.Integer, ForeignKey('accounts.id', ondelete='CASCADE'), nullable=False)

submit-api/src/submit_api/models/submission.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ class Submission(BaseModel):
3737
id = Column(db.Integer, primary_key=True, autoincrement=True)
3838
submitted_form_id = Column(db.Integer, ForeignKey('submitted_forms.id'), nullable=True)
3939
item_id = Column(db.Integer, ForeignKey('items.id', ondelete='CASCADE'), nullable=False)
40-
type = Column(Enum(SubmissionType), nullable=False)
40+
type = Column(Enum(SubmissionType, name='submissiontype'), nullable=False)
4141
submitted_document_id = Column(db.Integer, ForeignKey('submitted_documents.id'), nullable=True)
4242
submitted_form = db.relationship('SubmittedForm', foreign_keys=[submitted_form_id], lazy='joined')
4343
submitted_document = db.relationship('SubmittedDocument', foreign_keys=[submitted_document_id], lazy='joined')

submit-api/src/submit_api/resources/migration_ops.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -65,8 +65,8 @@ def post(): # pylint: disable=too-many-locals,too-many-statements
6565
# Query all foreign key constraints that reference users.auth_guid
6666
current_app.logger.info("Querying foreign key constraints on users.auth_guid...")
6767
fk_query = text("""
68-
SELECT
69-
tc.table_name,
68+
SELECT
69+
tc.table_name,
7070
tc.constraint_name,
7171
kcu.column_name
7272
FROM information_schema.table_constraints AS tc

submit-api/src/submit_api/services/invitation_service.py

Lines changed: 16 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -139,13 +139,12 @@ def create_invitation(invite_data):
139139
account_id = invite_data.get('account_id')
140140
role_name = invite_data.get('role_name')
141141
proponent_id = invite_data.get('proponent_id')
142-
project_ids = invite_data.get('project_ids')
143142

144143
role = InvitationService._validate_fetch_role(role_name)
145144

146145
with session_scope() as session:
147146
account = InvitationService._get_or_create_account(
148-
account_id, proponent_id, project_ids, session)
147+
account_id, proponent_id, session)
149148
session.flush()
150149

151150
token = InvitationService.generate_uuid_token()
@@ -200,6 +199,10 @@ def accept_invitation(token, payload):
200199
account_user = InvitationService._create_account_user(
201200
user.id, invitation.account_id, payload, session)
202201

202+
# Create account projects if they don't exist (handles concurrent invitations gracefully)
203+
InvitationService._create_account_projects(
204+
invitation.account_id, invitation.project_ids, session)
205+
203206
account_projects = AccountProjectModel.get_all_in_project_ids(invitation.project_ids)
204207
roles = []
205208
for account_project in account_projects:
@@ -260,13 +263,13 @@ def _validate_invitation_access(invitation):
260263
return True
261264

262265
@staticmethod
263-
def _get_or_create_account(account_id, proponent_id, project_ids, session):
266+
def _get_or_create_account(account_id, proponent_id, session):
264267
"""Retrieve or create an account based on proponent_id or account_id."""
265268
if account_id:
266269
return InvitationService._get_account_by_id(account_id)
267270

268271
if proponent_id:
269-
return InvitationService._get_or_create_account_by_proponent(proponent_id, project_ids, session)
272+
return InvitationService._get_or_create_account_by_proponent(proponent_id, session)
270273

271274
raise ResourceNotFoundError("No valid account found for the provided data.")
272275

@@ -281,14 +284,12 @@ def _get_proponent_by_id(proponent_id):
281284
return ProponentModel.find_by_id(proponent_id)
282285

283286
@staticmethod
284-
def _get_or_create_account_by_proponent(proponent_id, project_ids, session):
287+
def _get_or_create_account_by_proponent(proponent_id, session):
285288
"""Retrieve or create an account by proponent_id."""
286289
account = AccountModel.get_by_proponent_id(proponent_id)
287290
if not account:
288291
account_data = {'proponent_id': proponent_id}
289292
account = AccountModel.create_account(account_data, session)
290-
InvitationService._create_account_projects(
291-
account.id, project_ids, session)
292293
return account
293294

294295
@staticmethod
@@ -313,7 +314,7 @@ def _create_invitation_record(invite_data, role, account, token, session):
313314
role_id=role.id,
314315
package_ids=invite_data.get('package_ids'),
315316
original_package_ids=invite_data.get('original_package_ids'),
316-
expiry_date=datetime.datetime.utcnow() + datetime.timedelta(days=expiry_days),
317+
expiry_date=datetime.datetime.now(datetime.timezone.utc) + datetime.timedelta(days=expiry_days),
317318
is_first_time=is_first_time
318319
)
319320
session.add(invitation)
@@ -396,7 +397,12 @@ def get_valid_invitation(token):
396397
if invitation.status != InvitationStatus.PENDING.value:
397398
return {"error": "Invitation is not valid"}, False
398399

399-
if invitation.expiry_date < datetime.datetime.utcnow():
400+
# Ensure expiry_date is timezone-aware for comparison
401+
expiry_date = invitation.expiry_date
402+
if expiry_date.tzinfo is None:
403+
expiry_date = expiry_date.replace(tzinfo=datetime.timezone.utc)
404+
405+
if expiry_date < datetime.datetime.now(datetime.timezone.utc):
400406
return {"error": "Invitation has expired"}, False
401407

402408
# Update proponent status
@@ -428,7 +434,7 @@ def resend_invitation(token):
428434
InvitationService._check_action_authorized(invitation.project_ids)
429435

430436
# Extend expiry date by 1 week from current date
431-
invitation.expiry_date = datetime.datetime.utcnow() + datetime.timedelta(weeks=1)
437+
invitation.expiry_date = datetime.datetime.now(datetime.timezone.utc) + datetime.timedelta(weeks=1)
432438

433439
InvitationService._create_email_queue_record(
434440
invitation.id, session)

submit-api/tests/unit/resources/test_invitation.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
Tests for invitation resource endpoints.
44
"""
55

6-
from datetime import datetime, timedelta, UTC
6+
from datetime import datetime, timedelta, timezone
77
from http import HTTPStatus
88
from unittest.mock import patch
99

@@ -124,7 +124,7 @@ def test_get_invitation_expired(client, session, jwt):
124124
_, account_project = setup_authenticated_proponent(session, jwt)
125125
invitation = factory_invitation_model(
126126
account_id=account_project.account_id,
127-
expiry_date=datetime.now(UTC) - timedelta(days=1), # Expired
127+
expiry_date=datetime.now(timezone.utc) - timedelta(days=1), # Expired
128128
)
129129

130130
response = client.get(f"/api/invitations/{invitation.token}")
@@ -161,6 +161,7 @@ def test_accept_invitation(client, session, jwt):
161161
_, account_project = setup_authenticated_proponent(session, jwt)
162162
invitation = factory_invitation_model(
163163
account_id=account_project.account_id,
164+
project_ids=[account_project.project_id]
164165
)
165166

166167
payload = {
@@ -182,6 +183,12 @@ def test_accept_invitation(client, session, jwt):
182183
print(data)
183184
assert "user_id" in data
184185

186+
# Verify that account_projects were created during acceptance
187+
from submit_api.models import AccountProject as AccountProjectModel
188+
account_projects = AccountProjectModel.get_all_in_project_ids(invitation.project_ids)
189+
assert len(account_projects) > 0
190+
assert any(ap.account_id == invitation.account_id for ap in account_projects)
191+
185192

186193
def test_accept_invitation_invalid_token(client, session):
187194
"""Test accepting invitation with invalid token."""

submit-api/tests/utilities/factory_utils.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
1717
Test Utility for creating model factory.
1818
"""
19-
from datetime import datetime, timedelta, UTC
19+
from datetime import datetime, timedelta, timezone
2020
import random
2121
import string
2222

@@ -190,11 +190,15 @@ def factory_invitation_model(
190190
token=fake.uuid4(),
191191
email=fake.email(),
192192
status=InvitationStatus.PENDING.value,
193-
expiry_date=datetime.now(UTC) + timedelta(days=7),
193+
expiry_date=datetime.now(timezone.utc) + timedelta(days=7),
194194
role_id=None,
195195
is_first_time=False
196196
):
197197
"""Create and return a mock invitation entry."""
198+
if role_id is None:
199+
role = Role.get_by_name(RoleEnum.PROJECT_ADMIN.value)
200+
role_id = role.id if role else None
201+
198202
invitation = Invitations(
199203
account_id=account_id,
200204
project_ids=project_ids,

0 commit comments

Comments
 (0)