Skip to content

Commit 8b1e811

Browse files
1 parent 3e5279d commit 8b1e811

17 files changed

Lines changed: 383 additions & 236 deletions

auth_backend/__main__.py

Lines changed: 0 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,6 @@
1-
import logging
2-
31
import uvicorn
42

53
from auth_backend.routes import app
64

75
if __name__ == '__main__':
8-
9-
logging.basicConfig(
10-
filename=f'logger_{__name__}.log',
11-
level=logging.INFO,
12-
format='%(asctime)s.%(msecs)03d %(levelname)s %(module)s - %(funcName)s: %(message)s',
13-
datefmt='%Y-%m-%d %H:%M:%S',
14-
)
15-
166
uvicorn.run(app)

auth_backend/auth_plugins/email.py

Lines changed: 125 additions & 79 deletions
Large diffs are not rendered by default.

auth_backend/models/base.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ def create(cls, *, session: Session, **kwargs) -> BaseDbModel:
3939
return obj
4040

4141
@classmethod
42-
def get_all(cls, *, with_deleted: bool = False, session: Session) -> Query:
42+
def query(cls, *, with_deleted: bool = False, session: Session) -> Query:
4343
"""Get all objects with soft deletes"""
4444
objs = session.query(cls)
4545
if not with_deleted and hasattr(cls, "is_deleted"):

auth_backend/routes/groups.py

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,10 @@
44
from fastapi_sqlalchemy import db
55

66
from auth_backend.exceptions import ObjectNotFound, AlreadyExists
7-
from auth_backend.models.db import Group as DbGroup
8-
from .models.models import Group, GroupPost, GroupsGet, GroupPatch, GroupGet
9-
from ..base import ResponseModel
10-
from ..utils.security import UnionAuth
7+
from auth_backend.models.db import Group as DbGroup, UserSession
8+
from auth_backend.routes.models.models import Group, GroupPost, GroupsGet, GroupPatch, GroupGet
9+
from auth_backend.base import ResponseModel
10+
from auth_backend.utils.security import UnionAuth
1111

1212
auth = UnionAuth()
1313

@@ -25,20 +25,20 @@ async def get_group(id: int, info: list[Literal["child"]] = Query(default=[])) -
2525

2626

2727
@groups.post("", response_model=Group)
28-
async def create_group(group_inp: GroupPost, _: dict[str, str] = Depends(auth)) -> Group:
28+
async def create_group(group_inp: GroupPost, _: UserSession = Depends(auth)) -> Group:
2929
if group_inp.parent_id and not db.session.query(DbGroup).get(group_inp.parent_id):
3030
raise ObjectNotFound(Group, group_inp.parent_id)
31-
if DbGroup.get_all(session=db.session).filter(DbGroup.name == group_inp.name).one_or_none():
31+
if DbGroup.query(session=db.session).filter(DbGroup.name == group_inp.name).one_or_none():
3232
raise HTTPException(status_code=409, detail=ResponseModel(status="Error", message="Name already exists").json())
3333
group = DbGroup.create(session=db.session, **group_inp.dict())
3434
db.session.commit()
3535
return Group.from_orm(group)
3636

3737

3838
@groups.patch("/{id}", response_model=Group)
39-
async def patch_group(id: int, group_inp: GroupPatch, _: dict[str, str] = Depends(auth)) -> Group:
39+
async def patch_group(id: int, group_inp: GroupPatch, _: UserSession = Depends(auth)) -> Group:
4040
if (
41-
exists_check := DbGroup.get_all(session=db.session)
41+
exists_check := DbGroup.query(session=db.session)
4242
.filter(DbGroup.name == group_inp.name, DbGroup.id != id)
4343
.one_or_none()
4444
):
@@ -52,11 +52,11 @@ async def patch_group(id: int, group_inp: GroupPatch, _: dict[str, str] = Depend
5252

5353

5454
@groups.delete("/{id}", response_model=None)
55-
async def delete_group(id: int, _: dict[str, str] = Depends(auth)) -> None:
55+
async def delete_group(id: int, _: UserSession = Depends(auth)) -> None:
5656
group: DbGroup = DbGroup.get(id, session=db.session)
5757
if child := group.child:
5858
for children in child:
59-
children.parent = group.parent
59+
children.parent_id = group.parent_id
6060
db.session.flush()
6161
DbGroup.delete(id, session=db.session)
6262
db.session.commit()
@@ -65,4 +65,4 @@ async def delete_group(id: int, _: dict[str, str] = Depends(auth)) -> None:
6565

6666
@groups.get("", response_model=GroupsGet)
6767
async def get_groups() -> GroupsGet:
68-
return GroupsGet(items=DbGroup.get_all(session=db.session).all())
68+
return GroupsGet(items=DbGroup.query(session=db.session).all())

auth_backend/routes/models/models.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
from __future__ import annotations
2-
import datetime
32

43
from pydantic import Field
54

auth_backend/routes/user_groups.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,9 @@
33
from starlette.exceptions import HTTPException
44

55
from auth_backend.models.db import Group, UserGroup
6-
from .models.models import UserGroupGet, GroupUserListGet, UserGroupPost
7-
from ..base import ResponseModel
8-
from ..utils.security import UnionAuth
6+
from auth_backend.routes.models.models import UserGroupGet, GroupUserListGet, UserGroupPost
7+
from auth_backend.base import ResponseModel
8+
from auth_backend.utils.security import UnionAuth
99

1010
auth = UnionAuth()
1111
user_groups = APIRouter(prefix="/group/{id}/user", tags=["User Groups"])

auth_backend/routes/user_session.py

Lines changed: 11 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,43 +1,37 @@
11
from datetime import datetime
2-
from typing import Literal, Union
2+
from typing import Literal
33

4-
from fastapi import APIRouter, Header, HTTPException, Query
4+
from fastapi import APIRouter, Query, Depends
55
from fastapi_sqlalchemy import db
66
from starlette.responses import JSONResponse
77

88
from auth_backend.base import ResponseModel
9-
from auth_backend.exceptions import AuthFailed
109
from auth_backend.exceptions import SessionExpired
1110
from auth_backend.models.db import UserSession, Group
12-
from .models.models import UserGroups, UserIndirectGroups, UserInfo, UserGet
11+
from auth_backend.routes.models.models import UserGroups, UserIndirectGroups, UserInfo, UserGet
12+
from auth_backend.utils.security import UnionAuth
13+
14+
auth = UnionAuth()
1315

1416
logout_router = APIRouter(prefix="", tags=["Logout"])
1517

1618

1719
@logout_router.post("/logout", response_model=str)
18-
async def logout(token: str = Header(min_length=1)) -> JSONResponse:
19-
session = db.session.query(UserSession).filter(UserSession.token == token).one_or_none()
20-
if not session:
21-
raise AuthFailed(error="Session not found")
20+
async def logout(session: UserSession = Depends(auth)) -> JSONResponse:
2221
if session.expired:
2322
raise SessionExpired(session.token)
2423
session.expires = datetime.utcnow()
2524
db.session.commit()
2625
return JSONResponse(status_code=200, content=ResponseModel(status="Success", message="Logout successful").json())
2726

2827

29-
@logout_router.post("/me", response_model_exclude_unset=True, response_model=UserGet)
28+
@logout_router.get("/me", response_model_exclude_unset=True, response_model=UserGet)
3029
async def me(
31-
token: str = Header(min_length=1), info: list[Literal["groups", "indirect_groups", ""]] = Query(default=[])
30+
session: UserSession = Depends(auth), info: list[Literal["groups", "indirect_groups", ""]] = Query(default=[])
3231
) -> dict[str, str | int]:
33-
if not token:
34-
raise HTTPException(status_code=400, detail=ResponseModel(status="Error", message="Header missing").json())
35-
session: UserSession = db.session.query(UserSession).filter(UserSession.token == token).one_or_none()
36-
if not session:
37-
raise HTTPException(status_code=404, detail=ResponseModel(status="Error", message="Session not found").json())
3832
if session.expired:
39-
raise SessionExpired(token)
40-
result = {}
33+
raise SessionExpired(str(session.token))
34+
result: dict[str, str | int] = {}
4135
result = result | UserInfo(id=session.user_id, email=session.user.auth_methods.email.value).dict()
4236
if "groups" in info:
4337
result = result | UserGroups(groups=session.user.groups).dict()

auth_backend/settings.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ class Settings(BaseSettings):
88

99
EMAIL: str | None
1010
APPLICATION_HOST: str = "localhost"
11-
EMAIL_PASS: str = None
11+
EMAIL_PASS: str | None
1212
SMTP_HOST: str = 'smtp.gmail.com'
1313
SMTP_PORT: int = 587
1414
ENABLED_AUTH_METHODS: list[str] | None

auth_backend/utils/security.py

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -11,29 +11,27 @@
1111
class UnionAuth(SecurityBase):
1212
model = APIKey.construct(in_=APIKeyIn.header, name="Authorization")
1313
scheme_name = "token"
14-
auth_url: str
1514

16-
def __init__(self, auth_url: str = "", auto_error=True) -> None:
15+
def __init__(self, auto_error=True) -> None:
1716
super().__init__()
1817
self.auto_error = auto_error
19-
self.auth_url = auth_url
2018

2119
def _except(self):
2220
if self.auto_error:
2321
raise HTTPException(status_code=HTTP_403_FORBIDDEN, detail="Not authenticated")
2422
else:
25-
return {}
23+
return None
2624

2725
async def __call__(
2826
self,
2927
request: Request,
30-
) -> dict[str, str | int]:
28+
) -> UserSession:
3129
token = request.headers.get("Authorization")
3230
if not token:
3331
return self._except()
3432
user_session: UserSession = (
35-
UserSession.get_all(session=db.session).filter(UserSession.token == token).one_or_none()
33+
UserSession.query(session=db.session).filter(UserSession.token == token).one_or_none()
3634
)
3735
if not user_session:
3836
self._except()
39-
return {"id": user_session.user_id, "email": user_session.user.auth_methods.email.value}
37+
return user_session

tests/conftest.py

Lines changed: 52 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,29 +1,64 @@
11
import datetime
2-
from unittest.mock import Mock
2+
from unittest.mock import patch
33

44
import pytest
55
from fastapi.testclient import TestClient
66
from sqlalchemy import create_engine
77
from sqlalchemy.orm import sessionmaker
88
from starlette import status
99

10-
import auth_backend.auth_plugins.email
1110
from auth_backend.models import AuthMethod, User
1211
from auth_backend.models.db import Group, UserSession, UserGroup
1312
from auth_backend.routes.base import app
1413
from auth_backend.settings import get_settings
15-
import auth_backend.utils.security
1614

1715

18-
@pytest.fixture(scope="session")
16+
@pytest.fixture
1917
def client():
20-
auth_backend.auth_plugins.email.send_confirmation_email = Mock(return_value=None)
21-
auth_backend.auth_plugins.email.send_change_password_confirmation = Mock(return_value=None)
22-
auth_backend.auth_plugins.email.send_changes_password_notification = Mock(return_value=None)
23-
auth_backend.auth_plugins.email.send_reset_email = Mock(return_value=None)
24-
auth_backend.utils.security.UnionAuth.__call__ = Mock(return_value={"id": 0, "email": ""})
18+
patcher1 = patch("auth_backend.auth_plugins.email.send_confirmation_email")
19+
patcher2 = patch("auth_backend.auth_plugins.email.send_change_password_confirmation")
20+
patcher3 = patch("auth_backend.auth_plugins.email.send_changes_password_notification")
21+
patcher4 = patch("auth_backend.auth_plugins.email.send_reset_email")
22+
patcher5 = patch("auth_backend.utils.security.UnionAuth.__call__")
23+
patcher1.start()
24+
patcher2.start()
25+
patcher3.start()
26+
patcher4.start()
27+
patcher5.start()
28+
patcher1.return_value = None
29+
patcher2.return_value = None
30+
patcher3.return_value = None
31+
patcher4.return_value = None
32+
patcher5.return_value = {"id": 0, "email": None}
2533
client = TestClient(app)
2634
yield client
35+
patcher1.stop()
36+
patcher2.stop()
37+
patcher3.stop()
38+
patcher4.stop()
39+
patcher5.stop()
40+
41+
42+
@pytest.fixture
43+
def client_auth():
44+
patcher1 = patch("auth_backend.auth_plugins.email.send_confirmation_email")
45+
patcher2 = patch("auth_backend.auth_plugins.email.send_change_password_confirmation")
46+
patcher3 = patch("auth_backend.auth_plugins.email.send_changes_password_notification")
47+
patcher4 = patch("auth_backend.auth_plugins.email.send_reset_email")
48+
patcher1.start()
49+
patcher2.start()
50+
patcher3.start()
51+
patcher4.start()
52+
patcher1.return_value = None
53+
patcher2.return_value = None
54+
patcher3.return_value = None
55+
patcher4.return_value = None
56+
client = TestClient(app)
57+
yield client
58+
patcher1.stop()
59+
patcher2.stop()
60+
patcher3.stop()
61+
patcher4.stop()
2762

2863

2964
@pytest.fixture(scope='session')
@@ -35,10 +70,10 @@ def dbsession():
3570

3671

3772
@pytest.fixture()
38-
def user_id(client: TestClient, dbsession):
73+
def user_id(client_auth: TestClient, dbsession):
3974
time = datetime.datetime.utcnow()
4075
body = {"email": f"user{time}@example.com", "password": "string"}
41-
client.post("/email/registration", json=body)
76+
client_auth.post("/email/registration", json=body)
4277
db_user: AuthMethod = (
4378
dbsession.query(AuthMethod).filter(AuthMethod.value == body['email'], AuthMethod.param == 'email').one()
4479
)
@@ -54,15 +89,15 @@ def user_id(client: TestClient, dbsession):
5489

5590

5691
@pytest.fixture()
57-
def user(client: TestClient, dbsession):
92+
def user(client_auth: TestClient, dbsession):
5893
url = "/email/login"
5994
time = datetime.datetime.utcnow()
6095
body = {"email": f"user{time}@example.com", "password": "string"}
61-
client.post("/email/registration", json=body)
96+
client_auth.post("/email/registration", json=body)
6297
db_user: AuthMethod = (
6398
dbsession.query(AuthMethod).filter(AuthMethod.value == body['email'], AuthMethod.param == 'email').one()
6499
)
65-
response = client.post(url, json=body)
100+
response = client_auth.post(url, json=body)
66101
assert response.status_code == status.HTTP_401_UNAUTHORIZED
67102
token = (
68103
dbsession.query(AuthMethod)
@@ -73,9 +108,9 @@ def user(client: TestClient, dbsession):
73108
)
74109
.one()
75110
)
76-
response = client.get(f"/email/approve?token={token.value}")
111+
response = client_auth.get(f"/email/approve?token={token.value}")
77112
assert response.status_code == status.HTTP_200_OK
78-
response = client.post(url, json=body)
113+
response = client_auth.post(url, json=body)
79114
assert response.status_code == status.HTTP_200_OK
80115
yield {"user_id": db_user.user_id, "body": body, "login_json": response.json()}
81116
session = dbsession.query(UserSession).filter(UserSession.user_id == db_user.user_id).all()
@@ -88,7 +123,7 @@ def user(client: TestClient, dbsession):
88123
dbsession.commit()
89124

90125

91-
@pytest.fixture(scope="module")
126+
@pytest.fixture
92127
def parent_id(client, dbsession):
93128
time = datetime.datetime.utcnow()
94129
body = {"name": f"group{time}", "parent_id": None}

0 commit comments

Comments
 (0)