diff --git a/docs/user-guide/index.md b/docs/user-guide/index.md index 24eee0c8..d7b25e4f 100644 --- a/docs/user-guide/index.md +++ b/docs/user-guide/index.md @@ -18,4 +18,5 @@ documentation/full_sensor_list.md documentation/copernicus_products.md documentation/pre_download_data.md documentation/example_copernicus_download.ipynb +documentation/full_sensor_list.md ``` diff --git a/pixi.toml b/pixi.toml index ba024968..240c4047 100644 --- a/pixi.toml +++ b/pixi.toml @@ -18,13 +18,12 @@ setuptools = "*" setuptools_scm = "*" [package.run-dependencies] # Keep in sync with `pyproject.toml` and feedstock recipe -python = ">=3.10" +python = "3.11.*" click = "*" -parcels = ">3.1.0" pyproj = ">=3,<4" sortedcontainers = "==2.4.0" opensimplex = "==0.4.5" -numpy = ">=1,<2" +numpy = ">=2.1.0" pydantic = ">=2,<3" pyyaml = "*" copernicusmarine = ">=2.2.2" @@ -33,15 +32,26 @@ textual = "*" [dependencies] virtualship = { path = "." } +# Pre-install as conda packages to avoid PyPI source builds +netcdf4 = "*" +numpy = ">=2.1.0" +dask = "*" +zarr = ">=3" +ipdb = ">=0.13.13,<0.14" +cmocean = ">=4.0.3,<5" -[feature.py310.dependencies] -python = "3.10.*" +[pypi-dependencies] +parcels = { git = "https://github.com/Parcels-code/Parcels", branch = "main" } -[feature.py311.dependencies] -python = "3.11.*" +# Commented out whilst parcels v4 alpha only supports Python 3.11 (?) +# [feature.py310.dependencies] +# python = "3.10.*" + +# [feature.py311.dependencies] +# python = "3.11.*" -[feature.py312.dependencies] -python = "3.12.*" +# [feature.py312.dependencies] +# python = "3.12.*" [feature.test.dependencies] pytest = "*" @@ -98,11 +108,8 @@ lxml = "*" typing = "mypy src/virtualship --install-types" [environments] -default = { features = ["test", "notebooks", "typing", "pre-commit", "analysis"] } +default = { features = ["test", "notebooks", "typing", "pre-commit", "analysis"] } test-latest = { features = ["test"], solve-group = "test" } -test-py310 = { features = ["test", "py310"] } -test-py311 = { features = ["test", "py311"] } -test-py312 = { features = ["test", "py312"] } test-notebooks = { features = ["test", "notebooks"], solve-group = "test" } analysis = { features = ["analysis"], solve-group = "analysis" } docs = { features = ["docs"], solve-group = "docs" } diff --git a/pyproject.toml b/pyproject.toml index 7f9a2108..fd0d612b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -7,7 +7,7 @@ name = "virtualship" description = "Code for the Virtual Ship Classroom, where Marine Scientists can combine Copernicus Marine Data with an OceanParcels ship to go on a virtual expedition." readme = "README.md" dynamic = ["version"] -authors = [{ name = "oceanparcels.org team" }] +authors = [{ name = "parcels-code.org team" }] requires-python = ">=3.10" license = { file = "LICENSE" } classifiers = [ @@ -26,11 +26,11 @@ classifiers = [ ] dependencies = [ "click", - "parcels >3.1.0", + "parcels >=4.0.0alpha", "pyproj >= 3, < 4", "sortedcontainers == 2.4.0", "opensimplex == 0.4.5", - "numpy >=1, < 2", + "numpy >=2.1.0", "pydantic >=2, <3", "PyYAML", "copernicusmarine >= 2.2.2", @@ -40,7 +40,7 @@ dependencies = [ ] [project.urls] -Homepage = "https://oceanparcels.org/" # TODO: Update this to just be repo? +Homepage = "https://virtualship.parcels-code.org/" Repository = "https://github.com/OceanParcels/virtualship" Documentation = "https://virtualship.readthedocs.io/" "Bug Tracker" = "https://github.com/OceanParcels/virtualship/issues" @@ -69,7 +69,8 @@ filterwarnings = [ "error", "default::DeprecationWarning", "error::DeprecationWarning:virtualship", - "ignore:ParticleSet is empty.*:RuntimeWarning" # TODO: Probably should be ignored in the source code + "ignore:ParticleSet is empty.*:RuntimeWarning", # TODO: Probably should be ignored in the source code + "ignore:This is an alpha version of Parcels v4.*:UserWarning" # TODO: necessary whilst Parcels v4 is still alpha ] log_cli_level = "INFO" testpaths = [ diff --git a/src/virtualship/cli/_run.py b/src/virtualship/cli/_run.py index f2622be3..a320acad 100644 --- a/src/virtualship/cli/_run.py +++ b/src/virtualship/cli/_run.py @@ -35,11 +35,9 @@ get_instrument_class, ) -# parcels logger (suppress INFO messages to prevent log being flooded) -external_logger = logging.getLogger("parcels.tools.loggers") -external_logger.setLevel(logging.WARNING) - -# copernicusmarine logger (suppress INFO messages to prevent log being flooded) +# suppress INFO messages from copernicusmarine and parcels loggers; prevent log flooding +parcels_logger = logging.getLogger("parcels._logger") +parcels_logger.setLevel(logging.WARNING) logging.getLogger("copernicusmarine").setLevel("ERROR") @@ -204,7 +202,9 @@ def _run( # execute simulation instrument.execute( measurements=measurements, - out_path=expedition_dir.joinpath(RESULTS, f"{itype.name.lower()}.zarr"), + out_path=expedition_dir.joinpath( + RESULTS, f"{itype.name.lower()}.parquet" + ), ) except Exception as e: # clean up if unexpected error occurs diff --git a/src/virtualship/instruments/adcp.py b/src/virtualship/instruments/adcp.py index b2da6582..e9a55a8d 100644 --- a/src/virtualship/instruments/adcp.py +++ b/src/virtualship/instruments/adcp.py @@ -3,7 +3,7 @@ from typing import ClassVar import numpy as np -from parcels import ParticleSet, ScipyParticle +from parcels import ParticleFile, ParticleSet from virtualship.instruments.base import Instrument from virtualship.instruments.sensors import SensorType @@ -35,9 +35,13 @@ class ADCP: # ===================================================== -def _sample_velocity(particle, fieldset, time): - particle.U, particle.V = fieldset.UV.eval( - time, particle.depth, particle.lat, particle.lon, applyConversion=False +def _sample_velocity(particles, fieldset): + particles.U, particles.V = fieldset.UV.eval( + particles.time, + particles.z, + particles.lat, + particles.lon, + applyConversion=False, ) @@ -96,23 +100,22 @@ def simulate(self, measurements, out_path) -> None: # build dynamic particle class from the active sensors adcp_config = self.expedition.instruments_config.adcp_config _ADCPParticle = build_particle_class_from_sensors( - adcp_config.sensors, _ADCP_NONSENSOR_VARIABLES, ScipyParticle + adcp_config.sensors, _ADCP_NONSENSOR_VARIABLES ) bins = np.linspace(MAX_DEPTH, MIN_DEPTH, NUM_BINS) num_particles = len(bins) - particleset = ParticleSet.from_list( + particleset = ParticleSet( fieldset=fieldset, pclass=_ADCPParticle, lon=np.full( num_particles, 0.0 ), # initial lat/lon are irrelevant and will be overruled later.s lat=np.full(num_particles, 0.0), - depth=bins, - time=0, + z=bins, ) - out_file = particleset.ParticleFile(name=out_path, outputdt=np.inf) + out_file = ParticleFile(path=out_path, outputdt=np.inf) # build kernel list from active sensors only sampling_kernels = [ @@ -121,6 +124,9 @@ def simulate(self, measurements, out_path) -> None: if sc.enabled and sc.sensor_type in self.sensor_kernels ] + # TODO: need to overhaul ADCP/underway instruments generally... don't think this Parcels API works anymore + # TODO: a good time to implement https://github.com/Parcels-code/virtualship/issues/231 + for point in measurements: particleset.lon_nextloop[:] = point.location.lon particleset.lat_nextloop[:] = point.location.lat diff --git a/src/virtualship/instruments/argo_float.py b/src/virtualship/instruments/argo_float.py index 3e8e2ea6..96adcd74 100644 --- a/src/virtualship/instruments/argo_float.py +++ b/src/virtualship/instruments/argo_float.py @@ -1,11 +1,11 @@ -import math from collections.abc import Callable from dataclasses import dataclass from datetime import timedelta from typing import ClassVar import numpy as np -from parcels import AdvectionRK4, JITParticle, ParticleSet, StatusCode, Variable +from parcels import ParticleFile, ParticleSet, StatusCode, Variable +from parcels.kernels import AdvectionRK2 from virtualship.instruments.base import Instrument from virtualship.instruments.sensors import SensorType @@ -13,6 +13,11 @@ from virtualship.models.spacetime import Spacetime from virtualship.utils import build_particle_class_from_sensors, register_instrument +# mapping from StatusCode integer value to attribute name (e.g. 60 -> "ErrorOutOfBounds") +_STATUS_CODE_NAMES: dict[int, str] = { + v: k for k, v in vars(StatusCode).items() if not k.startswith("_") +} + # ===================================================== # SECTION: Dataclass # ===================================================== @@ -53,120 +58,131 @@ class ArgoFloat: # SECTION: Kernels # ===================================================== - -def _argo_float_vertical_movement(particle, fieldset, time): - if particle.cycle_phase == 0: - # Phase 0: Sinking with vertical_speed until depth is drift_depth - particle_ddepth += ( # noqa - particle.vertical_speed * particle.dt - ) - - # bathymetry at particle location - loc_bathy = fieldset.bathymetry.eval( - time, particle.depth, particle.lat, particle.lon - ) - if particle.depth + particle_ddepth <= loc_bathy: - particle_ddepth = loc_bathy - particle.depth + 50.0 # 50m above bathy - particle.cycle_phase = 1 - particle.grounded = 1 - print( - "Shallow bathymetry warning: Argo float grounded at bathymetry depth during sinking to drift depth. Raising by 50m above bathymetry and continuing cycle." - ) - - elif particle.depth + particle_ddepth <= particle.drift_depth: - particle_ddepth = particle.drift_depth - particle.depth - particle.cycle_phase = 1 - - elif particle.cycle_phase == 1: - # Phase 1: Drifting at depth for drifttime seconds - particle.drift_age += particle.dt - if particle.drift_age >= particle.drift_days * 86400: - particle.drift_age = 0 # reset drift_age for next cycle - particle.cycle_phase = 2 - - elif particle.cycle_phase == 2: - # Phase 2: Sinking further to max_depth - particle_ddepth += particle.vertical_speed * particle.dt - loc_bathy = fieldset.bathymetry.eval( - time, particle.depth, particle.lat, particle.lon +# TODO: need to add back in the shallow bathymetry checks (to phases 0 and 2?!) +# TODO: can this be refactored as well to a helper function? + + +def _argo_float_vertical_movement(particles, fieldset): + # Split particles based on their current cycle_phase + ptcls0 = particles[particles.cycle_phase == 0] + ptcls1 = particles[particles.cycle_phase == 1] + ptcls2 = particles[particles.cycle_phase == 2] + ptcls3 = particles[particles.cycle_phase == 3] + ptcls4 = particles[particles.cycle_phase == 4] + + # Phase 0: Sinking with vertical_speed until depth is driftdepth + ptcls0.dz += particles.vertical_speed * ptcls0.dt + loc_bathy = fieldset.bathymetry.eval(ptcls0.time, ptcls0.z, ptcls0.lat, ptcls0.lon) + driftdepth_mask = ptcls0.z + ptcls0.dz >= particles.drift_depth + bathy_mask = ptcls0.z + ptcls0.dz >= loc_bathy + next_phase = np.logical_and( + driftdepth_mask, bathy_mask + ) # combined mask; not at drift depth yet and not hitting bathymetry + ptcls0.cycle_phase[next_phase] = 1 + ptcls0.dz[next_phase] = ( + particles.drift_depth - ptcls0.z[next_phase] + ) # avoid overshoot + + # Phase 0.5: Check for grounding at bathymetry and raise if necessary + ptcls0.grounded[~bathy_mask] = 1 + if np.any(~bathy_mask): + print( + "Shallow bathymetry warning: Argo float grounded at bathymetry depth during sinking to drift depth. Raising by 50m above bathymetry and continuing cycle." ) - if particle.depth + particle_ddepth <= loc_bathy: - particle_ddepth = loc_bathy - particle.depth + 50.0 # 50m above bathy - particle.cycle_phase = 3 - particle.grounded = 1 - print( - "Shallow bathymetry warning: Argo float grounded at bathymetry depth during sinking to max depth. Raising by 50m above bathymetry and continuing cycle." - ) - elif particle.depth + particle_ddepth <= particle.max_depth: - particle_ddepth = particle.max_depth - particle.depth - particle.cycle_phase = 3 - - elif particle.cycle_phase == 3: - # Phase 3: Rising with vertical_speed until at surface - particle_ddepth -= particle.vertical_speed * particle.dt - particle.cycle_age += ( - particle.dt - ) # solve issue of not updating cycle_age during ascent - particle.grounded = 0 - if particle.depth + particle_ddepth >= particle.min_depth: - particle_ddepth = particle.min_depth - particle.depth - particle.cycle_phase = 4 - - elif particle.cycle_phase == 4: - # Phase 4: Transmitting at surface until cycletime is reached - if particle.cycle_age > particle.cycle_days * 86400: - particle.cycle_phase = 0 - particle.cycle_age = 0 - - if particle.state == StatusCode.Evaluate: - particle.cycle_age += particle.dt # update cycle_age - - -def _keep_at_surface(particle, fieldset, time): - # Prevent error when float reaches surface - if particle.state == StatusCode.ErrorThroughSurface: - particle.depth = particle.min_depth - particle.state = StatusCode.Success - - -def _check_error(particle, fieldset, time): - if particle.state >= 50: # This captures all Errors - if particle.state == 50: - print("WARNING: Error during Argo Float simulation...") - elif particle.state == 51: - print("WARNING: ErrorInterpolation during Argo Float simulation...") - elif particle.state == 60: - print("WARNING: ErrorOutOfBounds during Argo Float simulation...") - elif particle.state == 61: - print("WARNING: ErrorThroughSurface during Argo Float simulation...") - elif particle.state == 70: - print("WARNING: ErrorTimeExtrapolation during Argo Float simulation...") - else: - print("Unknown error during Argo Float simulation...") + ptcls0.dz[~bathy_mask] = ( + loc_bathy[~bathy_mask] - ptcls0.z[~bathy_mask] + 50.0 + ) # raise to 50m above bathymetry + ptcls0.cycle_phase[~bathy_mask] = 1 + + # Phase 1: Drifting at depth for drifttime seconds + ptcls1.drift_age += ptcls1.dt + next_phase = ptcls1.drift_age >= particles.drift_days * 86400 # [seconds] + ptcls1.cycle_phase[next_phase] = 2 + ptcls1.drift_age[next_phase] = 0 # reset drift_age for next cycle + + # Phase 2: Sinking further to maxdepth + ptcls2.dz += particles.vertical_speed * ptcls2.dt + loc_bathy = fieldset.bathymetry.eval(ptcls2.time, ptcls2.z, ptcls2.lat, ptcls2.lon) + maxdepth_mask = ptcls2.z + ptcls2.dz >= particles.max_depth + bathy_mask = ptcls2.z + ptcls2.dz >= loc_bathy + next_phase = np.logical_and( + maxdepth_mask, bathy_mask + ) # combined mask; not at max depth yet and not hitting bathymetry + ptcls2.cycle_phase[next_phase] = 3 + ptcls2.dz[next_phase] = ( + particles.max_depth - ptcls2.z[next_phase] + ) # avoid overshoot + + # Phase 2.5: Check for grounding at bathymetry and raise if necessary + ptcls2.grounded[~bathy_mask] = 1 + if np.any(~bathy_mask): print( - "WARNING: An error occured during simulation but the expedition will continue. If ErrorOutOfBounds, consider reducing the lifetime in Argo Float config (the fieldset spatial bounds are constrained under-the-hood). For further advice please contact the VirtualShip team via GitHub (https://github.com/Parcels-code/virtualship/issues) or email (virtualship@uu.nl). Carrying on with the expedition..." + "Shallow bathymetry warning: Argo float grounded at bathymetry depth during sinking to max depth. Raising by 50m above bathymetry and continuing cycle." ) - # TODO: warnings are a bit limited in Parcels v3, but v4 should allow more informative (+ not all these if statements) when e.g. f-strings are supported in kernels - - particle.delete() - - -def _argo_sample_temperature(particle, fieldset, time): + ptcls2.dz[~bathy_mask] = ( + loc_bathy[~bathy_mask] - ptcls2.z[~bathy_mask] + 50.0 + ) # raise to 50m above bathymetry + ptcls2.cycle_phase[~bathy_mask] = 3 + + # Phase 3: Rising with vertical_speed until at surface + ptcls3.dz -= particles.vertical_speed * ptcls3.dt + next_phase = ptcls3.z + ptcls3.dz <= particles.min_depth + ptcls3.cycle_phase[next_phase] = 4 + ptcls3.dz[next_phase] = ( + particles.min_depth - ptcls3.z[next_phase] + ) # avoid overshoot + + # Phase 4: Transmitting at surface until cycletime is reached + next_phase = ptcls4.cycle_age >= particles.cycle_days * 86400 + ptcls4.cycle_phase[next_phase] = 0 + ptcls4.cycle_age[next_phase] = 0 # reset cycle_age for next cycle + ptcls4.temperature = np.nan # no temperature measurement when at surface + + particles.cycle_age += particles.dt # update cycle_age + + +def _keep_at_surface(particles, fieldset): + through_surface = particles.state == StatusCode.ErrorThroughSurface + particles.z[through_surface] = particles.min_depth[through_surface] + particles.state[through_surface] = StatusCode.Success + + +def _check_error(particles, fieldset): + errors = particles.state >= 50 # captures all Errors + # TODO: check print statements are as expected + print( + "WARNING: Error(s) found during Argo Float simulation but the expedition will continue..." + f"\n\nError code(s): {', '.join(_STATUS_CODE_NAMES.get(error, str(error)) + 'at time: ' + str(particles.time[errors][i]) + ', lat: ' + str(particles.lat[errors][i]) + ', lon: ' + str(particles.lon[errors][i]) for i, error in enumerate(particles.state[errors]))}" + "\n\nIf ErrorOutOfBounds, consider reducing the lifetime in Argo Float config (the fieldset spatial bounds are constrained under-the-hood). For further advice please contact the VirtualShip team via GitHub (https://github.com/Parcels-code/virtualship/issues) or email (virtualship@uu.nl)." + "\nCarrying on with the expedition..." + ) + particles.state[errors] = StatusCode.Delete + + +def _argo_sample_temperature(particles, fieldset): # Phase 3: ascending — sample temperature; NaN otherwise - if particle.cycle_phase == 3 and particle.depth < particle.min_depth: - particle.temperature = fieldset.T[ - time, particle.depth, particle.lat, particle.lon - ] - else: - particle.temperature = math.nan - - -def _argo_sample_salinity(particle, fieldset, time): + phase_mask = particles.cycle_phase == 3 + depth_mask = particles.depth < particles.min_depth + sampling_particles = particles[np.logical_and(phase_mask, depth_mask)] + sampling_particles.temperature = fieldset.T[ + sampling_particles.time, + sampling_particles.depth, + sampling_particles.lat, + sampling_particles.lon, + ] + + +def _argo_sample_salinity(particles, fieldset): # Phase 3: ascending — sample salinity; NaN otherwise - if particle.cycle_phase == 3 and particle.depth < particle.min_depth: - particle.salinity = fieldset.S[time, particle.depth, particle.lat, particle.lon] - else: - particle.salinity = math.nan + phase_mask = particles.cycle_phase == 3 + depth_mask = particles.depth < particles.min_depth + sampling_particles = particles[np.logical_and(phase_mask, depth_mask)] + sampling_particles.salinity = fieldset.S[ + sampling_particles.time, + sampling_particles.depth, + sampling_particles.lat, + sampling_particles.lon, + ] # ===================================================== @@ -230,7 +246,7 @@ def simulate(self, measurements, out_path) -> None: shallow_waypoints = {} for i, m in enumerate(measurements): loc_bathy = fieldset.bathymetry.eval( - time=0, + time=np.float64(0), z=0, y=m.spacetime.location.lat, x=m.spacetime.location.lon, @@ -246,9 +262,7 @@ def simulate(self, measurements, out_path) -> None: # build dynamic particle class from the active sensors argo_float_config = self.expedition.instruments_config.argo_float_config _ArgoParticle = build_particle_class_from_sensors( - argo_float_config.sensors, - _ARGO_NONSENSOR_VARIABLES, - JITParticle, + argo_float_config.sensors, _ARGO_NONSENSOR_VARIABLES ) # define parcel particles @@ -257,8 +271,8 @@ def simulate(self, measurements, out_path) -> None: pclass=_ArgoParticle, lat=[argo.spacetime.location.lat for argo in measurements], lon=[argo.spacetime.location.lon for argo in measurements], - depth=[argo.min_depth for argo in measurements], - time=[argo.spacetime.time for argo in measurements], + z=[argo.min_depth for argo in measurements], + time=[np.datetime64(argo.spacetime.time) for argo in measurements], min_depth=[argo.min_depth for argo in measurements], max_depth=[argo.max_depth for argo in measurements], drift_depth=[argo.drift_depth for argo in measurements], @@ -268,14 +282,13 @@ def simulate(self, measurements, out_path) -> None: ) # define output file for the simulation - out_file = argo_float_particleset.ParticleFile( - name=out_path, + out_file = ParticleFile( + path=out_path, outputdt=OUTPUT_DT, - chunks=[len(argo_float_particleset), 100], ) # endtime - endtime = fieldset.time_origin.fulltime(fieldset.U.grid.time_full[-1]) + endtime = fieldset.U.data.time.isel(time=-1).values # build kernel list from active sensors only sampling_kernels = [ @@ -289,7 +302,7 @@ def simulate(self, measurements, out_path) -> None: [ _argo_float_vertical_movement, *sampling_kernels, - AdvectionRK4, + AdvectionRK2, _keep_at_surface, _check_error, ], diff --git a/src/virtualship/instruments/base.py b/src/virtualship/instruments/base.py index d4e078e6..39dd7419 100644 --- a/src/virtualship/instruments/base.py +++ b/src/virtualship/instruments/base.py @@ -8,8 +8,8 @@ from typing import TYPE_CHECKING, ClassVar import copernicusmarine +import parcels import xarray as xr -from parcels import FieldSet from yaspin import yaspin from virtualship.errors import CopernicusCatalogueError @@ -86,7 +86,7 @@ def __init__( self.min_lat, self.max_lat = min(wp_lats), max(wp_lats) self.min_lon, self.max_lon = min(wp_lons), max(wp_lons) - def load_input_data(self) -> FieldSet: + def load_input_data(self) -> parcels.FieldSet: """Load and return the input data as a FieldSet for the instrument.""" try: fieldset = self._generate_fieldset() @@ -97,21 +97,11 @@ def load_input_data(self) -> FieldSet: # interpolation methods for var in (v for v in self.variables if v not in ("U", "V")): - getattr(fieldset, var).interp_method = "linear_invdist_land_tracer" - - # depth negative - for g in fieldset.gridset.grids: - g.negate_depth() + getattr(fieldset, var).interp_method = parcels.interpolators.XLinear # bathymetry data if self.add_bathymetry: - bathymetry_field = _get_bathy_data( - self.min_lat, - self.max_lat, - self.min_lon, - self.max_lon, - from_data=self.from_data, - ).bathymetry + bathymetry_field = _get_bathy_data(from_data=self.from_data).bathymetry bathymetry_field.data = -bathymetry_field.data fieldset.add_field(bathymetry_field) @@ -128,18 +118,22 @@ def simulate( def execute(self, measurements: list, out_path: str | Path) -> None: """Run instrument simulation.""" + TMP = True # TODO: just for dev; remove before merging + instrument_name = self.__class__.__name__.split("Instrument")[0] + if not self.verbose_progress: - with yaspin( - text=f"Simulating {self.__class__.__name__.split('Instrument')[0]} measurements... ", - side="right", - spinner=ship_spinner, - ) as spinner: + if TMP: + with yaspin( + text=f"Simulating {instrument_name} measurements... ", + side="right", + spinner=ship_spinner, + ) as spinner: + self.simulate(measurements, out_path) + spinner.ok("✅\n") + else: self.simulate(measurements, out_path) - spinner.ok("✅\n") else: - print( - f"Simulating {self.__class__.__name__.split('Instrument')[0]} measurements... " - ) + print(f"Simulating {instrument_name} measurements... ") self.simulate(measurements, out_path) print("\n") @@ -183,11 +177,11 @@ def _get_copernicus_ds( coordinates_selection_method="outside", ) - def _generate_fieldset(self) -> FieldSet: + def _generate_fieldset(self) -> parcels.FieldSet: """ Create and combine FieldSets for each variable, supporting both local and Copernicus Marine data sources. - Per variable avoids issues when using copernicusmarine and creating directly one FieldSet of ds's sourced from different Copernicus Marine product IDs, which is often the case for BGC variables. + N.B. Per variable avoids issues when using copernicusmarine and creating directly one FieldSet of ds's sourced from different Copernicus Marine product IDs (which can also have different temporal resolutions), which is often the case for BGC variables. """ fieldsets_list = [] keys = list(self.variables.keys()) @@ -196,12 +190,11 @@ def _generate_fieldset(self) -> FieldSet: for key in keys: var = self.variables[key] + physical = var in COPERNICUSMARINE_PHYS_VARIABLES + + # TODO: do docs on pre-downloading data need to be updated for these changes? Anything about conventions etc.? if self.from_data is not None: # load from local data - physical = var in COPERNICUSMARINE_PHYS_VARIABLES - if physical: - data_dir = self.from_data.joinpath("phys") - else: - data_dir = self.from_data.joinpath("bgc") + data_dir = self.from_data.joinpath("phys" if physical else "bgc") files = _find_files_in_timerange( data_dir, @@ -209,36 +202,44 @@ def _generate_fieldset(self) -> FieldSet: self.max_time + timedelta(days=time_buffer), ) - _, full_var_name = _find_nc_file_with_variable( + _, field_var_name = _find_nc_file_with_variable( data_dir, var ) # get full variable name from one of the files; var may only appear as substring in variable name in file - ds = xr.open_mfdataset( - [data_dir.joinpath(f) for f in files] - ) # using: ds --> .from_xarray_dataset seems more robust than .from_netcdf for handling different temporal resolutions for different variables ... + ds = xr.open_mfdataset([data_dir.joinpath(f) for f in files]) - fs = FieldSet.from_xarray_dataset( - ds, - variables={key: full_var_name}, - dimensions=self.dimensions, - mesh="spherical", - ) else: # stream via Copernicus Marine Service - physical = var in COPERNICUSMARINE_PHYS_VARIABLES ds = self._get_copernicus_ds( time_buffer, physical=physical, var=var, ) - fs = FieldSet.from_xarray_dataset( - ds, {key: var}, self.dimensions, mesh="spherical" - ) + field_var_name = var + + # TODO: I think this is potentially slowing down simulations slightly... compared to v0.3 anyway for *drifters* + ds.load() # TODO: tmp step during v4 alpha stage... probably to be updated on the Parcels end + + fields = {key: ds[field_var_name]} + ds_fset = parcels.convert.copernicusmarine_to_sgrid(fields=fields) + fs = parcels.FieldSet.from_sgrid_conventions(ds_fset) + fieldsets_list.append(fs) base_fieldset = fieldsets_list[0] for fs, key in zip(fieldsets_list[1:], keys[1:], strict=False): base_fieldset.add_field(getattr(fs, key)) + # some instruments use AdvectionRKn kernels which require a combined UV vector field + # fieldsets are created per variable and thus are not seen by from_sgrid_conventions at the same time, therefore build combined VectorField here in FieldSet + if "U" in keys and "V" in keys: + uv = parcels.VectorField( + "UV", + base_fieldset.U, + base_fieldset.V, + vector_interp_method=parcels.interpolators.XLinear_Velocity, + ) + base_fieldset.add_field(uv) + return base_fieldset def _get_spec_value(self, spec_type: str, key: str, default=None): diff --git a/src/virtualship/instruments/ctd.py b/src/virtualship/instruments/ctd.py index 583a099c..9f93b6ac 100644 --- a/src/virtualship/instruments/ctd.py +++ b/src/virtualship/instruments/ctd.py @@ -4,13 +4,14 @@ from typing import TYPE_CHECKING, ClassVar import numpy as np -from parcels import JITParticle, ParticleSet, Variable +from parcels import ParticleFile, ParticleSet, Variable +from parcels._core.statuscodes import StatusCode from virtualship.instruments.base import Instrument from virtualship.instruments.sensors import SensorType from virtualship.instruments.types import InstrumentType from virtualship.utils import ( - add_dummy_UV, + _compute_max_depths, build_particle_class_from_sensors, register_instrument, ) @@ -52,60 +53,90 @@ class CTD: ## physical variables -def _sample_temperature(particle, fieldset, time): - particle.temperature = fieldset.T[time, particle.depth, particle.lat, particle.lon] +def _sample_temperature(particles, fieldset): + particles.temperature = fieldset.T[ + particles.time, particles.z, particles.lat, particles.lon + ] -def _sample_salinity(particle, fieldset, time): - particle.salinity = fieldset.S[time, particle.depth, particle.lat, particle.lon] +def _sample_salinity(particles, fieldset): + particles.salinity = fieldset.S[ + particles.time, particles.z, particles.lat, particles.lon + ] ## bgc variables -def _sample_o2(particle, fieldset, time): - particle.o2 = fieldset.o2[time, particle.depth, particle.lat, particle.lon] +def _sample_o2(particles, fieldset): + particles.o2 = fieldset.o2[ + particles.time, particles.z, particles.lat, particles.lon + ] -def _sample_chlorophyll(particle, fieldset, time): - particle.chl = fieldset.chl[time, particle.depth, particle.lat, particle.lon] +def _sample_chlorophyll(particles, fieldset): + particles.chl = fieldset.chl[ + particles.time, particles.z, particles.lat, particles.lon + ] -def _sample_nitrate(particle, fieldset, time): - particle.no3 = fieldset.no3[time, particle.depth, particle.lat, particle.lon] +def _sample_nitrate(particles, fieldset): + particles.no3 = fieldset.no3[ + particles.time, particles.z, particles.lat, particles.lon + ] -def _sample_phosphate(particle, fieldset, time): - particle.po4 = fieldset.po4[time, particle.depth, particle.lat, particle.lon] +def _sample_phosphate(particles, fieldset): + particles.po4 = fieldset.po4[ + particles.time, particles.z, particles.lat, particles.lon + ] -def _sample_ph(particle, fieldset, time): - particle.ph = fieldset.ph[time, particle.depth, particle.lat, particle.lon] +def _sample_ph(particles, fieldset): + particles.ph = fieldset.ph[ + particles.time, particles.z, particles.lat, particles.lon + ] -def _sample_phytoplankton(particle, fieldset, time): - particle.phyc = fieldset.phyc[time, particle.depth, particle.lat, particle.lon] +def _sample_phytoplankton(particles, fieldset): + particles.phyc = fieldset.phyc[ + particles.time, particles.z, particles.lat, particles.lon + ] -def _sample_primary_production(particle, fieldset, time): - particle.nppv = fieldset.nppv[time, particle.depth, particle.lat, particle.lon] +def _sample_primary_production(particles, fieldset): + particles.nppv = fieldset.nppv[ + particles.time, particles.z, particles.lat, particles.lon + ] ## cast -def _ctd_cast(particle, fieldset, time): +def _ctd_cast(particles, fieldset): + particles_lowering = particles[particles.raising == 0] + particles_raising = particles[particles.raising == 1] + + # TODO: change to boolean masking, like with Argo Floats? + # TODO: different handling of positive down for z now?! Doing positive down now... think kernels need adjusting... + # TODO: need to check on all other instrument kernels as well... + # TODO: plus how the configs are inputted in e.g. expedition.yaml + # lowering - if particle.raising == 0: - particle_ddepth = -particle.winch_speed * particle.dt - if particle.depth + particle_ddepth < particle.max_depth: - particle.raising = 1 - particle_ddepth = -particle_ddepth + particles_lowering.dz = -particles_lowering.winch_speed * particles_lowering.dt + particles_lowering.raising = np.where( + particles_lowering.z + particles_lowering.dz < particles_lowering.max_depth, + 1, + particles_lowering.raising, + ) + # raising - else: - particle_ddepth = particle.winch_speed * particle.dt - if particle.depth + particle_ddepth > particle.min_depth: - particle.delete() + particles_raising.dz = particles_raising.winch_speed * particles_raising.dt + particles_raising.state = np.where( + particles_raising.z + particles_raising.dz > particles_raising.min_depth, + StatusCode.Delete, + particles_raising.state, + ) # ===================================================== @@ -162,18 +193,12 @@ def simulate(self, measurements, out_path) -> None: fieldset = self.load_input_data() - # add dummy U - add_dummy_UV(fieldset) # TODO: parcels v3 bodge; remove when parcels v4 is used - # use first active field for time reference _time_ref_key = next(iter(self.variables)) _time_ref_field = getattr(fieldset, _time_ref_key) - fieldset_starttime = _time_ref_field.grid.time_origin.fulltime( - _time_ref_field.grid.time_full[0] - ) - fieldset_endtime = _time_ref_field.grid.time_origin.fulltime( - _time_ref_field.grid.time_full[-1] - ) + + fieldset_starttime = _time_ref_field.data.time.isel(time=0).values + fieldset_endtime = _time_ref_field.data.time.isel(time=-1).values # deploy time for all ctds should be later than fieldset start time if not all( @@ -185,18 +210,7 @@ def simulate(self, measurements, out_path) -> None: raise ValueError("CTD deployed before fieldset starts.") # depth the ctd will go to. shallowest between ctd max depth and bathymetry. - max_depths = [ - max( - ctd.max_depth, - fieldset.bathymetry.eval( - z=0, - y=ctd.spacetime.location.lat, - x=ctd.spacetime.location.lon, - time=0, - ), - ) - for ctd in measurements - ] + max_depths = _compute_max_depths(measurements, fieldset) # CTD depth can not be too shallow, because kernel would break. # This shallow is not useful anyway, no need to support. @@ -208,7 +222,7 @@ def simulate(self, measurements, out_path) -> None: # build dynamic particle class from the active sensors ctd_config = self.expedition.instruments_config.ctd_config _CTDParticle = build_particle_class_from_sensors( - ctd_config.sensors, _CTD_NONSENSOR_VARIABLES, JITParticle + ctd_config.sensors, _CTD_NONSENSOR_VARIABLES ) # define parcel particles @@ -217,15 +231,15 @@ def simulate(self, measurements, out_path) -> None: pclass=_CTDParticle, lon=[ctd.spacetime.location.lon for ctd in measurements], lat=[ctd.spacetime.location.lat for ctd in measurements], - depth=[ctd.min_depth for ctd in measurements], - time=[ctd.spacetime.time for ctd in measurements], + z=[ctd.min_depth for ctd in measurements], + time=[np.datetime64(ctd.spacetime.time) for ctd in measurements], max_depth=max_depths, min_depth=[ctd.min_depth for ctd in measurements], winch_speed=[WINCH_SPEED for _ in measurements], ) # define output file for the simulation - out_file = ctd_particleset.ParticleFile(name=out_path, outputdt=OUTPUT_DT) + out_file = ParticleFile(path=out_path, outputdt=OUTPUT_DT) # build kernel list from active sensors only sampling_kernels = [ @@ -244,7 +258,7 @@ def simulate(self, measurements, out_path) -> None: ) # there should be no particles left, as they delete themselves when they resurface - if len(ctd_particleset.particledata) != 0: + if len(ctd_particleset.lon) != 0: raise ValueError( "Simulation ended before CTD resurfaced. This most likely means the field time dimension did not match the simulation time span." ) diff --git a/src/virtualship/instruments/drifter.py b/src/virtualship/instruments/drifter.py index 379334b3..c18088ee 100644 --- a/src/virtualship/instruments/drifter.py +++ b/src/virtualship/instruments/drifter.py @@ -4,7 +4,9 @@ from typing import ClassVar import numpy as np -from parcels import AdvectionRK4, JITParticle, ParticleSet, Variable +from parcels import ParticleFile, ParticleSet, Variable +from parcels._core.statuscodes import StatusCode +from parcels.kernels import AdvectionRK2 from virtualship.instruments.base import Instrument from virtualship.instruments.sensors import SensorType @@ -46,15 +48,21 @@ class Drifter: # ===================================================== -def _sample_temperature(particle, fieldset, time): - particle.temperature = fieldset.T[time, particle.depth, particle.lat, particle.lon] +def _sample_temperature(particles, fieldset): + particles.temperature = fieldset.T[ + particles.time, particles.z, particles.lat, particles.lon + ] -def _check_lifetime(particle, fieldset, time): - if particle.has_lifetime == 1: - particle.age += particle.dt - if particle.age >= particle.lifetime: - particle.delete() +def _check_lifetime(particles, fieldset): + particles_wlifetime = particles[particles.has_lifetime == 1] + + particles_wlifetime.age += particles_wlifetime.dt + particles_wlifetime.state = np.where( + particles_wlifetime.age >= particles_wlifetime.lifetime, + StatusCode.Delete, + particles_wlifetime.state, + ) # ===================================================== @@ -123,7 +131,7 @@ def simulate(self, measurements, out_path) -> None: # build dynamic particle class from the active sensors drifter_config = self.expedition.instruments_config.drifter_config _DrifterParticle = build_particle_class_from_sensors( - drifter_config.sensors, _DRIFTER_NONSENSOR_VARIABLES, JITParticle + drifter_config.sensors, _DRIFTER_NONSENSOR_VARIABLES ) # define parcel particles @@ -139,8 +147,8 @@ def simulate(self, measurements, out_path) -> None: pclass=_DrifterParticle, lat=lat_release, lon=lon_release, - depth=[drifter.depth for drifter in measurements], - time=[drifter.spacetime.time for drifter in measurements], + z=[drifter.depth for drifter in measurements], + time=[np.datetime64(drifter.spacetime.time) for drifter in measurements], has_lifetime=[ 1 if drifter.lifetime is not None else 0 for drifter in measurements ], @@ -151,14 +159,13 @@ def simulate(self, measurements, out_path) -> None: ) # define output file for the simulation - out_file = drifter_particleset.ParticleFile( - name=out_path, + out_file = ParticleFile( + path=out_path, outputdt=OUTPUT_DT, - chunks=[len(drifter_particleset), 100], ) # determine end time for simulation, from fieldset (which itself is controlled by drifter lifetimes) - endtime = fieldset.time_origin.fulltime(fieldset.U.grid.time_full[-1]) + endtime = fieldset.U.data.time.isel(time=-1).values # build kernel list from active sensors only sampling_kernels = [ @@ -169,7 +176,7 @@ def simulate(self, measurements, out_path) -> None: # execute simulation drifter_particleset.execute( - [AdvectionRK4, *sampling_kernels, _check_lifetime], + [AdvectionRK2, *sampling_kernels, _check_lifetime], endtime=endtime, dt=DT, output_file=out_file, diff --git a/src/virtualship/instruments/ship_underwater_st.py b/src/virtualship/instruments/ship_underwater_st.py index 6a564cc0..c5149a7a 100644 --- a/src/virtualship/instruments/ship_underwater_st.py +++ b/src/virtualship/instruments/ship_underwater_st.py @@ -3,13 +3,12 @@ from typing import ClassVar import numpy as np -from parcels import ParticleSet, ScipyParticle +from parcels import ParticleFile, ParticleSet from virtualship.instruments.base import Instrument from virtualship.instruments.sensors import SensorType from virtualship.instruments.types import InstrumentType from virtualship.utils import ( - add_dummy_UV, build_particle_class_from_sensors, register_instrument, ) @@ -40,13 +39,13 @@ class Underwater_ST: # define function sampling Salinity -def _sample_salinity(particle, fieldset, time): - particle.salinity = fieldset.S[time, particle.depth, particle.lat, particle.lon] +def _sample_salinity(particles, fieldset): + particles.S = fieldset.S[particles.time, particles.z, particles.lat, particles.lon] # define function sampling Temperature -def _sample_temperature(particle, fieldset, time): - particle.temperature = fieldset.T[time, particle.depth, particle.lat, particle.lon] +def _sample_temperature(particles, fieldset): + particles.T = fieldset.T[particles.time, particles.z, particles.lat, particles.lon] # ===================================================== @@ -95,25 +94,21 @@ def simulate(self, measurements, out_path) -> None: fieldset = self.load_input_data() - # add dummy U - add_dummy_UV(fieldset) # TODO: parcels v3 bodge; remove when parcels v4 is used - # build dynamic particle class from the active sensors st_config = self.expedition.instruments_config.ship_underwater_st_config _ShipSTParticle = build_particle_class_from_sensors( - st_config.sensors, _ST_NONSENSOR_VARIABLES, ScipyParticle + st_config.sensors, _ST_NONSENSOR_VARIABLES ) - particleset = ParticleSet.from_list( + particleset = ParticleSet( fieldset=fieldset, pclass=_ShipSTParticle, lon=0.0, lat=0.0, depth=DEPTH, - time=0, ) - out_file = particleset.ParticleFile(name=out_path, outputdt=np.inf) + out_file = ParticleFile(path=out_path, outputdt=np.inf) # build kernel list from active sensors only sampling_kernels = [ @@ -122,6 +117,9 @@ def simulate(self, measurements, out_path) -> None: if sc.enabled and sc.sensor_type in self.sensor_kernels ] + # TODO: need to overhaul UNDERWATER_ST/underway instruments generally... don't think this Parcels API works anymore + # TODO: a good time to implement https://github.com/Parcels-code/virtualship/issues/231 + for point in measurements: particleset.lon_nextloop[:] = point.location.lon particleset.lat_nextloop[:] = point.location.lat diff --git a/src/virtualship/instruments/xbt.py b/src/virtualship/instruments/xbt.py index 051bf1fa..bc5f5ecf 100644 --- a/src/virtualship/instruments/xbt.py +++ b/src/virtualship/instruments/xbt.py @@ -4,14 +4,15 @@ from typing import ClassVar import numpy as np -from parcels import JITParticle, ParticleSet, Variable +from parcels import ParticleFile, ParticleSet, Variable +from parcels._core.statuscodes import StatusCode from virtualship.instruments.base import Instrument from virtualship.instruments.sensors import SensorType from virtualship.instruments.types import InstrumentType from virtualship.models.spacetime import Spacetime from virtualship.utils import ( - add_dummy_UV, + _compute_max_depths, build_particle_class_from_sensors, register_instrument, ) @@ -50,26 +51,32 @@ class XBT: # ===================================================== -def _sample_temperature(particle, fieldset, time): - particle.temperature = fieldset.T[time, particle.depth, particle.lat, particle.lon] +def _sample_temperature(particles, fieldset): + particles.temperature = fieldset.T[ + particles.time, particles.z, particles.lat, particles.lon + ] -def _xbt_cast(particle, fieldset, time): - particle_ddepth = -particle.fall_speed * particle.dt +def _xbt_cast(particles, fieldset): + particles.dz = -particles.fall_speed * particles.dt # update the fall speed from the quadractic fall-rate equation # check https://doi.org/10.5194/os-7-231-2011 - particle.fall_speed = ( - particle.fall_speed - 2 * particle.deceleration_coefficient * particle.dt + particles.fall_speed = ( + particles.fall_speed - 2 * particles.deceleration_coefficient * particles.dt ) # delete particle if depth is exactly max_depth - if particle.depth == particle.max_depth: - particle.delete() + particles.state = np.where( + particles.z == particles.max_depth, StatusCode.Delete, particles.state + ) # set particle depth to max depth if it's too deep - if particle.depth + particle_ddepth < particle.max_depth: - particle_ddepth = particle.max_depth - particle.depth + particles.dz = np.where( + particles.z + particles.dz < particles.max_depth, + particles.max_depth - particles.z, + particles.z, + ) # ===================================================== @@ -117,18 +124,12 @@ def simulate(self, measurements, out_path) -> None: fieldset = self.load_input_data() - # add dummy U - add_dummy_UV(fieldset) # TODO: parcels v3 bodge; remove when parcels v4 is used - # use first active field for time reference _time_ref_key = next(iter(self.variables)) _time_ref_field = getattr(fieldset, _time_ref_key) - fieldset_starttime = _time_ref_field.grid.time_origin.fulltime( - _time_ref_field.grid.time_full[0] - ) - fieldset_endtime = _time_ref_field.grid.time_origin.fulltime( - _time_ref_field.grid.time_full[-1] - ) + + fieldset_starttime = _time_ref_field.data.time.isel(time=0).values + fieldset_endtime = _time_ref_field.data.time.isel(time=-1).values # deploy time for all xbts should be later than fieldset start time if not all( @@ -140,18 +141,7 @@ def simulate(self, measurements, out_path) -> None: raise ValueError("XBT deployed before fieldset starts.") # depth the xbt will go to. shallowest between xbt max depth and bathymetry. - max_depths = [ - max( - xbt.max_depth, - fieldset.bathymetry.eval( - z=0, - y=xbt.spacetime.location.lat, - x=xbt.spacetime.location.lon, - time=0, - ), - ) - for xbt in measurements - ] + max_depths = _compute_max_depths(measurements, fieldset) # initial fall speeds initial_fall_speeds = [xbt.fall_speed for xbt in measurements] @@ -166,7 +156,7 @@ def simulate(self, measurements, out_path) -> None: # build dynamic particle class from the active sensors xbt_config = self.expedition.instruments_config.xbt_config _XBTParticle = build_particle_class_from_sensors( - xbt_config.sensors, _XBT_NONSENSOR_VARIABLES, JITParticle + xbt_config.sensors, _XBT_NONSENSOR_VARIABLES ) # define xbt particles @@ -175,14 +165,14 @@ def simulate(self, measurements, out_path) -> None: pclass=_XBTParticle, lon=[xbt.spacetime.location.lon for xbt in measurements], lat=[xbt.spacetime.location.lat for xbt in measurements], - depth=[xbt.min_depth for xbt in measurements], - time=[xbt.spacetime.time for xbt in measurements], + z=[xbt.min_depth for xbt in measurements], + time=[np.datetime64(xbt.spacetime.time) for xbt in measurements], max_depth=max_depths, min_depth=[xbt.min_depth for xbt in measurements], fall_speed=[xbt.fall_speed for xbt in measurements], ) - out_file = xbt_particleset.ParticleFile(name=out_path, outputdt=OUTPUT_DT) + out_file = ParticleFile(path=out_path, outputdt=OUTPUT_DT) # build kernel list from active sensors only sampling_kernels = [ diff --git a/src/virtualship/models/expedition.py b/src/virtualship/models/expedition.py index d7badd53..b7269373 100644 --- a/src/virtualship/models/expedition.py +++ b/src/virtualship/models/expedition.py @@ -17,7 +17,6 @@ _calc_sail_time, _calc_wp_stationkeeping_time, _get_bathy_data, - _get_waypoint_latlons, _validate_numeric_to_timedelta, get_supported_sensors, register_instrument_config, @@ -131,14 +130,7 @@ def verify( land_waypoints = [] if not ignore_land_test: try: - wp_lats, wp_lons = _get_waypoint_latlons(self.waypoints) - bathymetry_field = _get_bathy_data( - min(wp_lats), - max(wp_lats), - min(wp_lons), - max(wp_lons), - from_data=from_data, - ).bathymetry + bathymetry_field = _get_bathy_data(from_data=from_data).bathymetry except Exception as e: raise ScheduleError( f"Problem loading bathymetry data (used to verify waypoints are in water) directly via copernicusmarine. \n\n original message: {e}" @@ -147,7 +139,7 @@ def verify( for wp_i, wp in enumerate(self.waypoints): try: value = bathymetry_field.eval( - 0, # time + np.float64(0.0), # time 0, # depth (surface) wp.location.lat, wp.location.lon, diff --git a/src/virtualship/utils.py b/src/virtualship/utils.py index 7ad275cd..df5d153c 100644 --- a/src/virtualship/utils.py +++ b/src/virtualship/utils.py @@ -13,9 +13,10 @@ import copernicusmarine import numpy as np +import parcels import pyproj import xarray as xr -from parcels import FieldSet, Variable +from parcels import FieldSet, Particle, Variable from virtualship.errors import CopernicusCatalogueError @@ -340,29 +341,6 @@ def _get_expedition(expedition_dir: Path) -> Expedition: ) from e -def add_dummy_UV(fieldset: FieldSet): - """Add a dummy U and V field to a FieldSet to satisfy parcels FieldSet completeness checks.""" - if "U" not in fieldset.__dict__.keys(): - for uv_var in ["U", "V"]: - dummy_field = getattr( - FieldSet.from_data( - {"U": 0, "V": 0}, {"lon": 0, "lat": 0}, mesh="spherical" - ), - uv_var, - ) - fieldset.add_field(dummy_field) - try: - fieldset.time_origin = ( - fieldset.T.grid.time_origin - if "T" in fieldset.__dict__.keys() - else fieldset.o2.grid.time_origin - ) - except Exception: - raise ValueError( - "Cannot determine time_origin for dummy UV fields. Assert T or o2 exists in fieldset." - ) from None - - def _select_product_id( physical: bool, schedule_start, @@ -448,43 +426,36 @@ def _start_end_in_product_timerange( ) -def _get_bathy_data( - min_lat: float, - max_lat: float, - min_lon: float, - max_lon: float, - from_data: Path | None = None, -) -> FieldSet: +def _get_bathy_data(from_data: Path | None = None) -> FieldSet: """Bathymetry data from local or 'streamed' directly from Copernicus Marine.""" + VAR = "deptho" if from_data is not None: # load from local data - var = "deptho" bathy_dir = from_data.joinpath("bathymetry") try: - filename, _ = _find_nc_file_with_variable(bathy_dir, var) + filename, _ = _find_nc_file_with_variable(bathy_dir, VAR) except Exception as e: raise RuntimeError( - f"\n\n❗️ Could not find bathymetry variable '{var}' in data directory '{from_data}/bathymetry/'.\n\n❗️ Is the pre-downloaded data directory structure compliant with VirtualShip expectations?\n\n❗️ See the docs for more information on expectations: https://virtualship.readthedocs.io/en/latest/user-guide/index.html#documentation\n" + f"\n\n❗️ Could not find bathymetry variable '{VAR}' in data directory '{from_data}/bathymetry/'.\n\n❗️ Is the pre-downloaded data directory structure compliant with VirtualShip expectations?\n\n❗️ See the docs for more information on expectations: https://virtualship.readthedocs.io/en/latest/user-guide/index.html#documentation\n" ) from e ds_bathymetry = xr.open_dataset(bathy_dir.joinpath(filename)) - bathymetry_variables = {"bathymetry": "deptho"} - bathymetry_dimensions = {"lon": "longitude", "lat": "latitude"} - return FieldSet.from_xarray_dataset( - ds_bathymetry, bathymetry_variables, bathymetry_dimensions - ) else: # stream via Copernicus Marine Service ds_bathymetry = copernicusmarine.open_dataset( dataset_id=BATHYMETRY_ID, - variables=["deptho"], + variables=[VAR], coordinates_selection_method="outside", ) - bathymetry_variables = {"bathymetry": "deptho"} - bathymetry_dimensions = {"lon": "longitude", "lat": "latitude"} - return FieldSet.from_xarray_dataset( - ds_bathymetry, bathymetry_variables, bathymetry_dimensions + ds_bathymetry = ds_bathymetry.expand_dims( + {"depth": 1} + ) # TODO: bodge whilst parcels v4 does not support 2D fields and seeks depth dim; change when parcels v4 released + + ds_fset = parcels.convert.copernicusmarine_to_sgrid( + fields={"bathymetry": ds_bathymetry[VAR]} ) + return FieldSet.from_sgrid_conventions(ds_fset) + def expedition_cost(schedule_results: ScheduleOk, time_past: timedelta) -> float: """ @@ -589,6 +560,22 @@ def _find_files_in_timerange( return [fname for _, fname in files_with_dates] +def _compute_max_depths(measurements, fieldset) -> list[float]: + """Compute the effective max depth for each measurement, capped by bathymetry.""" + return [ + max( + m.max_depth, + fieldset.bathymetry.eval( + z=0, + y=m.spacetime.location.lat, + x=m.spacetime.location.lon, + time=np.float64(0), + )[0], + ) + for m in measurements + ] + + def _random_noise(scale: float = 0.05, limit: float = 0.1) -> float: """Generate a small random noise value for drifter seeding locations.""" value = np.random.normal(loc=0.0, scale=scale) @@ -677,14 +664,13 @@ def _make_hash(s: str, length: int) -> str: def build_particle_class_from_sensors( sensors: list[SensorConfig], nonsensor_variables: list[Variable], - particle_class: type, # generic type annotation needed for v3 particle class behaviour # TODO: Update with Parcels v4 ) -> type: - """Build a Particle class (JITParticle or ScipyParticle) from nonsensor variables and active sensors.""" + """Build a Particle class from nonsensor variables and active sensors.""" sensor_variables = [ variable for sc in sensors if sc.enabled for variable in sc.meta.particle_vars ] - return particle_class.add_variables(nonsensor_variables + sensor_variables) + return Particle.add_variable(nonsensor_variables + sensor_variables) # ===================================================== diff --git a/tests/instruments/test_base.py b/tests/instruments/test_base.py index bbcfea44..a17f95bf 100644 --- a/tests/instruments/test_base.py +++ b/tests/instruments/test_base.py @@ -33,7 +33,6 @@ def test_load_input_data(mock_copernicusmarine, mock_select_product_id, mock_Fie mock_fieldset = MagicMock() mock_FieldSet.from_netcdf.return_value = mock_fieldset mock_FieldSet.from_xarray_dataset.return_value = mock_fieldset - mock_fieldset.gridset.grids = [MagicMock(negate_depth=MagicMock())] mock_fieldset.__getitem__.side_effect = lambda k: MagicMock() mock_copernicusmarine.open_dataset.return_value = MagicMock() # Create a mock waypoint with latitude and longitude