from __future__ import annotations
import warnings
from abc import ABC, abstractmethod
from typing import Any, Literal
from ..dataset import (
ABCRole,
AdditionalTargetRole,
Dataset,
DatasetAdapter,
ExperimentData,
GroupingRole,
InfoRole,
PreTargetRole,
StatisticRole,
TargetRole,
TempTargetRole,
)
from ..executor import Calculator
from ..utils import (
NAME_BORDER_SYMBOL,
BackendsEnum,
ExperimentDataEnum,
FromDictTypes,
GroupingDataType,
)
from ..utils.errors import (
AbstractMethodError,
NoColumnsError,
NoRequiredArgumentError,
NotSuitableFieldError,
)
[docs]
class Comparator(Calculator, ABC):
def __init__(
self,
compare_by: Literal[
"groups", "columns", "columns_in_groups", "cross", "matched_pairs"
],
grouping_role: ABCRole | None = None,
target_roles: ABCRole | list[ABCRole] | None = None,
baseline_role: ABCRole | None = None,
key: Any = "",
calc_kwargs: dict[str, Any] = {},
):
super().__init__(key=key)
self.grouping_role = grouping_role or GroupingRole()
self.compare_by = compare_by
self.target_roles = target_roles or TargetRole()
self.baseline_role = baseline_role or PreTargetRole()
self.calc_kwargs = calc_kwargs
@property
def search_types(self) -> list[type] | None:
return None
def _local_extract_dataset(
self, compare_result: dict[Any, Any], roles: dict[Any, ABCRole]
) -> Dataset:
return self._extract_dataset(compare_result, roles)
@classmethod
@abstractmethod
def _inner_function(
cls, data: Dataset, test_data: Dataset | None = None, **kwargs
) -> Any:
raise AbstractMethodError
def _get_fields_data(self, data: ExperimentData) -> dict[str, Dataset]:
tmp_role = (
True if data.ds.tmp_roles or data.additional_fields.tmp_roles else False
)
group_field_data = data.field_data_search(roles=self.grouping_role)
target_fields_data = data.field_data_search(
roles=(
(TempTargetRole() if data.ds.tmp_roles else AdditionalTargetRole())
if tmp_role
else self.target_roles
),
tmp_role=tmp_role,
search_types=self.search_types,
)
baseline_field_data = data.field_data_search(roles=self.baseline_role)
return {
"group_field": group_field_data,
"target_fields": target_fields_data,
"baseline_field": baseline_field_data,
}
@classmethod
def _execute_inner_function(
cls,
baseline_data: list[tuple[str, Dataset]],
compared_data: list[tuple[str, Dataset]],
compare_by: Literal[
"groups", "columns", "columns_in_groups", "cross", "matched_pairs"
],
**kwargs,
) -> dict:
result = {}
for i in range(len(compared_data)):
res_name = (
compared_data[i][0]
if compare_by == "groups"
else f"{compared_data[i][0]}{NAME_BORDER_SYMBOL}{compared_data[i][1].columns[0]}"
)
result[res_name] = DatasetAdapter.to_dataset(
cls._inner_function(
baseline_data[0 if len(baseline_data) == 1 else i][1],
compared_data[i][1],
**kwargs,
),
InfoRole(),
)
return result
@staticmethod
def _check_test_data(test_data: Dataset | None = None) -> Dataset:
if test_data is None:
raise ValueError("test_data is needed for evaluation")
return test_data
def _set_value(
self, data: ExperimentData, value: Dataset | None = None, key: Any = None
) -> ExperimentData:
data.set_value(
ExperimentDataEnum.analysis_tables,
self.id,
value,
)
return data
@staticmethod
def _extract_dataset(
compare_result: FromDictTypes, roles: dict[Any, ABCRole]
) -> Dataset:
if isinstance(next(iter(compare_result.values())), Dataset):
cr_list_v: list[Dataset] = list(compare_result.values())
result = cr_list_v[0]
if len(cr_list_v) > 1:
result = result.append(cr_list_v[1:])
result.index = list(compare_result.keys())
return result
return Dataset.from_dict(compare_result, roles, BackendsEnum.pandas)
@staticmethod
def _grouping_data_split(
grouping_data: dict[str, Dataset],
compare_by: Literal[
"groups", "columns", "columns_in_groups", "cross", "matched_pairs"
],
target_fields: list[str],
baseline_field: str | None = None,
) -> GroupingDataType:
if not isinstance(grouping_data, dict):
raise TypeError(
f"Grouping data must be dict of strings and datasets, but got {type(grouping_data)}"
)
compared_data = list(grouping_data.items())
baseline_data = [compared_data.pop(0)]
baseline_data = [
(
bucket[0],
bucket[1][target_fields if compare_by == "groups" else baseline_field],
)
for bucket in baseline_data
]
compared_data = [
(bucket[0], bucket[1][target_fields]) for bucket in compared_data
]
return baseline_data, compared_data
@staticmethod
def _split_ds_into_columns(
data: list[tuple[str, Dataset]],
) -> list[tuple[str, Dataset]]:
result = [
(bucket[0], bucket[1][column])
for bucket in data
for column in bucket[1].columns
]
return result
@staticmethod
def _field_validity_check(
field_data: Dataset,
comparison_role: Literal[
"group_field_data", "target_fields_data", "baseline_field_data"
],
compare_by: Literal[
"groups", "columns", "columns_in_groups", "cross", "matched_pairs"
],
) -> Dataset:
if len(field_data.columns) == 0:
raise NoRequiredArgumentError(comparison_role)
if len(field_data.columns) > 1:
warnings.warn(
f"{comparison_role} must have only one column when the comparison is done by {compare_by}. {len(field_data.columns)} passed. {field_data.columns[0]} will be used.",
)
field_data = field_data[field_data.columns[0]]
return field_data
@classmethod
def _split_for_groups_mode(
cls,
group_field_data: Dataset,
target_fields_data: Dataset,
) -> GroupingDataType:
target_fields_data = cls._field_validity_check(
target_fields_data, "target_fields_data", "groups"
)
group_field_data = cls._field_validity_check(
group_field_data, "group_field_data", "groups"
)
data_buckets = sorted(
target_fields_data.groupby(by=group_field_data), key=lambda tup: tup[0]
)
baseline_data = cls._split_ds_into_columns([data_buckets.pop(0)])
compared_data = cls._split_ds_into_columns(data=data_buckets)
return baseline_data, compared_data
@classmethod
def _split_for_columns_mode(
cls,
baseline_field_data: Dataset,
target_fields_data: Dataset,
) -> GroupingDataType:
baseline_field_data = cls._field_validity_check(
baseline_field_data, "baseline_field_data", "columns"
)
if len(target_fields_data.columns) == 0:
raise NoRequiredArgumentError(target_fields_data)
baseline_data = [(f"{baseline_field_data.columns[0]}", baseline_field_data)]
compared_data = [
(f"{column}", target_fields_data[column])
for column in target_fields_data.columns
]
return baseline_data, compared_data
@classmethod
def _split_for_columns_in_groups_mode(
cls,
group_field_data: Dataset,
baseline_field_data: Dataset,
target_fields_data: Dataset,
) -> GroupingDataType:
baseline_field_data = cls._field_validity_check(
baseline_field_data, "baseline_field_data", "columns_in_groups"
)
target_fields_data = cls._field_validity_check(
target_fields_data, "target_fields_data", "columns_in_groups"
)
group_field_data = cls._field_validity_check(
group_field_data, "group_field_data", "columns_in_groups"
)
baseline_data = baseline_field_data.groupby(by=group_field_data)
compared_data = cls._split_ds_into_columns(
target_fields_data.groupby(by=group_field_data)
)
return baseline_data, compared_data
@classmethod
def _split_for_cross_mode(
cls,
group_field_data: Dataset,
baseline_field_data: Dataset,
target_fields_data: Dataset,
) -> GroupingDataType:
baseline_field_data = cls._field_validity_check(
baseline_field_data, "baseline_field_data", "cross"
)
target_fields_data = cls._field_validity_check(
target_fields_data, "target_fields_data", "cross"
)
group_field_data = cls._field_validity_check(
group_field_data, "group_field_data", "cross"
)
baseline_data = [
sorted(
baseline_field_data.groupby(by=group_field_data), key=lambda tup: tup[0]
).pop(0)
]
compared_data = sorted(
target_fields_data.groupby(by=group_field_data), key=lambda tup: tup[0]
)
compared_data.pop(0)
compared_data = cls._split_ds_into_columns(data=compared_data)
return baseline_data, compared_data
@classmethod
def _split_for_matched_pairs_mode(
cls,
group_field_data: Dataset,
baseline_field_data: Dataset,
target_fields_data: Dataset,
) -> GroupingDataType:
group_field_data = cls._field_validity_check(
group_field_data, "group_field_data", "matched_pairs"
)
baseline_field_data = cls._field_validity_check(
baseline_field_data, "baseline_field_data", "matched_pairs"
)
target_fields_data = cls._field_validity_check(
target_fields_data, "target_fields_data", "matched_pairs"
)
compared_data = target_fields_data.groupby(by=group_field_data)
baseline_indexes = baseline_field_data.groupby(by=group_field_data)
baseline_data = []
# mapping the data of the baseline data to its matches data. If there are no matches, matching index will be -1
for group in baseline_indexes:
name = group[0]
indexes = group[1].iget_values(column=0)
dummy_index = target_fields_data.index[-1]
indexes = list(map(lambda x: dummy_index if x < 0 else x, indexes))
baseline_data.append((name, target_fields_data.loc[indexes, :]))
return baseline_data, compared_data
@classmethod
def _split_data_to_buckets(
cls,
compare_by: Literal[
"groups", "columns", "columns_in_groups", "cross", "matched_pairs"
],
target_fields_data: Dataset,
baseline_field_data: Dataset,
group_field_data: Dataset,
) -> GroupingDataType:
"""
Splits the given dataset into buckets into baseline and compared data, based on the specified comparison mode.
Args:
group_field (Union[Sequence[str], str]): The field(s) to group the data by.
target_fields (Union[str, List[str]]): The field(s) to target for comparison.
compare_by (Literal['groups', 'columns', 'columns_in_groups', 'cross', 'matched_pairs'], optional): The method to compare the data. Defaults to 'groups'.
baseline_field (Optional[str], optional): The column to use as the baseline for comparison. Required if `compare_by` is 'columns' or 'columns_in_groups'. Defaults to None.
Returns:
Tuple: A tuple containing the baseline data and the compared data.
Raises:
NoRequiredArgumentError: If `baseline_field` is None and `compare_by` is 'columns' or 'columns_in_groups' or 'cross'.
ValueError: If `compare_by` is not one of the allowed values.
"""
if compare_by == "groups":
baseline_data, compared_data = cls._split_for_groups_mode(
group_field_data, target_fields_data
)
elif compare_by == "columns":
baseline_data, compared_data = cls._split_for_columns_mode(
baseline_field_data, target_fields_data
)
elif compare_by == "columns_in_groups":
baseline_data, compared_data = cls._split_for_columns_in_groups_mode(
group_field_data, baseline_field_data, target_fields_data
)
elif compare_by == "cross":
baseline_data, compared_data = cls._split_for_cross_mode(
group_field_data, baseline_field_data, target_fields_data
)
elif compare_by == "matched_pairs":
baseline_data, compared_data = cls._split_for_matched_pairs_mode(
group_field_data, baseline_field_data, target_fields_data
)
else:
raise ValueError(
f"Wrong compare_by argument passed {compare_by}. It can be only one of the following modes: 'groups', 'columns', 'columns_in_groups', 'cross'."
)
return baseline_data, compared_data
@classmethod
def calc(
cls,
compare_by: (
Literal["groups", "columns", "columns_in_groups", "cross", "matched_pairs"]
| None
) = None,
target_fields_data: Dataset | None = None,
baseline_field_data: Dataset | None = None,
group_field_data: Dataset | None = None,
grouping_data: (
tuple[list[tuple[str, Dataset]]] | list[tuple[str, Dataset]] | None
) = None,
**kwargs,
) -> dict:
if compare_by is None and target_fields_data is None:
raise ValueError(
"You should pass either compare_by or target_fields argument."
)
if grouping_data is None:
grouping_data = cls._split_data_to_buckets(
compare_by=compare_by,
target_fields_data=target_fields_data,
baseline_field_data=baseline_field_data,
group_field_data=group_field_data,
)
baseline_data, compared_data = grouping_data
return cls._execute_inner_function(
baseline_data=baseline_data,
compared_data=compared_data,
compare_by=compare_by,
**kwargs,
)
[docs]
def execute(self, data: ExperimentData) -> ExperimentData:
"""
Execute the comparator on the given data.
The comparator will split the data into a baseline and a comparison
dataset based on the compare_by argument. Then it will calculate
statistics comparing the baseline and comparison datasets.
:param data: The ExperimentData to execute the comparator on
:type data: ExperimentData
:return: The ExperimentData with the comparison results
:rtype: ExperimentData
"""
fields = self._get_fields_data(data)
group_field_data = fields["group_field"]
target_fields_data = fields["target_fields"]
baseline_field_data = fields["baseline_field"]
self.key = str(
target_fields_data.columns[0]
if len(target_fields_data.columns) == 1
else (list(target_fields_data.columns) or "")
)
if len(target_fields_data.columns) == 0:
# If the column is not suitable for the test, then the target will be empty, but if there is a role tempo, then this is normal behavior
if data.ds.tmp_roles:
return data
else:
raise NoColumnsError(TargetRole().role_name)
if len(group_field_data.columns) != 1 and self.compare_by != "columns":
raise NotSuitableFieldError(group_field_data, "Grouping")
if (
group_field_data.columns[0] in data.groups
) and self.compare_by != "matched_pairs": # TODO: proper split between groups and columns
grouping_data = self._grouping_data_split(
grouping_data=data.groups[group_field_data.columns[0]],
compare_by=self.compare_by,
target_fields=(
[data.ds.columns[0]]
if group_field_data.columns[0] == target_fields_data.columns[0]
else list(target_fields_data.columns)
),
baseline_field=(
baseline_field_data.columns[0]
if len(baseline_field_data.columns) > 0
else None
),
)
else:
combined_data = (
data.ds.merge(
data.additional_fields[
[
col
for col in data.additional_fields.columns
if isinstance(
data.additional_fields.roles[col], AdditionalTargetRole
)
]
],
left_index=True,
right_index=True,
how="outer",
)
if any(
isinstance(data.additional_fields.roles[col], AdditionalTargetRole)
for col in data.additional_fields.columns
)
else data.ds
)
data.groups[group_field_data.columns[0]] = {
f"{group}": ds for group, ds in combined_data.groupby(group_field_data)
}
grouping_data = self._split_data_to_buckets(
compare_by=self.compare_by,
target_fields_data=target_fields_data,
baseline_field_data=baseline_field_data,
group_field_data=group_field_data,
)
if len(grouping_data[0]) < 1 or len(grouping_data[1]) < 1:
raise NotSuitableFieldError(group_field_data, "Grouping")
compare_result = self.calc(
**self.calc_kwargs,
compare_by=self.compare_by,
target_fields_data=target_fields_data,
baseline_field_data=baseline_field_data,
group_field_data=group_field_data,
grouping_data=grouping_data,
# kwargs=,
)
result_dataset = self._local_extract_dataset(
compare_result, {key: StatisticRole() for key in compare_result}
)
return self._set_value(data, result_dataset)
[docs]
class StatHypothesisTesting(Comparator, ABC):
def __init__(
self,
compare_by: Literal[
"groups", "columns", "columns_in_groups", "cross", "matched_pairs"
],
grouping_role: ABCRole | None = None,
target_role: ABCRole | None = None,
baseline_role: ABCRole | None = None,
reliability: float = 0.05,
key: Any = "",
calc_kwargs: dict[str, Any] = {},
):
super().__init__(
compare_by=compare_by,
grouping_role=grouping_role,
target_roles=target_role,
baseline_role=baseline_role,
key=key,
calc_kwargs=calc_kwargs,
)
self.reliability = reliability