Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 16 additions & 0 deletions configs/barcode/barcode-R-BC.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
MODEL:
OUT_PLANES: 2
TARGET_OPT: ["0", "4-0-1"]
LOSS_OPTION:
- - WeightedBCEWithLogitsLoss
- DiceLoss
- - WeightedBCEWithLogitsLoss
- DiceLoss
LOSS_WEIGHT: [[1.0, 0.5], [1.0, 0.5]]
WEIGHT_OPT: [["1", "0"], ["1", "0"]]
OUTPUT_ACT: [["none", "sigmoid"], ["none", "sigmoid"]]
INFERENCE:
OUTPUT_ACT: ["sigmoid", "sigmoid"]
OUTPUT_PATH: outputs/barcode_R_BC/test/
DATASET:
OUTPUT_PATH: outputs/barcode_R_BC/
17 changes: 17 additions & 0 deletions configs/barcode/barcode-R-BCS.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
MODEL:
OUT_PLANES: 3
TARGET_OPT: ["0", "4-0-1", "a-0-40-16-16"]
LOSS_OPTION:
- - WeightedBCEWithLogitsLoss
- DiceLoss
- - WeightedBCEWithLogitsLoss
- DiceLoss
- - WeightedMSE
LOSS_WEIGHT: [[1.0, 0.5], [1.0, 0.5], [4.0]]
WEIGHT_OPT: [["1", "0"], ["1", "0"], ["0"]]
OUTPUT_ACT: [["none", "sigmoid"], ["none", "sigmoid"], ["tanh"]]
INFERENCE:
OUTPUT_ACT: ["sigmoid", "sigmoid", "tanh"]
OUTPUT_PATH: outputs/barcode_R_BCS/test/
DATASET:
OUTPUT_PATH: outputs/barcode_R_BCS/
48 changes: 48 additions & 0 deletions configs/barcode/barcode-R-Base.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
SYSTEM:
NUM_GPUS: 1
NUM_CPUS: 16
# NUM_GPUS: 4
# NUM_CPUS: 16
MODEL:
ARCHITECTURE: unet_3d
BLOCK_TYPE: residual_se
INPUT_SIZE: [33, 97, 97]
OUTPUT_SIZE: [33, 97, 97]
NORM_MODE: gn
IN_PLANES: 1
MIXED_PRECESION: False
FILTERS: [32, 64, 96, 128, 160]
LABEL_EROSION: 1
DATASET:
IMAGE_NAME: ["1-xri_deconvolved.tif", "2-xri_deconvolved.tif"]
LABEL_NAME: ["1-annotated_mask.tif", "2-annotated_mask.tif"]
INPUT_PATH: datasets/barcode_R/ # or your own dataset path
OUTPUT_PATH: outputs/barcode_R/
PAD_SIZE: [16, 32, 32]
DATA_SCALE: [1.0, 1.0, 1.0]
REJECT_SAMPLING:
SIZE_THRES: 1000
P: 1.0
DISTRIBUTED: True
SOLVER:
LR_SCHEDULER_NAME: WarmupCosineLR
BASE_LR: 0.02
ITERATION_STEP: 1
ITERATION_SAVE: 5000
ITERATION_TOTAL: 100000
SAMPLES_PER_BATCH: 2
MONITOR:
ITERATION_NUM: [40, 400]
INFERENCE:
INPUT_SIZE: [33, 257, 257]
OUTPUT_SIZE: [33, 257, 257]
INPUT_PATH: datasets/barcode_R/
IMAGE_NAME: ["1-xri_deconvolved.tif", "2-xri_deconvolved.tif", "3-xri_deconvolved.tif", "4_1-xri_deconvolved.tif", "4_2-xri_deconvolved.tif", "4_3-xri_deconvolved.tif", "5_1-xri_deconvolved.tif", "5_2-xri_deconvolved.tif", "6_1-xri_deconvolved.tif", "6_2-xri_deconvolved.tif"]
# IMAGE_NAME: 3-xri_deconvolved.tif
OUTPUT_PATH: outputs/barcode_R/test/
OUTPUT_NAME: result.h5
PAD_SIZE: [16, 32, 32]
AUG_MODE: "mean"
AUG_NUM: None
STRIDE: [26, 128, 128]
SAMPLES_PER_BATCH: 4
6 changes: 6 additions & 0 deletions connectomics/data/utils/data_segmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -362,6 +362,12 @@ def seg_to_targets(
label = maximum_filter(label, dilation)
distance = seg2inst_edt(label, topt[2+index+1:])
out[tid] = distance[np.newaxis, :].astype(np.float32)
elif topt[0] == 'a':
# "8-{0 if no quantize, 1 if quantize}-{z_res}-{y_res}-{x_res}"
_, quantize, z_res, y_res, x_res = topt.split('-')
quantize = bool(int(quantize))
z_res, y_res, x_res = float(z_res), float(y_res), float(x_res)
out[tid] = sdt_instance(label, quantize=quantize, resolution=(z_res, y_res, x_res))
elif topt[0] == '9': # generic semantic segmentation
out[tid] = label.astype(np.int64)
else:
Expand Down
80 changes: 52 additions & 28 deletions connectomics/data/utils/data_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,11 @@
from skimage.morphology import remove_small_holes, skeletonize, binary_erosion, disk, ball
from skimage.measure import label as label_cc # avoid namespace conflict
from skimage.filters import gaussian
from em_util.io import compute_bbox_all

from .data_misc import get_padsize, array_unpad


__all__ = [
'edt_semantic',
'edt_instance',
Expand Down Expand Up @@ -89,28 +91,19 @@ def edt_instance(label: np.ndarray,


def sdt_instance(label: np.ndarray,
mode: str = '2d',
mode: str = '3d',
quantize: bool = True,
resolution: Tuple[float] = (1.0, 1.0),
resolution: Tuple[float] = (1.0, 1.0, 1.0),
padding: bool = True):
"""Skeleton-based distance transform (SDT) for a stack of label images.

Lin, Zudi, et al. "Structure-Preserving Instance Segmentation via Skeleton-Aware
Distance Transform." International Conference on Medical Image Computing and
Computer-Assisted Intervention. Cham: Springer Nature Switzerland, 2023.
"""
assert mode == "2d", "Only 2d skeletonization is currently supported."
assert mode == "3d", "Only 3D mode is supported, revert to other branch"
vol_distance, vol_semantic = skeleton_aware_distance_transform(label, padding=padding, resolution=resolution)

vol_distance = []
vol_semantic = []
for i in range(label.shape[0]):
label_img = label[i].copy()
distance, semantic = skeleton_aware_distance_transform(label_img, padding=padding)
vol_distance.append(distance)
vol_semantic.append(semantic)

vol_distance = np.stack(vol_distance, 0)
vol_semantic = np.stack(vol_semantic, 0)
if quantize:
vol_distance = energy_quantize(vol_distance)

Expand Down Expand Up @@ -186,19 +179,36 @@ def smooth_edge(binary, smooth_sigma: float = 2.0, smooth_threshold: float = 0.5
return binary


def pad_bbox(bbox, shape, pad_size: int):
bbox[:, 1] = bbox[:, 1] - pad_size
bbox[:, 3] = bbox[:, 3] - pad_size
bbox[:, 5] = bbox[:, 5] - pad_size
bbox[:, 2] = bbox[:, 2] + pad_size
bbox[:, 4] = bbox[:, 4] + pad_size
bbox[:, 6] = bbox[:, 6] + pad_size

bbox[:, 1] = np.maximum(bbox[:, 1], 0)
bbox[:, 3] = np.maximum(bbox[:, 3], 0)
bbox[:, 5] = np.maximum(bbox[:, 5], 0)
bbox[:, 2] = np.minimum(bbox[:, 2], shape[0] - 1)
bbox[:, 4] = np.minimum(bbox[:, 4], shape[1] - 1)
bbox[:, 6] = np.minimum(bbox[:, 6], shape[2] - 1)

return bbox

def skeleton_aware_distance_transform(
label: np.ndarray,
bg_value: float = -1.0,
relabel: bool = True,
padding: bool = False,
resolution: Tuple[float] = (1.0, 1.0),
resolution: Tuple[float] = (1.0, 1.0, 1.0),
alpha: float = 0.8,
smooth: bool = True,
smooth_skeleton_only: bool = True,
):
"""Skeleton-based distance transform (SDT).

Lin, Zudi, et al. "Structure-Preserving Instance Segmentation via Skeleton-Aware
Lin, Zudi, et al. "Structure-Preserving Instance Segmentation via Skeleton-Aware
Distance Transform." International Conference on Medical Image Computing and
Computer-Assisted Intervention. Cham: Springer Nature Switzerland, 2023.
"""
Expand All @@ -212,7 +222,7 @@ def skeleton_aware_distance_transform(
# The distance_transform_edt function does not treat image border
# as background. If image border needs to be considered as background
# in distance calculation, set padding to True.
label = np.pad(label, pad_size, mode='constant', constant_values=0)
label = np.pad(label, pad_size, mode="constant", constant_values=0)

label_shape = label.shape
all_bg_sample = False
Expand All @@ -222,15 +232,28 @@ def skeleton_aware_distance_transform(
semantic = np.zeros(label_shape, dtype=np.uint8)

indices = np.unique(label)

# [N, 7]: [label, z0, z1, y0, y1, x0, x1]
bbox = compute_bbox_all(label, uid = indices)
# NOTE: maybe unnecessary, but just in case
bbox = pad_bbox(bbox, label_shape, pad_size)

if indices[0] == 0:
if len(indices) > 1: # exclude background
indices = indices[1:]
assert bbox[0, 0] == 0, "Missing background bbox"
bbox = bbox[1:]
else: # all-background sample
all_bg_sample = True

if not all_bg_sample:
for idx in indices:
temp2 = remove_small_holes(label == idx, 16, connectivity=1)
for i, idx in enumerate(indices):
assert bbox[i, 0] == idx, "Mismatched label and bbox"
assert np.all(bbox[i, 1:] >= 0), "Negative bbox coordinates"
z0, z1, y0, y1, x0, x1 = bbox[i, 1:]

temp1 = label[z0:z1+1, y0:y1+1, x0:x1+1].copy() == idx
temp2 = remove_small_holes(temp1, 16, connectivity=1)
binary = temp2.copy()

if smooth:
Expand All @@ -245,32 +268,33 @@ def skeleton_aware_distance_transform(
else:
temp2 = binary.copy()

semantic += temp2.astype(np.uint8)
semantic[z0:z1+1, y0:y1+1, x0:x1+1] += temp2.astype(np.uint8)

skeleton_mask = skeletonize(binary)
skeleton_mask = (skeleton_mask != 0).astype(np.uint8)
skeleton += skeleton_mask
skeleton[z0:z1+1, y0:y1+1, x0:x1+1] += skeleton_mask

skeleton_edt = distance_transform_edt(1-skeleton_mask, resolution)
skeleton_edt = distance_transform_edt(1 - skeleton_mask, resolution)
boundary_edt = distance_transform_edt(temp2, resolution)

energy = boundary_edt / (skeleton_edt + boundary_edt + eps) # normalize
energy = energy ** alpha
distance = np.maximum(distance, energy * temp2.astype(np.float32))
energy = boundary_edt / (skeleton_edt + boundary_edt + eps) # normalize
energy = energy**alpha
distance[z0:z1+1, y0:y1+1, x0:x1+1] = np.maximum(
distance[z0:z1+1, y0:y1+1, x0:x1+1], energy * temp2.astype(np.float32)
)

if bg_value != 0:
distance[distance==0] = bg_value

if padding:
# Unpad the output array to preserve original shape.
distance = array_unpad(distance, get_padsize(
pad_size, ndim=distance.ndim))
semantic = array_unpad(semantic, get_padsize(
pad_size, ndim=distance.ndim))
distance = array_unpad(distance, get_padsize(pad_size, ndim=distance.ndim))
semantic = array_unpad(semantic, get_padsize(pad_size, ndim=distance.ndim))

return distance, semantic



def energy_quantize(energy, levels=10):
"""Convert the continuous energy map into the quantized version.
"""
Expand Down
7 changes: 3 additions & 4 deletions connectomics/model/utils/initialize.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,8 @@ def model_init(model, mode='orthogonal'):
'selu': selu_init,
'orthogonal': ortho_init,
}
# Applies fn recursively to every submodule (as returned by .children()) as well
# as self. See https://pytorch.org/docs/stable/generated/torch.nn.Module.html.
model.apply(model_init_dict[mode])
# no need to do model.apply since each init function already iterates through model.modules()
model_init_dict[mode](model)

def xavier_init(model):
# sxavier initialization
Expand All @@ -38,7 +37,7 @@ def selu_init(model):
nn.init.normal(m.weight, 0, sqrt(1. / fan_in))
elif isinstance(m, nn.Linear):
fan_in = m.in_features
nn.init.normal(m.weight, 0, sqrt(1. / fan_in))
nn.init.normal_(m.weight, 0, sqrt(1. / fan_in))

def ortho_init(model):
# orthogonal initialization
Expand Down
2 changes: 2 additions & 0 deletions connectomics/model/utils/misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,8 @@ class SplitActivation(object):
'5': 1, # instance edt (11 channels for quantized)
'6': 1, # semantic edt
'7': 2, # diffusion gradients (2d)
'8': 1, # skeleton dilation
'a': 1, # skeleton aware distance transform
'all': -1 # all remaining channels
}

Expand Down
4 changes: 4 additions & 0 deletions connectomics/utils/visualizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,10 +76,14 @@ def visualize(self, volume, label, output, weight, iter_total, writer,
:, np.newaxis]
label[idx] = temp_label / temp_label.max() + 1e-6


if topt[0]=='7': # diffusion gradient
output[idx] = dx_to_circ(output[idx])
label[idx] = dx_to_circ(label[idx])

if topt[0] == "a": # skeletonization-aware distance transform
label[idx] = label[idx][:, np.newaxis]

RGB = (topt[0] in ['1', '2', '7', '9'])
vis_name = self.cfg.MODEL.TARGET_OPT[idx] + '_' + str(idx)
if suffix is not None:
Expand Down