Source code for py_eddy_tracker.tracking

# -*- coding: utf-8 -*-
"""
Class to store link between observations
"""

from datetime import datetime, timedelta
import json
import logging
import platform

from netCDF4 import Dataset, default_fillvals
from numba import njit, types as numba_types
from numpy import (
    arange,
    array,
    bool_,
    concatenate,
    empty,
    isin,
    ma,
    ones,
    setdiff1d,
    uint16,
    unique,
    where,
    zeros,
)

from py_eddy_tracker.observations.observation import (
    EddiesObservations,
    VirtualEddiesObservations,
)
from py_eddy_tracker.observations.tracking import TrackEddiesObservations

logger = logging.getLogger("pet")


[docs]@njit(cache=True) def index(ar, items): indexs = empty(items.shape[0], dtype=numba_types.int_) for i, item in enumerate(items): for idx, val in enumerate(ar): if val == item: indexs[i] = idx break return indexs
[docs]class Correspondances(list): """Object to store correspondances And run tracking """ UINT32_MAX = 4294967295 # Prolongation limit to 255 VIRTUAL_DTYPE = "u1" # ID limit to 4294967295 ID_DTYPE = "u4" # Track limit to 65535 N_DTYPE = "u2" def __init__( self, datasets, virtual=0, class_method=None, class_kw=None, previous_correspondance=None, memory=False, ): """Initiate tracking :param list(str) datasets: A sorted list of filename which contains eddy observations to track :param class class_method: A class which tell how to track :param dict class_kw: keyword argument to setup class :param Correspondances previous_correspondance: A previous correspondance object if you want continue tracking :param bool memory: identification file are load in memory before to be open with netcdf """ super().__init__() # Correspondance dtype self.correspondance_dtype = [ ("in", "u2"), ("out", "u2"), ("id", self.ID_DTYPE), ("cost_value", "f4"), ] if class_method is None: self.class_method = EddiesObservations else: self.class_method = class_method self.class_kw = dict() if class_kw is None else class_kw self.memory = memory # To count ID self.current_id = 0 # To know the number maximal of link between two state self.nb_link_max = 0 # Dataset to iterate self.datasets = datasets self.previous2_obs = None self.previous_obs = None self.current_obs = None # To use virtual obs # Number of obs which can prolongate real observations self.nb_virtual = virtual # Activation or not self.virtual = virtual > 0 self.virtual_obs = None self.previous_virtual_obs = None # Correspondance to prolongate self.filename_previous_correspondance = previous_correspondance self.previous_correspondance = self.load_compatible( self.filename_previous_correspondance ) if self.virtual: # Add field to dtype to follow virtual observations self.correspondance_dtype += [ # True if it isn't a real obs ("virtual", bool_), # Length of virtual segment ("virtual_length", self.VIRTUAL_DTYPE), ] # Array to simply merged self.nb_obs_by_tracks = None self.i_current_by_tracks = None self.nb_obs = 0 self.eddies = None def _copy(self): new = self.__class__( datasets=self.datasets, virtual=self.nb_virtual, class_method=self.class_method, class_kw=self.class_kw, previous_correspondance=self.filename_previous_correspondance, ) for i in self: new.append(i) new.current_id = self.current_id new.nb_link_max = self.nb_link_max new.nb_obs = self.nb_obs new.prepare_merging() logger.debug("Copy done") return new
[docs] def reset_dataset_cache(self): self.previous2_obs = None self.previous_obs = None self.current_obs = None
@property def period(self): """To rethink Returns: period coverage by obs """ date_start = datetime(1950, 1, 1) + timedelta( self.class_method.load_file(self.datasets[0]).time[0] ) date_stop = datetime(1950, 1, 1) + timedelta( self.class_method.load_file(self.datasets[-1]).time[0] ) return date_start, date_stop
[docs] def swap_dataset(self, dataset, *args, **kwargs): """Swap to next dataset""" self.previous2_obs = self.previous_obs self.previous_obs = self.current_obs kwargs = kwargs.copy() kwargs.update(self.class_kw) if self.memory: with open(dataset, "rb") as h: self.current_obs = self.class_method.load_file(h, *args, **kwargs) else: self.current_obs = self.class_method.load_file(dataset, *args, **kwargs)
[docs] def merge_correspondance(self, other): # Verify compliance of file if self.nb_virtual != other.nb_virtual: raise Exception("Different method of tracking") # Determine junction i = where(other.datasets == array(self.datasets[-1]))[0] if len(i) != 1: raise Exception("More than one intersection") # Merge # Create a hash table translate = empty(other.current_id, dtype="u4") translate[:] = self.UINT32_MAX translate[other[i[0] - 1]["id"]] = self[-1]["id"] nb_max = other[i[0] - 1]["id"].max() mask = translate == self.UINT32_MAX # We won't translate previous id mask[:nb_max] = False # Next id will be shifted translate[mask] = arange(mask.sum()) + self.current_id # Translate for items in other[i[0] :]: items["id"] = translate[items["id"]] # Extend with other obs self.extend(other[i[0] :]) # Extend datasets list, which are bounds so we add one self.datasets.extend(other.datasets[i[0] + 1 :]) # We set new id available self.current_id = translate[-1] + 1
[docs] def store_correspondance( self, i_previous, i_current, nb_real_obs, association_cost ): """Storing correspondance in an array""" # Create array to store correspondance data correspondance = array(i_previous, dtype=self.correspondance_dtype) if self.virtual: correspondance["virtual_length"][:] = 255 # index from current_obs correspondance["out"] = i_current correspondance["cost_value"] = association_cost if self.virtual: # if index in previous dataset is bigger than real obs number # it's a virtual data correspondance["virtual"] = i_previous >= nb_real_obs if self.previous2_obs is None: # First time we set ID (Program starting) nb_match = i_previous.shape[0] # Set an id for each match correspondance["id"] = self.id_generator(nb_match) self.append(correspondance) return True # We set all id to UINT32_MAX id_previous = ( ones(len(self.previous_obs), dtype=self.ID_DTYPE) * self.UINT32_MAX ) # We get old id for previously eddies tracked id_previous[self[-1]["out"]] = self[-1]["id"] # We store ID in correspondance if the ID is UINT32_MAX, we never # track it before correspondance["id"] = id_previous[correspondance["in"]] # We set correspondance data for virtual obs : ID/LENGTH if self.previous2_obs is not None and self.virtual: nb_rebirth = correspondance["virtual"].sum() if nb_rebirth != 0: logger.debug( "%d re-birth due to prolongation with" " virtual observations", nb_rebirth, ) # Set id for virtual # get correspondance mask using virtual obs m_virtual = correspondance["virtual"] # index of virtual in virtual obs i_virtual = correspondance["in"][m_virtual] - nb_real_obs correspondance["id"][m_virtual] = self.virtual_obs["track"][i_virtual] correspondance["virtual_length"][m_virtual] = self.virtual_obs[ "segment_size" ][i_virtual] # new_id is equal to UINT32_MAX we must add a new ones # we count the number of new mask_new_id = correspondance["id"] == self.UINT32_MAX nb_new_tracks = mask_new_id.sum() logger.debug("%d birth in this step", nb_new_tracks) # Set new id correspondance["id"][mask_new_id] = self.id_generator(nb_new_tracks) self.append(correspondance) return False
[docs] def append(self, *args, **kwargs): self.nb_link_max = max(self.nb_link_max, len(args[0])) super().append(*args, **kwargs)
[docs] def id_generator(self, nb_id): """Generation id and incrementation""" values = arange(self.current_id, self.current_id + nb_id) self.current_id += nb_id return values
[docs] def recense_dead_id_to_extend(self): """Recense dead id to extend in virtual observation""" # List previous id which are not use in the next step dead_id = setdiff1d(self[-2]["id"], self[-1]["id"]) nb_dead = dead_id.shape[0] logger.debug("%d death of real obs in this step", nb_dead) if not self.virtual: return # get id already dead from few time nb_virtual_extend = 0 if self.virtual_obs is not None: virtual_dead_id = setdiff1d(self.virtual_obs["track"], self[-1]["id"]) i_virtual_dead_id = index(self.virtual_obs["track"], virtual_dead_id) # Virtual obs which can be prolongate alive_virtual_obs = ( self.virtual_obs["segment_size"][i_virtual_dead_id] < self.nb_virtual ) nb_virtual_extend = alive_virtual_obs.sum() logger.debug( "%d virtual obs will be prolongate on the next step", nb_virtual_extend ) # Save previous state to count virtual obs self.previous_virtual_obs = self.virtual_obs # Find mask/index on previous correspondance to extrapolate # position i_dead_id = index(self[-2]["id"], dead_id) # Selection of observations on N-2 and N-1 obs_a = self.previous2_obs.obs[self[-2][i_dead_id]["in"]] obs_b = self.previous_obs.obs[self[-2][i_dead_id]["out"]] self.virtual_obs = self.previous_obs.propagate( obs_a, obs_b, self.previous_virtual_obs.obs[i_virtual_dead_id][alive_virtual_obs] if nb_virtual_extend > 0 else None, dead_track=dead_id, nb_next=nb_dead + nb_virtual_extend, model=self.previous_obs, )
[docs] def load_state(self): # If we have a previous file of correspondance, we will replay only recent part if self.previous_correspondance is not None: first_dataset = len(self.previous_correspondance.datasets) for correspondance in self.previous_correspondance[:first_dataset]: self.append(correspondance) self.current_obs = self.class_method.load_file( self.datasets[first_dataset - 2], **self.class_kw ) flg_virtual = self.previous_correspondance.virtual with Dataset(self.filename_previous_correspondance) as general_handler: self.current_id = general_handler.last_current_id if flg_virtual: # Load last virtual obs self.virtual_obs = VirtualEddiesObservations.from_netcdf( general_handler.groups["LastVirtualObs"] ) self.previous_virtual_obs = VirtualEddiesObservations.from_netcdf( general_handler.groups["LastPreviousVirtualObs"] ) # Load and last previous virtual obs to be merge with current => will be previous2_obs # TODO : Need to rethink this line ?? self.current_obs = self.current_obs.merge( VirtualEddiesObservations.from_netcdf( general_handler.groups["LastPreviousVirtualObs"] ) ) return first_dataset, flg_virtual return 1, False
[docs] def track(self): """Run tracking""" self.reset_dataset_cache() first_dataset, flg_virtual = self.load_state() kwargs = dict() needed_variable = self.class_method.needed_variable() if needed_variable is not None: kwargs["include_vars"] = needed_variable self.swap_dataset(self.datasets[first_dataset - 1], **kwargs) # We begin with second file, first one is in previous for file_name in self.datasets[first_dataset:]: self.swap_dataset(file_name, **kwargs) logger.info("%s match with previous state", file_name) logger.debug("%d obs to match", len(self.current_obs)) nb_real_obs = len(self.previous_obs) if flg_virtual: logger.debug( "%d virtual obs will be add to previous", len(self.virtual_obs) ) self.previous_obs = self.previous_obs.merge(self.virtual_obs) i_previous, i_current, association_cost = self.previous_obs.tracking( self.current_obs ) # return true if the first time (previous2obs is none) if self.store_correspondance( i_previous, i_current, nb_real_obs, association_cost ): continue self.recense_dead_id_to_extend() if self.virtual: flg_virtual = True
[docs] def to_netcdf(self, handler): nb_step = len(self.datasets) - 1 logger.info("Create correspondance file") # Create dimensions logger.debug('Create Dimensions "Nlink" : %d', self.nb_link_max) handler.createDimension("Nlink", self.nb_link_max) logger.debug('Create Dimensions "Nstep" : %d', nb_step) handler.createDimension("Nstep", nb_step) var_file_in = handler.createVariable( zlib=True, complevel=1, varname="FileIn", datatype="S1024", dimensions="Nstep", ) var_file_out = handler.createVariable( zlib=True, complevel=1, varname="FileOut", datatype="S1024", dimensions="Nstep", ) def get_filename(dataset): if not isinstance(dataset, str) or not isinstance(dataset, bytes): return "In memory file" return dataset for i, dataset in enumerate(self.datasets[:-1]): var_file_in[i] = get_filename(dataset) var_file_out[i] = get_filename(self.datasets[i + 1]) var_nb_link = handler.createVariable( zlib=True, complevel=1, varname="nb_link", datatype="u2", dimensions="Nstep", ) datas = dict() for name, dtype in self.correspondance_dtype: if dtype is bool_: dtype = "u1" kwargs_cv = dict() if "u1" in dtype: kwargs_cv["fill_value"] = (255,) handler.createVariable( zlib=True, complevel=1, varname=name, datatype=dtype, dimensions=("Nstep", "Nlink"), **kwargs_cv ) datas[name] = ma.empty((nb_step, self.nb_link_max), dtype=dtype) datas[name].mask = datas[name] == datas[name] for i, correspondance in enumerate(self): logger.debug("correspondance %d", i) nb_elt = correspondance.shape[0] var_nb_link[i] = nb_elt for name, _ in self.correspondance_dtype: datas[name][i, :nb_elt] = correspondance[name] for name, data in datas.items(): h_v = handler.variables[name] h_v[:] = data if "File" not in name: h_v.min = h_v[:].min() h_v.max = h_v[:].max() handler.virtual_use = str(self.virtual) handler.virtual_max_segment = self.nb_virtual handler.last_current_id = self.current_id if self.virtual_obs is not None: group = handler.createGroup("LastVirtualObs") self.virtual_obs.to_netcdf(group) group = handler.createGroup("LastPreviousVirtualObs") self.previous_virtual_obs.to_netcdf(group) handler.module = self.class_method.__module__ handler.classname = self.class_method.__qualname__ handler.class_kw = json.dumps(self.class_kw) handler.node = platform.node() logger.info("Create correspondance file done")
[docs] def save(self, filename, dict_completion=None): self.prepare_merging() if isinstance(dict_completion, dict): filename = filename.format(**dict_completion) with Dataset(filename, "w", format="NETCDF4") as h_nc: self.to_netcdf(h_nc)
[docs] def load_compatible(self, filename): if filename is None: return None previous_correspondance = Correspondances.load(filename) if self.nb_virtual != previous_correspondance.nb_virtual: raise Exception( "File of correspondance IN contains a different virtual segment size : file(%d), yaml(%d)" % (previous_correspondance.nb_virtual, self.nb_virtual) ) if self.class_method != previous_correspondance.class_method: raise Exception( "File of correspondance IN contains a different class method: file(%s), yaml(%s)" % (previous_correspondance.class_method, self.class_method) ) return previous_correspondance
[docs] @classmethod def from_netcdf(cls, handler): datas = {varname: data[:] for varname, data in handler.variables.items()} datasets = list(datas["FileIn"]) datasets.append(datas["FileOut"][-1]) if hasattr(handler, "module"): class_method = getattr( __import__(handler.module, globals(), locals(), handler.classname), handler.classname, ) class_kw = getattr(handler, "class_kw", dict()) if isinstance(class_kw, str): class_kw = json.loads(class_kw) else: class_method = None class_kw = dict() logger.info("File load with class %s(%s)", class_method, class_kw) obj = cls( datasets, handler.virtual_max_segment, class_method=class_method, class_kw=class_kw, ) id_max = 0 for i, nb_elt in enumerate(datas["nb_link"][:]): logger.debug( "Link between %s and %s", datas["FileIn"][i], datas["FileOut"][i] ) correspondance = array( datas["in"][i, :nb_elt], dtype=obj.correspondance_dtype ) for name, _ in obj.correspondance_dtype: if name == "in": continue if name == "virtual_length": correspondance[name] = 255 correspondance[name] = datas[name][i, :nb_elt] id_max = max(id_max, correspondance["id"].max()) obj.append(correspondance) obj.current_id = id_max + 1 return obj
[docs] @classmethod def load(cls, filename): logger.info("Loading %s", filename) with Dataset(filename, "r", format="NETCDF4") as h_nc: obj = cls.from_netcdf(h_nc) return obj
[docs] def prepare_merging(self): # count obs by tracks (we add directly one, because correspondance # is an interval) self.nb_obs_by_tracks = ones(self.current_id, dtype=self.N_DTYPE) for correspondance in self: self.nb_obs_by_tracks[correspondance["id"]] += 1 if self.virtual: # When start is virtual, we don't have a previous # correspondance self.nb_obs_by_tracks[ correspondance["id"][correspondance["virtual"]] ] += correspondance["virtual_length"][correspondance["virtual"]] # Compute index of each tracks self.i_current_by_tracks = ( self.nb_obs_by_tracks.cumsum() - self.nb_obs_by_tracks ) # Number of global obs self.nb_obs = self.nb_obs_by_tracks.sum() logger.info("%d tracks identified", self.current_id) logger.info("%d observations will be join", self.nb_obs)
[docs] def longer_than(self, size_min): """Remove from correspondance table all association for shorter eddies than size_min""" # Identify eddies longer than mask = self.nb_obs_by_tracks >= size_min if not mask.any(): return False i_keep_track = where(mask)[0] # Reduce array self.nb_obs_by_tracks = self.nb_obs_by_tracks[i_keep_track] self.i_current_by_tracks = ( self.nb_obs_by_tracks.cumsum() - self.nb_obs_by_tracks ) self.nb_obs = self.nb_obs_by_tracks.sum() # Give the last id used self.current_id = self.nb_obs_by_tracks.shape[0] translate = empty(i_keep_track.max() + 1, dtype="u4") translate[i_keep_track] = arange(self.current_id) for i, correspondance in enumerate(self): m_keep = isin(correspondance["id"], i_keep_track) self[i] = correspondance[m_keep] self[i]["id"] = translate[self[i]["id"]] logger.debug("Select longer than %d done", size_min)
[docs] def shorter_than(self, size_max): """Remove from correspondance table all association for longer eddies than size_max""" # Identify eddies longer than i_keep_track = where(self.nb_obs_by_tracks < size_max)[0] # Reduce array self.nb_obs_by_tracks = self.nb_obs_by_tracks[i_keep_track] self.i_current_by_tracks = ( self.nb_obs_by_tracks.cumsum() - self.nb_obs_by_tracks ) self.nb_obs = self.nb_obs_by_tracks.sum() # Give the last id used self.current_id = self.nb_obs_by_tracks.shape[0] translate = empty(i_keep_track.max() + 1, dtype="u4") translate[i_keep_track] = arange(self.current_id) for i, correspondance in enumerate(self): m_keep = isin(correspondance["id"], i_keep_track) self[i] = correspondance[m_keep] self[i]["id"] = translate[self[i]["id"]] logger.debug("Select shorter than %d done", size_max)
[docs] def merge(self, until=-1, raw_data=True): """Merge all the correspondance in one array with all fields""" # Start loading identification again to save in the finals tracks # Load first file self.reset_dataset_cache() self.swap_dataset(self.datasets[0], raw_data=raw_data) # Start create netcdf to agglomerate all eddy logger.debug("We will create an array (size %d)", self.nb_obs) eddies = TrackEddiesObservations( size=self.nb_obs, track_extra_variables=self.current_obs.track_extra_variables, track_array_variables=self.current_obs.track_array_variables, array_variables=self.current_obs.array_variables, raw_data=raw_data, ) # All the value put at nan, necessary only for all end of track eddies["cost_association"][:] = default_fillvals["f4"] # Calculate the index in each tracks, we compute in u4 and translate # in u2 (which are limited to 65535) logger.debug("Compute global index array (N)") eddies["n"][:] = uint16( arange(self.nb_obs, dtype="u4") - self.i_current_by_tracks.repeat(self.nb_obs_by_tracks) ) logger.debug("Compute global track array") eddies["track"][:] = arange(self.current_id).repeat(self.nb_obs_by_tracks) # Set type of eddy with first file eddies.sign_type = self.current_obs.sign_type # Fields to copy fields = self.current_obs.fields # To know if the track start first_obs_save_in_tracks = zeros(self.i_current_by_tracks.shape, dtype=bool_) for i, file_name in enumerate(self.datasets[1:]): if until != -1 and i >= until: break logger.debug("Merge data from %s", file_name) # Load current file (we begin with second one) self.swap_dataset(file_name, raw_data=raw_data) # We select the list of id which are involve in the correspondance i_id = self[i]["id"] # Index where we will write in the final object index_final = self.i_current_by_tracks[i_id] # First obs of eddies m_first_obs = ~first_obs_save_in_tracks[i_id] if m_first_obs.any(): # Index in the previous file index_in = self[i]["in"][m_first_obs] # Copy all variable for field in fields: if field == "cost_association": continue eddies[field][index_final[m_first_obs]] = self.previous_obs[field][ index_in ] # Increment self.i_current_by_tracks[i_id[m_first_obs]] += 1 # Active this flag, we have only one first by tracks first_obs_save_in_tracks[i_id] = True index_final = self.i_current_by_tracks[i_id] if self.virtual: # If the flag virtual in correspondance is active, # the previous is virtual m_virtual = self[i]["virtual"] if m_virtual.any(): # Incrementing index self.i_current_by_tracks[i_id[m_virtual]] += self[i][ "virtual_length" ][m_virtual] # Get new index index_final = self.i_current_by_tracks[i_id] # Index in the current file index_current = self[i]["out"] if "cost_association" in eddies.fields: eddies["cost_association"][index_final - 1] = self[i]["cost_value"] # Copy all variable for field in fields: eddies[field][index_final] = self.current_obs[field][index_current] # Add increment for each index used self.i_current_by_tracks[i_id] += 1 self.previous_obs = self.current_obs return eddies
[docs] def get_unused_data(self, raw_data=False): """ Add in track object all the observations which aren't selected Returns: Unused Eddies """ nb_dataset = len(self.datasets) has_virtual = "virtual" in self[0].dtype.names eddies = list() for i, dataset in enumerate(self.datasets): last_dataset = i == (nb_dataset - 1) if has_virtual and not last_dataset: m_in = ~self[i]["virtual"] else: m_in = slice(None) if i == 0: index_used = self[i]["in"] elif last_dataset: index_used = self[i - 1]["out"] else: index_used = unique( concatenate((self[i - 1]["out"], self[i]["in"][m_in])) ) logger.debug("Load file : %s", dataset) if self.memory: with open(dataset, "rb") as h: current_obs = self.class_method.load_file(h, raw_data=raw_data) else: current_obs = self.class_method.load_file(dataset, raw_data=raw_data) eddies.append(current_obs.index(index_used, reverse=True)) return EddiesObservations.concatenate(eddies)