diff --git a/python/otbtf.py b/python/otbtf.py index c944728c86db5f7f96faa78c943e2bd10b742b88..2a48e01d6c3fdbb70919ccacb3d66282ef094ced 100644 --- a/python/otbtf.py +++ b/python/otbtf.py @@ -684,13 +684,15 @@ class TFRecords: self.save(output_shapes, self.output_shape_file) @staticmethod - def parse_tfrecord(example, features_types, target_keys, target_cropping=None): + def parse_tfrecord(example, features_types, target_keys, preprocessing_fn=None, **kwargs): """ Parse example object to sample dict. :param example: Example object to parse :param features_types: List of types for each feature :param target_keys: list of keys of the targets - :param target_cropping: Optional. Number of pixels to be removed on each side of the target tensor. + :param preprocessing_fn: Optional. A preprocessing function that takes input, target as args and returns + a tuple (input_preprocessed, target_preprocessed) + :param kwargs: some keywords arguments for preprocessing_fn """ read_features = {key: tf.io.FixedLenFeature([], dtype=tf.string) for key in features_types} example_parsed = tf.io.parse_single_example(example, read_features) @@ -702,19 +704,17 @@ class TFRecords: input_parsed = {key: value for (key, value) in example_parsed.items() if key not in target_keys} target_parsed = {key: value for (key, value) in example_parsed.items() if key in target_keys} - if target_cropping: - target_parsed = {key: value[target_cropping:-target_cropping, target_cropping:-target_cropping, :] for key, value in target_parsed.items()} + if preprocessing_fn: + input_parsed, target_parsed = preprocessing_fn(input_parsed, target_parsed, **kwargs) return input_parsed, target_parsed - - def read(self, batch_size, target_keys, target_cropping=None, n_workers=1, drop_remainder=True, shuffle_buffer_size=None): + def read(self, batch_size, target_keys, n_workers=1, drop_remainder=True, shuffle_buffer_size=None, + preprocessing_fn=None, **kwargs): """ Read all tfrecord files matching with pattern and convert data to tensorflow dataset. :param batch_size: Size of tensorflow batch :param target_keys: Keys of the target, e.g. ['s2_out'] - :param target_cropping: Number of pixels to be removed on each side of the target. Must be used with a network - architecture coherent with this, i.e. that has a Cropping2D layer in the end :param n_workers: number of workers, e.g. 4 if using 4 GPUs e.g. 12 if using 3 nodes of 4 GPUs :param drop_remainder: whether the last batch should be dropped in the case it has fewer than @@ -722,12 +722,16 @@ class TFRecords: False is advisable when evaluating metrics so that all samples are used :param shuffle_buffer_size: if None, shuffle is not used. Else, blocks of shuffle_buffer_size elements are shuffled using uniform random. + :param preprocessing_fn: Optional. A preprocessing function that takes input, target as args and returns + a tuple (input_preprocessed, target_preprocessed) + :param kwargs: some keywords arguments for preprocessing_fn """ options = tf.data.Options() if shuffle_buffer_size: options.experimental_deterministic = False # disable order, increase speed options.experimental_distribute.auto_shard_policy = tf.data.experimental.AutoShardPolicy.AUTO # for multiworker - parse = partial(self.parse_tfrecord, features_types=self.output_types, target_keys=target_keys, target_cropping=target_cropping) + parse = partial(self.parse_tfrecord, features_types=self.output_types, target_keys=target_keys, + preprocessing_fn=preprocessing_fn, **kwargs) # TODO: to be investigated : # 1/ num_parallel_reads useful ? I/O bottleneck of not ?