|
1 | 1 | from sqlalchemy import select |
2 | 2 |
|
3 | | -from ..models import CourseAttribute |
| 3 | +from ..models import CourseAttribute, Courses |
4 | 4 | from ..async_session import async_session |
| 5 | +from rsptx.logging import rslogger |
5 | 6 |
|
6 | 7 |
|
7 | 8 | async def fetch_all_course_attributes(course_id: int) -> dict: |
@@ -48,27 +49,31 @@ async def copy_course_attributes(basecourse_id: int, new_course_id: int): |
48 | 49 | async with async_session() as session: |
49 | 50 | res = await session.execute(query) |
50 | 51 | for row in res.scalars().fetchall(): |
51 | | - print(row.attr, row.value) |
| 52 | + rslogger.debug(f"copy_course_attributes: {row.attr}={row.value}") |
52 | 53 | new_attr = CourseAttribute( |
53 | 54 | course_id=new_course_id, attr=row.attr, value=row.value |
54 | 55 | ) |
55 | 56 | session.add(new_attr) |
56 | 57 | await session.commit() |
57 | 58 |
|
58 | 59 |
|
59 | | -async def get_course_origin(base_course): |
| 60 | +async def get_course_origin(base_course: str): |
60 | 61 | """ |
61 | | - Retrieve the origin of a given course (base_course) |
| 62 | + Retrieve the origin (markup system) of a given course by its name. |
62 | 63 |
|
63 | | - :param base_course: str, the name of the base course |
64 | | - :return: str, the origin of the course |
| 64 | + :param base_course: str, the name of the base course (i.e. ``courses.course_name``) |
| 65 | + :return: str, the value of the ``markup_system`` course attribute, or None if not found |
65 | 66 | """ |
66 | | - query = select(CourseAttribute).where( |
67 | | - (CourseAttribute.course_id == base_course) |
68 | | - & (CourseAttribute.attr == "markup_system") |
| 67 | + query = ( |
| 68 | + select(CourseAttribute) |
| 69 | + .join(Courses, Courses.id == CourseAttribute.course_id) |
| 70 | + .where( |
| 71 | + (Courses.course_name == base_course) |
| 72 | + & (CourseAttribute.attr == "markup_system") |
| 73 | + ) |
69 | 74 | ) |
70 | 75 |
|
71 | 76 | async with async_session() as session: |
72 | 77 | res = await session.execute(query) |
73 | 78 | ca = res.scalars().first() |
74 | | - return ca.value |
| 79 | + return ca.value if ca else None |
0 commit comments