@@ -82,11 +82,16 @@ def plot_2d(X, y, w, flat=True, alpha=None, show_noiseless=True):
8282
8383 fig = plt .figure (figsize = (6.7 , 2.5 ))
8484
85+ #####################
8586 # left plot: y = f(x)
86- ax = fig .add_subplot (121 , projection = '3d' , computed_zorder = False )
87+
88+ try : # computed_zorder is only available in matplotlib >= 3.4
89+ ax = fig .add_subplot (121 , projection = '3d' , computed_zorder = False )
90+ except AttributeError :
91+ ax = fig .add_subplot (121 , projection = '3d' )
8792
8893 # to help matplotlib displays scatter points behind any surface, we
89- # first plot the point below, then the surface, then the ponts above,
94+ # first plot the point below, then the surface, then the points above,
9095 # and use computed_zorder=False.
9196 above = y > X @ w
9297 ax .scatter3D (X [~ above , 0 ], X [~ above , 1 ], y [~ above ], alpha = 0.5 , color = "C0" )
@@ -103,6 +108,7 @@ def plot_2d(X, y, w, flat=True, alpha=None, show_noiseless=True):
103108 ax .set (xlabel = "X[:, 0]" , ylabel = "X[:, 1]" , zlabel = "y" ,
104109 zlim = [yy .min (), yy .max ()])
105110
111+ #########################
106112 # right plot: loss = f(w)
107113 if flat :
108114 ax = fig .add_subplot (122 )
0 commit comments