Skip to content

Commit 74f5c13

Browse files
SQLAlchemy 2.0 (#26)
* SQLAlchemy 2.0
1 parent b5fbdb1 commit 74f5c13

9 files changed

Lines changed: 47 additions & 37 deletions

File tree

auth_backend/auth_plugins/email.py

Lines changed: 14 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
send_confirmation_email,
1717
send_change_password_confirmation,
1818
send_changes_password_notification,
19+
send_reset_email,
1920
)
2021
from .auth_method import AuthMethodMeta, Session
2122
from fastapi.background import BackgroundTasks
@@ -133,7 +134,7 @@ async def _login(user_inp: EmailLogin) -> Session:
133134
):
134135
raise AuthFailed(error="Incorrect login or password")
135136
db.session.add(user_session := UserSession(user_id=query.user.id, token=random_string()))
136-
db.session.flush()
137+
db.session.commit()
137138
return Session(
138139
user_id=user_session.user_id, token=user_session.token, id=user_session.id, expires=user_session.expires
139140
)
@@ -155,13 +156,14 @@ async def _add_to_db(user_inp: EmailRegister, confirmation_token: str, user: Use
155156
)
156157
db.session.flush()
157158

159+
158160
@staticmethod
159161
async def _change_confirmation_link(user: User, confirmation_token: str) -> None:
160162
if user.auth_methods.confirmed.value == "true":
161163
raise AlreadyExists(User, user.id)
162164
else:
163165
user.auth_methods.confirmation_token.value = confirmation_token
164-
db.session.flush()
166+
165167

166168
@staticmethod
167169
async def _get_user_by_token_and_id(id: int, token: str) -> User:
@@ -198,6 +200,7 @@ async def _register(
198200
to_addr=user_inp.email,
199201
link=f"{settings.APPLICATION_HOST}/email/approve?token={confirmation_token}",
200202
)
203+
db.session.commit()
201204
return ResponseModel(status="Success", message="Email confirmation link sent")
202205
if user_inp.user_id and token:
203206
user = await Email._get_user_by_token_and_id(user_inp.user_id, token)
@@ -211,6 +214,7 @@ async def _register(
211214
to_addr=user_inp.email,
212215
link=f"{settings.APPLICATION_HOST}/email/approve?token={confirmation_token}",
213216
)
217+
db.session.commit()
214218
raise HTTPException(
215219
status_code=201, detail=ResponseModel(status="Success", message="Email confirmation link sent").json()
216220
)
@@ -238,8 +242,8 @@ async def _approve_email(token: str) -> ResponseModel:
238242
)
239243
if not auth_method:
240244
raise HTTPException(status_code=403, detail=ResponseModel(status="Error", message="Incorrect link").json())
241-
auth_method.user.auth_methods.confirmed.value = True
242-
db.session.flush()
245+
auth_method.user.auth_methods.confirmed.value = "true"
246+
db.session.commit()
243247
return ResponseModel(status="Success", message="Email approved")
244248

245249
@staticmethod
@@ -269,12 +273,12 @@ async def _request_reset_email(
269273
user_id=session.user_id, auth_method=Email.get_name(), param="tmp_email_confirmation_token", value=token
270274
)
271275
db.session.add_all([tmp_email, tmp_email_confirmation_token])
272-
db.session.flush()
273276
background_tasks.add_task(
274-
send_confirmation_email,
277+
send_reset_email,
275278
to_addr=scheme.email,
276279
link=f"{settings.APPLICATION_HOST}/email/reset/email/{session.user_id}?token={token}&email={scheme.email}",
277280
)
281+
db.session.commit()
278282
return ResponseModel(status="Success", message="Email confirmation link sent")
279283

280284
@staticmethod
@@ -295,7 +299,7 @@ async def _reset_email(user_id: int, token: str, email: str):
295299
user.auth_methods.email.value = user.auth_methods.tmp_email.value
296300
db.session.delete(user.auth_methods.tmp_email_confirmation_token)
297301
db.session.delete(user.auth_methods.tmp_email)
298-
db.session.flush()
302+
db.session.commit()
299303
return ResponseModel(status="Success", message="Email successfully changed")
300304

301305
@staticmethod
@@ -326,8 +330,8 @@ async def _request_reset_password(
326330
)
327331
session.user.auth_methods.hashed_password.value = Email._hash_password(schema.new_password, salt)
328332
session.user.auth_methods.salt.value = salt
329-
db.session.flush()
330333
background_tasks.add_task(send_changes_password_notification, session.user.auth_methods.email.value)
334+
db.session.commit()
331335
return ResponseModel(status="Success", message="Password has been successfully changed")
332336
elif not token and not schema.password and not schema.new_password:
333337
user: User = db.session.query(User).get(user_id)
@@ -345,7 +349,7 @@ async def _request_reset_password(
345349
db.session.add(
346350
AuthMethod(user_id=user_id, auth_method=Email.get_name(), param="reset_token", value=random_string())
347351
)
348-
db.session.flush()
352+
db.session.commit()
349353
background_tasks.add_task(
350354
send_change_password_confirmation,
351355
user.auth_methods.email.value,
@@ -372,5 +376,5 @@ async def _reset_password(
372376
user.auth_methods.hashed_password.value = Email._hash_password(schema.new_password, salt)
373377
user.auth_methods.salt.value = salt
374378
db.session.delete(user.auth_methods.reset_token)
375-
db.session.flush()
379+
db.session.commit()
376380
return ResponseModel(status="Success", message="Password has been successfully changed")

auth_backend/models/db.py

Lines changed: 16 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@
33
import datetime
44

55
import sqlalchemy.orm
6+
from sqlalchemy.orm import Mapped, mapped_column, relationship
7+
from sqlalchemy import String, Integer, ForeignKey, DateTime
68
from sqlalchemy.ext.hybrid import hybrid_property
79

810
from auth_backend.models.base import Base
@@ -32,10 +34,10 @@ def __new__(cls, methods: list[AuthMethod], *args, **kwargs):
3234

3335
class User(Base):
3436

35-
id = sqlalchemy.Column(sqlalchemy.Integer, primary_key=True)
37+
id: Mapped[int] = mapped_column(Integer, primary_key=True)
3638

37-
_auth_methods: list[AuthMethod] = sqlalchemy.orm.relationship("AuthMethod", foreign_keys="AuthMethod.user_id")
38-
sessions: list[UserSession] = sqlalchemy.orm.relationship("UserSession", foreign_keys="UserSession.user_id")
39+
_auth_methods: Mapped[list["AuthMethod"]] = relationship("AuthMethod", foreign_keys="AuthMethod.user_id")
40+
sessions: Mapped[list["UserSession"]] = relationship("UserSession", foreign_keys="UserSession.user_id")
3941

4042
@hybrid_property
4143
def auth_methods(self) -> ParamDict:
@@ -48,22 +50,22 @@ def auth_methods(self) -> ParamDict:
4850

4951

5052
class AuthMethod(Base):
51-
id = sqlalchemy.Column(sqlalchemy.Integer, primary_key=True)
52-
user_id = sqlalchemy.Column(sqlalchemy.Integer, sqlalchemy.ForeignKey("user.id"))
53-
auth_method = sqlalchemy.Column(sqlalchemy.String)
54-
param = sqlalchemy.Column(sqlalchemy.String)
55-
value = sqlalchemy.Column(sqlalchemy.String)
53+
id: Mapped[int] = mapped_column(Integer, primary_key=True)
54+
user_id: Mapped[int] = mapped_column(Integer, ForeignKey("user.id"))
55+
auth_method: Mapped[str] = mapped_column(String)
56+
param: Mapped[str] = mapped_column(String)
57+
value: Mapped[str] = mapped_column(String)
5658

57-
user: User = sqlalchemy.orm.relationship("User", foreign_keys=[user_id], back_populates="_auth_methods")
59+
user: Mapped["User"] = relationship("User", foreign_keys=[user_id], back_populates="_auth_methods")
5860

5961

6062
class UserSession(Base):
61-
id = sqlalchemy.Column(sqlalchemy.Integer, primary_key=True)
62-
user_id = sqlalchemy.Column(sqlalchemy.Integer, sqlalchemy.ForeignKey("user.id"))
63-
expires = sqlalchemy.Column(sqlalchemy.DateTime, default=datetime.datetime.utcnow() + datetime.timedelta(days=7))
64-
token = sqlalchemy.Column(sqlalchemy.String, unique=True)
63+
id: Mapped[int] = mapped_column(Integer, primary_key=True)
64+
user_id: Mapped[int] = mapped_column(Integer, sqlalchemy.ForeignKey("user.id"))
65+
expires: Mapped[datetime.datetime] = mapped_column(DateTime, default=datetime.datetime.utcnow() + datetime.timedelta(days=7))
66+
token: Mapped[str] = mapped_column(String, unique=True)
6567

66-
user: User = sqlalchemy.orm.relationship("User", foreign_keys=[user_id], back_populates="sessions")
68+
user: Mapped["User"] = relationship("User", foreign_keys=[user_id], back_populates="sessions")
6769

6870
@hybrid_property
6971
def expired(self):

auth_backend/routes/base.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212

1313

1414
app.add_middleware(
15-
DBSessionMiddleware, db_url=settings.DB_DSN, session_args={"autocommit": True}, engine_args={"pool_pre_ping": True}
15+
DBSessionMiddleware, db_url=settings.DB_DSN, engine_args={"pool_pre_ping": True}
1616
)
1717

1818
app.add_middleware(

auth_backend/routes/user_session.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ async def logout(token: str = Header(min_length=1)) -> JSONResponse:
2121
if session.expired:
2222
raise SessionExpired(session.token)
2323
session.expires = datetime.utcnow()
24-
db.session.flush()
24+
db.session.commit()
2525
return JSONResponse(status_code=200, content=ResponseModel(status="Success", message="Logout successful").json())
2626

2727

tests/conftest.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ def client():
1818
auth_backend.auth_plugins.email.send_confirmation_email = Mock(return_value=None)
1919
auth_backend.auth_plugins.email.send_change_password_confirmation = Mock(return_value=None)
2020
auth_backend.auth_plugins.email.send_changes_password_notification = Mock(return_value=None)
21+
auth_backend.auth_plugins.email.send_reset_email = Mock(return_value=None)
2122
client = TestClient(app)
2223
yield client
2324

@@ -26,7 +27,7 @@ def client():
2627
def dbsession():
2728
settings = get_settings()
2829
engine = create_engine(settings.DB_DSN)
29-
TestingSessionLocal = sessionmaker(autocommit=True, autoflush=False, bind=engine)
30+
TestingSessionLocal = sessionmaker(bind=engine)
3031
return TestingSessionLocal()
3132

3233

@@ -42,7 +43,7 @@ def user_id(client: TestClient, dbsession):
4243
for row in dbsession.query(AuthMethod).filter(AuthMethod.user_id == db_user.user_id).all():
4344
dbsession.delete(row)
4445
dbsession.delete(dbsession.query(User).filter(User.id == db_user.user_id).one())
45-
dbsession.flush()
46+
dbsession.commit()
4647

4748

4849
@pytest.fixture()
@@ -73,4 +74,4 @@ def user(client: TestClient, dbsession):
7374
for row in dbsession.query(AuthMethod).filter(AuthMethod.user_id == db_user.user_id).all():
7475
dbsession.delete(row)
7576
dbsession.delete(dbsession.query(User).filter(User.id == db_user.user_id).one())
76-
dbsession.flush()
77+
dbsession.commit()

tests/test_routes/test_change_email.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
import datetime
2+
13
from fastapi.testclient import TestClient
24
from sqlalchemy.orm import Session
35
from starlette import status
@@ -15,7 +17,8 @@ def test_main_scenario(client: TestClient, dbsession: Session, user):
1517
.one()
1618
.value
1719
)
18-
response = client.post(f"{url}request", json={"email": "changed@mail.com"}, headers={"token": login["token"]})
20+
tmp_email = f"changed{datetime.datetime.utcnow()}@mail.com"
21+
response = client.post(f"{url}request", json={"email": tmp_email}, headers={"token": login["token"]})
1922
assert response.status_code == status.HTTP_200_OK
2023

2124
conf_token_2 = (
@@ -38,7 +41,7 @@ def test_main_scenario(client: TestClient, dbsession: Session, user):
3841
response = client.post(f"/email/login", json=body)
3942
assert response.status_code == status.HTTP_200_OK
4043

41-
response = client.post(f"/email/login", json={"email": "changed@mail.com", "password": body["password"]})
44+
response = client.post(f"/email/login", json={"email": tmp_email, "password": body["password"]})
4245
assert response.status_code == status.HTTP_401_UNAUTHORIZED
4346

4447
response = client.get(f"{url}{user_id}?token={conf_token_1}&email=changed@mail.com")
@@ -47,13 +50,13 @@ def test_main_scenario(client: TestClient, dbsession: Session, user):
4750
response = client.get(f"{url}{user_id}?token={tmp_token}&email=wrong@mail.com")
4851
assert response.status_code == status.HTTP_403_FORBIDDEN
4952

50-
response = client.get(f"{url}{user_id}?token={tmp_token}&email=changed@mail.com")
53+
response = client.get(f"{url}{user_id}?token={tmp_token}&email={tmp_email}")
5154
assert response.status_code == status.HTTP_200_OK
5255

5356
response = client.post(f"/email/login", json=body)
5457
assert response.status_code == status.HTTP_401_UNAUTHORIZED
5558

56-
response = client.post(f"/email/login", json={"email": "changed@mail.com", "password": body["password"]})
59+
response = client.post(f"/email/login", json={"email": tmp_email, "password": body["password"]})
5760
assert response.status_code == status.HTTP_200_OK
5861

5962

tests/test_routes/test_login.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,7 @@ def test_incorrect_data(client: TestClient, dbsession: Session):
6767
for row in dbsession.query(AuthMethod).filter(AuthMethod.user_id == id).all():
6868
dbsession.delete(row)
6969
dbsession.delete(dbsession.query(User).filter(User.id == id).one())
70-
dbsession.flush()
70+
dbsession.commit()
7171

7272

7373
def test_check_token(client: TestClient, user, dbsession: Session):

tests/test_routes/test_logout.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ def test_main_scenario(client: TestClient, dbsession: Session):
4141
dbsession.delete(row)
4242
dbsession.delete(dbsession.query(UserSession).filter(UserSession.user_id == id).one())
4343
dbsession.delete(dbsession.query(User).filter(User.id == id).one())
44-
dbsession.flush()
44+
dbsession.commit()
4545

4646

4747
def test_without_token(client: TestClient, dbsession: Session):

tests/test_routes/test_registration.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@ def test_main_scenario(client: TestClient, dbsession: Session):
6363
for row in dbsession.query(AuthMethod).filter(AuthMethod.user_id == db_user.user_id).all():
6464
dbsession.delete(row)
6565
dbsession.delete(dbsession.query(User).filter(User.id == db_user.user_id).one())
66-
dbsession.flush()
66+
dbsession.commit()
6767

6868

6969
def test_repeated_registration_case(client: TestClient, dbsession: Session):
@@ -105,4 +105,4 @@ def test_repeated_registration_case(client: TestClient, dbsession: Session):
105105
for row in dbsession.query(AuthMethod).filter(AuthMethod.user_id == user_id).all():
106106
dbsession.delete(row)
107107
dbsession.delete(dbsession.query(User).filter(User.id == user_id).one())
108-
dbsession.flush()
108+
dbsession.commit()

0 commit comments

Comments
 (0)