Augmentation Models
This module contains the core logic for the PU learning process: the Spy Model classifier and the decision engine for identifying reliable negatives.
Spy Model (payn.AugmentationModels.SpyModel.SpyModel)
Wraps a CatBoostClassifier. Selected for its native handling of categorical features and robust performance on tabular chemical data without extensive preprocessing. Other model architectures are applicable here as well, but a class probability score must be calculable or estimable.
The spy model is trained on the spy_infused_training_data (from payn.SpySplitting) to distinguish between "known positives" (s = 1) and "unlabeled/spy Mixture" (s = 0).
- Categorical Handling: The model automatically detects categorical features (e.g., specific bit positions or metadata tags) appended to the end of the feature vector, optimizing the split strategy for mixed data types.
- Parallelisation: Automatically detects SLURM cluster environments (
SLURM_CPUS_PER_TASK) to adjust thread counts (thread_count), ensuring optimal resource usage while defaulting to single-threaded execution locally for maximum safety. - Determinism: Random seeds are propagated strictly from the global config to the CatBoost engine (
random_state). - Logging:
SpyModelis tightly coupled with thepayn.Loggingsystem. It automatically logs hyperparameters, trained model artifacts, and evaluation metrics (on test sets) to MLflow run immediately after training.
SpyModel encapsulates the CatBoostClassifier used in the Spy-based learning step.
Attributes:
| Name | Type | Description |
|---|---|---|
config_key |
str
|
The key in the config dict relevant to SpyModel. |
logger |
Optional[Logger]
|
Logger instance for logging model training and evaluation. |
fold_index |
int
|
Index of the current fold (for cross-validation purposes). |
random_state |
int
|
Random seed. |
eval_metric |
str
|
Evaluation metric to use. |
verbose |
int
|
Verbosity level. |
model |
Optional[CatBoostClassifier]
|
The trained CatBoost model. |
feature_column_name |
Optional[str]
|
Column name containing feature vectors. |
training_target_column_name |
Optional[str]
|
Target column name for training data. |
validation_target_column_name |
Optional[str]
|
Target column name for validation data. |
metrics_list |
Optional[List[str]]
|
List of additional metrics to evaluate. |
categorical_column_indices |
List[int]
|
Indices of features identified as categorical. |
Source code in payn\AugmentationModels\SpyModel\spymodel.py
9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 | |
__init__(eval_metric, random_state, verbose, fold_index=1, logger=None, feature_column_name=None, training_target_column_name=None, validation_target_column_name=None, metrics_list=None, categorical_column_indices=None)
Initialize the SpyModel class.
You can either pass a config dict via the alternative constructor from_config or pass parameters explicitly.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
eval_metric
|
str
|
Metric for evaluation of model performance. |
required |
random_state
|
int
|
Random seed. |
required |
verbose
|
int
|
Verbosity level for CatBoost output. |
required |
fold_index
|
int
|
Index of the current fold. Defaults to 1. |
1
|
logger
|
Logger
|
Logger instance for logging. |
None
|
feature_column_name
|
str
|
Column name containing feature vectors. |
None
|
training_target_column_name
|
str
|
Target column name for training data. |
None
|
validation_target_column_name
|
str
|
Target column name for validation data. |
None
|
metrics_list
|
List[str]
|
List of metrics to evaluate. |
None
|
categorical_column_indices
|
List[int]
|
Indices of features that are categorical. |
None
|
Source code in payn\AugmentationModels\SpyModel\spymodel.py
30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 | |
evaluate(test_pool)
Evaluate the trained model on a test dataset.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
test_pool
|
Pool
|
Catboost Pool containing test features and labels. |
required |
Returns:
| Type | Description |
|---|---|
Dict[str, Any]
|
Dict[str, Any]: Dictionary of evaluation metrics. |
Raises:
| Type | Description |
|---|---|
ValueError
|
If the model has not been trained yet. |
Source code in payn\AugmentationModels\SpyModel\spymodel.py
201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 | |
from_config(config, logger=None, fold_index=1)
classmethod
Alternative constructor that creates a SpyModel instance from a config object.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
config
|
dict
|
Configuration dictionary. |
required |
logger
|
Logger
|
Logger instance. |
None
|
fold_index
|
int
|
Current fold index. |
1
|
Returns:
| Name | Type | Description |
|---|---|---|
SpyModel |
SpyModel
|
An instance of SpyModel with parameters extracted from the config. |
Source code in payn\AugmentationModels\SpyModel\spymodel.py
67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 | |
predict(data, feature_column=None)
Make predictions using the trained Spy model.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
data
|
DataFrame
|
Dataset containing features for prediction. |
required |
feature_column
|
str
|
Column name for features. |
None
|
Returns:
| Type | Description |
|---|---|
Series
|
pd.Series: Predicted labels or probabilities. |
Source code in payn\AugmentationModels\SpyModel\spymodel.py
223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 | |
train(train_data, val_data, test_data=None, feature_column=None, training_label_column=None, validation_label_column=None, **kwargs)
Train the Spy model on the given datasets.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
train_data
|
DataFrame
|
Training dataset with features and target labels. |
required |
val_data
|
DataFrame
|
Validation dataset for monitoring training progress. |
required |
test_data
|
Optional[DataFrame]
|
Optional test dataset for evaluation (default: None). |
None
|
feature_column
|
Optional[str]
|
Column name for features. |
None
|
training_label_column
|
Optional[str]
|
Column name for target labels in training data. |
None
|
validation_label_column
|
Optional[str]
|
Column name for target labels in validation data. |
None
|
**kwargs
|
Any
|
Additional hyperparameters (overriding defaults). |
{}
|
Returns:
| Name | Type | Description |
|---|---|---|
CatBoostClassifier |
CatBoostClassifier
|
Trained CatBoost model. |
Source code in payn\AugmentationModels\SpyModel\spymodel.py
128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 | |
Reliable Negative Identification (payn.AugmentationModels.SpyModel.augmen_negative_identifier)
This module is the decision-making engine of the PU learning workflow. It leverages the trained Spy Model to filter the unlabeled dataset, identifying a subset of reliable negatives that are statistically distinct from the positive class.
- Dynamic Thresholding: Instead of using a fixed probability threshold (e.g., 0.5), the module calculates a dynamic cutoff based on the probability distribution of the spies within the unlabeled datapoints (known positives injected into the unlabeled set). A user-defined
spy_tolerance(default 5%) sets the threshold such that 95% of the spies are correctly recognized as positive by the model. This ensures that the identified negatives are unlikely to be latent positives. Unlabeled data points scoring below this threshold are classified as reliable negatives. - Classification: The module segments the unlabeled data into three distinct categories:
- Known Positives: Original true positives and recovered spies.
- Reliable Negatives: Unlabeled data points with predicted probabilities below the calculated threshold. These form the clean negative set for downstream applications such as Regression model training.
- Undecisives: Unlabeled data points with probabilities above the threshold but not labeled as positive. These are discarded to prevent "noisy negatives".
AugmenNegativeIdentifier
Identifies augmented (augmen_) reliable negatives using the Spy technique and an optimized threshold.
Attributes:
| Name | Type | Description |
|---|---|---|
model |
CatBoostClassifier
|
The trained Spy model. |
spy_tolerance |
float
|
The acceptable proportion of spies within the reliable negatives. |
logger |
Logger
|
Logger instance for logging messages and artifacts. |
feature_column_name |
str
|
Default column name for input features. |
mod_data_point_role_column_name |
str
|
Default column name indicating each data point's role. |
probability_class_1_column_name |
str
|
Default column name for the predicted probability of class 1. |
mod_prediction_class_column_name |
str
|
Default column name for the predicted class. |
augmented_bin_column_name |
str
|
Default column name for the binary augmented label. |
augmented_role_column_name |
str
|
Default column name for the augmented role label. |
Source code in payn\AugmentationModels\SpyModel\augmen_negative_identifier.py
9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 | |
__init__(model, spy_tolerance=0.05, logger=None, feature_column_name=None, mod_data_point_role_column_name=None, probability_class_1_column_name=None, mod_prediction_class_column_name=None, augmented_bin_column_name=None, augmented_role_column_name=None)
Initialize the AugmenNegativeIdentifier.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
model
|
CatBoostClassifier
|
Trained Spy model. |
required |
spy_tolerance
|
float
|
Tolerance for spy inclusion in negatives. |
0.05
|
logger
|
Logger
|
Logger instance for tracking and logging. |
None
|
feature_column_name
|
str
|
Column name for input features. |
None
|
mod_data_point_role_column_name
|
str
|
Column name for data point role. |
None
|
probability_class_1_column_name
|
str
|
Column name for probability predictions for class 1. |
None
|
mod_prediction_class_column_name
|
str
|
Column name for predicted class. |
None
|
augmented_bin_column_name
|
str
|
Column name for binary augmented labels. |
None
|
augmented_role_column_name
|
str
|
Column name for augmented role labels. |
None
|
Source code in payn\AugmentationModels\SpyModel\augmen_negative_identifier.py
25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 | |
filter_augmented_negatives_and_known_positives(spy_inf_data, augmented_role_column_name=None, mod_data_point_role_column_name=None)
Filter the augmented negatives from spy-infused data by excluding known positives. Also, return the set of "undecisive" datapoints.
Known positives are defined as rows where the data point role (from mod_data_point_role_column_name) is "unlabeled spy" or "true positive". For these rows, the augmented role is forcibly set to "known positive". The method returns three DataFrames: - filtered_augmented_negatives: rows with augmented role "reliable negative" - known_positives: rows with role in ["unlabeled spy", "true positive"] - undecisives: rows with augmented role "undecisive"
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
spy_inf_data
|
DataFrame
|
Spy-infused DataFrame containing meta columns. |
required |
augmented_role_column_name
|
str
|
Override for the augmented role column name. |
None
|
mod_data_point_role_column_name
|
str
|
Override for the data point role column name. |
None
|
Returns:
| Type | Description |
|---|---|
Tuple[DataFrame, DataFrame, DataFrame]
|
Tuple[pd.DataFrame, pd.DataFrame, pd.DataFrame]: (filtered_negatives, known_positives, undecisives). |
Raises:
| Type | Description |
|---|---|
KeyError
|
If expected columns are missing. |
Source code in payn\AugmentationModels\SpyModel\augmen_negative_identifier.py
184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 | |
find_augmen_threshold(spy_inf_data, mod_data_point_role_column_name=None, probability_class_1_column_name=None)
Find the optimal threshold for classifying augmented negatives.
The threshold is determined by sorting the predicted probabilities for examples with a data point role of "unlabeled spy" and selecting the value at an index defined by the spy tolerance. If the computed threshold exceeds 0.5, it is set to 0.5.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
spy_inf_data
|
DataFrame
|
Data with predicted probabilities. |
required |
mod_data_point_role_column_name
|
str
|
Override for the data point role column name. |
None
|
probability_class_1_column_name
|
str
|
Override for the probability column name. |
None
|
Returns:
| Name | Type | Description |
|---|---|---|
float |
float
|
The determined threshold. |
Raises:
| Type | Description |
|---|---|
KeyError
|
If the role column is not found in the data. |
ValueError
|
If no spy data is found. |
Source code in payn\AugmentationModels\SpyModel\augmen_negative_identifier.py
135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 | |
from_config(config, model, logger=None)
classmethod
Alternative constructor that extracts the required parameters from a config object.
The configuration dictionary is expected to have keys "spy_splitting" and "meta_columns" with appropriate entries.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
config
|
Dict[str, Any]
|
Configuration dictionary. |
required |
model
|
CatBoostClassifier
|
Trained Spy model. |
required |
logger
|
Logger
|
Logger instance. |
None
|
Returns:
| Name | Type | Description |
|---|---|---|
AugmenNegativeIdentifier |
AugmenNegativeIdentifier
|
A new instance configured from the provided config. |
Source code in payn\AugmentationModels\SpyModel\augmen_negative_identifier.py
58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 | |
get_augmen_negatives_and_known_positives(spy_inf_data, threshold, augmented_bin_column_name=None, augmented_role_column_name=None, probability_class_1_column_name=None, mod_data_point_role_column_name=None)
Extract augmented reliable negatives, known positives, and undecisive datapoints based on a threshold.
The method creates a new binary column (augmented_bin_column_name) for augmented labels based on whether the predicted probability (from probability_class_1_column_name) exceeds the threshold. It then assigns an augmented role ("reliable negative" if binary label is 0; otherwise "undecisive") and calls the filtering function to separate known positives from reliable negatives and to collect undecisive datapoints.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
spy_inf_data
|
DataFrame
|
Spy-infused training data with probability predictions. |
required |
threshold
|
float
|
Threshold for binary classification. |
required |
augmented_bin_column_name
|
str
|
Override for the binary augmented column name. |
None
|
augmented_role_column_name
|
str
|
Override for the augmented role column name. |
None
|
probability_class_1_column_name
|
str
|
Override for the probability column name. |
None
|
mod_data_point_role_column_name
|
str
|
Override for the data point role column name. |
None
|
Returns:
| Type | Description |
|---|---|
Tuple[DataFrame, DataFrame, DataFrame]
|
Tuple[pd.DataFrame, pd.DataFrame, pd.DataFrame]: (augmen_reliable_negatives, known_positives, undecisives). |
Source code in payn\AugmentationModels\SpyModel\augmen_negative_identifier.py
239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 | |
predict_augmen_probabilities(spy_inf_data, feature_column_name=None, mod_prediction_class_column_name=None, probability_class_1_column_name=None)
Predict probabilities and labels for spy-infused training data.
The method adds two columns to a copy of the input DataFrame: one for predicted classes and one for the predicted probabilities for class 1.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
spy_inf_data
|
DataFrame
|
Spy-infused training data. |
required |
feature_column_name
|
str
|
Name of the column containing input features. |
None
|
mod_prediction_class_column_name
|
Optional[str]
|
Override for predicted class column name. |
None
|
probability_class_1_column_name
|
Optional[str]
|
Override for probability column name. |
None
|
Returns:
| Type | Description |
|---|---|
DataFrame
|
pd.DataFrame: A new DataFrame with predicted class and probability columns appended. |
Raises:
| Type | Description |
|---|---|
KeyError
|
If the feature column is not found in the data. |
Exception
|
Propagates any exceptions raised during prediction. |
Source code in payn\AugmentationModels\SpyModel\augmen_negative_identifier.py
87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 | |