Skip to content

Commit 3d6b29d

Browse files
fix: return types not always returning the correct type
1 parent aa3b4de commit 3d6b29d

File tree

2 files changed

+41
-29
lines changed

2 files changed

+41
-29
lines changed

tests/mocks.py

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,6 @@
5757
"HashableMixin",
5858
"CustomMockMixin",
5959
"ColourMixin",
60-
"MockAsyncWebhook",
6160
"MockAttachment",
6261
"MockBot",
6362
"MockCategoryChannel",
@@ -75,6 +74,7 @@
7574
"MockThread",
7675
"MockUser",
7776
"MockVoiceChannel",
77+
"MockWebhook",
7878
]
7979

8080

@@ -158,7 +158,7 @@ def generate_mock_message(content=unittest.mock.DEFAULT, *args, **kwargs):
158158
discord.Thread: lambda *args, **kwargs: MockThread(),
159159
discord.User: lambda *args, **kwargs: MockUser(),
160160
discord.VoiceChannel: lambda *args, **kwargs: MockVoiceChannel(),
161-
discord.Webhook: lambda *args, **kwargs: MockAsyncWebhook(),
161+
discord.Webhook: lambda *args, **kwargs: MockWebhook(),
162162
}
163163

164164

@@ -222,10 +222,8 @@ def __init__(self, **kwargs):
222222
# this list can be added to as methods are discovered
223223
if attr.__name__ == "send":
224224
hints["return"] = discord.Message
225-
elif self.__class__ == discord.Message and attr.__name__ == "edit":
226-
# set up message editing to return the same object
227-
mock_config[f"{attr.__name__}.return_value"] = self
228-
continue
225+
elif attr.__name__ == "edit":
226+
hints["return"] = type(self.spec_set)
229227

230228
if hints.get("return") is None:
231229
continue
@@ -822,7 +820,7 @@ def __init__(self, **kwargs) -> None:
822820
webhook_instance = discord.Webhook(data=unittest.mock.MagicMock(), session=unittest.mock.MagicMock())
823821

824822

825-
class MockAsyncWebhook(CustomMockMixin, unittest.mock.NonCallableMagicMock):
823+
class MockWebhook(CustomMockMixin, unittest.mock.NonCallableMagicMock):
826824
"""
827825
A MagicMock subclass to mock Webhook objects using an AsyncWebhookAdapter.
828826

tests/test_mocks.py

Lines changed: 36 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -400,42 +400,56 @@ class TestReturnTypes:
400400
Eg, ctx.send should return a message object.
401401
"""
402402

403+
@pytest.mark.parametrize(
404+
"mock_cls",
405+
[
406+
mocks.MockClientUser,
407+
mocks.MockGuild,
408+
mocks.MockMember,
409+
mocks.MockMessage,
410+
mocks.MockTextChannel,
411+
mocks.MockVoiceChannel,
412+
mocks.MockWebhook,
413+
],
414+
)
403415
@pytest.mark.asyncio
404-
async def test_message_edit_returns_self(self):
405-
"""Message editing edits the message in place. We should be returning the message."""
406-
msg = mocks.MockMessage()
407-
408-
new_msg = await msg.edit()
416+
async def test_edit_returns_same_class(self, mock_cls):
417+
"""Edit methods return a new instance of the same type."""
418+
mock = mock_cls()
409419

410-
assert isinstance(new_msg, discord.Message)
420+
new_mock = await mock.edit()
411421

412-
assert msg is new_msg
422+
assert isinstance(new_mock, type(mock_cls.spec_set))
413423

424+
@pytest.mark.parametrize(
425+
"mock_cls",
426+
[
427+
mocks.MockMember,
428+
mocks.MockTextChannel,
429+
mocks.MockThread,
430+
mocks.MockUser,
431+
],
432+
)
414433
@pytest.mark.asyncio
415-
async def test_channel_send_returns_message(self):
434+
async def test_messageable_send_returns_message(self, mock_cls):
416435
"""Ensure that channel objects return mocked messages when sending messages."""
417-
channel = mocks.MockTextChannel()
436+
messageable = mock_cls()
418437

419-
msg = await channel.send("hi")
438+
msg = await messageable.send("hi")
420439

421440
print(type(msg))
422441
assert isinstance(msg, discord.Message)
423442

443+
@pytest.mark.parametrize(
444+
"mock_cls",
445+
[mocks.MockMessage, mocks.MockTextChannel],
446+
)
424447
@pytest.mark.asyncio
425-
async def test_message_thread_create_returns_thread(self):
426-
"""Thread create methods should return a MockThread."""
427-
msg = mocks.MockMessage()
428-
429-
thread = await msg.create_thread()
430-
431-
assert isinstance(thread, discord.Thread)
432-
433-
@pytest.mark.asyncio
434-
async def test_channel_thread_create_returns_thread(self):
448+
async def test_thread_create_returns_thread(self, mock_cls):
435449
"""Thread create methods should return a MockThread."""
436-
channel = mocks.MockTextChannel()
450+
mock = mock_cls()
437451

438-
thread = await channel.create_thread()
452+
thread = await mock.create_thread()
439453

440454
assert isinstance(thread, discord.Thread)
441455

0 commit comments

Comments
 (0)