Skip to content

Splitting Strategy

Splitting of Dataset (payn.Splitting.DataSplitting)

Provides reproducible mechanisms for partitioning data into training, validation, and test sets. It prevents data leakage by ensuring total isolation of indices across splits.

The following splitting strategies are implemented:

  • Random Split (K-Fold/Single): Standard stratified shuffling (with random seed from config) for general-purpose evaluation. 5-fold splitting is our default within this work.
  • Scaffold Split (Leave-One-Group-Out): Partitions data based on molecular scaffolds (core structures). This evaluates the model's ability to generalize to structurally distinct families of molecules.
  • Butina Clustering Split: Uses RDKit's Butina algorithm to cluster molecules based on Tanimoto similarity (using fingerprints). Entire clusters are assigned to either train or test sets to ensure that the test set is chemically distinct from the training set, enforcing a rigorous test of out-of-domain generalization.

All splitting methods are tightly integrated with payn.DataSchema validators to automatically verify split integrity (conservation of row counts, mutual exclusivity of indices) before any training occurs.

Class for performing data splitting operations.

Supports multiple splitting strategies including: - K-Fold Cross-Validation - Random Train/Validation/Test Split - Scaffold-based (Leave-One-Group-Out) Split - Butina Clustering Split (Chemical Similarity)

Attributes:

Name Type Description
data DataFrame

The full dataset to be split.

validation_size Optional[float]

Fraction of the training data to use for validation.

test_size Optional[float]

Fraction of the data to use for testing.

random_state int

Seed for reproducibility.

logger Optional[Logger]

Logger instance for tracking split stats.

application_mode str

Usage mode ('training' or 'inference').

n_splits Optional[int]

Default number of folds for K-Fold.

meta_column_name str

Column name used to store split labels (e.g., 'role').

Source code in payn\Splitting\splitting.py
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
class DataSplitting:
    """
    Class for performing data splitting operations.

    Supports multiple splitting strategies including:
    - K-Fold Cross-Validation
    - Random Train/Validation/Test Split
    - Scaffold-based (Leave-One-Group-Out) Split
    - Butina Clustering Split (Chemical Similarity)

    Attributes:
        data (pd.DataFrame): The full dataset to be split.
        validation_size (Optional[float]): Fraction of the training data to use for validation.
        test_size (Optional[float]): Fraction of the data to use for testing.
        random_state (int): Seed for reproducibility.
        logger (Optional[Logger]): Logger instance for tracking split stats.
        application_mode (str): Usage mode ('training' or 'inference').
        n_splits (Optional[int]): Default number of folds for K-Fold.
        meta_column_name (str): Column name used to store split labels (e.g., 'role').
    """

    def __init__(self, data: pd.DataFrame, validation_size: float = None, test_size: float = None, random_state: int = 42,
                 logger: Logger = None, application_mode: str="training", n_splits: int = None, meta_column_name: str = "role") -> None:
        """
        Initialize the DataSplitting class.

        Args:
            data (pd.DataFrame): Dataframe to be split.
            validation_size (float): Default proportion of validation data within training data.
            test_size (float): Default proportion of data reserved for testing.
            random_state (int): Random state for reproducibility.
            logger (Logger): Instance of Logger class for logging purposes.
            application_mode (str): Context of usage ('training', 'inference').
            n_splits (int): Default number of folds for cross-validation.
            meta_column_name (str): Name of the column to store split roles.
        """
        self.data = data.copy()
        self.validation_size = validation_size
        self.test_size = test_size
        self.random_state = random_state
        self.logger = logger
        self.application_mode = application_mode
        # Optional parameters for kfold splitting and train-test-val split
        self.n_splits = n_splits
        self.meta_column_name = meta_column_name

    @classmethod
    def from_config(cls, config: dict, data: pd.DataFrame, logger: Optional[Logger] = None) -> "DataSplitting":
        """
        Initialize the DataSplitting class from a config object.

        Args:
            config (dict): Configuration dictionary containing splitting parameters.
            data (pd.DataFrame): Data to be split.
            logger (Logger, optional): Instance of Logger class for logging purposes.

        Returns:
            DataSplitting: Instance of the DataSplitting class.
        """
        return cls(
            data=data,
            validation_size= config["splitting"]["validation_size"],
            test_size=config["splitting"]["test_size"],
            random_state=config["general"]["random_seed"],
            application_mode = config["general"]["usage_mode"],
            logger=logger,
            n_splits=config["splitting"]["cross_validation_folds"],
            meta_column_name=config["meta_columns"]["meta_data_point_role"]
        )


    def kfold_train_val_test_split(self, n_splits: int = None, meta_column_name: str = None,
                                   application_mode: str = None,  schema: DataSchema = None) -> Generator[Tuple[pd.DataFrame, pd.DataFrame, pd.DataFrame], None, None]:
        """
        Perform K-Fold train-validation-test splits for cross-validation.
        Splits the data into N folds. In each iteration, one fold is the test set,
        and the remaining data is split into training and validation sets based
        on `self.validation_size`.

        Args:
            n_splits (int, 5): Number of folds for cross-validation.
            meta_column_name (str, optional): Name of the column used to label the role of each datapoint.
            application_mode (str, optional): Validate Dataframe against schema for training, validation and test.
            schema (DataSchema, optional): Schema to validate correct splitting into training, validation and test set.

        Yields:
            Generator[Tuple[pd.DataFrame, pd.DataFrame, pd.DataFrame], None, None]:
            A tuple (train_data, val_data, test_data) for each fold.
        """
        n_splits = n_splits or self.n_splits
        meta_column_name = meta_column_name or self.meta_column_name
        application_mode = application_mode or self.application_mode


        kf = KFold(n_splits=n_splits, shuffle=True, random_state=self.random_state)
        splits = []

        for fold_index, (train_val_idx, test_idx) in enumerate(kf.split(self.data)):
            # Create copies for each split to ensure immutability.
            train_val_data = self.data.iloc[train_val_idx].copy()
            test_data = self.data.iloc[test_idx].copy()

            # Further split train_val_data into train and validation sets.
            train_data, val_data = train_test_split(
                train_val_data,
                test_size=self.validation_size,
                random_state=self.random_state
            )
            # Label the datapoint role for traceability.
            test_data[meta_column_name] = f"test_fold_{fold_index}"
            train_data[meta_column_name] = f"train_fold_{fold_index}"
            val_data[meta_column_name] = f"validation_fold_{fold_index}"

            # Embedded schema validation if a schema is provided
            if schema:
                for df, name in zip([train_data, val_data, test_data], ["train", "validation", "test"]):
                    validate_dataframe(df, schema, mode=application_mode)
                    self.logger and self.logger.log_message(
                        f"Fold {fold_index}: {name} split validated against schema.")

            # Perform leakage checks (ensuring index sets are disjoint)
                verify_no_leakage(train=train_data, val=val_data, test=test_data, fold_index=fold_index, meta_column_name=meta_column_name)
                validate_split_integrity(input_dfs=[self.data], output_dfs=[train_data, val_data, test_data])

            splits.append((train_data, val_data, test_data))

            # Optionally log the datasets.
            if self.logger:
                self.logger.log_fold_data(train_data, val_data, test_data, fold_index)

            yield train_data, val_data, test_data


    def kfold_train_val_scaffold_test_split(self, column_for_scaffolds: str, val_size: float = None, meta_column_name: str = None,
                                             schema: DataSchema = None) -> Generator[Tuple[pd.DataFrame, pd.DataFrame, pd.DataFrame], None, None]:
        """
        Perform Leave-One-Group-Out cross-validation based on a scaffold column.

        Iterates through each unique value (scaffold) in `column_for_scaffolds`.
        In each fold, all rows with that scaffold become the test set.

        Args:
            val_size (float): Optional override for the validation set fraction.
            meta_column_name (str): Optional override for the metadata column name.
            column_for_scaffolds (str): The column defining the groups/scaffolds.
            schema (DataSchema): DataSchema object to validate the splits against.

        Yields:
            A tuple (train_data, val_data, test_data) for each scaffold fold.

        Raises:
            ValueError: If the scaffold column is missing.
        """
        val_size = val_size or self.validation_size
        meta_column_name = meta_column_name or self.meta_column_name

        if column_for_scaffolds not in self.data.columns:
            raise ValueError(f"Scaffold column '{column_for_scaffolds}' not found in data.")

        scaffolds = self.data[column_for_scaffolds].unique()
        n_splits = len(scaffolds)

        for fold_index, scaffold_group in enumerate(scaffolds):
            self.logger and self.logger.log_message(
                f"Processing scaffold fold {fold_index + 1}/{n_splits} (Group: {scaffold_group})")

            # Identify indices for the current test set (all rows with this scaffold)
            test_idx = self.data.index[self.data[column_for_scaffolds] == scaffold_group]
            # All other indices form the training and validation pool
            train_val_idx = self.data.index.difference(test_idx)

            # Create DataFrames and perform random train/val split (from SOP)
            train_val_data = self.data.iloc[train_val_idx].copy()
            test_data = self.data.iloc[test_idx].copy()

            train_data, val_data = train_test_split(
                train_val_data, test_size=val_size, random_state=self.random_state
            )

            # Labeling, Validation, and Logging (from SOP)
            test_data[meta_column_name] = f"test_fold_{fold_index}"
            train_data[meta_column_name] = f"train_fold_{fold_index}"
            val_data[meta_column_name] = f"validation_fold_{fold_index}"

            if schema:
                for df, name in zip([train_data, val_data, test_data], ["train", "validation", "test"]):
                    validate_dataframe(df, schema, mode="training")
                    self.logger and self.logger.log_message(f"Single split: {name} set validated against schema.")

                verify_no_leakage(train_data, val_data, test_data, fold_index=fold_index, meta_column_name=meta_column_name)
                validate_split_integrity(input_dfs=[self.data], output_dfs=[train_data, val_data, test_data])

            # Optionally log the splits.
            if self.logger:
                self.logger.log_fold_data(train_data, val_data, test_data, fold_index=fold_index)

            yield train_data, val_data, test_data


    def train_val_random_test_split(self, test_size: float = None, val_size: float = None,
                                    meta_column_name: str = None, schema: DataSchema = None) -> Tuple[pd.DataFrame, pd.DataFrame, pd.DataFrame]:
        """
        Generates a single, validated train-validation-test split randomly.
        Args:
            test_size (float, optional): Proportion of the dataset to be allocated to the test split.
            val_size (float, optional): Proportion of the training set (after removing the test split) to be used as validation.
            meta_column_name (str, optional): Name of the column used to label datapoint roles.
            schema (DataSchema, optional): Schema to validate correct splitting into training, validation and test set.

        Returns:
        A tuple containing the (train_data, val_data, test_data) DataFrames.
        """
        meta_column_name = meta_column_name or self.meta_column_name
        val_size = val_size or self.validation_size
        test_size = test_size or self.test_size

        # Perform the splits, creating safe copies
        train_val_data, test_data = train_test_split(
            self.data,
            test_size=test_size,
            random_state=self.random_state
        )
        train_val_data = train_val_data.copy()
        test_data = test_data.copy()

        # The validation size is a fraction of the remaining train_val set.
        train_data, val_data = train_test_split(
            train_val_data,
            test_size=val_size,
            random_state=self.random_state
        )
        train_data = train_data.copy()
        val_data = val_data.copy()

        train_data[meta_column_name] = "train"
        val_data[meta_column_name] = "validation"
        test_data[meta_column_name] = "test"

        if schema:
            for df, name in zip([train_data, val_data, test_data], ["train", "validation", "test"]):
                validate_dataframe(df, schema, mode="training")
                self.logger and self.logger.log_message(f"Single split: {name} set validated against schema.")

            verify_no_leakage(train_data, val_data, test_data, fold_index=-1, meta_column_name=meta_column_name)
            validate_split_integrity(input_dfs=[self.data], output_dfs=[train_data, val_data, test_data])

        # Optionally log the splits.
        if self.logger:
            self.logger.log_fold_data(train_data, val_data, test_data, fold_index=-1)

        return train_data, val_data, test_data



    def train_val_butina_test_split(self, fp_clustering_column: str, cutoff: float, test_size: float = None, val_size: float = None, meta_column_name: str = None,
                                    schema: DataSchema = None) -> Tuple[pd.DataFrame, pd.DataFrame, pd.DataFrame]:
        """Generates a single train-val-test split using Butina clustering on pre-computed fingerprints.

            This method partitions the dataset into a training/validation set and a
            chemically distinct test set based on the similarity of existing feature vectors.
            It ensures that entire molecular scaffolds are kept within the same set.

            Reference:
            Based on the deepchem implementation of Butina splitting.

            Args:
                schema (DataSchema): An optional DataSchema object to validate the splits against.
                test_size (float): Optional override for the test set fraction.
                val_size (float): Optional override for the validation set fraction.
                meta_column_name (str): Optional override for the metadata column name.
                fp_clustering_column (str): Column name containing pre-computed fingerprints for clustering.
                cutoff (float): Butina clustering cutoff.

            Returns:
                A tuple containing the (train_data, val_data, test_data) DataFrames.

            Raises:
                ValueError: If the fingerprint column is missing or split results in empty sets.
        """

        test_size = test_size or self.test_size
        val_size = val_size or self.validation_size
        meta_column_name = meta_column_name or self.meta_column_name

        if fp_clustering_column not in self.data.columns:
            raise ValueError(f"Fingerprint column '{fp_clustering_column}' not found for Butina splitting.")

        # Convert pre-computed fingerprints into RDKit BitVect objects necessary for the Tanimoto similarity calculation.
        fps_as_lists = self.data[fp_clustering_column].tolist()
        bit_vects = []
        for fp_list in fps_as_lists:
            bit_vector = DataStructs.ExplicitBitVect(len(fp_list))
            on_bits = [i for i, bit in enumerate(fp_list) if bit == 1]
            bit_vector.SetBitsFromList(on_bits)
            bit_vects.append(bit_vector)

        # Perform Butina clustering
        dists = []
        num_fps = len(bit_vects)
        for i in range(1, num_fps):
            sims = DataStructs.BulkTanimotoSimilarity(bit_vects[i], bit_vects[:i])
            dists.extend([1 - x for x in sims])

        clusters = Butina.ClusterData(data=dists, nPts=num_fps, distThresh=cutoff, isDistData=True)
        clusters = sorted(clusters, key=len, reverse=True)

        intended_train_val_size = int(len(self.data) * (1-self.test_size))
        # Assign clusters to train/val and test sets
        train_val_indices, test_indices = [], []

        for cluster in clusters:
            if len(train_val_indices) + len(cluster) <= intended_train_val_size:
                train_val_indices.extend(cluster)
            else:
                test_indices.extend(cluster)

        train_val_data = self.data.iloc[train_val_indices].copy()
        test_data = self.data.iloc[test_indices].copy()

        # Step 5: Perform standard random split for train and validation sets
        train_data, val_data = train_test_split(
            train_val_data, test_size=val_size, random_state=self.random_state
        )
        train_data = train_data.copy()
        val_data = val_data.copy()

        train_data[meta_column_name] = "train"
        val_data[meta_column_name] = "validation"
        test_data[meta_column_name] = "test"

        if train_data.empty or val_data.empty or test_data.empty:
            raise ValueError(
                "A data split resulted in one or more empty DataFrames. "
                "This can happen with scaffold splits on small or highly clustered datasets. "
                "Please check your dataset, test_size, and Butina cutoff."
            )

        if schema:
            for df, name in zip([train_data, val_data, test_data], ["train", "validation", "test"]):
                validate_dataframe(df, schema, mode="training")
                self.logger and self.logger.log_message(f"Single split: {name} set validated against schema.")

            verify_no_leakage(train_data, val_data, test_data, fold_index=-1, meta_column_name=meta_column_name)
            validate_split_integrity(input_dfs=[self.data], output_dfs=[train_data, val_data, test_data])

        # Optionally log the splits.
        if self.logger:
            self.logger.log_fold_data(train_data, val_data, test_data, fold_index=-1)

        return train_data, val_data, test_data

__init__(data, validation_size=None, test_size=None, random_state=42, logger=None, application_mode='training', n_splits=None, meta_column_name='role')

Initialize the DataSplitting class.

Parameters:

Name Type Description Default
data DataFrame

Dataframe to be split.

required
validation_size float

Default proportion of validation data within training data.

None
test_size float

Default proportion of data reserved for testing.

None
random_state int

Random state for reproducibility.

42
logger Logger

Instance of Logger class for logging purposes.

None
application_mode str

Context of usage ('training', 'inference').

'training'
n_splits int

Default number of folds for cross-validation.

None
meta_column_name str

Name of the column to store split roles.

'role'
Source code in payn\Splitting\splitting.py
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
def __init__(self, data: pd.DataFrame, validation_size: float = None, test_size: float = None, random_state: int = 42,
             logger: Logger = None, application_mode: str="training", n_splits: int = None, meta_column_name: str = "role") -> None:
    """
    Initialize the DataSplitting class.

    Args:
        data (pd.DataFrame): Dataframe to be split.
        validation_size (float): Default proportion of validation data within training data.
        test_size (float): Default proportion of data reserved for testing.
        random_state (int): Random state for reproducibility.
        logger (Logger): Instance of Logger class for logging purposes.
        application_mode (str): Context of usage ('training', 'inference').
        n_splits (int): Default number of folds for cross-validation.
        meta_column_name (str): Name of the column to store split roles.
    """
    self.data = data.copy()
    self.validation_size = validation_size
    self.test_size = test_size
    self.random_state = random_state
    self.logger = logger
    self.application_mode = application_mode
    # Optional parameters for kfold splitting and train-test-val split
    self.n_splits = n_splits
    self.meta_column_name = meta_column_name

from_config(config, data, logger=None) classmethod

Initialize the DataSplitting class from a config object.

Parameters:

Name Type Description Default
config dict

Configuration dictionary containing splitting parameters.

required
data DataFrame

Data to be split.

required
logger Logger

Instance of Logger class for logging purposes.

None

Returns:

Name Type Description
DataSplitting DataSplitting

Instance of the DataSplitting class.

Source code in payn\Splitting\splitting.py
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
@classmethod
def from_config(cls, config: dict, data: pd.DataFrame, logger: Optional[Logger] = None) -> "DataSplitting":
    """
    Initialize the DataSplitting class from a config object.

    Args:
        config (dict): Configuration dictionary containing splitting parameters.
        data (pd.DataFrame): Data to be split.
        logger (Logger, optional): Instance of Logger class for logging purposes.

    Returns:
        DataSplitting: Instance of the DataSplitting class.
    """
    return cls(
        data=data,
        validation_size= config["splitting"]["validation_size"],
        test_size=config["splitting"]["test_size"],
        random_state=config["general"]["random_seed"],
        application_mode = config["general"]["usage_mode"],
        logger=logger,
        n_splits=config["splitting"]["cross_validation_folds"],
        meta_column_name=config["meta_columns"]["meta_data_point_role"]
    )

kfold_train_val_scaffold_test_split(column_for_scaffolds, val_size=None, meta_column_name=None, schema=None)

Perform Leave-One-Group-Out cross-validation based on a scaffold column.

Iterates through each unique value (scaffold) in column_for_scaffolds. In each fold, all rows with that scaffold become the test set.

Parameters:

Name Type Description Default
val_size float

Optional override for the validation set fraction.

None
meta_column_name str

Optional override for the metadata column name.

None
column_for_scaffolds str

The column defining the groups/scaffolds.

required
schema DataSchema

DataSchema object to validate the splits against.

None

Yields:

Type Description
Tuple[DataFrame, DataFrame, DataFrame]

A tuple (train_data, val_data, test_data) for each scaffold fold.

Raises:

Type Description
ValueError

If the scaffold column is missing.

Source code in payn\Splitting\splitting.py
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
def kfold_train_val_scaffold_test_split(self, column_for_scaffolds: str, val_size: float = None, meta_column_name: str = None,
                                         schema: DataSchema = None) -> Generator[Tuple[pd.DataFrame, pd.DataFrame, pd.DataFrame], None, None]:
    """
    Perform Leave-One-Group-Out cross-validation based on a scaffold column.

    Iterates through each unique value (scaffold) in `column_for_scaffolds`.
    In each fold, all rows with that scaffold become the test set.

    Args:
        val_size (float): Optional override for the validation set fraction.
        meta_column_name (str): Optional override for the metadata column name.
        column_for_scaffolds (str): The column defining the groups/scaffolds.
        schema (DataSchema): DataSchema object to validate the splits against.

    Yields:
        A tuple (train_data, val_data, test_data) for each scaffold fold.

    Raises:
        ValueError: If the scaffold column is missing.
    """
    val_size = val_size or self.validation_size
    meta_column_name = meta_column_name or self.meta_column_name

    if column_for_scaffolds not in self.data.columns:
        raise ValueError(f"Scaffold column '{column_for_scaffolds}' not found in data.")

    scaffolds = self.data[column_for_scaffolds].unique()
    n_splits = len(scaffolds)

    for fold_index, scaffold_group in enumerate(scaffolds):
        self.logger and self.logger.log_message(
            f"Processing scaffold fold {fold_index + 1}/{n_splits} (Group: {scaffold_group})")

        # Identify indices for the current test set (all rows with this scaffold)
        test_idx = self.data.index[self.data[column_for_scaffolds] == scaffold_group]
        # All other indices form the training and validation pool
        train_val_idx = self.data.index.difference(test_idx)

        # Create DataFrames and perform random train/val split (from SOP)
        train_val_data = self.data.iloc[train_val_idx].copy()
        test_data = self.data.iloc[test_idx].copy()

        train_data, val_data = train_test_split(
            train_val_data, test_size=val_size, random_state=self.random_state
        )

        # Labeling, Validation, and Logging (from SOP)
        test_data[meta_column_name] = f"test_fold_{fold_index}"
        train_data[meta_column_name] = f"train_fold_{fold_index}"
        val_data[meta_column_name] = f"validation_fold_{fold_index}"

        if schema:
            for df, name in zip([train_data, val_data, test_data], ["train", "validation", "test"]):
                validate_dataframe(df, schema, mode="training")
                self.logger and self.logger.log_message(f"Single split: {name} set validated against schema.")

            verify_no_leakage(train_data, val_data, test_data, fold_index=fold_index, meta_column_name=meta_column_name)
            validate_split_integrity(input_dfs=[self.data], output_dfs=[train_data, val_data, test_data])

        # Optionally log the splits.
        if self.logger:
            self.logger.log_fold_data(train_data, val_data, test_data, fold_index=fold_index)

        yield train_data, val_data, test_data

kfold_train_val_test_split(n_splits=None, meta_column_name=None, application_mode=None, schema=None)

Perform K-Fold train-validation-test splits for cross-validation. Splits the data into N folds. In each iteration, one fold is the test set, and the remaining data is split into training and validation sets based on self.validation_size.

Parameters:

Name Type Description Default
n_splits (int, 5)

Number of folds for cross-validation.

None
meta_column_name str

Name of the column used to label the role of each datapoint.

None
application_mode str

Validate Dataframe against schema for training, validation and test.

None
schema DataSchema

Schema to validate correct splitting into training, validation and test set.

None

Yields:

Type Description
DataFrame

Generator[Tuple[pd.DataFrame, pd.DataFrame, pd.DataFrame], None, None]:

DataFrame

A tuple (train_data, val_data, test_data) for each fold.

Source code in payn\Splitting\splitting.py
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
def kfold_train_val_test_split(self, n_splits: int = None, meta_column_name: str = None,
                               application_mode: str = None,  schema: DataSchema = None) -> Generator[Tuple[pd.DataFrame, pd.DataFrame, pd.DataFrame], None, None]:
    """
    Perform K-Fold train-validation-test splits for cross-validation.
    Splits the data into N folds. In each iteration, one fold is the test set,
    and the remaining data is split into training and validation sets based
    on `self.validation_size`.

    Args:
        n_splits (int, 5): Number of folds for cross-validation.
        meta_column_name (str, optional): Name of the column used to label the role of each datapoint.
        application_mode (str, optional): Validate Dataframe against schema for training, validation and test.
        schema (DataSchema, optional): Schema to validate correct splitting into training, validation and test set.

    Yields:
        Generator[Tuple[pd.DataFrame, pd.DataFrame, pd.DataFrame], None, None]:
        A tuple (train_data, val_data, test_data) for each fold.
    """
    n_splits = n_splits or self.n_splits
    meta_column_name = meta_column_name or self.meta_column_name
    application_mode = application_mode or self.application_mode


    kf = KFold(n_splits=n_splits, shuffle=True, random_state=self.random_state)
    splits = []

    for fold_index, (train_val_idx, test_idx) in enumerate(kf.split(self.data)):
        # Create copies for each split to ensure immutability.
        train_val_data = self.data.iloc[train_val_idx].copy()
        test_data = self.data.iloc[test_idx].copy()

        # Further split train_val_data into train and validation sets.
        train_data, val_data = train_test_split(
            train_val_data,
            test_size=self.validation_size,
            random_state=self.random_state
        )
        # Label the datapoint role for traceability.
        test_data[meta_column_name] = f"test_fold_{fold_index}"
        train_data[meta_column_name] = f"train_fold_{fold_index}"
        val_data[meta_column_name] = f"validation_fold_{fold_index}"

        # Embedded schema validation if a schema is provided
        if schema:
            for df, name in zip([train_data, val_data, test_data], ["train", "validation", "test"]):
                validate_dataframe(df, schema, mode=application_mode)
                self.logger and self.logger.log_message(
                    f"Fold {fold_index}: {name} split validated against schema.")

        # Perform leakage checks (ensuring index sets are disjoint)
            verify_no_leakage(train=train_data, val=val_data, test=test_data, fold_index=fold_index, meta_column_name=meta_column_name)
            validate_split_integrity(input_dfs=[self.data], output_dfs=[train_data, val_data, test_data])

        splits.append((train_data, val_data, test_data))

        # Optionally log the datasets.
        if self.logger:
            self.logger.log_fold_data(train_data, val_data, test_data, fold_index)

        yield train_data, val_data, test_data

train_val_butina_test_split(fp_clustering_column, cutoff, test_size=None, val_size=None, meta_column_name=None, schema=None)

Generates a single train-val-test split using Butina clustering on pre-computed fingerprints.

This method partitions the dataset into a training/validation set and a chemically distinct test set based on the similarity of existing feature vectors. It ensures that entire molecular scaffolds are kept within the same set.

Reference: Based on the deepchem implementation of Butina splitting.

Parameters:

Name Type Description Default
schema DataSchema

An optional DataSchema object to validate the splits against.

None
test_size float

Optional override for the test set fraction.

None
val_size float

Optional override for the validation set fraction.

None
meta_column_name str

Optional override for the metadata column name.

None
fp_clustering_column str

Column name containing pre-computed fingerprints for clustering.

required
cutoff float

Butina clustering cutoff.

required

Returns:

Type Description
Tuple[DataFrame, DataFrame, DataFrame]

A tuple containing the (train_data, val_data, test_data) DataFrames.

Raises:

Type Description
ValueError

If the fingerprint column is missing or split results in empty sets.

Source code in payn\Splitting\splitting.py
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
def train_val_butina_test_split(self, fp_clustering_column: str, cutoff: float, test_size: float = None, val_size: float = None, meta_column_name: str = None,
                                schema: DataSchema = None) -> Tuple[pd.DataFrame, pd.DataFrame, pd.DataFrame]:
    """Generates a single train-val-test split using Butina clustering on pre-computed fingerprints.

        This method partitions the dataset into a training/validation set and a
        chemically distinct test set based on the similarity of existing feature vectors.
        It ensures that entire molecular scaffolds are kept within the same set.

        Reference:
        Based on the deepchem implementation of Butina splitting.

        Args:
            schema (DataSchema): An optional DataSchema object to validate the splits against.
            test_size (float): Optional override for the test set fraction.
            val_size (float): Optional override for the validation set fraction.
            meta_column_name (str): Optional override for the metadata column name.
            fp_clustering_column (str): Column name containing pre-computed fingerprints for clustering.
            cutoff (float): Butina clustering cutoff.

        Returns:
            A tuple containing the (train_data, val_data, test_data) DataFrames.

        Raises:
            ValueError: If the fingerprint column is missing or split results in empty sets.
    """

    test_size = test_size or self.test_size
    val_size = val_size or self.validation_size
    meta_column_name = meta_column_name or self.meta_column_name

    if fp_clustering_column not in self.data.columns:
        raise ValueError(f"Fingerprint column '{fp_clustering_column}' not found for Butina splitting.")

    # Convert pre-computed fingerprints into RDKit BitVect objects necessary for the Tanimoto similarity calculation.
    fps_as_lists = self.data[fp_clustering_column].tolist()
    bit_vects = []
    for fp_list in fps_as_lists:
        bit_vector = DataStructs.ExplicitBitVect(len(fp_list))
        on_bits = [i for i, bit in enumerate(fp_list) if bit == 1]
        bit_vector.SetBitsFromList(on_bits)
        bit_vects.append(bit_vector)

    # Perform Butina clustering
    dists = []
    num_fps = len(bit_vects)
    for i in range(1, num_fps):
        sims = DataStructs.BulkTanimotoSimilarity(bit_vects[i], bit_vects[:i])
        dists.extend([1 - x for x in sims])

    clusters = Butina.ClusterData(data=dists, nPts=num_fps, distThresh=cutoff, isDistData=True)
    clusters = sorted(clusters, key=len, reverse=True)

    intended_train_val_size = int(len(self.data) * (1-self.test_size))
    # Assign clusters to train/val and test sets
    train_val_indices, test_indices = [], []

    for cluster in clusters:
        if len(train_val_indices) + len(cluster) <= intended_train_val_size:
            train_val_indices.extend(cluster)
        else:
            test_indices.extend(cluster)

    train_val_data = self.data.iloc[train_val_indices].copy()
    test_data = self.data.iloc[test_indices].copy()

    # Step 5: Perform standard random split for train and validation sets
    train_data, val_data = train_test_split(
        train_val_data, test_size=val_size, random_state=self.random_state
    )
    train_data = train_data.copy()
    val_data = val_data.copy()

    train_data[meta_column_name] = "train"
    val_data[meta_column_name] = "validation"
    test_data[meta_column_name] = "test"

    if train_data.empty or val_data.empty or test_data.empty:
        raise ValueError(
            "A data split resulted in one or more empty DataFrames. "
            "This can happen with scaffold splits on small or highly clustered datasets. "
            "Please check your dataset, test_size, and Butina cutoff."
        )

    if schema:
        for df, name in zip([train_data, val_data, test_data], ["train", "validation", "test"]):
            validate_dataframe(df, schema, mode="training")
            self.logger and self.logger.log_message(f"Single split: {name} set validated against schema.")

        verify_no_leakage(train_data, val_data, test_data, fold_index=-1, meta_column_name=meta_column_name)
        validate_split_integrity(input_dfs=[self.data], output_dfs=[train_data, val_data, test_data])

    # Optionally log the splits.
    if self.logger:
        self.logger.log_fold_data(train_data, val_data, test_data, fold_index=-1)

    return train_data, val_data, test_data

train_val_random_test_split(test_size=None, val_size=None, meta_column_name=None, schema=None)

Generates a single, validated train-validation-test split randomly. Args: test_size (float, optional): Proportion of the dataset to be allocated to the test split. val_size (float, optional): Proportion of the training set (after removing the test split) to be used as validation. meta_column_name (str, optional): Name of the column used to label datapoint roles. schema (DataSchema, optional): Schema to validate correct splitting into training, validation and test set.

Returns: A tuple containing the (train_data, val_data, test_data) DataFrames.

Source code in payn\Splitting\splitting.py
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
def train_val_random_test_split(self, test_size: float = None, val_size: float = None,
                                meta_column_name: str = None, schema: DataSchema = None) -> Tuple[pd.DataFrame, pd.DataFrame, pd.DataFrame]:
    """
    Generates a single, validated train-validation-test split randomly.
    Args:
        test_size (float, optional): Proportion of the dataset to be allocated to the test split.
        val_size (float, optional): Proportion of the training set (after removing the test split) to be used as validation.
        meta_column_name (str, optional): Name of the column used to label datapoint roles.
        schema (DataSchema, optional): Schema to validate correct splitting into training, validation and test set.

    Returns:
    A tuple containing the (train_data, val_data, test_data) DataFrames.
    """
    meta_column_name = meta_column_name or self.meta_column_name
    val_size = val_size or self.validation_size
    test_size = test_size or self.test_size

    # Perform the splits, creating safe copies
    train_val_data, test_data = train_test_split(
        self.data,
        test_size=test_size,
        random_state=self.random_state
    )
    train_val_data = train_val_data.copy()
    test_data = test_data.copy()

    # The validation size is a fraction of the remaining train_val set.
    train_data, val_data = train_test_split(
        train_val_data,
        test_size=val_size,
        random_state=self.random_state
    )
    train_data = train_data.copy()
    val_data = val_data.copy()

    train_data[meta_column_name] = "train"
    val_data[meta_column_name] = "validation"
    test_data[meta_column_name] = "test"

    if schema:
        for df, name in zip([train_data, val_data, test_data], ["train", "validation", "test"]):
            validate_dataframe(df, schema, mode="training")
            self.logger and self.logger.log_message(f"Single split: {name} set validated against schema.")

        verify_no_leakage(train_data, val_data, test_data, fold_index=-1, meta_column_name=meta_column_name)
        validate_split_integrity(input_dfs=[self.data], output_dfs=[train_data, val_data, test_data])

    # Optionally log the splits.
    if self.logger:
        self.logger.log_fold_data(train_data, val_data, test_data, fold_index=-1)

    return train_data, val_data, test_data