|
123 | 123 | "\n" |
124 | 124 | ] |
125 | 125 | }, |
| 126 | + { |
| 127 | + "cell_type": "code", |
| 128 | + "execution_count": null, |
| 129 | + "metadata": {}, |
| 130 | + "outputs": [], |
| 131 | + "source": [ |
| 132 | + "#| export\n", |
| 133 | + "import numpy as np\n", |
| 134 | + "from typing import Union, Optional\n", |
| 135 | + "\n", |
| 136 | + "def prop_dataset(group:Union[list, tuple, np.ndarray, dict], #Accepts lists, tuples, or numpy ndarrays of numeric types.\n", |
| 137 | + " group_names: Optional[list] = None):\n", |
| 138 | + " '''\n", |
| 139 | + " Convenient function to generate a dataframe of binary data.\n", |
| 140 | + " '''\n", |
| 141 | + " import pandas as pd\n", |
| 142 | + "\n", |
| 143 | + " if isinstance(group, dict):\n", |
| 144 | + " # If group_names is not provided, use the keys of the dict as group_names\n", |
| 145 | + " if group_names is None:\n", |
| 146 | + " group_names = list(group.keys())\n", |
| 147 | + " elif not set(group_names) == set(group.keys()):\n", |
| 148 | + " # Check if the group_names provided is the same as the keys of the dict\n", |
| 149 | + " raise ValueError('group_names must be the same as the keys of the dict.')\n", |
| 150 | + " # Check if the values in the dict are numeric\n", |
| 151 | + " if not all([isinstance(group[name], (list, tuple, np.ndarray)) for name in group_names]):\n", |
| 152 | + " raise ValueError('group must be a dict of lists, tuples, or numpy ndarrays of numeric types.')\n", |
| 153 | + " # Check if the values in the dict only have two elements under each parent key\n", |
| 154 | + " if not all([len(group[name]) == 2 for name in group_names]):\n", |
| 155 | + " raise ValueError('Each parent key should have only two elements.')\n", |
| 156 | + " group_val = group\n", |
| 157 | + "\n", |
| 158 | + " else:\n", |
| 159 | + " if group_names is None:\n", |
| 160 | + " raise ValueError('group_names must be provided if group is not a dict.')\n", |
| 161 | + " # Check if the length of group is two times of the length of group_names\n", |
| 162 | + " if not len(group) == 2 * len(group_names):\n", |
| 163 | + " raise ValueError('The length of group must be two times of the length of group_names.')\n", |
| 164 | + " group_val = {group_names[i]: [group[i*2], group[i*2+1]] for i in range(len(group_names))}\n", |
| 165 | + "\n", |
| 166 | + " # Check if the sum of values in group_val under each key are the same\n", |
| 167 | + " if not all([sum(group_val[name]) == sum(group_val[group_names[0]]) for name in group_val.keys()]):\n", |
| 168 | + " raise ValueError('The sum of values under each key must be the same.')\n", |
| 169 | + " \n", |
| 170 | + " id_col = pd.Series(range(1, sum(group_val[group_names[0]])+1))\n", |
| 171 | + " \n", |
| 172 | + " final_df = pd.DataFrame()\n", |
| 173 | + "\n", |
| 174 | + " for name in group_val.keys():\n", |
| 175 | + " col = np.repeat(0, group_val[name][0]).tolist() + np.repeat(1, group_val[name][1]).tolist()\n", |
| 176 | + " df = pd.DataFrame({name:col})\n", |
| 177 | + " final_df = pd.concat([final_df, df], axis=1)\n", |
| 178 | + "\n", |
| 179 | + " final_df['ID'] = id_col\n", |
| 180 | + "\n", |
| 181 | + " return final_df" |
| 182 | + ] |
| 183 | + }, |
126 | 184 | { |
127 | 185 | "cell_type": "markdown", |
128 | 186 | "metadata": {}, |
|
0 commit comments