|
10 | 10 | import os |
11 | 11 | import time |
12 | 12 | import threading |
13 | | -from concurrent.futures import ThreadPoolExecutor, as_completed |
| 13 | +from concurrent.futures import ProcessPoolExecutor, ThreadPoolExecutor, as_completed |
14 | 14 | from typing import List, Tuple, Dict, Any, Optional |
15 | 15 |
|
16 | 16 | import dotenv |
@@ -78,32 +78,31 @@ def load_pairs_dataset(dataset: Dataset) -> Dict[str, Tuple[str, Image.Image, Im |
78 | 78 | pairs[key2] = (instruction, input_image, data["output_images"][1].convert("RGB")) |
79 | 79 | return pairs |
80 | 80 |
|
| 81 | +def _load_item(data: dict) -> list[tuple[str, tuple[str, Image.Image, Image.Image]]]: |
| 82 | + key1, key2 = data["key"] |
| 83 | + instruction = data["instruction"] |
| 84 | + |
| 85 | + input_image = data["input_image"].convert("RGB") |
| 86 | + output_image1 = data["output_images"][0].convert("RGB") |
| 87 | + output_image2 = data["output_images"][1].convert("RGB") |
| 88 | + |
| 89 | + return [ |
| 90 | + (key1, (instruction, input_image, output_image1)), |
| 91 | + (key2, (instruction, input_image, output_image2)), |
| 92 | + ] |
81 | 93 |
|
82 | 94 | def load_pairs_dataset_multithreaded(dataset: Dataset, max_workers: int = None) -> Dict[str, Tuple[str, Image.Image, Image.Image]]: |
83 | 95 | if max_workers is None: |
84 | | - max_workers = min(32, (os.cpu_count() or 1) * 5) |
| 96 | + # max_workers = min(32, (os.cpu_count() or 1) * 5) |
| 97 | + max_workers = os.cpu_count() or 1 |
85 | 98 |
|
86 | 99 | pairs = {} |
87 | 100 |
|
88 | | - |
89 | | - def _process_item(data: dict) -> list[tuple[str, tuple[str, Image.Image, Image.Image]]]: |
90 | | - key1, key2 = data["key"] |
91 | | - instruction = data["instruction"] |
92 | | - |
93 | | - input_image = data["input_image"].convert("RGB") |
94 | | - output_image1 = data["output_images"][0].convert("RGB") |
95 | | - output_image2 = data["output_images"][1].convert("RGB") |
96 | | - |
97 | | - return [ |
98 | | - (key1, (instruction, input_image, output_image1)), |
99 | | - (key2, (instruction, input_image, output_image2)), |
100 | | - ] |
101 | | - |
102 | 101 | print(f"Processing dataset (length: {len(dataset)}) with {max_workers} threads", flush=True) |
103 | 102 |
|
104 | | - with ThreadPoolExecutor(max_workers=max_workers) as executor: |
| 103 | + with ProcessPoolExecutor(max_workers=max_workers) as executor: |
105 | 104 | results_iterator = tqdm( |
106 | | - executor.map(_process_item, dataset), |
| 105 | + executor.map(_load_item, dataset), |
107 | 106 | total=len(dataset), |
108 | 107 | desc="Processing dataset with multiple threads" |
109 | 108 | ) |
@@ -237,10 +236,12 @@ def main(args): |
237 | 236 | score1 = all_scores[key1][dimension] |
238 | 237 | score2 = all_scores[key2][dimension] |
239 | 238 | data["score"] = [score1, score2] |
| 239 | + |
| 240 | + input_image_path = os.path.join(args.result_dir, "images", f"{key1}_input.png") |
| 241 | + output_image_path1 = os.path.join(args.result_dir, "images", f"{key1}") |
| 242 | + output_image_path2 = os.path.join(args.result_dir, "images", f"{key2}.png") |
240 | 243 |
|
241 | | - input_image_path = os.path.join(args.result_dir, "input_images", f"{key1}_input.png") |
242 | | - output_image_path1 = os.path.join(args.result_dir, "output_images", f"{key1}_output0.png") |
243 | | - output_image_path2 = os.path.join(args.result_dir, "output_images", f"{key2}_output1.png") |
| 244 | + os.makedirs(os.path.dirname(input_image_path), exist_ok=True) |
244 | 245 |
|
245 | 246 | data['input_image'].save(input_image_path) |
246 | 247 | data['output_images'][0].save(output_image_path1) |
|
0 commit comments