Skip to content

Commit

Permalink
Retrieval Metrics: Updating HitRate and MRR for Evaluation@K document…
Browse files Browse the repository at this point in the history
…s retrieved. Also adding RR as a separate metric (#12997)

* Updating metrics: MRR renamed to RR, HitRate updated for multi-doc evaluation and new separate MRR implementation

* Updated MRR and HitRate with requested changes

* Iteration w/ class attribute implementation for the calculation option choice
  • Loading branch information
AgenP committed May 1, 2024
1 parent d0ffd01 commit b5a57ca
Show file tree
Hide file tree
Showing 2 changed files with 176 additions and 20 deletions.
119 changes: 99 additions & 20 deletions llama-index-core/llama_index/core/evaluation/retrieval/metrics.py
Expand Up @@ -12,9 +12,18 @@


class HitRate(BaseRetrievalMetric):
"""Hit rate metric."""
"""Hit rate metric: Compute hit rate with two calculation options.
- The default method checks for a single match between any of the retrieved docs and expected docs.
- The more granular method checks for all potential matches between retrieved docs and expected docs.
Attributes:
use_granular_hit_rate (bool): Determines whether to use the granular method for calculation.
metric_name (str): The name of the metric.
"""

metric_name: str = "hit_rate"
use_granular_hit_rate: bool = False

def compute(
self,
Expand All @@ -23,21 +32,57 @@ def compute(
retrieved_ids: Optional[List[str]] = None,
expected_texts: Optional[List[str]] = None,
retrieved_texts: Optional[List[str]] = None,
**kwargs: Any,
) -> RetrievalMetricResult:
"""Compute metric."""
if retrieved_ids is None or expected_ids is None:
"""Compute metric based on the provided inputs.
Parameters:
query (Optional[str]): The query string (not used in the current implementation).
expected_ids (Optional[List[str]]): Expected document IDs.
retrieved_ids (Optional[List[str]]): Retrieved document IDs.
expected_texts (Optional[List[str]]): Expected texts (not used in the current implementation).
retrieved_texts (Optional[List[str]]): Retrieved texts (not used in the current implementation).
Raises:
ValueError: If the necessary IDs are not provided.
Returns:
RetrievalMetricResult: The result with the computed hit rate score.
"""
# Checking for the required arguments
if (
retrieved_ids is None
or expected_ids is None
or not retrieved_ids
or not expected_ids
):
raise ValueError("Retrieved ids and expected ids must be provided")
is_hit = any(id in expected_ids for id in retrieved_ids)
return RetrievalMetricResult(
score=1.0 if is_hit else 0.0,
)

if self.use_granular_hit_rate:
# Granular HitRate calculation: Calculate all hits and divide by the number of expected docs
expected_set = set(expected_ids)
hits = sum(1 for doc_id in retrieved_ids if doc_id in expected_set)
score = hits / len(expected_ids) if expected_ids else 0.0
else:
# Default HitRate calculation: Check if there is a single hit
is_hit = any(id in expected_ids for id in retrieved_ids)
score = 1.0 if is_hit else 0.0

return RetrievalMetricResult(score=score)


class MRR(BaseRetrievalMetric):
"""MRR metric."""
"""MRR (Mean Reciprocal Rank) metric with two calculation options.
- The default method calculates the reciprocal rank of the first relevant retrieved document.
- The more granular method sums the reciprocal ranks of all relevant retrieved documents and divides by the count of relevant documents.
Attributes:
use_granular_mrr (bool): Determines whether to use the granular method for calculation.
metric_name (str): The name of the metric.
"""

metric_name: str = "mrr"
use_granular_mrr: bool = False

def compute(
self,
Expand All @@ -46,19 +91,53 @@ def compute(
retrieved_ids: Optional[List[str]] = None,
expected_texts: Optional[List[str]] = None,
retrieved_texts: Optional[List[str]] = None,
**kwargs: Any,
) -> RetrievalMetricResult:
"""Compute metric."""
if retrieved_ids is None or expected_ids is None:
"""Compute MRR based on the provided inputs and selected method.
Parameters:
query (Optional[str]): The query string (not used in the current implementation).
expected_ids (Optional[List[str]]): Expected document IDs.
retrieved_ids (Optional[List[str]]): Retrieved document IDs.
expected_texts (Optional[List[str]]): Expected texts (not used in the current implementation).
retrieved_texts (Optional[List[str]]): Retrieved texts (not used in the current implementation).
Raises:
ValueError: If the necessary IDs are not provided.
Returns:
RetrievalMetricResult: The result with the computed MRR score.
"""
# Checking for the required arguments
if (
retrieved_ids is None
or expected_ids is None
or not retrieved_ids
or not expected_ids
):
raise ValueError("Retrieved ids and expected ids must be provided")
for i, id in enumerate(retrieved_ids):
if id in expected_ids:
return RetrievalMetricResult(
score=1.0 / (i + 1),
)
return RetrievalMetricResult(
score=0.0,
)

if self.use_granular_mrr:
# Granular MRR calculation: All relevant retrieved docs have their reciprocal ranks summed and averaged
expected_set = set(expected_ids)
reciprocal_rank_sum = 0.0
relevant_docs_count = 0
for index, doc_id in enumerate(retrieved_ids):
if doc_id in expected_set:
relevant_docs_count += 1
reciprocal_rank_sum += 1.0 / (index + 1)
mrr_score = (
reciprocal_rank_sum / relevant_docs_count
if relevant_docs_count > 0
else 0.0
)
else:
# Default MRR calculation: Reciprocal rank of the first relevant document retrieved
for i, id in enumerate(retrieved_ids):
if id in expected_ids:
return RetrievalMetricResult(score=1.0 / (i + 1))
mrr_score = 0.0

return RetrievalMetricResult(score=mrr_score)


class CohereRerankRelevancyMetric(BaseRetrievalMetric):
Expand Down
77 changes: 77 additions & 0 deletions llama-index-core/tests/evaluation/test_rr_mrr_hitrate.py
@@ -0,0 +1,77 @@
import pytest
from llama_index.core.evaluation.retrieval.metrics import HitRate, MRR


# Test cases for the updated HitRate class using instance attribute
@pytest.mark.parametrize(
("expected_ids", "retrieved_ids", "use_granular", "expected_result"),
[
(["id1", "id2", "id3"], ["id3", "id1", "id2", "id4"], False, 1.0),
(["id1", "id2", "id3", "id4"], ["id1", "id5", "id2"], True, 2 / 4),
(["id1", "id2"], ["id3", "id4"], False, 0.0),
(["id1", "id2"], ["id2", "id1", "id7"], True, 2 / 2),
],
)
def test_hit_rate(expected_ids, retrieved_ids, use_granular, expected_result):
hr = HitRate()
hr.use_granular_hit_rate = use_granular
result = hr.compute(expected_ids=expected_ids, retrieved_ids=retrieved_ids)
assert result.score == pytest.approx(expected_result)


# Test cases for the updated MRR class using instance attribute
@pytest.mark.parametrize(
("expected_ids", "retrieved_ids", "use_granular", "expected_result"),
[
(["id1", "id2", "id3"], ["id3", "id1", "id2", "id4"], False, 1 / 1),
(["id1", "id2", "id3", "id4"], ["id5", "id1"], False, 1 / 2),
(["id1", "id2"], ["id3", "id4"], False, 0.0),
(["id1", "id2"], ["id2", "id1", "id7"], False, 1 / 1),
(
["id1", "id2", "id3"],
["id3", "id1", "id2", "id4"],
True,
(1 / 1 + 1 / 2 + 1 / 3) / 3,
),
(
["id1", "id2", "id3", "id4"],
["id1", "id2", "id5"],
True,
(1 / 1 + 1 / 2) / 2,
),
(["id1", "id2"], ["id1", "id7", "id15", "id2"], True, (1 / 1 + 1 / 4) / 2),
],
)
def test_mrr(expected_ids, retrieved_ids, use_granular, expected_result):
mrr = MRR()
mrr.use_granular_mrr = use_granular
result = mrr.compute(expected_ids=expected_ids, retrieved_ids=retrieved_ids)
assert result.score == pytest.approx(expected_result)


# Test cases for exceptions handling for both HitRate and MRR
@pytest.mark.parametrize(
("expected_ids", "retrieved_ids", "use_granular"),
[
(
None,
["id3", "id1", "id2", "id4"],
False,
), # None expected_ids should trigger ValueError
(
["id1", "id2", "id3"],
None,
True,
), # None retrieved_ids should trigger ValueError
([], [], False), # Empty IDs should trigger ValueError
],
)
def test_exceptions(expected_ids, retrieved_ids, use_granular):
with pytest.raises(ValueError):
hr = HitRate()
hr.use_granular_hit_rate = use_granular
hr.compute(expected_ids=expected_ids, retrieved_ids=retrieved_ids)

mrr = MRR()
mrr.use_granular_mrr = use_granular
mrr.compute(expected_ids=expected_ids, retrieved_ids=retrieved_ids)

0 comments on commit b5a57ca

Please sign in to comment.