Skip to content

Commit 05df18f

Browse files
committed
Only allow a single auth provider that's controlled by enabled_auth
1 parent 39888ab commit 05df18f

2 files changed

Lines changed: 22 additions & 3 deletions

File tree

llms/extensions/github_auth/__init__.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,12 @@
1212
def install(ctx):
1313
g_app = ctx.app
1414

15+
enabled_auth = ctx.enabled_auth()
16+
if enabled_auth != "github_auth":
17+
ctx.log(f"{enabled_auth} is enabled, skipping github_auth auth provider.")
18+
ctx.disabled = True
19+
return
20+
1521
auth_config_file = os.path.join(ctx.get_user_path(), "github_auth", "config.json")
1622

1723
auth_config = None

llms/main.py

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,7 @@
6262
MOCK = os.getenv("MOCK") == "1"
6363
MOCK_DIR = os.getenv("MOCK_DIR")
6464
LLMS_MODE = os.getenv("LLMS_MODE", "local")
65+
LLMS_AUTH = os.getenv("LLMS_AUTH", "credentials")
6566
DISABLE_EXTENSIONS = (os.getenv("LLMS_DISABLE") or "").split(",")
6667
DEFAULT_LIMITS = {
6768
"client_timeout": 120,
@@ -2889,6 +2890,7 @@ def __init__(self, cli_args: argparse.Namespace, extra_args: Dict[str, Any]):
28892890
self.config = None
28902891
self.limits = DEFAULT_LIMITS
28912892
self.mode = LLMS_MODE
2893+
self.auth_extension = LLMS_AUTH
28922894
self.is_local = self.mode == "local"
28932895
self.error_auth_required = create_error_response("Authentication required", "Unauthorized")
28942896
self.ui_extensions = []
@@ -2991,6 +2993,10 @@ def get_allowed_directories(self) -> List[str]:
29912993
"""Get the list of allowed directories."""
29922994
return self.allowed_directories
29932995

2996+
def enabled_auth(self) -> str:
2997+
"""Get the enabled auth extension."""
2998+
return self.auth_extension
2999+
29943000
def set_auth_provider(self, auth_provider: AuthProvider) -> None:
29953001
"""Add an authentication provider."""
29963002
self.auth_provider = auth_provider
@@ -3161,6 +3167,9 @@ def __init__(self, app: AppExtensions, path: str):
31613167
def get_client_timeout(self):
31623168
return self.app.get_client_timeout()
31633169

3170+
def enabled_auth(self) -> bool:
3171+
return self.app.enabled_auth()
3172+
31643173
def set_auth_provider(self, auth_provider: AuthProvider) -> None:
31653174
"""Add an authentication provider."""
31663175
self.app.set_auth_provider(auth_provider)
@@ -3785,6 +3794,7 @@ def create_arg_parser():
37853794
metavar="TYPE",
37863795
)
37873796

3797+
parser.add_argument("--auth", default=None, help="Which Auth Provider to use", metavar="EXTENSION")
37883798
parser.add_argument("--logprefix", default="", help="Prefix used in log messages", metavar="PREFIX")
37893799
parser.add_argument("--verbose", action="store_true", help="Verbose output")
37903800

@@ -3817,7 +3827,7 @@ def create_arg_parser():
38173827

38183828

38193829
def cli_exec(cli_args, extra_args):
3820-
global _ROOT, g_verbose, g_default_model, g_logprefix, g_providers, g_config, g_config_path, g_app
3830+
global _ROOT, LLMS_AUTH, g_verbose, g_default_model, g_logprefix, g_providers, g_config, g_config_path, g_app
38213831

38223832
verify_root_path()
38233833

@@ -3857,6 +3867,9 @@ def cli_exec(cli_args, extra_args):
38573867
print(f"Created default extra providers config at {home_providers_extra_path}")
38583868
return ExitCode.SUCCESS
38593869

3870+
if cli_args.auth:
3871+
LLMS_AUTH = cli_args.auth
3872+
38603873
if cli_args.providers:
38613874
if not os.path.exists(cli_args.providers):
38623875
print(f"providers.json not found at {cli_args.providers}")
@@ -4691,7 +4704,7 @@ async def start_background_tasks(app):
46914704
try:
46924705
stdin_chat = json.loads(stdin_data)
46934706
except json.JSONDecodeError:
4694-
print(f"Invalid JSON from stdin")
4707+
print("Invalid JSON from stdin")
46954708
return ExitCode.FAILED
46964709

46974710
if (
@@ -4728,7 +4741,7 @@ async def start_background_tasks(app):
47284741
chat_json = f.read()
47294742
chat = json.loads(chat_json)
47304743
elif stdin_chat is not None:
4731-
_log(f"Using chat from stdin")
4744+
_log("Using chat from stdin")
47324745
chat = stdin_chat
47334746

47344747
if cli_args.system is not None:

0 commit comments

Comments
 (0)