diff --git a/tensorflow_model_optimization/python/core/internal/tensor_encoding/core/simple_encoder.py b/tensorflow_model_optimization/python/core/internal/tensor_encoding/core/simple_encoder.py index 03966569..09ff5844 100644 --- a/tensorflow_model_optimization/python/core/internal/tensor_encoding/core/simple_encoder.py +++ b/tensorflow_model_optimization/python/core/internal/tensor_encoding/core/simple_encoder.py @@ -134,7 +134,12 @@ def decode_fn(encoded_structure): encoded_structure = py_utils.merge_dicts(encoded_structure, encoded_py_structure['flat_py']) encoded_structure = tf.nest.pack_sequence_as( - encoded_py_structure['full'], tf.nest.flatten(encoded_structure)) + encoded_py_structure['full'], + [ + encoded_structure[k] for k, _ in + py_utils.flatten_with_joined_string_paths( + encoded_py_structure['full']) + ]) return encoder.decode(encoded_structure[_TENSORS], encoded_structure[_PARAMS], encoded_structure[_SHAPES])