Source code for hypex.splitters.aa

from __future__ import annotations

from typing import Any

import numpy as np
import pandas as pd

from ..dataset import (
    AdditionalTreatmentRole,
    Dataset,
    ExperimentData,
    StratificationRole,
    TreatmentRole,
)
from ..dataset.roles import ConstGroupRole
from ..executor import Calculator
from ..utils import ExperimentDataEnum


[docs] class AASplitter(Calculator): def __init__( self, control_size: float = 0.5, random_state: int | None = None, sample_size: float | None = None, constant_key: bool = True, save_groups: bool = True, groups_sizes: list[float] | None = None, key: Any = "", ): self.control_size = control_size self.random_state = random_state self._key = key self.constant_key = constant_key self.save_groups = save_groups self.sample_size = sample_size self.groups_sizes = groups_sizes super().__init__(key) def _generate_params_hash(self): hash_parts: list[str] = [] if self.control_size != 0.5: hash_parts.append(f"cs {self.control_size}") if self.random_state is not None: hash_parts.append(f"rs {self.random_state}") if self.groups_sizes is not None: hash_parts.append(f"gs {self.groups_sizes}") self._params_hash = "|".join(hash_parts) def init_from_hash(self, params_hash: str): hash_parts: list[str] = params_hash.split("|") for hash_part in hash_parts: if hash_part.startswith("cs"): self.control_size = float(hash_part[hash_part.rfind(" ") + 1 :]) elif hash_part.startswith("rs"): self.random_state = int(hash_part[hash_part.rfind(" ") + 1 :]) elif hash_part.startswith("gs"): self.groups_sizes = [] groups_sizes = ( hash_part[hash_part.find(" ") + 1 :].strip("[]").split(",") ) self.groups_sizes = [float(gs) for gs in groups_sizes] self._generate_id() @property def key(self) -> Any: return self._key @key.setter def key(self, value: Any): if not self.constant_key: self._key = value self._generate_id() def _set_value(self, data: ExperimentData, value, key=None) -> ExperimentData: data = data.set_value( ExperimentDataEnum.additional_fields, self._id, value, role=AdditionalTreatmentRole(), ) if self.save_groups: data.groups[self.id] = { group: data.ds.loc[group_data.index] for group, group_data in data.additional_fields.groupby(self.id) } return data @staticmethod def _inner_function( data: Dataset, random_state: int | None = None, control_size: float = 0.5, groups_sizes: list[float] | None = None, sample_size: float | None = None, const_group_field: str | None = None, **kwargs, ) -> list[str]: sample_size = 1.0 if sample_size is None else sample_size control_indexes = [] if const_group_field: const_data = dict(data.groupby(const_group_field)) control_data = const_data.get("control") if control_data is not None: control_indexes = list(control_data.index) const_size = sum(len(cd) for cd in const_data.values()) control_size = ( 0 if len(data) <= const_size else (len(data) * control_size - len(const_data["control"])) / (len(data) - const_size) ) # control_size = len(data) * control_size experiment_data = ( data[data[const_group_field].isna()] if const_group_field else data ) experiment_data_index = experiment_data.sample( frac=sample_size, random_state=random_state ).index addition_indexes = list(experiment_data_index) edges = [] if groups_sizes: if sum(groups_sizes) != 1: raise ValueError("Groups sizes must sum to 1") for group_size in groups_sizes: size = int(len(addition_indexes) * group_size) + ( 0 if not edges else edges[-1] ) size = min(size, len(addition_indexes)) if size not in edges: edges += [size] else: edges = [int(len(addition_indexes) * control_size), len(addition_indexes)] control_indexes += addition_indexes[: edges[0]] test_indexes = [ addition_indexes[edges[i - 1] : edges[i]] for i in range(1, len(edges)) ] split_series = pd.Series( np.ones(data.data.shape[0], dtype="int"), index=data.data.index ) split_series[control_indexes] -= 1 for i, test_index in enumerate(test_indexes): split_series[test_index] += i label_map = {0: "control"} label_map.update({i: f"test_{i}" for i in range(1, len(edges))}) split_series = split_series.map(label_map) return split_series.to_list() def execute(self, data: ExperimentData) -> ExperimentData: const_group_fields = data.ds.search_columns(ConstGroupRole()) const_group_fields = ( const_group_fields[0] if len(const_group_fields) > 0 else None ) result = self.calc( data.ds, random_state=self.random_state, control_size=self.control_size, sample_size=self.sample_size, const_group_field=const_group_fields, groups_sizes=self.groups_sizes, ) return self._set_value( data, result, )
[docs] class AASplitterWithStratification(AASplitter): @staticmethod def _inner_function( data: Dataset, random_state: int | None = None, control_size: float = 0.5, grouping_fields=None, **kwargs, ) -> list[str] | Dataset: if not grouping_fields: return AASplitter._inner_function( data, random_state, control_size, **kwargs ) result = {"split": []} index = [] for group, group_data in data.groupby(grouping_fields): result["split"].extend( AASplitter._inner_function(group_data, random_state, control_size) ) index.extend(list(group_data.index)) return Dataset.from_dict(result, index=index, roles={"split": TreatmentRole()}) def execute(self, data: ExperimentData) -> ExperimentData: grouping_fields = data.ds.search_columns(StratificationRole()) result = self.calc( data.ds, random_state=self.random_state, control_size=self.control_size, grouping_fields=grouping_fields, groups_sizes=self.groups_sizes, ) if isinstance(result, Dataset): result = result.replace_roles({"split": AdditionalTreatmentRole()}) return self._set_value(data, result)
# # class AASplitterWithStratification(AASplitter): # def __init__( # self, # control_size: float = 0.5, # random_state: Optional[int] = None, # # key: Any = "", # ): # super().__init__(control_size, random_state, key) # # def calc(self, data: Dataset): # stratification_columns = data.get_columns_by_roles(StratificationRole()) # # groups = data.groupby(stratification_columns) # result = Dataset._create_empty() # for _, gd in groups: # ged = ExperimentData(gd) # ged = super().execute(ged) # # result = ged if result is None else result.append(ged) # return result["group"] # As idea # class SplitterAAMulti(ExperimentMulti): # def execute(self, data): # raise NotImplementedError