Skip to content

Commit a27f6c1

Browse files
committed
Fix Upgrading Error
1 parent 3164aee commit a27f6c1

9 files changed

Lines changed: 66 additions & 30 deletions

File tree

.github/workflows/test-pytest.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ jobs:
88
- uses: actions/checkout@v3
99
- uses: actions/setup-python@v4
1010
with:
11-
python-version: 3.8
11+
python-version: 3.9
1212
cache: "pip"
1313
cache-dependency-path: settings.ini
1414
- name: Run pytest

dabest/_modidx.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,9 @@
6565
'dabest/forest_plot.py'),
6666
'dabest.forest_plot.forest_plot': ('API/forest_plot.html#forest_plot', 'dabest/forest_plot.py'),
6767
'dabest.forest_plot.load_plot_data': ('API/forest_plot.html#load_plot_data', 'dabest/forest_plot.py')},
68-
'dabest.misc_tools': { 'dabest.misc_tools.get_varname': ('API/misc_tools.html#get_varname', 'dabest/misc_tools.py'),
68+
'dabest.misc_tools': { 'dabest.misc_tools.get_unique_categories': ( 'API/misc_tools.html#get_unique_categories',
69+
'dabest/misc_tools.py'),
70+
'dabest.misc_tools.get_varname': ('API/misc_tools.html#get_varname', 'dabest/misc_tools.py'),
6971
'dabest.misc_tools.merge_two_dicts': ('API/misc_tools.html#merge_two_dicts', 'dabest/misc_tools.py'),
7072
'dabest.misc_tools.print_greeting': ('API/misc_tools.html#print_greeting', 'dabest/misc_tools.py'),
7173
'dabest.misc_tools.unpack_and_add': ('API/misc_tools.html#unpack_and_add', 'dabest/misc_tools.py')},

dabest/misc_tools.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,12 @@
11
# AUTOGENERATED! DO NOT EDIT! File to edit: ../nbs/API/misc_tools.ipynb.
22

33
# %% auto 0
4-
__all__ = ['merge_two_dicts', 'unpack_and_add', 'print_greeting', 'get_varname']
4+
__all__ = ['merge_two_dicts', 'unpack_and_add', 'print_greeting', 'get_varname', 'get_unique_categories']
55

66
# %% ../nbs/API/misc_tools.ipynb 4
77
import datetime as dt
8+
import numpy as np
9+
import pandas as pd
810
from numpy import repeat
911

1012
# %% ../nbs/API/misc_tools.ipynb 5
@@ -68,3 +70,15 @@ def get_varname(obj):
6870
if len(matching_vars) > 0:
6971
return matching_vars[0]
7072
return ""
73+
74+
def get_unique_categories(names):
75+
"""
76+
Extract unique categories from various input types.
77+
"""
78+
if isinstance(names, np.ndarray):
79+
return names # numpy.unique() returns a sorted array
80+
elif isinstance(names, (pd.Categorical, pd.Series)):
81+
return names.cat.categories if hasattr(names, 'cat') else names.unique()
82+
else:
83+
# For dict_keys and other iterables
84+
return np.unique(list(names))

dabest/plot_tools.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1136,9 +1136,10 @@ def _swarm(
11361136
raise ValueError("`dsize` must be a scalar or float.")
11371137

11381138
# Sorting algorithm based off of: https://github.com/mgymrek/pybeeswarm
1139-
points_data = pd.DataFrame(
1140-
{"y": [yval * 1.0 / dsize for yval in values], "x": [0] * len(values)}
1141-
)
1139+
points_data = pd.DataFrame({
1140+
"y": [yval * 1.0 / dsize for yval in values],
1141+
"x": np.zeros(len(values), dtype=float) # Initialize with float zeros
1142+
})
11421143
for i in range(1, points_data.shape[0]):
11431144
y_i = points_data["y"].values[i]
11441145
points_placed = points_data[0:i]
@@ -1271,7 +1272,7 @@ def plot(
12711272
0 # x-coordinate of center of each individual swarm of the swarm plot
12721273
)
12731274
x_tick_tabels = []
1274-
for group_i, values_i in self.__data_copy.groupby(self.__x):
1275+
for group_i, values_i in self.__data_copy.groupby(self.__x, observed=False):
12751276
x_new = []
12761277
values_i_y = values_i[self.__y]
12771278
x_offset = self._swarm(

dabest/plotter.py

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ def effectsize_df_plotter(effectsize_df, **plot_kwargs):
5454
fontsize_contrastxlabel=12, fontsize_contrastylabel=12,
5555
fontsize_delta2label=12
5656
"""
57-
from .misc_tools import merge_two_dicts
57+
from .misc_tools import merge_two_dicts, get_unique_categories
5858
from .plot_tools import (
5959
halfviolin,
6060
get_swarm_spans,
@@ -298,14 +298,16 @@ def effectsize_df_plotter(effectsize_df, **plot_kwargs):
298298
raise ValueError(err1 + err2)
299299

300300
if custom_pal is None and color_col is None:
301+
categories = get_unique_categories(names)
302+
301303
swarm_colors = [sns.desaturate(c, swarm_desat) for c in unsat_colors]
302-
plot_palette_raw = dict(zip(names.categories, swarm_colors))
303-
304304
bar_color = [sns.desaturate(c, bar_desat) for c in unsat_colors]
305-
plot_palette_bar = dict(zip(names.categories, bar_color))
306-
307305
contrast_colors = [sns.desaturate(c, contrast_desat) for c in unsat_colors]
308-
plot_palette_contrast = dict(zip(names.categories, contrast_colors))
306+
307+
308+
plot_palette_raw = dict(zip(categories, swarm_colors))
309+
plot_palette_bar = dict(zip(categories, bar_color))
310+
plot_palette_contrast = dict(zip(categories, contrast_colors))
309311

310312
# For Sankey Diagram plot, no need to worry about the color, each bar will have the same two colors
311313
# default color palette will be set to "hls"
@@ -1081,10 +1083,10 @@ def effectsize_df_plotter(effectsize_df, **plot_kwargs):
10811083
)
10821084
elif effect_size_type == "median_diff":
10831085
control_group_summary = (
1084-
plot_data.groupby(xvar).median().loc[current_control, yvar]
1086+
plot_data.groupby(xvar).median(numeric_only=True).loc[current_control, yvar]
10851087
)
10861088
test_group_summary = (
1087-
plot_data.groupby(xvar).median().loc[current_group, yvar]
1089+
plot_data.groupby(xvar).median(numeric_only=True).loc[current_group, yvar]
10881090
)
10891091

10901092
if swarm_ylim is None:

nbs/API/misc_tools.ipynb

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,8 @@
5555
"source": [
5656
"#| export\n",
5757
"import datetime as dt\n",
58+
"import numpy as np\n",
59+
"import pandas as pd\n",
5860
"from numpy import repeat"
5961
]
6062
},
@@ -125,7 +127,19 @@
125127
" matching_vars = [k for k, v in globals().items() if v is obj]\n",
126128
" if len(matching_vars) > 0:\n",
127129
" return matching_vars[0]\n",
128-
" return \"\""
130+
" return \"\"\n",
131+
"\n",
132+
"def get_unique_categories(names):\n",
133+
" \"\"\"\n",
134+
" Extract unique categories from various input types.\n",
135+
" \"\"\"\n",
136+
" if isinstance(names, np.ndarray):\n",
137+
" return names # numpy.unique() returns a sorted array\n",
138+
" elif isinstance(names, (pd.Categorical, pd.Series)):\n",
139+
" return names.cat.categories if hasattr(names, 'cat') else names.unique()\n",
140+
" else:\n",
141+
" # For dict_keys and other iterables\n",
142+
" return np.unique(list(names))"
129143
]
130144
}
131145
],

nbs/API/plot_tools.ipynb

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1199,9 +1199,10 @@
11991199
" raise ValueError(\"`dsize` must be a scalar or float.\")\n",
12001200
"\n",
12011201
" # Sorting algorithm based off of: https://github.com/mgymrek/pybeeswarm\n",
1202-
" points_data = pd.DataFrame(\n",
1203-
" {\"y\": [yval * 1.0 / dsize for yval in values], \"x\": [0] * len(values)}\n",
1204-
" )\n",
1202+
" points_data = pd.DataFrame({\n",
1203+
" \"y\": [yval * 1.0 / dsize for yval in values],\n",
1204+
" \"x\": np.zeros(len(values), dtype=float) # Initialize with float zeros\n",
1205+
" })\n",
12051206
" for i in range(1, points_data.shape[0]):\n",
12061207
" y_i = points_data[\"y\"].values[i]\n",
12071208
" points_placed = points_data[0:i]\n",
@@ -1334,7 +1335,7 @@
13341335
" 0 # x-coordinate of center of each individual swarm of the swarm plot\n",
13351336
" )\n",
13361337
" x_tick_tabels = []\n",
1337-
" for group_i, values_i in self.__data_copy.groupby(self.__x):\n",
1338+
" for group_i, values_i in self.__data_copy.groupby(self.__x, observed=False):\n",
13381339
" x_new = []\n",
13391340
" values_i_y = values_i[self.__y]\n",
13401341
" x_offset = self._swarm(\n",

nbs/API/plotter.ipynb

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -113,7 +113,7 @@
113113
" fontsize_contrastxlabel=12, fontsize_contrastylabel=12,\n",
114114
" fontsize_delta2label=12\n",
115115
" \"\"\"\n",
116-
" from .misc_tools import merge_two_dicts\n",
116+
" from .misc_tools import merge_two_dicts, get_unique_categories\n",
117117
" from .plot_tools import (\n",
118118
" halfviolin,\n",
119119
" get_swarm_spans,\n",
@@ -357,14 +357,16 @@
357357
" raise ValueError(err1 + err2)\n",
358358
"\n",
359359
" if custom_pal is None and color_col is None:\n",
360+
" categories = get_unique_categories(names)\n",
361+
" \n",
360362
" swarm_colors = [sns.desaturate(c, swarm_desat) for c in unsat_colors]\n",
361-
" plot_palette_raw = dict(zip(names.categories, swarm_colors))\n",
362-
"\n",
363363
" bar_color = [sns.desaturate(c, bar_desat) for c in unsat_colors]\n",
364-
" plot_palette_bar = dict(zip(names.categories, bar_color))\n",
365-
"\n",
366364
" contrast_colors = [sns.desaturate(c, contrast_desat) for c in unsat_colors]\n",
367-
" plot_palette_contrast = dict(zip(names.categories, contrast_colors))\n",
365+
"\n",
366+
" \n",
367+
" plot_palette_raw = dict(zip(categories, swarm_colors))\n",
368+
" plot_palette_bar = dict(zip(categories, bar_color))\n",
369+
" plot_palette_contrast = dict(zip(categories, contrast_colors))\n",
368370
"\n",
369371
" # For Sankey Diagram plot, no need to worry about the color, each bar will have the same two colors\n",
370372
" # default color palette will be set to \"hls\"\n",
@@ -1140,10 +1142,10 @@
11401142
" )\n",
11411143
" elif effect_size_type == \"median_diff\":\n",
11421144
" control_group_summary = (\n",
1143-
" plot_data.groupby(xvar).median().loc[current_control, yvar]\n",
1145+
" plot_data.groupby(xvar).median(numeric_only=True).loc[current_control, yvar]\n",
11441146
" )\n",
11451147
" test_group_summary = (\n",
1146-
" plot_data.groupby(xvar).median().loc[current_group, yvar]\n",
1148+
" plot_data.groupby(xvar).median(numeric_only=True).loc[current_group, yvar]\n",
11471149
" )\n",
11481150
"\n",
11491151
" if swarm_ylim is None:\n",

settings.ini

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
repo = DABEST-python
44
lib_name = dabest
55
version = 2024.03.29
6-
min_python = 3.8
6+
min_python = 3.9
77
license = apache2
88

99
### nbdev ###
@@ -37,7 +37,7 @@ language = English
3737
status = 3
3838
user = acclab
3939

40-
requirements = fastcore pandas~=1.5.0 numpy~=1.23.5 matplotlib~=3.8.4 seaborn~=0.12.2 scipy~=1.9.3 datetime statsmodels lqrt
40+
requirements = fastcore pandas~=1.5.3 numpy~=1.26 matplotlib~=3.8.4 seaborn~=0.12.2 scipy~=1.12 datetime statsmodels lqrt
4141
dev_requirements = pytest~=7.2.1 pytest-mpl~=0.16.1
4242

4343
### Optional ###

0 commit comments

Comments
 (0)