Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 39 additions & 5 deletions stake/product.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,17 @@
import uuid
from datetime import datetime
from typing import Any, List, Optional
from datetime import date, datetime, timedelta
from typing import TYPE_CHECKING, Any, List, Optional

from pydantic import BaseModel, ConfigDict
from pydantic import BaseModel, ConfigDict, PrivateAttr
from pydantic.fields import Field

from stake.common import BaseClient, camelcase

if TYPE_CHECKING:
from stake.client import StakeClient
from stake.ratings import Rating
from stake.statement import Statement

__all__ = ["ProductSearchByName"]


Expand Down Expand Up @@ -54,8 +59,29 @@ class Product(BaseModel):
inception_date: Optional[datetime] = None
instrument_tags: List[Any]
child_instruments: List[Instrument]
_client: "StakeClient" = PrivateAttr()
model_config = ConfigDict(alias_generator=camelcase)

def model_post_init(self, context: Any | None = None) -> None:
if context:
self._client = context.get("client")
assert self._client

async def ratings(self) -> "List[Rating]":
from stake import RatingsRequest

return await self._client.ratings.list(RatingsRequest(symbols=[self.symbol]))

async def statements(self, start_date: date | None = None) -> "List[Statement]":
from stake.statement import StatementRequest

return await self._client.statements.list(
StatementRequest(
symbol=self.symbol,
start_date=start_date or (date.today() - timedelta(days=365)),
)
)


class ProductsClient(BaseClient):
async def get(self, symbol: str) -> Optional[Product]:
Expand All @@ -68,13 +94,21 @@ async def get(self, symbol: str) -> Optional[Product]:
self._client.exchange.symbol.format(symbol=symbol)
)

return Product(**data["products"][0]) if data["products"] else None
return (
Product.model_validate(
data["products"][0], context=dict(client=self._client)
)
if data["products"]
else None
)

async def search(self, request: ProductSearchByName) -> List[Instrument]:
products = await self._client.get(
self._client.exchange.products_suggestions.format(keyword=request.keyword)
)
return [Instrument(**product) for product in products["instruments"]]
return [
Instrument.model_validate(product) for product in products["instruments"]
]

async def product_from_instrument(
self, instrument: Instrument
Expand Down
6 changes: 4 additions & 2 deletions stake/ratings.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,9 +36,11 @@ class Rating(pydantic.BaseModel):
url_news: Optional[str] = None
analyst_name: Optional[str] = None

@pydantic.field_validator("pt_prior", "rating_prior", mode="before")
@pydantic.field_validator(
"pt_prior", "rating_prior", "pt_current", "rating_current", mode="before"
)
@classmethod
def pt_prior_blank_string(cls, value, *args) -> Optional[str]:
def remove_blank_strings(cls, value, *args) -> Optional[str]:
return None if value == "" else value


Expand Down
7 changes: 4 additions & 3 deletions stake/transaction.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import enum
import json
from datetime import datetime, timedelta
from datetime import datetime, timedelta, timezone
from enum import Enum
from typing import Dict, List, Optional

Expand All @@ -18,9 +18,10 @@ class TransactionRecordEnumDirection(str, Enum):


class TransactionRecordRequest(BaseModel):
to: datetime = Field(default_factory=datetime.utcnow)
to: datetime = Field(default_factory=lambda *_: datetime.now(timezone.utc))
from_: datetime = Field(
default_factory=lambda *_: datetime.utcnow() - timedelta(days=365), alias="from"
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

issue (bug_risk): Same UTC concern applies to from_ default_factory.

Ensure that UTC is properly defined and accessible in this context.

default_factory=lambda *_: datetime.now(timezone.utc) - timedelta(days=365),
alias="from",
)
limit: int = 1000
offset: Optional[datetime] = None
Expand Down
12 changes: 12 additions & 0 deletions tests/test_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,3 +53,15 @@ async def test_integration_ASX(exchange):
)
result = await session.transactions.list(request=request)
assert len(result.transactions) == 10


@pytest.mark.parametrize("exchange", (constant.NYSE,))
@pytest.mark.asyncio
async def test_integration_product(exchange):
async with stake.StakeClient(exchange=exchange) as session:
product = await session.products.get("AAPL")
assert product is not None
ratings = await product.ratings()
assert len(ratings) > 0
statements = await product.statements()
assert len(statements) > 0
Loading