"""Helpers to work with inputs from multiple nodes."""
import dataclasses
import logging
import typing
import numpy as np
log = logging.getLogger(__name__)
[docs]
@dataclasses.dataclass
class ExcludeIds:
"""Remove entries from a dataset."""
data: typing.Union[list, dict]
ids: typing.Union[list, dict]
def __post_init__(self):
if self.ids is None:
return
if isinstance(self.ids, list):
log.debug("ids is list")
if isinstance(self.ids[0], dict):
log.debug("ids is list of dicts")
# we assume list[dict]. IF mixed it will raise some error
ids = {}
for data in self.ids:
for key, value in data.items():
if key in ids:
ids[key].extend(value)
else:
if not isinstance(value, list):
raise ValueError(
f"Ids can not be {type(value)} but must be int Found"
f" {value} instead."
)
ids[key] = value
self.ids = {}
for key, val in ids.items():
self.ids[key] = np.sort(val).astype(int).tolist()
else:
log.debug("ids is list of ints")
self.ids = np.sort(self.ids).astype(int).tolist()
else:
log.debug("ids is dict")
for key, ids in self.ids.items():
self.ids[key] = np.sort(ids).astype(int).tolist()
[docs]
def get_clean_data(self, flatten: bool = False) -> list:
"""Remove the 'ids' from the 'data'."""
# TODO do we need a dict return here or could we just return a flat list?
if self.ids is None:
if isinstance(self.data, list):
return self.data
elif isinstance(self.data, dict):
if flatten:
return get_flat_data_from_dict(self.data)
return self.data
if isinstance(self.data, list) and isinstance(self.ids, list):
return [x for i, x in enumerate(self.data) if i not in self.ids]
elif isinstance(self.data, dict) and isinstance(self.ids, dict):
clean_data = {}
for key, data in self.data.items():
if key in self.ids:
clean_data[key] = [
x for i, x in enumerate(data) if i not in self.ids[key]
]
else:
clean_data[key] = data
if flatten:
return get_flat_data_from_dict(clean_data)
return clean_data
else:
raise TypeError(
"ids and data must be of the same type. "
f"ids is {type(self.ids)} and data is {type(self.data)}"
)
[docs]
def get_original_ids(self, ids: list, per_key: bool = False) -> list:
"""Shift the 'ids' such that they are valid for the initial data."""
ids = np.array(ids).astype(int)
ids = np.sort(ids)
if isinstance(self.ids, list):
for removed_id in self.ids:
ids[ids >= removed_id] += 1
elif isinstance(self.ids, dict):
for removed_id in self.ids_as_list:
ids[ids >= removed_id] += 1
if per_key:
return get_ids_per_key(self.data, ids, silent_ignore=True)
return ids.tolist()
@property
def ids_as_list(self) -> list:
# {a: [1, 2], b: [1, 3]}
# {a: list(10), b:list(10)}
# [1, 2, 1+10, 3+10]
ids = []
size = 0
for key in self.data:
# we iterate through data, not ids, because ids must not contain all keys
if key in self.ids:
ids.append(np.array(self.ids[key]) + size)
size += len(self.data[key])
if len(ids):
ids = np.concatenate(ids)
ids = np.sort(ids)
return ids.astype(int).tolist()
return []
[docs]
def get_flat_data_from_dict(data: dict, silent_ignore: bool = False) -> list:
"""Flatten a dictionary of lists into a single list.
Parameters
----------
data : dict
Dictionary of lists.
silent_ignore : bool, optional
If True, the function will return the input if it is not a
dictionary. If False, it will raise a TypeError.
Example
-------
>>> data = {'a': [1, 2, 3], 'b': [4, 5, 6]}
>>> get_flat_data_from_dict(data)
[1, 2, 3, 4, 5, 6]
"""
if not isinstance(data, dict):
if silent_ignore:
return data
else:
raise TypeError(f"data must be a dictionary and not {type(data)}")
flat_data = []
for x in data.values():
flat_data.extend(x)
return flat_data
[docs]
def get_ids_per_key(
data: dict, ids: list, silent_ignore: bool = False
) -> typing.Dict[str, list]:
"""Get the ids per key from a dictionary of lists.
Parameters
----------
data : dict
Dictionary of lists.
ids : list
List of ids. The ids are assumed to be taken from the flattened
'get_flat_data_from_dict(data)' data. If the ids aren't sorted,
they will be sorted.
silent_ignore : bool, optional
If True, the function will return the input if it is not a
dictionary. If False, it will raise a TypeError.
Example
-------
>>> data = {'a': [1, 2, 3], 'b': [4, 5, 6]}
>>> get_ids_per_key(data, [0, 1, 3, 5])
{'a': [0, 1], 'b': [0, 2]}
"""
if not isinstance(data, dict):
if silent_ignore:
return np.array(ids).tolist()
else:
raise TypeError(f"data must be a dictionary and not {type(data)}")
ids_per_key = {}
ids = np.array(ids).astype(int)
ids = np.sort(ids)
start = 0
for key, val in data.items():
condition = ids - start
condition = np.logical_and(condition < len(val), condition >= 0)
ids_per_key[key] = np.array(ids[condition] - start).tolist()
start += len(val)
return ids_per_key