Skip to content

Commit 32e5626

Browse files
committed
Updated notebooks and created utils
1 parent 5349e94 commit 32e5626

9 files changed

Lines changed: 634 additions & 891 deletions

compress_rtp/compress_rtp_optimization.py

Lines changed: 1 addition & 88 deletions
Original file line numberDiff line numberDiff line change
@@ -9,15 +9,6 @@
99
import cvxpy as cp
1010
import numpy as np
1111
from copy import deepcopy
12-
try:
13-
from sklearn.utils.extmath import randomized_svd
14-
except ImportError:
15-
pass
16-
import scipy
17-
try:
18-
import pywt
19-
except ImportError:
20-
pass
2112

2213

2314
class CompressRTPOptimization(Optimization):
@@ -165,82 +156,4 @@ def create_cvxpy_problem_compressed(self, S=None, H=None, W=None):
165156
<= limit / num_fractions]
166157

167158
constraints += [Wx == (W @ x)]
168-
print('Constraints done')
169-
170-
def get_sparse_plus_low_rank(self, A=None, thresold_perc=1, rank=5):
171-
"""
172-
:param A: dose influence matrix
173-
:param thresold_perc: thresold percentage. Default to 1% of max(A)
174-
:type rank: rank of L = A-S.
175-
:returns: S, H, W using randomized svd
176-
"""
177-
if A is None:
178-
A = deepcopy(self.inf_matrix.A)
179-
tol = np.max(A) * thresold_perc * 0.01
180-
# S = S*0
181-
S = np.where(A > tol, A, 0)
182-
if rank == 0:
183-
H = np.zeros((A.shape[0], 1))
184-
W = np.zeros((1, A.shape[1]))
185-
else:
186-
print('Running svd..')
187-
[U, svd_S, V] = randomized_svd(A - S, n_components=rank + 1, random_state=0)
188-
print('svd done!')
189-
H = U[:, :rank]
190-
W = np.diag(svd_S[:rank]) @ V[:rank, :]
191-
S = scipy.sparse.csr_matrix(S)
192-
return S, H, W
193-
194-
def get_low_dim_basis(self, inf_matrix: InfluenceMatrix = None, compression: str = 'wavelet'):
195-
"""
196-
:param inf_matrix: an object of class InfluenceMatrix for the specified plan
197-
:param compression: the compression method
198-
:type compression: str
199-
:return: a list that contains the dimension reduction basis in the format of array(float)
200-
"""
201-
if inf_matrix is None:
202-
inf_matrix = self.inf_matrix
203-
low_dim_basis = {}
204-
num_of_beams = len(inf_matrix.beamlets_dict)
205-
num_of_beamlets = inf_matrix.beamlets_dict[num_of_beams - 1]['end_beamlet_idx'] + 1
206-
beam_id = [inf_matrix.beamlets_dict[i]['beam_id'] for i in range(num_of_beams)]
207-
beamlets = inf_matrix.get_bev_2d_grid(beam_id=beam_id)
208-
index_position = list()
209-
for ind in range(num_of_beams):
210-
low_dim_basis[beam_id[ind]] = []
211-
for i in range(inf_matrix.beamlets_dict[ind]['start_beamlet_idx'],
212-
inf_matrix.beamlets_dict[ind]['end_beamlet_idx'] + 1):
213-
index_position.append((np.where(beamlets[ind] == i)[0][0], np.where(beamlets[ind] == i)[1][0]))
214-
if compression == 'wavelet':
215-
max_dim_0 = np.max([beamlets[ind].shape[0] for ind in range(num_of_beams)])
216-
max_dim_1 = np.max([beamlets[ind].shape[1] for ind in range(num_of_beams)])
217-
beamlet_2d_grid = np.zeros((int(np.ceil(max_dim_0 / 2)), int(np.ceil(max_dim_1 / 2))))
218-
for row in range(beamlet_2d_grid.shape[0]):
219-
for col in range(beamlet_2d_grid.shape[1]):
220-
beamlet_2d_grid[row][col] = 1
221-
approximation_coeffs = pywt.idwt2((beamlet_2d_grid, (None, None, None)), 'sym4',
222-
mode='periodization')
223-
horizontal_coeffs = pywt.idwt2((None, (beamlet_2d_grid, None, None)), 'sym4', mode='periodization')
224-
for b in range(num_of_beams):
225-
if ((2 * row + 1 < beamlets[b].shape[0] and 2 * col + 1 < beamlets[b].shape[1] and
226-
beamlets[b][2 * row + 1][2 * col + 1] != -1) or
227-
(2 * row + 1 < beamlets[b].shape[0] and 2 * col < beamlets[b].shape[1] and
228-
beamlets[b][2 * row + 1][2 * col] != -1) or
229-
(2 * row < beamlets[b].shape[0] and 2 * col + 1 < beamlets[b].shape[1] and
230-
beamlets[b][2 * row][2 * col + 1] != -1) or
231-
(2 * row < beamlets[b].shape[0] and 2 * col < beamlets[b].shape[1] and
232-
beamlets[b][2 * row][2 * col] != -1)):
233-
approximation = np.zeros(num_of_beamlets)
234-
horizontal = np.zeros(num_of_beamlets)
235-
for ind in range(inf_matrix.beamlets_dict[b]['start_beamlet_idx'],
236-
inf_matrix.beamlets_dict[b]['end_beamlet_idx'] + 1):
237-
approximation[ind] = approximation_coeffs[index_position[ind]]
238-
horizontal[ind] = horizontal_coeffs[index_position[ind]]
239-
low_dim_basis[beam_id[b]].append(np.transpose(np.stack([approximation, horizontal])))
240-
beamlet_2d_grid[row][col] = 0
241-
for b in beam_id:
242-
low_dim_basis[b] = np.concatenate(low_dim_basis[b], axis=1)
243-
u, s, vh = scipy.sparse.linalg.svds(low_dim_basis[b], k=min(low_dim_basis[b].shape[0], low_dim_basis[b].shape[1]) - 1)
244-
ind = np.where(s > 0.0001)
245-
low_dim_basis[b] = u[:, ind[0]]
246-
return np.concatenate([low_dim_basis[b] for b in beam_id], axis=1)
159+
print('Constraints done')

compress_rtp/utils/__init__.py

Whitespace-only changes.
Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,64 @@
1+
2+
from typing import TYPE_CHECKING
3+
if TYPE_CHECKING:
4+
from portpy.photon.influence_matrix import InfluenceMatrix
5+
import numpy as np
6+
import scipy
7+
try:
8+
import pywt
9+
except ImportError:
10+
pass
11+
12+
13+
def get_low_dim_basis(inf_matrix: InfluenceMatrix, compression: str = 'wavelet'):
14+
"""
15+
:param inf_matrix: an object of class InfluenceMatrix for the specified plan
16+
:param compression: the compression method
17+
:type compression: str
18+
:return: a list that contains the dimension reduction basis in the format of array(float)
19+
"""
20+
21+
low_dim_basis = {}
22+
num_of_beams = len(inf_matrix.beamlets_dict)
23+
num_of_beamlets = inf_matrix.beamlets_dict[num_of_beams - 1]['end_beamlet_idx'] + 1
24+
beam_id = [inf_matrix.beamlets_dict[i]['beam_id'] for i in range(num_of_beams)]
25+
beamlets = inf_matrix.get_bev_2d_grid(beam_id=beam_id)
26+
index_position = list()
27+
for ind in range(num_of_beams):
28+
low_dim_basis[beam_id[ind]] = []
29+
for i in range(inf_matrix.beamlets_dict[ind]['start_beamlet_idx'],
30+
inf_matrix.beamlets_dict[ind]['end_beamlet_idx'] + 1):
31+
index_position.append((np.where(beamlets[ind] == i)[0][0], np.where(beamlets[ind] == i)[1][0]))
32+
if compression == 'wavelet':
33+
max_dim_0 = np.max([beamlets[ind].shape[0] for ind in range(num_of_beams)])
34+
max_dim_1 = np.max([beamlets[ind].shape[1] for ind in range(num_of_beams)])
35+
beamlet_2d_grid = np.zeros((int(np.ceil(max_dim_0 / 2)), int(np.ceil(max_dim_1 / 2))))
36+
for row in range(beamlet_2d_grid.shape[0]):
37+
for col in range(beamlet_2d_grid.shape[1]):
38+
beamlet_2d_grid[row][col] = 1
39+
approximation_coeffs = pywt.idwt2((beamlet_2d_grid, (None, None, None)), 'sym4',
40+
mode='periodization')
41+
horizontal_coeffs = pywt.idwt2((None, (beamlet_2d_grid, None, None)), 'sym4', mode='periodization')
42+
for b in range(num_of_beams):
43+
if ((2 * row + 1 < beamlets[b].shape[0] and 2 * col + 1 < beamlets[b].shape[1] and
44+
beamlets[b][2 * row + 1][2 * col + 1] != -1) or
45+
(2 * row + 1 < beamlets[b].shape[0] and 2 * col < beamlets[b].shape[1] and
46+
beamlets[b][2 * row + 1][2 * col] != -1) or
47+
(2 * row < beamlets[b].shape[0] and 2 * col + 1 < beamlets[b].shape[1] and
48+
beamlets[b][2 * row][2 * col + 1] != -1) or
49+
(2 * row < beamlets[b].shape[0] and 2 * col < beamlets[b].shape[1] and
50+
beamlets[b][2 * row][2 * col] != -1)):
51+
approximation = np.zeros(num_of_beamlets)
52+
horizontal = np.zeros(num_of_beamlets)
53+
for ind in range(inf_matrix.beamlets_dict[b]['start_beamlet_idx'],
54+
inf_matrix.beamlets_dict[b]['end_beamlet_idx'] + 1):
55+
approximation[ind] = approximation_coeffs[index_position[ind]]
56+
horizontal[ind] = horizontal_coeffs[index_position[ind]]
57+
low_dim_basis[beam_id[b]].append(np.transpose(np.stack([approximation, horizontal])))
58+
beamlet_2d_grid[row][col] = 0
59+
for b in beam_id:
60+
low_dim_basis[b] = np.concatenate(low_dim_basis[b], axis=1)
61+
u, s, vh = scipy.sparse.linalg.svds(low_dim_basis[b], k=min(low_dim_basis[b].shape[0], low_dim_basis[b].shape[1]) - 1)
62+
ind = np.where(s > 0.0001)
63+
low_dim_basis[b] = u[:, ind[0]]
64+
return np.concatenate([low_dim_basis[b] for b in beam_id], axis=1)
Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
import numpy as np
2+
import scipy
3+
4+
try:
5+
from sklearn.utils.extmath import randomized_svd
6+
except ImportError:
7+
pass
8+
9+
10+
def get_sparse_plus_low_rank(A: np.ndarray, thresold_perc: float = 1, rank: int = 5):
11+
"""
12+
:param A: dose influence matrix
13+
:param thresold_perc: thresold percentage. Default to 1% of max(A)
14+
:type rank: rank of L = A-S.
15+
:returns: S, H, W using randomized svd
16+
"""
17+
tol = np.max(A) * thresold_perc * 0.01
18+
S = np.where(A > tol, A, 0)
19+
if rank == 0:
20+
S = scipy.sparse.csr_matrix(S)
21+
return S
22+
else:
23+
print('Running svd..')
24+
[U, svd_S, V] = randomized_svd(A - S, n_components=rank + 1, random_state=0)
25+
print('svd done!')
26+
H = U[:, :rank]
27+
W = np.diag(svd_S[:rank]) @ V[:rank, :]
28+
S = scipy.sparse.csr_matrix(S)
29+
return S, H, W

examples/fluence_map_compress_wavelets.ipynb renamed to examples/fluence_wavelets.ipynb

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
"outputs": [],
2727
"source": [
2828
"import portpy.photon as pp\n",
29-
"from compress_rtp.compress_rtp_optimization import CompressRTPOptimization\n",
29+
"from compress_rtp.utils.get_low_dim_basis import get_low_dim_basis\n",
3030
"import cvxpy as cp\n",
3131
"import matplotlib.pyplot as plt"
3232
]
@@ -124,7 +124,7 @@
124124
"my_plan = pp.Plan(ct=ct, structs=structs, beams=beams, inf_matrix=inf_matrix, clinical_criteria=clinical_criteria)\n",
125125
"\n",
126126
"# create cvxpy problem using the clinical criteria and optimization parameters\n",
127-
"opt = CompressRTPOptimization(my_plan, opt_params=opt_params)\n",
127+
"opt = pp.Optimization(my_plan, opt_params=opt_params)\n",
128128
"opt.create_cvxpy_problem()\n",
129129
"sol_no_quad_no_wav = opt.solve(solver='MOSEK', verbose=False)"
130130
]
@@ -145,7 +145,7 @@
145145
"outputs": [],
146146
"source": [
147147
"# creating the wavelet incomplete basis representing a low dimensional subspace for dimension reduction\n",
148-
"wavelet_basis = opt.get_low_dim_basis(inf_matrix=inf_matrix, 'wavelet')\n",
148+
"wavelet_basis = get_low_dim_basis(inf_matrix=inf_matrix, compression='wavelet')\n",
149149
"# Smoothness Constraint\n",
150150
"y = cp.Variable(wavelet_basis.shape[1])\n",
151151
"opt.constraints += [wavelet_basis @ y == opt.vars['x']]\n",
Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
"""
66

77
import portpy.photon as pp
8-
from compress_rtp.compress_rtp_optimization import CompressRTPOptimization
8+
from compress_rtp.utils.get_low_dim_basis import get_low_dim_basis
99
import cvxpy as cp
1010
import matplotlib.pyplot as plt
1111

@@ -64,14 +64,14 @@ def ex_wavelet():
6464
my_plan = pp.Plan(ct=ct, structs=structs, beams=beams, inf_matrix=inf_matrix, clinical_criteria=clinical_criteria)
6565

6666
# create cvxpy problem using the clinical criteria and optimization parameters
67-
opt = CompressRTPOptimization(my_plan, opt_params=opt_params)
67+
opt = pp.Optimization(my_plan, opt_params=opt_params)
6868
opt.create_cvxpy_problem()
6969
sol_no_quad_no_wav = opt.solve(solver='MOSEK', verbose=False)
7070

7171
# - With wavelet constraint
7272

7373
# creating the wavelet incomplete basis representing a low dimensional subspace for dimension reduction
74-
wavelet_basis = opt.get_low_dim_basis()
74+
wavelet_basis = get_low_dim_basis(inf_matrix=inf_matrix, compression='wavelet')
7575
# Smoothness Constraint
7676
y = cp.Variable(wavelet_basis.shape[1])
7777
opt.constraints += [wavelet_basis @ y == opt.vars['x']]
@@ -88,7 +88,7 @@ def ex_wavelet():
8888
opt_params['objective_functions'][i]['weight'] = 10
8989

9090
# create cvxpy problem using the clinical criteria and optimization parameters
91-
opt = CompressRTPOptimization(my_plan, opt_params=opt_params)
91+
opt = pp.Optimization(my_plan, opt_params=opt_params)
9292
opt.create_cvxpy_problem()
9393

9494
sol_quad_no_wav = opt.solve(solver='MOSEK', verbose=False)
@@ -154,7 +154,7 @@ def ex_wavelet():
154154
155155
'''
156156
# visualize plan metrics based upon clinical criteria
157-
pp.Evaluation.plan_metrics(my_plan,
157+
pp.Evaluation.display_clinical_criteria(my_plan,
158158
sol=[sol_no_quad_no_wav, sol_no_quad_with_wav, sol_quad_no_wav, sol_quad_with_wav],
159159
sol_names=['no_quad_no_wav', 'no_quad_with_wav', 'quad_no_wav', 'quad_with_wav'])
160160

Lines changed: 14 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -6,10 +6,9 @@
66
import os
77
import portpy.photon as pp
88
from compress_rtp.compress_rtp_optimization import CompressRTPOptimization
9+
from compress_rtp.utils.get_sparse_plus_low_rank import get_sparse_plus_low_rank
910
import matplotlib.pyplot as plt
1011
from copy import deepcopy
11-
import numpy as np
12-
import scipy
1312

1413

1514
def sparse_plus_low_rank():
@@ -25,7 +24,7 @@ def sparse_plus_low_rank():
2524
data = pp.DataExplorer(data_dir=data_dir)
2625

2726
# pick a patient from the existing patient list to get detailed info (e.g., beam angles, structures).
28-
data.patient_id = 'Lung_Patient_3'
27+
data.patient_id = 'Lung_Patient_2'
2928

3029
ct = pp.CT(data)
3130
structs = pp.Structures(data)
@@ -71,18 +70,16 @@ def sparse_plus_low_rank():
7170
# run optimization with naive thresold of 1% of max(A) and no low rank
7271
# create cvxpy problem using the clinical criteria and optimization parameters
7372
A = deepcopy(inf_matrix.A)
74-
tol = np.max(A) * 1 * 0.01
75-
S = np.where(A > tol, A, 0)
76-
S = scipy.sparse.csr_matrix(S)
73+
S = get_sparse_plus_low_rank(A=A, thresold_perc=1, rank=0)
7774
inf_matrix.A = S
7875
opt = pp.Optimization(my_plan, inf_matrix=inf_matrix, opt_params=opt_params)
7976
opt.create_cvxpy_problem()
8077
sol_sparse = opt.solve(solver='MOSEK', verbose=True)
8178

8279
# run optimization with thresold of 1% and rank 5
8380
# create cvxpy problem using the clinical criteria and optimization parameters
81+
S, H, W = get_sparse_plus_low_rank(A=A, thresold_perc=1, rank=5)
8482
opt = CompressRTPOptimization(my_plan, opt_params=opt_params)
85-
S, H, W = opt.get_sparse_plus_low_rank(A=A, thresold_perc=1, rank=5)
8683
opt.create_cvxpy_problem_compressed(S=S, H=H, W=W)
8784

8885
# run imrt fluence map optimization using cvxpy and one of the supported solvers and save the optimal solution in sol
@@ -95,22 +92,19 @@ def sparse_plus_low_rank():
9592
9693
"""
9794

98-
fig, ax = plt.subplots(figsize=(12, 8))
95+
fig, ax = plt.subplots(1, 2, figsize=(20, 8))
9996
struct_names = ['PTV', 'ESOPHAGUS', 'HEART', 'CORD', 'LUNGS_NOT_GTV']
10097
dose_1d_sparse = (S @ sol_sparse['optimal_intensity']) * my_plan.get_num_of_fractions()
101-
dose_1d_full = (A @ sol_sparse['optimal_intensity']) * my_plan.get_num_of_fractions()
102-
ax = pp.Visualization.plot_dvh(my_plan, dose_1d=dose_1d_sparse, struct_names=struct_names, style='dotted', ax=ax, norm_flag=True)
103-
ax = pp.Visualization.plot_dvh(my_plan, dose_1d=dose_1d_full, struct_names=struct_names, style='solid', ax=ax, norm_flag=True)
104-
ax.set_title("sparse_vs_full")
105-
plt.show(block=False)
98+
dose_1d_full_sparse = (A @ sol_sparse['optimal_intensity']) * my_plan.get_num_of_fractions()
99+
ax0 = pp.Visualization.plot_dvh(my_plan, dose_1d=dose_1d_sparse, struct_names=struct_names, style='dotted', ax=ax[0], norm_flag=True)
100+
ax0 = pp.Visualization.plot_dvh(my_plan, dose_1d=dose_1d_full_sparse, struct_names=struct_names, style='solid', ax=ax0, norm_flag=True)
101+
ax0.set_title("sparse_vs_full")
106102

107-
fig, ax = plt.subplots(figsize=(12, 8))
108-
struct_names = ['PTV', 'ESOPHAGUS', 'HEART', 'CORD', 'LUNGS_NOT_GTV']
109103
dose_1d_slr = (S @ sol_slr['optimal_intensity'] + H @ (W @ sol_slr['optimal_intensity'])) * my_plan.get_num_of_fractions()
110-
dose_1d_full = (A @ sol_slr['optimal_intensity']) * my_plan.get_num_of_fractions()
111-
ax = pp.Visualization.plot_dvh(my_plan, dose_1d=dose_1d_slr, struct_names=struct_names, style='dashed', ax=ax, norm_flag=True)
112-
ax = pp.Visualization.plot_dvh(my_plan, dose_1d=dose_1d_full, struct_names=struct_names, style='solid', ax=ax, norm_flag=True)
113-
ax.set_title("slr_vs_full")
104+
dose_1d_full_slr = (A @ sol_slr['optimal_intensity']) * my_plan.get_num_of_fractions()
105+
ax1 = pp.Visualization.plot_dvh(my_plan, dose_1d=dose_1d_slr, struct_names=struct_names, style='dashed', ax=ax[1], norm_flag=True)
106+
ax1 = pp.Visualization.plot_dvh(my_plan, dose_1d=dose_1d_full_slr, struct_names=struct_names, style='solid', ax=ax1, norm_flag=True)
107+
ax1.set_title("slr_vs_full")
114108
plt.show(block=False)
115109

116110
"""
@@ -122,7 +116,7 @@ def sparse_plus_low_rank():
122116
"""
123117

124118
# visualize plan metrics based upon clinical criteria
125-
pp.Evaluation.display_clinical_criteria(my_plan, dose_1d=[dose_1d_sparse, dose_1d_slr], in_browser=True)
119+
pp.Evaluation.display_clinical_criteria(my_plan, dose_1d=[dose_1d_full_sparse, dose_1d_full_slr], sol_names=['Without compression', 'With compression'])
126120

127121
"""
128122
5) saving and loading the plan for future use (utils)

0 commit comments

Comments
 (0)