diff --git a/lars/preprocessing/radar_preprocessing.py b/lars/preprocessing/radar_preprocessing.py index a5b9290..50c8e91 100644 --- a/lars/preprocessing/radar_preprocessing.py +++ b/lars/preprocessing/radar_preprocessing.py @@ -6,9 +6,10 @@ import cmweather # noqa -def preprocess_radar_data(file_path, output_path, date=None, +def preprocess_radar_data(file_path, output_path, date=None, radar_field='corrected_reflectivity', x_bounds=(-150000, 150000), y_bounds=(-150000, 150000), + size_px=256, dpi=150, **kwargs): """ Preprocess cf/Radial radar data from a given file path. This module will load the radar data, @@ -20,10 +21,12 @@ def preprocess_radar_data(file_path, output_path, date=None, output_path (str): Path to save the processed .png images. date (str or list): Optional date string to filter radar files, in the format 'YYYYMMDD'. - radar_field (str): The radar field to be processed, + radar_field (str): The radar field to be processed, default is 'corrected_reflectivity'. x_bounds (tuple): The x-axis bounds for plotting in meters. y_bounds (tuple): The y-axis bounds for plotting in meters. + size_px (int): Width and height of the output PNG in pixels. Default is 256. + dpi (int): Dots per inch for the saved figure. Default is 150. **kwargs: @@ -39,9 +42,10 @@ def preprocess_radar_data(file_path, output_path, date=None, if date is not None: if isinstance(date, str): date = [date] + file_list2 = [] for date_str in date: - - file_list = [f for f in file_list if date_str in f] + file_list2.extend([f for f in file_list if date_str in f]) + file_list = file_list2 out_df = pd.DataFrame(columns=['file_path', 'time', 'label', 'ref_min', 'ref_max']) if not "vmin" in kwargs: kwargs['vmin'] = -20 @@ -62,11 +66,12 @@ def preprocess_radar_data(file_path, output_path, date=None, if 'sweep_0' in radar: sweep = radar['sweep_0'] if sweep["sweep_mode"] == 'ppi' or sweep["sweep_mode"] == 'sector': - fig = plt.figure(figsize=(256/150, 256/150)) + fig = plt.figure(figsize=(size_px/dpi, size_px/dpi)) ax = plt.axes() sweep["corrected_reflectivity"].where( sweep["corrected_reflectivity"] > min_ref).plot(x="x", y="y", ax=ax, + add_colorbar=False, **kwargs) min_ref = sweep["corrected_reflectivity"].where( sweep["corrected_reflectivity"] > min_ref).values.min() @@ -75,10 +80,7 @@ def preprocess_radar_data(file_path, output_path, date=None, ax.set_xlim(x_bounds) ax.set_ylim(y_bounds) - ax.set_xlabel('X [m]') - ax.set_ylabel('Y [m]') - ax.set_xticks([-100000, -50000, 0, 50000, 100000]) - ax.set_yticks([-100000, -50000, 0, 50000, 100000]) + fig.tight_layout() file_name = os.path.join(output_path, os.path.basename(file).replace('.nc', '.png')) @@ -86,7 +88,7 @@ def preprocess_radar_data(file_path, output_path, date=None, label = "UNKNOWN" # Placeholder for actual label extraction logic fig.savefig(os.path.join(output_path, os.path.basename(file).replace('.nc', '.png')), - dpi=150) + dpi=dpi, bbox_inches='tight', pad_inches=0) plt.close(fig) out_df.loc[len(out_df)] = [file_name, time_str, label, min_ref, max_ref] diff --git a/pyproject.toml b/pyproject.toml index 4b69fc2..be50e70 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -20,7 +20,7 @@ classifiers = [ "License :: OSI Approved :: MIT License", "Programming Language :: Python :: 3", ] -dependencies = ["xradar"] +dependencies = ["xradar", "scikit-learn", "python-dotenv", "aiohttp", "asksageclient", "pip_system_certs", "requests"] [project.optional-dependencies] dev = ["pytest>=6.0", "pytest-asyncio>=0.21", "black", "flake8", "openai", "xradar", "python-dotenv",