|
14 | 14 | # limitations under the License. |
15 | 15 | """ |
16 | 16 |
|
| 17 | +import argparse |
17 | 18 | import json |
18 | 19 | from dataclasses import asdict, dataclass |
19 | 20 | from dataclasses import fields as dataclass_fields |
@@ -190,7 +191,7 @@ class EngineArgs: |
190 | 191 | """ |
191 | 192 | Flag to indicate whether to use warm-up before inference. |
192 | 193 | """ |
193 | | - enable_prefix_caching: bool = False |
| 194 | + enable_prefix_caching: bool = True |
194 | 195 | """ |
195 | 196 | Flag to enable prefix caching. |
196 | 197 | """ |
@@ -387,6 +388,16 @@ def __post_init__(self): |
387 | 388 | """ |
388 | 389 | if not self.tokenizer: |
389 | 390 | self.tokenizer = self.model |
| 391 | + if self.splitwise_role == "decode": |
| 392 | + self.enable_prefix_caching = False |
| 393 | + if self.speculative_config is not None: |
| 394 | + self.enable_prefix_caching = False |
| 395 | + if self.enable_mm: |
| 396 | + self.enable_prefix_caching = False |
| 397 | + if not current_platform.is_cuda(): |
| 398 | + self.enable_prefix_caching = False |
| 399 | + if self.dynamic_load_weight: |
| 400 | + self.enable_prefix_caching = False |
390 | 401 | if self.enable_logprob: |
391 | 402 | if self.speculative_config is not None: |
392 | 403 | raise NotImplementedError("Logprob does not support speculation_config.") |
@@ -725,7 +736,7 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: |
725 | 736 | perf_group = parser.add_argument_group("Performance Tuning") |
726 | 737 | perf_group.add_argument( |
727 | 738 | "--enable-prefix-caching", |
728 | | - action="store_true", |
| 739 | + action=argparse.BooleanOptionalAction, |
729 | 740 | default=EngineArgs.enable_prefix_caching, |
730 | 741 | help="Flag to enable prefix caching.", |
731 | 742 | ) |
|
0 commit comments