Skip to content

Commit 6fa7039

Browse files
teunbrandclaude
andcommitted
Add automatic line segmentation for variable material aesthetics
Automatically detects when material aesthetics (stroke, linetype) vary within partition groups and converts to segmented rendering. Uses Vega-Lite transforms (window, flatten, calculate) instead of data restructuring. Implementation uses efficient vectorized Polars operations for group boundary detection and preserves the unified dataset architecture with proper source filter integration. Co-Authored-By: Claude Sonnet 4.5 <noreply@anthropic.com>
1 parent 4f0f94f commit 6fa7039

1 file changed

Lines changed: 301 additions & 0 deletions

File tree

src/writer/vegalite/layer.rs

Lines changed: 301 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -334,10 +334,180 @@ impl GeomRenderer for PathRenderer {
334334
// Line Renderer
335335
// =============================================================================
336336

337+
/// Find indices where group values change in a DataFrame
338+
///
339+
/// Returns a sorted vector of indices marking group boundaries. The first element
340+
/// is always 0 (start of first group), followed by indices where any of the
341+
/// group columns change value.
342+
fn find_group_boundaries(df: &DataFrame, group_columns: &[String]) -> Result<Vec<usize>> {
343+
use polars::prelude::*;
344+
345+
let n_rows = df.height();
346+
347+
if group_columns.is_empty() {
348+
// No grouping: treat entire dataset as one group
349+
return Ok(vec![0, n_rows]);
350+
}
351+
352+
if n_rows <= 1 {
353+
return Ok(vec![0, n_rows]);
354+
}
355+
356+
// Initialize change mask as all false (no changes)
357+
let mut change_mask = BooleanChunked::full("change_mask".into(), false, n_rows - 1);
358+
359+
// For each group column, OR its change mask into the accumulator
360+
for col_name in group_columns {
361+
let series = df.column(col_name).map_err(|e| {
362+
GgsqlError::InternalError(format!("Group column '{}' not found: {}", col_name, e))
363+
})?;
364+
365+
// Compare each row with the previous row
366+
// curr = series[1..n], prev = series[0..n-1]
367+
let curr = series.slice(1, n_rows - 1);
368+
let prev = series.slice(0, n_rows - 1);
369+
370+
// Get boolean mask where values differ
371+
let not_equal = curr.not_equal(&prev).map_err(|e| {
372+
GgsqlError::InternalError(format!("Failed to compare column '{}': {}", col_name, e))
373+
})?;
374+
375+
// OR with accumulator (change if this column OR any previous column changed)
376+
change_mask = &change_mask | &not_equal;
377+
}
378+
379+
// Extract indices where mask is true (offset by 1 since we compared with previous)
380+
let mut boundaries = vec![0];
381+
for (idx, changed) in change_mask.into_iter().enumerate() {
382+
if changed == Some(true) {
383+
boundaries.push(idx + 1);
384+
}
385+
}
386+
387+
// Add final boundary (end of data)
388+
boundaries.push(n_rows);
389+
390+
Ok(boundaries)
391+
}
392+
393+
/// Check if an aesthetic varies within any group segment
394+
///
395+
/// Uses precomputed group boundaries to efficiently check if the aesthetic
396+
/// has multiple distinct values within any group segment.
397+
fn aesthetic_varies_within_groups(
398+
df: &DataFrame,
399+
aesthetic_col: &str,
400+
group_boundaries: &[usize],
401+
) -> Result<bool> {
402+
let series = df.column(aesthetic_col).map_err(|e| {
403+
GgsqlError::InternalError(format!("Column '{}' not found: {}", aesthetic_col, e))
404+
})?;
405+
406+
// Check each group segment
407+
for window in group_boundaries.windows(2) {
408+
let start = window[0];
409+
let end = window[1];
410+
411+
if end - start < 2 {
412+
continue; // Single-row groups can't vary
413+
}
414+
415+
// Slice the series for this group and check uniqueness
416+
let segment = series.slice(start as i64, (end - start) as usize);
417+
let n_unique = segment.n_unique().map_err(|e| {
418+
GgsqlError::InternalError(format!("Failed to count unique values: {}", e))
419+
})?;
420+
421+
if n_unique > 1 {
422+
return Ok(true);
423+
}
424+
}
425+
426+
Ok(false)
427+
}
428+
429+
/// Metadata for segmented line rendering
430+
#[derive(Debug)]
431+
struct LineSegmentMetadata {
432+
partition_columns: Vec<String>,
433+
}
434+
337435
/// Renderer for line geom - preserves data order for correct line rendering
436+
///
437+
/// Automatically detects when material aesthetics (stroke, linetype) vary within
438+
/// partition groups and converts to segmented rendering using detail encoding.
338439
pub struct LineRenderer;
339440

340441
impl GeomRenderer for LineRenderer {
442+
fn prepare_data(
443+
&self,
444+
df: &DataFrame,
445+
layer: &Layer,
446+
_data_key: &str,
447+
binned_columns: &HashMap<String, Vec<f64>>,
448+
) -> Result<PreparedData> {
449+
// Identify material aesthetics that are column-mapped
450+
let material_aesthetics = ["stroke", "linetype"];
451+
let mut varying_aesthetics = Vec::new();
452+
453+
// Collect (aesthetic, column) pairs for material aesthetics that are mapped to columns
454+
let mapped_material_aesthetics: Vec<(&str, String)> = material_aesthetics
455+
.iter()
456+
.filter_map(|aesthetic| {
457+
if let Some(AestheticValue::Column { name: col, .. }) = layer.mappings.get(aesthetic) {
458+
Some((*aesthetic, col.clone()))
459+
} else {
460+
None
461+
}
462+
})
463+
.collect();
464+
465+
// Build list of partition columns EXCLUDING the material aesthetics we're checking
466+
// (we need to check if they vary within the other partition groups)
467+
let mut partition_columns: Vec<String> = layer
468+
.partition_by
469+
.iter()
470+
.filter(|col| !mapped_material_aesthetics.iter().any(|(_, c)| c == *col))
471+
.cloned()
472+
.collect();
473+
474+
// Compute group boundaries once (without the material aesthetics)
475+
let group_boundaries = find_group_boundaries(df, &partition_columns)?;
476+
477+
// Check each mapped material aesthetic for within-group variation
478+
for (aesthetic, col) in &mapped_material_aesthetics {
479+
// Check if this aesthetic varies within partition groups
480+
let varies = aesthetic_varies_within_groups(df, col, &group_boundaries)?;
481+
if varies {
482+
varying_aesthetics.push(*aesthetic);
483+
} else {
484+
// If it doesn't vary within groups, treat it as a partition column
485+
// (boundaries remain the same since the aesthetic changes only where partitions change)
486+
partition_columns.push(col.clone());
487+
}
488+
}
489+
490+
// Return the data with segmentation metadata if needed
491+
let values = if binned_columns.is_empty() {
492+
dataframe_to_values(df)?
493+
} else {
494+
dataframe_to_values_with_bins(df, binned_columns)?
495+
};
496+
497+
let needs_segmentation = !varying_aesthetics.is_empty();
498+
499+
if needs_segmentation {
500+
// Use Composite with empty component name so dataset key = data_key (not data_key + suffix)
501+
// This ensures the source filter works correctly with the unified dataset
502+
Ok(PreparedData::Composite {
503+
components: [("".to_string(), values)].iter().cloned().collect(),
504+
metadata: Box::new(LineSegmentMetadata { partition_columns }),
505+
})
506+
} else {
507+
Ok(PreparedData::Single { values })
508+
}
509+
}
510+
341511
fn modify_encoding(
342512
&self,
343513
encoding: &mut Map<String, Value>,
@@ -352,6 +522,137 @@ impl GeomRenderer for LineRenderer {
352522
);
353523
Ok(())
354524
}
525+
526+
fn finalize(
527+
&self,
528+
mut layer_spec: Value,
529+
_layer: &Layer,
530+
_data_key: &str,
531+
prepared: &PreparedData,
532+
) -> Result<Vec<Value>> {
533+
// Early return for standard line rendering
534+
let PreparedData::Composite { metadata, .. } = prepared else {
535+
return Ok(vec![layer_spec]);
536+
};
537+
538+
// Extract partition columns from metadata
539+
let metadata_any = metadata.as_ref() as &dyn Any;
540+
let partition_columns = if let Some(meta) = metadata_any.downcast_ref::<LineSegmentMetadata>() {
541+
&meta.partition_columns
542+
} else {
543+
return Err(GgsqlError::InternalError(
544+
"Invalid metadata type for segmented line".to_string(),
545+
));
546+
};
547+
548+
// Get position column names
549+
let x_col = naming::aesthetic_column("pos1");
550+
let y_col = naming::aesthetic_column("pos2");
551+
552+
// Segmented rendering using detail encoding:
553+
// 1. Create segment IDs (row_index serves as segment ID)
554+
// 2. Create next row's x/y values using window transform
555+
// 3. Flatten to create 2 rows per segment (point_index: 0=start, 1=end)
556+
// 4. Use calculate to pick current or next based on point_index
557+
// 5. Add segment ID to detail encoding
558+
559+
// Preserve existing transforms (e.g., source filter)
560+
let mut transforms = layer_spec
561+
.get("transform")
562+
.and_then(|t| t.as_array())
563+
.cloned()
564+
.unwrap_or_default();
565+
566+
// Step 1 & 2: Window transform to get next row's values
567+
let window_ops = vec![
568+
json!({
569+
"op": "lead",
570+
"field": x_col,
571+
"as": format!("{}_next", x_col)
572+
}),
573+
json!({
574+
"op": "lead",
575+
"field": y_col,
576+
"as": format!("{}_next", y_col)
577+
}),
578+
];
579+
580+
let mut window_transform = json!({
581+
"window": window_ops,
582+
"sort": [{"field": ROW_INDEX_COLUMN}]
583+
});
584+
585+
if !partition_columns.is_empty() {
586+
window_transform["groupby"] = json!(partition_columns);
587+
}
588+
589+
transforms.push(window_transform);
590+
591+
// Step 2b: Filter out last row in each group (no next point)
592+
transforms.push(json!({
593+
"filter": format!("datum.{}_next != null", x_col)
594+
}));
595+
596+
// Step 3: Flatten to create 2 rows per segment
597+
// Create a constant array [0, 1] to flatten
598+
transforms.push(json!({
599+
"calculate": "[0, 1]",
600+
"as": "__segment_points__"
601+
}));
602+
603+
transforms.push(json!({
604+
"flatten": ["__segment_points__"],
605+
"as": ["__point_index__"]
606+
}));
607+
608+
// Step 4: Calculate actual x/y based on point_index
609+
transforms.push(json!({
610+
"calculate": format!("datum.__point_index__ == 0 ? datum.{} : datum.{}_next", x_col, x_col),
611+
"as": format!("{}_final", x_col)
612+
}));
613+
614+
transforms.push(json!({
615+
"calculate": format!("datum.__point_index__ == 0 ? datum.{} : datum.{}_next", y_col, y_col),
616+
"as": format!("{}_final", y_col)
617+
}));
618+
619+
// Step 5: Create segment ID (use original row_index)
620+
transforms.push(json!({
621+
"calculate": format!("datum.{}", ROW_INDEX_COLUMN),
622+
"as": "__segment_id__"
623+
}));
624+
625+
layer_spec["transform"] = json!(transforms);
626+
// Don't set layer_spec["data"] - use the unified top-level dataset
627+
// The source filter transform will select the correct rows
628+
629+
// Update encodings to use final x/y and add segment_id to detail
630+
if let Some(encoding_obj) = layer_spec.get_mut("encoding") {
631+
if let Some(encoding_map) = encoding_obj.as_object_mut() {
632+
// Update x encoding to use x_final
633+
if let Some(x_enc) = encoding_map.get_mut("x") {
634+
if let Some(x_obj) = x_enc.as_object_mut() {
635+
x_obj.insert("field".to_string(), json!(format!("{}_final", x_col)));
636+
}
637+
}
638+
639+
// Update y encoding to use y_final
640+
if let Some(y_enc) = encoding_map.get_mut("y") {
641+
if let Some(y_obj) = y_enc.as_object_mut() {
642+
y_obj.insert("field".to_string(), json!(format!("{}_final", y_col)));
643+
}
644+
}
645+
646+
// Add segment_id to detail encoding
647+
encoding_map.insert("detail".to_string(), json!({
648+
"field": "__segment_id__",
649+
"type": "nominal"
650+
}));
651+
}
652+
}
653+
654+
Ok(vec![layer_spec])
655+
}
355656
}
356657

357658
// =============================================================================

0 commit comments

Comments
 (0)