-
Notifications
You must be signed in to change notification settings - Fork 512
/
multiclass_classification_pr.py
454 lines (393 loc) · 17.4 KB
/
multiclass_classification_pr.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
#
# For licensing see accompanying LICENSE file.
# Copyright (C) 2024 Apple Inc. All Rights Reserved.
#
import argparse
from numbers import Number
from typing import Any, Dict, List, Optional, Tuple, Union
import numpy as np
import torch
from sklearn.metrics import average_precision_score, precision_recall_curve
from torch import Tensor
from torch.nn import functional as F
from corenet.metrics import METRICS_REGISTRY
from corenet.metrics.metric_base import BaseMetric
from corenet.utils import logger, tensor_utils
from corenet.utils.file_logger import FileLogger
def get_recall_at_precision(
precisions: np.ndarray,
recalls: np.ndarray,
precision_value: float,
suppress_warnings: bool = False,
) -> float:
"""
Compute the recall at the given @precision_value.
Args:
precisions: An array of shape [num_elements] with precision values.
recalls: An array of shape [num_elements] with precision values.
precision_value: The precision at which to obtain the recall.
suppress_warnings: Suppress warnings.
Returns: The recall at @precision_value.
"""
sort_indices = np.argsort(precisions)
sorted_precisions = precisions[sort_indices]
sorted_recalls = recalls[sort_indices]
index = np.searchsorted(sorted_precisions, precision_value)
if not suppress_warnings and not np.isclose(
sorted_precisions[index], precision_value, rtol=0.01, atol=0.01
):
# The difference between the requested and true precisions are higher than expected.
logger.warning(
f"Found recall at precision {sorted_precisions[index]} "
f"when recall at precision {precision_value} was requested."
)
return float(sorted_recalls[index])
def compute_oi_f1(predictions: Tensor, targets: Tensor) -> Tuple[float, float]:
"""
Compute the "Optimal Instance" F1 score.
The "official" computation corresponds to:
1. Compute the threshold that maximizes individual input's F1 scores
separately.
2. Using those thresholds, count the true positives, false positives,
and false negatives.
3. Use these values to compute the F1 score.
This function also computes a simple averaging of F1 scores individually
calculated from each input's optimal thresholds.
Args:
predictions: A tensor of shape [batch_size, num_classes, predictions_per_sample]
containing predictions.
targets: A tensor of shape [batch_size, num_classes, predictions_per_sample]
containing targets.
Returns: A tuple containing the "official" F1 and the simple average of individual
F1 scores, as described above.
"""
if not predictions.ndim == 3:
raise ValueError(f"Invalid shape {predictions.shape}")
if not targets.ndim == 3:
raise ValueError(f"Invalid shape {targets.shape}")
batch_size = predictions.shape[0]
num_classes = predictions.shape[1]
official_ois_f1_scores = []
avg_of_best_f1_scores = []
for class_id in range(num_classes):
true_positives = 0
false_positives = 0
false_negatives = 0
best_f1_scores = []
for idx in range(batch_size):
# 1. Find the threshold that maximizes F1 score.
prediction = predictions[idx, class_id]
target = targets[idx, class_id]
precisions, recalls, thresholds = precision_recall_curve(target, prediction)
f1_scores = (
2
* precisions
* recalls
/ (precisions + recalls + ((precisions + recalls) == 0))
)
max_idx = np.argmax(f1_scores)
threshold = thresholds[max_idx]
best_f1_scores.append(f1_scores[max_idx])
# 2. Use the threshold to update counts.
true_positives += ((prediction >= threshold) * (target == 1)).sum().item()
false_positives += ((prediction >= threshold) * (target == 0)).sum().item()
false_negatives += ((prediction < threshold) * (target == 1)).sum().item()
# Store the F1 score for this class.
precision = true_positives / (
true_positives + false_positives + ((true_positives + false_positives) == 0)
)
recall = true_positives / (
true_positives + false_negatives + ((true_positives + false_negatives) == 0)
)
official_ois_f1_scores.append(
float(
(2 * precision * recall)
/ (precision + recall + ((precision + recall) == 0))
)
)
avg_of_best_f1_scores.append(sum(best_f1_scores) / len(best_f1_scores))
return official_ois_f1_scores, avg_of_best_f1_scores
@METRICS_REGISTRY.register(name="multiclass_classification_pr")
class MulticlassClassificationPR(BaseMetric):
"""
Computes multiclass precision/recall metrics.
Example .yaml configuration to use this metric (assuming your
model outputs a dict with key "logits"):
stats:
val: ["multiclass_classification_pr(pred=logits)"]
checkpoint_metric: "multiclass_classification_pr(pred=logits).macro"
checkpoint_metric_max: true
metrics:
multiclass_classification_pr:
include_curve: false
"""
def __init__(
self,
opts: Optional[argparse.Namespace] = None,
is_distributed: bool = False,
pred: str = None,
target: str = None,
) -> None:
self.all_predictions: List[torch.Tensor] = []
self.all_targets: List[torch.Tensor] = []
self.include_curve = getattr(
opts, "stats.metrics.multiclass_classification_pr.include_curve"
)
self.suppress_warnings = getattr(
opts, "stats.metrics.multiclass_classification_pr.suppress_warnings"
)
super().__init__(opts, is_distributed, pred, target)
@classmethod
def add_arguments(cls, parser: argparse.ArgumentParser) -> argparse.ArgumentParser:
"""
Add metric specific arguments.
Args:
parser: The parser to which to add the arguments.
Returns:
The parser.
"""
if cls == MulticlassClassificationPR:
parser.add_argument(
"--stats.metrics.multiclass-classification-pr.include-curve",
action="store_true",
help="If set, PR curves will be stored.",
)
parser.add_argument(
"--stats.metrics.multiclass-classification-pr.suppress-warnings",
action="store_true",
help="If set, warnings will be suppressed. This is useful to reduce the logs size during training.",
)
return parser
def reset(self) -> None:
"""
Resets all aggregated data.
Called at the start of every epoch.
"""
self.all_predictions.clear()
self.all_targets.clear()
def update(
self,
prediction: Union[Tensor, Dict],
target: Union[Tensor, Dict],
extras: Dict[str, Any] = {},
batch_size: Optional[int] = 1,
) -> None:
"""
Processes a new batch of predictions and targets for computing the metric.
Args:
predictions: model outputs for the current batch. They must be a
tensor of shape [batch_size, num_classes, ...], or a dictionary
with key self.pred_key containing such a tensor.
target: labels for the current batch. They may be a tensor of shape
[batch_size, ...], or a dictionary with key self.target_key containing
such a tensor. If so, the entries are assumed to be class indices.
The target may also be a tensor of shape [batch_size, num_classes, ...],
or a dictionary with key self.target_key containing such a tensor. If so,
the entries are assumed to be binary class labels.
extras: unused.
batch_size: unused.
"""
if isinstance(prediction, dict):
if self.pred_key in prediction:
prediction = prediction[self.pred_key]
else:
raise KeyError(
f"Missing prediction key '{self.pred_key}. Existing keys: {prediction.keys()}'"
)
if isinstance(target, dict):
if self.target_key in target:
target = target[self.target_key]
else:
raise KeyError(
f"Missing target key '{self.target_key}. Existing keys: {target.keys()}'"
)
if (prediction.ndim - target.ndim) not in (0, 1):
raise ValueError(
f"Invalid dimensions prediction.shape={prediction.shape}, target.shape={target.shape}"
)
if target.ndim < prediction.ndim:
# The target doesn't have a num_classes dimension because it has
# class labels. Expand it.
num_classes = prediction.shape[1]
target = F.one_hot(target, num_classes=num_classes)
# Change from [batch_size, ..., num_classes] to [batch_size, num_classes, ...].
new_order = (
0,
target.ndim - 1,
) + tuple(range(1, target.ndim - 1))
target = target.permute(*new_order)
# Now, @target and @prediction are both in [batch_size, num_classes, ...] order.
assert target.shape == prediction.shape
if prediction.dim() > 2:
prediction = prediction.reshape(
prediction.shape[0], prediction.shape[1], -1
)
target = target.reshape(target.shape[0], target.shape[1], -1)
with torch.no_grad():
if self.is_distributed:
all_predictions = tensor_utils.all_gather_list(
prediction.detach().cpu().contiguous()
)
all_targets = tensor_utils.all_gather_list(
target.detach().cpu().contiguous()
)
all_predictions = torch.cat(
[p.detach().cpu() for p in all_predictions], dim=0
)
all_targets = torch.cat([t.detach().cpu() for t in all_targets], dim=0)
else:
all_predictions = prediction.detach().cpu()
all_targets = target.detach().cpu()
self.all_predictions.append(all_predictions)
self.all_targets.append(all_targets)
def compute(self) -> Dict[str, Union[Number, List[List[Number]]]]:
"""
Compute the multiclass classification Precision-Recall metrics.
See https://scikit-learn.org/stable/modules/generated/sklearn.metrics.average_precision_score.html#sklearn.metrics.average_precision_score
for details.
Returns:
A dictionary containing:
{
"micro": The "micro"-averaged precision, as defined by SKLearn.
This corresponds to treating each element in the multiclass
prediction separately.
"macro": The "macro"-averaged precision, as defined by SKLearn.
This corresponds to calculating precision-recall metrics
separately for each label, then computing an unweighted mean.
"weighted": The "weighted"-averaged precision, as defined by SKLearn.
This corresponds to calculating precision-recall metrics
separately for each label, then computing a weighted mean.
"precisions": A list of lists, where element [i][j] is the
j'th precision value for the i'th class.
"recalls": A list of lists, where element [i][j] is
the j'th precision value for the i'th class.
"thresholds": A list of lists, where element [i][j] is
the j'th threshold value for the i'th class.
}
"""
if self.include_curve:
metrics = {
"precisions": [],
"recalls": [],
"thresholds": [],
"ODS-F1": [],
"AP": [],
"Recall@P=50": [],
}
else:
metrics = {
"ODS-F1": [],
"AP": [],
"Recall@P=50": [],
}
predictions = (
torch.cat(self.all_predictions, dim=0).float().numpy()
) # [batch_size, num_classes, ...]
num_classes = predictions.shape[1]
targets = (
torch.cat(self.all_targets, dim=0).float().numpy()
) # [batch_size, num_classes, ...]
if predictions.ndim == 3:
assert targets.ndim == 3
# @predictions and @targets have shape [batch_size, num_classes, predictions_per_element]. Compute
# the optimal instance score (OIS-F), which is the only metric that needs the predictions_per_element
# dimension. Then, reshape to [batch_size * predictions_per_element, num_classes].
official, avg = compute_oi_f1(predictions, targets)
metrics["OIS-F1-official"] = official
metrics["OIS-F1-avg"] = avg
predictions = predictions.transpose(0, 2, 1).reshape(-1, num_classes)
targets = targets.transpose(0, 2, 1).reshape(-1, num_classes)
for class_id in range(num_classes):
(
precisions,
recalls,
thresholds,
) = precision_recall_curve(targets[:, class_id], predictions[:, class_id])
f1_scores = (2 * precisions * recalls) / (
precisions + recalls + ((precisions + recalls) == 0)
)
metrics["ODS-F1"].append(f1_scores.max().item())
if self.include_curve:
metrics["precisions"].append(precisions.tolist())
metrics["recalls"].append(recalls.tolist())
metrics["thresholds"].append(thresholds.tolist())
metrics["AP"].append(
float(
average_precision_score(
targets[:, class_id], predictions[:, class_id]
)
)
)
metrics["Recall@P=50"].append(
get_recall_at_precision(
precisions, recalls, 0.5, suppress_warnings=self.suppress_warnings
)
)
for average in ["micro", "macro", "weighted"]:
metrics[average] = float(
average_precision_score(targets, predictions, average=average)
)
return metrics
def is_epoch_summary_enabled_for_metric(
self, metric_name: str, log_writer: Any
) -> bool:
"""
Determines whether to log a metric with the given @metric_name when the
given @log_writer is invoked.
This is mainly used to prevent logs from becoming too large. For
example, we might not want to display every value in a PR curve, even
though we want to calculate and store the curve.
Args:
metric_name: The name of the metric.
log_writer: An object that can be used as a log writer (for example,
a TensorBoardLogger).
Returns:
True if the name of the metric should be logged. False otherwise.
"""
if isinstance(log_writer, FileLogger):
# For FileLoggers, we log everything, including the rather large
# precisions/thresholds/recalls keys.
return True
else:
# For other loggers, we avoid the precisions/thresholds/recalls
# keys.
return not any(
(
"precisions" in metric_name.lower(),
"thresholds" in metric_name.lower(),
"recalls" in metric_name.lower(),
)
)
def flatten_metric(
self, values: Union[Number, List, Dict[str, Any]], metric_name: str
) -> Dict[str, Union[Number, List, Dict[str, Any]]]:
"""
Flatten the given metric @values, prepending @metric_name to the
resulting dictionary's keys.
Unlike the base class's method, we do not recursively flatten. This is
because we have lists of PR curve values, and we don't want to generate
an enormous number of keys to avoid inefficient storage.
Args:
values: The values, as output by @self.compute.
metric_name: The metric name key prefix.
Returns:
A version of @values that has been flattened, with key names
starting with @metric_name.
"""
return {f"{metric_name}/{k}": v for k, v in values.items()}
def summary_string(self, name: str, sep: str, values: Dict[str, Any]) -> str:
"""
Get a string representation of the given metric values, suitable for
printing to the terminal.
We avoid printing precision/thresholds/recalls from PR curve
computation, to avoid excessively long logs.
Args:
name: The name of the metric.
sep: The separator used in the printout.
values: The metric values, as output by @self.compute.
Returns:
A string representation of the metric.
"""
filtered_keys = {"precisions", "thresholds", "recalls"}
values = {k: v for k, v in values.items() if k not in filtered_keys}
return super().summary_string(name, sep, values)