Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 9 additions & 6 deletions solar_consumer/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ async def app(
db_url: str,
save_method: str,
csv_dir: str = None,
country: str = "uk",
country: str = "gb",
historic_or_forecast: str = "generation",
):
"""
Expand All @@ -44,20 +44,23 @@ async def app(
db_url (str): Database connection URL from an environment variable.
save_method (str): Method to save the forecast data. Options are "db" or "csv".
csv_dir (str, optional): Directory to save CSV files if save_method is "csv".
country (str): Country code for fetching data. Default is "uk".
country (str): Country code for fetching data. Default is "gb".
historic_or_forecast: (str): Type of data to fetch. Default is "generation".
"""
logger.info(f"Starting the NESO Solar Forecast pipeline (version: {__version__}).")

# Use the `Neso` class for hardcoded configuration]
if country == "uk":
if country == "gb":
model_tag = "neso-solar-forecast"
elif country == "nl":
model_tag = "ned-nl-national"
elif country == "de":
model_tag = "entsoe-de"
elif country == "be":
model_tag = "elia-be-forecast"
else:
raise ValueError(f"Unsupported country code: {country}")



# Step 1: Fetch forecast data (returns as pd.Dataframe)
Expand All @@ -77,8 +80,8 @@ async def app(

with connection.get_session() as session:

# Step 2: Formate and save the forecast data
# A. Format forecast to database object and save
# Step 2: Formate and save the forecast data
# A. Format forecast to database object and save
logger.info(f"Formatting {len(forecast_data)} rows of forecast data.")
forecasts = format_to_forecast_sql(
data=forecast_data,
Expand Down Expand Up @@ -153,7 +156,7 @@ async def app(
if __name__ == "__main__":
# Step 1: Fetch the database URL from the environment variable
db_url = os.getenv("DB_URL") # Change from "DATABASE_URL" to "DB_URL"
country = os.getenv("COUNTRY", "uk")
country = os.getenv("COUNTRY", "gb")
save_method = os.getenv("SAVE_METHOD", "db").lower() # Default to "db"
csv_dir = os.getenv("CSV_DIR")
historic_or_forecast = os.getenv("HISTORIC_OR_FORECAST", "generation").lower()
Expand Down