Skip to content

Commit 26edd95

Browse files
Groups (#29)
1 parent 74f5c13 commit 26edd95

17 files changed

Lines changed: 817 additions & 38 deletions

File tree

auth_backend/auth_plugins/auth_method.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -40,10 +40,10 @@ def __init_subclass__(cls, **kwargs):
4040

4141
@staticmethod
4242
@abstractmethod
43-
async def _register(**kwargs) -> object:
43+
async def _register(*args, **kwargs) -> object:
4444
raise NotImplementedError()
4545

4646
@staticmethod
4747
@abstractmethod
48-
async def _login(**kwargs) -> Session:
48+
async def _login(*args, **kwargs) -> Session:
4949
raise NotImplementedError()

auth_backend/auth_plugins/email.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -156,15 +156,13 @@ async def _add_to_db(user_inp: EmailRegister, confirmation_token: str, user: Use
156156
)
157157
db.session.flush()
158158

159-
160159
@staticmethod
161160
async def _change_confirmation_link(user: User, confirmation_token: str) -> None:
162161
if user.auth_methods.confirmed.value == "true":
163162
raise AlreadyExists(User, user.id)
164163
else:
165164
user.auth_methods.confirmation_token.value = confirmation_token
166165

167-
168166
@staticmethod
169167
async def _get_user_by_token_and_id(id: int, token: str) -> User:
170168
user: User = db.session.query(User).get(id)

auth_backend/models/base.py

Lines changed: 58 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,21 +1,76 @@
1+
from __future__ import annotations
12
import re
3+
4+
from sqlalchemy import not_, Integer
5+
from sqlalchemy.exc import NoResultFound
26
from sqlalchemy.ext.declarative import as_declarative, declared_attr
7+
from sqlalchemy.orm import Session, Mapped, mapped_column, Query
8+
9+
from auth_backend.exceptions import ObjectNotFound
310

411

512
@as_declarative()
613
class Base:
714
"""Base class for all database entities"""
815

9-
@classmethod
1016
@declared_attr
1117
def __tablename__(cls) -> str: # pylint: disable=no-self-argument
1218
"""Generate database table name automatically.
1319
Convert CamelCase class name to snake_case db table name.
1420
"""
1521
return re.sub(r"(?<!^)(?=[A-Z])", "_", cls.__name__).lower()
1622

17-
def __repr__(self) -> str:
23+
def __repr__(self):
1824
attrs = []
1925
for c in self.__table__.columns:
2026
attrs.append(f"{c.name}={getattr(self, c.name)}")
21-
return "{}({})".format(self.__class__.__name__, ', '.join(attrs))
27+
return "{}({})".format(c.__class__.__name__, ', '.join(attrs))
28+
29+
30+
class BaseDbModel(Base):
31+
__abstract__ = True
32+
id: Mapped[int] = mapped_column(Integer, primary_key=True)
33+
34+
@classmethod
35+
def create(cls, *, session: Session, **kwargs) -> BaseDbModel:
36+
obj = cls(**kwargs)
37+
session.add(obj)
38+
session.flush()
39+
return obj
40+
41+
@classmethod
42+
def get_all(cls, *, with_deleted: bool = False, session: Session) -> Query:
43+
"""Get all objects with soft deletes"""
44+
objs = session.query(cls)
45+
if not with_deleted and hasattr(cls, "is_deleted"):
46+
objs = objs.filter(not_(cls.is_deleted))
47+
return objs
48+
49+
@classmethod
50+
def get(cls, id: int, *, with_deleted=False, session: Session) -> BaseDbModel:
51+
"""Get object with soft deletes"""
52+
objs = session.query(cls)
53+
if not with_deleted and hasattr(cls, "is_deleted"):
54+
objs = objs.filter(not_(cls.is_deleted))
55+
try:
56+
return objs.filter(cls.id == id).one()
57+
except NoResultFound:
58+
raise ObjectNotFound(cls, id)
59+
60+
@classmethod
61+
def update(cls, id: int, *, session: Session, **kwargs) -> BaseDbModel:
62+
obj = cls.get(id, session=session)
63+
for k, v in kwargs.items():
64+
setattr(obj, k, v)
65+
session.flush()
66+
return obj
67+
68+
@classmethod
69+
def delete(cls, id: int, *, session: Session) -> None:
70+
"""Soft delete object if possible, else hard delete"""
71+
obj = cls.get(id, session=session)
72+
if hasattr(obj, "is_deleted"):
73+
obj.is_deleted = True
74+
else:
75+
session.delete(obj)
76+
session.flush()

auth_backend/models/db.py

Lines changed: 76 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,17 @@
11
from __future__ import annotations
22

33
import datetime
4+
from typing import Iterator
45

56
import sqlalchemy.orm
6-
from sqlalchemy.orm import Mapped, mapped_column, relationship
7-
from sqlalchemy import String, Integer, ForeignKey, DateTime
7+
from sqlalchemy import String, Integer, ForeignKey, DateTime, Boolean
88
from sqlalchemy.ext.hybrid import hybrid_property
9+
from sqlalchemy.orm import Mapped, mapped_column, relationship, backref
910

10-
from auth_backend.models.base import Base
11+
from auth_backend.models.base import BaseDbModel
1112

1213

1314
class ParamDict:
14-
1515
# Type hints
1616
email: AuthMethod
1717
hashed_password: AuthMethod
@@ -32,12 +32,28 @@ def __new__(cls, methods: list[AuthMethod], *args, **kwargs):
3232
return obj
3333

3434

35-
class User(Base):
36-
37-
id: Mapped[int] = mapped_column(Integer, primary_key=True)
35+
class User(BaseDbModel):
36+
is_deleted: Mapped[bool] = mapped_column(Boolean, default=False)
37+
38+
_auth_methods: Mapped[list[AuthMethod]] = relationship(
39+
"AuthMethod",
40+
foreign_keys="AuthMethod.user_id",
41+
primaryjoin="and_(User.id==AuthMethod.user_id, not_(AuthMethod.is_deleted))",
42+
)
43+
sessions: Mapped[list[UserSession]] = relationship(
44+
"UserSession", foreign_keys="UserSession.user_id", back_populates="user"
45+
)
46+
groups: Mapped[list[Group]] = relationship(
47+
"Group",
48+
secondary="user_group",
49+
back_populates="users",
50+
primaryjoin="and_(User.id==UserGroup.user_id, not_(UserGroup.is_deleted))",
51+
secondaryjoin="and_(Group.id==UserGroup.group_id, not_(Group.is_deleted))",
52+
)
3853

39-
_auth_methods: Mapped[list["AuthMethod"]] = relationship("AuthMethod", foreign_keys="AuthMethod.user_id")
40-
sessions: Mapped[list["UserSession"]] = relationship("UserSession", foreign_keys="UserSession.user_id")
54+
@hybrid_property
55+
def active_sessions(self) -> list:
56+
return [row for row in self.sessions if not row.expired]
4157

4258
@hybrid_property
4359
def auth_methods(self) -> ParamDict:
@@ -49,24 +65,68 @@ def auth_methods(self) -> ParamDict:
4965
return ParamDict.__new__(ParamDict, self._auth_methods)
5066

5167

52-
class AuthMethod(Base):
68+
class Group(BaseDbModel):
5369
id: Mapped[int] = mapped_column(Integer, primary_key=True)
70+
name: Mapped[str] = mapped_column(String, unique=True, nullable=False)
71+
parent_id: Mapped[int] = mapped_column(Integer, ForeignKey("group.id"), nullable=True)
72+
create_ts: Mapped[datetime.datetime] = mapped_column(DateTime, default=datetime.datetime.utcnow)
73+
is_deleted: Mapped[bool] = mapped_column(Boolean, default=False)
74+
75+
child: Mapped[list[Group]] = relationship(
76+
"Group",
77+
backref=backref("parent", remote_side=[id]),
78+
primaryjoin="and_(Group.id==Group.parent_id, not_(Group.is_deleted))",
79+
)
80+
users: Mapped[list[User]] = relationship(
81+
"User",
82+
secondary="user_group",
83+
back_populates="groups",
84+
primaryjoin="and_(Group.id==UserGroup.group_id, not_(UserGroup.is_deleted))",
85+
secondaryjoin="and_(User.id==UserGroup.user_id, not_(User.is_deleted))",
86+
)
87+
88+
@hybrid_property
89+
def parents(self) -> Iterator[Group]:
90+
parent = self
91+
while parent := parent.parent:
92+
yield parent
93+
94+
95+
class UserGroup(BaseDbModel):
96+
user_id: Mapped[int] = mapped_column(Integer, ForeignKey("user.id"))
97+
group_id: Mapped[int] = mapped_column(Integer, ForeignKey("group.id"))
98+
is_deleted: Mapped[bool] = mapped_column(Boolean, default=False)
99+
100+
101+
class AuthMethod(BaseDbModel):
54102
user_id: Mapped[int] = mapped_column(Integer, ForeignKey("user.id"))
55103
auth_method: Mapped[str] = mapped_column(String)
56104
param: Mapped[str] = mapped_column(String)
57105
value: Mapped[str] = mapped_column(String)
106+
is_deleted: Mapped[bool] = mapped_column(Boolean, default=False)
58107

59-
user: Mapped["User"] = relationship("User", foreign_keys=[user_id], back_populates="_auth_methods")
108+
user: Mapped[User] = relationship(
109+
"User",
110+
foreign_keys=[user_id],
111+
back_populates="_auth_methods",
112+
primaryjoin="and_(AuthMethod.user_id==User.id, not_(User.is_deleted))",
113+
)
60114

61115

62-
class UserSession(Base):
63-
id: Mapped[int] = mapped_column(Integer, primary_key=True)
116+
class UserSession(BaseDbModel):
64117
user_id: Mapped[int] = mapped_column(Integer, sqlalchemy.ForeignKey("user.id"))
65-
expires: Mapped[datetime.datetime] = mapped_column(DateTime, default=datetime.datetime.utcnow() + datetime.timedelta(days=7))
118+
expires: Mapped[datetime.datetime] = mapped_column(
119+
DateTime, default=datetime.datetime.utcnow() + datetime.timedelta(days=7)
120+
)
66121
token: Mapped[str] = mapped_column(String, unique=True)
67122

68-
user: Mapped["User"] = relationship("User", foreign_keys=[user_id], back_populates="sessions")
123+
user: Mapped[User] = relationship(
124+
"User",
125+
foreign_keys=[user_id],
126+
back_populates="sessions",
127+
primaryjoin="and_(UserSession.user_id==User.id, not_(User.is_deleted))",
128+
)
69129

70130
@hybrid_property
71-
def expired(self):
131+
def expired(self) -> bool:
72132
return self.expires <= datetime.datetime.utcnow()

auth_backend/routes/base.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,15 +5,15 @@
55
from auth_backend.auth_plugins.auth_method import AUTH_METHODS
66
from auth_backend.settings import get_settings
77
from .user_session import logout_router
8+
from .user_groups import user_groups
9+
from .groups import groups
810

911
settings = get_settings()
1012

1113
app = FastAPI()
1214

1315

14-
app.add_middleware(
15-
DBSessionMiddleware, db_url=settings.DB_DSN, engine_args={"pool_pre_ping": True}
16-
)
16+
app.add_middleware(DBSessionMiddleware, db_url=settings.DB_DSN, engine_args={"pool_pre_ping": True})
1717

1818
app.add_middleware(
1919
CORSMiddleware,
@@ -24,6 +24,8 @@
2424
)
2525

2626
app.include_router(logout_router)
27+
app.include_router(user_groups)
28+
app.include_router(groups)
2729
if not settings.ENABLED_AUTH_METHODS:
2830
for method in AUTH_METHODS.values():
2931
app.include_router(router := method().router, prefix=router.prefix, tags=[method.get_name()])

auth_backend/routes/groups.py

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,68 @@
1+
from typing import Literal
2+
3+
from fastapi import APIRouter, HTTPException, Depends, Query
4+
from fastapi_sqlalchemy import db
5+
6+
from auth_backend.exceptions import ObjectNotFound, AlreadyExists
7+
from auth_backend.models.db import Group as DbGroup
8+
from .models.models import Group, GroupPost, GroupsGet, GroupPatch, GroupGet
9+
from ..base import ResponseModel
10+
from ..utils.security import UnionAuth
11+
12+
auth = UnionAuth()
13+
14+
groups = APIRouter(prefix="/group", tags=["Groups"])
15+
16+
17+
@groups.get("/{id}", response_model=GroupGet, response_model_exclude_unset=True)
18+
async def get_group(id: int, info: list[Literal["child"]] = Query(default=[])) -> dict[str, str | int]:
19+
group = DbGroup.get(id, session=db.session)
20+
result = {}
21+
result = result | Group.from_orm(group).dict()
22+
if "child" in info:
23+
result = result | {"child": group.child}
24+
return GroupGet(**result).dict(exclude_unset=True)
25+
26+
27+
@groups.post("", response_model=Group)
28+
async def create_group(group_inp: GroupPost, _: dict[str, str] = Depends(auth)) -> Group:
29+
if group_inp.parent_id and not db.session.query(DbGroup).get(group_inp.parent_id):
30+
raise ObjectNotFound(Group, group_inp.parent_id)
31+
if DbGroup.get_all(session=db.session).filter(DbGroup.name == group_inp.name).one_or_none():
32+
raise HTTPException(status_code=409, detail=ResponseModel(status="Error", message="Name already exists").json())
33+
group = DbGroup.create(session=db.session, **group_inp.dict())
34+
db.session.commit()
35+
return Group.from_orm(group)
36+
37+
38+
@groups.patch("/{id}", response_model=Group)
39+
async def patch_group(id: int, group_inp: GroupPatch, _: dict[str, str] = Depends(auth)) -> Group:
40+
if (
41+
exists_check := DbGroup.get_all(session=db.session)
42+
.filter(DbGroup.name == group_inp.name, DbGroup.id != id)
43+
.one_or_none()
44+
):
45+
raise AlreadyExists(Group, exists_check.id)
46+
group = DbGroup.get(id, session=db.session)
47+
if group_inp.parent_id in (row.id for row in group.child):
48+
raise HTTPException(status_code=400, detail=ResponseModel(status="Error", message="Cycle detected").json())
49+
patched = DbGroup.update(id, session=db.session, **group_inp.dict(exclude_unset=True))
50+
db.session.commit()
51+
return Group.from_orm(patched)
52+
53+
54+
@groups.delete("/{id}", response_model=None)
55+
async def delete_group(id: int, _: dict[str, str] = Depends(auth)) -> None:
56+
group: DbGroup = DbGroup.get(id, session=db.session)
57+
if child := group.child:
58+
for children in child:
59+
children.parent = group.parent
60+
db.session.flush()
61+
DbGroup.delete(id, session=db.session)
62+
db.session.commit()
63+
return None
64+
65+
66+
@groups.get("", response_model=GroupsGet)
67+
async def get_groups() -> GroupsGet:
68+
return GroupsGet(items=DbGroup.get_all(session=db.session).all())

0 commit comments

Comments
 (0)