diff --git a/stake/product.py b/stake/product.py index 938df45..b6a5ab5 100644 --- a/stake/product.py +++ b/stake/product.py @@ -1,12 +1,17 @@ import uuid -from datetime import datetime -from typing import Any, List, Optional +from datetime import date, datetime, timedelta +from typing import TYPE_CHECKING, Any, List, Optional -from pydantic import BaseModel, ConfigDict +from pydantic import BaseModel, ConfigDict, PrivateAttr from pydantic.fields import Field from stake.common import BaseClient, camelcase +if TYPE_CHECKING: + from stake.client import StakeClient + from stake.ratings import Rating + from stake.statement import Statement + __all__ = ["ProductSearchByName"] @@ -54,8 +59,29 @@ class Product(BaseModel): inception_date: Optional[datetime] = None instrument_tags: List[Any] child_instruments: List[Instrument] + _client: "StakeClient" = PrivateAttr() model_config = ConfigDict(alias_generator=camelcase) + def model_post_init(self, context: Any | None = None) -> None: + if context: + self._client = context.get("client") + assert self._client + + async def ratings(self) -> "List[Rating]": + from stake import RatingsRequest + + return await self._client.ratings.list(RatingsRequest(symbols=[self.symbol])) + + async def statements(self, start_date: date | None = None) -> "List[Statement]": + from stake.statement import StatementRequest + + return await self._client.statements.list( + StatementRequest( + symbol=self.symbol, + start_date=start_date or (date.today() - timedelta(days=365)), + ) + ) + class ProductsClient(BaseClient): async def get(self, symbol: str) -> Optional[Product]: @@ -68,13 +94,21 @@ async def get(self, symbol: str) -> Optional[Product]: self._client.exchange.symbol.format(symbol=symbol) ) - return Product(**data["products"][0]) if data["products"] else None + return ( + Product.model_validate( + data["products"][0], context=dict(client=self._client) + ) + if data["products"] + else None + ) async def search(self, request: ProductSearchByName) -> List[Instrument]: products = await self._client.get( self._client.exchange.products_suggestions.format(keyword=request.keyword) ) - return [Instrument(**product) for product in products["instruments"]] + return [ + Instrument.model_validate(product) for product in products["instruments"] + ] async def product_from_instrument( self, instrument: Instrument diff --git a/stake/ratings.py b/stake/ratings.py index dc0e6a7..92f69c8 100644 --- a/stake/ratings.py +++ b/stake/ratings.py @@ -36,9 +36,11 @@ class Rating(pydantic.BaseModel): url_news: Optional[str] = None analyst_name: Optional[str] = None - @pydantic.field_validator("pt_prior", "rating_prior", mode="before") + @pydantic.field_validator( + "pt_prior", "rating_prior", "pt_current", "rating_current", mode="before" + ) @classmethod - def pt_prior_blank_string(cls, value, *args) -> Optional[str]: + def remove_blank_strings(cls, value, *args) -> Optional[str]: return None if value == "" else value diff --git a/stake/transaction.py b/stake/transaction.py index 36b0e46..75066b3 100644 --- a/stake/transaction.py +++ b/stake/transaction.py @@ -1,6 +1,6 @@ import enum import json -from datetime import datetime, timedelta +from datetime import datetime, timedelta, timezone from enum import Enum from typing import Dict, List, Optional @@ -18,9 +18,10 @@ class TransactionRecordEnumDirection(str, Enum): class TransactionRecordRequest(BaseModel): - to: datetime = Field(default_factory=datetime.utcnow) + to: datetime = Field(default_factory=lambda *_: datetime.now(timezone.utc)) from_: datetime = Field( - default_factory=lambda *_: datetime.utcnow() - timedelta(days=365), alias="from" + default_factory=lambda *_: datetime.now(timezone.utc) - timedelta(days=365), + alias="from", ) limit: int = 1000 offset: Optional[datetime] = None diff --git a/tests/test_integration.py b/tests/test_integration.py index 034c12b..787b749 100644 --- a/tests/test_integration.py +++ b/tests/test_integration.py @@ -53,3 +53,15 @@ async def test_integration_ASX(exchange): ) result = await session.transactions.list(request=request) assert len(result.transactions) == 10 + + +@pytest.mark.parametrize("exchange", (constant.NYSE,)) +@pytest.mark.asyncio +async def test_integration_product(exchange): + async with stake.StakeClient(exchange=exchange) as session: + product = await session.products.get("AAPL") + assert product is not None + ratings = await product.ratings() + assert len(ratings) > 0 + statements = await product.statements() + assert len(statements) > 0