Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion guided_diffusion/dpm_solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ def __init__(
continuous_beta_1=20.,
dtype=torch.float32,
):
"""Create a wrapper class for the forward SDE (VP type).
r"""Create a wrapper class for the forward SDE (VP type).
***
Update: We support discrete-time diffusion models by implementing a picewise linear interpolation for log_alpha_t.
We recommend to use schedule='discrete' for the discrete-time diffusion models, especially for high-resolution images.
Expand Down
4 changes: 2 additions & 2 deletions guided_diffusion/unet.py
Original file line number Diff line number Diff line change
Expand Up @@ -1761,7 +1761,7 @@ def _internal_predict_3D_3Dconv(self, x: np.ndarray, min_size: Tuple[int, ...],

def _internal_maybe_mirror_and_pred_3D(self, x: Union[np.ndarray, torch.tensor], mirror_axes: tuple,
do_mirroring: bool = True,
mult: np.ndarray or torch.tensor = None) -> torch.tensor:
mult: Union[np.ndarray, torch.tensor] = None) -> torch.tensor:
assert len(x.shape) == 5, 'x must be (b, c, x, y, z)'

# if cuda available:
Expand Down Expand Up @@ -1828,7 +1828,7 @@ def _internal_maybe_mirror_and_pred_3D(self, x: Union[np.ndarray, torch.tensor],

def _internal_maybe_mirror_and_pred_2D(self, x: Union[np.ndarray, torch.tensor], mirror_axes: tuple,
do_mirroring: bool = True,
mult: np.ndarray or torch.tensor = None) -> torch.tensor:
mult: Union[np.ndarray, torch.tensor] = None) -> torch.tensor:
# if cuda available:
# everything in here takes place on the GPU. If x and mult are not yet on GPU this will be taken care of here
# we now return a cuda tensor! Not numpy array!
Expand Down
2 changes: 1 addition & 1 deletion scripts/segmentation_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ def main():

ds = BRATSDataset3D(args.data_dir, transform_train, test_flag=False)
args.in_ch = 5
elif any(Path(args.data_dir).glob("*\*.nii.gz")):
elif any(Path(args.data_dir).glob("*/*.nii.gz")):
tran_list = [transforms.Resize((args.image_size,args.image_size)),]
transform_train = transforms.Compose(tran_list)
print("Your current directory : ",args.data_dir)
Expand Down