Part II: Voice Assistant Desktop App with LLaMA3 and Whisper in Rust
Step by step tutorial on building a desktop app to interface with LLM LLaMA3 with text and audio instructions in Rust. This is the 2nd and final installment of the series.
July 7, 2024 · 19 min · 4025 words
In this series we are on a mission to build our very own desktop native app that can interface with LLM LLaMA3-8B over text and audio instructions. In our pervious installment of the series we did our setup with Tauri 2.0 Beta and loaded our LLaMA3 and Whisper models using Huggingface Candle framework, then we ran our text generation pipeline. In this post we complete this journey with voice instructions and get our app to respond to our audio instructions.
You’ll feel right at home if you are a programmer, have some exposure to Rust and a bit of experience working with Svelte, React or any other modern client-side framework.
With the text inference ready to rock, let’s move on to the Whisper based audio transcription which is going to be a lot more involved. So, gear up.
Let’s break down the steps to audio transcription.
We receive chunks of audio data from the frontend - we keep appending these chunks
When the recording stops, the frontend would call ask() command handler to initiate audio inference which will in-turn call the audio() method of our struct Instruct
This command is then passed on to a WhisperWrap method infer() which will actually run the transcription.
Note, these steps just generate the transcript of the audio and doesn’t interact with our LLaMA3 yet. We’ll work on that in the Pipeline phase of our processing.
We’ve already exposed a tauri command fn audio_chunk() in audio-instruct/src-tauri/src/commands.rs, let’s modify that to actually send the chunks to the instance of our struct Instruct with the MPSC channel we had defined earlier.
/// This tauri command would receive a Vec<f32> which represents a chunk of audio being recorded
/// The chunk will be forwarded through the MPSC channel
#[tauri::command]pubfnaudio_chunk(app: tauri::State<'_,Arc<Instruct>>,req: ipc::Request<'_>)-> Result<(),&'staticstr>{iflettauri::ipc::InvokeBody::Raw(data)=req.body(){letchunk=bytes_to_f32(&data[..]);ifletErr(e)=app.send(chunk){error!("audio_chunk: error: {e:?}");returnErr("invalid chunk")}}else{returnErr("invalid chunk")}Ok(())}
And, let’s also change our audio-instruct/src-tauri/src/instruct.rs method listen() to actually do some work.
implInstruct{// .. code ommitted ..
/// Exposes an API to send data into our MPSC channel
pubfnsend(&self,data: Vec<f32>)-> Result<()>{self.send.send(data)?;Ok(())}// This method just forwards the incoming chunks to a method exposed by our `struct WhisperWrap`. The client doesn't need to wait for this to happen
fnlisten(app: Arc<Instruct>,recv: Receiver<Vec<f32>>){whileletOk(next)=recv.recv(){app.whisper.chunk(next);}}// .. code ommitted ..
}
Now, we’ll expose a method of struct WhisperWrap to accept the incoming chunks and update it’s field data.
implWhisperWrap{// ... code ommitted ...
/// Accepts an incoming chunk of data and appends it to our `data` field of the struct
pubfnchunk(&self,chunk: Vec<f32>){letmutc=chunk;letmutchunk=self.data.lock().unwrap();// while appending this also `drains` the incoming Vec<>, saving some space
chunk.append(&mutc);}// ... code ommitted ...
}
That should do the trick, every time a new audio chunk is received this method just appends it to the field data.
Once the audio recording stops, our tauri command ask() is called by the frontend but this time requesting an audio inference instead of a text inference. We’ll work on the calling from frontend bit of this flow later, for now, let’s just scope in for this. Our tauri::command ask()already derives the Mode of the incoming command and forwards it to Instruct method text() or Instruct method audio(). Let’s modify our audio() method to trigger the inference.
audio-instruct/src-tauri/src/instruct.rs
59
60
61
62
63
64
65
66
67
68
/// Public API to trigger audio inference
pubfnaudio(&self)-> Result<Response>{let(transcript,n_tokens,elapsed)=self.whisper.infer()?;// More when we work on the `Pipeline` part of our inference
Ok(Response::new(&transcript,&transcript,n_tokens,0))}
Let’s start with the actual transcription now. Whisper is an encoder -> decoder architecture model vs LLaMA3 which is a decoder only model.
In encoder -> decoder models the input data is first passed through the encoder part of a model; the output of the encoder is then fed into the decoder part of the model to generate tokens.
Note
A quick read about transformers and the various modules that are involved.
We’ve already discussed Mel Spectogram and our mel_filters before, we’ll use that now.
Preprocessing would involve using the candle api pcm_to_mel() (here pcm stands for pulse code modulation representation of the audio data and mel is Mel Spectogram representation).
implWhisperWrap{// ... code ommitted ..
// 1. Checks if we have valid data
// 2. Creates the `Mel Spectogram` representation of our audio data
// 3. Creates and returns a `Tensor` from the given data
fnpreproc(&self)-> Result<Tensor>{letdata=matchself.data.lock(){Ok(mutd)=>{ifd.len()<4096*4{anyhow::bail!("Not enough audio data in buffer!");}letd=d.drain(..).collect::<Vec<_>>();d}Err(e)=>{error!("error acquiring data lock: {e:?}");anyhow::bail!("Not enough audio data in buffer!");}};letmel=pcm_to_mel(&self.config,&data[..],&self.mel_filters[..]);letmel_len=mel.len();letmel=Tensor::from_vec(mel,(1,self.config.num_mel_bins,mel_len/self.config.num_mel_bins),&self.device,)?;Ok(mel)}// ... code ommitted ..
/// Runs transcription
pubfninfer(&self)-> Result<(String,u32,std::time::Duration)>{// generates `mel`
letmels=self.preproc()?;letmutmodel=matchself.model.lock(){Ok(m)=>m,Err(e)=>{error!("infer: error acquiring model lock: {e:?}");anyhow::bail!("error during inference");}};let(_,_,content_frames)=mels.dims3()?;letmutseek=0;letmutsegments=vec![];// newline tokens after each segment
letnltokens=self.tokenizer.encode("\n",false).unwrap().get_ids().to_vec();letmuttotal_dur=Duration::from_millis(0);letmuttotal_tokens=0;// seek through the generated `mels` and call the `decode_segment` method on a chunk
whileseek<content_frames{letstart=std::time::Instant::now();letsegment_size=usize::min(content_frames-seek,N_FRAMES);letmel_segment=mels.narrow(2,seek,segment_size)?;letmutdecoded=self.decode_segment(&mutmodel,&mel_segment)?;seek+=segment_size;total_dur+=std::time::Instant::now()-start;total_tokens+=decoded.tokens.len();ifdecoded.no_speech_prob>NO_SPEECH_THRESHOLD&&decoded.avg_logprob<LOGPROB_THRESHOLD{println!("no speech detected, skipping {seek}{decoded:?}");continue;}segments.append(&mutdecoded.tokens);// adding newline tokens after each segment
nltokens.iter().for_each(|&t|{segments.push(t);});}// Let us now create the final text output
letinstruct=self.tokenizer.decode(&segments,true).map_err(|_|anyhow!("error creating text from tokens"))?;Ok((instruct,total_tokensasu32,total_dur))}// Decodes a single segment at different `hyperparameter temperature`
// The current values of the hyperparameter TEMPERATURES = [0.0, 0.2, 0.4, 0.6, 0.8, 1.0]
fndecode_segment(&self,model: &mutWhisper,segment: &Tensor)-> Result<DecodingResult>{// Decode at a particular temperature, check if we have a valid result or move on to the next temperature
for(i,&t)inTEMPERATURES.iter().enumerate(){letdecoded=self.decode(model,segment,t);ifi==TEMPERATURES.len()-1{returndecoded;}matchdecoded{Ok(decoded)=>{ifdecoded.avg_logprob>=LOGPROB_THRESHOLD||decoded.no_speech_prob>NO_SPEECH_THRESHOLD{info!("Decoded: {decoded:?}");returnOk(decoded);}}Err(e)=>{warn!("Error decoding @ temperature: {t}: {e:?}");}}}unreachable!()}}
Let’s look at what is happening above, the logical segments of the mel representation is passed on to a method decode_segments(). For each mel segment we are attempting to generate inference on various pre-set temperatures, a hyperparameter used by whisper for decoding.
Basically, we are trying to get a valid result for each segment across the different temperature values, the validity is defined by the some more hyperparameters like LOGPROB_THRESHOLD (sort of how sure is the model) and NO_SPEECH_THRESHOLD (what are the chances of this segment being random noise?).
Inside self.decode the mel representation is passed through the encoder and then after some processing on to the decoder part of our model. Let’s jump into that, thats the crux of this whole process.
// inside impl WhisperWrap
fndecode(&self,model: &mutWhisper,segment: &Tensor,temp: f64)-> Result<DecodingResult>{// generating some random seed
letmutrng=rand::thread_rng();// the `mel` segment block is passed through the encoder here, you can think of this process as generating `features` from your input audio data segment
letfeatures=model.encoder.forward(segment,true)?;// some token pre-generation, basically creating a `token` representation which tells the `decode` part of the model [start of transcript, english, task transcribe, don't generate timestamps] - these special tokens kind of act like a prompt for the decoder model
letmuttokens=self.preproc_decode();// now, we initialize some variables that will maintain our stats and metrics derived from this part of the inference
// probability of this segment being some random/ background noise
letmutno_speech_prob=f64::MAX;// the average probability of our prediction across all the decoding passes
letmutsum_log_p=0.;// Ok, so now we loop through a `max number of possible outputs` and `autoregressively` generate the next token
foriin0..self.config.max_target_positions{// we'll comvert the `Vec<tokens>` slice to a `Tensor`
lettensor=Tensor::new(&tokens[..],&self.device)?;// and send that tensor to the `decoder` to generate the next token
letdec=model.decoder.forward(&tensor.unsqueeze(0)?,&features,i==0)?;// Extract the no speech probability on the first iteration by looking at the first
// token logits and the probability for the according token.
ifi==0{letlogits=model.decoder.final_linear(&dec.i(..1)?)?.i(0)?.i(0)?;no_speech_prob=softmax(&logits,0)?.i(self.default_tokens.no_speechasusize)?.to_scalar::<f32>()?asf64;}let(_,seq_len,_)=dec.dims3()?;letlogits=model.decoder.final_linear(&dec.i((..1,seq_len-1..))?)?.i(0)?.i(0)?;// a simple sampler, picked up from `https://github.com/huggingface/candle/blob/main/candle-examples/examples/whisper/main.rs`
letnext_token=iftemp>0.{letprs=softmax(&(&logits/temp)?,0)?;letlogits_v: Vec<f32>=prs.to_vec1()?;letdistr=rand::distributions::WeightedIndex::new(&logits_v)?;distr.sample(&mutrng)asu32}else{letlogits_v: Vec<f32>=logits.to_vec1()?;logits_v.iter().enumerate().max_by(|(_,u),(_,v)|u.total_cmp(v)).map(|(i,_)|iasu32).unwrap()};// remember, that this decoding is `autoregressive`, meaning, output of the current pass is passed on as input to the next pass till some stop condition is reached
tokens.push(next_token);// bookkeeping the `probability`, we'll calculate average on this later and that would serve as our `decisioning` benchmark comparing with `LOGPROB_THRESHOLD`
letprob=softmax(&logits,candle_core::D::Minus1)?.i(next_tokenasusize)?.to_scalar::<f32>()?asf64;// the stop condition, if end-of-token or max target has been reached
ifnext_token==self.default_tokens.eot||tokens.len()>self.config.max_target_positions{break;}sum_log_p+=prob.ln();}// Finally create a struct to hold our output and metadata for this decoding pass
Ok(DecodingResult{text: self.tokenizer.decode(&tokens,true).map_err(|_|anyhow!("error creating text from tokens"))?,avg_logprob: sum_log_p/tokens.len()asf64,tokens,no_speech_prob,temperature: temp,})}
Code snippet above (please read through the comments for details) can be summarized into the following broad steps:
We pass our mel through the encoder of our model to generate features - input for the decoder
We generate a set of tokens (language, task etc.)
We go into an autoregressive loop and use our tokens and mel features - each loop results in a new token and we append it to our tokens
Loop is broken when we hit some conditions like end of transcript token
We also keep maintaining some metrics which we use to decide if the current segment is a voice segment or noise
That’s it! we have our transcript ready. To test this out we can write a test case to read an audio file, convert it to pcm then mel spectogram etc.. but that’s rather time consuming and since we are not going to read a file for our final flow, I’d rather finish off the frontend to test this out.
As we did last time, we already have a svelte frontend ready and not a lot has changed from the previous setup. Let’s move on to the unique and juicy bits of the current flow.
Let’s take a moment to figure out what are the things and capabilities that need to change in the client side to accommodate audio instructions from our existing text only inference!
This is a pre-requisite and not strictly a part of the flow. We’ll need to setup the right permissions so that our client side can access and find the microphones.
a. [TODO] figure this out for Tauri 2.0 Beta
First, of course we need a way of capturing the audio, since our rust backend expects pcm encoded Vec<f32> input, we’ll need a way of converting the audio waveforms to our Vec<f32>
We need a way of sending the audio converted to Vec<f32> in chunks
We need a way of stopping the recording and once stopped our client will should send across a command ask() to the backend to process the audio instructions, which means we’ll end up modifying our call to ask to support both the use cases; text and audio
We’ll add a button alongside our text input and on click of that trigger toggle record / stop record functionality.
<scriptlang="ts">// Adapted from https://github.com/kgullion/vite-typescript-audio-worklet-example/blob/main/src/main.ts
importaudioProcUrlfrom"$lib/audio-proc/audio-processor?url";import{invoke}from'@tauri-apps/api/core';importtype{Inference,QuestionAnswer}from"$lib/types";importQafrom"QA.svelte";// defining a constant buffer size for chunking audio
constBUFFER_SIZE=4096;// sampling rate for whisper 16KHz
constSAMPLE_RATE=16000;// Whisper typically expects 16kHz audio
// ... code ommitted ...
// a variable to hold the media stream
letstream: MediaStream|null=null;// an audio context
letaudioContext: AudioContext|null=null;letsource: MediaStreamAudioSourceNode|null=null;letworkletNode: AudioWorkletNode|null=null;</script>
Ok, some definitions and explanations
let stream: MediaStream
The MediaStream interface of the Media Capture and Streams API represents a stream of media content. A stream consists of several tracks, such as video or audio tracks.
The AudioContext interface represents an audio-processing graph built from audio modules linked together, each represented by an AudioNode.
An audio context controls both the creation of the nodes it contains and the execution of the audio processing, or decoding. You need to create an AudioContext before you do anything else, as everything happens inside a context. It’s recommended to create one AudioContext and reuse it instead of initializing a new one each time, and it’s OK to use a single AudioContext for several different audio sources and pipeline concurrently.
Simply put, one or multiple audio devices or streams, source and destination are connected via a conceptual AudioNode and they are all logically bound together by this AudioContext interface.
let workletNode: AudioWorkletNode
This is an interface which represents the base class for a user defined AudioNode. This would have an AudioWorkletProcessor where the actual processing happens BUT, in the browser’s (separate from main thread if I get this right) Web Audio Rendering thread, which makes it pretty efficient and cool.
Ok, now that we understand what those heavy hitting words are for, let’s move on. Below we define the functions to record, stopRecord and toggleRecord. Read the comments for more details.
<scriptlang="ts">importaudioProcUrlfrom"$lib/audio-proc/audio-processor?url";// audio worklet defined and exported from here
// ... code ommitted ...
// begins recording of a new stream
constrecord=async()=>{if(stream){console.error("Duplicate record??");return;}// requests for `audio` user media using this navigator API. For the first run, this will ask for the pemission to grant access to microphone
stream=awaitnavigator.mediaDevices.getUserMedia({audio: true});// Create AudioContext with our 16KHz sample rate
audioContext=newAudioContext({sampleRate: SAMPLE_RATE});// Load and register the audio worklet
// Worker file loaded as a module
awaitaudioContext.audioWorklet.addModule(audioProcUrl)// Create MediaStreamSource
source=audioContext.createMediaStreamSource(stream);// Create AudioWorkletNode - the `audioProcUrl` content is attached to the execution context
// Like a pipe, the audio-stream would pass through this transformation
workletNode=newAudioWorkletNode(audioContext,'audio-processor',{outputChannelCount:[1],processorOptions:{bufferSize: BUFFER_SIZE}});// Connect the nodes
source.connect(workletNode);workletNode.connect(audioContext.destination);// Set up message handling from the audio worklet
workletNode.port.onmessage=handleAudioData;}// the output chunk of the AudioWorkerNode is passed on to this function
// and this function `emits` the audio_chunk to the `backend` using `tauri::comand audio_chunk()`
consthandleAudioData=async(event: MessageEvent):Promise<void>=>{constfloat32Array=event.dataasFloat32Array;invoke("audio_chunk",float32Array);}// stops the recording, invokes `ask()` tauri command with indication that we are going to process audio
// then, cleans up all the audio related instances and objects
conststopRecord=async()=>{goAskAudio();if(workletNode){workletNode.disconnect();workletNode=null;}if(source){source.disconnect();source=null;}if(audioContext){audioContext.close();audioContext=null;}if(stream){stream.getTracks().forEach(t=>t.stop());stream=null;}recordstart=null;}// toggle start/ stop recording
consttoggleRecord=async()=>{isrecording=!isrecording;if(isrecording){recordstart=newDate();record();}else{stopRecord()}}// ... code ommitted ...
// prepares to ask audio inference and calls `tauri ask()` command with `audio: true` and `text: undefined`
constgoAskAudio=async()=>{asking=true;// We are just using a simple keyword to
qas.push({q:"..",a:"__asking__",ts: newDate()});question="";qas=[...qas];// The inference generation is extremely resource intensive, giving our UI to update before the call
setTimeout(()=>{command(undefined,true)},100)}</script>
Now, let’s look at the AudioWorker file we have been talking about.
// Adapted from https://github.com/kgullion/vite-typescript-audio-worklet-example/blob/main/src/main.ts
classAudioProcessorextendsAudioWorkletProcessor{privatebufferSize: number;privatebuffer: Float32Array;privatebufferIndex: number;constructor(options?: AudioWorkletNodeOptions){super();this.bufferSize=options?.processorOptions.bufferSize||4096;this.buffer=newFloat32Array(this.bufferSize);this.bufferIndex=0;}// this method is basically responsible for an input buffer and creating a output chunked buffer of specific size
process(inputs: Float32Array[][],outputs: Float32Array[][],parameters: Record<string,Float32Array>):boolean{constinput=inputs[0];constchannel=input[0];if(channel){for(leti=0;i<channel.length;i++){this.buffer[this.bufferIndex++]=channel[i];if(this.bufferIndex===this.bufferSize){this.port.postMessage(this.buffer);this.buffer=newFloat32Array(this.bufferSize);this.bufferIndex=0;}}}returntrue;}}registerProcessor('audio-processor',AudioProcessor);
Ok, that should do the trick, it’s time to check try this out!
RUST_LOG=info npm run tauri dev --release -- --features metal
And in a few seconds, we should see …
There you go, if you play the video you’ll hear my awkward voice, but you’ll also see the transcript. Our audio workflow WORKS 🎉🎉🎉!
So far, we have our text inference up and running and audio transcript doing its job. Let’s tie them up together in this Pipeline.
We are simply going to pass on the transcript as a part of our text inference prompt as an instruction.
implInstruct{// .. code ommitted ..
/// Public API to trigger audio inference
pubfnaudio(&self)-> Result<Response>{let(transcript,n_tokens,elapsed)=self.whisper.infer()?;let(generated,n_txt_tok,txt_elapsed)=self.llama.infer(&transcript)?;Ok(Response::new(&transcript,&generated,(n_tokens+n_txt_tok)asu32,(elapsed+txt_elapsed).as_secs(),))}}
And that’s it. Your Personal Voice Assistant is ready for your command.
That was a lot of information! But if you have reached here you have done GREAT and you have my thanks and congratulations. You’ve captured audio from a WebView frontend, pcm encoded the chunks of the waveform, emitted chunks to the backend, ran mel transformations on it, made predictions with 2 models all in your own computer.
What next from here you ask? Here are some ideas …
We hardcoded our audio model to work with English only, try a multi-lingual model. That would involve a phase of language detection
Create commands out of your instructions - open notes could open the note taking or text edit app
Implement push-to-talk - global listeners of your app could launch this app and always be there for you
If you have found this post helpful consider spreading the word, it would act as a strong motivator for me to create more.
If you found an issue or hit a snag, reach out to me @beingAnubhab.