Skip to content

Commit 4de52d8

Browse files
committed
updated njit in vec_utils
1 parent 3aacf84 commit 4de52d8

4 files changed

Lines changed: 27 additions & 31 deletions

File tree

openptv_python/track.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -289,6 +289,7 @@ def angle_acc(
289289
return float(angle), float(acc)
290290

291291

292+
292293
def candsearch_in_pix(
293294
next_frame: List[Target],
294295
num_targets: int,

openptv_python/vec_utils.py

Lines changed: 23 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -5,87 +5,81 @@
55
# system to decide whether to invest in loop peeling etc. Here we write
66
# the logical structure, and allow optimizing for size as well.
77

8-
import math
9-
108
import numpy as np
11-
from numba import njit
12-
13-
# Define the np.ndarray type as an numpy array of 3 floats
14-
# vec3d = np.empty(3, dtype=float)
9+
from numba import float64, int32, njit
1510

16-
# and 2 floats
17-
# = np.empty(2, dtype=float)
1811

19-
@njit
20-
def norm(x: float, y: float, z: float) -> float:
21-
"""Return the norm of a 3D vector given by 3 float components."""
22-
return vec_norm(vec_set(x, y, z))
12+
@njit(float64(float64[:]),fastmath=True, cache=True, nogil=True)
13+
def vec_norm(vec: np.ndarray) -> float:
14+
"""vec_norm() gives the norm of a vector."""
15+
return np.sqrt(vec[0]**2 + vec[1]**2 + vec[2]**2)
2316

24-
@njit
17+
@njit(float64[:](float64,float64,float64),fastmath=True, cache=True, nogil=True)
2518
def vec_set(x: float, y: float, z: float) -> np.ndarray:
2619
"""Set the components of a 3D vector from separate doubles."""
2720
return np.array([x, y, z])
2821

22+
@njit(float64(float64,float64,float64),fastmath=True, cache=True, nogil=True)
23+
def norm(x: float, y: float, z: float) -> float:
24+
"""Return the norm of a 3D vector given by 3 float components."""
25+
return vec_norm(vec_set(x, y, z))
2926

27+
28+
@njit(float64[:](float64[:]),fastmath=True, cache=True, nogil=True)
3029
def vec_copy(src: np.ndarray) -> np.ndarray:
3130
"""Copy one 3D vector into another."""
3231
return src.copy()
3332

34-
@njit
33+
@njit(float64[:](float64[:],float64[:]),fastmath=True, cache=True, nogil=True)
3534
def vec_subt(from_: np.ndarray, sub: np.ndarray) -> np.ndarray:
3635
"""Subtract two 3D vectors."""
3736
return from_ - sub
3837

39-
@njit
38+
@njit(float64[:](float64[:],float64[:]),fastmath=True, cache=True, nogil=True)
4039
def vec_add(vec1: np.ndarray, vec2: np.ndarray) -> np.ndarray:
4140
"""Add two 3D vectors."""
4241
return vec1 + vec2
4342

44-
@njit
43+
@njit(float64[:](float64[:],float64),fastmath=True, cache=True, nogil=True)
4544
def vec_scalar_mul(vec: np.ndarray, scalar: float) -> np.ndarray:
4645
"""vec_scalar_mul(np.ndarray, scalar) multiplies a vector by a scalar."""
4746
return vec * scalar
4847

49-
@njit
48+
@njit(float64(float64[:],float64[:]),fastmath=True, cache=True, nogil=True)
5049
def vec_diff_norm(vec1: np.ndarray, vec2: np.ndarray) -> float:
5150
"""vec_diff_norm() gives the norm of the difference between two vectors."""
52-
# return np.linalg.norm(vec1 - vec2)
5351
vec = vec1 - vec2
54-
return math.sqrt(vec[0]**2 + vec[1]**2 + vec[2]**2)
52+
return np.sqrt(vec[0]**2 + vec[1]**2 + vec[2]**2)
5553

56-
@njit
57-
def vec_norm(vec: np.ndarray) -> float:
58-
"""vec_norm() gives the norm of a vector."""
59-
return math.sqrt(vec[0]**2 + vec[1]**2 + vec[2]**2)
6054

61-
@njit
55+
@njit(float64(float64[:],float64[:]),fastmath=True, cache=True, nogil=True)
6256
def vec_dot(vec1: np.ndarray, vec2: np.ndarray) -> float:
6357
"""vec_dot() gives the dot product of two vectors as lists of floats."""
6458
# return np.dot(vec1, vec2)
6559
return float(vec1[0]*vec2[0] + vec1[1]*vec2[1] + vec1[2]*vec2[2])
6660

67-
@njit
61+
@njit(float64[:](float64[:],float64[:]),fastmath=True, cache=True, nogil=True)
6862
def vec_cross(vec1: np.ndarray, vec2: np.ndarray) -> np.ndarray:
6963
"""Cross product of two vectors."""
7064
# return np.cross(vec1, vec2)
7165
return np.array([vec1[1]*vec2[2] - vec1[2]*vec2[1],
7266
vec1[2]*vec2[0] - vec1[0]*vec2[2],
7367
vec1[0]*vec2[1] - vec1[1]*vec2[0]])
7468

75-
@njit
69+
@njit("boolean(float64[:],float64[:],float64)",fastmath=True, cache=True, nogil=True)
7670
def vec_cmp(vec1: np.ndarray, vec2: np.ndarray, tol: float = 1e-6) -> bool:
7771
"""vec_cmp() checks whether two vectors are equal within a tolerance."""
7872
return np.allclose(vec1, vec2, atol=tol)
7973

80-
@njit
74+
@njit(float64[:](float64[:]),fastmath=True, cache=True, nogil=True)
8175
def unit_vector(vec: np.ndarray) -> np.ndarray:
8276
"""Normalize a vector to a unit vector."""
8377
magnitude = vec_norm(vec)
8478
if magnitude == 0:
8579
return vec # Avoid division by zero for zero vectors
8680
return vec / magnitude
8781

88-
@njit
89-
def vec_init(length=3) -> np.ndarray:
82+
@njit(float64[:](int32),fastmath=True, cache=True, nogil=True)
83+
def vec_init(length: int=3) -> np.ndarray:
9084
"""Initialize a vector to zero."""
9185
return np.zeros(length, dtype=float)

tests/test_tracking_run.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -285,8 +285,8 @@ def test_trackback(self):
285285
trackcorr_c_finish(run, run.seq_par.last)
286286

287287

288-
run.tpar.dvxmin = run.tpar.dvymin = run.tpar.dvzmin = -50
289-
run.tpar.dvxmax = run.tpar.dvymax = run.tpar.dvzmax = 50
288+
run.tpar.dvxmin = run.tpar.dvymin = run.tpar.dvzmin = -50.0
289+
run.tpar.dvxmax = run.tpar.dvymax = run.tpar.dvzmax = 50.0
290290

291291

292292
run.lmax = vec_norm(

tests/test_vec_utils.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
"""Tests for the vec_utils module."""
12
import numpy as np
23

34
from openptv_python.vec_utils import (

0 commit comments

Comments
 (0)