2424from monai .losses .focal_loss import FocalLoss
2525from monai .losses .spatial_mask import MaskedLoss
2626from monai .networks import one_hot
27- from monai .utils import DiceCEReduction , LossReduction , Weight , look_up_option , pytorch_after
27+ from monai .utils import DiceCEReduction , LossReduction , Weight , deprecated_arg , look_up_option , pytorch_after
2828
2929
3030class DiceLoss (_Loss ):
@@ -57,6 +57,7 @@ def __init__(
5757 smooth_nr : float = 1e-5 ,
5858 smooth_dr : float = 1e-5 ,
5959 batch : bool = False ,
60+ weight : Sequence [float ] | float | int | torch .Tensor | None = None ,
6061 ) -> None :
6162 """
6263 Args:
@@ -83,6 +84,11 @@ def __init__(
8384 batch: whether to sum the intersection and union areas over the batch dimension before the dividing.
8485 Defaults to False, a Dice loss value is computed independently from each item in the batch
8586 before any `reduction`.
87+ weight: weights to apply to the voxels of each class. If None no weights are applied.
88+ The input can be a single value (same weight for all classes), a sequence of values (the length
89+ of the sequence should be the same as the number of classes. If not ``include_background``,
90+ the number of classes should not include the background category class 0).
91+ The value/values should be no less than 0. Defaults to None.
8692
8793 Raises:
8894 TypeError: When ``other_act`` is not an ``Optional[Callable]``.
@@ -105,6 +111,8 @@ def __init__(
105111 self .smooth_nr = float (smooth_nr )
106112 self .smooth_dr = float (smooth_dr )
107113 self .batch = batch
114+ self .weight = weight
115+ self .register_buffer ("class_weight" , torch .ones (1 ))
108116
109117 def forward (self , input : torch .Tensor , target : torch .Tensor ) -> torch .Tensor :
110118 """
@@ -181,6 +189,24 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
181189
182190 f : torch .Tensor = 1.0 - (2.0 * intersection + self .smooth_nr ) / (denominator + self .smooth_dr )
183191
192+ if self .weight is not None and target .shape [1 ] != 1 :
193+ # make sure the lengths of weights are equal to the number of classes
194+ num_of_classes = target .shape [1 ]
195+ if isinstance (self .weight , (float , int )):
196+ self .class_weight = torch .as_tensor ([self .weight ] * num_of_classes )
197+ else :
198+ self .class_weight = torch .as_tensor (self .weight )
199+ if self .class_weight .shape [0 ] != num_of_classes :
200+ raise ValueError (
201+ """the length of the `weight` sequence should be the same as the number of classes.
202+ If `include_background=False`, the weight should not include
203+ the background category class 0."""
204+ )
205+ if self .class_weight .min () < 0 :
206+ raise ValueError ("the value/values of the `weight` should be no less than 0." )
207+ # apply class_weight to loss
208+ f = f * self .class_weight .to (f )
209+
184210 if self .reduction == LossReduction .MEAN .value :
185211 f = torch .mean (f ) # the batch and channel average
186212 elif self .reduction == LossReduction .SUM .value :
@@ -620,6 +646,9 @@ class DiceCELoss(_Loss):
620646
621647 """
622648
649+ @deprecated_arg (
650+ "ce_weight" , since = "1.2" , removed = "1.4" , new_name = "weight" , msg_suffix = "please use `weight` instead."
651+ )
623652 def __init__ (
624653 self ,
625654 include_background : bool = True ,
@@ -634,13 +663,14 @@ def __init__(
634663 smooth_dr : float = 1e-5 ,
635664 batch : bool = False ,
636665 ce_weight : torch .Tensor | None = None ,
666+ weight : torch .Tensor | None = None ,
637667 lambda_dice : float = 1.0 ,
638668 lambda_ce : float = 1.0 ,
639669 ) -> None :
640670 """
641671 Args:
642- ``ce_weight`` and `` lambda_ce`` are only used for cross entropy loss.
643- ``reduction`` is used for both losses and other parameters are only used for dice loss.
672+ ``lambda_ce`` are only used for cross entropy loss.
673+ ``reduction`` and ``weight`` is used for both losses and other parameters are only used for dice loss.
644674
645675 include_background: if False channel index 0 (background category) is excluded from the calculation.
646676 to_onehot_y: whether to convert the ``target`` into the one-hot format,
@@ -666,9 +696,10 @@ def __init__(
666696 batch: whether to sum the intersection and union areas over the batch dimension before the dividing.
667697 Defaults to False, a Dice loss value is computed independently from each item in the batch
668698 before any `reduction`.
669- ce_weight : a rescaling weight given to each class for cross entropy loss for `CrossEntropyLoss`.
670- or a rescaling weight given to the loss of each batch element for `BCEWithLogitsLoss`.
699+ weight : a rescaling weight given to each class for cross entropy loss for `CrossEntropyLoss`.
700+ or a weight of positive examples to be broadcasted with target used as `pos_weight` for `BCEWithLogitsLoss`.
671701 See ``torch.nn.CrossEntropyLoss()`` or ``torch.nn.BCEWithLogitsLoss()`` for more information.
702+ The weight is also used in `DiceLoss`.
672703 lambda_dice: the trade-off weight value for dice loss. The value should be no less than 0.0.
673704 Defaults to 1.0.
674705 lambda_ce: the trade-off weight value for cross entropy loss. The value should be no less than 0.0.
@@ -677,6 +708,12 @@ def __init__(
677708 """
678709 super ().__init__ ()
679710 reduction = look_up_option (reduction , DiceCEReduction ).value
711+ weight = ce_weight if ce_weight is not None else weight
712+ dice_weight : torch .Tensor | None
713+ if weight is not None and not include_background :
714+ dice_weight = weight [1 :]
715+ else :
716+ dice_weight = weight
680717 self .dice = DiceLoss (
681718 include_background = include_background ,
682719 to_onehot_y = to_onehot_y ,
@@ -689,9 +726,10 @@ def __init__(
689726 smooth_nr = smooth_nr ,
690727 smooth_dr = smooth_dr ,
691728 batch = batch ,
729+ weight = dice_weight ,
692730 )
693- self .cross_entropy = nn .CrossEntropyLoss (weight = ce_weight , reduction = reduction )
694- self .binary_cross_entropy = nn .BCEWithLogitsLoss (weight = ce_weight , reduction = reduction )
731+ self .cross_entropy = nn .CrossEntropyLoss (weight = weight , reduction = reduction )
732+ self .binary_cross_entropy = nn .BCEWithLogitsLoss (pos_weight = weight , reduction = reduction )
695733 if lambda_dice < 0.0 :
696734 raise ValueError ("lambda_dice should be no less than 0.0." )
697735 if lambda_ce < 0.0 :
@@ -762,12 +800,15 @@ class DiceFocalLoss(_Loss):
762800 The details of Dice loss is shown in ``monai.losses.DiceLoss``.
763801 The details of Focal Loss is shown in ``monai.losses.FocalLoss``.
764802
765- ``gamma``, ``focal_weight`` and ``lambda_focal`` are only used for the focal loss.
766- ``include_background`` and ``reduction`` are used for both losses
803+ ``gamma`` and ``lambda_focal`` are only used for the focal loss.
804+ ``include_background``, ``weight`` and ``reduction`` are used for both losses
767805 and other parameters are only used for dice loss.
768806
769807 """
770808
809+ @deprecated_arg (
810+ "focal_weight" , since = "1.2" , removed = "1.4" , new_name = "weight" , msg_suffix = "please use `weight` instead."
811+ )
771812 def __init__ (
772813 self ,
773814 include_background : bool = True ,
@@ -783,6 +824,7 @@ def __init__(
783824 batch : bool = False ,
784825 gamma : float = 2.0 ,
785826 focal_weight : Sequence [float ] | float | int | torch .Tensor | None = None ,
827+ weight : Sequence [float ] | float | int | torch .Tensor | None = None ,
786828 lambda_dice : float = 1.0 ,
787829 lambda_focal : float = 1.0 ,
788830 ) -> None :
@@ -812,7 +854,7 @@ def __init__(
812854 Defaults to False, a Dice loss value is computed independently from each item in the batch
813855 before any `reduction`.
814856 gamma: value of the exponent gamma in the definition of the Focal loss.
815- focal_weight : weights to apply to the voxels of each class. If None no weights are applied.
857+ weight : weights to apply to the voxels of each class. If None no weights are applied.
816858 The input can be a single value (same weight for all classes), a sequence of values (the length
817859 of the sequence should be the same as the number of classes).
818860 lambda_dice: the trade-off weight value for dice loss. The value should be no less than 0.0.
@@ -822,6 +864,7 @@ def __init__(
822864
823865 """
824866 super ().__init__ ()
867+ weight = focal_weight if focal_weight is not None else weight
825868 self .dice = DiceLoss (
826869 include_background = include_background ,
827870 to_onehot_y = False ,
@@ -834,13 +877,10 @@ def __init__(
834877 smooth_nr = smooth_nr ,
835878 smooth_dr = smooth_dr ,
836879 batch = batch ,
880+ weight = weight ,
837881 )
838882 self .focal = FocalLoss (
839- include_background = include_background ,
840- to_onehot_y = False ,
841- gamma = gamma ,
842- weight = focal_weight ,
843- reduction = reduction ,
883+ include_background = include_background , to_onehot_y = False , gamma = gamma , weight = weight , reduction = reduction
844884 )
845885 if lambda_dice < 0.0 :
846886 raise ValueError ("lambda_dice should be no less than 0.0." )
@@ -879,7 +919,7 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
879919 return total_loss
880920
881921
882- class GeneralizedDiceFocalLoss (torch . nn . modules . loss . _Loss ):
922+ class GeneralizedDiceFocalLoss (_Loss ):
883923 """Compute both Generalized Dice Loss and Focal Loss, and return their weighted average. The details of Generalized Dice Loss
884924 and Focal Loss are available at ``monai.losses.GeneralizedDiceLoss`` and ``monai.losses.FocalLoss``.
885925
@@ -905,7 +945,7 @@ class GeneralizedDiceFocalLoss(torch.nn.modules.loss._Loss):
905945 batch (bool, optional): whether to sum the intersection and union areas over the batch dimension before the dividing.
906946 Defaults to False, i.e., the areas are computed for each item in the batch.
907947 gamma (float, optional): value of the exponent gamma in the definition of the Focal loss. Defaults to 2.0.
908- focal_weight (Optional[Union[Sequence[float], float, int, torch.Tensor]], optional): weights to apply to
948+ weight (Optional[Union[Sequence[float], float, int, torch.Tensor]], optional): weights to apply to
909949 the voxels of each class. If None no weights are applied. The input can be a single value
910950 (same weight for all classes), a sequence of values (the length of the sequence hould be the same as
911951 the number of classes). Defaults to None.
@@ -918,6 +958,9 @@ class GeneralizedDiceFocalLoss(torch.nn.modules.loss._Loss):
918958 ValueError: if either `lambda_gdl` or `lambda_focal` is less than 0.
919959 """
920960
961+ @deprecated_arg (
962+ "focal_weight" , since = "1.2" , removed = "1.4" , new_name = "weight" , msg_suffix = "please use `weight` instead."
963+ )
921964 def __init__ (
922965 self ,
923966 include_background : bool = True ,
@@ -932,6 +975,7 @@ def __init__(
932975 batch : bool = False ,
933976 gamma : float = 2.0 ,
934977 focal_weight : Sequence [float ] | float | int | torch .Tensor | None = None ,
978+ weight : Sequence [float ] | float | int | torch .Tensor | None = None ,
935979 lambda_gdl : float = 1.0 ,
936980 lambda_focal : float = 1.0 ,
937981 ) -> None :
@@ -948,11 +992,12 @@ def __init__(
948992 smooth_dr = smooth_dr ,
949993 batch = batch ,
950994 )
995+ weight = focal_weight if focal_weight is not None else weight
951996 self .focal = FocalLoss (
952997 include_background = include_background ,
953998 to_onehot_y = to_onehot_y ,
954999 gamma = gamma ,
955- weight = focal_weight ,
1000+ weight = weight ,
9561001 reduction = reduction ,
9571002 )
9581003 if lambda_gdl < 0.0 :
0 commit comments