-
-
Notifications
You must be signed in to change notification settings - Fork 22
Expand file tree
/
Copy pathhooks.py
More file actions
481 lines (387 loc) · 17.3 KB
/
hooks.py
File metadata and controls
481 lines (387 loc) · 17.3 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
from __future__ import annotations
import asyncio
import logging
from collections import defaultdict
from collections.abc import Awaitable
from typing import (
TYPE_CHECKING,
Any,
Callable,
Union,
cast,
)
from uuid import uuid4
import orjson
from channels import DEFAULT_CHANNEL_LAYER
from channels import auth as channels_auth
from channels.layers import InMemoryChannelLayer, get_channel_layer
from reactpy import use_async_effect, use_callback, use_effect, use_memo, use_ref, use_state
from reactpy import use_connection as _use_connection
from reactpy import use_location as _use_location
from reactpy import use_scope as _use_scope
from reactpy_django.exceptions import UserNotFoundError
from reactpy_django.types import (
AsyncMessageReceiver,
AsyncMessageSender,
AsyncPostprocessor,
ConnectionType,
FuncParams,
Inferred,
Mutation,
Query,
SyncPostprocessor,
UseAuthTuple,
UserData,
)
from reactpy_django.utils import django_query_postprocessor, ensure_async, generate_obj_name, get_pk
if TYPE_CHECKING:
from collections.abc import Sequence
from channels_redis.core import RedisChannelLayer
from django.contrib.auth.models import AbstractUser
from reactpy.types import Location
_logger = logging.getLogger(__name__)
_REFETCH_CALLBACKS: defaultdict[Callable[..., Any], set[Callable[[], None]]] = defaultdict(set)
def use_location() -> Location:
"""Get the current route as a `Location` object"""
return _use_location()
def use_origin() -> str | None:
"""Get the current origin as a string. If the browser did not send an origin header,
this will be None."""
scope = _use_scope()
try:
if scope["type"] == "websocket":
return next(
(header[1].decode("utf-8") for header in scope["headers"] if header[0] == b"origin"),
None,
)
if scope["type"] == "http":
host = next(header[1].decode("utf-8") for header in scope["headers"] if header[0] == b"host")
return f"{scope['scheme']}://{host}" if host else None
except Exception:
_logger.info("Failed to get origin")
return None
def use_scope() -> dict[str, Any]:
"""Get the current ASGI scope dictionary"""
scope = _use_scope()
if isinstance(scope, dict):
return scope
msg = f"Expected scope to be a dict, got {type(scope)}"
raise TypeError(msg)
def use_connection() -> ConnectionType:
"""Get the current `Connection` object"""
return _use_connection()
def use_query(
query: Callable[FuncParams, Awaitable[Inferred]] | Callable[FuncParams, Inferred],
kwargs: dict[str, Any] | None = None,
*,
thread_sensitive: bool = True,
postprocessor: (AsyncPostprocessor | SyncPostprocessor | None) = django_query_postprocessor,
postprocessor_kwargs: dict[str, Any] | None = None,
) -> Query[Inferred]:
"""This hook is used to execute functions in the background and return the result, \
typically to read data the Django ORM.
Args:
query: A function that executes a query and returns some data.
Kwargs:
kwargs: Keyword arguments to passed into the `query` function.
thread_sensitive: Whether to run the query in thread sensitive mode. \
This setting only applies to sync functions, and is turned on by default \
due to Django ORM limitations.
postprocessor: A callable that processes the query `data` before it is returned. \
The first argument of postprocessor function must be the query `data`. All \
proceeding arguments are optional `postprocessor_kwargs`. This postprocessor \
function must return the modified `data`. \
\
If unset, `REACTPY_DEFAULT_QUERY_POSTPROCESSOR` is used. By default, this \
is used to prevent Django's lazy query execution and supports `many_to_many` \
and `many_to_one` as `postprocessor_kwargs`.
postprocessor_kwargs: Keyworded arguments passed into the `postprocessor` function.
Returns:
An object containing `loading`/`#!python error` states, your `data` (if the query \
has successfully executed), and a `refetch` callable that can be used to re-run the query.
"""
should_execute, set_should_execute = use_state(True)
data, set_data = use_state(cast(Inferred, None))
loading, set_loading = use_state(True)
error, set_error = use_state(cast(Union[Exception, None], None))
query_ref = use_ref(query)
async_task_refs = use_ref(set())
kwargs = kwargs or {}
postprocessor_kwargs = postprocessor_kwargs or {}
if query_ref.current is not query:
msg = f"Query function changed from {query_ref.current} to {query}."
raise ValueError(msg)
async def execute_query() -> None:
"""The main running function for `use_query`"""
try:
# Run the query
query_async = cast(
Callable[..., Awaitable[Inferred]], ensure_async(query, thread_sensitive=thread_sensitive)
)
new_data = await query_async(**kwargs)
# Run the postprocessor
if postprocessor:
async_postprocessor = cast(
Callable[..., Awaitable[Any]], ensure_async(postprocessor, thread_sensitive=thread_sensitive)
)
new_data = await async_postprocessor(new_data, **postprocessor_kwargs)
# Log any errors and set the error state
except Exception as e:
set_data(cast(Inferred, None))
set_loading(False)
set_error(e)
_logger.exception("Failed to execute query: %s", generate_obj_name(query))
return
# Query was successful
else:
set_data(new_data)
set_loading(False)
set_error(None)
@use_effect(dependencies=None)
def schedule_query() -> None:
"""Schedule the query to be run"""
# Make sure we don't re-execute the query unless we're told to
if not should_execute:
return
set_should_execute(False)
# Execute the query in the background
task = asyncio.create_task(execute_query())
# Add the task to a set to prevent it from being garbage collected
async_task_refs.current.add(task)
task.add_done_callback(async_task_refs.current.remove)
@use_callback
def refetch() -> None:
"""Callable provided to the user, used to re-execute the query"""
set_should_execute(True)
set_loading(True)
set_error(None)
@use_effect(dependencies=[])
def register_refetch_callback() -> Callable[[], None]:
"""Track the refetch callback so we can re-execute the query"""
# By tracking callbacks globally, any usage of the query function will be re-run
# if the user has told a mutation to refetch it.
_REFETCH_CALLBACKS[query].add(refetch)
return lambda: _REFETCH_CALLBACKS[query].remove(refetch)
# Return Query user API
return Query(data, loading, error, refetch)
def use_mutation(
mutation: (Callable[FuncParams, bool | None] | Callable[FuncParams, Awaitable[bool | None]]),
*,
thread_sensitive: bool = True,
refetch: Callable[..., Any] | Sequence[Callable[..., Any]] | None = None,
) -> Mutation[FuncParams]:
"""This hook is used to modify data in the background, typically to create/update/delete \
data from the Django ORM.
Mutation functions can `return False` to prevent executing your `refetch` function. All \
other returns are ignored. Mutation functions can be sync or async.
Args:
mutation: A callable that performs Django ORM create, update, or delete \
functionality. If this function returns `False`, then your `refetch` \
function will not be used.
Kwargs:
thread_sensitive: Whether to run the mutation in thread sensitive mode. \
This setting only applies to sync functions, and is turned on by default \
due to Django ORM limitations.
refetch: A query function (the function you provide to your `use_query` \
hook) or a sequence of query functions that need a `refetch` if the \
mutation succeeds. This is useful for refreshing data after a mutation \
has been performed.
Returns:
An object containing `#!python loading`/`#!python error` states, and a \
`#!python reset` callable that will set `#!python loading`/`#!python error` \
states to defaults. This object can be called to run the query.
"""
loading, set_loading = use_state(False)
error, set_error = use_state(cast(Union[Exception, None], None))
async_task_refs = use_ref(set())
# The main "running" function for `use_mutation`
async def execute_mutation(*exec_args: FuncParams.args, **exec_kwargs: FuncParams.kwargs) -> None:
# Run the mutation
try:
should_refetch = await ensure_async(mutation, thread_sensitive=thread_sensitive)(*exec_args, **exec_kwargs)
# Log any errors and set the error state
except Exception as e:
set_loading(False)
set_error(e)
_logger.exception("Failed to execute mutation: %s", generate_obj_name(mutation))
# Mutation was successful
else:
set_loading(False)
set_error(None)
# `refetch` will execute unless explicitly told not to
# or if `refetch` was not defined.
if should_refetch is not False and refetch:
for query in (refetch,) if callable(refetch) else refetch:
for callback in _REFETCH_CALLBACKS.get(query) or ():
callback()
# Schedule the mutation to be run when needed
@use_callback
def schedule_mutation(*exec_args: FuncParams.args, **exec_kwargs: FuncParams.kwargs) -> None:
# Set the loading state.
# It's okay to re-execute the mutation if we're told to. The user
# can use the `loading` state to prevent this.
set_loading(True)
# Execute the mutation in the background
task = asyncio.ensure_future(execute_mutation(*exec_args, **exec_kwargs))
# Add the task to a set to prevent it from being garbage collected
async_task_refs.current.add(task)
task.add_done_callback(async_task_refs.current.remove)
# Used when the user has told us to reset this mutation
@use_callback
def reset() -> None:
set_loading(False)
set_error(None)
# Return mutation user API
return Mutation(schedule_mutation, loading, error, reset)
def use_user() -> AbstractUser:
"""Get the current `User` object from either the WebSocket or HTTP request."""
connection = use_connection()
user = connection.scope.get("user") or getattr(connection.carrier, "user", None)
if user is None:
msg = "No user is available in the current environment."
raise UserNotFoundError(msg)
return user
def use_user_data(
default_data: (None | dict[str, Callable[[], Any] | Callable[[], Awaitable[Any]] | Any]) = None,
save_default_data: bool = False,
) -> UserData:
"""Get or set user data stored within the REACTPY_DATABASE.
Kwargs:
default_data: A dictionary containing `{key: default_value}` pairs. \
For computationally intensive defaults, your `default_value` \
can be sync or async functions that return the value to set.
save_default_data: If True, `default_data` values will automatically be stored \
within the database if they do not exist.
"""
from reactpy_django.models import UserDataModel
user = use_user()
async def _set_user_data(data: dict):
if not isinstance(data, dict):
msg = f"Expected dict while setting user data, got {type(data)}"
raise TypeError(msg)
if user.is_anonymous:
msg = "AnonymousUser cannot have user data."
raise ValueError(msg)
pk = get_pk(user)
model, _ = await UserDataModel.objects.aget_or_create(user_pk=pk)
model.data = orjson.dumps(data)
await model.asave()
query: Query[dict | None] = use_query(
_get_user_data,
kwargs={
"user": user,
"default_data": default_data,
"save_default_data": save_default_data,
},
postprocessor=None,
)
mutation = use_mutation(_set_user_data, refetch=_get_user_data)
return UserData(query, mutation)
def use_channel_layer(
*,
channel: str | None = None,
group: str | None = None,
receiver: AsyncMessageReceiver | None = None,
layer: str = DEFAULT_CHANNEL_LAYER,
) -> AsyncMessageSender:
"""
Subscribe to a Django Channels layer to send/receive messages.
Args:
channel: The name of the channel this hook will send/receive messages on. If `group` is \
defined and `channel` is `None`, ReactPy will automatically generate a unique channel name.
Kwargs:
group: If configured, the `channel` is added to a `group` and any messages sent by `AsyncMessageSender` \
is broadcasted to all channels within the `group`.
receiver: An async function that receives a `message: dict` from a channel.
layer: The Django Channels layer to use. This layer must be defined in `settings.py:CHANNEL_LAYERS`.
Returns:
An async callable that can send messages to the channel(s). This callable accepts a single \
argument, `message: dict`, which is the data sent to the channel or group of channels.
"""
channel_layer: InMemoryChannelLayer | RedisChannelLayer = get_channel_layer(layer) # type: ignore
channel_name = use_memo(lambda: str(channel or uuid4()))
if not (channel or group):
msg = "You must either define a `channel` or `group` for this hook."
raise ValueError(msg)
if not channel_layer:
msg = (
f"Channel layer '{layer}' is not available. Are you sure you"
" configured settings.py:CHANNEL_LAYERS properly?"
)
raise ValueError(msg)
# Add/remove a group's channel during component mount/dismount respectively.
@use_async_effect(dependencies=[])
async def group_manager():
if group:
await channel_layer.group_add(group, channel_name)
return lambda: asyncio.run(channel_layer.group_discard(group, channel_name))
return None
# Listen for messages on the channel using the provided `receiver` function.
@use_async_effect
async def message_receiver():
if not receiver:
return
while True:
message = await channel_layer.receive(channel_name)
await receiver(message)
# User interface for sending messages to the channel
async def message_sender(message: dict):
if group:
await channel_layer.group_send(group, message)
else:
await channel_layer.send(channel_name, message)
return message_sender
def use_root_id() -> str:
"""Get the root element's ID. This value is guaranteed to be unique. Current versions of \
ReactPy-Django return a `uuid4` string."""
scope = use_scope()
return scope["reactpy"]["id"]
def use_rerender() -> Callable[[], None]:
"""Provides a callable that can re-render the entire component tree without disconnecting the websocket."""
scope = use_scope()
def rerender():
scope["reactpy"]["rerender"]()
return rerender
def use_auth() -> UseAuthTuple:
"""Provides the ability to login/logout a user using Django's authentication framework."""
from reactpy_django import config
scope = use_scope()
trigger_rerender = use_rerender()
async def login(user: AbstractUser, rerender: bool = True) -> None:
await channels_auth.login(scope, user, backend=config.REACTPY_AUTH_BACKEND)
session_save_method = getattr(scope["session"], "asave", scope["session"].save)
await ensure_async(session_save_method)()
await scope["reactpy"]["synchronize_auth"]()
if rerender:
trigger_rerender()
async def logout(rerender: bool = True) -> None:
await channels_auth.logout(scope)
if rerender:
trigger_rerender()
return UseAuthTuple(login=login, logout=logout)
async def _get_user_data(user: AbstractUser, default_data: None | dict, save_default_data: bool) -> dict | None:
"""The mutation function for `use_user_data`"""
from reactpy_django.models import UserDataModel
if not user or user.is_anonymous:
return None
pk = get_pk(user)
model, _ = await UserDataModel.objects.aget_or_create(user_pk=pk)
data = orjson.loads(model.data) if model.data else {}
if not isinstance(data, dict):
msg = f"Expected dict while loading user data, got {type(data)}"
raise TypeError(msg)
# Set default values, if needed
if default_data:
changed = False
for key, value in default_data.items():
if key not in data:
new_value: Any = value
if callable(value):
new_value = await ensure_async(value)()
data[key] = new_value
changed = True
if changed:
model.data = orjson.dumps(data)
if save_default_data:
await model.asave()
return data