Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "serverish"
version = "1.2.3"
version = "1.3.0"
description = "helpers for server alike projects"
authors = ["Mikołaj Kałuszyński", "MMME team"]
readme = "README.md"
Expand Down
97 changes: 97 additions & 0 deletions serverish/base/iterators.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
import asyncio
from typing import List, Dict, Any, Tuple, Set


class AsyncRangeIter:
"""This class represent async range iterator (instead of sync).

WARNING: Noticed that end is included (not like in standard iterator).
Example 1: start=1, end=4, will return [1, 2, 3, 4]
Example 2: async for n in AsyncRangeIter(start=1, end=3): ...
"""

def __init__(self, start: int, end: int) -> None:
self.start = start
self.end = end

def __aiter__(self) -> Any:
self.current = self.start
return self

async def __anext__(self) -> int:
if self.current <= self.end:
value = self.current
self.current += 1
await asyncio.sleep(0)
return value
else:
raise StopAsyncIteration


class AsyncListIter:
"""This class represent async list iterator (instead of sync).

Example of use: async for n in AsyncListIter(some_list): ...
"""
def __init__(self, iterable: List | Tuple | Set | Any):
self.iterable = iterable
self.index: int = 0

def __aiter__(self) -> Any:
self.index = 0
return self

async def __anext__(self):
if self.index < len(self.iterable):
value = self.iterable[self.index]
self.index += 1
await asyncio.sleep(0)
return value
else:
raise StopAsyncIteration


class AsyncEnumerateIter:
"""This class represent async enumerate iterator (instead of sync).

Example of use: async for current_index, value in AsyncEnumerateIter(some_iterable): ...
"""

def __init__(self, iterable: List | Tuple | Set | Any) -> None:
self.iterable = iterable

def __aiter__(self):
self.index = 0
return self

async def __anext__(self):
if self.index < len(self.iterable):
value = self.iterable[self.index]
current_index = self.index
self.index += 1
await asyncio.sleep(0)
return current_index, value
else:
raise StopAsyncIteration


class AsyncDictItemsIter:
"""This class represent async dict items iterator (instead of sync).

Example of use: async for key, value in AsyncEnumerateIter(some_dict): ...
"""

def __init__(self, data_dict: Dict) -> None:
self.data_dict = data_dict

def __aiter__(self) -> Any:
self.iterator = iter(self.data_dict.items())
return self

async def __anext__(self) -> Tuple:
try:
n, m = next(self.iterator)
await asyncio.sleep(0)
return n, m
except StopIteration:
raise StopAsyncIteration
46 changes: 46 additions & 0 deletions tests/test_iterators.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
import unittest

from serverish.base.iterators import AsyncRangeIter, AsyncListIter, AsyncDictItemsIter, AsyncEnumerateIter


class TestIterAsync(unittest.IsolatedAsyncioTestCase):

async def test_async_range_iter(self):
target_list = [1, 2, 3, 4, 5]
new_list = []
async for n in AsyncRangeIter(1, 5):
new_list.append(n)
# print(target_list)
# print(new_list)
self.assertListEqual(target_list, new_list)

async def test_async_list_iter(self):
target_list = [1, 2, 3, 4, 5]
new_list = []
async for n in AsyncListIter(target_list):
new_list.append(n)
# print(target_list)
# print(new_list)
self.assertListEqual(target_list, new_list)

async def test_async_dict_items_iter(self):
target_dict = {'a': 2, 'b': 55}
new_dict = {}
async for n, m in AsyncDictItemsIter(target_dict):
new_dict[n] = m
# print(target_dict)
# print(new_dict)
self.assertDictEqual(target_dict, new_dict)

async def test_async_enumerate_items_iter(self):
target_dict = {0: 1, 1: 2, 2: 3}
new_dict = {}
async for n, m in AsyncEnumerateIter([m for n, m in target_dict.items()]):
new_dict[n] = m
# print(target_dict)
# print(new_dict)
self.assertDictEqual(target_dict, new_dict)


if __name__ == '__main__':
unittest.main()