|
62 | 62 | MOCK = os.getenv("MOCK") == "1" |
63 | 63 | MOCK_DIR = os.getenv("MOCK_DIR") |
64 | 64 | LLMS_MODE = os.getenv("LLMS_MODE", "local") |
| 65 | +LLMS_AUTH = os.getenv("LLMS_AUTH", "credentials") |
65 | 66 | DISABLE_EXTENSIONS = (os.getenv("LLMS_DISABLE") or "").split(",") |
66 | 67 | DEFAULT_LIMITS = { |
67 | 68 | "client_timeout": 120, |
@@ -2889,6 +2890,7 @@ def __init__(self, cli_args: argparse.Namespace, extra_args: Dict[str, Any]): |
2889 | 2890 | self.config = None |
2890 | 2891 | self.limits = DEFAULT_LIMITS |
2891 | 2892 | self.mode = LLMS_MODE |
| 2893 | + self.auth_extension = LLMS_AUTH |
2892 | 2894 | self.is_local = self.mode == "local" |
2893 | 2895 | self.error_auth_required = create_error_response("Authentication required", "Unauthorized") |
2894 | 2896 | self.ui_extensions = [] |
@@ -2991,6 +2993,10 @@ def get_allowed_directories(self) -> List[str]: |
2991 | 2993 | """Get the list of allowed directories.""" |
2992 | 2994 | return self.allowed_directories |
2993 | 2995 |
|
| 2996 | + def enabled_auth(self) -> str: |
| 2997 | + """Get the enabled auth extension.""" |
| 2998 | + return self.auth_extension |
| 2999 | + |
2994 | 3000 | def set_auth_provider(self, auth_provider: AuthProvider) -> None: |
2995 | 3001 | """Add an authentication provider.""" |
2996 | 3002 | self.auth_provider = auth_provider |
@@ -3161,6 +3167,9 @@ def __init__(self, app: AppExtensions, path: str): |
3161 | 3167 | def get_client_timeout(self): |
3162 | 3168 | return self.app.get_client_timeout() |
3163 | 3169 |
|
| 3170 | + def enabled_auth(self) -> bool: |
| 3171 | + return self.app.enabled_auth() |
| 3172 | + |
3164 | 3173 | def set_auth_provider(self, auth_provider: AuthProvider) -> None: |
3165 | 3174 | """Add an authentication provider.""" |
3166 | 3175 | self.app.set_auth_provider(auth_provider) |
@@ -3785,6 +3794,7 @@ def create_arg_parser(): |
3785 | 3794 | metavar="TYPE", |
3786 | 3795 | ) |
3787 | 3796 |
|
| 3797 | + parser.add_argument("--auth", default=None, help="Which Auth Provider to use", metavar="EXTENSION") |
3788 | 3798 | parser.add_argument("--logprefix", default="", help="Prefix used in log messages", metavar="PREFIX") |
3789 | 3799 | parser.add_argument("--verbose", action="store_true", help="Verbose output") |
3790 | 3800 |
|
@@ -3817,7 +3827,7 @@ def create_arg_parser(): |
3817 | 3827 |
|
3818 | 3828 |
|
3819 | 3829 | def cli_exec(cli_args, extra_args): |
3820 | | - global _ROOT, g_verbose, g_default_model, g_logprefix, g_providers, g_config, g_config_path, g_app |
| 3830 | + global _ROOT, LLMS_AUTH, g_verbose, g_default_model, g_logprefix, g_providers, g_config, g_config_path, g_app |
3821 | 3831 |
|
3822 | 3832 | verify_root_path() |
3823 | 3833 |
|
@@ -3857,6 +3867,9 @@ def cli_exec(cli_args, extra_args): |
3857 | 3867 | print(f"Created default extra providers config at {home_providers_extra_path}") |
3858 | 3868 | return ExitCode.SUCCESS |
3859 | 3869 |
|
| 3870 | + if cli_args.auth: |
| 3871 | + LLMS_AUTH = cli_args.auth |
| 3872 | + |
3860 | 3873 | if cli_args.providers: |
3861 | 3874 | if not os.path.exists(cli_args.providers): |
3862 | 3875 | print(f"providers.json not found at {cli_args.providers}") |
@@ -4691,7 +4704,7 @@ async def start_background_tasks(app): |
4691 | 4704 | try: |
4692 | 4705 | stdin_chat = json.loads(stdin_data) |
4693 | 4706 | except json.JSONDecodeError: |
4694 | | - print(f"Invalid JSON from stdin") |
| 4707 | + print("Invalid JSON from stdin") |
4695 | 4708 | return ExitCode.FAILED |
4696 | 4709 |
|
4697 | 4710 | if ( |
@@ -4728,7 +4741,7 @@ async def start_background_tasks(app): |
4728 | 4741 | chat_json = f.read() |
4729 | 4742 | chat = json.loads(chat_json) |
4730 | 4743 | elif stdin_chat is not None: |
4731 | | - _log(f"Using chat from stdin") |
| 4744 | + _log("Using chat from stdin") |
4732 | 4745 | chat = stdin_chat |
4733 | 4746 |
|
4734 | 4747 | if cli_args.system is not None: |
|
0 commit comments