From 583188cd4237b0aa89914aa0d264d9e5ec22ec04 Mon Sep 17 00:00:00 2001 From: Stefano Tabacco Date: Mon, 28 Jul 2025 23:02:35 +1000 Subject: [PATCH 1/7] Extended product with additional functionalities --- stake/product.py | 26 +++++++++++++++++++++----- stake/ratings.py | 2 +- stake/transaction.py | 6 +++--- tests/test_integration.py | 13 +++++++++++++ 4 files changed, 38 insertions(+), 9 deletions(-) diff --git a/stake/product.py b/stake/product.py index 938df45..50acdaa 100644 --- a/stake/product.py +++ b/stake/product.py @@ -1,12 +1,17 @@ import uuid -from datetime import datetime -from typing import Any, List, Optional +from datetime import date, datetime, timedelta +from typing import TYPE_CHECKING, Any, List, Optional -from pydantic import BaseModel, ConfigDict +from pydantic import BaseModel, ConfigDict, PrivateAttr from pydantic.fields import Field from stake.common import BaseClient, camelcase +if TYPE_CHECKING: + from stake.ratings import Rating + from stake.statement import Statement + from stake.client import StakeClient + __all__ = ["ProductSearchByName"] @@ -54,8 +59,19 @@ class Product(BaseModel): inception_date: Optional[datetime] = None instrument_tags: List[Any] child_instruments: List[Instrument] + _client: "StakeClient" = PrivateAttr() model_config = ConfigDict(alias_generator=camelcase) + def model_post_init(self, context: Any) -> None: + self._client = context.get("client") + + async def ratings(self) -> "List[Rating]": + from stake import RatingsRequest + return await self._client.ratings.list(RatingsRequest(symbols=[self.symbol])) + + async def statements(self, start_date: date | None = None) -> "List[Statement]": + from stake.statement import StatementRequest + return await self._client.statements.list(StatementRequest(symbol=self.symbol, start_date=start_date or (date.today() - timedelta(days=365)))) class ProductsClient(BaseClient): async def get(self, symbol: str) -> Optional[Product]: @@ -68,13 +84,13 @@ async def get(self, symbol: str) -> Optional[Product]: self._client.exchange.symbol.format(symbol=symbol) ) - return Product(**data["products"][0]) if data["products"] else None + return Product.model_validate(data["products"][0], context=dict(client=self._client)) if data["products"] else None async def search(self, request: ProductSearchByName) -> List[Instrument]: products = await self._client.get( self._client.exchange.products_suggestions.format(keyword=request.keyword) ) - return [Instrument(**product) for product in products["instruments"]] + return [Instrument.model_validate(product) for product in products["instruments"]] async def product_from_instrument( self, instrument: Instrument diff --git a/stake/ratings.py b/stake/ratings.py index dc0e6a7..33a133d 100644 --- a/stake/ratings.py +++ b/stake/ratings.py @@ -36,7 +36,7 @@ class Rating(pydantic.BaseModel): url_news: Optional[str] = None analyst_name: Optional[str] = None - @pydantic.field_validator("pt_prior", "rating_prior", mode="before") + @pydantic.field_validator("pt_prior", "rating_prior", "pt_current", "rating_current", mode="before") @classmethod def pt_prior_blank_string(cls, value, *args) -> Optional[str]: return None if value == "" else value diff --git a/stake/transaction.py b/stake/transaction.py index 36b0e46..b69f5e6 100644 --- a/stake/transaction.py +++ b/stake/transaction.py @@ -1,6 +1,6 @@ import enum import json -from datetime import datetime, timedelta +from datetime import datetime, timedelta, UTC from enum import Enum from typing import Dict, List, Optional @@ -18,9 +18,9 @@ class TransactionRecordEnumDirection(str, Enum): class TransactionRecordRequest(BaseModel): - to: datetime = Field(default_factory=datetime.utcnow) + to: datetime = Field(default_factory=lambda *_: datetime.now(UTC)) from_: datetime = Field( - default_factory=lambda *_: datetime.utcnow() - timedelta(days=365), alias="from" + default_factory=lambda *_: datetime.now(UTC) - timedelta(days=365), alias="from" ) limit: int = 1000 offset: Optional[datetime] = None diff --git a/tests/test_integration.py b/tests/test_integration.py index 034c12b..53d5ec5 100644 --- a/tests/test_integration.py +++ b/tests/test_integration.py @@ -53,3 +53,16 @@ async def test_integration_ASX(exchange): ) result = await session.transactions.list(request=request) assert len(result.transactions) == 10 + + + +@pytest.mark.parametrize("exchange", (constant.NYSE,)) +@pytest.mark.asyncio +async def test_integration_product(exchange): + async with stake.StakeClient(exchange=exchange) as session: + product = await session.products.get("AAPL") + assert product is not None + ratings = await product.ratings() + assert len(ratings) > 0 + statements = await product.statements() + assert len(statements) > 0 \ No newline at end of file From 125ced803bf84964740988c88ec84e662caf3f54 Mon Sep 17 00:00:00 2001 From: Stefano Tabacco Date: Mon, 28 Jul 2025 23:05:56 +1000 Subject: [PATCH 2/7] Extended product with additional functionalities --- stake/product.py | 26 +++++++++++++++++++++----- stake/ratings.py | 4 +++- stake/transaction.py | 2 +- tests/test_integration.py | 3 +-- 4 files changed, 26 insertions(+), 9 deletions(-) diff --git a/stake/product.py b/stake/product.py index 50acdaa..9c4654d 100644 --- a/stake/product.py +++ b/stake/product.py @@ -8,10 +8,10 @@ from stake.common import BaseClient, camelcase if TYPE_CHECKING: + from stake.client import StakeClient from stake.ratings import Rating from stake.statement import Statement - from stake.client import StakeClient - + __all__ = ["ProductSearchByName"] @@ -67,11 +67,19 @@ def model_post_init(self, context: Any) -> None: async def ratings(self) -> "List[Rating]": from stake import RatingsRequest + return await self._client.ratings.list(RatingsRequest(symbols=[self.symbol])) async def statements(self, start_date: date | None = None) -> "List[Statement]": from stake.statement import StatementRequest - return await self._client.statements.list(StatementRequest(symbol=self.symbol, start_date=start_date or (date.today() - timedelta(days=365)))) + + return await self._client.statements.list( + StatementRequest( + symbol=self.symbol, + start_date=start_date or (date.today() - timedelta(days=365)), + ) + ) + class ProductsClient(BaseClient): async def get(self, symbol: str) -> Optional[Product]: @@ -84,13 +92,21 @@ async def get(self, symbol: str) -> Optional[Product]: self._client.exchange.symbol.format(symbol=symbol) ) - return Product.model_validate(data["products"][0], context=dict(client=self._client)) if data["products"] else None + return ( + Product.model_validate( + data["products"][0], context=dict(client=self._client) + ) + if data["products"] + else None + ) async def search(self, request: ProductSearchByName) -> List[Instrument]: products = await self._client.get( self._client.exchange.products_suggestions.format(keyword=request.keyword) ) - return [Instrument.model_validate(product) for product in products["instruments"]] + return [ + Instrument.model_validate(product) for product in products["instruments"] + ] async def product_from_instrument( self, instrument: Instrument diff --git a/stake/ratings.py b/stake/ratings.py index 33a133d..ce4f517 100644 --- a/stake/ratings.py +++ b/stake/ratings.py @@ -36,7 +36,9 @@ class Rating(pydantic.BaseModel): url_news: Optional[str] = None analyst_name: Optional[str] = None - @pydantic.field_validator("pt_prior", "rating_prior", "pt_current", "rating_current", mode="before") + @pydantic.field_validator( + "pt_prior", "rating_prior", "pt_current", "rating_current", mode="before" + ) @classmethod def pt_prior_blank_string(cls, value, *args) -> Optional[str]: return None if value == "" else value diff --git a/stake/transaction.py b/stake/transaction.py index b69f5e6..9e29f79 100644 --- a/stake/transaction.py +++ b/stake/transaction.py @@ -1,6 +1,6 @@ import enum import json -from datetime import datetime, timedelta, UTC +from datetime import UTC, datetime, timedelta from enum import Enum from typing import Dict, List, Optional diff --git a/tests/test_integration.py b/tests/test_integration.py index 53d5ec5..787b749 100644 --- a/tests/test_integration.py +++ b/tests/test_integration.py @@ -55,7 +55,6 @@ async def test_integration_ASX(exchange): assert len(result.transactions) == 10 - @pytest.mark.parametrize("exchange", (constant.NYSE,)) @pytest.mark.asyncio async def test_integration_product(exchange): @@ -65,4 +64,4 @@ async def test_integration_product(exchange): ratings = await product.ratings() assert len(ratings) > 0 statements = await product.statements() - assert len(statements) > 0 \ No newline at end of file + assert len(statements) > 0 From 680594a04283ea194210eaab6a9d923899623208 Mon Sep 17 00:00:00 2001 From: Stefano Tabacco Date: Mon, 28 Jul 2025 23:07:11 +1000 Subject: [PATCH 3/7] Added assertion --- stake/product.py | 1 + 1 file changed, 1 insertion(+) diff --git a/stake/product.py b/stake/product.py index 9c4654d..512b0ee 100644 --- a/stake/product.py +++ b/stake/product.py @@ -64,6 +64,7 @@ class Product(BaseModel): def model_post_init(self, context: Any) -> None: self._client = context.get("client") + assert self._client async def ratings(self) -> "List[Rating]": from stake import RatingsRequest From cd132d6078c0151b8c0692eb0744375b2131cd7b Mon Sep 17 00:00:00 2001 From: Stefano Tabacco Date: Mon, 28 Jul 2025 23:12:18 +1000 Subject: [PATCH 4/7] Updated failing tests --- stake/product.py | 7 ++++--- stake/ratings.py | 2 +- 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/stake/product.py b/stake/product.py index 512b0ee..b6a5ab5 100644 --- a/stake/product.py +++ b/stake/product.py @@ -62,9 +62,10 @@ class Product(BaseModel): _client: "StakeClient" = PrivateAttr() model_config = ConfigDict(alias_generator=camelcase) - def model_post_init(self, context: Any) -> None: - self._client = context.get("client") - assert self._client + def model_post_init(self, context: Any | None = None) -> None: + if context: + self._client = context.get("client") + assert self._client async def ratings(self) -> "List[Rating]": from stake import RatingsRequest diff --git a/stake/ratings.py b/stake/ratings.py index ce4f517..92f69c8 100644 --- a/stake/ratings.py +++ b/stake/ratings.py @@ -40,7 +40,7 @@ class Rating(pydantic.BaseModel): "pt_prior", "rating_prior", "pt_current", "rating_current", mode="before" ) @classmethod - def pt_prior_blank_string(cls, value, *args) -> Optional[str]: + def remove_blank_strings(cls, value, *args) -> Optional[str]: return None if value == "" else value From b8fe9187e86ef3a193dfd4aa307a120a5eb085d4 Mon Sep 17 00:00:00 2001 From: Stefano Tabacco Date: Mon, 28 Jul 2025 23:17:00 +1000 Subject: [PATCH 5/7] Updated import --- stake/transaction.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/stake/transaction.py b/stake/transaction.py index 9e29f79..7b3ed57 100644 --- a/stake/transaction.py +++ b/stake/transaction.py @@ -1,6 +1,6 @@ import enum import json -from datetime import UTC, datetime, timedelta +from datetime import datetime, timedelta, timezone from enum import Enum from typing import Dict, List, Optional @@ -18,9 +18,9 @@ class TransactionRecordEnumDirection(str, Enum): class TransactionRecordRequest(BaseModel): - to: datetime = Field(default_factory=lambda *_: datetime.now(UTC)) + to: datetime = Field(default_factory=lambda *_: datetime.now(timezone.UTC)) from_: datetime = Field( - default_factory=lambda *_: datetime.now(UTC) - timedelta(days=365), alias="from" + default_factory=lambda *_: datetime.now(timezone.UTC) - timedelta(days=365), alias="from" ) limit: int = 1000 offset: Optional[datetime] = None From 9ac573d3ae6561a23f30dc4bfec8d8f87014d9af Mon Sep 17 00:00:00 2001 From: Stefano Tabacco Date: Mon, 28 Jul 2025 23:19:13 +1000 Subject: [PATCH 6/7] Trying to fix utc --- stake/transaction.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/stake/transaction.py b/stake/transaction.py index 7b3ed57..bf5036c 100644 --- a/stake/transaction.py +++ b/stake/transaction.py @@ -18,9 +18,9 @@ class TransactionRecordEnumDirection(str, Enum): class TransactionRecordRequest(BaseModel): - to: datetime = Field(default_factory=lambda *_: datetime.now(timezone.UTC)) + to: datetime = Field(default_factory=lambda *_: datetime.now(timezone.utc)) from_: datetime = Field( - default_factory=lambda *_: datetime.now(timezone.UTC) - timedelta(days=365), alias="from" + default_factory=lambda *_: datetime.now(timezone.utc) - timedelta(days=365), alias="from" ) limit: int = 1000 offset: Optional[datetime] = None From 29c725a3264c5e7496b0cf10a7759b4642739bbe Mon Sep 17 00:00:00 2001 From: Stefano Tabacco Date: Mon, 28 Jul 2025 23:21:19 +1000 Subject: [PATCH 7/7] run black --- stake/transaction.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/stake/transaction.py b/stake/transaction.py index bf5036c..75066b3 100644 --- a/stake/transaction.py +++ b/stake/transaction.py @@ -20,7 +20,8 @@ class TransactionRecordEnumDirection(str, Enum): class TransactionRecordRequest(BaseModel): to: datetime = Field(default_factory=lambda *_: datetime.now(timezone.utc)) from_: datetime = Field( - default_factory=lambda *_: datetime.now(timezone.utc) - timedelta(days=365), alias="from" + default_factory=lambda *_: datetime.now(timezone.utc) - timedelta(days=365), + alias="from", ) limit: int = 1000 offset: Optional[datetime] = None