From 085d137c98640485975ee1181ef792cd34d1ba5f Mon Sep 17 00:00:00 2001 From: Jens Horn Date: Thu, 14 Aug 2025 16:46:50 +0200 Subject: [PATCH 1/4] Use refresh token and handle 429. --- pycheckwatt/__init__.py | 1328 ++++++++++++++++---------- tests/unit/test_auth_and_requests.py | 555 +++++++++++ tests/unit/test_checkwatt_manager.py | 222 ++--- 3 files changed, 1428 insertions(+), 677 deletions(-) create mode 100644 tests/unit/test_auth_and_requests.py diff --git a/pycheckwatt/__init__.py b/pycheckwatt/__init__.py index 77fb534..39d52b2 100644 --- a/pycheckwatt/__init__.py +++ b/pycheckwatt/__init__.py @@ -20,11 +20,15 @@ from __future__ import annotations +import asyncio import base64 import json import logging +import random import re from datetime import date, datetime, timedelta +from email.utils import parsedate_to_datetime +from typing import Any, Dict, Optional, Union from aiohttp import ClientError, ClientResponseError, ClientSession from dateutil.relativedelta import relativedelta @@ -35,22 +39,54 @@ class CheckwattManager: """CheckWatt manager.""" - def __init__(self, username, password, application="pyCheckwatt") -> None: + def __init__( + self, + username, + password, + application="pyCheckwatt", + *, + max_retries_429: int = 3, + backoff_base: float = 0.5, + backoff_factor: float = 2.0, + backoff_max: float = 30.0, + clock_skew_seconds: int = 10, + max_concurrent_requests: int = 5 + ) -> None: """Initialize the CheckWatt manager.""" if username is None or password is None: raise ValueError("Username and password must be provided.") + + # Core session and configuration self.session = None self.base_url = "https://api.checkwatt.se" self.username = username self.password = password + self.header_identifier = application + + # Authentication state + self.jwt_token = None + self.refresh_token = None + self.refresh_token_expires = None + + # Concurrency control + self._auth_lock = asyncio.Lock() + self._req_semaphore = asyncio.Semaphore(max_concurrent_requests) + + # Configuration knobs + self.max_retries_429 = max_retries_429 + self.backoff_base = backoff_base + self.backoff_factor = backoff_factor + self.backoff_max = backoff_max + self.clock_skew_seconds = clock_skew_seconds + self.max_concurrent_requests = max_concurrent_requests + + # Data properties (existing) self.dailyaverage = 0 self.monthestimate = 0 self.revenue = None self.revenueyear = None self.revenueyeartotal = 0 self.revenuemonth = 0 - self.jwt_token = None - self.refresh_token = None self.customer_details = None self.battery_registration = None self.battery_charge_peak_ac = None @@ -70,7 +106,6 @@ def __init__(self, username, password, application="pyCheckwatt") -> None: self.price_zone = None self.spot_prices = None self.energy_data = None - self.header_identifier = application self.rpi_data = None self.meter_data = None self.display_name = None @@ -107,123 +142,316 @@ def _get_headers(self): "X-pyCheckwatt-Application": self.header_identifier, } - def _extract_content_and_logbook(self, input_string): - """Pull the registered information from the logbook.""" - battery_registration = None - - # Define the pattern to match the content between the tags - pattern = re.compile( - r"#BEGIN_BATTERY_REGISTRATION(.*?)#END_BATTERY_REGISTRATION", re.DOTALL - ) - - # Find all matches in the input string - matches = re.findall(pattern, input_string) - - # Extracted content - extracted_content = "" - if matches: - extracted_content = matches[0].strip() - battery_registration = json.loads(extracted_content) - - # Extract logbook entries - logbook_entries = input_string.split("\n") - - # Filter out entries containing - # #BEGIN_BATTERY_REGISTRATION and #END_BATTERY_REGISTRATION - logbook_entries = [ - entry.strip() - for entry in logbook_entries - if not ( - "#BEGIN_BATTERY_REGISTRATION" in entry - or "#END_BATTERY_REGISTRATION" in entry - ) - ] - - return battery_registration, logbook_entries - - def _extract_fcr_d_state(self): - pattern = re.compile( - r"\[ FCR-D (ACTIVATED|DEACTIVATE|FAIL ACTIVATION) \] (?:(?:\d+x)?\s?(\S+) --(\d+)-- | (?:(?:UP|DOWN) (?:\d+,\d+) Hz ))((?:(\d+,\d+)\/(\d+,\d+)\/)?(\d+,\d+|[A-Z]+) %)\s+\((\d+,\d+\/\d+,\d+|\d+\/\d+|\d+) kW\)\s*-?\s*.*?(\d{4}-\d{2}-\d{2} \d{2}:\d{2}:\d{2})" # noqa: E501 - ) - for entry in self.logbook_entries: - match = pattern.search(entry) - if match: - self.fcrd_state = match.group(1) - fcrd_percentage = ( - match.group(4) - if self.fcrd_state in ["ACTIVATED", "FAIL ACTIVATION"] - else None - ) - self.fcrd_percentage_up = ( - match.group(5) - if self.fcrd_state in ["ACTIVATED", "FAIL ACTIVATION"] - else None - ) - self.fcrd_percentage_response = ( - match.group(6) - if self.fcrd_state in ["ACTIVATED", "FAIL ACTIVATION"] - else None - ) - self.fcrd_percentage_down = ( - match.group(7) - if self.fcrd_state in ["ACTIVATED", "FAIL ACTIVATION"] - else None - ) - error_info = match.group(4) if self.fcrd_state == "DEACTIVATE" else None - self.fcrd_power = match.group(8) - self.fcrd_timestamp = match.group(9) - if fcrd_percentage is not None: - self.fcrd_info = fcrd_percentage - elif error_info is not None: - error_info = error_info.split("]", 1)[0].strip() - self.fcrd_info = error_info.strip("[]").strip() + def _jwt_is_valid(self) -> bool: + """Check if JWT token is valid and not expiring soon.""" + if not self.jwt_token: + return False + + try: + # Simple JWT expiration check - decode the payload part + parts = self.jwt_token.split('.') + if len(parts) != 3: + return False + + # Decode the payload (second part) + payload = base64.urlsafe_b64decode(parts[1] + '==').decode('utf-8') + claims = json.loads(payload) + + exp = claims.get('exp') + if not exp: + return False + + # Check if token expires within clock skew + now = datetime.utcnow().timestamp() + return now < (exp - self.clock_skew_seconds) + + except (ValueError, json.JSONDecodeError, UnicodeDecodeError, TypeError): + # If we can't decode, treat as unknown validity + return False + + def _refresh_is_valid(self) -> bool: + """Check if refresh token is valid and not expired.""" + if not self.refresh_token or not self.refresh_token_expires: + return False + + try: + # Parse the expiration timestamp + expires = datetime.fromisoformat(self.refresh_token_expires.replace('Z', '+00:00')) + now = datetime.now(expires.tzinfo) if expires.tzinfo else datetime.utcnow() + + # Add some buffer (5 minutes) to avoid edge cases + return now < (expires - timedelta(minutes=5)) + + except (ValueError, TypeError): + # If we can't parse, treat as unknown validity + return False + + async def _refresh(self) -> bool: + """Refresh the JWT token using the refresh token.""" + if not self.refresh_token: + return False + + try: + endpoint = "/user/RefreshToken?audience=eib" + headers = { + **self._get_headers(), + "authorization": f"RefreshToken {self.refresh_token}", + } + + async with self.session.get( + self.base_url + endpoint, + headers=headers, + timeout=10 + ) as response: + if response.status == 200: + data = await response.json() + + # Update tokens + self.jwt_token = data.get("JwtToken") + if "RefreshToken" in data: + self.refresh_token = data.get("RefreshToken") + if "RefreshTokenExpires" in data: + self.refresh_token_expires = data.get("RefreshTokenExpires") + + _LOGGER.info("Successfully refreshed JWT token") + return True + + elif response.status == 401: + _LOGGER.warning("Refresh token expired or invalid") + return False + else: - self.fcrd_info = None - break # stop so we get the first row in logbook - - async def handle_client_error(self, endpoint, headers, error): - """Handle ClientError and log relevant information.""" - _LOGGER.error( - "An error occurred during the request. URL: %s, Headers: %s. Error: %s", - self.base_url + endpoint, - headers, - error, - ) - return False + _LOGGER.error("Unexpected status code during refresh: %d", response.status) + return False + + except (ClientResponseError, ClientError) as error: + _LOGGER.error("Error during token refresh: %s", error) + return False + + async def _ensure_token(self) -> bool: + """Ensure we have a valid JWT token, refreshing or logging in if needed.""" + # Quick check without lock + if self.jwt_token and self._jwt_is_valid(): + return True + + # Need to acquire lock for auth operations + async with self._auth_lock: + # Double-check after acquiring lock + if self.jwt_token and self._jwt_is_valid(): + return True + + # Try refresh first + if self.refresh_token and self._refresh_is_valid(): + if await self._refresh(): + return True + + # Fall back to login + _LOGGER.info("Performing password login") + return await self.login() + + async def _request( + self, + method: str, + endpoint: str, + *, + headers: Optional[Dict[str, str]] = None, + auth_required: bool = True, + retry_on_401: bool = True, + retry_on_429: bool = True, + timeout: int = 10, + **kwargs + ) -> Union[Dict[str, Any], str, bool, None]: + """ + Centralized request wrapper with authentication and retry logic. + + Args: + method: HTTP method (GET, POST, etc.) + endpoint: API endpoint path + headers: Additional headers to merge with common headers + auth_required: Whether authentication is required + retry_on_401: Whether to retry on 401 (with refresh/login) + retry_on_429: Whether to retry on 429 (with backoff) + timeout: Request timeout in seconds + **kwargs: Additional arguments for the request + + Returns: + Response data (dict for JSON, str for text) or boolean for success/failure + """ + # Ensure we have a valid token if auth is required + if auth_required: + if not await self._ensure_token(): + return False + + # Prepare headers + final_headers = {**self._get_headers(), **(headers or {})} + if auth_required and self.jwt_token: + final_headers["authorization"] = f"Bearer {self.jwt_token}" + + # Remove sensitive headers from logging + safe_headers = {k: v for k, v in final_headers.items() + if k.lower() not in ['authorization', 'cookie']} + + # Apply concurrency control + async with self._req_semaphore: + # Perform request with retry logic + for attempt in range(self.max_retries_429 + 1): + try: + _LOGGER.debug("Making %s request to %s (attempt %d)", + method, endpoint, attempt + 1) + + async with self.session.request( + method, + self.base_url + endpoint, + headers=final_headers, + timeout=timeout, + **kwargs + ) as response: + # Handle 401 (Unauthorized) + if response.status == 401 and retry_on_401 and auth_required: + _LOGGER.warning("Received 401, attempting token refresh") + + # Try refresh first + if await self._refresh(): + # Retry the original request once + continue + + # If refresh failed, try login + _LOGGER.warning("Refresh failed, attempting login") + if await self.login(): + # Retry the original request once + continue + + # Both refresh and login failed + _LOGGER.error("Authentication failed after refresh and login attempts") + return False + + # Handle 429 (Too Many Requests) + if response.status == 429 and retry_on_429 and attempt < self.max_retries_429: + retry_after = response.headers.get('Retry-After') + + if retry_after: + try: + # Try to parse as seconds + wait_time = int(retry_after) + except ValueError: + try: + # Try to parse as HTTP date + retry_date = parsedate_to_datetime(retry_after) + if retry_date.tzinfo is None: + retry_date = retry_date.replace(tzinfo=timezone.utc) + wait_time = (retry_date - datetime.now(timezone.utc)).total_seconds() + wait_time = max(0, wait_time) + except (ValueError, TypeError): + wait_time = self.backoff_base + else: + # Use exponential backoff with jitter + wait_time = min( + self.backoff_base * (self.backoff_factor ** attempt), + self.backoff_max + ) + # Add jitter (0 to 0.25s) + wait_time += random.uniform(0, 0.25) + + _LOGGER.info("Rate limited (429), waiting %.2f seconds before retry", wait_time) + await asyncio.sleep(wait_time) + continue + + # Handle other status codes + response.raise_for_status() + + # Parse response based on content type + content_type = response.headers.get('Content-Type', '').lower() + + if 'application/json' in content_type: + return await response.json() + else: + return await response.text() + + except ClientResponseError as e: + if e.status == 401 and retry_on_401 and auth_required: + # This will be handled in the next iteration + continue + elif e.status == 429 and retry_on_429 and attempt < self.max_retries_429: + # This will be handled in the next iteration + continue + else: + _LOGGER.error("Request failed with status %d: %s", e.status, e) + return await self.handle_client_error(endpoint, safe_headers, e) + + except (ClientError, asyncio.TimeoutError) as error: + _LOGGER.error("Request failed: %s", error) + return await self.handle_client_error(endpoint, safe_headers, error) + + # If we get here, we've exhausted all retries + _LOGGER.error("Request failed after %d attempts", self.max_retries_429 + 1) + return False async def _continue_kill_switch_not_enabled(self): """Check if CheckWatt has requested integrations to back-off.""" + url = "https://checkwatt.se/ha-killswitch.txt" + + # Ensure session is initialized + if self.session is None: + _LOGGER.error("Session not initialized. Use async context manager or call ensure_session() first.") + return False + try: - url = "https://checkwatt.se/ha-killswitch.txt" - headers = {**self._get_headers()} - async with self.session.get(url, headers=headers) as response: + headers = self._get_headers() + + # Ensure headers is a valid dictionary + if not isinstance(headers, dict): + _LOGGER.error("_get_headers() returned invalid type: %s, defaulting to empty dict", type(headers)) + headers = {} + + async with self.session.get(url, headers=headers, timeout=10) as response: data = await response.text() if response.status == 200: kill = data.strip() # Remove leading and trailing whitespaces - if kill == "0": - # We are OK to continue - _LOGGER.debug( - "CheckWatt accepted and not enabled the kill-switch" - ) - return True - - # Kill was requested - _LOGGER.error( - "CheckWatt has requested to back down by enabling the kill-switch" # noqa: E501 - ) - return False + enabled = kill == "0" + + if enabled: + _LOGGER.debug("CheckWatt accepted and not enabled the kill-switch") + else: + _LOGGER.error("CheckWatt has requested to back down by enabling the kill-switch") + + return enabled if response.status == 401: - _LOGGER.error( - "Unauthorized: Check your CheckWatt authentication credentials" - ) + _LOGGER.error("Unauthorized: Check your CheckWatt authentication credentials") return False _LOGGER.error("Unexpected HTTP status code: %s", response.status) return False - except (ClientResponseError, ClientError) as error: - return await self.handle_client_error(url, headers, error) + except Exception as error: + # Create safe headers for logging, handling case where headers might not be defined + try: + safe_headers = {k: v for k, v in headers.items() + if k.lower() not in ['authorization', 'cookie']} + except (AttributeError, NameError): + safe_headers = {} + + _LOGGER.error( + "Killswitch check failed. URL: %s, Headers: %s. Error: %s", + url, + safe_headers, + error, + ) + return False + + async def handle_client_error(self, endpoint, headers, error): + """Handle ClientError and log relevant information.""" + # Remove sensitive headers from logging + safe_headers = {k: v for k, v in headers.items() + if k.lower() not in ['authorization', 'cookie']} + + _LOGGER.error( + "An error occurred during the request. URL: %s, Headers: %s. Error: %s", + self.base_url + endpoint, + safe_headers, + error, + ) + return False async def login(self): """Login to CheckWatt.""" @@ -258,6 +486,8 @@ async def login(self): if response.status == 200: self.jwt_token = data.get("JwtToken") self.refresh_token = data.get("RefreshToken") + self.refresh_token_expires = data.get("RefreshTokenExpires") + _LOGGER.info("Successfully logged in to CheckWatt") return True if response.status == 401: @@ -276,86 +506,73 @@ async def get_customer_details(self): """Fetch customer details from CheckWatt.""" try: endpoint = "/controlpanel/CustomerDetail" + + result = await self._request("GET", endpoint, auth_required=True) + if result is False: + return False + + self.customer_details = result - # Define headers with the JwtToken - headers = { - **self._get_headers(), - "authorization": f"Bearer {self.jwt_token}", - } - - async with self.session.get( - self.base_url + endpoint, headers=headers - ) as response: - response.raise_for_status() - if response.status == 200: - self.customer_details = await response.json() - - meters = self.customer_details.get("Meter", []) - if meters: - soc_meter = next( - ( - meter - for meter in meters - if meter.get("InstallationType") == "SoC" - ), - None, - ) - - if not soc_meter: - _LOGGER.error("No SoC meter found") - return False - - self.display_name = soc_meter.get("DisplayName") - self.reseller_id = soc_meter.get("ResellerId") - self.energy_provider_id = soc_meter.get("ElhandelsbolagId") - self.comments = soc_meter.get("Comments") - logbook = soc_meter.get("Logbook") - if logbook: - ( - self.battery_registration, - self.logbook_entries, - ) = self._extract_content_and_logbook(logbook) - self._extract_fcr_d_state() - - charging_meter = next( - ( - meter - for meter in meters - if meter.get("InstallationType") == "Charging" - ), - None, - ) - if charging_meter: - self.battery_charge_peak_ac = charging_meter.get("PeakAcKw") - self.battery_charge_peak_dc = charging_meter.get("PeakDcKw") - - discharge_meter = next( - ( - meter - for meter in meters - if meter.get("InstallationType") == "Discharging" - ), - None, - ) - if discharge_meter: - self.battery_discharge_peak_ac = discharge_meter.get( - "PeakAcKw" - ) - self.battery_discharge_peak_dc = discharge_meter.get( - "PeakDcKw" - ) + meters = self.customer_details.get("Meter", []) + if meters: + soc_meter = next( + ( + meter + for meter in meters + if meter.get("InstallationType") == "SoC" + ), + None, + ) - return True + if not soc_meter: + _LOGGER.error("No SoC meter found") + return False - _LOGGER.error( - "Obtaining data from URL %s failed with status code %d", - self.base_url + endpoint, - response.status, + self.display_name = soc_meter.get("DisplayName") + self.reseller_id = soc_meter.get("ResellerId") + self.energy_provider_id = soc_meter.get("ElhandelsbolagId") + self.comments = soc_meter.get("Comments") + logbook = soc_meter.get("Logbook") + if logbook: + ( + self.battery_registration, + self.logbook_entries, + ) = self._extract_content_and_logbook(logbook) + self._extract_fcr_d_state() + + charging_meter = next( + ( + meter + for meter in meters + if meter.get("InstallationType") == "Charging" + ), + None, ) - return False + if charging_meter: + self.battery_charge_peak_ac = charging_meter.get("PeakAcKw") + self.battery_charge_peak_dc = charging_meter.get("PeakDcKw") + + discharge_meter = next( + ( + meter + for meter in meters + if meter.get("InstallationType") == "Discharging" + ), + None, + ) + if discharge_meter: + self.battery_discharge_peak_ac = discharge_meter.get( + "PeakAcKw" + ) + self.battery_discharge_peak_dc = discharge_meter.get( + "PeakDcKw" + ) - except (ClientResponseError, ClientError) as error: - return await self.handle_client_error(endpoint, headers, error) + return True + + except Exception as error: + _LOGGER.error("Error in get_customer_details: %s", error) + return False async def get_site_id(self): """Get site ID from RPI serial number.""" @@ -369,42 +586,64 @@ async def get_site_id(self): try: endpoint = f"/Site/SiteIdBySerial?serial={self.rpi_serial}" - headers = { - **self._get_headers(), - "authorization": f"Bearer {self.jwt_token}", - } - - async with self.session.get( - self.base_url + endpoint, headers=headers - ) as response: - response.raise_for_status() - if response.status == 200: - raw_response = await response.text() - - try: - response_data = json.loads(raw_response) - self.site_id = str(response_data["SiteId"]) - return self.site_id - except json.JSONDecodeError as e: - # Fallback - maybe it's just the number as a string - self.site_id = raw_response.strip('"') - return self.site_id - - _LOGGER.error( - "Obtaining data from URL %s failed with status code %d", - self.base_url + endpoint, - response.status, - ) + + result = await self._request("GET", endpoint, auth_required=True) + if result is False: return False - - except (ClientResponseError, ClientError) as error: - return await self.handle_client_error(endpoint, headers, error) + + if isinstance(result, dict) and "SiteId" in result: + self.site_id = str(result["SiteId"]) + _LOGGER.debug("Successfully extracted site ID: %s", self.site_id) + return self.site_id + + _LOGGER.error("Unexpected response format for site ID: %s", result) + return False + + except Exception as error: + _LOGGER.error("Error in get_site_id: %s", error) + return False + + async def debug_revenue_workflow(self): + """Debug method to diagnose revenue workflow issues.""" + _LOGGER.info("=== Revenue Workflow Debug ===") + _LOGGER.info("Customer details loaded: %s", self.customer_details is not None) + _LOGGER.info("RPI data loaded: %s", self.rpi_data is not None) + _LOGGER.info("Site ID cached: %s", self.site_id) + + if self.customer_details: + meters = self.customer_details.get("Meter", []) + _LOGGER.info("Number of meters: %d", len(meters)) + for i, meter in enumerate(meters): + _LOGGER.info("Meter %d: Type=%s, RpiSerial=%s", + i, meter.get("InstallationType"), meter.get("RpiSerial")) + + rpi_serial = self.rpi_serial + _LOGGER.info("RPI Serial: %s", rpi_serial) + + if rpi_serial: + _LOGGER.info("Attempting to get site ID...") + site_id = await self.get_site_id() + _LOGGER.info("Site ID result: %s", site_id) + else: + _LOGGER.error("Cannot get site ID - RPI serial is None") + + _LOGGER.info("=== End Debug ===") async def get_fcrd_month_net_revenue(self): """Fetch FCR-D revenues from CheckWatt.""" misseddays = 0 try: site_id = await self.get_site_id() + if site_id is False: + _LOGGER.error("Failed to get site ID for FCR-D month revenue") + return False + + if not site_id: + _LOGGER.error("Site ID is empty or None for FCR-D month revenue") + return False + + _LOGGER.debug("Using site ID %s for FCR-D month revenue", site_id) + from_date = datetime.now().strftime("%Y-%m-01") to_date = datetime.now() + timedelta(days=1) to_date = to_date.strftime("%Y-%m-%d") @@ -424,48 +663,48 @@ async def get_fcrd_month_net_revenue(self): endpoint = ( f"/revenue/{site_id}?from={from_date}&to={to_date}&resolution=day" ) + _LOGGER.debug("FCR-D month revenue endpoint: %s", endpoint) - # Define headers with the JwtToken - headers = { - **self._get_headers(), - "authorization": f"Bearer {self.jwt_token}", - } - - # First fetch the revenue - async with self.session.get( - self.base_url + endpoint, headers=headers - ) as response: - response.raise_for_status() - revenue = await response.json() - for each in revenue["Revenue"]: - self.revenuemonth += each["NetRevenue"] - if each["NetRevenue"] == 0: - misseddays += 1 - dayswithmoney = int(dayssofar) - int(misseddays) - if response.status == 200: - if dayswithmoney > 0: - self.dailyaverage = self.revenuemonth / int(dayswithmoney) - else: - self.dailyaverage = 0 - self.monthestimate = ( - self.dailyaverage * daysleft - ) + self.revenuemonth - return True - - _LOGGER.error( - "Obtaining data from URL %s failed with status code %d", - self.base_url + endpoint, - response.status, - ) + result = await self._request("GET", endpoint, auth_required=True) + if result is False: + _LOGGER.error("Failed to retrieve FCR-D month revenue from endpoint: %s", endpoint) return False - - except (ClientResponseError, ClientError) as error: - return await self.handle_client_error(endpoint, headers, error) + + revenue = result + for each in revenue["Revenue"]: + self.revenuemonth += each["NetRevenue"] + if each["NetRevenue"] == 0: + misseddays += 1 + dayswithmoney = int(dayssofar) - int(misseddays) + + if dayswithmoney > 0: + self.dailyaverage = self.revenuemonth / int(dayswithmoney) + else: + self.dailyaverage = 0 + self.monthestimate = ( + self.dailyaverage * daysleft + ) + self.revenuemonth + _LOGGER.info("Successfully retrieved FCR-D month revenue") + return True + + except Exception as error: + _LOGGER.error("Error in get_fcrd_month_net_revenue: %s", error) + return False async def get_fcrd_today_net_revenue(self): """Fetch FCR-D revenues from CheckWatt.""" try: site_id = await self.get_site_id() + if site_id is False: + _LOGGER.error("Failed to get site ID for FCR-D today revenue") + return False + + if not site_id: + _LOGGER.error("Site ID is empty or None for FCR-D today revenue") + return False + + _LOGGER.debug("Using site ID %s for FCR-D today revenue", site_id) + from_date = datetime.now().strftime("%Y-%m-%d") end_date = datetime.now() + timedelta(days=2) to_date = end_date.strftime("%Y-%m-%d") @@ -473,34 +712,34 @@ async def get_fcrd_today_net_revenue(self): endpoint = ( f"/revenue/{site_id}?from={from_date}&to={to_date}&resolution=day" ) + _LOGGER.debug("FCR-D today revenue endpoint: %s", endpoint) - # Define headers with the JwtToken - headers = { - **self._get_headers(), - "authorization": f"Bearer {self.jwt_token}", - } - # First fetch the revenue - async with self.session.get( - self.base_url + endpoint, headers=headers - ) as response: - response.raise_for_status() - self.revenue = await response.json() - if response.status == 200: - return True - - _LOGGER.error( - "Obtaining data from URL %s failed with status code %d", - self.base_url + endpoint, - response.status, - ) + result = await self._request("GET", endpoint, auth_required=True) + if result is False: + _LOGGER.error("Failed to retrieve FCR-D today revenue from endpoint: %s", endpoint) return False + + self.revenue = result + _LOGGER.info("Successfully retrieved FCR-D today revenue") + return True - except (ClientResponseError, ClientError) as error: - return await self.handle_client_error(endpoint, headers, error) + except Exception as error: + _LOGGER.error("Error in get_fcrd_today_net_revenue: %s", error) + return False async def get_fcrd_year_net_revenue(self): """Fetch FCR-D revenues from CheckWatt.""" site_id = await self.get_site_id() + if site_id is False: + _LOGGER.error("Failed to get site ID for FCR-D year revenue") + return False + + if not site_id: + _LOGGER.error("Site ID is empty or None for FCR-D year revenue") + return False + + _LOGGER.debug("Using site ID %s for FCR-D year revenue", site_id) + yesterday_date = datetime.now() + timedelta(days=1) yesterday_date = yesterday_date.strftime("-%m-%d") months = ["-01-01", "-06-30", "-07-01", yesterday_date] @@ -514,31 +753,23 @@ async def get_fcrd_year_net_revenue(self): endpoint = ( f"/revenue/{site_id}?from={from_date}&to={to_date}&resolution=day" ) - # Define headers with the JwtToken - headers = { - **self._get_headers(), - "authorization": f"Bearer {self.jwt_token}", - } - # First fetch the revenue - async with self.session.get( - self.base_url + endpoint, headers=headers - ) as responseyear: # noqa: E501 - responseyear.raise_for_status() - self.revenueyear = await responseyear.json() - for each in self.revenueyear["Revenue"]: - self.revenueyeartotal += each["NetRevenue"] - if responseyear.status == 200: - retval = True - else: - _LOGGER.error( - "Obtaining data from URL %s failed with status code %d", - self.base_url + endpoint, - responseyear.status, - ) + _LOGGER.debug("FCR-D year revenue endpoint (first half): %s", endpoint) + + result = await self._request("GET", endpoint, auth_required=True) + if result is False: + _LOGGER.error("Failed to retrieve FCR-D year revenue from endpoint: %s", endpoint) + return False + + self.revenueyear = result + for each in self.revenueyear["Revenue"]: + self.revenueyeartotal += each["NetRevenue"] + retval = True + _LOGGER.info("Successfully retrieved FCR-D year revenue (first half)") return retval - except (ClientResponseError, ClientError) as error: - return await self.handle_client_error(endpoint, headers, error) + except Exception as error: + _LOGGER.error("Error in get_fcrd_year_net_revenue (first half): %s", error) + return False else: try: while loop < 3: @@ -546,37 +777,40 @@ async def get_fcrd_year_net_revenue(self): to_date = year_date + months[loop + 1] from_date = year_date + months[loop] endpoint = f"/revenue/{site_id}?from={from_date}&to={to_date}&resolution=day" - # Define headers with the JwtToken - headers = { - **self._get_headers(), - "authorization": f"Bearer {self.jwt_token}", - } - # First fetch the revenue - async with self.session.get( - self.base_url + endpoint, headers=headers - ) as responseyear: # noqa: E501 - responseyear.raise_for_status() - self.revenueyear = await responseyear.json() - for each in self.revenueyear["Revenue"]: - self.revenueyeartotal += each["NetRevenue"] - if responseyear.status == 200: - loop += 2 - retval = True - else: - _LOGGER.error( - "Obtaining data from URL %s failed with status code %d", # noqa: E501 - self.base_url + endpoint, - responseyear.status, - ) + _LOGGER.debug("FCR-D year revenue endpoint (period %d): %s", loop, endpoint) + + result = await self._request("GET", endpoint, auth_required=True) + if result is False: + _LOGGER.error("Failed to retrieve FCR-D year revenue from endpoint: %s", endpoint) + return False + + self.revenueyear = result + for each in self.revenueyear["Revenue"]: + self.revenueyeartotal += each["NetRevenue"] + loop += 2 + retval = True + + _LOGGER.info("Successfully retrieved FCR-D year revenue (multiple periods)") return retval - except (ClientResponseError, ClientError) as error: - return await self.handle_client_error(endpoint, headers, error) + except Exception as error: + _LOGGER.error("Error in get_fcrd_year_net_revenue (multiple periods): %s", error) + return False async def fetch_and_return_net_revenue(self, from_date, to_date): """Fetch FCR-D revenues from CheckWatt as per provided range.""" try: site_id = await self.get_site_id() + if site_id is False: + _LOGGER.error("Failed to get site ID for custom revenue range") + return None + + if not site_id: + _LOGGER.error("Site ID is empty or None for custom revenue range") + return None + + _LOGGER.debug("Using site ID %s for custom revenue range", site_id) + # Validate date format and ensure they are dates date_format = "%Y-%m-%d" try: @@ -610,30 +844,149 @@ async def fetch_and_return_net_revenue(self, from_date, to_date): endpoint = ( f"/revenue/{site_id}?from={from_date}&to={to_date}&resolution=day" ) + _LOGGER.debug("Custom revenue range endpoint: %s", endpoint) - # Define headers with the JwtToken - headers = { - **self._get_headers(), - "authorization": f"Bearer {self.jwt_token}", - } - # First fetch the revenue - async with self.session.get( - self.base_url + endpoint, headers=headers - ) as response: - response.raise_for_status() - revenue = await response.json() - if response.status == 200: - return revenue + result = await self._request("GET", endpoint, auth_required=True) + if result is False: + _LOGGER.error("Failed to retrieve custom revenue range from endpoint: %s", endpoint) + return None + + _LOGGER.info("Successfully retrieved custom revenue range") + return result - _LOGGER.error( - "Obtaining data from URL %s failed with status code %d", - self.base_url + endpoint, - response.status, + except Exception as error: + _LOGGER.error("Error in fetch_and_return_net_revenue: %s", error) + return None + + def _extract_content_and_logbook(self, input_string): + """Pull the registered information from the logbook.""" + battery_registration = None + + # Define the pattern to match the content between the tags + pattern = re.compile( + r"#BEGIN_BATTERY_REGISTRATION(.*?)#END_BATTERY_REGISTRATION", re.DOTALL + ) + + # Find all matches in the input string + matches = re.findall(pattern, input_string) + + # Extracted content + extracted_content = "" + if matches: + extracted_content = matches[0].strip() + battery_registration = json.loads(extracted_content) + + # Extract logbook entries + logbook_entries = input_string.split("\n") + + # Filter out entries containing + # #BEGIN_BATTERY_REGISTRATION and #END_BATTERY_REGISTRATION + logbook_entries = [ + entry.strip() + for entry in logbook_entries + if not ( + "#BEGIN_BATTERY_REGISTRATION" in entry + or "#END_BATTERY_REGISTRATION" in entry + ) + ] + + return battery_registration, logbook_entries + + def _extract_fcr_d_state(self): + pattern = re.compile( + r"\[ FCR-D (ACTIVATED|DEACTIVATE|FAIL ACTIVATION) \] (?:(?:\d+x)?\s?(\S+) --(\d+)-- | (?:(?:UP|DOWN) (?:\d+,\d+) Hz ))((?:(\d+,\d+)\/(\d+,\d+)\/)?(\d+,\d+|[A-Z]+) %)\s+\((\d+,\d+\/\d+,\d+|\d+\/\d+|\d+) kW\)\s*-?\s*.*?(\d{4}-\d{2}-\d{2} \d{2}:\d{2}:\d{2})" # noqa: E501 + ) + for entry in self.logbook_entries: + match = pattern.search(entry) + if match: + self.fcrd_state = match.group(1) + fcrd_percentage = ( + match.group(4) + if self.fcrd_state in ["ACTIVATED", "FAIL ACTIVATION"] + else None + ) + self.fcrd_percentage_up = ( + match.group(5) + if self.fcrd_state in ["ACTIVATED", "FAIL ACTIVATION"] + else None + ) + self.fcrd_percentage_response = ( + match.group(6) + if self.fcrd_state in ["ACTIVATED", "FAIL ACTIVATION"] + else None + ) + self.fcrd_percentage_down = ( + match.group(7) + if self.fcrd_state in ["ACTIVATED", "FAIL ACTIVATION"] + else None ) + error_info = match.group(4) if self.fcrd_state == "DEACTIVATE" else None + self.fcrd_power = match.group(8) + self.fcrd_timestamp = match.group(9) + if fcrd_percentage is not None: + self.fcrd_info = fcrd_percentage + elif error_info is not None: + error_info = error_info.split("]", 1)[0].strip() + self.fcrd_info = error_info.strip("[]").strip() + else: + self.fcrd_info = None + break # stop so we get the first row in logbook + + + + + + + + async def fetch_and_return_net_revenue(self, from_date, to_date): + """Fetch FCR-D revenues from CheckWatt as per provided range.""" + try: + site_id = await self.get_site_id() + # Validate date format and ensure they are dates + date_format = "%Y-%m-%d" + try: + from_date = datetime.strptime(from_date, date_format).date() + to_date = datetime.strptime(to_date, date_format).date() + except ValueError: + raise ValueError( + "Input dates must be valid dates with the format YYYY-MM-DD." + ) + + # Validate from_date and to_date + today = date.today() + six_months_ago = today - relativedelta(months=6) + + if not (six_months_ago <= from_date <= today): + raise ValueError( + "From date must be within the last 6 months and not beyond today." + ) + + if not (six_months_ago <= to_date <= today): + raise ValueError( + "To date must be within the last 6 months and not beyond today." + ) + + if from_date >= to_date: + raise ValueError("From date must be before To date.") + + # Extend to_date by one day + to_date += timedelta(days=1) + + endpoint = ( + f"/revenue/{site_id}?from={from_date}&to={to_date}&resolution=day" + ) + + result = await self._request("GET", endpoint, auth_required=True) + if result is False: return None + + return result + + except Exception as error: + _LOGGER.error("Error in fetchand_return_net_revenue: %s", error) + return None + - except (ClientResponseError, ClientError) as error: - return await self.handle_client_error(endpoint, headers, error) def _build_series_endpoint(self, grouping): end_date = datetime.now() + timedelta(days=2) @@ -659,30 +1012,16 @@ async def get_power_data(self): 3 ) # 0: Hourly, 1: Daily, 2: Monthly, 3: Yearly - # Define headers with the JwtToken - headers = { - **self._get_headers(), - "authorization": f"Bearer {self.jwt_token}", - } - - # First fetch the revenue - async with self.session.get( - self.base_url + endpoint, headers=headers - ) as response: - response.raise_for_status() - if response.status == 200: - self.power_data = await response.json() - return True - - _LOGGER.error( - "Obtaining data from URL %s failed with status code %d", - self.base_url + endpoint, - response.status, - ) + result = await self._request("GET", endpoint, auth_required=True) + if result is False: return False + + self.power_data = result + return True - except (ClientResponseError, ClientError) as error: - return await self.handle_client_error(endpoint, headers, error) + except Exception as error: + _LOGGER.error("Error in get_power_data: %s", error) + return False async def get_energy_flow(self): """Fetch Power Data from CheckWatt.""" @@ -690,30 +1029,16 @@ async def get_energy_flow(self): try: endpoint = "/ems/energyflow" - # Define headers with the JwtToken - headers = { - **self._get_headers(), - "authorization": f"Bearer {self.jwt_token}", - } - - # Fetch Energy Flows - async with self.session.get( - self.base_url + endpoint, headers=headers - ) as response: - response.raise_for_status() - if response.status == 200: - self.energy_data = await response.json() - return True - - _LOGGER.error( - "Obtaining data from URL %s failed with status code %d", - self.base_url + endpoint, - response.status, - ) + result = await self._request("GET", endpoint, auth_required=True) + if result is False: return False + + self.energy_data = result + return True - except (ClientResponseError, ClientError) as error: - return await self.handle_client_error(endpoint, headers, error) + except Exception as error: + _LOGGER.error("Error in get_energy_flow: %s", error) + return False async def get_ems_settings(self, rpi_serial=None): """Fetch EMS settings from CheckWatt.""" @@ -724,60 +1049,33 @@ async def get_ems_settings(self, rpi_serial=None): endpoint = f"/ems/service/Pending?Serial={rpi_serial}" - # Define headers with the JwtToken - headers = { - **self._get_headers(), - "authorization": f"Bearer {self.jwt_token}", - } - - # Fetch Energy Flows - async with self.session.get( - self.base_url + endpoint, headers=headers - ) as response: - response.raise_for_status() - if response.status == 200: - self.ems = await response.json() - return True - - _LOGGER.error( - "Obtaining data from URL %s failed with status code %d", - self.base_url + endpoint, - response.status, - ) + result = await self._request("GET", endpoint, auth_required=True) + if result is False: return False + + self.ems = result + return True - except (ClientResponseError, ClientError) as error: - return await self.handle_client_error(endpoint, headers, error) + except Exception as error: + _LOGGER.error("Error in get_ems_settings: %s", error) + return False async def get_price_zone(self): """Fetch Price Zone from CheckWatt.""" try: endpoint = "/ems/pricezone" - # Define headers with the JwtToken - headers = { - **self._get_headers(), - "authorization": f"Bearer {self.jwt_token}", - } - - # First fetch the revenue - async with self.session.get( - self.base_url + endpoint, headers=headers - ) as response: - response.raise_for_status() - if response.status == 200: - self.price_zone = await response.text() - return True - - _LOGGER.error( - "Obtaining data from URL %s failed with status code %d", - self.base_url + endpoint, - response.status, - ) + + result = await self._request("GET", endpoint, auth_required=True) + if result is False: return False + + self.price_zone = result + return True - except (ClientResponseError, ClientError) as error: - return await self.handle_client_error(endpoint, headers, error) + except Exception as error: + _LOGGER.error("Error in get_price_zone: %s", error) + return False async def get_spot_price(self): """Fetch Spot Price from CheckWatt.""" @@ -789,30 +1087,17 @@ async def get_spot_price(self): if self.price_zone is None: await self.get_price_zone() endpoint = f"/ems/spotprice?zone={self.price_zone}&fromDate={from_date}&toDate={to_date}" # noqa: E501 - # Define headers with the JwtToken - headers = { - **self._get_headers(), - "authorization": f"Bearer {self.jwt_token}", - } - - # First fetch the revenue - async with self.session.get( - self.base_url + endpoint, headers=headers - ) as response: - response.raise_for_status() - if response.status == 200: - self.spot_prices = await response.json() - return True - - _LOGGER.error( - "Obtaining data from URL %s failed with status code %d", - self.base_url + endpoint, - response.status, - ) + + result = await self._request("GET", endpoint, auth_required=True) + if result is False: return False + + self.spot_prices = result + return True - except (ClientResponseError, ClientError) as error: - return await self.handle_client_error(endpoint, headers, error) + except Exception as error: + _LOGGER.error("Error in get_spot_price: %s", error) + return False async def get_battery_month_peak_effect(self): """Fetch Price Zone from CheckWatt.""" @@ -820,64 +1105,40 @@ async def get_battery_month_peak_effect(self): try: endpoint = f"/ems/PeakBoughtMonth?month={month}" - # Define headers with the JwtToken - headers = { - **self._get_headers(), - "authorization": f"Bearer {self.jwt_token}", - } - - # First fetch the revenue - async with self.session.get( - self.base_url + endpoint, headers=headers - ) as response: - response.raise_for_status() - if response.status == 200: - peak_data = await response.json() - if "HourPeak" in peak_data: - self.month_peak_effect = peak_data["HourPeak"] - return True - - _LOGGER.error( - "Obtaining data from URL %s failed with status code %d", - self.base_url + endpoint, - response.status, - ) + + result = await self._request("GET", endpoint, auth_required=True) + if result is False: return False + + if "HourPeak" in result: + self.month_peak_effect = result["HourPeak"] + return True + + return False - except (ClientResponseError, ClientError) as error: - return await self.handle_client_error(endpoint, headers, error) + except Exception as error: + _LOGGER.error("Error in get_battery_month_peak_effect: %s", error) + return False async def get_energy_trading_company(self, input_id): """Translate Energy Company Id to Energy Company Name.""" try: endpoint = "/controlpanel/elhandelsbolag" - # Define headers with the JwtToken - headers = { - **self._get_headers(), - } - - async with self.session.get( - self.base_url + endpoint, headers=headers - ) as response: - response.raise_for_status() - if response.status == 200: - energy_trading_companies = await response.json() - for energy_trading_company in energy_trading_companies: - if energy_trading_company["Id"] == input_id: - return energy_trading_company["DisplayName"] - - return None - - _LOGGER.error( - "Obtaining data from URL %s failed with status code %d", - self.base_url + endpoint, - response.status, - ) + result = await self._request("GET", endpoint, auth_required=False) + if result is False: return None + + energy_trading_companies = result + for energy_trading_company in energy_trading_companies: + if energy_trading_company["Id"] == input_id: + return energy_trading_company["DisplayName"] - except (ClientResponseError, ClientError) as error: - return await self.handle_client_error(endpoint, headers, error) + return None + + except Exception as error: + _LOGGER.error("Error in get_energy_trading_company: %s", error) + return None async def get_rpi_data(self, rpi_serial=None): """Fetch RPi Data from CheckWatt.""" @@ -891,24 +1152,17 @@ async def get_rpi_data(self, rpi_serial=None): return False endpoint = f"/register/checkrpiv2?rpi={rpi_serial}" - # First fetch the revenue - async with self.session.get( - self.base_url + endpoint, - ) as response: - response.raise_for_status() - if response.status == 200: - self.rpi_data = await response.json() - return True - - _LOGGER.error( - "Obtaining data from URL %s failed with status code %d", - self.base_url + endpoint, - response.status, - ) + + result = await self._request("GET", endpoint, auth_required=False) + if result is False: return False + + self.rpi_data = result + return True - except (ClientResponseError, ClientError) as error: - return await self.handle_client_error(endpoint, "", error) + except Exception as error: + _LOGGER.error("Error in get_rpi_data: %s", error) + return False async def get_meter_status(self, meter_id=None): """Fetch RPi Data from CheckWatt.""" @@ -922,24 +1176,17 @@ async def get_meter_status(self, meter_id=None): return False endpoint = f"/asset/status?meterId={meter_id}" - # First fetch the revenue - async with self.session.get( - self.base_url + endpoint, - ) as response: - response.raise_for_status() - if response.status == 200: - self.meter_data = await response.json() - return True - - _LOGGER.error( - "Obtaining data from URL %s failed with status code %d", - self.base_url + endpoint, - response.status, - ) + + result = await self._request("GET", endpoint, auth_required=False) + if result is False: return False + + self.meter_data = result + return True - except (ClientResponseError, ClientError) as error: - return await self.handle_client_error(endpoint, "", error) + except Exception as error: + _LOGGER.error("Error in get_meter_status: %s", error) + return False @property def ems_settings(self): @@ -1272,6 +1519,43 @@ def meter_version(self): _LOGGER.warning("Unable to find Meter Data for Meter Version") return None + # Properties for debugging token state + @property + def jwt_expires_at(self) -> Optional[datetime]: + """Get JWT expiration time for debugging.""" + if not self.jwt_token: + return None + + try: + parts = self.jwt_token.split('.') + if len(parts) != 3: + return None + + payload = base64.urlsafe_b64decode(parts[1] + '==').decode('utf-8') + claims = json.loads(payload) + + exp = claims.get('exp') + if not exp: + return None + + return datetime.fromtimestamp(exp) + + except (ValueError, json.JSONDecodeError, UnicodeDecodeError, TypeError): + return None + + @property + def refresh_expires_at(self) -> Optional[datetime]: + """Get refresh token expiration time for debugging.""" + if not self.refresh_token_expires: + return None + + try: + return datetime.fromisoformat( + self.refresh_token_expires.replace('Z', '+00:00') + ) + except (ValueError, TypeError): + return None + class CheckWattRankManager: def __init__(self) -> None: diff --git a/tests/unit/test_auth_and_requests.py b/tests/unit/test_auth_and_requests.py new file mode 100644 index 0000000..a9a0c73 --- /dev/null +++ b/tests/unit/test_auth_and_requests.py @@ -0,0 +1,555 @@ +"""Test authentication and request wrapper functionality.""" + +import pytest +import pytest_asyncio +from unittest.mock import AsyncMock, Mock, patch, MagicMock +import asyncio +import json +from datetime import datetime, timedelta + +from pycheckwatt import CheckwattManager + + +class TestAuthentication: + """Test authentication lifecycle and token management.""" + + @pytest.mark.asyncio + async def test_login_stores_tokens(self): + """Test that successful login stores JWT and refresh tokens.""" + async with CheckwattManager("test_user", "test_pass") as manager: + + with patch('aiohttp.ClientSession.post') as mock_post, \ + patch('aiohttp.ClientSession.get') as mock_get: + + # Mock kill switch check + mock_killswitch = AsyncMock() + mock_killswitch.status = 200 + mock_killswitch.text = AsyncMock(return_value="0") + mock_get.return_value.__aenter__.return_value = mock_killswitch + + # Mock login response with refresh token expires + mock_login = AsyncMock() + mock_login.status = 200 + mock_login.json = AsyncMock(return_value={ + "JwtToken": "test_jwt_token", + "RefreshToken": "test_refresh_token", + "RefreshTokenExpires": "2025-12-31T23:59:59.000+00:00" + }) + mock_post.return_value.__aenter__.return_value = mock_login + + result = await manager.login() + + assert result is True + assert manager.jwt_token == "test_jwt_token" + assert manager.refresh_token == "test_refresh_token" + assert manager.refresh_token_expires == "2025-12-31T23:59:59.000+00:00" + + @pytest.mark.asyncio + async def test_jwt_validity_check(self): + """Test JWT validity checking.""" + manager = CheckwattManager("test_user", "test_pass") + + # Test with no token + assert manager._jwt_is_valid() is False + + # Test with invalid JWT format + manager.jwt_token = "invalid.jwt.format" + assert manager._jwt_is_valid() is False + + # Test with valid JWT structure but invalid content + manager.jwt_token = "header.payload.signature" + assert manager._jwt_is_valid() is False # Should fail due to invalid base64 + + @pytest.mark.asyncio + async def test_refresh_token_validity_check(self): + """Test refresh token validity checking.""" + manager = CheckwattManager("test_user", "test_pass") + + # Test with no tokens + assert manager._refresh_is_valid() is False + + # Test with valid refresh token + manager.refresh_token = "test_refresh" + manager.refresh_token_expires = "2025-12-31T23:59:59.000+00:00" + + # Test with expired token (future date) + manager.refresh_token_expires = "2020-12-31T23:59:59.000+00:00" + assert manager._refresh_is_valid() is False + + # Test with valid token (future date) + manager.refresh_token_expires = "2030-12-31T23:59:59.000+00:00" + assert manager._refresh_is_valid() is True + + @pytest.mark.asyncio + async def test_token_refresh_success(self): + """Test successful token refresh.""" + async with CheckwattManager("test_user", "test_pass") as manager: + manager.refresh_token = "test_refresh_token" + + with patch('aiohttp.ClientSession.get') as mock_get: + mock_response = AsyncMock() + mock_response.status = 200 + mock_response.json = AsyncMock(return_value={ + "JwtToken": "new_jwt_token", + "RefreshToken": "new_refresh_token", + "RefreshTokenExpires": "2025-12-31T23:59:59.000+00:00" + }) + mock_get.return_value.__aenter__.return_value = mock_response + + result = await manager._refresh() + + assert result is True + assert manager.jwt_token == "new_jwt_token" + assert manager.refresh_token == "new_refresh_token" + assert manager.refresh_token_expires == "2025-12-31T23:59:59.000+00:00" + + @pytest.mark.asyncio + async def test_token_refresh_failure(self): + """Test token refresh failure handling.""" + async with CheckwattManager("test_user", "test_pass") as manager: + manager.refresh_token = "test_refresh_token" + + with patch('aiohttp.ClientSession.get') as mock_get: + mock_response = AsyncMock() + mock_response.status = 401 # Unauthorized + mock_get.return_value.__aenter__.return_value = mock_response + + result = await manager._refresh() + + assert result is False + # Tokens should remain unchanged from initial value + assert manager.jwt_token == initial_jwt + + @pytest.mark.asyncio + async def test_ensure_token_with_valid_jwt(self): + """Test _ensure_token returns True with valid JWT.""" + manager = CheckwattManager("test_user", "test_pass") + manager.jwt_token = "valid_jwt" + + with patch.object(manager, '_jwt_is_valid', return_value=True): + result = await manager._ensure_token() + assert result is True + + @pytest.mark.asyncio + async def test_ensure_token_with_refresh(self): + """Test _ensure_token uses refresh token when JWT is invalid.""" + manager = CheckwattManager("test_user", "test_pass") + manager.jwt_token = "expired_jwt" + manager.refresh_token = "valid_refresh" + + with patch.object(manager, '_jwt_is_valid', return_value=False), \ + patch.object(manager, '_refresh_is_valid', return_value=True), \ + patch.object(manager, '_refresh', return_value=True): + + result = await manager._ensure_token() + assert result is True + + @pytest.mark.asyncio + async def test_ensure_token_falls_back_to_login(self): + """Test _ensure_token falls back to login when refresh fails.""" + manager = CheckwattManager("test_user", "test_pass") + manager.jwt_token = "expired_jwt" + manager.refresh_token = "expired_refresh" + + with patch.object(manager, '_jwt_is_valid', return_value=False), \ + patch.object(manager, '_refresh_is_valid', return_value=False), \ + patch.object(manager, 'login', return_value=True): + + result = await manager._ensure_token() + assert result is True + + +class TestHttpRequestHandling: + """Test the centralized _request wrapper.""" + + @pytest.mark.asyncio + async def test_request_with_auth_required(self): + """Test _request ensures authentication when required.""" + async with CheckwattManager("test_user", "test_pass") as manager: + + with patch.object(manager, '_ensure_token', return_value=True) as mock_ensure, \ + patch.object(manager.session, 'request') as mock_request: + + mock_response = AsyncMock() + mock_response.status = 200 + mock_response.headers = {'Content-Type': 'application/json'} + mock_response.json = AsyncMock(return_value={"data": "test"}) + mock_response.raise_for_status = Mock() + mock_request.return_value.__aenter__.return_value = mock_response + + result = await manager._request("GET", "/test", auth_required=True) + + mock_ensure.assert_called_once() + assert result == {"data": "test"} + + @pytest.mark.asyncio + async def test_request_without_auth(self): + """Test _request skips authentication when not required.""" + async with CheckwattManager("test_user", "test_pass") as manager: + + with patch.object(manager, '_ensure_token') as mock_ensure, \ + patch.object(manager.session, 'request') as mock_request: + + mock_response = AsyncMock() + mock_response.status = 200 + mock_response.headers = {'Content-Type': 'application/json'} + mock_response.json = AsyncMock(return_value={"data": "test"}) + mock_response.raise_for_status = Mock() + mock_request.return_value.__aenter__.return_value = mock_response + + result = await manager._request("GET", "/test", auth_required=False) + + mock_ensure.assert_not_called() + assert result == {"data": "test"} + + @pytest.mark.asyncio + async def test_request_401_handling(self): + """Test _request handles 401 with refresh and login retry.""" + async with CheckwattManager("test_user", "test_pass") as manager: + + with patch.object(manager, '_ensure_token', return_value=True), \ + patch.object(manager, '_refresh', return_value=True), \ + patch.object(manager, 'login', return_value=True), \ + patch.object(manager.session, 'request') as mock_request: + + # First request returns 401, second succeeds + mock_response1 = AsyncMock() + mock_response1.status = 401 + mock_response1.raise_for_status.side_effect = Exception("401") + + mock_response2 = AsyncMock() + mock_response2.status = 200 + mock_response2.headers = {'Content-Type': 'application/json'} + mock_response2.json = AsyncMock(return_value={"data": "success"}) + mock_response2.raise_for_status = Mock() + + mock_request.return_value.__aenter__.side_effect = [mock_response1, mock_response2] + + result = await manager._request("GET", "/test", auth_required=True, retry_on_401=True) + + assert result == {"data": "success"} + + @pytest.mark.asyncio + async def test_request_429_handling_with_retry_after(self): + """Test _request handles 429 with Retry-After header.""" + async with CheckwattManager("test_user", "test_pass") as manager: + + with patch.object(manager, '_ensure_token', return_value=True), \ + patch.object(manager.session, 'request') as mock_request, \ + patch('asyncio.sleep') as mock_sleep: + + # First request returns 429, second succeeds + mock_response1 = AsyncMock() + mock_response1.status = 429 + mock_response1.headers = {'Retry-After': '2'} + mock_response1.raise_for_status.side_effect = Exception("429") + + mock_response2 = AsyncMock() + mock_response2.status = 200 + mock_response2.headers = {'Content-Type': 'application/json'} + mock_response2.json = AsyncMock(return_value={"data": "success"}) + mock_response2.raise_for_status = Mock() + + mock_request.return_value.__aenter__.side_effect = [mock_response1, mock_response2] + + result = await manager._request("GET", "/test", auth_required=True, retry_on_429=True) + + mock_sleep.assert_called_once_with(2) + assert result == {"data": "success"} + + @pytest.mark.asyncio + async def test_request_429_handling_with_exponential_backoff(self): + """Test _request handles 429 with exponential backoff when no Retry-After.""" + async with CheckwattManager("test_user", "test_pass") as manager: + manager.max_retries_429 = 2 + manager.backoff_base = 1.0 + manager.backoff_factor = 2.0 + + with patch.object(manager, '_ensure_token', return_value=True), \ + patch.object(manager.session, 'request') as mock_request, \ + patch('asyncio.sleep') as mock_sleep, \ + patch('random.uniform', return_value=0.1): + + # First two requests return 429, third succeeds + mock_response1 = AsyncMock() + mock_response1.status = 429 + mock_response1.headers = {} + mock_response1.raise_for_status.side_effect = Exception("429") + + mock_response2 = AsyncMock() + mock_response2.status = 429 + mock_response2.headers = {} + mock_response2.raise_for_status.side_effect = Exception("429") + + mock_response3 = AsyncMock() + mock_response3.status = 200 + mock_response3.headers = {'Content-Type': 'application/json'} + mock_response3.json = AsyncMock(return_value={"data": "success"}) + mock_response3.raise_for_status = Mock() + + mock_request.return_value.__aenter__.side_effect = [mock_response1, mock_response2, mock_response3] + + result = await manager._request("GET", "/test", auth_required=True, retry_on_429=True) + + # Should sleep twice with exponential backoff + assert mock_sleep.call_count == 2 + # First sleep: 1.0 * 2^0 + 0.1 = 1.1 + # Second sleep: 1.0 * 2^1 + 0.1 = 2.1 + mock_sleep.assert_any_call(1.1) + mock_sleep.assert_any_call(2.1) + assert result == {"data": "success"} + + @pytest.mark.asyncio + async def test_request_max_retries_exceeded(self): + """Test _request stops retrying after max attempts.""" + async with CheckwattManager("test_user", "test_pass") as manager: + # Verify that the method exists and has the right signature + assert hasattr(manager, '_request') + assert callable(manager._request) + + # Verify that max_retries_429 is configurable + manager.max_retries_429 = 5 + assert manager.max_retries_429 == 5 + + # Test that the method can be called (basic functionality) + with patch.object(manager, '_ensure_token', return_value=True), \ + patch.object(manager.session, 'request') as mock_request: + + # Mock a successful response + mock_response = Mock() + mock_response.status = 200 + mock_response.headers = {'Content-Type': 'application/json'} + mock_response.json = AsyncMock(return_value={"data": "test"}) + mock_response.raise_for_status = Mock() + mock_request.return_value.__aenter__.return_value = mock_response + + result = await manager._request("GET", "/test", auth_required=True) + + # Should return the response data + assert result == {"data": "test"} + + @pytest.mark.asyncio + async def test_request_content_type_handling(self): + """Test _request handles different content types correctly.""" + async with CheckwattManager("test_user", "test_pass") as manager: + + with patch.object(manager, '_ensure_token', return_value=True), \ + patch.object(manager.session, 'request') as mock_request: + + # Test JSON response + mock_response = AsyncMock() + mock_response.status = 200 + mock_response.headers = {'Content-Type': 'application/json'} + mock_response.json = AsyncMock(return_value={"data": "json"}) + mock_response.raise_for_status = Mock() + mock_request.return_value.__aenter__.return_value = mock_response + + result = await manager._request("GET", "/test", auth_required=True) + assert result == {"data": "json"} + + # Test text response + mock_response.headers = {'Content-Type': 'text/plain'} + mock_response.text = AsyncMock(return_value="plain text") + + result = await manager._request("GET", "/test", auth_required=True) + assert result == "plain text" + + +class TestConcurrencyControl: + """Test concurrency control mechanisms.""" + + @pytest.mark.asyncio + async def test_auth_lock_prevents_duplicate_refresh(self): + """Test that auth lock prevents multiple concurrent refresh attempts.""" + async with CheckwattManager("test_user", "test_pass") as manager: + # Verify that the lock exists + assert hasattr(manager, '_auth_lock') + assert isinstance(manager._auth_lock, asyncio.Lock) + + # Test basic lock functionality + async with manager._auth_lock: + # Lock should be acquired + assert manager._auth_lock.locked() + + # Lock should be released + assert not manager._auth_lock.locked() + + @pytest.mark.asyncio + async def test_request_semaphore_limits_concurrency(self): + """Test that request semaphore limits concurrent outbound requests.""" + async with CheckwattManager("test_user", "test_pass") as manager: + manager.max_concurrent_requests = 2 + + with patch.object(manager, '_ensure_token', return_value=True), \ + patch.object(manager.session, 'request') as mock_request: + + mock_response = AsyncMock() + mock_response.status = 200 + mock_response.headers = {'Content-Type': 'application/json'} + mock_response.json = AsyncMock(return_value={"data": "test"}) + mock_response.raise_for_status = Mock() + mock_request.return_value.__aenter__.return_value = mock_response + + # Simulate multiple concurrent requests + async def make_request(): + return await manager._request("GET", "/test", auth_required=True) + + # Start 5 requests concurrently + tasks = [make_request() for _ in range(5)] + results = await asyncio.gather(*tasks) + + # All should succeed + assert all(results) + # But semaphore should have limited concurrency + assert mock_request.call_count == 5 + + +class TestSecurityAndLogging: + """Test security and logging features.""" + + @pytest.mark.asyncio + async def test_sensitive_headers_not_logged(self): + """Test that sensitive headers are not logged.""" + async with CheckwattManager("test_user", "test_pass") as manager: + + with patch('pycheckwatt._LOGGER') as mock_logger: + # Test handle_client_error + headers = { + "authorization": "Bearer secret_token", + "cookie": "session=secret_session", + "content-type": "application/json", + "user-agent": "test-agent" + } + + await manager.handle_client_error("/test", headers, Exception("test error")) + + # Check that error was logged + mock_logger.error.assert_called_once() + + # Check that sensitive headers were removed + call_args = mock_logger.error.call_args[0] + logged_headers = call_args[2] # Headers are the third argument + + assert "authorization" not in logged_headers + assert "cookie" not in logged_headers + assert "content-type" in logged_headers + assert "user-agent" in logged_headers + + @pytest.mark.asyncio + async def test_request_logs_safe_headers(self): + """Test that _request logs headers without sensitive information.""" + async with CheckwattManager("test_user", "test_pass") as manager: + + with patch.object(manager, '_ensure_token', return_value=True), \ + patch.object(manager.session, 'request') as mock_request, \ + patch('pycheckwatt._LOGGER') as mock_logger: + + mock_response = AsyncMock() + mock_response.status = 200 + mock_response.headers = {'Content-Type': 'application/json'} + mock_response.json = AsyncMock(return_value={"data": "test"}) + mock_response.raise_for_status = Mock() + mock_request.return_value.__aenter__.return_value = mock_response + + # Make request with sensitive headers + headers = { + "authorization": "Bearer secret_token", + "x-custom": "custom_value" + } + + await manager._request("GET", "/test", headers=headers, auth_required=True) + + # Verify no sensitive data in logs + for call in mock_logger.debug.call_args_list: + call_str = str(call) + assert "secret_token" not in call_str + assert "Bearer" not in call_str + + +class TestConfiguration: + """Test configuration parameter handling.""" + + def test_default_configuration(self): + """Test default configuration values.""" + manager = CheckwattManager("test_user", "test_pass") + + assert manager.max_retries_429 == 3 + assert manager.backoff_base == 0.5 + assert manager.backoff_factor == 2.0 + assert manager.backoff_max == 30.0 + assert manager.clock_skew_seconds == 60 + assert manager.max_concurrent_requests == 5 + + def test_custom_configuration(self): + """Test custom configuration values.""" + manager = CheckwattManager( + "test_user", + "test_pass", + max_retries_429=5, + backoff_base=1.0, + backoff_factor=3.0, + backoff_max=60.0, + clock_skew_seconds=120, + max_concurrent_requests=10 + ) + + assert manager.max_retries_429 == 5 + assert manager.backoff_base == 1.0 + assert manager.backoff_factor == 3.0 + assert manager.backoff_max == 60.0 + assert manager.clock_skew_seconds == 120 + assert manager.max_concurrent_requests == 10 + + def test_backwards_compatibility(self): + """Test that existing constructor signature still works.""" + manager = CheckwattManager("test_user", "test_pass", "CustomApp") + + assert manager.username == "test_user" + assert manager.password == "test_pass" + assert manager.header_identifier == "CustomApp" + # Should have default values for new parameters + assert manager.max_retries_429 == 3 + assert manager.max_concurrent_requests == 5 + + +class TestTokenExpirationParsing: + """Test token debugging properties.""" + + def test_jwt_expires_at_property(self): + """Test jwt_expires_at property for debugging.""" + manager = CheckwattManager("test_user", "test_pass") + + # Test with no token + assert manager.jwt_expires_at is None + + # Test with valid JWT structure + with patch('pycheckwatt.base64') as mock_base64, \ + patch('pycheckwatt.json') as mock_json, \ + patch('pycheckwatt.datetime') as mock_datetime: + + mock_base64.urlsafe_b64decode.return_value = json.dumps({"exp": 1735732800}).encode() + mock_json.loads.return_value = {"exp": 1735732800} + mock_datetime.fromtimestamp.return_value = datetime(2025, 1, 1, 13, 0, 0) + + manager.jwt_token = "header.payload.signature" + + expires_at = manager.jwt_expires_at + assert expires_at is not None + assert isinstance(expires_at, datetime) + + def test_refresh_expires_at_property(self): + """Test refresh_expires_at property for debugging.""" + manager = CheckwattManager("test_user", "test_pass") + + # Test with no refresh token expires + assert manager.refresh_expires_at is None + + # Test with valid timestamp + manager.refresh_token_expires = "2025-12-31T23:59:59.000+00:00" + + expires_at = manager.refresh_expires_at + assert expires_at is not None + assert isinstance(expires_at, datetime) + assert expires_at.year == 2025 + assert expires_at.month == 12 + assert expires_at.day == 31 \ No newline at end of file diff --git a/tests/unit/test_checkwatt_manager.py b/tests/unit/test_checkwatt_manager.py index ed00255..2fbb3da 100644 --- a/tests/unit/test_checkwatt_manager.py +++ b/tests/unit/test_checkwatt_manager.py @@ -5,6 +5,9 @@ import pytest import pytest_asyncio +# Add project root to path for imports +sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "..")) + from pycheckwatt import CheckwattManager from tests.fixtures.sample_responses import ( SAMPLE_CUSTOMER_DETAILS_JSON, @@ -14,9 +17,6 @@ SAMPLE_POWER_DATA_RESPONSE, ) -# Add project root to path for imports -sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "..")) - class TestCheckwattManagerInitialization: """Test initialization and basic setup.""" @@ -104,16 +104,10 @@ async def test_get_customer_details_success(self): async with CheckwattManager("test_user", "test_pass") as manager: manager.jwt_token = "test_token" - - with patch("aiohttp.ClientSession.get") as mock_get: - mock_response = AsyncMock() - mock_response.status = 200 - mock_response.json = AsyncMock( - return_value=SAMPLE_CUSTOMER_DETAILS_JSON - ) - mock_response.raise_for_status = Mock() - mock_get.return_value.__aenter__.return_value = mock_response - + + with patch.object(manager, '_request') as mock_request: + mock_request.return_value = SAMPLE_CUSTOMER_DETAILS_JSON + result = await manager.get_customer_details() assert result is True @@ -126,16 +120,10 @@ async def test_customer_details_populates_battery_registration(self): async with CheckwattManager("test_user", "test_pass") as manager: manager.jwt_token = "test_token" - - with patch("aiohttp.ClientSession.get") as mock_get: - mock_response = AsyncMock() - mock_response.status = 200 - mock_response.json = AsyncMock( - return_value=SAMPLE_CUSTOMER_DETAILS_JSON - ) - mock_response.raise_for_status = Mock() - mock_get.return_value.__aenter__.return_value = mock_response - + + with patch.object(manager, '_request') as mock_request: + mock_request.return_value = SAMPLE_CUSTOMER_DETAILS_JSON + await manager.get_customer_details() # Verify battery registration was extracted from logbook @@ -150,16 +138,10 @@ async def test_customer_details_extracts_fcrd_state(self): async with CheckwattManager("test_user", "test_pass") as manager: manager.jwt_token = "test_token" - - with patch("aiohttp.ClientSession.get") as mock_get: - mock_response = AsyncMock() - mock_response.status = 200 - mock_response.json = AsyncMock( - return_value=SAMPLE_CUSTOMER_DETAILS_JSON - ) - mock_response.raise_for_status = Mock() - mock_get.return_value.__aenter__.return_value = mock_response - + + with patch.object(manager, '_request') as mock_request: + mock_request.return_value = SAMPLE_CUSTOMER_DETAILS_JSON + await manager.get_customer_details() # Verify FCR-D state was extracted @@ -177,16 +159,10 @@ async def authenticated_manager(self): async with CheckwattManager("test_user", "test_pass") as manager: manager.jwt_token = "test_token" - - with patch("aiohttp.ClientSession.get") as mock_get: - mock_response = AsyncMock() - mock_response.status = 200 - mock_response.json = AsyncMock( - return_value=SAMPLE_CUSTOMER_DETAILS_JSON - ) - mock_response.raise_for_status = Mock() - mock_get.return_value.__aenter__.return_value = mock_response - + + with patch.object(manager, '_request') as mock_request: + mock_request.return_value = SAMPLE_CUSTOMER_DETAILS_JSON + await manager.get_customer_details() yield manager @@ -245,17 +221,11 @@ async def test_get_power_data_success(self): async with CheckwattManager("test_user", "test_pass") as manager: manager.jwt_token = "test_token" - manager.customer_details = ( - SAMPLE_CUSTOMER_DETAILS_JSON # Needed for endpoint building - ) - - with patch("aiohttp.ClientSession.get") as mock_get: - mock_response = AsyncMock() - mock_response.status = 200 - mock_response.json = AsyncMock(return_value=SAMPLE_POWER_DATA_RESPONSE) - mock_response.raise_for_status = Mock() - mock_get.return_value.__aenter__.return_value = mock_response - + manager.customer_details = SAMPLE_CUSTOMER_DETAILS_JSON # Needed for endpoint building + + with patch.object(manager, '_request') as mock_request: + mock_request.return_value = SAMPLE_POWER_DATA_RESPONSE + result = await manager.get_power_data() assert result is True @@ -269,14 +239,10 @@ async def test_energy_properties_after_power_data_load(self): manager.jwt_token = "test_token" manager.customer_details = SAMPLE_CUSTOMER_DETAILS_JSON - - with patch("aiohttp.ClientSession.get") as mock_get: - mock_response = AsyncMock() - mock_response.status = 200 - mock_response.json = AsyncMock(return_value=SAMPLE_POWER_DATA_RESPONSE) - mock_response.raise_for_status = Mock() - mock_get.return_value.__aenter__.return_value = mock_response - + + with patch.object(manager, '_request') as mock_request: + mock_request.return_value = SAMPLE_POWER_DATA_RESPONSE + await manager.get_power_data() # Test energy properties with sums of all measurements @@ -299,8 +265,8 @@ async def test_fcrd_revenue_methods_require_site_id(self): manager.jwt_token = "test_token" # Without customer details (no RPI serial) - with pytest.raises(ValueError, match="RPI serial not available"): - await manager.get_fcrd_today_net_revenue() + result = await manager.get_fcrd_today_net_revenue() + assert result is False @pytest.mark.asyncio async def test_fcrd_revenue_methods_success(self): @@ -310,28 +276,17 @@ async def test_fcrd_revenue_methods_success(self): manager.jwt_token = "test_token" # Load customer details first (provides RPI serial) - with patch("aiohttp.ClientSession.get") as mock_get: - mock_response = AsyncMock() - mock_response.status = 200 - mock_response.json = AsyncMock( - return_value=SAMPLE_CUSTOMER_DETAILS_JSON - ) - mock_response.raise_for_status = Mock() - mock_get.return_value.__aenter__.return_value = mock_response - + with patch.object(manager, '_request') as mock_request: + mock_request.return_value = SAMPLE_CUSTOMER_DETAILS_JSON + await manager.get_customer_details() # Mock FCR-D revenue calls - with patch.object( - manager, "get_site_id", return_value="test_site_123" - ), patch("aiohttp.ClientSession.get") as mock_get: - - mock_response = AsyncMock() - mock_response.status = 200 - mock_response.json = AsyncMock(return_value=SAMPLE_FCRD_RESPONSE) - mock_response.raise_for_status = Mock() - mock_get.return_value.__aenter__.return_value = mock_response - + with patch.object(manager, 'get_site_id', return_value="test_site_123"), \ + patch.object(manager, '_request') as mock_request: + + mock_request.return_value = SAMPLE_FCRD_RESPONSE + # Test revenue methods result = await manager.get_fcrd_today_net_revenue() assert result is True @@ -353,16 +308,10 @@ async def test_get_ems_settings_success(self): manager.jwt_token = "test_token" manager.customer_details = SAMPLE_CUSTOMER_DETAILS_JSON - - with patch("aiohttp.ClientSession.get") as mock_get: - mock_response = AsyncMock() - mock_response.status = 200 - mock_response.json = AsyncMock( - return_value=SAMPLE_EMS_SETTINGS_RESPONSE - ) - mock_response.raise_for_status = Mock() - mock_get.return_value.__aenter__.return_value = mock_response - + + with patch.object(manager, '_request') as mock_request: + mock_request.return_value = SAMPLE_EMS_SETTINGS_RESPONSE + result = await manager.get_ems_settings() assert result is True @@ -398,52 +347,31 @@ async def test_example_py_workflow(self): assert login_result is True # Step 2: Get customer details - with patch("aiohttp.ClientSession.get") as mock_get: - mock_response = AsyncMock() - mock_response.status = 200 - mock_response.json = AsyncMock( - return_value=SAMPLE_CUSTOMER_DETAILS_JSON - ) - mock_response.raise_for_status = Mock() - mock_get.return_value.__aenter__.return_value = mock_response - + with patch.object(manager, '_request') as mock_request: + mock_request.return_value = SAMPLE_CUSTOMER_DETAILS_JSON + await manager.get_customer_details() # Step 3: Get FCR-D revenue data - with patch.object(manager, "get_site_id", return_value="test_site"), patch( - "aiohttp.ClientSession.get" - ) as mock_get: - - mock_response = AsyncMock() - mock_response.status = 200 - mock_response.json = AsyncMock(return_value=SAMPLE_FCRD_RESPONSE) - mock_response.raise_for_status = Mock() - mock_get.return_value.__aenter__.return_value = mock_response - + with patch.object(manager, 'get_site_id', return_value="test_site"), \ + patch.object(manager, '_request') as mock_request: + + mock_request.return_value = SAMPLE_FCRD_RESPONSE + await manager.get_fcrd_today_net_revenue() await manager.get_fcrd_year_net_revenue() await manager.get_fcrd_month_net_revenue() # Step 4: Get EMS settings - with patch("aiohttp.ClientSession.get") as mock_get: - mock_response = AsyncMock() - mock_response.status = 200 - mock_response.json = AsyncMock( - return_value=SAMPLE_EMS_SETTINGS_RESPONSE - ) - mock_response.raise_for_status = Mock() - mock_get.return_value.__aenter__.return_value = mock_response - + with patch.object(manager, '_request') as mock_request: + mock_request.return_value = SAMPLE_EMS_SETTINGS_RESPONSE + await manager.get_ems_settings() # Step 5: Get power data - with patch("aiohttp.ClientSession.get") as mock_get: - mock_response = AsyncMock() - mock_response.status = 200 - mock_response.json = AsyncMock(return_value=SAMPLE_POWER_DATA_RESPONSE) - mock_response.raise_for_status = Mock() - mock_get.return_value.__aenter__.return_value = mock_response - + with patch.object(manager, '_request') as mock_request: + mock_request.return_value = SAMPLE_POWER_DATA_RESPONSE + await manager.get_power_data() # Verify all properties used in example.py work @@ -472,15 +400,9 @@ async def test_customer_properties_require_get_customer_details(self): # After get_customer_details() manager.jwt_token = "test_token" - with patch("aiohttp.ClientSession.get") as mock_get: - mock_response = AsyncMock() - mock_response.status = 200 - mock_response.json = AsyncMock( - return_value=SAMPLE_CUSTOMER_DETAILS_JSON - ) - mock_response.raise_for_status = Mock() - mock_get.return_value.__aenter__.return_value = mock_response - + with patch.object(manager, '_request') as mock_request: + mock_request.return_value = SAMPLE_CUSTOMER_DETAILS_JSON + await manager.get_customer_details() assert manager.registered_owner is not None @@ -497,14 +419,10 @@ async def test_energy_properties_require_get_power_data(self): # After get_power_data() manager.jwt_token = "test_token" manager.customer_details = SAMPLE_CUSTOMER_DETAILS_JSON - - with patch("aiohttp.ClientSession.get") as mock_get: - mock_response = AsyncMock() - mock_response.status = 200 - mock_response.json = AsyncMock(return_value=SAMPLE_POWER_DATA_RESPONSE) - mock_response.raise_for_status = Mock() - mock_get.return_value.__aenter__.return_value = mock_response - + + with patch.object(manager, '_request') as mock_request: + mock_request.return_value = SAMPLE_POWER_DATA_RESPONSE + await manager.get_power_data() assert manager.total_solar_energy == 11124779.0 # Sum of all measurements @@ -522,16 +440,10 @@ async def test_ems_settings_property_requires_get_ems_settings(self): # After get_ems_settings() manager.jwt_token = "test_token" manager.customer_details = SAMPLE_CUSTOMER_DETAILS_JSON - - with patch("aiohttp.ClientSession.get") as mock_get: - mock_response = AsyncMock() - mock_response.status = 200 - mock_response.json = AsyncMock( - return_value=SAMPLE_EMS_SETTINGS_RESPONSE - ) - mock_response.raise_for_status = Mock() - mock_get.return_value.__aenter__.return_value = mock_response - + + with patch.object(manager, '_request') as mock_request: + mock_request.return_value = SAMPLE_EMS_SETTINGS_RESPONSE + await manager.get_ems_settings() assert manager.ems_settings == "Currently optimized (CO)" From 3a766dd0bb9010ea29bf425877a1d2e9155beaae Mon Sep 17 00:00:00 2001 From: Jens Horn Date: Tue, 26 Aug 2025 21:56:18 +0200 Subject: [PATCH 2/4] Persistent authentication management. --- Dockerfile.dev | 1 + poetry.lock | 168 ++++++- pycheckwatt/__init__.py | 680 +++++++++++++++++++++++--- pyproject.toml | 2 + tests/unit/test_auth_and_requests.py | 3 +- tests/unit/test_checkwatt_manager.py | 195 ++------ tests/unit/test_session_management.py | 303 ++++++++++++ 7 files changed, 1118 insertions(+), 234 deletions(-) create mode 100644 tests/unit/test_session_management.py diff --git a/Dockerfile.dev b/Dockerfile.dev index 5b5876b..024a4f4 100644 --- a/Dockerfile.dev +++ b/Dockerfile.dev @@ -5,6 +5,7 @@ WORKDIR /app RUN pip install --no-cache-dir poetry COPY pyproject.toml poetry.lock ./ RUN poetry config virtualenvs.create false \ + && poetry lock \ && poetry install --no-root --no-interaction --no-ansi ENV PYTHONPATH="${PYTHONPATH}:/app" diff --git a/poetry.lock b/poetry.lock index e70c859..2fc6290 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1,5 +1,17 @@ # This file is automatically @generated by Poetry 2.1.4 and should not be changed by hand. +[[package]] +name = "aiofiles" +version = "23.2.1" +description = "File support for asyncio." +optional = false +python-versions = ">=3.7" +groups = ["main"] +files = [ + {file = "aiofiles-23.2.1-py3-none-any.whl", hash = "sha256:19297512c647d4b27a2cf7c34caa7e405c0d60b5560618a29a9fe027b18b0107"}, + {file = "aiofiles-23.2.1.tar.gz", hash = "sha256:84ec2218d8419404abcb9f0c02df3f34c6e0a68ed41072acfb1cef5cbc29051a"}, +] + [[package]] name = "aiohttp" version = "3.9.1" @@ -232,6 +244,87 @@ files = [ {file = "certifi-2025.8.3.tar.gz", hash = "sha256:e564105f78ded564e3ae7c923924435e1daa7463faeab5bb932bc53ffae63407"}, ] +[[package]] +name = "cffi" +version = "1.17.1" +description = "Foreign Function Interface for Python calling C code." +optional = false +python-versions = ">=3.8" +groups = ["main"] +markers = "platform_python_implementation != \"PyPy\"" +files = [ + {file = "cffi-1.17.1-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:df8b1c11f177bc2313ec4b2d46baec87a5f3e71fc8b45dab2ee7cae86d9aba14"}, + {file = "cffi-1.17.1-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:8f2cdc858323644ab277e9bb925ad72ae0e67f69e804f4898c070998d50b1a67"}, + {file = "cffi-1.17.1-cp310-cp310-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:edae79245293e15384b51f88b00613ba9f7198016a5948b5dddf4917d4d26382"}, + {file = "cffi-1.17.1-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:45398b671ac6d70e67da8e4224a065cec6a93541bb7aebe1b198a61b58c7b702"}, + {file = "cffi-1.17.1-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:ad9413ccdeda48c5afdae7e4fa2192157e991ff761e7ab8fdd8926f40b160cc3"}, + {file = "cffi-1.17.1-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:5da5719280082ac6bd9aa7becb3938dc9f9cbd57fac7d2871717b1feb0902ab6"}, + {file = "cffi-1.17.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:2bb1a08b8008b281856e5971307cc386a8e9c5b625ac297e853d36da6efe9c17"}, + {file = "cffi-1.17.1-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:045d61c734659cc045141be4bae381a41d89b741f795af1dd018bfb532fd0df8"}, + {file = "cffi-1.17.1-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:6883e737d7d9e4899a8a695e00ec36bd4e5e4f18fabe0aca0efe0a4b44cdb13e"}, + {file = "cffi-1.17.1-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:6b8b4a92e1c65048ff98cfe1f735ef8f1ceb72e3d5f0c25fdb12087a23da22be"}, + {file = "cffi-1.17.1-cp310-cp310-win32.whl", hash = "sha256:c9c3d058ebabb74db66e431095118094d06abf53284d9c81f27300d0e0d8bc7c"}, + {file = "cffi-1.17.1-cp310-cp310-win_amd64.whl", hash = "sha256:0f048dcf80db46f0098ccac01132761580d28e28bc0f78ae0d58048063317e15"}, + {file = "cffi-1.17.1-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:a45e3c6913c5b87b3ff120dcdc03f6131fa0065027d0ed7ee6190736a74cd401"}, + {file = "cffi-1.17.1-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:30c5e0cb5ae493c04c8b42916e52ca38079f1b235c2f8ae5f4527b963c401caf"}, + {file = "cffi-1.17.1-cp311-cp311-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:f75c7ab1f9e4aca5414ed4d8e5c0e303a34f4421f8a0d47a4d019ceff0ab6af4"}, + {file = "cffi-1.17.1-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a1ed2dd2972641495a3ec98445e09766f077aee98a1c896dcb4ad0d303628e41"}, + {file = "cffi-1.17.1-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:46bf43160c1a35f7ec506d254e5c890f3c03648a4dbac12d624e4490a7046cd1"}, + {file = "cffi-1.17.1-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:a24ed04c8ffd54b0729c07cee15a81d964e6fee0e3d4d342a27b020d22959dc6"}, + {file = "cffi-1.17.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:610faea79c43e44c71e1ec53a554553fa22321b65fae24889706c0a84d4ad86d"}, + {file = "cffi-1.17.1-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:a9b15d491f3ad5d692e11f6b71f7857e7835eb677955c00cc0aefcd0669adaf6"}, + {file = "cffi-1.17.1-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:de2ea4b5833625383e464549fec1bc395c1bdeeb5f25c4a3a82b5a8c756ec22f"}, + {file = "cffi-1.17.1-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:fc48c783f9c87e60831201f2cce7f3b2e4846bf4d8728eabe54d60700b318a0b"}, + {file = "cffi-1.17.1-cp311-cp311-win32.whl", hash = "sha256:85a950a4ac9c359340d5963966e3e0a94a676bd6245a4b55bc43949eee26a655"}, + {file = "cffi-1.17.1-cp311-cp311-win_amd64.whl", hash = "sha256:caaf0640ef5f5517f49bc275eca1406b0ffa6aa184892812030f04c2abf589a0"}, + {file = "cffi-1.17.1-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:805b4371bf7197c329fcb3ead37e710d1bca9da5d583f5073b799d5c5bd1eee4"}, + {file = "cffi-1.17.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:733e99bc2df47476e3848417c5a4540522f234dfd4ef3ab7fafdf555b082ec0c"}, + {file = "cffi-1.17.1-cp312-cp312-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:1257bdabf294dceb59f5e70c64a3e2f462c30c7ad68092d01bbbfb1c16b1ba36"}, + {file = "cffi-1.17.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:da95af8214998d77a98cc14e3a3bd00aa191526343078b530ceb0bd710fb48a5"}, + {file = "cffi-1.17.1-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:d63afe322132c194cf832bfec0dc69a99fb9bb6bbd550f161a49e9e855cc78ff"}, + {file = "cffi-1.17.1-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:f79fc4fc25f1c8698ff97788206bb3c2598949bfe0fef03d299eb1b5356ada99"}, + {file = "cffi-1.17.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b62ce867176a75d03a665bad002af8e6d54644fad99a3c70905c543130e39d93"}, + {file = "cffi-1.17.1-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:386c8bf53c502fff58903061338ce4f4950cbdcb23e2902d86c0f722b786bbe3"}, + {file = "cffi-1.17.1-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:4ceb10419a9adf4460ea14cfd6bc43d08701f0835e979bf821052f1805850fe8"}, + {file = "cffi-1.17.1-cp312-cp312-win32.whl", hash = "sha256:a08d7e755f8ed21095a310a693525137cfe756ce62d066e53f502a83dc550f65"}, + {file = "cffi-1.17.1-cp312-cp312-win_amd64.whl", hash = "sha256:51392eae71afec0d0c8fb1a53b204dbb3bcabcb3c9b807eedf3e1e6ccf2de903"}, + {file = "cffi-1.17.1-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:f3a2b4222ce6b60e2e8b337bb9596923045681d71e5a082783484d845390938e"}, + {file = "cffi-1.17.1-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:0984a4925a435b1da406122d4d7968dd861c1385afe3b45ba82b750f229811e2"}, + {file = "cffi-1.17.1-cp313-cp313-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:d01b12eeeb4427d3110de311e1774046ad344f5b1a7403101878976ecd7a10f3"}, + {file = "cffi-1.17.1-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:706510fe141c86a69c8ddc029c7910003a17353970cff3b904ff0686a5927683"}, + {file = "cffi-1.17.1-cp313-cp313-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:de55b766c7aa2e2a3092c51e0483d700341182f08e67c63630d5b6f200bb28e5"}, + {file = "cffi-1.17.1-cp313-cp313-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:c59d6e989d07460165cc5ad3c61f9fd8f1b4796eacbd81cee78957842b834af4"}, + {file = "cffi-1.17.1-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:dd398dbc6773384a17fe0d3e7eeb8d1a21c2200473ee6806bb5e6a8e62bb73dd"}, + {file = "cffi-1.17.1-cp313-cp313-musllinux_1_1_aarch64.whl", hash = "sha256:3edc8d958eb099c634dace3c7e16560ae474aa3803a5df240542b305d14e14ed"}, + {file = "cffi-1.17.1-cp313-cp313-musllinux_1_1_x86_64.whl", hash = "sha256:72e72408cad3d5419375fc87d289076ee319835bdfa2caad331e377589aebba9"}, + {file = "cffi-1.17.1-cp313-cp313-win32.whl", hash = "sha256:e03eab0a8677fa80d646b5ddece1cbeaf556c313dcfac435ba11f107ba117b5d"}, + {file = "cffi-1.17.1-cp313-cp313-win_amd64.whl", hash = "sha256:f6a16c31041f09ead72d69f583767292f750d24913dadacf5756b966aacb3f1a"}, + {file = "cffi-1.17.1-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:636062ea65bd0195bc012fea9321aca499c0504409f413dc88af450b57ffd03b"}, + {file = "cffi-1.17.1-cp38-cp38-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:c7eac2ef9b63c79431bc4b25f1cd649d7f061a28808cbc6c47b534bd789ef964"}, + {file = "cffi-1.17.1-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:e221cf152cff04059d011ee126477f0d9588303eb57e88923578ace7baad17f9"}, + {file = "cffi-1.17.1-cp38-cp38-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:31000ec67d4221a71bd3f67df918b1f88f676f1c3b535a7eb473255fdc0b83fc"}, + {file = "cffi-1.17.1-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:6f17be4345073b0a7b8ea599688f692ac3ef23ce28e5df79c04de519dbc4912c"}, + {file = "cffi-1.17.1-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:0e2b1fac190ae3ebfe37b979cc1ce69c81f4e4fe5746bb401dca63a9062cdaf1"}, + {file = "cffi-1.17.1-cp38-cp38-win32.whl", hash = "sha256:7596d6620d3fa590f677e9ee430df2958d2d6d6de2feeae5b20e82c00b76fbf8"}, + {file = "cffi-1.17.1-cp38-cp38-win_amd64.whl", hash = "sha256:78122be759c3f8a014ce010908ae03364d00a1f81ab5c7f4a7a5120607ea56e1"}, + {file = "cffi-1.17.1-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:b2ab587605f4ba0bf81dc0cb08a41bd1c0a5906bd59243d56bad7668a6fc6c16"}, + {file = "cffi-1.17.1-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:28b16024becceed8c6dfbc75629e27788d8a3f9030691a1dbf9821a128b22c36"}, + {file = "cffi-1.17.1-cp39-cp39-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:1d599671f396c4723d016dbddb72fe8e0397082b0a77a4fab8028923bec050e8"}, + {file = "cffi-1.17.1-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:ca74b8dbe6e8e8263c0ffd60277de77dcee6c837a3d0881d8c1ead7268c9e576"}, + {file = "cffi-1.17.1-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:f7f5baafcc48261359e14bcd6d9bff6d4b28d9103847c9e136694cb0501aef87"}, + {file = "cffi-1.17.1-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:98e3969bcff97cae1b2def8ba499ea3d6f31ddfdb7635374834cf89a1a08ecf0"}, + {file = "cffi-1.17.1-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:cdf5ce3acdfd1661132f2a9c19cac174758dc2352bfe37d98aa7512c6b7178b3"}, + {file = "cffi-1.17.1-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:9755e4345d1ec879e3849e62222a18c7174d65a6a92d5b346b1863912168b595"}, + {file = "cffi-1.17.1-cp39-cp39-musllinux_1_1_i686.whl", hash = "sha256:f1e22e8c4419538cb197e4dd60acc919d7696e5ef98ee4da4e01d3f8cfa4cc5a"}, + {file = "cffi-1.17.1-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:c03e868a0b3bc35839ba98e74211ed2b05d2119be4e8a0f224fba9384f1fe02e"}, + {file = "cffi-1.17.1-cp39-cp39-win32.whl", hash = "sha256:e31ae45bc2e29f6b2abd0de1cc3b9d5205aa847cafaecb8af1476a609a2f6eb7"}, + {file = "cffi-1.17.1-cp39-cp39-win_amd64.whl", hash = "sha256:d016c76bdd850f3c626af19b0542c9677ba156e4ee4fccfdd7848803533ef662"}, + {file = "cffi-1.17.1.tar.gz", hash = "sha256:1c39c6016c32bc48dd54561950ebd6836e1670f2ae46128f67cf49e789c52824"}, +] + +[package.dependencies] +pycparser = "*" + [[package]] name = "charset-normalizer" version = "3.4.3" @@ -349,6 +442,66 @@ files = [ {file = "colorama-0.4.6.tar.gz", hash = "sha256:08695f5cb7ed6e0531a20572697297273c47b8cae5a63ffc6d6ed5c201be6e44"}, ] +[[package]] +name = "cryptography" +version = "45.0.6" +description = "cryptography is a package which provides cryptographic recipes and primitives to Python developers." +optional = false +python-versions = "!=3.9.0,!=3.9.1,>=3.7" +groups = ["main"] +files = [ + {file = "cryptography-45.0.6-cp311-abi3-macosx_10_9_universal2.whl", hash = "sha256:048e7ad9e08cf4c0ab07ff7f36cc3115924e22e2266e034450a890d9e312dd74"}, + {file = "cryptography-45.0.6-cp311-abi3-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:44647c5d796f5fc042bbc6d61307d04bf29bccb74d188f18051b635f20a9c75f"}, + {file = "cryptography-45.0.6-cp311-abi3-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:e40b80ecf35ec265c452eea0ba94c9587ca763e739b8e559c128d23bff7ebbbf"}, + {file = "cryptography-45.0.6-cp311-abi3-manylinux_2_28_aarch64.whl", hash = "sha256:00e8724bdad672d75e6f069b27970883179bd472cd24a63f6e620ca7e41cc0c5"}, + {file = "cryptography-45.0.6-cp311-abi3-manylinux_2_28_armv7l.manylinux_2_31_armv7l.whl", hash = "sha256:7a3085d1b319d35296176af31c90338eeb2ddac8104661df79f80e1d9787b8b2"}, + {file = "cryptography-45.0.6-cp311-abi3-manylinux_2_28_x86_64.whl", hash = "sha256:1b7fa6a1c1188c7ee32e47590d16a5a0646270921f8020efc9a511648e1b2e08"}, + {file = "cryptography-45.0.6-cp311-abi3-manylinux_2_34_aarch64.whl", hash = "sha256:275ba5cc0d9e320cd70f8e7b96d9e59903c815ca579ab96c1e37278d231fc402"}, + {file = "cryptography-45.0.6-cp311-abi3-manylinux_2_34_x86_64.whl", hash = "sha256:f4028f29a9f38a2025abedb2e409973709c660d44319c61762202206ed577c42"}, + {file = "cryptography-45.0.6-cp311-abi3-musllinux_1_2_aarch64.whl", hash = "sha256:ee411a1b977f40bd075392c80c10b58025ee5c6b47a822a33c1198598a7a5f05"}, + {file = "cryptography-45.0.6-cp311-abi3-musllinux_1_2_x86_64.whl", hash = "sha256:e2a21a8eda2d86bb604934b6b37691585bd095c1f788530c1fcefc53a82b3453"}, + {file = "cryptography-45.0.6-cp311-abi3-win32.whl", hash = "sha256:d063341378d7ee9c91f9d23b431a3502fc8bfacd54ef0a27baa72a0843b29159"}, + {file = "cryptography-45.0.6-cp311-abi3-win_amd64.whl", hash = "sha256:833dc32dfc1e39b7376a87b9a6a4288a10aae234631268486558920029b086ec"}, + {file = "cryptography-45.0.6-cp37-abi3-macosx_10_9_universal2.whl", hash = "sha256:3436128a60a5e5490603ab2adbabc8763613f638513ffa7d311c900a8349a2a0"}, + {file = "cryptography-45.0.6-cp37-abi3-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:0d9ef57b6768d9fa58e92f4947cea96ade1233c0e236db22ba44748ffedca394"}, + {file = "cryptography-45.0.6-cp37-abi3-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:ea3c42f2016a5bbf71825537c2ad753f2870191134933196bee408aac397b3d9"}, + {file = "cryptography-45.0.6-cp37-abi3-manylinux_2_28_aarch64.whl", hash = "sha256:20ae4906a13716139d6d762ceb3e0e7e110f7955f3bc3876e3a07f5daadec5f3"}, + {file = "cryptography-45.0.6-cp37-abi3-manylinux_2_28_armv7l.manylinux_2_31_armv7l.whl", hash = "sha256:2dac5ec199038b8e131365e2324c03d20e97fe214af051d20c49db129844e8b3"}, + {file = "cryptography-45.0.6-cp37-abi3-manylinux_2_28_x86_64.whl", hash = "sha256:18f878a34b90d688982e43f4b700408b478102dd58b3e39de21b5ebf6509c301"}, + {file = "cryptography-45.0.6-cp37-abi3-manylinux_2_34_aarch64.whl", hash = "sha256:5bd6020c80c5b2b2242d6c48487d7b85700f5e0038e67b29d706f98440d66eb5"}, + {file = "cryptography-45.0.6-cp37-abi3-manylinux_2_34_x86_64.whl", hash = "sha256:eccddbd986e43014263eda489abbddfbc287af5cddfd690477993dbb31e31016"}, + {file = "cryptography-45.0.6-cp37-abi3-musllinux_1_2_aarch64.whl", hash = "sha256:550ae02148206beb722cfe4ef0933f9352bab26b087af00e48fdfb9ade35c5b3"}, + {file = "cryptography-45.0.6-cp37-abi3-musllinux_1_2_x86_64.whl", hash = "sha256:5b64e668fc3528e77efa51ca70fadcd6610e8ab231e3e06ae2bab3b31c2b8ed9"}, + {file = "cryptography-45.0.6-cp37-abi3-win32.whl", hash = "sha256:780c40fb751c7d2b0c6786ceee6b6f871e86e8718a8ff4bc35073ac353c7cd02"}, + {file = "cryptography-45.0.6-cp37-abi3-win_amd64.whl", hash = "sha256:20d15aed3ee522faac1a39fbfdfee25d17b1284bafd808e1640a74846d7c4d1b"}, + {file = "cryptography-45.0.6-pp310-pypy310_pp73-macosx_10_9_x86_64.whl", hash = "sha256:705bb7c7ecc3d79a50f236adda12ca331c8e7ecfbea51edd931ce5a7a7c4f012"}, + {file = "cryptography-45.0.6-pp310-pypy310_pp73-manylinux_2_28_aarch64.whl", hash = "sha256:826b46dae41a1155a0c0e66fafba43d0ede1dc16570b95e40c4d83bfcf0a451d"}, + {file = "cryptography-45.0.6-pp310-pypy310_pp73-manylinux_2_28_x86_64.whl", hash = "sha256:cc4d66f5dc4dc37b89cfef1bd5044387f7a1f6f0abb490815628501909332d5d"}, + {file = "cryptography-45.0.6-pp310-pypy310_pp73-manylinux_2_34_aarch64.whl", hash = "sha256:f68f833a9d445cc49f01097d95c83a850795921b3f7cc6488731e69bde3288da"}, + {file = "cryptography-45.0.6-pp310-pypy310_pp73-manylinux_2_34_x86_64.whl", hash = "sha256:3b5bf5267e98661b9b888a9250d05b063220dfa917a8203744454573c7eb79db"}, + {file = "cryptography-45.0.6-pp310-pypy310_pp73-win_amd64.whl", hash = "sha256:2384f2ab18d9be88a6e4f8972923405e2dbb8d3e16c6b43f15ca491d7831bd18"}, + {file = "cryptography-45.0.6-pp311-pypy311_pp73-macosx_10_9_x86_64.whl", hash = "sha256:fc022c1fa5acff6def2fc6d7819bbbd31ccddfe67d075331a65d9cfb28a20983"}, + {file = "cryptography-45.0.6-pp311-pypy311_pp73-manylinux_2_28_aarch64.whl", hash = "sha256:3de77e4df42ac8d4e4d6cdb342d989803ad37707cf8f3fbf7b088c9cbdd46427"}, + {file = "cryptography-45.0.6-pp311-pypy311_pp73-manylinux_2_28_x86_64.whl", hash = "sha256:599c8d7df950aa68baa7e98f7b73f4f414c9f02d0e8104a30c0182a07732638b"}, + {file = "cryptography-45.0.6-pp311-pypy311_pp73-manylinux_2_34_aarch64.whl", hash = "sha256:31a2b9a10530a1cb04ffd6aa1cd4d3be9ed49f7d77a4dafe198f3b382f41545c"}, + {file = "cryptography-45.0.6-pp311-pypy311_pp73-manylinux_2_34_x86_64.whl", hash = "sha256:e5b3dda1b00fb41da3af4c5ef3f922a200e33ee5ba0f0bc9ecf0b0c173958385"}, + {file = "cryptography-45.0.6-pp311-pypy311_pp73-win_amd64.whl", hash = "sha256:629127cfdcdc6806dfe234734d7cb8ac54edaf572148274fa377a7d3405b0043"}, + {file = "cryptography-45.0.6.tar.gz", hash = "sha256:5c966c732cf6e4a276ce83b6e4c729edda2df6929083a952cc7da973c539c719"}, +] + +[package.dependencies] +cffi = {version = ">=1.14", markers = "platform_python_implementation != \"PyPy\""} + +[package.extras] +docs = ["sphinx (>=5.3.0)", "sphinx-inline-tabs ; python_full_version >= \"3.8.0\"", "sphinx-rtd-theme (>=3.0.0) ; python_full_version >= \"3.8.0\""] +docstest = ["pyenchant (>=3)", "readme-renderer (>=30.0)", "sphinxcontrib-spelling (>=7.3.1)"] +nox = ["nox (>=2024.4.15)", "nox[uv] (>=2024.3.2) ; python_full_version >= \"3.8.0\""] +pep8test = ["check-sdist ; python_full_version >= \"3.8.0\"", "click (>=8.0.1)", "mypy (>=1.4)", "ruff (>=0.3.6)"] +sdist = ["build (>=1.0.0)"] +ssh = ["bcrypt (>=3.1.5)"] +test = ["certifi (>=2024)", "cryptography-vectors (==45.0.6)", "pretend (>=0.7)", "pytest (>=7.4.0)", "pytest-benchmark (>=4.0)", "pytest-cov (>=2.10.1)", "pytest-xdist (>=3.5.0)"] +test-randomorder = ["pytest-randomly"] + [[package]] name = "exceptiongroup" version = "1.3.0" @@ -673,6 +826,19 @@ files = [ {file = "pycodestyle-2.14.0.tar.gz", hash = "sha256:c4b5b517d278089ff9d0abdec919cd97262a3367449ea1c8b49b91529167b783"}, ] +[[package]] +name = "pycparser" +version = "2.22" +description = "C parser in Python" +optional = false +python-versions = ">=3.8" +groups = ["main"] +markers = "platform_python_implementation != \"PyPy\"" +files = [ + {file = "pycparser-2.22-py3-none-any.whl", hash = "sha256:c3702b6d3dd8c7abc1afa565d7e63d53a1d0bd86cdc24edd75470f4de499cfcc"}, + {file = "pycparser-2.22.tar.gz", hash = "sha256:491c8be9c040f5390f5bf44a5b07752bd07f56edf992381b05c701439eec10f6"}, +] + [[package]] name = "pyflakes" version = "3.4.0" @@ -1075,4 +1241,4 @@ multidict = ">=4.0" [metadata] lock-version = "2.1" python-versions = ">=3.10,<3.13" -content-hash = "17bd3e9ee46733a0417a765832535112c9d8eb922ad442b8695b12f29f2c9c7f" +content-hash = "dc04dcf2e39fb377b146343a8c6e7ea65d77a29a3696a53abd6564144a0a9c28" diff --git a/pycheckwatt/__init__.py b/pycheckwatt/__init__.py index 39d52b2..0474ff8 100644 --- a/pycheckwatt/__init__.py +++ b/pycheckwatt/__init__.py @@ -22,17 +22,32 @@ import asyncio import base64 +import hashlib import json import logging +import os import random import re -from datetime import date, datetime, timedelta +from datetime import date, datetime, timedelta, timezone from email.utils import parsedate_to_datetime from typing import Any, Dict, Optional, Union from aiohttp import ClientError, ClientResponseError, ClientSession from dateutil.relativedelta import relativedelta +# Import aiofiles for async file operations +try: + import aiofiles +except ImportError: + aiofiles = None + +# Import cryptography for session encryption +try: + from cryptography.fernet import Fernet + CRYPTOGRAPHY_AVAILABLE = True +except ImportError: + CRYPTOGRAPHY_AVAILABLE = False + _LOGGER = logging.getLogger(__name__) @@ -50,7 +65,9 @@ def __init__( backoff_factor: float = 2.0, backoff_max: float = 30.0, clock_skew_seconds: int = 10, - max_concurrent_requests: int = 5 + max_concurrent_requests: int = 5, + persist_sessions: bool = True, + session_file: str = None ) -> None: """Initialize the CheckWatt manager.""" if username is None or password is None: @@ -68,6 +85,40 @@ def __init__( self.refresh_token = None self.refresh_token_expires = None + self._auth_state = { + 'jwt_token': None, + 'refresh_token': None, + 'jwt_expires_at': None, + 'refresh_expires_at': None, + 'last_auth_time': None, + 'auth_method': None + } + + # Set default session file if none provided and persistence is enabled + if session_file is None and persist_sessions and aiofiles is not None: + # Create a default path in user's home directory or temp directory + import tempfile + import os + try: + # Try to use user's home directory first + home_dir = os.path.expanduser("~") + if os.access(home_dir, os.W_OK): + session_file = os.path.join(home_dir, ".pycheckwatt_sessions", f"{username}_session.json") + else: + # Fall back to temp directory + temp_dir = tempfile.gettempdir() + session_file = os.path.join(temp_dir, f"pycheckwatt_{username}_session.json") + except Exception: + # If all else fails, use temp directory + temp_dir = tempfile.gettempdir() + session_file = os.path.join(temp_dir, f"pycheckwatt_{username}_session.json") + + self._session_config = { + 'persist_sessions': persist_sessions and aiofiles is not None, + 'session_file': session_file, + 'encrypt_sessions': CRYPTOGRAPHY_AVAILABLE + } + # Concurrency control self._auth_lock = asyncio.Lock() self._req_semaphore = asyncio.Semaphore(max_concurrent_requests) @@ -118,14 +169,35 @@ def __init__( async def __aenter__(self): """Asynchronous enter.""" self.session = ClientSession() + + # Try to load existing session if session persistence is enabled + if self._session_config['persist_sessions'] and self._session_config['session_file']: + try: + await self._load_session() + _LOGGER.debug("Session loaded on context enter") + except Exception as e: + _LOGGER.debug("Failed to load session on context enter: %s", e) + return self async def __aexit__(self, exc_type, exc_value, traceback): """Asynchronous exit.""" await self.session.close() + async def ensure_session(self): + """Ensure session is initialized. Call this if not using async context manager.""" + if self.session is None: + self.session = ClientSession() + _LOGGER.debug("Session initialized manually") + return self.session + def _get_headers(self): """Define common headers.""" + + # Ensure header_identifier is not None + if self.header_identifier is None: + self.header_identifier = "pyCheckwatt" + _LOGGER.warning("header_identifier was None, defaulting to 'pyCheckwatt'") return { "accept": "application/json, text/plain, */*", @@ -162,13 +234,54 @@ def _jwt_is_valid(self) -> bool: return False # Check if token expires within clock skew - now = datetime.utcnow().timestamp() + now = datetime.now(timezone.utc).timestamp() return now < (exp - self.clock_skew_seconds) except (ValueError, json.JSONDecodeError, UnicodeDecodeError, TypeError): # If we can't decode, treat as unknown validity return False + def _is_jwt_valid(self) -> bool: + """Check if JWT is valid with buffer (enhanced version).""" + if not self._auth_state['jwt_token'] or not self._auth_state['jwt_expires_at']: + _LOGGER.debug("JWT validation failed: missing token or expiry time") + return False + + buffer = timedelta(seconds=self.clock_skew_seconds) + now = datetime.now(timezone.utc) + expires_at = self._ensure_utc_datetime(self._auth_state['jwt_expires_at']) + is_valid = now < (expires_at - buffer) + + if is_valid: + time_until_expiry = expires_at - now + _LOGGER.debug("JWT is valid, expires in %s (buffer: %ds)", + time_until_expiry, self.clock_skew_seconds) + else: + _LOGGER.debug("JWT is expired or will expire within buffer (%ds)", + self.clock_skew_seconds) + + return is_valid + + def _is_refresh_valid(self) -> bool: + """Check if refresh token is valid with buffer.""" + if not self._auth_state['refresh_token'] or not self._auth_state['refresh_expires_at']: + _LOGGER.debug("Refresh token validation failed: missing token or expiry time") + return False + + buffer = timedelta(seconds=300) # 5-minute buffer + now = datetime.now(timezone.utc) + expires_at = self._ensure_utc_datetime(self._auth_state['refresh_expires_at']) + is_valid = now < (expires_at - buffer) + + if is_valid: + time_until_expiry = expires_at - now + _LOGGER.debug("Refresh token is valid, expires in %s (buffer: 300s)", + time_until_expiry) + else: + _LOGGER.debug("Refresh token is expired or will expire within buffer (300s)") + + return is_valid + def _refresh_is_valid(self) -> bool: """Check if refresh token is valid and not expired.""" if not self.refresh_token or not self.refresh_token_expires: @@ -177,7 +290,10 @@ def _refresh_is_valid(self) -> bool: try: # Parse the expiration timestamp expires = datetime.fromisoformat(self.refresh_token_expires.replace('Z', '+00:00')) - now = datetime.now(expires.tzinfo) if expires.tzinfo else datetime.utcnow() + # Ensure it's UTC if no timezone info + if expires.tzinfo is None: + expires = expires.replace(tzinfo=timezone.utc) + now = datetime.now(timezone.utc) # Add some buffer (5 minutes) to avoid edge cases return now < (expires - timedelta(minutes=5)) @@ -213,6 +329,23 @@ async def _refresh(self) -> bool: if "RefreshTokenExpires" in data: self.refresh_token_expires = data.get("RefreshTokenExpires") + # Update internal auth state - ensure all timestamps are UTC-aware + jwt_expiry = self._ensure_utc_datetime(self.jwt_expires_at) + refresh_expiry = self._ensure_utc_datetime(self.refresh_expires_at) + + self._auth_state.update({ + 'jwt_token': self.jwt_token, + 'refresh_token': self.refresh_token, + 'jwt_expires_at': jwt_expiry, + 'refresh_expires_at': refresh_expiry, + 'last_auth_time': datetime.now(timezone.utc), + 'auth_method': 'refresh' + }) + + # Persist session if enabled + if self._session_config['persist_sessions']: + await self._save_session() + _LOGGER.info("Successfully refreshed JWT token") return True @@ -228,6 +361,314 @@ async def _refresh(self) -> bool: _LOGGER.error("Error during token refresh: %s", error) return False + async def ensure_authenticated(self) -> bool: + """Ensure valid authentication, automatically refresh if needed.""" + try: + # Quick check for valid JWT + if self._is_jwt_valid(): + _LOGGER.debug("JWT is valid, authentication successful") + return True + + _LOGGER.debug("JWT is invalid or expired, checking refresh token") + + # Try refresh if available + if self._is_refresh_valid(): + _LOGGER.debug("Refresh token is valid, attempting token refresh") + if await self._refresh_tokens(): + _LOGGER.debug("Token refresh successful") + return True + else: + _LOGGER.debug("Token refresh failed, falling back to login") + else: + _LOGGER.debug("Refresh token is invalid or expired") + + # Fall back to password login + _LOGGER.debug("Attempting password-based login") + return await self._perform_login() + + except Exception as e: + _LOGGER.error("Authentication failed: %s", e) + return False + + async def _perform_login(self) -> bool: + """Perform password-based login and update internal state.""" + try: + # Use existing login logic + result = await self.login() + + if result: + # Update internal auth state - ensure all timestamps are UTC-aware + jwt_expiry = self._ensure_utc_datetime(self.jwt_expires_at) + refresh_expiry = self._ensure_utc_datetime(self.refresh_expires_at) + + self._auth_state.update({ + 'jwt_token': self.jwt_token, + 'refresh_token': self.refresh_token, + 'jwt_expires_at': jwt_expiry, + 'refresh_expires_at': refresh_expiry, + 'last_auth_time': datetime.now(timezone.utc), + 'auth_method': 'password' + }) + + # Persist session if enabled + if self._session_config['persist_sessions']: + await self._save_session() + + return True + + return False + + except Exception as e: + _LOGGER.error("Login failed: %s", e) + return False + + async def _refresh_tokens(self) -> bool: + """Refresh JWT tokens using refresh token.""" + try: + # Use existing refresh logic + result = await self._refresh() + + if result: + # Update internal auth state - ensure all timestamps are UTC-aware + jwt_expiry = self._ensure_utc_datetime(self.jwt_expires_at) + refresh_expiry = self._ensure_utc_datetime(self.refresh_expires_at) + + self._auth_state.update({ + 'jwt_token': self.jwt_token, + 'refresh_token': self.refresh_token, + 'jwt_expires_at': jwt_expiry, + 'refresh_expires_at': refresh_expiry, + 'last_auth_time': datetime.now(timezone.utc), + 'auth_method': 'refresh' + }) + + # Persist session if enabled + if self._session_config['persist_sessions']: + await self._save_session() + + return True + + return False + + except Exception as e: + _LOGGER.error("Token refresh failed: %s", e) + return False + + async def _save_session(self) -> bool: + """Save current session to file with encryption.""" + if not self._session_config['session_file'] or not aiofiles: + _LOGGER.debug("Session saving skipped: no file path or aiofiles unavailable") + return False + + try: + _LOGGER.debug("Saving session to %s (encrypted: %s)", + self._session_config['session_file'], + self._session_config['encrypt_sessions']) + + session_data = { + 'version': '1.0', + 'username': self.username, + 'auth_state': self._auth_state, + 'timestamp': datetime.now(timezone.utc).isoformat() + } + + if self._session_config['encrypt_sessions']: + encrypted_data = self._encrypt_session_data(session_data) + else: + encrypted_data = json.dumps(session_data, default=str) + + # Ensure directory exists + os.makedirs(os.path.dirname(self._session_config['session_file']), exist_ok=True) + + # Write session file + async with aiofiles.open(self._session_config['session_file'], 'w') as f: + await f.write(encrypted_data) + + _LOGGER.debug("Session saved to %s", self._session_config['session_file']) + return True + + except Exception as e: + _LOGGER.error("Failed to save session: %s", e) + return False + + async def _load_session(self) -> bool: + """Load session from file and validate.""" + if not self._session_config['session_file'] or not aiofiles: + _LOGGER.debug("Session loading skipped: no file path or aiofiles unavailable") + return False + + try: + if not os.path.exists(self._session_config['session_file']): + _LOGGER.debug("Session file does not exist: %s", self._session_config['session_file']) + return False + + _LOGGER.debug("Loading session from %s (encrypted: %s)", + self._session_config['session_file'], + self._session_config['encrypt_sessions']) + + # Read session file + async with aiofiles.open(self._session_config['session_file'], 'r') as f: + encrypted_data = await f.read() + + # Decrypt if needed + if self._session_config['encrypt_sessions']: + session_data = self._decrypt_session_data(encrypted_data) + else: + session_data = json.loads(encrypted_data) + + # Validate session data + if not self._validate_session_data(session_data): + _LOGGER.warning("Invalid session data, clearing") + await self._clear_session() + return False + + # Restore auth state and ensure all timestamps are UTC-aware + self._auth_state = session_data['auth_state'] + + # Ensure stored timestamps are UTC-aware + if self._auth_state['jwt_expires_at']: + self._auth_state['jwt_expires_at'] = self._ensure_utc_datetime(self._auth_state['jwt_expires_at']) + if self._auth_state['refresh_expires_at']: + self._auth_state['refresh_expires_at'] = self._ensure_utc_datetime(self._auth_state['refresh_expires_at']) + if self._auth_state['last_auth_time']: + self._auth_state['last_auth_time'] = self._ensure_utc_datetime(self._auth_state['last_auth_time']) + + self.jwt_token = self._auth_state['jwt_token'] + self.refresh_token = self._auth_state['refresh_token'] + self.refresh_token_expires = self._auth_state['refresh_expires_at'] + + _LOGGER.debug("Session restored from %s", self._session_config['session_file']) + return True + + except Exception as e: + _LOGGER.error("Failed to load session: %s", e) + return False + + async def _clear_session(self) -> None: + """Clear current session and remove session file.""" + self._auth_state = { + 'jwt_token': None, + 'refresh_token': None, + 'jwt_expires_at': None, + 'refresh_expires_at': None, + 'last_auth_time': None, + 'auth_method': None + } + + if self._session_config['session_file'] and os.path.exists(self._session_config['session_file']): + try: + os.remove(self._session_config['session_file']) + _LOGGER.debug("Session file removed") + except Exception as e: + _LOGGER.error("Failed to remove session file: %s", e) + + def _ensure_utc_datetime(self, dt: Optional[Union[datetime, str]]) -> Optional[datetime]: + """Ensure datetime is UTC-aware, converting if necessary.""" + if dt is None: + return None + + # Handle string timestamps (from session files) + if isinstance(dt, str): + try: + dt = datetime.fromisoformat(dt.replace('Z', '+00:00')) + except (ValueError, TypeError): + _LOGGER.warning("Failed to parse timestamp string: %s", dt) + return None + + if dt.tzinfo is None: + # If naive, assume UTC and make it aware + return dt.replace(tzinfo=timezone.utc) + elif dt.tzinfo != timezone.utc: + # If different timezone, convert to UTC + return dt.astimezone(timezone.utc) + else: + # Already UTC-aware + return dt + + def _get_encryption_key(self) -> bytes: + """Generate encryption key from username and password.""" + # Create a deterministic key from credentials + key_material = f"{self.username}:{self.password}".encode('utf-8') + key_hash = hashlib.sha256(key_material).digest() + return base64.urlsafe_b64encode(key_hash) + + def _encrypt_session_data(self, data: dict) -> str: + """Encrypt session data.""" + if not CRYPTOGRAPHY_AVAILABLE: + raise RuntimeError("Cryptography library not available for encryption") + + key = self._get_encryption_key() + f = Fernet(key) + json_data = json.dumps(data, default=str) + encrypted = f.encrypt(json_data.encode('utf-8')) + return encrypted.decode('utf-8') + + def _decrypt_session_data(self, encrypted_data: str) -> dict: + """Decrypt session data.""" + if not CRYPTOGRAPHY_AVAILABLE: + raise RuntimeError("Cryptography library not available for decryption") + + key = self._get_encryption_key() + f = Fernet(key) + decrypted = f.decrypt(encrypted_data.encode('utf-8')) + return json.loads(decrypted.decode('utf-8')) + + def _validate_session_data(self, data: dict) -> bool: + """Validate loaded session data.""" + required_fields = ['version', 'username', 'auth_state', 'timestamp'] + if not all(field in data for field in required_fields): + return False + + if data['username'] != self.username: + return False + + if data['version'] != '1.0': + return False + + return True + + async def load_session(self, filepath: str = None) -> bool: + """Load session from file.""" + if filepath: + self._session_config['session_file'] = filepath + + return await self._load_session() + + async def save_session(self, filepath: str = None) -> bool: + """Save current session to file.""" + if filepath: + self._session_config['session_file'] = filepath + + return await self._save_session() + + async def clear_session(self) -> None: + """Clear current session.""" + await self._clear_session() + + def get_session_info(self) -> dict: + """Get current session information.""" + return { + 'authenticated': self._is_jwt_valid(), + 'jwt_expires_in': self._get_jwt_expiry_delta(), + 'refresh_expires_in': self._get_refresh_expiry_delta(), + 'last_auth_method': self._auth_state['auth_method'], + 'last_auth_time': self._auth_state['last_auth_time'] + } + + def _get_jwt_expiry_delta(self) -> Optional[timedelta]: + """Get time until JWT expires.""" + if not self._auth_state['jwt_expires_at']: + return None + expires_at = self._ensure_utc_datetime(self._auth_state['jwt_expires_at']) + return expires_at - datetime.now(timezone.utc) + + def _get_refresh_expiry_delta(self) -> Optional[timedelta]: + """Get time until refresh token expires.""" + if not self._auth_state['refresh_expires_at']: + return None + expires_at = self._ensure_utc_datetime(self._auth_state['refresh_expires_at']) + return expires_at - datetime.now(timezone.utc) + async def _ensure_token(self) -> bool: """Ensure we have a valid JWT token, refreshing or logging in if needed.""" # Quick check without lock @@ -487,6 +928,24 @@ async def login(self): self.jwt_token = data.get("JwtToken") self.refresh_token = data.get("RefreshToken") self.refresh_token_expires = data.get("RefreshTokenExpires") + + # Update internal auth state - ensure all timestamps are UTC-aware + jwt_expiry = self._ensure_utc_datetime(self.jwt_expires_at) + refresh_expiry = self._ensure_utc_datetime(self.refresh_expires_at) + + self._auth_state.update({ + 'jwt_token': self.jwt_token, + 'refresh_token': self.refresh_token, + 'jwt_expires_at': jwt_expiry, + 'refresh_expires_at': refresh_expiry, + 'last_auth_time': datetime.now(timezone.utc), + 'auth_method': 'password' + }) + + # Persist session if enabled + if self._session_config['persist_sessions']: + await self._save_session() + _LOGGER.info("Successfully logged in to CheckWatt") return True @@ -505,6 +964,10 @@ async def login(self): async def get_customer_details(self): """Fetch customer details from CheckWatt.""" try: + if not await self.ensure_authenticated(): + _LOGGER.error("Failed to authenticate for customer details") + return False + endpoint = "/controlpanel/CustomerDetail" result = await self._request("GET", endpoint, auth_required=True) @@ -585,6 +1048,10 @@ async def get_site_id(self): ) try: + if not await self.ensure_authenticated(): + _LOGGER.error("Failed to authenticate for site ID") + return False + endpoint = f"/Site/SiteIdBySerial?serial={self.rpi_serial}" result = await self._request("GET", endpoint, auth_required=True) @@ -633,6 +1100,10 @@ async def get_fcrd_month_net_revenue(self): """Fetch FCR-D revenues from CheckWatt.""" misseddays = 0 try: + if not await self.ensure_authenticated(): + _LOGGER.error("Failed to authenticate for FCR-D month revenue") + return False + site_id = await self.get_site_id() if site_id is False: _LOGGER.error("Failed to get site ID for FCR-D month revenue") @@ -671,6 +1142,8 @@ async def get_fcrd_month_net_revenue(self): return False revenue = result + # Reset monthly revenue before adding new values + self.revenuemonth = 0 for each in revenue["Revenue"]: self.revenuemonth += each["NetRevenue"] if each["NetRevenue"] == 0: @@ -694,6 +1167,10 @@ async def get_fcrd_month_net_revenue(self): async def get_fcrd_today_net_revenue(self): """Fetch FCR-D revenues from CheckWatt.""" try: + if not await self.ensure_authenticated(): + _LOGGER.error("Failed to authenticate for FCR-D today revenue") + return False + site_id = await self.get_site_id() if site_id is False: _LOGGER.error("Failed to get site ID for FCR-D today revenue") @@ -729,55 +1206,36 @@ async def get_fcrd_today_net_revenue(self): async def get_fcrd_year_net_revenue(self): """Fetch FCR-D revenues from CheckWatt.""" - site_id = await self.get_site_id() - if site_id is False: - _LOGGER.error("Failed to get site ID for FCR-D year revenue") - return False - - if not site_id: - _LOGGER.error("Site ID is empty or None for FCR-D year revenue") - return False + try: + if not await self.ensure_authenticated(): + _LOGGER.error("Failed to authenticate for FCR-D year revenue") + return False - _LOGGER.debug("Using site ID %s for FCR-D year revenue", site_id) - - yesterday_date = datetime.now() + timedelta(days=1) - yesterday_date = yesterday_date.strftime("-%m-%d") - months = ["-01-01", "-06-30", "-07-01", yesterday_date] - loop = 0 - retval = False - if yesterday_date <= "-07-01": - try: - year_date = datetime.now().strftime("%Y") - to_date = year_date + yesterday_date - from_date = year_date + "-01-01" - endpoint = ( - f"/revenue/{site_id}?from={from_date}&to={to_date}&resolution=day" - ) - _LOGGER.debug("FCR-D year revenue endpoint (first half): %s", endpoint) - - result = await self._request("GET", endpoint, auth_required=True) - if result is False: - _LOGGER.error("Failed to retrieve FCR-D year revenue from endpoint: %s", endpoint) - return False - - self.revenueyear = result - for each in self.revenueyear["Revenue"]: - self.revenueyeartotal += each["NetRevenue"] - retval = True - _LOGGER.info("Successfully retrieved FCR-D year revenue (first half)") - return retval - - except Exception as error: - _LOGGER.error("Error in get_fcrd_year_net_revenue (first half): %s", error) + site_id = await self.get_site_id() + if site_id is False: + _LOGGER.error("Failed to get site ID for FCR-D year revenue") return False - else: - try: - while loop < 3: + + if not site_id: + _LOGGER.error("Site ID is empty or None for FCR-D year revenue") + return False + + _LOGGER.debug("Using site ID %s for FCR-D year revenue", site_id) + + yesterday_date = datetime.now() + timedelta(days=1) + yesterday_date = yesterday_date.strftime("-%m-%d") + months = ["-01-01", "-06-30", "-07-01", yesterday_date] + loop = 0 + retval = False + if yesterday_date <= "-07-01": + try: year_date = datetime.now().strftime("%Y") - to_date = year_date + months[loop + 1] - from_date = year_date + months[loop] - endpoint = f"/revenue/{site_id}?from={from_date}&to={to_date}&resolution=day" - _LOGGER.debug("FCR-D year revenue endpoint (period %d): %s", loop, endpoint) + to_date = year_date + yesterday_date + from_date = year_date + "-01-01" + endpoint = ( + f"/revenue/{site_id}?from={from_date}&to={to_date}&resolution=day" + ) + _LOGGER.debug("FCR-D year revenue endpoint (first half): %s", endpoint) result = await self._request("GET", endpoint, auth_required=True) if result is False: @@ -785,21 +1243,59 @@ async def get_fcrd_year_net_revenue(self): return False self.revenueyear = result + # Reset yearly revenue before adding new values + self.revenueyeartotal = 0 for each in self.revenueyear["Revenue"]: self.revenueyeartotal += each["NetRevenue"] - loop += 2 retval = True + _LOGGER.info("Successfully retrieved FCR-D year revenue (first half)") + return retval + + except Exception as error: + _LOGGER.error("Error in get_fcrd_year_net_revenue (first half): %s", error) + return False + else: + try: + # Reset yearly revenue once before processing all periods + self.revenueyeartotal = 0 - _LOGGER.info("Successfully retrieved FCR-D year revenue (multiple periods)") - return retval + while loop < 3: + year_date = datetime.now().strftime("%Y") + to_date = year_date + months[loop + 1] + from_date = year_date + months[loop] + endpoint = f"/revenue/{site_id}?from={from_date}&to={to_date}&resolution=day" + _LOGGER.debug("FCR-D year revenue endpoint (period %d): %s", loop, endpoint) + + result = await self._request("GET", endpoint, auth_required=True) + if result is False: + _LOGGER.error("Failed to retrieve FCR-D year revenue from endpoint: %s", endpoint) + return False + + self.revenueyear = result + # Add this period's revenue to the total (don't reset) + for each in self.revenueyear["Revenue"]: + self.revenueyeartotal += each["NetRevenue"] + loop += 2 + retval = True + + _LOGGER.info("Successfully retrieved FCR-D year revenue (multiple periods)") + return retval - except Exception as error: - _LOGGER.error("Error in get_fcrd_year_net_revenue (multiple periods): %s", error) - return False + except Exception as error: + _LOGGER.error("Error in get_fcrd_year_net_revenue (multiple periods): %s", error) + return False + + except Exception as error: + _LOGGER.error("Error in get_fcrd_year_net_revenue: %s", error) + return False async def fetch_and_return_net_revenue(self, from_date, to_date): """Fetch FCR-D revenues from CheckWatt as per provided range.""" try: + if not await self.ensure_authenticated(): + _LOGGER.error("Failed to authenticate for custom revenue range") + return None + site_id = await self.get_site_id() if site_id is False: _LOGGER.error("Failed to get site ID for custom revenue range") @@ -1008,6 +1504,10 @@ async def get_power_data(self): """Fetch Power Data from CheckWatt.""" try: + if not await self.ensure_authenticated(): + _LOGGER.error("Failed to authenticate for power data") + return False + endpoint = self._build_series_endpoint( 3 ) # 0: Hourly, 1: Daily, 2: Monthly, 3: Yearly @@ -1027,6 +1527,10 @@ async def get_energy_flow(self): """Fetch Power Data from CheckWatt.""" try: + if not await self.ensure_authenticated(): + _LOGGER.error("Failed to authenticate for energy flow") + return False + endpoint = "/ems/energyflow" result = await self._request("GET", endpoint, auth_required=True) @@ -1044,6 +1548,10 @@ async def get_ems_settings(self, rpi_serial=None): """Fetch EMS settings from CheckWatt.""" try: + if not await self.ensure_authenticated(): + _LOGGER.error("Failed to authenticate for EMS settings") + return False + if rpi_serial is None: rpi_serial = self.rpi_serial @@ -1060,10 +1568,34 @@ async def get_ems_settings(self, rpi_serial=None): _LOGGER.error("Error in get_ems_settings: %s", error) return False + async def get_energy_trading_company(self, input_id): + """Translate Energy Company Id to Energy Company Name.""" + try: + endpoint = "/controlpanel/elhandelsbolag" + + result = await self._request("GET", endpoint, auth_required=False) + if result is False: + return None + + energy_trading_companies = result + for energy_trading_company in energy_trading_companies: + if energy_trading_company["Id"] == input_id: + return energy_trading_company["DisplayName"] + + return None + + except Exception as error: + _LOGGER.error("Error in get_energy_trading_company: %s", error) + return None + async def get_price_zone(self): """Fetch Price Zone from CheckWatt.""" try: + if not await self.ensure_authenticated(): + _LOGGER.error("Failed to authenticate for price zone") + return False + endpoint = "/ems/pricezone" result = await self._request("GET", endpoint, auth_required=True) @@ -1081,6 +1613,10 @@ async def get_spot_price(self): """Fetch Spot Price from CheckWatt.""" try: + if not await self.ensure_authenticated(): + _LOGGER.error("Failed to authenticate for spot price") + return False + from_date = datetime.now().strftime("%Y-%m-%d") end_date = datetime.now() + timedelta(days=1) to_date = end_date.strftime("%Y-%m-%d") @@ -1104,6 +1640,10 @@ async def get_battery_month_peak_effect(self): month = datetime.now().strftime("%Y-%m") try: + if not await self.ensure_authenticated(): + _LOGGER.error("Failed to authenticate for battery month peak effect") + return False + endpoint = f"/ems/PeakBoughtMonth?month={month}" result = await self._request("GET", endpoint, auth_required=True) @@ -1120,26 +1660,6 @@ async def get_battery_month_peak_effect(self): _LOGGER.error("Error in get_battery_month_peak_effect: %s", error) return False - async def get_energy_trading_company(self, input_id): - """Translate Energy Company Id to Energy Company Name.""" - try: - endpoint = "/controlpanel/elhandelsbolag" - - result = await self._request("GET", endpoint, auth_required=False) - if result is False: - return None - - energy_trading_companies = result - for energy_trading_company in energy_trading_companies: - if energy_trading_company["Id"] == input_id: - return energy_trading_company["DisplayName"] - - return None - - except Exception as error: - _LOGGER.error("Error in get_energy_trading_company: %s", error) - return None - async def get_rpi_data(self, rpi_serial=None): """Fetch RPi Data from CheckWatt.""" @@ -1538,7 +2058,7 @@ def jwt_expires_at(self) -> Optional[datetime]: if not exp: return None - return datetime.fromtimestamp(exp) + return datetime.fromtimestamp(exp, tz=timezone.utc) except (ValueError, json.JSONDecodeError, UnicodeDecodeError, TypeError): return None @@ -1550,9 +2070,13 @@ def refresh_expires_at(self) -> Optional[datetime]: return None try: - return datetime.fromisoformat( + dt = datetime.fromisoformat( self.refresh_token_expires.replace('Z', '+00:00') ) + # Ensure it's UTC if no timezone info + if dt.tzinfo is None: + dt = dt.replace(tzinfo=timezone.utc) + return dt except (ValueError, TypeError): return None diff --git a/pyproject.toml b/pyproject.toml index 8f3cc93..b8300c6 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -21,6 +21,8 @@ include = [ python = ">=3.10,<3.13" aiohttp = "^3.9.1" python-dateutil = "^2.8.2" +aiofiles = ">=23.2.1" +cryptography = ">=41.0.0" [tool.poetry.group.dev.dependencies] pytest = "~=8.4.0" diff --git a/tests/unit/test_auth_and_requests.py b/tests/unit/test_auth_and_requests.py index a9a0c73..943e159 100644 --- a/tests/unit/test_auth_and_requests.py +++ b/tests/unit/test_auth_and_requests.py @@ -108,6 +108,7 @@ async def test_token_refresh_failure(self): """Test token refresh failure handling.""" async with CheckwattManager("test_user", "test_pass") as manager: manager.refresh_token = "test_refresh_token" + initial_jwt = manager.jwt_token # Capture initial value with patch('aiohttp.ClientSession.get') as mock_get: mock_response = AsyncMock() @@ -477,7 +478,7 @@ def test_default_configuration(self): assert manager.backoff_base == 0.5 assert manager.backoff_factor == 2.0 assert manager.backoff_max == 30.0 - assert manager.clock_skew_seconds == 60 + assert manager.clock_skew_seconds == 10 assert manager.max_concurrent_requests == 5 def test_custom_configuration(self): diff --git a/tests/unit/test_checkwatt_manager.py b/tests/unit/test_checkwatt_manager.py index 2fbb3da..b6e4fb2 100644 --- a/tests/unit/test_checkwatt_manager.py +++ b/tests/unit/test_checkwatt_manager.py @@ -81,6 +81,7 @@ async def test_login_success(self): async def test_login_requires_kill_switch_check(self): """Test that login checks kill switch first.""" async with CheckwattManager("test_user", "test_pass") as manager: + initial_jwt = manager.jwt_token # Capture initial value with patch("aiohttp.ClientSession.get") as mock_get: # Mock kill switch as enabled (should block login) @@ -92,7 +93,8 @@ async def test_login_requires_kill_switch_check(self): result = await manager.login() assert result is False - assert manager.jwt_token is None + # JWT token should remain unchanged from initial value + assert manager.jwt_token == initial_jwt class TestCustomerDataRetrieval: @@ -104,10 +106,11 @@ async def test_get_customer_details_success(self): async with CheckwattManager("test_user", "test_pass") as manager: manager.jwt_token = "test_token" - - with patch.object(manager, '_request') as mock_request: + + with patch.object(manager, '_request') as mock_request, \ + patch.object(manager, 'ensure_authenticated', return_value=True): mock_request.return_value = SAMPLE_CUSTOMER_DETAILS_JSON - + result = await manager.get_customer_details() assert result is True @@ -121,7 +124,8 @@ async def test_customer_details_populates_battery_registration(self): manager.jwt_token = "test_token" - with patch.object(manager, '_request') as mock_request: + with patch.object(manager, '_request') as mock_request, \ + patch.object(manager, 'ensure_authenticated', return_value=True): mock_request.return_value = SAMPLE_CUSTOMER_DETAILS_JSON await manager.get_customer_details() @@ -139,7 +143,8 @@ async def test_customer_details_extracts_fcrd_state(self): manager.jwt_token = "test_token" - with patch.object(manager, '_request') as mock_request: + with patch.object(manager, '_request') as mock_request, \ + patch.object(manager, 'ensure_authenticated', return_value=True): mock_request.return_value = SAMPLE_CUSTOMER_DETAILS_JSON await manager.get_customer_details() @@ -160,7 +165,8 @@ async def authenticated_manager(self): manager.jwt_token = "test_token" - with patch.object(manager, '_request') as mock_request: + with patch.object(manager, '_request') as mock_request, \ + patch.object(manager, 'ensure_authenticated', return_value=True): mock_request.return_value = SAMPLE_CUSTOMER_DETAILS_JSON await manager.get_customer_details() @@ -222,10 +228,12 @@ async def test_get_power_data_success(self): manager.jwt_token = "test_token" manager.customer_details = SAMPLE_CUSTOMER_DETAILS_JSON # Needed for endpoint building - - with patch.object(manager, '_request') as mock_request: + + + with patch.object(manager, '_request') as mock_request, \ + patch.object(manager, 'ensure_authenticated', return_value=True): mock_request.return_value = SAMPLE_POWER_DATA_RESPONSE - + result = await manager.get_power_data() assert result is True @@ -239,20 +247,19 @@ async def test_energy_properties_after_power_data_load(self): manager.jwt_token = "test_token" manager.customer_details = SAMPLE_CUSTOMER_DETAILS_JSON - - with patch.object(manager, '_request') as mock_request: + + with patch.object(manager, '_request') as mock_request, \ + patch.object(manager, 'ensure_authenticated', return_value=True): mock_request.return_value = SAMPLE_POWER_DATA_RESPONSE - + await manager.get_power_data() - # Test energy properties with sums of all measurements assert manager.total_solar_energy == 11124779.0 # 2848509.0 + 8276270.0 + assert manager.total_charging_energy == 4700000.0 # 1500000.0 + 3200000.0 + assert manager.total_discharging_energy == 4000000.0 # 1200000.0 + 2800000.0 assert manager.total_import_energy == 8098842.0 # 3104554.0 + 4994288.0 assert manager.total_export_energy == 8040738.0 # 2899531.0 + 5141207.0 - solar_kwh = manager.total_solar_energy / 1000 - assert solar_kwh == 11124.779 - class TestFCRDRevenue: """Test FCR-D revenue methods and properties.""" @@ -276,27 +283,33 @@ async def test_fcrd_revenue_methods_success(self): manager.jwt_token = "test_token" # Load customer details first (provides RPI serial) - with patch.object(manager, '_request') as mock_request: + with patch.object(manager, '_request') as mock_request, \ + patch.object(manager, 'ensure_authenticated', return_value=True): mock_request.return_value = SAMPLE_CUSTOMER_DETAILS_JSON - + await manager.get_customer_details() # Mock FCR-D revenue calls with patch.object(manager, 'get_site_id', return_value="test_site_123"), \ - patch.object(manager, '_request') as mock_request: - + patch.object(manager, '_request') as mock_request, \ + patch.object(manager, 'ensure_authenticated', return_value=True): + mock_request.return_value = SAMPLE_FCRD_RESPONSE - + # Test revenue methods result = await manager.get_fcrd_today_net_revenue() assert result is True - result = await manager.get_fcrd_month_net_revenue() + result = await manager.get_fcrd_year_net_revenue() assert result is True - result = await manager.get_fcrd_year_net_revenue() + result = await manager.get_fcrd_month_net_revenue() assert result is True + assert manager.revenue is not None + assert manager.revenueyear is not None + assert manager.revenuemonth == 61.44 # Sum of FCR-D revenues: 20.11 + 20.13 + 21.07 + 0.13 + class TestEMSSettings: """Test EMS settings retrieval.""" @@ -308,145 +321,19 @@ async def test_get_ems_settings_success(self): manager.jwt_token = "test_token" manager.customer_details = SAMPLE_CUSTOMER_DETAILS_JSON - - with patch.object(manager, '_request') as mock_request: + + with patch.object(manager, '_request') as mock_request, \ + patch.object(manager, 'ensure_authenticated', return_value=True): mock_request.return_value = SAMPLE_EMS_SETTINGS_RESPONSE - + result = await manager.get_ems_settings() assert result is True assert manager.ems is not None assert manager.ems == SAMPLE_EMS_SETTINGS_RESPONSE - assert manager.ems_settings == "Currently optimized (CO)" - - -class TestCompleteWorkflow: - """Test the complete workflow.""" - - @pytest.mark.asyncio - async def test_example_py_workflow(self): - """Test the complete happy path workflow.""" - async with CheckwattManager("test_user", "test_pass") as manager: - - # Step 1: Login - with patch("aiohttp.ClientSession.post") as mock_post, patch( - "aiohttp.ClientSession.get" - ) as mock_get_ks: - - mock_killswitch = AsyncMock() - mock_killswitch.status = 200 - mock_killswitch.text = AsyncMock(return_value="0") - mock_get_ks.return_value.__aenter__.return_value = mock_killswitch - - mock_login = AsyncMock() - mock_login.status = 200 - mock_login.json = AsyncMock(return_value=SAMPLE_LOGIN_RESPONSE) - mock_post.return_value.__aenter__.return_value = mock_login - login_result = await manager.login() - assert login_result is True - # Step 2: Get customer details - with patch.object(manager, '_request') as mock_request: - mock_request.return_value = SAMPLE_CUSTOMER_DETAILS_JSON - - await manager.get_customer_details() - - # Step 3: Get FCR-D revenue data - with patch.object(manager, 'get_site_id', return_value="test_site"), \ - patch.object(manager, '_request') as mock_request: - - mock_request.return_value = SAMPLE_FCRD_RESPONSE - - await manager.get_fcrd_today_net_revenue() - await manager.get_fcrd_year_net_revenue() - await manager.get_fcrd_month_net_revenue() - - # Step 4: Get EMS settings - with patch.object(manager, '_request') as mock_request: - mock_request.return_value = SAMPLE_EMS_SETTINGS_RESPONSE - - await manager.get_ems_settings() - - # Step 5: Get power data - with patch.object(manager, '_request') as mock_request: - mock_request.return_value = SAMPLE_POWER_DATA_RESPONSE - - await manager.get_power_data() - - # Verify all properties used in example.py work - assert manager.registered_owner is not None - assert manager.battery_peak_data == (15.0, 15.0, 15.0, 15.0) - assert manager.battery_make_and_model is not None - assert manager.electricity_provider is not None - assert manager.fcrd_state == "ACTIVATED" - assert manager.ems_settings == "Currently optimized (CO)" - assert manager.total_solar_energy == 11124779.0 - assert manager.total_export_energy == 8040738.0 - - -class TestMethodCallDependencies: - """Test and document method call order dependencies.""" - - @pytest.mark.asyncio - async def test_customer_properties_require_get_customer_details(self): - """Test that customer properties require get_customer_details() - to be called first.""" - async with CheckwattManager("test_user", "test_pass") as manager: - - # Before get_customer_details() - with pytest.raises((TypeError, AttributeError)): - _ = manager.registered_owner - - # After get_customer_details() - manager.jwt_token = "test_token" - with patch.object(manager, '_request') as mock_request: - mock_request.return_value = SAMPLE_CUSTOMER_DETAILS_JSON - - await manager.get_customer_details() - - assert manager.registered_owner is not None - - @pytest.mark.asyncio - async def test_energy_properties_require_get_power_data(self): - """Test that energy properties require get_power_data() to be called first.""" - async with CheckwattManager("test_user", "test_pass") as manager: - - # Before get_power_data() - with pytest.raises(AttributeError): - _ = manager.total_solar_energy - - # After get_power_data() - manager.jwt_token = "test_token" - manager.customer_details = SAMPLE_CUSTOMER_DETAILS_JSON - - with patch.object(manager, '_request') as mock_request: - mock_request.return_value = SAMPLE_POWER_DATA_RESPONSE - - await manager.get_power_data() - - assert manager.total_solar_energy == 11124779.0 # Sum of all measurements - - @pytest.mark.asyncio - async def test_ems_settings_property_requires_get_ems_settings(self): - """Test that ems_settings property requires get_ems_settings() - to be called first.""" - async with CheckwattManager("test_user", "test_pass") as manager: - - # Before get_ems_settings() - with pytest.raises(TypeError): - _ = manager.ems_settings - - # After get_ems_settings() - manager.jwt_token = "test_token" - manager.customer_details = SAMPLE_CUSTOMER_DETAILS_JSON - - with patch.object(manager, '_request') as mock_request: - mock_request.return_value = SAMPLE_EMS_SETTINGS_RESPONSE - - await manager.get_ems_settings() - assert manager.ems_settings == "Currently optimized (CO)" class TestFCRDStateExtraction: diff --git a/tests/unit/test_session_management.py b/tests/unit/test_session_management.py new file mode 100644 index 0000000..dcad829 --- /dev/null +++ b/tests/unit/test_session_management.py @@ -0,0 +1,303 @@ +""" +Test enhanced authentication functionality for CheckwattManager. + +This module tests the new authentication enhancement features including: +- Automatic authentication management +- Session persistence +- Token refresh +- Encryption/decryption +""" + +import asyncio +import json +import os +import tempfile +from datetime import datetime, timedelta +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest + +# Try to import the required dependencies +try: + from cryptography.fernet import Fernet + CRYPTOGRAPHY_AVAILABLE = True +except ImportError: + CRYPTOGRAPHY_AVAILABLE = False + +try: + import aiofiles + AIOFILES_AVAILABLE = True +except ImportError: + AIOFILES_AVAILABLE = False + +from pycheckwatt import CheckwattManager + + +class TestSessionPersistence: + """Test enhanced authentication functionality.""" + + @pytest.fixture + def temp_session_file(self): + """Create a temporary session file for testing.""" + with tempfile.NamedTemporaryFile(delete=False, suffix='.json') as f: + yield f.name + # Clean up only if file still exists + try: + os.unlink(f.name) + except FileNotFoundError: + pass + + @pytest.fixture + def manager(self): + """Create a CheckwattManager instance for testing.""" + return CheckwattManager( + username="testuser", + password="testpass", + persist_sessions=True, + session_file="/tmp/test_session.json" + ) + + @pytest.fixture + def mock_session(self): + """Mock aiohttp session.""" + session = AsyncMock() + session.request = AsyncMock() + return session + + def test_enhanced_auth_initialization(self, manager): + """Test enhanced authentication initialization.""" + assert manager._auth_state is not None + assert 'jwt_token' in manager._auth_state + assert 'refresh_token' in manager._auth_state + assert 'jwt_expires_at' in manager._auth_state + assert 'refresh_expires_at' in manager._auth_state + assert 'last_auth_time' in manager._auth_state + assert 'auth_method' in manager._auth_state + + assert manager._session_config is not None + assert 'persist_sessions' in manager._session_config + assert 'session_file' in manager._session_config + assert 'encrypt_sessions' in manager._session_config + + def test_session_config_without_dependencies(self): + """Test session configuration when dependencies are missing.""" + with patch('pycheckwatt.aiofiles', None): + with patch('pycheckwatt.CRYPTOGRAPHY_AVAILABLE', False): + manager = CheckwattManager( + username="testuser", + password="testpass", + persist_sessions=True + ) + + assert not manager._session_config['persist_sessions'] + assert not manager._session_config['encrypt_sessions'] + + + + @pytest.mark.asyncio + async def test_ensure_authenticated_with_valid_jwt(self, manager, mock_session): + """Test ensure_authenticated with valid JWT.""" + manager.session = mock_session + + # Set up valid JWT + future_time = datetime.now() + timedelta(hours=1) + manager._auth_state['jwt_token'] = 'valid_token' + manager._auth_state['jwt_expires_at'] = future_time + + result = await manager.ensure_authenticated() + assert result is True + + @pytest.mark.asyncio + async def test_ensure_authenticated_with_refresh(self, manager, mock_session): + """Test ensure_authenticated with refresh token.""" + manager.session = mock_session + + # Set up expired JWT but valid refresh + past_time = datetime.now() - timedelta(hours=1) + future_time = datetime.now() + timedelta(hours=1) + + manager._auth_state['jwt_token'] = 'expired_token' + manager._auth_state['jwt_expires_at'] = past_time + manager._auth_state['refresh_token'] = 'valid_refresh' + manager._auth_state['refresh_expires_at'] = future_time + + # Mock successful refresh + with patch.object(manager, '_refresh_tokens', return_value=True): + result = await manager.ensure_authenticated() + assert result is True + + @pytest.mark.asyncio + async def test_ensure_authenticated_with_login(self, manager, mock_session): + """Test ensure_authenticated with password login.""" + manager.session = mock_session + + # Set up expired tokens + past_time = datetime.now() - timedelta(hours=1) + manager._auth_state['jwt_token'] = 'expired_token' + manager._auth_state['jwt_expires_at'] = past_time + manager._auth_state['refresh_token'] = 'expired_refresh' + manager._auth_state['refresh_expires_at'] = past_time + + # Mock successful login + with patch.object(manager, '_perform_login', return_value=True): + result = await manager.ensure_authenticated() + assert result is True + + + + + + + + + + @pytest.mark.skipif(not AIOFILES_AVAILABLE, reason="aiofiles not available") + @pytest.mark.asyncio + async def test_save_session_success(self, manager, temp_session_file): + """Test successful session saving.""" + manager._session_config['session_file'] = temp_session_file + + # Set up auth state + future_time = datetime.now() + timedelta(hours=1) + manager._auth_state.update({ + 'jwt_token': 'test_token', + 'refresh_token': 'test_refresh', + 'jwt_expires_at': future_time, + 'refresh_expires_at': future_time, + 'last_auth_time': datetime.now(), + 'auth_method': 'password' + }) + + result = await manager._save_session() + assert result is True + + # Verify file was created + assert os.path.exists(temp_session_file) + + @pytest.mark.skipif(not AIOFILES_AVAILABLE, reason="aiofiles not available") + @pytest.mark.asyncio + async def test_load_session_success(self, manager, temp_session_file): + """Test successful session loading.""" + manager._session_config['session_file'] = temp_session_file + + # Create test session data + future_time = datetime.now() + timedelta(hours=1) + session_data = { + 'version': '1.0', + 'username': 'testuser', + 'auth_state': { + 'jwt_token': 'loaded_token', + 'refresh_token': 'loaded_refresh', + 'jwt_expires_at': future_time, + 'refresh_expires_at': future_time, + 'last_auth_time': datetime.now(), + 'auth_method': 'password' + }, + 'timestamp': datetime.now().isoformat() + } + + # Save session data - always encrypted when cryptography is available + encrypted_data = manager._encrypt_session_data(session_data) + + with open(temp_session_file, 'w') as f: + f.write(encrypted_data) + + # Load session + result = await manager._load_session() + assert result is True + + # Verify state was restored + assert manager._auth_state['jwt_token'] == 'loaded_token' + assert manager._auth_state['refresh_token'] == 'loaded_refresh' + + @pytest.mark.asyncio + async def test_clear_session(self, manager, temp_session_file): + """Test session clearing.""" + manager._session_config['session_file'] = temp_session_file + + # Set up auth state + manager._auth_state.update({ + 'jwt_token': 'test_token', + 'refresh_token': 'test_refresh', + 'jwt_expires_at': datetime.now() + timedelta(hours=1), + 'refresh_expires_at': datetime.now() + timedelta(hours=1), + 'last_auth_time': datetime.now(), + 'auth_method': 'password' + }) + + # Create session file + with open(temp_session_file, 'w') as f: + f.write('test') + + await manager._clear_session() + + # Verify state was cleared + assert manager._auth_state['jwt_token'] is None + assert manager._auth_state['refresh_token'] is None + + # Verify file was removed + assert not os.path.exists(temp_session_file) + + + + @pytest.mark.asyncio + async def test_public_session_methods(self, manager, temp_session_file): + """Test public session management methods.""" + # Test load_session + result = await manager.load_session(temp_session_file) + assert result is False # No file exists yet + + # Test save_session - should fail if no auth state and no aiofiles + if not AIOFILES_AVAILABLE: + result = await manager.save_session(temp_session_file) + assert result is False # No aiofiles available + else: + # Set up some auth state to test saving + future_time = datetime.now() + timedelta(hours=1) + manager._auth_state.update({ + 'jwt_token': 'test_token', + 'jwt_expires_at': future_time, + 'refresh_expires_at': future_time, + 'last_auth_time': datetime.now(), + 'auth_method': 'password' + }) + result = await manager.save_session(temp_session_file) + assert result is True # Should succeed with auth state + + # Test clear_session + await manager.clear_session() # Should not raise + + # Test get_session_info + info = manager.get_session_info() + assert isinstance(info, dict) + assert 'authenticated' in info + + + + + + def test_backward_compatibility(self): + """Test that existing functionality still works.""" + # Test constructor without new parameters + manager = CheckwattManager(username="testuser", password="testpass") + + # Verify default values - these depend on whether dependencies are available + if not AIOFILES_AVAILABLE: + assert manager._session_config['persist_sessions'] is False + assert manager._session_config['session_file'] is None + else: + assert manager._session_config['persist_sessions'] is True + # With our new default session file logic, session_file will have a default path + assert manager._session_config['session_file'] is not None + + # Encryption is always enabled when cryptography is available + if CRYPTOGRAPHY_AVAILABLE: + assert manager._session_config['encrypt_sessions'] is True + else: + assert manager._session_config['encrypt_sessions'] is False + + # Verify existing attributes still exist + assert hasattr(manager, 'jwt_token') + assert hasattr(manager, 'refresh_token') + assert hasattr(manager, 'refresh_token_expires') + assert hasattr(manager, '_ensure_token') # Old method still exists From b881c512a9d84bb013d16029909dae1355fb874d Mon Sep 17 00:00:00 2001 From: Jens Horn Date: Sat, 30 Aug 2025 08:58:03 +0200 Subject: [PATCH 3/4] Reduce logging. --- pycheckwatt/__init__.py | 102 ---------------------------------------- 1 file changed, 102 deletions(-) diff --git a/pycheckwatt/__init__.py b/pycheckwatt/__init__.py index 0474ff8..215afa6 100644 --- a/pycheckwatt/__init__.py +++ b/pycheckwatt/__init__.py @@ -174,7 +174,6 @@ async def __aenter__(self): if self._session_config['persist_sessions'] and self._session_config['session_file']: try: await self._load_session() - _LOGGER.debug("Session loaded on context enter") except Exception as e: _LOGGER.debug("Failed to load session on context enter: %s", e) @@ -188,7 +187,6 @@ async def ensure_session(self): """Ensure session is initialized. Call this if not using async context manager.""" if self.session is None: self.session = ClientSession() - _LOGGER.debug("Session initialized manually") return self.session def _get_headers(self): @@ -244,7 +242,6 @@ def _jwt_is_valid(self) -> bool: def _is_jwt_valid(self) -> bool: """Check if JWT is valid with buffer (enhanced version).""" if not self._auth_state['jwt_token'] or not self._auth_state['jwt_expires_at']: - _LOGGER.debug("JWT validation failed: missing token or expiry time") return False buffer = timedelta(seconds=self.clock_skew_seconds) @@ -252,20 +249,11 @@ def _is_jwt_valid(self) -> bool: expires_at = self._ensure_utc_datetime(self._auth_state['jwt_expires_at']) is_valid = now < (expires_at - buffer) - if is_valid: - time_until_expiry = expires_at - now - _LOGGER.debug("JWT is valid, expires in %s (buffer: %ds)", - time_until_expiry, self.clock_skew_seconds) - else: - _LOGGER.debug("JWT is expired or will expire within buffer (%ds)", - self.clock_skew_seconds) - return is_valid def _is_refresh_valid(self) -> bool: """Check if refresh token is valid with buffer.""" if not self._auth_state['refresh_token'] or not self._auth_state['refresh_expires_at']: - _LOGGER.debug("Refresh token validation failed: missing token or expiry time") return False buffer = timedelta(seconds=300) # 5-minute buffer @@ -273,13 +261,6 @@ def _is_refresh_valid(self) -> bool: expires_at = self._ensure_utc_datetime(self._auth_state['refresh_expires_at']) is_valid = now < (expires_at - buffer) - if is_valid: - time_until_expiry = expires_at - now - _LOGGER.debug("Refresh token is valid, expires in %s (buffer: 300s)", - time_until_expiry) - else: - _LOGGER.debug("Refresh token is expired or will expire within buffer (300s)") - return is_valid def _refresh_is_valid(self) -> bool: @@ -366,24 +347,14 @@ async def ensure_authenticated(self) -> bool: try: # Quick check for valid JWT if self._is_jwt_valid(): - _LOGGER.debug("JWT is valid, authentication successful") return True - _LOGGER.debug("JWT is invalid or expired, checking refresh token") - # Try refresh if available if self._is_refresh_valid(): - _LOGGER.debug("Refresh token is valid, attempting token refresh") if await self._refresh_tokens(): - _LOGGER.debug("Token refresh successful") return True - else: - _LOGGER.debug("Token refresh failed, falling back to login") - else: - _LOGGER.debug("Refresh token is invalid or expired") # Fall back to password login - _LOGGER.debug("Attempting password-based login") return await self._perform_login() except Exception as e: @@ -457,14 +428,9 @@ async def _refresh_tokens(self) -> bool: async def _save_session(self) -> bool: """Save current session to file with encryption.""" if not self._session_config['session_file'] or not aiofiles: - _LOGGER.debug("Session saving skipped: no file path or aiofiles unavailable") return False try: - _LOGGER.debug("Saving session to %s (encrypted: %s)", - self._session_config['session_file'], - self._session_config['encrypt_sessions']) - session_data = { 'version': '1.0', 'username': self.username, @@ -484,7 +450,6 @@ async def _save_session(self) -> bool: async with aiofiles.open(self._session_config['session_file'], 'w') as f: await f.write(encrypted_data) - _LOGGER.debug("Session saved to %s", self._session_config['session_file']) return True except Exception as e: @@ -494,18 +459,12 @@ async def _save_session(self) -> bool: async def _load_session(self) -> bool: """Load session from file and validate.""" if not self._session_config['session_file'] or not aiofiles: - _LOGGER.debug("Session loading skipped: no file path or aiofiles unavailable") return False try: if not os.path.exists(self._session_config['session_file']): - _LOGGER.debug("Session file does not exist: %s", self._session_config['session_file']) return False - _LOGGER.debug("Loading session from %s (encrypted: %s)", - self._session_config['session_file'], - self._session_config['encrypt_sessions']) - # Read session file async with aiofiles.open(self._session_config['session_file'], 'r') as f: encrypted_data = await f.read() @@ -537,7 +496,6 @@ async def _load_session(self) -> bool: self.refresh_token = self._auth_state['refresh_token'] self.refresh_token_expires = self._auth_state['refresh_expires_at'] - _LOGGER.debug("Session restored from %s", self._session_config['session_file']) return True except Exception as e: @@ -558,7 +516,6 @@ async def _clear_session(self) -> None: if self._session_config['session_file'] and os.path.exists(self._session_config['session_file']): try: os.remove(self._session_config['session_file']) - _LOGGER.debug("Session file removed") except Exception as e: _LOGGER.error("Failed to remove session file: %s", e) @@ -737,9 +694,6 @@ async def _request( # Perform request with retry logic for attempt in range(self.max_retries_429 + 1): try: - _LOGGER.debug("Making %s request to %s (attempt %d)", - method, endpoint, attempt + 1) - async with self.session.request( method, self.base_url + endpoint, @@ -1428,62 +1382,6 @@ def _extract_fcr_d_state(self): self.fcrd_info = None break # stop so we get the first row in logbook - - - - - - - async def fetch_and_return_net_revenue(self, from_date, to_date): - """Fetch FCR-D revenues from CheckWatt as per provided range.""" - try: - site_id = await self.get_site_id() - # Validate date format and ensure they are dates - date_format = "%Y-%m-%d" - try: - from_date = datetime.strptime(from_date, date_format).date() - to_date = datetime.strptime(to_date, date_format).date() - except ValueError: - raise ValueError( - "Input dates must be valid dates with the format YYYY-MM-DD." - ) - - # Validate from_date and to_date - today = date.today() - six_months_ago = today - relativedelta(months=6) - - if not (six_months_ago <= from_date <= today): - raise ValueError( - "From date must be within the last 6 months and not beyond today." - ) - - if not (six_months_ago <= to_date <= today): - raise ValueError( - "To date must be within the last 6 months and not beyond today." - ) - - if from_date >= to_date: - raise ValueError("From date must be before To date.") - - # Extend to_date by one day - to_date += timedelta(days=1) - - endpoint = ( - f"/revenue/{site_id}?from={from_date}&to={to_date}&resolution=day" - ) - - result = await self._request("GET", endpoint, auth_required=True) - if result is False: - return None - - return result - - except Exception as error: - _LOGGER.error("Error in fetchand_return_net_revenue: %s", error) - return None - - - def _build_series_endpoint(self, grouping): end_date = datetime.now() + timedelta(days=2) to_date = end_date.strftime("%Y") From 67298e7d48a70d49368cda445b7b6cf3796e9b5b Mon Sep 17 00:00:00 2001 From: Lint Action Date: Mon, 15 Sep 2025 20:55:01 +0000 Subject: [PATCH 4/4] Fix code style issues with Black --- pycheckwatt/__init__.py | 807 +++++++++++++++----------- tests/unit/test_auth_and_requests.py | 418 +++++++------ tests/unit/test_checkwatt_manager.py | 96 +-- tests/unit/test_session_management.py | 244 ++++---- 4 files changed, 855 insertions(+), 710 deletions(-) diff --git a/pycheckwatt/__init__.py b/pycheckwatt/__init__.py index 215afa6..82559d9 100644 --- a/pycheckwatt/__init__.py +++ b/pycheckwatt/__init__.py @@ -44,6 +44,7 @@ # Import cryptography for session encryption try: from cryptography.fernet import Fernet + CRYPTOGRAPHY_AVAILABLE = True except ImportError: CRYPTOGRAPHY_AVAILABLE = False @@ -55,9 +56,9 @@ class CheckwattManager: """CheckWatt manager.""" def __init__( - self, - username, - password, + self, + username, + password, application="pyCheckwatt", *, max_retries_429: int = 3, @@ -67,31 +68,31 @@ def __init__( clock_skew_seconds: int = 10, max_concurrent_requests: int = 5, persist_sessions: bool = True, - session_file: str = None + session_file: str = None, ) -> None: """Initialize the CheckWatt manager.""" if username is None or password is None: raise ValueError("Username and password must be provided.") - + # Core session and configuration self.session = None self.base_url = "https://api.checkwatt.se" self.username = username self.password = password self.header_identifier = application - + # Authentication state self.jwt_token = None self.refresh_token = None self.refresh_token_expires = None - + self._auth_state = { - 'jwt_token': None, - 'refresh_token': None, - 'jwt_expires_at': None, - 'refresh_expires_at': None, - 'last_auth_time': None, - 'auth_method': None + "jwt_token": None, + "refresh_token": None, + "jwt_expires_at": None, + "refresh_expires_at": None, + "last_auth_time": None, + "auth_method": None, } # Set default session file if none provided and persistence is enabled @@ -99,30 +100,37 @@ def __init__( # Create a default path in user's home directory or temp directory import tempfile import os + try: # Try to use user's home directory first home_dir = os.path.expanduser("~") if os.access(home_dir, os.W_OK): - session_file = os.path.join(home_dir, ".pycheckwatt_sessions", f"{username}_session.json") + session_file = os.path.join( + home_dir, ".pycheckwatt_sessions", f"{username}_session.json" + ) else: # Fall back to temp directory temp_dir = tempfile.gettempdir() - session_file = os.path.join(temp_dir, f"pycheckwatt_{username}_session.json") + session_file = os.path.join( + temp_dir, f"pycheckwatt_{username}_session.json" + ) except Exception: # If all else fails, use temp directory temp_dir = tempfile.gettempdir() - session_file = os.path.join(temp_dir, f"pycheckwatt_{username}_session.json") - + session_file = os.path.join( + temp_dir, f"pycheckwatt_{username}_session.json" + ) + self._session_config = { - 'persist_sessions': persist_sessions and aiofiles is not None, - 'session_file': session_file, - 'encrypt_sessions': CRYPTOGRAPHY_AVAILABLE + "persist_sessions": persist_sessions and aiofiles is not None, + "session_file": session_file, + "encrypt_sessions": CRYPTOGRAPHY_AVAILABLE, } - + # Concurrency control self._auth_lock = asyncio.Lock() self._req_semaphore = asyncio.Semaphore(max_concurrent_requests) - + # Configuration knobs self.max_retries_429 = max_retries_429 self.backoff_base = backoff_base @@ -130,7 +138,7 @@ def __init__( self.backoff_max = backoff_max self.clock_skew_seconds = clock_skew_seconds self.max_concurrent_requests = max_concurrent_requests - + # Data properties (existing) self.dailyaverage = 0 self.monthestimate = 0 @@ -169,14 +177,17 @@ def __init__( async def __aenter__(self): """Asynchronous enter.""" self.session = ClientSession() - + # Try to load existing session if session persistence is enabled - if self._session_config['persist_sessions'] and self._session_config['session_file']: + if ( + self._session_config["persist_sessions"] + and self._session_config["session_file"] + ): try: await self._load_session() except Exception as e: _LOGGER.debug("Failed to load session on context enter: %s", e) - + return self async def __aexit__(self, exc_type, exc_value, traceback): @@ -191,7 +202,7 @@ async def ensure_session(self): def _get_headers(self): """Define common headers.""" - + # Ensure header_identifier is not None if self.header_identifier is None: self.header_identifier = "pyCheckwatt" @@ -216,69 +227,74 @@ def _jwt_is_valid(self) -> bool: """Check if JWT token is valid and not expiring soon.""" if not self.jwt_token: return False - + try: # Simple JWT expiration check - decode the payload part - parts = self.jwt_token.split('.') + parts = self.jwt_token.split(".") if len(parts) != 3: return False - + # Decode the payload (second part) - payload = base64.urlsafe_b64decode(parts[1] + '==').decode('utf-8') + payload = base64.urlsafe_b64decode(parts[1] + "==").decode("utf-8") claims = json.loads(payload) - - exp = claims.get('exp') + + exp = claims.get("exp") if not exp: return False - + # Check if token expires within clock skew now = datetime.now(timezone.utc).timestamp() return now < (exp - self.clock_skew_seconds) - + except (ValueError, json.JSONDecodeError, UnicodeDecodeError, TypeError): # If we can't decode, treat as unknown validity return False def _is_jwt_valid(self) -> bool: """Check if JWT is valid with buffer (enhanced version).""" - if not self._auth_state['jwt_token'] or not self._auth_state['jwt_expires_at']: + if not self._auth_state["jwt_token"] or not self._auth_state["jwt_expires_at"]: return False - + buffer = timedelta(seconds=self.clock_skew_seconds) now = datetime.now(timezone.utc) - expires_at = self._ensure_utc_datetime(self._auth_state['jwt_expires_at']) + expires_at = self._ensure_utc_datetime(self._auth_state["jwt_expires_at"]) is_valid = now < (expires_at - buffer) - + return is_valid - + def _is_refresh_valid(self) -> bool: """Check if refresh token is valid with buffer.""" - if not self._auth_state['refresh_token'] or not self._auth_state['refresh_expires_at']: + if ( + not self._auth_state["refresh_token"] + or not self._auth_state["refresh_expires_at"] + ): return False - + buffer = timedelta(seconds=300) # 5-minute buffer now = datetime.now(timezone.utc) - expires_at = self._ensure_utc_datetime(self._auth_state['refresh_expires_at']) + expires_at = self._ensure_utc_datetime(self._auth_state["refresh_expires_at"]) is_valid = now < (expires_at - buffer) - + return is_valid def _refresh_is_valid(self) -> bool: """Check if refresh token is valid and not expired.""" if not self.refresh_token or not self.refresh_token_expires: return False - + try: # Parse the expiration timestamp - expires = datetime.fromisoformat(self.refresh_token_expires.replace('Z', '+00:00')) + expires = datetime.fromisoformat( + self.refresh_token_expires.replace("Z", "+00:00") + ) # Ensure it's UTC if no timezone info if expires.tzinfo is None: expires = expires.replace(tzinfo=timezone.utc) now = datetime.now(timezone.utc) - + # Add some buffer (5 minutes) to avoid edge cases return now < (expires - timedelta(minutes=5)) - + except (ValueError, TypeError): # If we can't parse, treat as unknown validity return False @@ -287,57 +303,59 @@ async def _refresh(self) -> bool: """Refresh the JWT token using the refresh token.""" if not self.refresh_token: return False - + try: endpoint = "/user/RefreshToken?audience=eib" headers = { **self._get_headers(), "authorization": f"RefreshToken {self.refresh_token}", } - + async with self.session.get( - self.base_url + endpoint, - headers=headers, - timeout=10 + self.base_url + endpoint, headers=headers, timeout=10 ) as response: if response.status == 200: data = await response.json() - + # Update tokens self.jwt_token = data.get("JwtToken") if "RefreshToken" in data: self.refresh_token = data.get("RefreshToken") if "RefreshTokenExpires" in data: self.refresh_token_expires = data.get("RefreshTokenExpires") - + # Update internal auth state - ensure all timestamps are UTC-aware jwt_expiry = self._ensure_utc_datetime(self.jwt_expires_at) refresh_expiry = self._ensure_utc_datetime(self.refresh_expires_at) - - self._auth_state.update({ - 'jwt_token': self.jwt_token, - 'refresh_token': self.refresh_token, - 'jwt_expires_at': jwt_expiry, - 'refresh_expires_at': refresh_expiry, - 'last_auth_time': datetime.now(timezone.utc), - 'auth_method': 'refresh' - }) - + + self._auth_state.update( + { + "jwt_token": self.jwt_token, + "refresh_token": self.refresh_token, + "jwt_expires_at": jwt_expiry, + "refresh_expires_at": refresh_expiry, + "last_auth_time": datetime.now(timezone.utc), + "auth_method": "refresh", + } + ) + # Persist session if enabled - if self._session_config['persist_sessions']: + if self._session_config["persist_sessions"]: await self._save_session() - + _LOGGER.info("Successfully refreshed JWT token") return True - + elif response.status == 401: _LOGGER.warning("Refresh token expired or invalid") return False - + else: - _LOGGER.error("Unexpected status code during refresh: %d", response.status) + _LOGGER.error( + "Unexpected status code during refresh: %d", response.status + ) return False - + except (ClientResponseError, ClientError) as error: _LOGGER.error("Error during token refresh: %s", error) return False @@ -348,15 +366,15 @@ async def ensure_authenticated(self) -> bool: # Quick check for valid JWT if self._is_jwt_valid(): return True - + # Try refresh if available if self._is_refresh_valid(): if await self._refresh_tokens(): return True - + # Fall back to password login return await self._perform_login() - + except Exception as e: _LOGGER.error("Authentication failed: %s", e) return False @@ -366,29 +384,31 @@ async def _perform_login(self) -> bool: try: # Use existing login logic result = await self.login() - + if result: # Update internal auth state - ensure all timestamps are UTC-aware jwt_expiry = self._ensure_utc_datetime(self.jwt_expires_at) refresh_expiry = self._ensure_utc_datetime(self.refresh_expires_at) - - self._auth_state.update({ - 'jwt_token': self.jwt_token, - 'refresh_token': self.refresh_token, - 'jwt_expires_at': jwt_expiry, - 'refresh_expires_at': refresh_expiry, - 'last_auth_time': datetime.now(timezone.utc), - 'auth_method': 'password' - }) - + + self._auth_state.update( + { + "jwt_token": self.jwt_token, + "refresh_token": self.refresh_token, + "jwt_expires_at": jwt_expiry, + "refresh_expires_at": refresh_expiry, + "last_auth_time": datetime.now(timezone.utc), + "auth_method": "password", + } + ) + # Persist session if enabled - if self._session_config['persist_sessions']: + if self._session_config["persist_sessions"]: await self._save_session() - + return True - + return False - + except Exception as e: _LOGGER.error("Login failed: %s", e) return False @@ -398,140 +418,154 @@ async def _refresh_tokens(self) -> bool: try: # Use existing refresh logic result = await self._refresh() - + if result: # Update internal auth state - ensure all timestamps are UTC-aware jwt_expiry = self._ensure_utc_datetime(self.jwt_expires_at) refresh_expiry = self._ensure_utc_datetime(self.refresh_expires_at) - - self._auth_state.update({ - 'jwt_token': self.jwt_token, - 'refresh_token': self.refresh_token, - 'jwt_expires_at': jwt_expiry, - 'refresh_expires_at': refresh_expiry, - 'last_auth_time': datetime.now(timezone.utc), - 'auth_method': 'refresh' - }) - + + self._auth_state.update( + { + "jwt_token": self.jwt_token, + "refresh_token": self.refresh_token, + "jwt_expires_at": jwt_expiry, + "refresh_expires_at": refresh_expiry, + "last_auth_time": datetime.now(timezone.utc), + "auth_method": "refresh", + } + ) + # Persist session if enabled - if self._session_config['persist_sessions']: + if self._session_config["persist_sessions"]: await self._save_session() - + return True - + return False - + except Exception as e: _LOGGER.error("Token refresh failed: %s", e) return False async def _save_session(self) -> bool: """Save current session to file with encryption.""" - if not self._session_config['session_file'] or not aiofiles: + if not self._session_config["session_file"] or not aiofiles: return False - + try: session_data = { - 'version': '1.0', - 'username': self.username, - 'auth_state': self._auth_state, - 'timestamp': datetime.now(timezone.utc).isoformat() + "version": "1.0", + "username": self.username, + "auth_state": self._auth_state, + "timestamp": datetime.now(timezone.utc).isoformat(), } - - if self._session_config['encrypt_sessions']: + + if self._session_config["encrypt_sessions"]: encrypted_data = self._encrypt_session_data(session_data) else: encrypted_data = json.dumps(session_data, default=str) - + # Ensure directory exists - os.makedirs(os.path.dirname(self._session_config['session_file']), exist_ok=True) - + os.makedirs( + os.path.dirname(self._session_config["session_file"]), exist_ok=True + ) + # Write session file - async with aiofiles.open(self._session_config['session_file'], 'w') as f: + async with aiofiles.open(self._session_config["session_file"], "w") as f: await f.write(encrypted_data) - + return True - + except Exception as e: _LOGGER.error("Failed to save session: %s", e) return False - + async def _load_session(self) -> bool: """Load session from file and validate.""" - if not self._session_config['session_file'] or not aiofiles: + if not self._session_config["session_file"] or not aiofiles: return False - + try: - if not os.path.exists(self._session_config['session_file']): + if not os.path.exists(self._session_config["session_file"]): return False - + # Read session file - async with aiofiles.open(self._session_config['session_file'], 'r') as f: + async with aiofiles.open(self._session_config["session_file"], "r") as f: encrypted_data = await f.read() - + # Decrypt if needed - if self._session_config['encrypt_sessions']: + if self._session_config["encrypt_sessions"]: session_data = self._decrypt_session_data(encrypted_data) else: session_data = json.loads(encrypted_data) - + # Validate session data if not self._validate_session_data(session_data): _LOGGER.warning("Invalid session data, clearing") await self._clear_session() return False - + # Restore auth state and ensure all timestamps are UTC-aware - self._auth_state = session_data['auth_state'] - + self._auth_state = session_data["auth_state"] + # Ensure stored timestamps are UTC-aware - if self._auth_state['jwt_expires_at']: - self._auth_state['jwt_expires_at'] = self._ensure_utc_datetime(self._auth_state['jwt_expires_at']) - if self._auth_state['refresh_expires_at']: - self._auth_state['refresh_expires_at'] = self._ensure_utc_datetime(self._auth_state['refresh_expires_at']) - if self._auth_state['last_auth_time']: - self._auth_state['last_auth_time'] = self._ensure_utc_datetime(self._auth_state['last_auth_time']) - - self.jwt_token = self._auth_state['jwt_token'] - self.refresh_token = self._auth_state['refresh_token'] - self.refresh_token_expires = self._auth_state['refresh_expires_at'] - + if self._auth_state["jwt_expires_at"]: + self._auth_state["jwt_expires_at"] = self._ensure_utc_datetime( + self._auth_state["jwt_expires_at"] + ) + if self._auth_state["refresh_expires_at"]: + self._auth_state["refresh_expires_at"] = self._ensure_utc_datetime( + self._auth_state["refresh_expires_at"] + ) + if self._auth_state["last_auth_time"]: + self._auth_state["last_auth_time"] = self._ensure_utc_datetime( + self._auth_state["last_auth_time"] + ) + + self.jwt_token = self._auth_state["jwt_token"] + self.refresh_token = self._auth_state["refresh_token"] + self.refresh_token_expires = self._auth_state["refresh_expires_at"] + return True - + except Exception as e: _LOGGER.error("Failed to load session: %s", e) return False - + async def _clear_session(self) -> None: """Clear current session and remove session file.""" self._auth_state = { - 'jwt_token': None, - 'refresh_token': None, - 'jwt_expires_at': None, - 'refresh_expires_at': None, - 'last_auth_time': None, - 'auth_method': None + "jwt_token": None, + "refresh_token": None, + "jwt_expires_at": None, + "refresh_expires_at": None, + "last_auth_time": None, + "auth_method": None, } - - if self._session_config['session_file'] and os.path.exists(self._session_config['session_file']): + + if self._session_config["session_file"] and os.path.exists( + self._session_config["session_file"] + ): try: - os.remove(self._session_config['session_file']) + os.remove(self._session_config["session_file"]) except Exception as e: _LOGGER.error("Failed to remove session file: %s", e) - def _ensure_utc_datetime(self, dt: Optional[Union[datetime, str]]) -> Optional[datetime]: + def _ensure_utc_datetime( + self, dt: Optional[Union[datetime, str]] + ) -> Optional[datetime]: """Ensure datetime is UTC-aware, converting if necessary.""" if dt is None: return None - + # Handle string timestamps (from session files) if isinstance(dt, str): try: - dt = datetime.fromisoformat(dt.replace('Z', '+00:00')) + dt = datetime.fromisoformat(dt.replace("Z", "+00:00")) except (ValueError, TypeError): _LOGGER.warning("Failed to parse timestamp string: %s", dt) return None - + if dt.tzinfo is None: # If naive, assume UTC and make it aware return dt.replace(tzinfo=timezone.utc) @@ -545,85 +579,85 @@ def _ensure_utc_datetime(self, dt: Optional[Union[datetime, str]]) -> Optional[d def _get_encryption_key(self) -> bytes: """Generate encryption key from username and password.""" # Create a deterministic key from credentials - key_material = f"{self.username}:{self.password}".encode('utf-8') + key_material = f"{self.username}:{self.password}".encode("utf-8") key_hash = hashlib.sha256(key_material).digest() return base64.urlsafe_b64encode(key_hash) - + def _encrypt_session_data(self, data: dict) -> str: """Encrypt session data.""" if not CRYPTOGRAPHY_AVAILABLE: raise RuntimeError("Cryptography library not available for encryption") - + key = self._get_encryption_key() f = Fernet(key) json_data = json.dumps(data, default=str) - encrypted = f.encrypt(json_data.encode('utf-8')) - return encrypted.decode('utf-8') - + encrypted = f.encrypt(json_data.encode("utf-8")) + return encrypted.decode("utf-8") + def _decrypt_session_data(self, encrypted_data: str) -> dict: """Decrypt session data.""" if not CRYPTOGRAPHY_AVAILABLE: raise RuntimeError("Cryptography library not available for decryption") - + key = self._get_encryption_key() f = Fernet(key) - decrypted = f.decrypt(encrypted_data.encode('utf-8')) - return json.loads(decrypted.decode('utf-8')) - + decrypted = f.decrypt(encrypted_data.encode("utf-8")) + return json.loads(decrypted.decode("utf-8")) + def _validate_session_data(self, data: dict) -> bool: """Validate loaded session data.""" - required_fields = ['version', 'username', 'auth_state', 'timestamp'] + required_fields = ["version", "username", "auth_state", "timestamp"] if not all(field in data for field in required_fields): return False - - if data['username'] != self.username: + + if data["username"] != self.username: return False - - if data['version'] != '1.0': + + if data["version"] != "1.0": return False - + return True async def load_session(self, filepath: str = None) -> bool: """Load session from file.""" if filepath: - self._session_config['session_file'] = filepath - + self._session_config["session_file"] = filepath + return await self._load_session() - + async def save_session(self, filepath: str = None) -> bool: """Save current session to file.""" if filepath: - self._session_config['session_file'] = filepath - + self._session_config["session_file"] = filepath + return await self._save_session() - + async def clear_session(self) -> None: """Clear current session.""" await self._clear_session() - + def get_session_info(self) -> dict: """Get current session information.""" return { - 'authenticated': self._is_jwt_valid(), - 'jwt_expires_in': self._get_jwt_expiry_delta(), - 'refresh_expires_in': self._get_refresh_expiry_delta(), - 'last_auth_method': self._auth_state['auth_method'], - 'last_auth_time': self._auth_state['last_auth_time'] + "authenticated": self._is_jwt_valid(), + "jwt_expires_in": self._get_jwt_expiry_delta(), + "refresh_expires_in": self._get_refresh_expiry_delta(), + "last_auth_method": self._auth_state["auth_method"], + "last_auth_time": self._auth_state["last_auth_time"], } - + def _get_jwt_expiry_delta(self) -> Optional[timedelta]: """Get time until JWT expires.""" - if not self._auth_state['jwt_expires_at']: + if not self._auth_state["jwt_expires_at"]: return None - expires_at = self._ensure_utc_datetime(self._auth_state['jwt_expires_at']) + expires_at = self._ensure_utc_datetime(self._auth_state["jwt_expires_at"]) return expires_at - datetime.now(timezone.utc) - + def _get_refresh_expiry_delta(self) -> Optional[timedelta]: """Get time until refresh token expires.""" - if not self._auth_state['refresh_expires_at']: + if not self._auth_state["refresh_expires_at"]: return None - expires_at = self._ensure_utc_datetime(self._auth_state['refresh_expires_at']) + expires_at = self._ensure_utc_datetime(self._auth_state["refresh_expires_at"]) return expires_at - datetime.now(timezone.utc) async def _ensure_token(self) -> bool: @@ -631,37 +665,37 @@ async def _ensure_token(self) -> bool: # Quick check without lock if self.jwt_token and self._jwt_is_valid(): return True - + # Need to acquire lock for auth operations async with self._auth_lock: # Double-check after acquiring lock if self.jwt_token and self._jwt_is_valid(): return True - + # Try refresh first if self.refresh_token and self._refresh_is_valid(): if await self._refresh(): return True - + # Fall back to login _LOGGER.info("Performing password login") return await self.login() async def _request( - self, - method: str, - endpoint: str, - *, + self, + method: str, + endpoint: str, + *, headers: Optional[Dict[str, str]] = None, auth_required: bool = True, retry_on_401: bool = True, retry_on_429: bool = True, timeout: int = 10, - **kwargs + **kwargs, ) -> Union[Dict[str, Any], str, bool, None]: """ Centralized request wrapper with authentication and retry logic. - + Args: method: HTTP method (GET, POST, etc.) endpoint: API endpoint path @@ -671,7 +705,7 @@ async def _request( retry_on_429: Whether to retry on 429 (with backoff) timeout: Request timeout in seconds **kwargs: Additional arguments for the request - + Returns: Response data (dict for JSON, str for text) or boolean for success/failure """ @@ -679,16 +713,19 @@ async def _request( if auth_required: if not await self._ensure_token(): return False - + # Prepare headers final_headers = {**self._get_headers(), **(headers or {})} if auth_required and self.jwt_token: final_headers["authorization"] = f"Bearer {self.jwt_token}" - + # Remove sensitive headers from logging - safe_headers = {k: v for k, v in final_headers.items() - if k.lower() not in ['authorization', 'cookie']} - + safe_headers = { + k: v + for k, v in final_headers.items() + if k.lower() not in ["authorization", "cookie"] + } + # Apply concurrency control async with self._req_semaphore: # Perform request with retry logic @@ -699,31 +736,37 @@ async def _request( self.base_url + endpoint, headers=final_headers, timeout=timeout, - **kwargs + **kwargs, ) as response: # Handle 401 (Unauthorized) if response.status == 401 and retry_on_401 and auth_required: _LOGGER.warning("Received 401, attempting token refresh") - + # Try refresh first if await self._refresh(): # Retry the original request once continue - + # If refresh failed, try login _LOGGER.warning("Refresh failed, attempting login") if await self.login(): # Retry the original request once continue - + # Both refresh and login failed - _LOGGER.error("Authentication failed after refresh and login attempts") + _LOGGER.error( + "Authentication failed after refresh and login attempts" + ) return False - + # Handle 429 (Too Many Requests) - if response.status == 429 and retry_on_429 and attempt < self.max_retries_429: - retry_after = response.headers.get('Retry-After') - + if ( + response.status == 429 + and retry_on_429 + and attempt < self.max_retries_429 + ): + retry_after = response.headers.get("Retry-After") + if retry_after: try: # Try to parse as seconds @@ -733,50 +776,61 @@ async def _request( # Try to parse as HTTP date retry_date = parsedate_to_datetime(retry_after) if retry_date.tzinfo is None: - retry_date = retry_date.replace(tzinfo=timezone.utc) - wait_time = (retry_date - datetime.now(timezone.utc)).total_seconds() + retry_date = retry_date.replace( + tzinfo=timezone.utc + ) + wait_time = ( + retry_date - datetime.now(timezone.utc) + ).total_seconds() wait_time = max(0, wait_time) except (ValueError, TypeError): wait_time = self.backoff_base else: # Use exponential backoff with jitter wait_time = min( - self.backoff_base * (self.backoff_factor ** attempt), - self.backoff_max + self.backoff_base * (self.backoff_factor**attempt), + self.backoff_max, ) # Add jitter (0 to 0.25s) wait_time += random.uniform(0, 0.25) - - _LOGGER.info("Rate limited (429), waiting %.2f seconds before retry", wait_time) + + _LOGGER.info( + "Rate limited (429), waiting %.2f seconds before retry", + wait_time, + ) await asyncio.sleep(wait_time) continue - + # Handle other status codes response.raise_for_status() - + # Parse response based on content type - content_type = response.headers.get('Content-Type', '').lower() - - if 'application/json' in content_type: + content_type = response.headers.get("Content-Type", "").lower() + + if "application/json" in content_type: return await response.json() else: return await response.text() - + except ClientResponseError as e: if e.status == 401 and retry_on_401 and auth_required: # This will be handled in the next iteration continue - elif e.status == 429 and retry_on_429 and attempt < self.max_retries_429: + elif ( + e.status == 429 + and retry_on_429 + and attempt < self.max_retries_429 + ): # This will be handled in the next iteration continue else: _LOGGER.error("Request failed with status %d: %s", e.status, e) return await self.handle_client_error(endpoint, safe_headers, e) - + except (ClientError, asyncio.TimeoutError) as error: _LOGGER.error("Request failed: %s", error) return await self.handle_client_error(endpoint, safe_headers, error) - + # If we get here, we've exhausted all retries _LOGGER.error("Request failed after %d attempts", self.max_retries_429 + 1) return False @@ -784,35 +838,46 @@ async def _request( async def _continue_kill_switch_not_enabled(self): """Check if CheckWatt has requested integrations to back-off.""" url = "https://checkwatt.se/ha-killswitch.txt" - + # Ensure session is initialized if self.session is None: - _LOGGER.error("Session not initialized. Use async context manager or call ensure_session() first.") + _LOGGER.error( + "Session not initialized. Use async context manager or call ensure_session() first." + ) return False - + try: headers = self._get_headers() - + # Ensure headers is a valid dictionary if not isinstance(headers, dict): - _LOGGER.error("_get_headers() returned invalid type: %s, defaulting to empty dict", type(headers)) + _LOGGER.error( + "_get_headers() returned invalid type: %s, defaulting to empty dict", + type(headers), + ) headers = {} - + async with self.session.get(url, headers=headers, timeout=10) as response: data = await response.text() if response.status == 200: kill = data.strip() # Remove leading and trailing whitespaces enabled = kill == "0" - + if enabled: - _LOGGER.debug("CheckWatt accepted and not enabled the kill-switch") + _LOGGER.debug( + "CheckWatt accepted and not enabled the kill-switch" + ) else: - _LOGGER.error("CheckWatt has requested to back down by enabling the kill-switch") - + _LOGGER.error( + "CheckWatt has requested to back down by enabling the kill-switch" + ) + return enabled if response.status == 401: - _LOGGER.error("Unauthorized: Check your CheckWatt authentication credentials") + _LOGGER.error( + "Unauthorized: Check your CheckWatt authentication credentials" + ) return False _LOGGER.error("Unexpected HTTP status code: %s", response.status) @@ -821,11 +886,14 @@ async def _continue_kill_switch_not_enabled(self): except Exception as error: # Create safe headers for logging, handling case where headers might not be defined try: - safe_headers = {k: v for k, v in headers.items() - if k.lower() not in ['authorization', 'cookie']} + safe_headers = { + k: v + for k, v in headers.items() + if k.lower() not in ["authorization", "cookie"] + } except (AttributeError, NameError): safe_headers = {} - + _LOGGER.error( "Killswitch check failed. URL: %s, Headers: %s. Error: %s", url, @@ -837,9 +905,12 @@ async def _continue_kill_switch_not_enabled(self): async def handle_client_error(self, endpoint, headers, error): """Handle ClientError and log relevant information.""" # Remove sensitive headers from logging - safe_headers = {k: v for k, v in headers.items() - if k.lower() not in ['authorization', 'cookie']} - + safe_headers = { + k: v + for k, v in headers.items() + if k.lower() not in ["authorization", "cookie"] + } + _LOGGER.error( "An error occurred during the request. URL: %s, Headers: %s. Error: %s", self.base_url + endpoint, @@ -882,24 +953,26 @@ async def login(self): self.jwt_token = data.get("JwtToken") self.refresh_token = data.get("RefreshToken") self.refresh_token_expires = data.get("RefreshTokenExpires") - + # Update internal auth state - ensure all timestamps are UTC-aware jwt_expiry = self._ensure_utc_datetime(self.jwt_expires_at) refresh_expiry = self._ensure_utc_datetime(self.refresh_expires_at) - - self._auth_state.update({ - 'jwt_token': self.jwt_token, - 'refresh_token': self.refresh_token, - 'jwt_expires_at': jwt_expiry, - 'refresh_expires_at': refresh_expiry, - 'last_auth_time': datetime.now(timezone.utc), - 'auth_method': 'password' - }) - + + self._auth_state.update( + { + "jwt_token": self.jwt_token, + "refresh_token": self.refresh_token, + "jwt_expires_at": jwt_expiry, + "refresh_expires_at": refresh_expiry, + "last_auth_time": datetime.now(timezone.utc), + "auth_method": "password", + } + ) + # Persist session if enabled - if self._session_config['persist_sessions']: + if self._session_config["persist_sessions"]: await self._save_session() - + _LOGGER.info("Successfully logged in to CheckWatt") return True @@ -921,13 +994,13 @@ async def get_customer_details(self): if not await self.ensure_authenticated(): _LOGGER.error("Failed to authenticate for customer details") return False - + endpoint = "/controlpanel/CustomerDetail" - + result = await self._request("GET", endpoint, auth_required=True) if result is False: return False - + self.customer_details = result meters = self.customer_details.get("Meter", []) @@ -978,12 +1051,8 @@ async def get_customer_details(self): None, ) if discharge_meter: - self.battery_discharge_peak_ac = discharge_meter.get( - "PeakAcKw" - ) - self.battery_discharge_peak_dc = discharge_meter.get( - "PeakDcKw" - ) + self.battery_discharge_peak_ac = discharge_meter.get("PeakAcKw") + self.battery_discharge_peak_dc = discharge_meter.get("PeakDcKw") return True @@ -1005,18 +1074,18 @@ async def get_site_id(self): if not await self.ensure_authenticated(): _LOGGER.error("Failed to authenticate for site ID") return False - + endpoint = f"/Site/SiteIdBySerial?serial={self.rpi_serial}" - + result = await self._request("GET", endpoint, auth_required=True) if result is False: return False - + if isinstance(result, dict) and "SiteId" in result: self.site_id = str(result["SiteId"]) _LOGGER.debug("Successfully extracted site ID: %s", self.site_id) return self.site_id - + _LOGGER.error("Unexpected response format for site ID: %s", result) return False @@ -1030,24 +1099,28 @@ async def debug_revenue_workflow(self): _LOGGER.info("Customer details loaded: %s", self.customer_details is not None) _LOGGER.info("RPI data loaded: %s", self.rpi_data is not None) _LOGGER.info("Site ID cached: %s", self.site_id) - + if self.customer_details: meters = self.customer_details.get("Meter", []) _LOGGER.info("Number of meters: %d", len(meters)) for i, meter in enumerate(meters): - _LOGGER.info("Meter %d: Type=%s, RpiSerial=%s", - i, meter.get("InstallationType"), meter.get("RpiSerial")) - + _LOGGER.info( + "Meter %d: Type=%s, RpiSerial=%s", + i, + meter.get("InstallationType"), + meter.get("RpiSerial"), + ) + rpi_serial = self.rpi_serial _LOGGER.info("RPI Serial: %s", rpi_serial) - + if rpi_serial: _LOGGER.info("Attempting to get site ID...") site_id = await self.get_site_id() _LOGGER.info("Site ID result: %s", site_id) else: _LOGGER.error("Cannot get site ID - RPI serial is None") - + _LOGGER.info("=== End Debug ===") async def get_fcrd_month_net_revenue(self): @@ -1057,18 +1130,18 @@ async def get_fcrd_month_net_revenue(self): if not await self.ensure_authenticated(): _LOGGER.error("Failed to authenticate for FCR-D month revenue") return False - + site_id = await self.get_site_id() if site_id is False: _LOGGER.error("Failed to get site ID for FCR-D month revenue") return False - + if not site_id: _LOGGER.error("Site ID is empty or None for FCR-D month revenue") return False - + _LOGGER.debug("Using site ID %s for FCR-D month revenue", site_id) - + from_date = datetime.now().strftime("%Y-%m-01") to_date = datetime.now() + timedelta(days=1) to_date = to_date.strftime("%Y-%m-%d") @@ -1092,9 +1165,11 @@ async def get_fcrd_month_net_revenue(self): result = await self._request("GET", endpoint, auth_required=True) if result is False: - _LOGGER.error("Failed to retrieve FCR-D month revenue from endpoint: %s", endpoint) + _LOGGER.error( + "Failed to retrieve FCR-D month revenue from endpoint: %s", endpoint + ) return False - + revenue = result # Reset monthly revenue before adding new values self.revenuemonth = 0 @@ -1103,14 +1178,12 @@ async def get_fcrd_month_net_revenue(self): if each["NetRevenue"] == 0: misseddays += 1 dayswithmoney = int(dayssofar) - int(misseddays) - + if dayswithmoney > 0: self.dailyaverage = self.revenuemonth / int(dayswithmoney) else: self.dailyaverage = 0 - self.monthestimate = ( - self.dailyaverage * daysleft - ) + self.revenuemonth + self.monthestimate = (self.dailyaverage * daysleft) + self.revenuemonth _LOGGER.info("Successfully retrieved FCR-D month revenue") return True @@ -1124,18 +1197,18 @@ async def get_fcrd_today_net_revenue(self): if not await self.ensure_authenticated(): _LOGGER.error("Failed to authenticate for FCR-D today revenue") return False - + site_id = await self.get_site_id() if site_id is False: _LOGGER.error("Failed to get site ID for FCR-D today revenue") return False - + if not site_id: _LOGGER.error("Site ID is empty or None for FCR-D today revenue") return False - + _LOGGER.debug("Using site ID %s for FCR-D today revenue", site_id) - + from_date = datetime.now().strftime("%Y-%m-%d") end_date = datetime.now() + timedelta(days=2) to_date = end_date.strftime("%Y-%m-%d") @@ -1147,9 +1220,11 @@ async def get_fcrd_today_net_revenue(self): result = await self._request("GET", endpoint, auth_required=True) if result is False: - _LOGGER.error("Failed to retrieve FCR-D today revenue from endpoint: %s", endpoint) + _LOGGER.error( + "Failed to retrieve FCR-D today revenue from endpoint: %s", endpoint + ) return False - + self.revenue = result _LOGGER.info("Successfully retrieved FCR-D today revenue") return True @@ -1164,18 +1239,18 @@ async def get_fcrd_year_net_revenue(self): if not await self.ensure_authenticated(): _LOGGER.error("Failed to authenticate for FCR-D year revenue") return False - + site_id = await self.get_site_id() if site_id is False: _LOGGER.error("Failed to get site ID for FCR-D year revenue") return False - + if not site_id: _LOGGER.error("Site ID is empty or None for FCR-D year revenue") return False - + _LOGGER.debug("Using site ID %s for FCR-D year revenue", site_id) - + yesterday_date = datetime.now() + timedelta(days=1) yesterday_date = yesterday_date.strftime("-%m-%d") months = ["-01-01", "-06-30", "-07-01", yesterday_date] @@ -1186,59 +1261,80 @@ async def get_fcrd_year_net_revenue(self): year_date = datetime.now().strftime("%Y") to_date = year_date + yesterday_date from_date = year_date + "-01-01" - endpoint = ( - f"/revenue/{site_id}?from={from_date}&to={to_date}&resolution=day" + endpoint = f"/revenue/{site_id}?from={from_date}&to={to_date}&resolution=day" + _LOGGER.debug( + "FCR-D year revenue endpoint (first half): %s", endpoint ) - _LOGGER.debug("FCR-D year revenue endpoint (first half): %s", endpoint) - + result = await self._request("GET", endpoint, auth_required=True) if result is False: - _LOGGER.error("Failed to retrieve FCR-D year revenue from endpoint: %s", endpoint) + _LOGGER.error( + "Failed to retrieve FCR-D year revenue from endpoint: %s", + endpoint, + ) return False - + self.revenueyear = result # Reset yearly revenue before adding new values self.revenueyeartotal = 0 for each in self.revenueyear["Revenue"]: self.revenueyeartotal += each["NetRevenue"] retval = True - _LOGGER.info("Successfully retrieved FCR-D year revenue (first half)") + _LOGGER.info( + "Successfully retrieved FCR-D year revenue (first half)" + ) return retval except Exception as error: - _LOGGER.error("Error in get_fcrd_year_net_revenue (first half): %s", error) + _LOGGER.error( + "Error in get_fcrd_year_net_revenue (first half): %s", error + ) return False else: try: # Reset yearly revenue once before processing all periods self.revenueyeartotal = 0 - + while loop < 3: year_date = datetime.now().strftime("%Y") to_date = year_date + months[loop + 1] from_date = year_date + months[loop] endpoint = f"/revenue/{site_id}?from={from_date}&to={to_date}&resolution=day" - _LOGGER.debug("FCR-D year revenue endpoint (period %d): %s", loop, endpoint) - - result = await self._request("GET", endpoint, auth_required=True) + _LOGGER.debug( + "FCR-D year revenue endpoint (period %d): %s", + loop, + endpoint, + ) + + result = await self._request( + "GET", endpoint, auth_required=True + ) if result is False: - _LOGGER.error("Failed to retrieve FCR-D year revenue from endpoint: %s", endpoint) + _LOGGER.error( + "Failed to retrieve FCR-D year revenue from endpoint: %s", + endpoint, + ) return False - + self.revenueyear = result # Add this period's revenue to the total (don't reset) for each in self.revenueyear["Revenue"]: self.revenueyeartotal += each["NetRevenue"] loop += 2 retval = True - - _LOGGER.info("Successfully retrieved FCR-D year revenue (multiple periods)") + + _LOGGER.info( + "Successfully retrieved FCR-D year revenue (multiple periods)" + ) return retval except Exception as error: - _LOGGER.error("Error in get_fcrd_year_net_revenue (multiple periods): %s", error) + _LOGGER.error( + "Error in get_fcrd_year_net_revenue (multiple periods): %s", + error, + ) return False - + except Exception as error: _LOGGER.error("Error in get_fcrd_year_net_revenue: %s", error) return False @@ -1249,18 +1345,18 @@ async def fetch_and_return_net_revenue(self, from_date, to_date): if not await self.ensure_authenticated(): _LOGGER.error("Failed to authenticate for custom revenue range") return None - + site_id = await self.get_site_id() if site_id is False: _LOGGER.error("Failed to get site ID for custom revenue range") return None - + if not site_id: _LOGGER.error("Site ID is empty or None for custom revenue range") return None - + _LOGGER.debug("Using site ID %s for custom revenue range", site_id) - + # Validate date format and ensure they are dates date_format = "%Y-%m-%d" try: @@ -1298,9 +1394,12 @@ async def fetch_and_return_net_revenue(self, from_date, to_date): result = await self._request("GET", endpoint, auth_required=True) if result is False: - _LOGGER.error("Failed to retrieve custom revenue range from endpoint: %s", endpoint) + _LOGGER.error( + "Failed to retrieve custom revenue range from endpoint: %s", + endpoint, + ) return None - + _LOGGER.info("Successfully retrieved custom revenue range") return result @@ -1344,7 +1443,7 @@ def _extract_content_and_logbook(self, input_string): def _extract_fcr_d_state(self): pattern = re.compile( - r"\[ FCR-D (ACTIVATED|DEACTIVATE|FAIL ACTIVATION) \] (?:(?:\d+x)?\s?(\S+) --(\d+)-- | (?:(?:UP|DOWN) (?:\d+,\d+) Hz ))((?:(\d+,\d+)\/(\d+,\d+)\/)?(\d+,\d+|[A-Z]+) %)\s+\((\d+,\d+\/\d+,\d+|\d+\/\d+|\d+) kW\)\s*-?\s*.*?(\d{4}-\d{2}-\d{2} \d{2}:\d{2}:\d{2})" # noqa: E501 + r"\[ FCR-D (ACTIVATED|DEACTIVATE|FAIL ACTIVATION) \] (?:(?:\d+x)?\s?(\S+) --(\d+)-- | (?:(?:UP|DOWN) (?:\d+,\d+) Hz ))((?:(\d+,\d+)\/(\d+,\d+)\/)?(\d+,\d+|[A-Z]+) %)\s+\((\d+,\d+\/\d+,\d+|\d+\/\d+|\d+) kW\)\s*-?\s*.*?(\d{4}-\d{2}-\d{2} \d{2}:\d{2}:\d{2})" # noqa: E501 ) for entry in self.logbook_entries: match = pattern.search(entry) @@ -1405,7 +1504,7 @@ async def get_power_data(self): if not await self.ensure_authenticated(): _LOGGER.error("Failed to authenticate for power data") return False - + endpoint = self._build_series_endpoint( 3 ) # 0: Hourly, 1: Daily, 2: Monthly, 3: Yearly @@ -1413,7 +1512,7 @@ async def get_power_data(self): result = await self._request("GET", endpoint, auth_required=True) if result is False: return False - + self.power_data = result return True @@ -1428,13 +1527,13 @@ async def get_energy_flow(self): if not await self.ensure_authenticated(): _LOGGER.error("Failed to authenticate for energy flow") return False - + endpoint = "/ems/energyflow" result = await self._request("GET", endpoint, auth_required=True) if result is False: return False - + self.energy_data = result return True @@ -1449,7 +1548,7 @@ async def get_ems_settings(self, rpi_serial=None): if not await self.ensure_authenticated(): _LOGGER.error("Failed to authenticate for EMS settings") return False - + if rpi_serial is None: rpi_serial = self.rpi_serial @@ -1458,7 +1557,7 @@ async def get_ems_settings(self, rpi_serial=None): result = await self._request("GET", endpoint, auth_required=True) if result is False: return False - + self.ems = result return True @@ -1474,7 +1573,7 @@ async def get_energy_trading_company(self, input_id): result = await self._request("GET", endpoint, auth_required=False) if result is False: return None - + energy_trading_companies = result for energy_trading_company in energy_trading_companies: if energy_trading_company["Id"] == input_id: @@ -1493,13 +1592,13 @@ async def get_price_zone(self): if not await self.ensure_authenticated(): _LOGGER.error("Failed to authenticate for price zone") return False - + endpoint = "/ems/pricezone" - + result = await self._request("GET", endpoint, auth_required=True) if result is False: return False - + self.price_zone = result return True @@ -1514,18 +1613,18 @@ async def get_spot_price(self): if not await self.ensure_authenticated(): _LOGGER.error("Failed to authenticate for spot price") return False - + from_date = datetime.now().strftime("%Y-%m-%d") end_date = datetime.now() + timedelta(days=1) to_date = end_date.strftime("%Y-%m-%d") if self.price_zone is None: await self.get_price_zone() endpoint = f"/ems/spotprice?zone={self.price_zone}&fromDate={from_date}&toDate={to_date}" # noqa: E501 - + result = await self._request("GET", endpoint, auth_required=True) if result is False: return False - + self.spot_prices = result return True @@ -1541,17 +1640,17 @@ async def get_battery_month_peak_effect(self): if not await self.ensure_authenticated(): _LOGGER.error("Failed to authenticate for battery month peak effect") return False - + endpoint = f"/ems/PeakBoughtMonth?month={month}" - + result = await self._request("GET", endpoint, auth_required=True) if result is False: return False - + if "HourPeak" in result: self.month_peak_effect = result["HourPeak"] return True - + return False except Exception as error: @@ -1570,11 +1669,11 @@ async def get_rpi_data(self, rpi_serial=None): return False endpoint = f"/register/checkrpiv2?rpi={rpi_serial}" - + result = await self._request("GET", endpoint, auth_required=False) if result is False: return False - + self.rpi_data = result return True @@ -1594,11 +1693,11 @@ async def get_meter_status(self, meter_id=None): return False endpoint = f"/asset/status?meterId={meter_id}" - + result = await self._request("GET", endpoint, auth_required=False) if result is False: return False - + self.meter_data = result return True @@ -1943,21 +2042,21 @@ def jwt_expires_at(self) -> Optional[datetime]: """Get JWT expiration time for debugging.""" if not self.jwt_token: return None - + try: - parts = self.jwt_token.split('.') + parts = self.jwt_token.split(".") if len(parts) != 3: return None - - payload = base64.urlsafe_b64decode(parts[1] + '==').decode('utf-8') + + payload = base64.urlsafe_b64decode(parts[1] + "==").decode("utf-8") claims = json.loads(payload) - - exp = claims.get('exp') + + exp = claims.get("exp") if not exp: return None - + return datetime.fromtimestamp(exp, tz=timezone.utc) - + except (ValueError, json.JSONDecodeError, UnicodeDecodeError, TypeError): return None @@ -1966,10 +2065,10 @@ def refresh_expires_at(self) -> Optional[datetime]: """Get refresh token expiration time for debugging.""" if not self.refresh_token_expires: return None - + try: dt = datetime.fromisoformat( - self.refresh_token_expires.replace('Z', '+00:00') + self.refresh_token_expires.replace("Z", "+00:00") ) # Ensure it's UTC if no timezone info if dt.tzinfo is None: diff --git a/tests/unit/test_auth_and_requests.py b/tests/unit/test_auth_and_requests.py index 943e159..d4ec801 100644 --- a/tests/unit/test_auth_and_requests.py +++ b/tests/unit/test_auth_and_requests.py @@ -12,252 +12,274 @@ class TestAuthentication: """Test authentication lifecycle and token management.""" - + @pytest.mark.asyncio async def test_login_stores_tokens(self): """Test that successful login stores JWT and refresh tokens.""" async with CheckwattManager("test_user", "test_pass") as manager: - - with patch('aiohttp.ClientSession.post') as mock_post, \ - patch('aiohttp.ClientSession.get') as mock_get: - + + with patch("aiohttp.ClientSession.post") as mock_post, patch( + "aiohttp.ClientSession.get" + ) as mock_get: + # Mock kill switch check mock_killswitch = AsyncMock() mock_killswitch.status = 200 mock_killswitch.text = AsyncMock(return_value="0") mock_get.return_value.__aenter__.return_value = mock_killswitch - + # Mock login response with refresh token expires mock_login = AsyncMock() mock_login.status = 200 - mock_login.json = AsyncMock(return_value={ - "JwtToken": "test_jwt_token", - "RefreshToken": "test_refresh_token", - "RefreshTokenExpires": "2025-12-31T23:59:59.000+00:00" - }) + mock_login.json = AsyncMock( + return_value={ + "JwtToken": "test_jwt_token", + "RefreshToken": "test_refresh_token", + "RefreshTokenExpires": "2025-12-31T23:59:59.000+00:00", + } + ) mock_post.return_value.__aenter__.return_value = mock_login - + result = await manager.login() - + assert result is True assert manager.jwt_token == "test_jwt_token" assert manager.refresh_token == "test_refresh_token" assert manager.refresh_token_expires == "2025-12-31T23:59:59.000+00:00" - + @pytest.mark.asyncio async def test_jwt_validity_check(self): """Test JWT validity checking.""" manager = CheckwattManager("test_user", "test_pass") - + # Test with no token assert manager._jwt_is_valid() is False - + # Test with invalid JWT format manager.jwt_token = "invalid.jwt.format" assert manager._jwt_is_valid() is False - + # Test with valid JWT structure but invalid content manager.jwt_token = "header.payload.signature" assert manager._jwt_is_valid() is False # Should fail due to invalid base64 - + @pytest.mark.asyncio async def test_refresh_token_validity_check(self): """Test refresh token validity checking.""" manager = CheckwattManager("test_user", "test_pass") - + # Test with no tokens assert manager._refresh_is_valid() is False - + # Test with valid refresh token manager.refresh_token = "test_refresh" manager.refresh_token_expires = "2025-12-31T23:59:59.000+00:00" - + # Test with expired token (future date) manager.refresh_token_expires = "2020-12-31T23:59:59.000+00:00" assert manager._refresh_is_valid() is False - + # Test with valid token (future date) manager.refresh_token_expires = "2030-12-31T23:59:59.000+00:00" assert manager._refresh_is_valid() is True - + @pytest.mark.asyncio async def test_token_refresh_success(self): """Test successful token refresh.""" async with CheckwattManager("test_user", "test_pass") as manager: manager.refresh_token = "test_refresh_token" - - with patch('aiohttp.ClientSession.get') as mock_get: + + with patch("aiohttp.ClientSession.get") as mock_get: mock_response = AsyncMock() mock_response.status = 200 - mock_response.json = AsyncMock(return_value={ - "JwtToken": "new_jwt_token", - "RefreshToken": "new_refresh_token", - "RefreshTokenExpires": "2025-12-31T23:59:59.000+00:00" - }) + mock_response.json = AsyncMock( + return_value={ + "JwtToken": "new_jwt_token", + "RefreshToken": "new_refresh_token", + "RefreshTokenExpires": "2025-12-31T23:59:59.000+00:00", + } + ) mock_get.return_value.__aenter__.return_value = mock_response - + result = await manager._refresh() - + assert result is True assert manager.jwt_token == "new_jwt_token" assert manager.refresh_token == "new_refresh_token" assert manager.refresh_token_expires == "2025-12-31T23:59:59.000+00:00" - + @pytest.mark.asyncio async def test_token_refresh_failure(self): """Test token refresh failure handling.""" async with CheckwattManager("test_user", "test_pass") as manager: manager.refresh_token = "test_refresh_token" initial_jwt = manager.jwt_token # Capture initial value - - with patch('aiohttp.ClientSession.get') as mock_get: + + with patch("aiohttp.ClientSession.get") as mock_get: mock_response = AsyncMock() mock_response.status = 401 # Unauthorized mock_get.return_value.__aenter__.return_value = mock_response - + result = await manager._refresh() - + assert result is False # Tokens should remain unchanged from initial value assert manager.jwt_token == initial_jwt - + @pytest.mark.asyncio async def test_ensure_token_with_valid_jwt(self): """Test _ensure_token returns True with valid JWT.""" manager = CheckwattManager("test_user", "test_pass") manager.jwt_token = "valid_jwt" - - with patch.object(manager, '_jwt_is_valid', return_value=True): + + with patch.object(manager, "_jwt_is_valid", return_value=True): result = await manager._ensure_token() assert result is True - + @pytest.mark.asyncio async def test_ensure_token_with_refresh(self): """Test _ensure_token uses refresh token when JWT is invalid.""" manager = CheckwattManager("test_user", "test_pass") manager.jwt_token = "expired_jwt" manager.refresh_token = "valid_refresh" - - with patch.object(manager, '_jwt_is_valid', return_value=False), \ - patch.object(manager, '_refresh_is_valid', return_value=True), \ - patch.object(manager, '_refresh', return_value=True): - + + with patch.object(manager, "_jwt_is_valid", return_value=False), patch.object( + manager, "_refresh_is_valid", return_value=True + ), patch.object(manager, "_refresh", return_value=True): + result = await manager._ensure_token() assert result is True - + @pytest.mark.asyncio async def test_ensure_token_falls_back_to_login(self): """Test _ensure_token falls back to login when refresh fails.""" manager = CheckwattManager("test_user", "test_pass") manager.jwt_token = "expired_jwt" manager.refresh_token = "expired_refresh" - - with patch.object(manager, '_jwt_is_valid', return_value=False), \ - patch.object(manager, '_refresh_is_valid', return_value=False), \ - patch.object(manager, 'login', return_value=True): - + + with patch.object(manager, "_jwt_is_valid", return_value=False), patch.object( + manager, "_refresh_is_valid", return_value=False + ), patch.object(manager, "login", return_value=True): + result = await manager._ensure_token() assert result is True class TestHttpRequestHandling: """Test the centralized _request wrapper.""" - + @pytest.mark.asyncio async def test_request_with_auth_required(self): """Test _request ensures authentication when required.""" async with CheckwattManager("test_user", "test_pass") as manager: - - with patch.object(manager, '_ensure_token', return_value=True) as mock_ensure, \ - patch.object(manager.session, 'request') as mock_request: - + + with patch.object( + manager, "_ensure_token", return_value=True + ) as mock_ensure, patch.object(manager.session, "request") as mock_request: + mock_response = AsyncMock() mock_response.status = 200 - mock_response.headers = {'Content-Type': 'application/json'} + mock_response.headers = {"Content-Type": "application/json"} mock_response.json = AsyncMock(return_value={"data": "test"}) mock_response.raise_for_status = Mock() mock_request.return_value.__aenter__.return_value = mock_response - + result = await manager._request("GET", "/test", auth_required=True) - + mock_ensure.assert_called_once() assert result == {"data": "test"} - + @pytest.mark.asyncio async def test_request_without_auth(self): """Test _request skips authentication when not required.""" async with CheckwattManager("test_user", "test_pass") as manager: - - with patch.object(manager, '_ensure_token') as mock_ensure, \ - patch.object(manager.session, 'request') as mock_request: - + + with patch.object(manager, "_ensure_token") as mock_ensure, patch.object( + manager.session, "request" + ) as mock_request: + mock_response = AsyncMock() mock_response.status = 200 - mock_response.headers = {'Content-Type': 'application/json'} + mock_response.headers = {"Content-Type": "application/json"} mock_response.json = AsyncMock(return_value={"data": "test"}) mock_response.raise_for_status = Mock() mock_request.return_value.__aenter__.return_value = mock_response - + result = await manager._request("GET", "/test", auth_required=False) - + mock_ensure.assert_not_called() assert result == {"data": "test"} - + @pytest.mark.asyncio async def test_request_401_handling(self): """Test _request handles 401 with refresh and login retry.""" async with CheckwattManager("test_user", "test_pass") as manager: - - with patch.object(manager, '_ensure_token', return_value=True), \ - patch.object(manager, '_refresh', return_value=True), \ - patch.object(manager, 'login', return_value=True), \ - patch.object(manager.session, 'request') as mock_request: - + + with patch.object( + manager, "_ensure_token", return_value=True + ), patch.object(manager, "_refresh", return_value=True), patch.object( + manager, "login", return_value=True + ), patch.object( + manager.session, "request" + ) as mock_request: + # First request returns 401, second succeeds mock_response1 = AsyncMock() mock_response1.status = 401 mock_response1.raise_for_status.side_effect = Exception("401") - + mock_response2 = AsyncMock() mock_response2.status = 200 - mock_response2.headers = {'Content-Type': 'application/json'} + mock_response2.headers = {"Content-Type": "application/json"} mock_response2.json = AsyncMock(return_value={"data": "success"}) mock_response2.raise_for_status = Mock() - - mock_request.return_value.__aenter__.side_effect = [mock_response1, mock_response2] - - result = await manager._request("GET", "/test", auth_required=True, retry_on_401=True) - + + mock_request.return_value.__aenter__.side_effect = [ + mock_response1, + mock_response2, + ] + + result = await manager._request( + "GET", "/test", auth_required=True, retry_on_401=True + ) + assert result == {"data": "success"} - + @pytest.mark.asyncio async def test_request_429_handling_with_retry_after(self): """Test _request handles 429 with Retry-After header.""" async with CheckwattManager("test_user", "test_pass") as manager: - - with patch.object(manager, '_ensure_token', return_value=True), \ - patch.object(manager.session, 'request') as mock_request, \ - patch('asyncio.sleep') as mock_sleep: - + + with patch.object( + manager, "_ensure_token", return_value=True + ), patch.object(manager.session, "request") as mock_request, patch( + "asyncio.sleep" + ) as mock_sleep: + # First request returns 429, second succeeds mock_response1 = AsyncMock() mock_response1.status = 429 - mock_response1.headers = {'Retry-After': '2'} + mock_response1.headers = {"Retry-After": "2"} mock_response1.raise_for_status.side_effect = Exception("429") - + mock_response2 = AsyncMock() mock_response2.status = 200 - mock_response2.headers = {'Content-Type': 'application/json'} + mock_response2.headers = {"Content-Type": "application/json"} mock_response2.json = AsyncMock(return_value={"data": "success"}) mock_response2.raise_for_status = Mock() - - mock_request.return_value.__aenter__.side_effect = [mock_response1, mock_response2] - - result = await manager._request("GET", "/test", auth_required=True, retry_on_429=True) - + + mock_request.return_value.__aenter__.side_effect = [ + mock_response1, + mock_response2, + ] + + result = await manager._request( + "GET", "/test", auth_required=True, retry_on_429=True + ) + mock_sleep.assert_called_once_with(2) assert result == {"data": "success"} - + @pytest.mark.asyncio async def test_request_429_handling_with_exponential_backoff(self): """Test _request handles 429 with exponential backoff when no Retry-After.""" @@ -265,33 +287,42 @@ async def test_request_429_handling_with_exponential_backoff(self): manager.max_retries_429 = 2 manager.backoff_base = 1.0 manager.backoff_factor = 2.0 - - with patch.object(manager, '_ensure_token', return_value=True), \ - patch.object(manager.session, 'request') as mock_request, \ - patch('asyncio.sleep') as mock_sleep, \ - patch('random.uniform', return_value=0.1): - + + with patch.object( + manager, "_ensure_token", return_value=True + ), patch.object(manager.session, "request") as mock_request, patch( + "asyncio.sleep" + ) as mock_sleep, patch( + "random.uniform", return_value=0.1 + ): + # First two requests return 429, third succeeds mock_response1 = AsyncMock() mock_response1.status = 429 mock_response1.headers = {} mock_response1.raise_for_status.side_effect = Exception("429") - + mock_response2 = AsyncMock() mock_response2.status = 429 mock_response2.headers = {} mock_response2.raise_for_status.side_effect = Exception("429") - + mock_response3 = AsyncMock() mock_response3.status = 200 - mock_response3.headers = {'Content-Type': 'application/json'} + mock_response3.headers = {"Content-Type": "application/json"} mock_response3.json = AsyncMock(return_value={"data": "success"}) mock_response3.raise_for_status = Mock() - - mock_request.return_value.__aenter__.side_effect = [mock_response1, mock_response2, mock_response3] - - result = await manager._request("GET", "/test", auth_required=True, retry_on_429=True) - + + mock_request.return_value.__aenter__.side_effect = [ + mock_response1, + mock_response2, + mock_response3, + ] + + result = await manager._request( + "GET", "/test", auth_required=True, retry_on_429=True + ) + # Should sleep twice with exponential backoff assert mock_sleep.call_count == 2 # First sleep: 1.0 * 2^0 + 0.1 = 1.1 @@ -299,106 +330,109 @@ async def test_request_429_handling_with_exponential_backoff(self): mock_sleep.assert_any_call(1.1) mock_sleep.assert_any_call(2.1) assert result == {"data": "success"} - + @pytest.mark.asyncio async def test_request_max_retries_exceeded(self): """Test _request stops retrying after max attempts.""" async with CheckwattManager("test_user", "test_pass") as manager: # Verify that the method exists and has the right signature - assert hasattr(manager, '_request') + assert hasattr(manager, "_request") assert callable(manager._request) - + # Verify that max_retries_429 is configurable manager.max_retries_429 = 5 assert manager.max_retries_429 == 5 - + # Test that the method can be called (basic functionality) - with patch.object(manager, '_ensure_token', return_value=True), \ - patch.object(manager.session, 'request') as mock_request: - + with patch.object( + manager, "_ensure_token", return_value=True + ), patch.object(manager.session, "request") as mock_request: + # Mock a successful response mock_response = Mock() mock_response.status = 200 - mock_response.headers = {'Content-Type': 'application/json'} + mock_response.headers = {"Content-Type": "application/json"} mock_response.json = AsyncMock(return_value={"data": "test"}) mock_response.raise_for_status = Mock() mock_request.return_value.__aenter__.return_value = mock_response - + result = await manager._request("GET", "/test", auth_required=True) - + # Should return the response data assert result == {"data": "test"} - + @pytest.mark.asyncio async def test_request_content_type_handling(self): """Test _request handles different content types correctly.""" async with CheckwattManager("test_user", "test_pass") as manager: - - with patch.object(manager, '_ensure_token', return_value=True), \ - patch.object(manager.session, 'request') as mock_request: - + + with patch.object( + manager, "_ensure_token", return_value=True + ), patch.object(manager.session, "request") as mock_request: + # Test JSON response mock_response = AsyncMock() mock_response.status = 200 - mock_response.headers = {'Content-Type': 'application/json'} + mock_response.headers = {"Content-Type": "application/json"} mock_response.json = AsyncMock(return_value={"data": "json"}) mock_response.raise_for_status = Mock() mock_request.return_value.__aenter__.return_value = mock_response - + result = await manager._request("GET", "/test", auth_required=True) assert result == {"data": "json"} - + # Test text response - mock_response.headers = {'Content-Type': 'text/plain'} + mock_response.headers = {"Content-Type": "text/plain"} mock_response.text = AsyncMock(return_value="plain text") - + result = await manager._request("GET", "/test", auth_required=True) assert result == "plain text" class TestConcurrencyControl: """Test concurrency control mechanisms.""" - + @pytest.mark.asyncio async def test_auth_lock_prevents_duplicate_refresh(self): """Test that auth lock prevents multiple concurrent refresh attempts.""" async with CheckwattManager("test_user", "test_pass") as manager: # Verify that the lock exists - assert hasattr(manager, '_auth_lock') + assert hasattr(manager, "_auth_lock") assert isinstance(manager._auth_lock, asyncio.Lock) - + # Test basic lock functionality async with manager._auth_lock: # Lock should be acquired assert manager._auth_lock.locked() - + # Lock should be released assert not manager._auth_lock.locked() - + @pytest.mark.asyncio async def test_request_semaphore_limits_concurrency(self): """Test that request semaphore limits concurrent outbound requests.""" async with CheckwattManager("test_user", "test_pass") as manager: manager.max_concurrent_requests = 2 - - with patch.object(manager, '_ensure_token', return_value=True), \ - patch.object(manager.session, 'request') as mock_request: - + + with patch.object( + manager, "_ensure_token", return_value=True + ), patch.object(manager.session, "request") as mock_request: + mock_response = AsyncMock() mock_response.status = 200 - mock_response.headers = {'Content-Type': 'application/json'} + mock_response.headers = {"Content-Type": "application/json"} mock_response.json = AsyncMock(return_value={"data": "test"}) mock_response.raise_for_status = Mock() mock_request.return_value.__aenter__.return_value = mock_response - + # Simulate multiple concurrent requests async def make_request(): return await manager._request("GET", "/test", auth_required=True) - + # Start 5 requests concurrently tasks = [make_request() for _ in range(5)] results = await asyncio.gather(*tasks) - + # All should succeed assert all(results) # But semaphore should have limited concurrency @@ -407,59 +441,65 @@ async def make_request(): class TestSecurityAndLogging: """Test security and logging features.""" - + @pytest.mark.asyncio async def test_sensitive_headers_not_logged(self): """Test that sensitive headers are not logged.""" async with CheckwattManager("test_user", "test_pass") as manager: - - with patch('pycheckwatt._LOGGER') as mock_logger: + + with patch("pycheckwatt._LOGGER") as mock_logger: # Test handle_client_error headers = { "authorization": "Bearer secret_token", "cookie": "session=secret_session", "content-type": "application/json", - "user-agent": "test-agent" + "user-agent": "test-agent", } - - await manager.handle_client_error("/test", headers, Exception("test error")) - + + await manager.handle_client_error( + "/test", headers, Exception("test error") + ) + # Check that error was logged mock_logger.error.assert_called_once() - + # Check that sensitive headers were removed call_args = mock_logger.error.call_args[0] logged_headers = call_args[2] # Headers are the third argument - + assert "authorization" not in logged_headers assert "cookie" not in logged_headers assert "content-type" in logged_headers assert "user-agent" in logged_headers - + @pytest.mark.asyncio async def test_request_logs_safe_headers(self): """Test that _request logs headers without sensitive information.""" async with CheckwattManager("test_user", "test_pass") as manager: - - with patch.object(manager, '_ensure_token', return_value=True), \ - patch.object(manager.session, 'request') as mock_request, \ - patch('pycheckwatt._LOGGER') as mock_logger: - + + with patch.object( + manager, "_ensure_token", return_value=True + ), patch.object(manager.session, "request") as mock_request, patch( + "pycheckwatt._LOGGER" + ) as mock_logger: + mock_response = AsyncMock() mock_response.status = 200 - mock_response.headers = {'Content-Type': 'application/json'} + mock_response.headers = {"Content-Type": "application/json"} mock_response.json = AsyncMock(return_value={"data": "test"}) mock_response.raise_for_status = Mock() mock_request.return_value.__aenter__.return_value = mock_response - + # Make request with sensitive headers headers = { "authorization": "Bearer secret_token", - "x-custom": "custom_value" + "x-custom": "custom_value", } - - await manager._request("GET", "/test", headers=headers, auth_required=True) - + + await manager._request( + "GET", "/test", headers=headers, auth_required=True + ) + # Verify no sensitive data in logs for call in mock_logger.debug.call_args_list: call_str = str(call) @@ -469,42 +509,42 @@ async def test_request_logs_safe_headers(self): class TestConfiguration: """Test configuration parameter handling.""" - + def test_default_configuration(self): """Test default configuration values.""" manager = CheckwattManager("test_user", "test_pass") - + assert manager.max_retries_429 == 3 assert manager.backoff_base == 0.5 assert manager.backoff_factor == 2.0 assert manager.backoff_max == 30.0 assert manager.clock_skew_seconds == 10 assert manager.max_concurrent_requests == 5 - + def test_custom_configuration(self): """Test custom configuration values.""" manager = CheckwattManager( - "test_user", + "test_user", "test_pass", max_retries_429=5, backoff_base=1.0, backoff_factor=3.0, backoff_max=60.0, clock_skew_seconds=120, - max_concurrent_requests=10 + max_concurrent_requests=10, ) - + assert manager.max_retries_429 == 5 assert manager.backoff_base == 1.0 assert manager.backoff_factor == 3.0 assert manager.backoff_max == 60.0 assert manager.clock_skew_seconds == 120 assert manager.max_concurrent_requests == 10 - + def test_backwards_compatibility(self): """Test that existing constructor signature still works.""" manager = CheckwattManager("test_user", "test_pass", "CustomApp") - + assert manager.username == "test_user" assert manager.password == "test_pass" assert manager.header_identifier == "CustomApp" @@ -515,42 +555,44 @@ def test_backwards_compatibility(self): class TestTokenExpirationParsing: """Test token debugging properties.""" - + def test_jwt_expires_at_property(self): """Test jwt_expires_at property for debugging.""" manager = CheckwattManager("test_user", "test_pass") - + # Test with no token assert manager.jwt_expires_at is None - + # Test with valid JWT structure - with patch('pycheckwatt.base64') as mock_base64, \ - patch('pycheckwatt.json') as mock_json, \ - patch('pycheckwatt.datetime') as mock_datetime: - - mock_base64.urlsafe_b64decode.return_value = json.dumps({"exp": 1735732800}).encode() + with patch("pycheckwatt.base64") as mock_base64, patch( + "pycheckwatt.json" + ) as mock_json, patch("pycheckwatt.datetime") as mock_datetime: + + mock_base64.urlsafe_b64decode.return_value = json.dumps( + {"exp": 1735732800} + ).encode() mock_json.loads.return_value = {"exp": 1735732800} mock_datetime.fromtimestamp.return_value = datetime(2025, 1, 1, 13, 0, 0) - + manager.jwt_token = "header.payload.signature" - + expires_at = manager.jwt_expires_at assert expires_at is not None assert isinstance(expires_at, datetime) - + def test_refresh_expires_at_property(self): """Test refresh_expires_at property for debugging.""" manager = CheckwattManager("test_user", "test_pass") - + # Test with no refresh token expires assert manager.refresh_expires_at is None - + # Test with valid timestamp manager.refresh_token_expires = "2025-12-31T23:59:59.000+00:00" - + expires_at = manager.refresh_expires_at assert expires_at is not None assert isinstance(expires_at, datetime) assert expires_at.year == 2025 assert expires_at.month == 12 - assert expires_at.day == 31 \ No newline at end of file + assert expires_at.day == 31 diff --git a/tests/unit/test_checkwatt_manager.py b/tests/unit/test_checkwatt_manager.py index b6e4fb2..61fd2af 100644 --- a/tests/unit/test_checkwatt_manager.py +++ b/tests/unit/test_checkwatt_manager.py @@ -107,8 +107,9 @@ async def test_get_customer_details_success(self): manager.jwt_token = "test_token" - with patch.object(manager, '_request') as mock_request, \ - patch.object(manager, 'ensure_authenticated', return_value=True): + with patch.object(manager, "_request") as mock_request, patch.object( + manager, "ensure_authenticated", return_value=True + ): mock_request.return_value = SAMPLE_CUSTOMER_DETAILS_JSON result = await manager.get_customer_details() @@ -123,11 +124,12 @@ async def test_customer_details_populates_battery_registration(self): async with CheckwattManager("test_user", "test_pass") as manager: manager.jwt_token = "test_token" - - with patch.object(manager, '_request') as mock_request, \ - patch.object(manager, 'ensure_authenticated', return_value=True): + + with patch.object(manager, "_request") as mock_request, patch.object( + manager, "ensure_authenticated", return_value=True + ): mock_request.return_value = SAMPLE_CUSTOMER_DETAILS_JSON - + await manager.get_customer_details() # Verify battery registration was extracted from logbook @@ -142,11 +144,12 @@ async def test_customer_details_extracts_fcrd_state(self): async with CheckwattManager("test_user", "test_pass") as manager: manager.jwt_token = "test_token" - - with patch.object(manager, '_request') as mock_request, \ - patch.object(manager, 'ensure_authenticated', return_value=True): + + with patch.object(manager, "_request") as mock_request, patch.object( + manager, "ensure_authenticated", return_value=True + ): mock_request.return_value = SAMPLE_CUSTOMER_DETAILS_JSON - + await manager.get_customer_details() # Verify FCR-D state was extracted @@ -164,11 +167,12 @@ async def authenticated_manager(self): async with CheckwattManager("test_user", "test_pass") as manager: manager.jwt_token = "test_token" - - with patch.object(manager, '_request') as mock_request, \ - patch.object(manager, 'ensure_authenticated', return_value=True): + + with patch.object(manager, "_request") as mock_request, patch.object( + manager, "ensure_authenticated", return_value=True + ): mock_request.return_value = SAMPLE_CUSTOMER_DETAILS_JSON - + await manager.get_customer_details() yield manager @@ -227,11 +231,13 @@ async def test_get_power_data_success(self): async with CheckwattManager("test_user", "test_pass") as manager: manager.jwt_token = "test_token" - manager.customer_details = SAMPLE_CUSTOMER_DETAILS_JSON # Needed for endpoint building - + manager.customer_details = ( + SAMPLE_CUSTOMER_DETAILS_JSON # Needed for endpoint building + ) - with patch.object(manager, '_request') as mock_request, \ - patch.object(manager, 'ensure_authenticated', return_value=True): + with patch.object(manager, "_request") as mock_request, patch.object( + manager, "ensure_authenticated", return_value=True + ): mock_request.return_value = SAMPLE_POWER_DATA_RESPONSE result = await manager.get_power_data() @@ -248,15 +254,20 @@ async def test_energy_properties_after_power_data_load(self): manager.jwt_token = "test_token" manager.customer_details = SAMPLE_CUSTOMER_DETAILS_JSON - with patch.object(manager, '_request') as mock_request, \ - patch.object(manager, 'ensure_authenticated', return_value=True): + with patch.object(manager, "_request") as mock_request, patch.object( + manager, "ensure_authenticated", return_value=True + ): mock_request.return_value = SAMPLE_POWER_DATA_RESPONSE await manager.get_power_data() assert manager.total_solar_energy == 11124779.0 # 2848509.0 + 8276270.0 - assert manager.total_charging_energy == 4700000.0 # 1500000.0 + 3200000.0 - assert manager.total_discharging_energy == 4000000.0 # 1200000.0 + 2800000.0 + assert ( + manager.total_charging_energy == 4700000.0 + ) # 1500000.0 + 3200000.0 + assert ( + manager.total_discharging_energy == 4000000.0 + ) # 1200000.0 + 2800000.0 assert manager.total_import_energy == 8098842.0 # 3104554.0 + 4994288.0 assert manager.total_export_energy == 8040738.0 # 2899531.0 + 5141207.0 @@ -283,16 +294,19 @@ async def test_fcrd_revenue_methods_success(self): manager.jwt_token = "test_token" # Load customer details first (provides RPI serial) - with patch.object(manager, '_request') as mock_request, \ - patch.object(manager, 'ensure_authenticated', return_value=True): + with patch.object(manager, "_request") as mock_request, patch.object( + manager, "ensure_authenticated", return_value=True + ): mock_request.return_value = SAMPLE_CUSTOMER_DETAILS_JSON await manager.get_customer_details() # Mock FCR-D revenue calls - with patch.object(manager, 'get_site_id', return_value="test_site_123"), \ - patch.object(manager, '_request') as mock_request, \ - patch.object(manager, 'ensure_authenticated', return_value=True): + with patch.object( + manager, "get_site_id", return_value="test_site_123" + ), patch.object(manager, "_request") as mock_request, patch.object( + manager, "ensure_authenticated", return_value=True + ): mock_request.return_value = SAMPLE_FCRD_RESPONSE @@ -308,7 +322,9 @@ async def test_fcrd_revenue_methods_success(self): assert manager.revenue is not None assert manager.revenueyear is not None - assert manager.revenuemonth == 61.44 # Sum of FCR-D revenues: 20.11 + 20.13 + 21.07 + 0.13 + assert ( + manager.revenuemonth == 61.44 + ) # Sum of FCR-D revenues: 20.11 + 20.13 + 21.07 + 0.13 class TestEMSSettings: @@ -322,8 +338,9 @@ async def test_get_ems_settings_success(self): manager.jwt_token = "test_token" manager.customer_details = SAMPLE_CUSTOMER_DETAILS_JSON - with patch.object(manager, '_request') as mock_request, \ - patch.object(manager, 'ensure_authenticated', return_value=True): + with patch.object(manager, "_request") as mock_request, patch.object( + manager, "ensure_authenticated", return_value=True + ): mock_request.return_value = SAMPLE_EMS_SETTINGS_RESPONSE result = await manager.get_ems_settings() @@ -333,9 +350,6 @@ async def test_get_ems_settings_success(self): assert manager.ems == SAMPLE_EMS_SETTINGS_RESPONSE - - - class TestFCRDStateExtraction: """Test FCR-D state extraction from logbook entries.""" @@ -346,10 +360,10 @@ def setup_method(self): def test_fail_activation_with_retry_count_and_complex_power(self): """Test parsing of FAIL ACTIVATION entries with retry count and complex power format.""" log_entry = "[ FCR-D FAIL ACTIVATION ] 54x test@example.com --12345-- 85,9/0,6/97,0 % (10,0/10,0 kW) 2025-04-24 00:02:57 API-BACKEND" - + self.manager.logbook_entries = [log_entry] self.manager._extract_fcr_d_state() - + assert self.manager.fcrd_state == "FAIL ACTIVATION" assert self.manager.fcrd_percentage_up == "85,9" assert self.manager.fcrd_percentage_response == "0,6" @@ -360,10 +374,10 @@ def test_fail_activation_with_retry_count_and_complex_power(self): def test_activated_with_complex_power_format(self): """Test parsing of ACTIVATED entries with complex power format.""" log_entry = "[ FCR-D ACTIVATED ] test@example.com --12345-- 96,5/4,0/106,3 % (10,0/10,0 kW) 2025-08-07 00:04:45 API-BACKEND" - + self.manager.logbook_entries = [log_entry] self.manager._extract_fcr_d_state() - + assert self.manager.fcrd_state == "ACTIVATED" assert self.manager.fcrd_percentage_up == "96,5" assert self.manager.fcrd_percentage_response == "4,0" @@ -374,10 +388,10 @@ def test_activated_with_complex_power_format(self): def test_deactivate_with_frequency_up_hz(self): """Test parsing of DEACTIVATE entries with UP frequency.""" log_entry = "[ FCR-D DEACTIVATE ] UP 49,83 Hz 0,0 % (10 kW) - 2025-08-06 17:58:07 API-BACKEND" - + self.manager.logbook_entries = [log_entry] self.manager._extract_fcr_d_state() - + assert self.manager.fcrd_state == "DEACTIVATE" # For DEACTIVATE, the percentage info goes to fcrd_info assert self.manager.fcrd_power == "10" @@ -389,10 +403,10 @@ def test_multiple_entries_first_match_used(self): "[ FCR-D ACTIVATED ] test@example.com --12345-- 97,7/0,5/99,3 % (7 kW) 2024-07-07 00:08:19 API-BACKEND", "[ FCR-D FAIL ACTIVATION ] 54x test@example.com --12345-- 85,9/0,6/97,0 % (10,0/10,0 kW) 2025-04-24 00:02:57 API-BACKEND", ] - + self.manager.logbook_entries = log_entries self.manager._extract_fcr_d_state() - + # Should use the first entry (ACTIVATED) assert self.manager.fcrd_state == "ACTIVATED" assert self.manager.fcrd_power == "7" diff --git a/tests/unit/test_session_management.py b/tests/unit/test_session_management.py index dcad829..592bd86 100644 --- a/tests/unit/test_session_management.py +++ b/tests/unit/test_session_management.py @@ -20,12 +20,14 @@ # Try to import the required dependencies try: from cryptography.fernet import Fernet + CRYPTOGRAPHY_AVAILABLE = True except ImportError: CRYPTOGRAPHY_AVAILABLE = False try: import aiofiles + AIOFILES_AVAILABLE = True except ImportError: AIOFILES_AVAILABLE = False @@ -39,7 +41,7 @@ class TestSessionPersistence: @pytest.fixture def temp_session_file(self): """Create a temporary session file for testing.""" - with tempfile.NamedTemporaryFile(delete=False, suffix='.json') as f: + with tempfile.NamedTemporaryFile(delete=False, suffix=".json") as f: yield f.name # Clean up only if file still exists try: @@ -54,7 +56,7 @@ def manager(self): username="testuser", password="testpass", persist_sessions=True, - session_file="/tmp/test_session.json" + session_file="/tmp/test_session.json", ) @pytest.fixture @@ -67,43 +69,39 @@ def mock_session(self): def test_enhanced_auth_initialization(self, manager): """Test enhanced authentication initialization.""" assert manager._auth_state is not None - assert 'jwt_token' in manager._auth_state - assert 'refresh_token' in manager._auth_state - assert 'jwt_expires_at' in manager._auth_state - assert 'refresh_expires_at' in manager._auth_state - assert 'last_auth_time' in manager._auth_state - assert 'auth_method' in manager._auth_state - + assert "jwt_token" in manager._auth_state + assert "refresh_token" in manager._auth_state + assert "jwt_expires_at" in manager._auth_state + assert "refresh_expires_at" in manager._auth_state + assert "last_auth_time" in manager._auth_state + assert "auth_method" in manager._auth_state + assert manager._session_config is not None - assert 'persist_sessions' in manager._session_config - assert 'session_file' in manager._session_config - assert 'encrypt_sessions' in manager._session_config + assert "persist_sessions" in manager._session_config + assert "session_file" in manager._session_config + assert "encrypt_sessions" in manager._session_config def test_session_config_without_dependencies(self): """Test session configuration when dependencies are missing.""" - with patch('pycheckwatt.aiofiles', None): - with patch('pycheckwatt.CRYPTOGRAPHY_AVAILABLE', False): + with patch("pycheckwatt.aiofiles", None): + with patch("pycheckwatt.CRYPTOGRAPHY_AVAILABLE", False): manager = CheckwattManager( - username="testuser", - password="testpass", - persist_sessions=True + username="testuser", password="testpass", persist_sessions=True ) - - assert not manager._session_config['persist_sessions'] - assert not manager._session_config['encrypt_sessions'] - + assert not manager._session_config["persist_sessions"] + assert not manager._session_config["encrypt_sessions"] @pytest.mark.asyncio async def test_ensure_authenticated_with_valid_jwt(self, manager, mock_session): """Test ensure_authenticated with valid JWT.""" manager.session = mock_session - + # Set up valid JWT future_time = datetime.now() + timedelta(hours=1) - manager._auth_state['jwt_token'] = 'valid_token' - manager._auth_state['jwt_expires_at'] = future_time - + manager._auth_state["jwt_token"] = "valid_token" + manager._auth_state["jwt_expires_at"] = future_time + result = await manager.ensure_authenticated() assert result is True @@ -111,18 +109,18 @@ async def test_ensure_authenticated_with_valid_jwt(self, manager, mock_session): async def test_ensure_authenticated_with_refresh(self, manager, mock_session): """Test ensure_authenticated with refresh token.""" manager.session = mock_session - + # Set up expired JWT but valid refresh past_time = datetime.now() - timedelta(hours=1) future_time = datetime.now() + timedelta(hours=1) - - manager._auth_state['jwt_token'] = 'expired_token' - manager._auth_state['jwt_expires_at'] = past_time - manager._auth_state['refresh_token'] = 'valid_refresh' - manager._auth_state['refresh_expires_at'] = future_time - + + manager._auth_state["jwt_token"] = "expired_token" + manager._auth_state["jwt_expires_at"] = past_time + manager._auth_state["refresh_token"] = "valid_refresh" + manager._auth_state["refresh_expires_at"] = future_time + # Mock successful refresh - with patch.object(manager, '_refresh_tokens', return_value=True): + with patch.object(manager, "_refresh_tokens", return_value=True): result = await manager.ensure_authenticated() assert result is True @@ -130,47 +128,41 @@ async def test_ensure_authenticated_with_refresh(self, manager, mock_session): async def test_ensure_authenticated_with_login(self, manager, mock_session): """Test ensure_authenticated with password login.""" manager.session = mock_session - + # Set up expired tokens past_time = datetime.now() - timedelta(hours=1) - manager._auth_state['jwt_token'] = 'expired_token' - manager._auth_state['jwt_expires_at'] = past_time - manager._auth_state['refresh_token'] = 'expired_refresh' - manager._auth_state['refresh_expires_at'] = past_time - + manager._auth_state["jwt_token"] = "expired_token" + manager._auth_state["jwt_expires_at"] = past_time + manager._auth_state["refresh_token"] = "expired_refresh" + manager._auth_state["refresh_expires_at"] = past_time + # Mock successful login - with patch.object(manager, '_perform_login', return_value=True): + with patch.object(manager, "_perform_login", return_value=True): result = await manager.ensure_authenticated() assert result is True - - - - - - - - @pytest.mark.skipif(not AIOFILES_AVAILABLE, reason="aiofiles not available") @pytest.mark.asyncio async def test_save_session_success(self, manager, temp_session_file): """Test successful session saving.""" - manager._session_config['session_file'] = temp_session_file - + manager._session_config["session_file"] = temp_session_file + # Set up auth state future_time = datetime.now() + timedelta(hours=1) - manager._auth_state.update({ - 'jwt_token': 'test_token', - 'refresh_token': 'test_refresh', - 'jwt_expires_at': future_time, - 'refresh_expires_at': future_time, - 'last_auth_time': datetime.now(), - 'auth_method': 'password' - }) - + manager._auth_state.update( + { + "jwt_token": "test_token", + "refresh_token": "test_refresh", + "jwt_expires_at": future_time, + "refresh_expires_at": future_time, + "last_auth_time": datetime.now(), + "auth_method": "password", + } + ) + result = await manager._save_session() assert result is True - + # Verify file was created assert os.path.exists(temp_session_file) @@ -178,75 +170,75 @@ async def test_save_session_success(self, manager, temp_session_file): @pytest.mark.asyncio async def test_load_session_success(self, manager, temp_session_file): """Test successful session loading.""" - manager._session_config['session_file'] = temp_session_file - + manager._session_config["session_file"] = temp_session_file + # Create test session data future_time = datetime.now() + timedelta(hours=1) session_data = { - 'version': '1.0', - 'username': 'testuser', - 'auth_state': { - 'jwt_token': 'loaded_token', - 'refresh_token': 'loaded_refresh', - 'jwt_expires_at': future_time, - 'refresh_expires_at': future_time, - 'last_auth_time': datetime.now(), - 'auth_method': 'password' + "version": "1.0", + "username": "testuser", + "auth_state": { + "jwt_token": "loaded_token", + "refresh_token": "loaded_refresh", + "jwt_expires_at": future_time, + "refresh_expires_at": future_time, + "last_auth_time": datetime.now(), + "auth_method": "password", }, - 'timestamp': datetime.now().isoformat() + "timestamp": datetime.now().isoformat(), } - + # Save session data - always encrypted when cryptography is available encrypted_data = manager._encrypt_session_data(session_data) - - with open(temp_session_file, 'w') as f: + + with open(temp_session_file, "w") as f: f.write(encrypted_data) - + # Load session result = await manager._load_session() assert result is True - + # Verify state was restored - assert manager._auth_state['jwt_token'] == 'loaded_token' - assert manager._auth_state['refresh_token'] == 'loaded_refresh' + assert manager._auth_state["jwt_token"] == "loaded_token" + assert manager._auth_state["refresh_token"] == "loaded_refresh" @pytest.mark.asyncio async def test_clear_session(self, manager, temp_session_file): """Test session clearing.""" - manager._session_config['session_file'] = temp_session_file - + manager._session_config["session_file"] = temp_session_file + # Set up auth state - manager._auth_state.update({ - 'jwt_token': 'test_token', - 'refresh_token': 'test_refresh', - 'jwt_expires_at': datetime.now() + timedelta(hours=1), - 'refresh_expires_at': datetime.now() + timedelta(hours=1), - 'last_auth_time': datetime.now(), - 'auth_method': 'password' - }) - + manager._auth_state.update( + { + "jwt_token": "test_token", + "refresh_token": "test_refresh", + "jwt_expires_at": datetime.now() + timedelta(hours=1), + "refresh_expires_at": datetime.now() + timedelta(hours=1), + "last_auth_time": datetime.now(), + "auth_method": "password", + } + ) + # Create session file - with open(temp_session_file, 'w') as f: - f.write('test') - + with open(temp_session_file, "w") as f: + f.write("test") + await manager._clear_session() - + # Verify state was cleared - assert manager._auth_state['jwt_token'] is None - assert manager._auth_state['refresh_token'] is None - + assert manager._auth_state["jwt_token"] is None + assert manager._auth_state["refresh_token"] is None + # Verify file was removed assert not os.path.exists(temp_session_file) - - @pytest.mark.asyncio async def test_public_session_methods(self, manager, temp_session_file): """Test public session management methods.""" # Test load_session result = await manager.load_session(temp_session_file) assert result is False # No file exists yet - + # Test save_session - should fail if no auth state and no aiofiles if not AIOFILES_AVAILABLE: result = await manager.save_session(temp_session_file) @@ -254,50 +246,48 @@ async def test_public_session_methods(self, manager, temp_session_file): else: # Set up some auth state to test saving future_time = datetime.now() + timedelta(hours=1) - manager._auth_state.update({ - 'jwt_token': 'test_token', - 'jwt_expires_at': future_time, - 'refresh_expires_at': future_time, - 'last_auth_time': datetime.now(), - 'auth_method': 'password' - }) + manager._auth_state.update( + { + "jwt_token": "test_token", + "jwt_expires_at": future_time, + "refresh_expires_at": future_time, + "last_auth_time": datetime.now(), + "auth_method": "password", + } + ) result = await manager.save_session(temp_session_file) assert result is True # Should succeed with auth state - + # Test clear_session await manager.clear_session() # Should not raise - + # Test get_session_info info = manager.get_session_info() assert isinstance(info, dict) - assert 'authenticated' in info - - - - + assert "authenticated" in info def test_backward_compatibility(self): """Test that existing functionality still works.""" # Test constructor without new parameters manager = CheckwattManager(username="testuser", password="testpass") - + # Verify default values - these depend on whether dependencies are available if not AIOFILES_AVAILABLE: - assert manager._session_config['persist_sessions'] is False - assert manager._session_config['session_file'] is None + assert manager._session_config["persist_sessions"] is False + assert manager._session_config["session_file"] is None else: - assert manager._session_config['persist_sessions'] is True + assert manager._session_config["persist_sessions"] is True # With our new default session file logic, session_file will have a default path - assert manager._session_config['session_file'] is not None - + assert manager._session_config["session_file"] is not None + # Encryption is always enabled when cryptography is available if CRYPTOGRAPHY_AVAILABLE: - assert manager._session_config['encrypt_sessions'] is True + assert manager._session_config["encrypt_sessions"] is True else: - assert manager._session_config['encrypt_sessions'] is False - + assert manager._session_config["encrypt_sessions"] is False + # Verify existing attributes still exist - assert hasattr(manager, 'jwt_token') - assert hasattr(manager, 'refresh_token') - assert hasattr(manager, 'refresh_token_expires') - assert hasattr(manager, '_ensure_token') # Old method still exists + assert hasattr(manager, "jwt_token") + assert hasattr(manager, "refresh_token") + assert hasattr(manager, "refresh_token_expires") + assert hasattr(manager, "_ensure_token") # Old method still exists