-
Notifications
You must be signed in to change notification settings - Fork 684
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Knn graph code refactor #898
base: master
Are you sure you want to change the base?
Conversation
Codecov ReportAll modified and coverable lines are covered by tests ✅
Additional details and impacted files@@ Coverage Diff @@
## master #898 +/- ##
==========================================
+ Coverage 96.72% 96.80% +0.08%
==========================================
Files 69 70 +1
Lines 5402 5381 -21
Branches 925 916 -9
==========================================
- Hits 5225 5209 -16
+ Misses 89 86 -3
+ Partials 88 86 -2 ☔ View full report in Codecov by Sentry. |
@@ -0,0 +1,28 @@ | |||
from typing import Dict, Union, Optional, Any |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This looks good! It would be beneficial to add a file named test_utils.py
in tests/datalab/issue_manager
and write basic tests for process_knn_graph_from_inputs
.
@@ -94,7 +95,7 @@ def find_issues( | |||
pred_probs: Optional[np.ndarray] = None, | |||
**kwargs, | |||
) -> None: | |||
knn_graph = self._process_knn_graph_from_inputs(kwargs) | |||
knn_graph = ConstructedKNNGraph(self.datalab).process_knn_graph_from_inputs(kwargs) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For the changed issue managers, consider creating an internal attribute in __init__
for the ConstructedKNNGraph
object. Something like
self._constructed_knn_graph_util = ConstructedKNNGraph(self.datalab)
Then use the object to call methods
knn_graph = self._constructed_knn_graph_util.process_knn_graph_from_inputs(kwargs)
.
This way, if more methods are added to the ConstructedKNNGraph
class, you need not instantiate the object again to use those methods.
def __init__(self, datalab): | ||
self.datalab = datalab | ||
|
||
def process_knn_graph_from_inputs(self, kwargs: Dict[str, Any]) -> Union[csr_matrix, None]: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We can replace Union[csr_matrix, None]
with Optional[csr_matrix]
for consistency. Reference.
def process_knn_graph_from_inputs(self, kwargs: Dict[str, Any]) -> Union[csr_matrix, None]: | |
def process_knn_graph_from_inputs(self, kwargs: Dict[str, Any]) -> Optional[csr_matrix]: |
from cleanlab.datalab.datalab import Datalab | ||
|
||
|
||
class TestConstructedKNNGraph: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Here is the same class with some suggestions for the test cases. This should help you cover all relevant cases. Importantly, we want to check if the function
- returns
None
when an insufficient or invalid graph is passed - returns a valid graph present in
kwargs
or part of datalab instance.
class TestConstructedKNNGraph: | |
from scipy.sparse import csr_matrix | |
import numpy as np | |
class TestConstructedKNNGraph: | |
@pytest.fixture | |
def sparse_matrix(self): | |
X = np.random.RandomState(0).rand(5, 5) | |
return csr_matrix(X) | |
@pytest.fixture | |
def constructed_knn_graph_instance(self, lab): | |
return ConstructedKNNGraph(lab) | |
def test_process_knn_graph_from_inputs_vaid_graph(self, constructed_knn_graph_instance, sparse_matrix): | |
# Check when knn_graph is present in kwargs | |
kwargs = {"knn_graph": sparse_matrix} | |
knn_graph = constructed_knn_graph_instance.process_knn_graph_from_inputs(kwargs) | |
assert isinstance(knn_graph, csr_matrix) # Assert type | |
assert knn_graph is sparse_matrix # Assert that passed sparse matrix is same as returned knn graph | |
# Check when knn_graph is present in "statistics" | |
lab = constructed_knn_graph_instance.datalab | |
lab.info["statistics"]["weighted_knn_graph"] = sparse_matrix # Set key in statistics | |
knn_graph = constructed_knn_graph_instance.process_knn_graph_from_inputs(kwargs={}) | |
assert isinstance(knn_graph, csr_matrix) # Assert type | |
assert knn_graph is sparse_matrix | |
# Any other cases where valid knn_graph is returned | |
def test_process_knn_graph_from_inputs_return_None(self, constructed_knn_graph_instance, sparse_matrix): | |
# First check for no knn graph (not present in kwargs or "weighted_knn_graph") | |
kwargs = {} | |
knn_graph = constructed_knn_graph_instance.process_knn_graph_from_inputs(kwargs) | |
assert knn_graph is None | |
# Pass knn_graph with larger k | |
kwargs = {"knn_graph": sparse_matrix, "k": 10} | |
knn_graph = constructed_knn_graph_instance.process_knn_graph_from_inputs(kwargs) | |
assert knn_graph is None | |
# Any other cases where returned knn_graph is None | |
Summary
[ ✏️ implement
process_knn_graph_from_inputs
as an instance method under ConstructedKNNGraph class in new file utils.py ]