Skip to content

Commit 60fcdcb

Browse files
pytorchbotmattjcly
andauthored
Fix Voxtral Realtime runner flush (#18387)
### Summary `VoxtralRealtimeRunner` was outputting excessive duplicate tokens/gibberish on stream flush. In an audio file where I say "The weather is clear today" run like ``` voxtral_realtime_runner \ --model_path model.pte \ --tokenizer_path tekken.json \ --preprocessor_path preprocessor.pte \ --streaming \ --audio_path audio.wav ``` I would get output: `The weather is clear todayoday.</s>` I also experienced this with periods and many other circumstances with repeating tokens at the end of the stream. Upon investigation into vLLM's (Mistrals recommended inferencing runner for [Voxtral Realtime](https://huggingface.co/mistralai/Voxtral-Mini-4B-Realtime-2602#vllm-recommended)), I observed that vLLM finishes the stream by closing the streaming input and draining model-defined right-padding audio, whereas ExecuTorch `flush()` finished by switching into post-audio text-only decoding after audio ended. See vLLM ref: https://github.com/vllm-project/vllm/blob/2f9f946/vllm/model_executor/models/voxtral_realtime.py#L239-L270. Therefore, apply similar logic here by converting the model-defined transcription delay into a finite number of trailing silent streaming steps to properly conclude the stream. After this, the same command outputs: ``` The weather is clear today. ``` As I would expect. No `</s>` b/c like vLLM, the stream ends by naturally draining the padded audio tail and letting the model emit whatever final delayed text it wants. ### Test plan Tested with above example and a few other audio files to observe the behavior improvement and lack of giberrish/incorrect end of stream. Co-authored-by: Matt Clayton <156335168+mattjcly@users.noreply.github.com>
1 parent fa605be commit 60fcdcb

6 files changed

Lines changed: 95 additions & 72 deletions

File tree

examples/models/voxtral_realtime/README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -269,7 +269,7 @@ Ctrl+C stops recording and flushes remaining text.
269269
| `--preprocessor_path` | (none) | Path to mel preprocessor `.pte` |
270270
| `--audio_path` | (none) | Path to 16kHz mono WAV file |
271271
| `--temperature` | `0.0` | Sampling temperature (0 = greedy) |
272-
| `--max_new_tokens` | `500` | Maximum tokens to generate |
272+
| `--offline_max_new_tokens` | `500` | Offline-only: maximum extra tokens after audio embeddings are exhausted |
273273
| `--streaming` | off | Use streaming transcription (from WAV file) |
274274
| `--mic` | off | Live microphone mode (reads raw f32le PCM from stdin) |
275275
| `--mic_chunk_ms` | `80` | Mic read chunk size in ms (multiples of 80 recommended) |

examples/models/voxtral_realtime/export_voxtral_rt.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -600,6 +600,8 @@ def main():
600600
else:
601601
programs, metadata = export_all(model, args.max_seq_len, **quant_args)
602602

603+
metadata["delay_tokens"] = args.delay_tokens
604+
603605
# Lower
604606
et = lower_to_executorch(programs, metadata, backend=args.backend)
605607

examples/models/voxtral_realtime/main.cpp

Lines changed: 29 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,11 @@ DEFINE_string(tokenizer_path, "tekken.json", "Path to Tekken tokenizer file.");
4949
DEFINE_string(preprocessor_path, "", "Path to mel preprocessor (.pte).");
5050
DEFINE_string(audio_path, "", "Path to input audio file (.wav).");
5151
DEFINE_double(temperature, 0.0, "Sampling temperature (0 = greedy).");
52-
DEFINE_int32(max_new_tokens, 500, "Maximum number of tokens to generate.");
52+
DEFINE_int32(
53+
offline_max_new_tokens,
54+
500,
55+
"Offline-only: maximum extra tokens to generate after audio embeddings "
56+
"are exhausted.");
5357
DEFINE_bool(streaming, false, "Use streaming transcription mode.");
5458
DEFINE_bool(
5559
mic,
@@ -74,6 +78,12 @@ volatile sig_atomic_t g_interrupted = 0;
7478
void sigint_handler(int) {
7579
g_interrupted = 1;
7680
}
81+
82+
voxtral_realtime::StreamingTranscribeConfig make_streaming_config() {
83+
voxtral_realtime::StreamingTranscribeConfig config;
84+
config.temperature = static_cast<float>(FLAGS_temperature);
85+
return config;
86+
}
7787
} // namespace
7888

7989
int main(int argc, char** argv) {
@@ -89,6 +99,16 @@ int main(int argc, char** argv) {
8999
return 1;
90100
}
91101

102+
if ((FLAGS_streaming || FLAGS_mic) &&
103+
!gflags::GetCommandLineFlagInfoOrDie("offline_max_new_tokens")
104+
.is_default) {
105+
ET_LOG(
106+
Error,
107+
"--offline_max_new_tokens only applies to offline transcription. "
108+
"Streaming mode drains until EOS or the padded audio tail is exhausted.");
109+
return 1;
110+
}
111+
92112
if (FLAGS_preprocessor_path.empty()) {
93113
ET_LOG(Error, "preprocessor_path flag must be provided.");
94114
return 1;
@@ -109,10 +129,6 @@ int main(int argc, char** argv) {
109129
stats.model_load_end_ms = ::executorch::extension::llm::time_in_ms();
110130
stats.inference_start_ms = ::executorch::extension::llm::time_in_ms();
111131

112-
voxtral_realtime::TranscribeConfig config;
113-
config.temperature = static_cast<float>(FLAGS_temperature);
114-
config.max_new_tokens = FLAGS_max_new_tokens;
115-
116132
stats.num_prompt_tokens = 0;
117133
bool first_token = true;
118134

@@ -156,7 +172,8 @@ int main(int argc, char** argv) {
156172
ET_CHECK_MSG(
157173
runner.is_streaming(),
158174
"Model was not exported with --streaming. Re-export with --streaming flag.");
159-
auto session = runner.create_streaming_session(config, token_cb);
175+
auto session =
176+
runner.create_streaming_session(make_streaming_config(), token_cb);
160177

161178
// Drain any audio that buffered in stdin during model loading/warmup.
162179
// Without this, piped audio (e.g., from ffmpeg) accumulates while the
@@ -207,7 +224,8 @@ int main(int argc, char** argv) {
207224
ET_LOG(Info, "Loading audio from: %s", FLAGS_audio_path.c_str());
208225
auto audio_data =
209226
::executorch::extension::llm::load_wav_audio_data(FLAGS_audio_path);
210-
auto session = runner.create_streaming_session(config, token_cb);
227+
auto session =
228+
runner.create_streaming_session(make_streaming_config(), token_cb);
211229

212230
const int64_t chunk_size = 1280;
213231
for (int64_t offset = 0; offset < static_cast<int64_t>(audio_data.size());
@@ -221,10 +239,13 @@ int main(int argc, char** argv) {
221239
ET_LOG(Info, "Loading audio from: %s", FLAGS_audio_path.c_str());
222240
auto audio_data =
223241
::executorch::extension::llm::load_wav_audio_data(FLAGS_audio_path);
242+
voxtral_realtime::OfflineTranscribeConfig offline_config;
243+
offline_config.temperature = static_cast<float>(FLAGS_temperature);
244+
offline_config.max_new_tokens = FLAGS_offline_max_new_tokens;
224245
num_generated = runner.transcribe(
225246
audio_data.data(),
226247
static_cast<int64_t>(audio_data.size()),
227-
config,
248+
offline_config,
228249
token_cb);
229250
}
230251

examples/models/voxtral_realtime/model.md

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -248,8 +248,9 @@ runner (`StreamingSession::decode_step`) then:
248248
3. Feeds the combined embedding to `text_decoder` at the current position
249249
4. Samples one token from the output logits
250250

251-
After audio ends, `flush()` continues text-only decoding (token
252-
embedding only, no audio) until EOS or max tokens.
251+
After audio ends, `flush()` pads the unfinished tail with silence and keeps
252+
running the same audio-conditioned streaming path until the final partial
253+
step and transcription delay are drained.
253254

254255
### Conv state management
255256

examples/models/voxtral_realtime/voxtral_realtime_runner.cpp

Lines changed: 43 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -124,6 +124,10 @@ VoxtralRealtimeRunner::VoxtralRealtimeRunner(
124124
if (msf.ok())
125125
mel_skip_frames_ = msf.get()[0].toInt();
126126

127+
auto dt = model_->execute("delay_tokens", empty);
128+
if (dt.ok())
129+
delay_tokens_ = dt.get()[0].toInt();
130+
127131
ET_LOG(
128132
Info,
129133
"Streaming: chunk_mel=%ld, max_enc=%ld, enc_dim=%ld",
@@ -158,8 +162,7 @@ VoxtralRealtimeRunner::VoxtralRealtimeRunner(
158162
std::vector<float> dummy_audio(static_cast<size_t>(step_samples_), 0.0f);
159163

160164
if (is_streaming_) {
161-
TranscribeConfig warmup_config;
162-
warmup_config.max_new_tokens = 1;
165+
StreamingTranscribeConfig warmup_config;
163166
auto session =
164167
create_streaming_session(warmup_config, [](const std::string&) {});
165168
session->feed_audio(dummy_audio.data(), step_samples_);
@@ -291,7 +294,7 @@ TensorPtr VoxtralRealtimeRunner::convert_to_model_dtype(TensorPtr tensor) {
291294
int VoxtralRealtimeRunner::transcribe(
292295
const float* audio_data,
293296
int64_t num_samples,
294-
const TranscribeConfig& config,
297+
const OfflineTranscribeConfig& config,
295298
TokenCallback token_cb) {
296299
// --- Step 1: Preprocess raw audio to mel spectrogram ---
297300
ET_CHECK_MSG(preprocessor_ != nullptr, "No preprocessor provided.");
@@ -423,7 +426,7 @@ int VoxtralRealtimeRunner::transcribe(
423426

424427
std::unique_ptr<StreamingSession>
425428
VoxtralRealtimeRunner::create_streaming_session(
426-
const TranscribeConfig& config,
429+
const StreamingTranscribeConfig& config,
427430
TokenCallback token_cb) {
428431
ET_CHECK_MSG(is_streaming_, "Model was not exported with --streaming.");
429432
ET_CHECK_MSG(
@@ -434,10 +437,9 @@ VoxtralRealtimeRunner::create_streaming_session(
434437

435438
StreamingSession::StreamingSession(
436439
VoxtralRealtimeRunner& runner,
437-
TranscribeConfig config,
440+
StreamingTranscribeConfig config,
438441
TokenCallback token_cb)
439442
: runner_(runner),
440-
config_(config),
441443
token_cb_(std::move(token_cb)),
442444
prev_token_(runner.bos_id_),
443445
sampler_(
@@ -452,9 +454,9 @@ StreamingSession::StreamingSession(
452454
int StreamingSession::feed_audio(const float* data, int64_t num_samples) {
453455
audio_buf_.insert(audio_buf_.end(), data, data + num_samples);
454456

455-
int new_tokens = 0;
457+
const int generated_before = num_generated_;
456458
while (!eos_reached_ && try_process_step()) {
457-
new_tokens++;
459+
// num_generated_ is updated inside try_process_step()
458460
}
459461

460462
// Trim consumed audio to bound memory growth. Keep stft_left_overlap_
@@ -467,7 +469,7 @@ int StreamingSession::feed_audio(const float* data, int64_t num_samples) {
467469
samples_consumed_ -= keep_from;
468470
}
469471

470-
return new_tokens;
472+
return num_generated_ - generated_before;
471473
}
472474

473475
bool StreamingSession::try_process_step() {
@@ -586,10 +588,10 @@ bool StreamingSession::try_process_step() {
586588
samples_consumed_ += step;
587589

588590
// --- Decode one step ---
589-
return decode_step(&audio_embeds_ptr);
591+
return decode_step(audio_embeds_ptr);
590592
}
591593

592-
bool StreamingSession::decode_step(const TensorPtr* audio_embeds_tensor) {
594+
bool StreamingSession::decode_step(const TensorPtr& audio_embeds_tensor) {
593595
const int64_t dim = runner_.dim_;
594596
const auto model_dtype = runner_.model_dtype_;
595597

@@ -603,33 +605,25 @@ bool StreamingSession::decode_step(const TensorPtr* audio_embeds_tensor) {
603605
ET_CHECK_MSG(tok_result.ok(), "token_embedding failed.");
604606
auto tok_embed = tok_result.get()[0].toTensor();
605607

606-
// Sum audio + token embeddings (or token-only if no audio).
608+
// Sum audio + token embeddings.
607609
// Reuses pre-allocated input_embeds_ buffer (no per-token allocation).
608-
if (audio_embeds_tensor != nullptr) {
609-
auto& audio_embeds = **audio_embeds_tensor;
610-
if (model_dtype == ::executorch::aten::ScalarType::BFloat16) {
611-
auto* out =
612-
input_embeds_->mutable_data_ptr<::executorch::aten::BFloat16>();
613-
const auto* af =
614-
audio_embeds.const_data_ptr<::executorch::aten::BFloat16>();
615-
const auto* tf = tok_embed.const_data_ptr<::executorch::aten::BFloat16>();
616-
for (int64_t i = 0; i < dim; i++) {
617-
out[i] = ::executorch::aten::BFloat16(
618-
static_cast<float>(af[i]) + static_cast<float>(tf[i]));
619-
}
620-
} else {
621-
auto* out = input_embeds_->mutable_data_ptr<float>();
622-
const auto* af = audio_embeds.const_data_ptr<float>();
623-
const auto* tf = tok_embed.const_data_ptr<float>();
624-
for (int64_t i = 0; i < dim; i++) {
625-
out[i] = af[i] + tf[i];
626-
}
610+
auto& audio_embeds = *audio_embeds_tensor;
611+
if (model_dtype == ::executorch::aten::ScalarType::BFloat16) {
612+
auto* out = input_embeds_->mutable_data_ptr<::executorch::aten::BFloat16>();
613+
const auto* af =
614+
audio_embeds.const_data_ptr<::executorch::aten::BFloat16>();
615+
const auto* tf = tok_embed.const_data_ptr<::executorch::aten::BFloat16>();
616+
for (int64_t i = 0; i < dim; i++) {
617+
out[i] = ::executorch::aten::BFloat16(
618+
static_cast<float>(af[i]) + static_cast<float>(tf[i]));
627619
}
628620
} else {
629-
std::memcpy(
630-
input_embeds_->mutable_data_ptr(),
631-
tok_embed.const_data_ptr(),
632-
static_cast<size_t>(dim) * input_embeds_->element_size());
621+
auto* out = input_embeds_->mutable_data_ptr<float>();
622+
const auto* af = audio_embeds.const_data_ptr<float>();
623+
const auto* tf = tok_embed.const_data_ptr<float>();
624+
for (int64_t i = 0; i < dim; i++) {
625+
out[i] = af[i] + tf[i];
626+
}
633627
}
634628

635629
auto cache_pos =
@@ -669,31 +663,31 @@ int StreamingSession::flush() {
669663
}
670664
flushed_ = true;
671665

672-
// Pad with silence so any remaining audio (including partial steps and
673-
// the right look-ahead for the last complete step) can be processed.
674666
const int64_t remaining =
675667
static_cast<int64_t>(audio_buf_.size()) - samples_consumed_;
676668
if (remaining > 0 && !eos_reached_) {
677669
const int64_t step = runner_.step_samples_;
678670
const int64_t right_lookahead = runner_.stft_right_lookahead_;
679-
// Pad to next full step + right look-ahead
680-
int64_t pad_to = ((remaining + step - 1) / step) * step + right_lookahead;
681-
std::vector<float> silence(static_cast<size_t>(pad_to - remaining), 0.0f);
671+
const int64_t right_pad_audio_steps = runner_.delay_tokens_;
672+
673+
// Stay on the normal audio-conditioned path through the final partial
674+
// step, the preprocessor look-ahead, and the model's transcription delay.
675+
// Matches vLLM flush behavior:
676+
// https://github.com/vllm-project/vllm/blob/2f9f946/vllm/model_executor/models/voxtral_realtime.py#L239-L270
677+
int64_t pad_to =
678+
(((remaining + step - 1) / step) + right_pad_audio_steps) * step +
679+
right_lookahead;
680+
const int64_t silence_padded_samples = pad_to - remaining;
681+
std::vector<float> silence(
682+
static_cast<size_t>(silence_padded_samples), 0.0f);
682683
audio_buf_.insert(audio_buf_.end(), silence.begin(), silence.end());
683684

685+
// Guaranteed to terminate b/c each call to try_process_step() consumes a
686+
// fixed number of audio samples and the padded audio buffer is finite.
684687
while (!eos_reached_ && try_process_step()) {
685688
}
686689
}
687690

688-
// Text-only decoding after audio ends.
689-
const int64_t max_text_steps = std::min(
690-
static_cast<int64_t>(config_.max_new_tokens) - num_generated_,
691-
runner_.max_seq_len_ - dec_pos_);
692-
693-
for (int64_t i = 0; i < max_text_steps && !eos_reached_; i++) {
694-
decode_step(nullptr);
695-
}
696-
697691
return num_generated_;
698692
}
699693

examples/models/voxtral_realtime/voxtral_realtime_runner.h

Lines changed: 17 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -25,11 +25,15 @@ namespace voxtral_realtime {
2525
// audio and text embeddings at each position (element-wise add), while
2626
// MultimodalRunner concatenates modality segments sequentially.
2727

28-
struct TranscribeConfig {
28+
struct OfflineTranscribeConfig {
2929
int max_new_tokens = 500;
3030
float temperature = 0.0f; // 0 = greedy
3131
};
3232

33+
struct StreamingTranscribeConfig {
34+
float temperature = 0.0f; // 0 = greedy
35+
};
36+
3337
using TokenCallback = std::function<void(const std::string&)>;
3438

3539
class StreamingSession;
@@ -47,14 +51,14 @@ class VoxtralRealtimeRunner {
4751
int transcribe(
4852
const float* audio_data,
4953
int64_t num_samples,
50-
const TranscribeConfig& config,
54+
const OfflineTranscribeConfig& config,
5155
TokenCallback token_cb);
5256

5357
// Streaming transcription: processes raw audio incrementally via
5458
// StreamingSession. Requires a model exported with --streaming and
5559
// a streaming preprocessor .pte.
5660
std::unique_ptr<StreamingSession> create_streaming_session(
57-
const TranscribeConfig& config,
61+
const StreamingTranscribeConfig& config,
5862
TokenCallback token_cb);
5963

6064
int64_t max_seq_len() const {
@@ -101,6 +105,9 @@ class VoxtralRealtimeRunner {
101105
int64_t stft_right_lookahead_ = 40;
102106
int64_t mel_skip_frames_ = 2;
103107

108+
// Streaming transcription delay in steps (read from model metadata).
109+
int64_t delay_tokens_ = 6;
110+
104111
// Tokenizer special tokens
105112
uint64_t bos_id_ = 1;
106113
uint64_t eos_id_ = 2;
@@ -121,16 +128,16 @@ class StreamingSession {
121128
public:
122129
StreamingSession(
123130
VoxtralRealtimeRunner& runner,
124-
TranscribeConfig config,
131+
StreamingTranscribeConfig config,
125132
TokenCallback token_cb);
126133

127134
// Feed raw audio (16kHz float32). Processes as many complete 80ms steps
128135
// as possible. Returns number of new tokens generated.
129136
int feed_audio(const float* data, int64_t num_samples);
130137

131-
// Signal end of audio. Pads last partial step, then generates remaining
132-
// text-only tokens until EOS or max_new_tokens. Returns total tokens
133-
// generated across the entire session.
138+
// Signal end of audio. Pads the unfinished tail with silence so the final
139+
// partial step and model delay drain through the normal audio-conditioned
140+
// streaming path, then returns the total tokens generated for the session.
134141
int flush();
135142

136143
int total_tokens() const {
@@ -139,7 +146,6 @@ class StreamingSession {
139146

140147
private:
141148
VoxtralRealtimeRunner& runner_;
142-
TranscribeConfig config_;
143149
TokenCallback token_cb_;
144150

145151
// Raw audio accumulation buffer
@@ -163,11 +169,10 @@ class StreamingSession {
163169
// Process one 80ms step from the audio buffer.
164170
bool try_process_step();
165171

166-
// Run one decoder step (token_embed + optional audio_embed -> logits).
167-
// audio_embeds_tensor is the output from encode_audio_chunk, or nullptr
168-
// for text-only decoding after audio ends.
172+
// Run one audio-conditioned decoder step
173+
// (token_embed + audio_embed -> logits).
169174
bool decode_step(
170-
const ::executorch::extension::TensorPtr* audio_embeds_tensor);
175+
const ::executorch::extension::TensorPtr& audio_embeds_tensor);
171176
};
172177

173178
} // namespace voxtral_realtime

0 commit comments

Comments
 (0)