Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
81 changes: 20 additions & 61 deletions app/ldap_protocol/ldap_requests/modify.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,17 +37,11 @@
from ldap_protocol.policies.password import PasswordPolicyUseCases
from ldap_protocol.session_storage import SessionStorage
from ldap_protocol.utils.cte import check_root_group_membership_intersection
from ldap_protocol.utils.helpers import (
create_user_name,
ft_to_dt,
is_dn_in_base_directory,
validate_entry,
)
from ldap_protocol.utils.helpers import ft_to_dt, validate_entry
from ldap_protocol.utils.queries import (
add_lock_and_expire_attributes,
clear_group_membership,
extend_group_membership,
get_base_directories,
get_directories,
get_directory_by_rid,
get_filter_from_path,
Expand Down Expand Up @@ -275,9 +269,7 @@ async def handle(
await self._add(*add_args)

await ctx.session.flush()
await ctx.session.execute(
update(Directory).filter_by(id=directory.id),
)

except MODIFY_EXCEPTION_STACK as err:
await ctx.session.rollback()
result_code, message = self._match_bad_response(err)
Expand Down Expand Up @@ -857,17 +849,12 @@ async def _add( # noqa: C901

await session.execute(
delete(Attribute)
.filter_by(
name="nsAccountLock",
directory=directory,
),
) # fmt: skip

await session.execute(
delete(Attribute)
.filter_by(
name="shadowExpire",
directory=directory,
.where(
or_(
qa(Attribute.name) == "nsAccountLock",
qa(Attribute.name) == "shadowExpire",
),
qa(Attribute.directory) == directory,
),
) # fmt: skip

Expand All @@ -891,47 +878,19 @@ async def _add( # noqa: C901
)

elif name in User.search_fields:
if not directory.user:
path_dn = directory.path_dn
for base_directory in await get_base_directories(session):
if is_dn_in_base_directory(base_directory, path_dn):
base_dn = base_directory
break

sam_account_name = create_user_name(directory.id)
user_principal_name = f"{sam_account_name}@{base_dn.name}"
user = User(
sam_account_name=sam_account_name,
user_principal_name=user_principal_name,
directory_id=directory.id,
)
uac_attr = Attribute(
name="userAccountControl",
value=str(UserAccountControlFlag.NORMAL_ACCOUNT),
directory_id=directory.id,
)

session.add_all([user, uac_attr])
await session.flush()
await session.refresh(directory)

if name == "accountexpires":
new_value = ft_to_dt(int(value)) if value != "0" else None
else:
new_value = value # type: ignore

await session.execute(
update(User)
.filter_by(directory=directory)
.values({name: new_value}),
)
if directory.user:
if name == "accountexpires":
new_value = (
ft_to_dt(int(value)) if value != "0" else None
)
else:
new_value = value # type: ignore

elif name in Group.search_fields and directory.group:
await session.execute(
update(Group)
.filter_by(directory=directory)
.values({name: value}),
)
await session.execute(
update(User)
.filter_by(directory=directory)
.values({name: new_value}),
)

elif name in ("userpassword", "unicodepwd") and directory.user:
if not settings.USE_CORE_TLS:
Expand Down
2 changes: 1 addition & 1 deletion interface
Loading