@@ -219,14 +219,33 @@ def create_invite_tokens(
219219 _check_election_is_not_ended (get_election (db , election_ref ))
220220 now = datetime .now ()
221221 params = {"date_created" : now , "date_modified" : now , "election_ref" : election_ref }
222- db_votes = [models .Vote (** params ) for _ in range (num_voters * num_candidates )]
223- db .bulk_save_objects (db_votes , return_defaults = True )
224- db .commit ()
222+
223+ try :
224+ db_ballots = [models .Ballot (election_ref = election_ref ) for _ in range (num_voters )]
225+ db .bulk_save_objects (db_ballots , return_defaults = True )
226+
227+ db_votes = []
228+
229+ for ballot in db_ballots :
230+ for _ in range (num_candidates ):
231+ db_votes .append (models .Vote (** params , ballot_id = ballot .id ))
232+
233+ db .bulk_save_objects (db_votes , return_defaults = True )
234+ db .commit ()
235+ except Exception as e :
236+ db .rollback ()
237+ raise e
238+
239+ tokens = []
225240 vote_ids = [int (str (v .id )) for v in db_votes ]
226- tokens = [
227- create_ballot_token (vote_ids [i ::num_voters ], election_ref )
228- for i in range (num_voters )
229- ]
241+
242+ for i , ballot in enumerate (db_ballots ):
243+ start = i * num_candidates
244+ end = start + num_candidates
245+ tokens .append (
246+ create_ballot_token (vote_ids [start :end ], election_ref , int (str (ballot .id )))
247+ )
248+
230249 return tokens
231250
232251
@@ -417,18 +436,29 @@ def create_ballot(db: Session, ballot: schemas.BallotCreate) -> schemas.BallotGe
417436 )
418437 _check_ballot_is_consistent (election , ballot )
419438
420- # Ideally, we would use RETURNING but it does not work yet for SQLite
421- db_votes = [
422- models .Vote (** v .model_dump (), election_ref = ballot .election_ref ) for v in ballot .votes
423- ]
424- db .add_all (db_votes )
425- db .commit ()
426- for v in db_votes :
427- db .refresh (v )
439+ try :
440+ db_ballot = models .Ballot (election_ref = ballot .election_ref )
441+ db .add (db_ballot )
442+ db .flush ()
443+
444+ # Create votes and associate them with the ballot
445+ db_votes = [
446+ models .Vote (** v .model_dump (), election_ref = ballot .election_ref , ballot_id = db_ballot .id )
447+ for v in ballot .votes
448+ ]
449+ db .add_all (db_votes )
450+ db .commit ()
451+ db .refresh (db_ballot )
452+
453+ for v in db_votes :
454+ db .refresh (v )
455+ except Exception as e :
456+ db .rollback ()
457+ raise e
428458
429459 votes_get = [schemas .VoteGet .model_validate (v ) for v in db_votes ]
430460 vote_ids = [v .id for v in votes_get ]
431- token = create_ballot_token (vote_ids , ballot .election_ref )
461+ token = create_ballot_token (vote_ids , ballot .election_ref , int ( db_ballot . id ) )
432462 return schemas .BallotGet (votes = votes_get , token = token , election = election )
433463
434464
@@ -523,6 +553,13 @@ def update_ballot(
523553 if len (db_votes ) != len (vote_ids ):
524554 raise errors .NotFoundError ("votes" )
525555
556+ # Verify all votes belong to the same ballot
557+ ballot_ids = {int (v .ballot_id ) for v in db_votes if v .ballot_id is not None }
558+
559+ if len (ballot_ids ) > 1 :
560+ raise errors .ForbiddenError ("All votes must belong to the same ballot" )
561+
562+ # old API does not contains ballot id in the token
526563 election = schemas .ElectionGet .model_validate (db_votes [0 ].election )
527564
528565 for vote , db_vote in zip (ballot .votes , db_votes ):
@@ -533,7 +570,6 @@ def update_ballot(
533570 db .commit ()
534571
535572 votes_get = [schemas .VoteGet .model_validate (v ) for v in db_votes ]
536- token = create_ballot_token (vote_ids , election_ref )
537573 return schemas .BallotGet (votes = votes_get , token = token , election = election )
538574
539575
0 commit comments