Skip to content

Commit d8576b3

Browse files
authored
Fix create_user and usages to be more robust (#3330)
1 parent 052721f commit d8576b3

12 files changed

Lines changed: 148 additions & 50 deletions

File tree

b2b/serializers/v0/serializers_test.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99

1010

1111
@pytest.mark.django_db
12+
@pytest.mark.disable_responses
1213
def test_contract_page_serializer_expands_embeds():
1314
"""Test that welcome_message_extra expands embed tags and preserves HTML."""
1415
contract = ContractPageFactory.create(
@@ -26,4 +27,6 @@ def test_contract_page_serializer_expands_embeds():
2627
r"</iframe>\s*"
2728
r"</div>"
2829
)
29-
assert re.match(pattern, result, re.DOTALL)
30+
assert re.match(pattern, result, re.DOTALL), (
31+
f"Content did not match expectation: {result}"
32+
)

courses/api.py

Lines changed: 30 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,7 @@
6969
from main import features
7070
from openedx.api import (
7171
create_edx_course_mode,
72+
create_user,
7273
enroll_in_edx_course_runs,
7374
get_edx_api_course_detail_client,
7475
get_edx_api_course_list_client,
@@ -133,12 +134,33 @@ def get_user_relevant_program_course_run_qset(
133134
return enrollable_run_qset.order_by("enrollment_start")
134135

135136

136-
def create_run_enrollments( # noqa: C901
137+
def _enroll_learner_into_associated_programs(run, user):
138+
"""
139+
Enrolls the learner into all programs for which the course they are enrolling into
140+
is associated as a requirement or elective. If a program enrollment already exists
141+
then the change_status of that program_enrollment is checked to ensure it equals None.
142+
"""
143+
for program in run.course.programs:
144+
if not program.live:
145+
continue
146+
program_enrollment, _ = ProgramEnrollment.objects.get_or_create(
147+
user=user,
148+
program=program,
149+
defaults=dict( # noqa: C408
150+
change_status=None,
151+
),
152+
)
153+
if program_enrollment.change_status is not None:
154+
program_enrollment.reactivate_and_save()
155+
156+
157+
def create_run_enrollments( # noqa: C901,PLR0913
137158
user,
138159
runs,
139160
*,
140161
change_status=None,
141162
keep_failed_enrollments=None,
163+
create_courseware_user=False,
142164
mode=EDX_DEFAULT_ENROLLMENT_MODE,
143165
):
144166
"""
@@ -170,30 +192,14 @@ def create_run_enrollments( # noqa: C901
170192
features.IGNORE_EDX_FAILURES, False
171193
)
172194

173-
successful_enrollments = []
195+
if create_courseware_user:
196+
create_user(user)
197+
user.refresh_from_db()
174198

175-
def send_enrollment_emails():
199+
def _subscribe_to_edx_course_emails():
176200
subscribe_edx_course_emails.delay(enrollment.id)
177201

178-
def _enroll_learner_into_associated_programs():
179-
"""
180-
Enrolls the learner into all programs for which the course they are enrolling into
181-
is associated as a requirement or elective. If a program enrollment already exists
182-
then the change_status of that program_enrollment is checked to ensure it equals None.
183-
"""
184-
for program in run.course.programs:
185-
if not program.live:
186-
continue
187-
program_enrollment, _ = ProgramEnrollment.objects.get_or_create(
188-
user=user,
189-
program=program,
190-
defaults=dict( # noqa: C408
191-
change_status=None,
192-
),
193-
)
194-
if program_enrollment.change_status is not None:
195-
program_enrollment.reactivate_and_save()
196-
202+
successful_enrollments = []
197203
edx_request_success = True
198204
if not runs[0].is_fake_course_run:
199205
# Make the API call to enroll the user in edX only if the run is not a fake course run
@@ -233,7 +239,7 @@ def _enroll_learner_into_associated_programs():
233239
),
234240
)
235241

236-
_enroll_learner_into_associated_programs()
242+
_enroll_learner_into_associated_programs(run, user)
237243

238244
# If the run is associated with a B2B contract, add the contract
239245
# to the user's contract list and update their org memberships
@@ -267,7 +273,7 @@ def _enroll_learner_into_associated_programs():
267273
if enrollment_mode_changed:
268274
enrollment.enrollment_mode = mode
269275
enrollment.reactivate_and_save()
270-
transaction.on_commit(send_enrollment_emails)
276+
transaction.on_commit(_subscribe_to_edx_course_emails)
271277
except: # pylint: disable=bare-except # noqa: PERF203, E722
272278
mail_api.send_enrollment_failure_message(user, run, details=format_exc())
273279
log.exception(

courses/views/v1/__init__.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,6 @@
7777
from main.utils import encode_json_cookie_value, redirect_with_user_message
7878
from openapi.utils import extend_schema_get_queryset
7979
from openedx.api import (
80-
create_user,
8180
subscribe_to_edx_course_emails,
8281
sync_enrollments_with_edx,
8382
unsubscribe_from_edx_course_emails,
@@ -345,13 +344,10 @@ def post(self, request):
345344
if resp is not None:
346345
return resp
347346

348-
if not user.openedx_user_exists:
349-
create_user(user)
350-
user.refresh_from_db()
351-
352347
_, edx_request_success = create_run_enrollments(
353348
user=user,
354349
runs=[run],
350+
create_courseware_user=True,
355351
keep_failed_enrollments=settings.FEATURES.get(
356352
features.IGNORE_EDX_FAILURES, False
357353
),

courses/views/v1/views_test.py

Lines changed: 3 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -615,10 +615,7 @@ def test_user_enrollment_delete_other_fail(mocker, settings, user_drf_client, us
615615

616616
@pytest.mark.parametrize("api_request", [True, False])
617617
@pytest.mark.parametrize("product_exists", [True, False])
618-
@pytest.mark.parametrize("has_openedx_user", [True, False])
619-
def test_create_enrollments(
620-
mocker, user, api_request, product_exists, has_openedx_user
621-
):
618+
def test_create_enrollments(mocker, user, api_request, product_exists):
622619
"""
623620
Create enrollment view should create an enrollment and include a user message in the response cookies.
624621
Unless api_request is set to True, in which case we should get a string back.
@@ -627,17 +624,11 @@ def test_create_enrollments(
627624
"courses.views.v1.create_run_enrollments",
628625
return_value=(None, True),
629626
)
630-
mock_fulfilled_order_filter = mocker.patch( # noqa: F841
631-
"ecommerce.models.FulfilledOrder.objects.filter", return_value=None
632-
)
633-
mock_create_user = mocker.patch("courses.views.v1.create_user")
627+
mocker.patch("ecommerce.models.FulfilledOrder.objects.filter", return_value=None)
634628
run = CourseRunFactory.create()
635629
if product_exists:
636630
with reversion.create_revision():
637-
product = ProductFactory.create(purchasable_object=run) # noqa: F841
638-
639-
if not has_openedx_user:
640-
user.openedx_users.all().delete()
631+
ProductFactory.create(purchasable_object=run)
641632

642633
client = Client()
643634
client.force_login(user)
@@ -665,7 +656,6 @@ def test_create_enrollments(
665656
}
666657
)
667658
patched_create_enrollments.assert_called_once()
668-
mock_create_user.assert_called_once() if not has_openedx_user else mock_create_user.assert_not_called()
669659

670660

671661
def test_create_enrollments_failed(mocker, settings, user_client):

fixtures/common.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -70,10 +70,11 @@ def admin_drf_client(admin_user):
7070

7171

7272
@pytest.fixture
73-
def mocked_responses():
73+
def mocked_responses(request):
7474
"""Mocked responses for requests library"""
75-
with responses.RequestsMock() as rsps:
76-
yield rsps
75+
if not request.node.get_closest_marker("disable_responses"):
76+
with responses.RequestsMock() as rsps:
77+
yield rsps
7778

7879

7980
@pytest.fixture

main/exceptions.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,3 +12,7 @@ def exception_handler(exc, context):
1212
detail={"errors": exc.detail}, code=exc.status_code
1313
)
1414
return views.exception_handler(exc, context)
15+
16+
17+
class UnexpectedTransactionAtomicError(Exception):
18+
"""We attempted to run a function marked invalid for transactions while in a transaction"""

main/utils.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
from __future__ import annotations
44

5+
import contextlib
56
import json
67
from datetime import date, datetime
78
from typing import TYPE_CHECKING, Set, Tuple, TypeVar, Union # noqa: UP035
@@ -13,11 +14,13 @@
1314
from django.conf import settings
1415
from django.core.cache import caches
1516
from django.core.serializers import serialize
17+
from django.db import connection
1618
from django.http import HttpRequest, HttpResponseRedirect
1719
from mitol.common.utils.urls import remove_password_from_url
1820
from rest_framework import status
1921

2022
from main.constants import USER_MSG_COOKIE_MAX_AGE, USER_MSG_COOKIE_NAME
23+
from main.exceptions import UnexpectedTransactionAtomicError
2124
from main.settings import TIME_ZONE
2225

2326
if TYPE_CHECKING:
@@ -212,3 +215,34 @@ def get_redis_lock(name, **kwargs):
212215
redis_cache = caches["redis"]
213216
client = redis_cache.client.get_client()
214217
return redis_lock.Lock(client, name, **kwargs)
218+
219+
220+
def is_in_transaction() -> bool:
221+
"""Return True if we're in a transaction"""
222+
if settings.ENVIRONMENT == "pytest":
223+
# if we're running under pytest, tests that use django_db get wrapped in a
224+
# transaction.atomic(), so the logic for whether we're in a transaction inside
225+
# these tests is actually dependent on there being more than that one wrapper atomic block
226+
return len(connection.atomic_blocks) > 1
227+
228+
return connection.in_atomic_block
229+
230+
231+
class raise_on_transaction_atomic(contextlib.ContextDecorator): # noqa: N801
232+
"""
233+
Context manager / decorator that will raise the passed exception
234+
if this wrapped function or context executes while a
235+
transaction.atomic() is active.
236+
237+
If an exception type isn't passed UnexpectedTransactionAtomicError is used.
238+
"""
239+
240+
def __init__(self, exc=UnexpectedTransactionAtomicError):
241+
self.exc = exc
242+
243+
def __enter__(self):
244+
if is_in_transaction():
245+
raise self.exc
246+
247+
def __exit__(self, *_exc):
248+
pass

main/utils_test.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,16 +4,20 @@
44
from zoneinfo import ZoneInfo
55

66
import pytest
7+
from django.db import transaction
78
from mitol.common.utils.urls import remove_password_from_url
89

10+
from main.exceptions import UnexpectedTransactionAtomicError
911
from main.models import AuditModel
1012
from main.settings import TIME_ZONE
1113
from main.utils import (
1214
date_to_datetime,
1315
get_field_names,
1416
get_js_settings,
1517
get_partitioned_set_difference,
18+
is_in_transaction,
1619
parse_supplied_date,
20+
raise_on_transaction_atomic,
1721
)
1822

1923

@@ -97,3 +101,32 @@ def test_date_to_datetime():
97101

98102
with pytest.raises(AttributeError):
99103
date_to_datetime("this date isn't a date at all", TIME_ZONE)
104+
105+
106+
@pytest.mark.django_db
107+
def test_is_in_transaction(settings):
108+
"""Test that is_in_transaction functions correctly"""
109+
assert is_in_transaction() is False
110+
111+
with transaction.atomic():
112+
assert is_in_transaction() is True
113+
114+
settings.ENVIRONMENT = "dev"
115+
116+
assert is_in_transaction() is True
117+
118+
119+
@pytest.mark.django_db
120+
def test_raise_on_transaction_atomic():
121+
"""Test that raise_on_transaction_atomic correctly raises exceptions when in a transaction"""
122+
123+
with raise_on_transaction_atomic():
124+
pass
125+
126+
with pytest.raises(UnexpectedTransactionAtomicError), transaction.atomic(): # noqa: SIM117
127+
with raise_on_transaction_atomic():
128+
pass
129+
130+
with pytest.raises(AttributeError), transaction.atomic(): # noqa: SIM117
131+
with raise_on_transaction_atomic(AttributeError):
132+
pass

openedx/api.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,11 @@
3030
from authentication import api as auth_api
3131
from courses.constants import ENROLL_CHANGE_STATUS_UNENROLLED
3232
from main import features
33-
from main.utils import get_partitioned_set_difference, get_redis_lock
33+
from main.utils import (
34+
get_partitioned_set_difference,
35+
get_redis_lock,
36+
raise_on_transaction_atomic,
37+
)
3438
from openedx.constants import (
3539
EDX_DEFAULT_ENROLLMENT_MODE,
3640
OPENEDX_REPAIR_GRACE_PERIOD_MINS,
@@ -47,6 +51,7 @@
4751
OpenEdxUserMissingError,
4852
UnknownEdxApiEmailSettingsException,
4953
UnknownEdxApiEnrollException,
54+
UserCreateInTransactionError,
5055
UserNameUpdateFailedException,
5156
)
5257
from openedx.models import OpenEdxApiAuth, OpenEdxUser
@@ -387,6 +392,7 @@ def _create_edx_user_request(open_edx_user, user, access_token): # noqa: C901,
387392
lock.release()
388393

389394

395+
@raise_on_transaction_atomic(UserCreateInTransactionError)
390396
def create_edx_user(user, edx_username=None):
391397
"""
392398
Makes a request to create an equivalent user in Open edX

0 commit comments

Comments
 (0)