diff --git a/train_ddpm/main.py b/train_ddpm/main.py index 4ca3a3c..9ecb9c9 100755 --- a/train_ddpm/main.py +++ b/train_ddpm/main.py @@ -219,8 +219,12 @@ def main(): logging.info("Exp comment = {}".format(args.comment)) try: - runner = ConditionalDiffusion(args, config) - # runner = Diffusion(args, config) + if config.model.type == "simple": + runner = Diffusion(args, config) + elif config.model.type == "conditional": + runner = ConditionalDiffusion(args, config) + else: + raise AssertionError if args.sample: runner.sample() elif args.test: