Skip to content

Commit 65628e3

Browse files
yeldarbyclaude
andcommitted
feat(cli): train start auto-exports version in required format
Before training, the handler now: 1. Determines the required export format for the model type using get_model_format() (e.g., rfdetr needs coco, yolov8 needs yolov5pytorch) 2. Checks if the version is still generating; waits with progress updates 3. Checks if the version has the required export; triggers and polls if not 4. Then starts training Also improves the "Unknown error" from the train API — adds a hint suggesting the version may not be exported yet. This prevents the confusing failure mode where `train start` returns "Unknown error" because the server expects an export that doesn't exist. 405 tests pass, all linting clean. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent 17c81cb commit 65628e3

1 file changed

Lines changed: 77 additions & 1 deletion

File tree

roboflow/cli/handlers/train.py

Lines changed: 77 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,10 @@ def _start(args: argparse.Namespace) -> None:
9292
output_error(args, "No API key found.", hint="Set ROBOFLOW_API_KEY or run 'roboflow auth login'.", exit_code=2)
9393
return
9494

95+
# Ensure the version has the required export format before training
96+
if args.model_type:
97+
_ensure_export(args, api_key, workspace_url, project_slug, str(args.version_number), args.model_type)
98+
9599
try:
96100
rfapi.start_version_training(
97101
api_key,
@@ -104,7 +108,16 @@ def _start(args: argparse.Namespace) -> None:
104108
epochs=args.epochs,
105109
)
106110
except rfapi.RoboflowError as exc:
107-
output_error(args, str(exc))
111+
err_str = str(exc)
112+
if "Unknown error" in err_str:
113+
output_error(
114+
args,
115+
"Training failed. The server returned an unexpected error.",
116+
hint="Ensure the version is fully generated and exported. "
117+
"Run 'roboflow version export -p <project> <version> -f coco' first.",
118+
)
119+
else:
120+
output_error(args, err_str)
108121
return
109122

110123
data = {
@@ -113,3 +126,66 @@ def _start(args: argparse.Namespace) -> None:
113126
"version": args.version_number,
114127
}
115128
output(args, data, text=f"Training started for {project_slug} version {args.version_number}.")
129+
130+
131+
def _ensure_export(args, api_key, workspace_url, project_slug, version_str, model_type):
132+
"""Check if the version has the required export format; trigger and poll if not."""
133+
import sys
134+
import time
135+
136+
from roboflow.adapters import rfapi
137+
from roboflow.util.versions import get_model_format
138+
139+
required_format = get_model_format(model_type)
140+
141+
try:
142+
version_data = rfapi.get_version(api_key, workspace_url, project_slug, version_str)
143+
except rfapi.RoboflowError:
144+
return # Can't check; let the train call handle errors
145+
146+
version_info = version_data.get("version", {})
147+
148+
# Check if still generating
149+
if version_info.get("generating"):
150+
if not getattr(args, "quiet", False):
151+
print(f"Version is still generating ({version_info.get('progress', 0):.0%})... waiting.", file=sys.stderr)
152+
while True:
153+
time.sleep(5)
154+
try:
155+
version_data = rfapi.get_version(api_key, workspace_url, project_slug, version_str, nocache=True)
156+
version_info = version_data.get("version", {})
157+
if not version_info.get("generating"):
158+
break
159+
if not getattr(args, "quiet", False):
160+
print(
161+
f" Generating... {version_info.get('progress', 0):.0%}",
162+
file=sys.stderr,
163+
)
164+
except rfapi.RoboflowError:
165+
break
166+
167+
# Check if export exists
168+
exports = version_info.get("exports", [])
169+
if required_format not in exports:
170+
if not getattr(args, "quiet", False):
171+
print(
172+
f"Exporting version in {required_format} format (required for {model_type})...",
173+
file=sys.stderr,
174+
)
175+
try:
176+
rfapi.get_version_export(api_key, workspace_url, project_slug, version_str, required_format)
177+
except rfapi.RoboflowError:
178+
pass # Export may have been triggered; poll below
179+
180+
# Poll until export is ready
181+
for _ in range(120): # Up to 10 minutes
182+
time.sleep(5)
183+
try:
184+
version_data = rfapi.get_version(api_key, workspace_url, project_slug, version_str, nocache=True)
185+
current_exports = version_data.get("version", {}).get("exports", [])
186+
if required_format in current_exports:
187+
if not getattr(args, "quiet", False):
188+
print(" Export complete.", file=sys.stderr)
189+
return
190+
except rfapi.RoboflowError:
191+
pass

0 commit comments

Comments
 (0)