Skip to content

Commit

Permalink
Add point postprocessing (#27)
Browse files Browse the repository at this point in the history
Fixes # .

### Description

A few sentences describing the changes proposed in this pull request.

### Types of changes
<!--- Put an `x` in all the boxes that apply, and remove the not
applicable items -->
- [x] Non-breaking change (fix or new feature that would not break
existing functionality).
- [ ] Breaking change (fix or new feature that would cause existing
functionality to change).
- [ ] New tests added to cover the changes.
- [ ] In-line docstrings updated.

---------

Signed-off-by: heyufan1995 <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
heyufan1995 and pre-commit-ci[bot] authored Jul 16, 2024
1 parent 5aa1472 commit d71a2d1
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 3 deletions.
9 changes: 7 additions & 2 deletions scripts/infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@

from .sliding_window import point_based_window_inferer, sliding_window_inference
from .train import CONFIG
from .utils.trans_utils import VistaPostTransform
from .utils.trans_utils import VistaPostTransform, get_largest_connected_component_point

rearrange, _ = optional_import("einops", name="rearrange")
sys.path.insert(0, os.path.abspath(os.path.dirname(__file__)))
Expand Down Expand Up @@ -168,7 +168,8 @@ def infer(
batch_data = self.batch_data
else:
batch_data = self.infer_transforms(image_file)
batch_data["label_prompt"] = label_prompt
if label_prompt is not None:
batch_data["label_prompt"] = label_prompt
batch_data = list_data_collate([batch_data])
self.batch_data = batch_data
if point is not None:
Expand Down Expand Up @@ -231,6 +232,10 @@ def infer(
meta=batch_data["image"].meta,
)
self.prev_mask = batch_data["pred"]
if label_prompt is None and point is not None:
batch_data["pred"] = get_largest_connected_component_point(
batch_data["pred"], point_coords=point, point_labels=point_label
)
batch_data["image"] = batch_data["image"].to("cpu")
batch_data["pred"] = batch_data["pred"].to("cpu")
torch.cuda.empty_cache()
Expand Down
2 changes: 1 addition & 1 deletion scripts/utils/trans_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,7 +195,7 @@ def dilate3d(input_tensor, erosion=3):


def get_largest_connected_component_point(
img: NdarrayTensor, point_coords=None, point_labels=None, post_idx=3
img: NdarrayTensor, point_coords=None, point_labels=None
) -> NdarrayTensor:
"""
Gets the largest connected component mask of an image. img is before post process! And will include NaN values.
Expand Down

0 comments on commit d71a2d1

Please sign in to comment.