In my previous series Desktop QA Assistant With Llama3 in Rust we built a desktop app capable of interfacing with LLaMA3. In this post, we’ll extend the modality of the app to accept audio instructions instead of just text.

In this part of the series, we’ll deal with the setup, loading the models and get our text inference up using the rust ML framework called Candle.

The series:

Who is this for?

You’ll feel right at home if you are a programmer, have some exposure to Rust and a bit of experience working with Svelte, React or any other modern client-side framework.

Tools and Libraries

Note on GGUF

GGUF is a file format for storing models for inference with GGML and executors based on GGML. GGUF is a binary format that is designed for fast loading and saving of models, and for ease of reading. Models are traditionally developed using PyTorch or another framework, and then converted to GGUF for use in GGML.

GGML Repo

If you are not familiar with the GGML/ GGUF ecosystem, I’ve written a small note about it here

About Tauri 2.0 Beta
I’m taking this opportunity to test out Tauri 2.x Beta, I’ve never used this before so might change a few things from our setup of the previous project and hiccup warnings in advance

TL;DR

Github Repo for this project

The Final Output

Alert: contains audio

A note on Multimodality: Model vs Pipeline:

Modality:

Learning modalities are the sensory channels or pathways through which individuals give, receive, and store information. Perception, memory, and sensation comprise the concept of modality. The modalities or senses include visual, auditory, tactile/kinesthetic, smell, and taste.

cortland.edu

In context of ML/ DL, modality refers to the different kinds of data a model can work with, comprehend or process.

Multimodal Model:

Unimodal models like BERT Family uses text as their modality or ResNet-18 uses image as their modality respectively, i.e. they can comprehend, interpret and work with text (for BERT) and image (ResNet) data, that’s what they have been trained on and what they can run inference on.

On the other hand, multimodal models can work with, comprehend and interpret multiple modalities like text, image, audio and so on. E.g.

  • ChatGPT can work with text, image and audio inputs & outputs
  • Open source LLaVA Family of models are trained to work with both text and image inputs

Simply put, a multimodal model can work with more than one type of input or output.

Multimodal Pipeline: Mocking Multimodality with multiple unimodal models

Multimodal models tend to be of complex architectures or, sometimes, exhibit not-so-efficient jack of all trades, master of none problems. While solving real world problems with AI/ML I’ve often found a bunch of specialist models working together produce significantly better results than a multimodal generalist.

In our this project we are going to be stitching together 2 specialists, the OpenAI Whisper audio to text model for transcribing the audio and Meta LLaMA3 text & language specialist model for our LLM backend. We’ll stitch them together with our Multimodal Pipeline.

Setup

Setting up tauri

With the rust toolchain and tauri-cli installed, we’ll just run

cargo create-tauri-app audio-instruct --beta
Prerequisites
  • Installation instructions for rust can be found here
  • Follow this for tauri installation

For the options create-tauri-app --beta asks for, my choices were as follows:

  • Frontend Language -> Typescript
  • Package Manager -> npm
  • UI Template -> Svelte
  • UI Flavor -> Typescript
  • Mobile Project -> No - we’ll try this in a future project 🤩

setup

Now, we’ll move into our project directory and run

npm i

After the installation completes, let’s run our desktop app for the first time.

npm run tauri dev

And we get a neat looking default window

first-run

Observations
  1. Tauri 2.0 feels like a much more polished product than what we got in Tauri v1 (which is expected). The out-of-the-box experience is far superior to the v1 experience, the initial boilerplate setup has also reduced significantly. 🥳
  2. But, because Tauri 2.0 is much more powerful, it’s permission structures are now a lot more modular, and the v2.x documentation is still early and requires a bit of deep dive to find the right stuff.

The project/ directory structure looks familiar - at-least superficially I don’t see a major difference from Tauri v1. You can find a quick note and explanation about the project layout here

A notable change to the Tauri v1 structure being the audio-instruct/src-tauri/capabilities directory. As per the Tauri 2.0 documentation about capabilities this directory will contain a set of permissions defined as json mapped to the application windows and webviews by their respective label. This is what my audio-instruct/src-tauri/capabilities/default.json looks like (to enable dialogs):

audio-instruct/src-tauri/capabilities/default.json
{
  "$schema": "../gen/schemas/desktop-schema.json",
  "identifier": "default",
  "description": "Capability for the main window",
  "windows": ["main"],
  "permissions": [
    "path:default",
    "event:default",
    "window:default",
    "app:default",
    "image:default",
    "resources:default",
    "menu:default",
    "tray:default",
    "shell:allow-open",
    "dialog:allow-open",
    "dialog:allow-save",
    "dialog:allow-message",
    "dialog:allow-ask",
    "dialog:allow-confirm",
    "dialog:default"
  ]
}

Setting up the backend dependencies

Candle first

We are going to be using Candle as our ML framework for our inference. Candle requires a slightly more involved setup, let’s be done with that. The instructions for setting up candle for a project can be found here

cargo add --git https://github.com/huggingface/candle.git candle-core --features "metal"
Choosing a `Backend`

Because I’m on a Mac M1 I’m using the metal backend. Candle also works with Cuda, mkl etc. Choose your appropriate GPU Backend.

Candle has reasonably elaborate guides, check it out if you get stuck (I did, when I attempted this for the first time).

More on Tauri

Tauri 2.0 introduces a much more powerful and modular approach to access, permissions and scopes. We’ll need to add a couple of plugin crates to our Cargo.toml to get our Dialog and File System access right.

  • Crate tauri-plugin-dialog for enabling frontend dialogs, confirmation etc.
  • Crate tauri-plugin-fs for filesystem access

Usual suspects

Now, with candle-core added to our dependencies, let’s go ahead and add the usual suspects; logs & pretty_env_logger for logs, anyhow for errors, etc.

audio-instruct/src-tauri/Cargo.toml
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
[dependencies]
anyhow              = "1"
candle-core         = { git = "https://github.com/huggingface/candle.git", version = "0.6.0", features = [] }
candle-transformers = { git = "https://github.com/huggingface/candle.git", version = "0.6.0", features = [] }
candle-nn           = { git = "https://github.com/huggingface/candle.git", version = "0.6.0", features = [] }
hf-hub              = { version = "0" }
log                 = "0"
pretty_env_logger   = "0"
rand                = "0"
serde = { version   = "1", features = ["derive"] }
serde_json          = "1"
tauri = { version   = "2.0.0-beta", features = [] }
tauri-plugin-dialog = "2.0.0-alpha.2"
tauri-plugin-fs     = "2.0.0-beta"
tauri-plugin-shell  = "2.0.0-beta"
tokenizers          = { version = "0" }

[features]
cuda = ["candle-core/cuda", "candle-transformers/cuda", "candle-nn/cuda"]
metal = ["candle-core/metal", "candle-transformers/metal", "candle-nn/metal"]
default = ["metal"]
Note

Note how I defined features for our tauri project. This will help us create the GpuDevice automatically based on the cfg!(feature = <something>) macro.

I’ll run the app with --features "metal" and that should initiate the device automatically. If you are using cuda or some other Gpu device, make sure to change this to --features "cuda" etc..

We’ve added:

  • anyhow to find our way around errors
  • candle-transformers to almost plug and play whisper and llama models, candle-nn to work with tensors
  • hf-hub to download the models
  • log and pretty_env_logger for some neat logging
  • tokenizers - for decoding our data for inference with whisper
  • rand - we’ll need to generate some rand ranges for inference

Structuring our App

We’ll end up having the following overall structure and modules for our inference engine:

  • a module instruct where the state of the app will be maintained and handlers will call into
  • a whisper module to do the audio transcription and associated pre and post processing
  • a llama module for the text inference
  • we’ll also end up with a utils module, we’ll always need some helper functions that doesn’t fit anywhere
  • and a commands module where we’ll define our handler (or commands in Tauri universe)

Let’s just go ahead and create these empty files (instruct.rs, whisper.rs, llama.rs and so on …) and declare them as modules in our main.rs

A last bit of boilerplate I’ve always needed to do is to update the identifier field src-tauri/tauri.conf.json to something unique. Let’s change this to audio-instruct.llm.

Our main.rs should now look like the following:

audio-instruct/src-tauri/src/main.rs
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
#[macro_use]
extern crate log;

mod commands;
mod instruct;
mod llama;
mod utils;
mod whisper;

fn main() {
    pretty_env_logger::init();

    todo!()
}

App-state, commands and communication

To get our app to work, we’ll need:

  • An app state that will be instantiated on launch, this will basically hold our models once loaded and expose APIs for preprocessing, inference etc.
  • A way of communicating the instruction or command from the front-end to the backend engine
  • We’ll capture audio through the front-end interface - we’ll then send this for transcription, we could do this by recording the whole audio in the client interface and then sending the whole blob to the backend. But let’s try to make this a little jucier. We’ll attempt to buffer chunks of the audio and emit it to the backend and we’ll also need a MPSC (multiple producer single consumer) channel to orchestrate this. More on this later.
  • And of course, last but not the least, we’ll need a couple of structs to create instances of the candle models and some associated methods.

App-state with struct Instruct

We define our struct Instruct in audio-instruct/src-tauri/src/instruct.rs, we’ll also define an initializer for it which would be responsible for instantiating the models or download it from HuggingFace Hub with the hf_hub crate.

audio-instruct/src-tauri/src/instruct.rs
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
/// A struct to maintain our app state
pub struct Instruct {
    /// holds the instantiated `Llama` quantized `gguf` model and associated methods
    llama: LlamaWrap,
    /// holds the `distil-whisper` model and the associated methods
    whisper: WhisperWrap,
    /// a channel for triggering Instruct methods through events
    send: Sender<Event>,
}

impl Instruct {
    pub fn new(datadir: PathBuf) -> Result<Arc<Self>> {
        let llama = LlamaWrap::new(datadir.as_path())?;
        let whisper = WhisperWrap::new(datadir.as_path())?;

        let (send, recv) = channel();

        let app = Arc::new(Self {
            llama, whisper, send: send.clone()
        });

        // spawn a listener to receive incoming events
        let appclone = Arc::clone(&app);
        thread::spawn(move || {
            Self::listen(appclone, recv)
        });

        Ok(app)
    }

    /// Exposes an API to send data into our MPSC channel
    pub fn send(&self, data: Vec<f32>) -> Result<()> {
        self.send.send(data)?;
        
        Ok(())
    }

    // We initiate a listener to listen to incoming messages
    fn listen(app: Arc<Instruct>, recv: Receiver<Vec<f32>>) {
        while let Ok(next) = recv.recv() {
            
        }
    }

    pub fn text(&self, istruct: &str) -> Result<Response> {
        todo!()
    }

    pub fn audio(&self) -> Result<Response> {
        todo!()
    }
}

The event listener is pretty-much an empty shell now, we’ll change that soon enough. The idea around this communication pattern is simple - the command handlers will receive a request which will be pushed to the MPSC channel, the end user doesn’t need to wait for it to finish processing.

Command handlers

A Command in tauri is simply a request sent by the client (in our case the front-end of the app), to the server (the backend of our app). Unlike a typical webserver, this is not a HTTP request rather something closer to remote procedure calls.

To register a command handler, we’ll need to inform tauri builder about it during initialization. Let’s define them in our audio-instruct/src-tauri/commands.rs.

audio-instruct/src-tauri/src/commands.rs
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
/// A struct to represent the incoming request
/// For `text` inference it would contain the instruction itself,
/// But for `audio` instructions we'll just need to indicate that we are looking for the audio recorded to generate inference since audio data is sent in chunks and
/// already maintained in the app state
#[derive(Debug, Deserialize)]
pub struct Command {
    text: Option<String>,
    audio: Option<bool>
}

/// Enum to maintain what kind of instruction this is
pub enum Mode {
    Text(String),
    Audio
}

impl Command {
    pub fn mode(&self) -> Result<Mode> {
        if let Some(t) = self.text.as_ref() {
            Ok(Mode::Text(t.to_owned()))
        } else if self.audio.map_or(false, |d| d) {
            Ok(Mode::Audio)
        } else {
            anyhow::bail!("not a valid command")
        }
    }
}

/// A struct to hold the response data and some stats or metadata required to show the inference
#[derive(Debug, Serialize, Deserialize)]
pub struct Response {
    text: String,
    meta: Meta,
    // since we are accepting instructions over `audio` now, we'll also return the original `transcribed` instruction
    instruct: String
}

/// A struct to hold some metadata and additional information about the QA/ Response/ Instruction etc.
#[derive(Debug, Serialize, Deserialize)]
pub struct Meta {
    // number of tokens generated
    n_tokens: u32,
    // number of seconds elapsed
    n_secs: u64
}

impl Response {
    pub fn new(instruct: &str, txt: &str, n_tokens: u32, n_secs: u64) -> Self {
        Self {
            instruct: instruct.to_string(),
            text: txt.to_string(),
            meta: Meta { n_secs, n_tokens }
        }
    }
}

/// A command to accept incoming `instruction` and respond with the `inference`
#[tauri::command]
pub fn ask(
    app: tauri::State<'_, Arc<Instruct>>,
    cmd: Command
) -> Result<Response, &'static str> {
    let command = match cmd.mode() {
        Ok(c) => c,
        Err(e) => {
            error!("ask: invalid incoming command: {e:?}");
            return Err("invalid command")
        }
    };
    
    let res = match command {
        Mode::Text(t) => app.text(&t),
        Mode::Audio => app.audio()
    };

    match res {
        Ok(r) => Ok(r),
        Err(e) => {
            error!("ask: error during inference: {e:?}");
            Err("inference error")
        }
    }
}

/// This tauri command would receive a Vec<f32> which represents a chunk of audio being recorded
/// The chunk will be forwarded through the MPSC channel
#[tauri::command]
pub fn audio_chunk(
    app: tauri::State<'_, Arc<Instruct>>,
    req: ipc::Request<'_>
) -> Result<(), &'static str> {
    if let tauri::ipc::InvokeBody::Raw(data) = req.body() {
        let chunk = bytes_to_f32(&data[..]);
        if let Err(e) = app.send(chunk) {
            error!("audio_chunk: error: {e:?}");
            return Err("invalid chunk")    
        }
    } else {
        return Err("invalid chunk")
    }
    Ok(())
}

Ok, so quite a lot happening there, lets break it down.

  1. The struct Command defines a structure for our incoming instruction. Unlike our previous text-only inference attempt, we can now have a text as well as an audio instruction mode. For text instruction relatively like we did in our previous blogs, simply accepting the text input but for audio instruction doesn’t contain any data. That’s because the data for the audio chunks is transmitted over in fn audio_chunks(), each chunk is then (preprocessed) and stored by the WhisperWrap object and then, once we send command ask() with audio: Some(true) the inference runs on the existing data.

  2. The enum Mode is simply to identify what kind of command is being requested in ask()

  3. Structs Response and Meta simply hold our inference and a bunch of metadata around it

  4. fn ask() is a command handler which processes incoming Command and responds with a Response

  5. Finally, the handler fn audio_chunk() is slightly different (and this is why I chose Tauri 2.0) for this project. You see in previous version of Tauri the incoming data in tauri::ommand needed to be text serializable, that kind of defeats the purpose of sending chunked data. Tauri 2.0 introduces this capability (I couldn’t find any documentation yet, but some reference from this github issue). Hence, the fn audio_chunk() is basically reading the incoming bytes as a Vec<f32> which is required for our chunk processing.

Let’s modify our fn main() to account for these handlers.

audio-instruct/src-tauri/src/main.rs
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
fn main() {
    pretty_env_logger::init();

    let app = tauri::Builder::default()
        .plugin(tauri_plugin_fs::init())
        .plugin(tauri_plugin_dialog::init())
        .setup(|tauri_app| {
            let datadir = tauri_app.path().app_data_dir()?;
            // Initialize the instruct app
            let instruct = Instruct::new(datadir.clone()).expect("initialization failed");            
            tauri_app.manage(instruct);

            Ok(())
        })
        .invoke_handler(tauri::generate_handler![
            crate::commands::ask,
            crate::commands::audio_chunk
        ])
        .build(tauri::generate_context!())
        .expect("Failed to build app!");
    
    // ... code ommitted ...
}

Ok, great, now that our State, commands and a simple communication infrastructure is setup, lets focus our attention on the models.

Models

We are going to be using two models, a Whisper based distil-whisper-large-v3 model for our audio transcription and a LLaMA3 8B gguf quantized model for our text inference. When our struct Instruct initializes, it would need to instantiate both the models. We’ll define a wrapper struct (let call them struct WhisperWrap and struct LlamaWrap) for each of the models. These structs would expose their init() constructor, inference methods, pre and postprocessing logic if any and would be responsible downloading the model if and when required.

Model Loading

Instantiation and model loading will largely involve the following steps:

  1. Check if required files exist locally in our app_data_dir(), if they don’t exist download them from hf_hub.

  2. Each model should then have a mechanism to initialize their respective models based on their configurations. Now, an interesting thing over here is that we are using two different model types, the LLaMA3 model is a gguf quantized variant while the distil-whisper-large-v3 model is a HuggingFace style [safetensors](https://huggingface.co/docs/safetensors/en/index) model, Candle exposes slightly different ways of loading them and we’ll need to account for that.

With that information, let’s code up the model loaders.

struct WhisperWrap and model distil-whisper-large-v3

For a non-gguf candle model to load we’ll need:

  1. a tokenizer.json file specific to the model - this specifies the vocabulary of the model along with a bunch of other tokenizer configurations
  2. a config.json file for the model - the model architecture is defined here
  3. and finally the model.safetensors file - these are the model weights and in our case it’s of float16 format
Reference
Check out the model card for distil-whisper-large-v3 here, browse through the Files and versions section to see what other formats are available

In our audio-instruct/src-tauri/src/whisper.rs we’ll first define a struct WhisperWrap and its initializers:

audio-instruct/src-tauri/src/whisper.rs
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
const MODEL_REPO: &str = "distil-whisper/distil-large-v3";

/// A struct to hold `distil-whisper` object and associated methods
pub struct WhisperWrap {
    device: Device,
    // we'll store our raw audio data in this vec
    data: Arc<Mutex<Vec<f32>>>,
    // Model config stored here - basically our `config.json` representation
    config: Config,
    // deep dive below
    mel_filters: Vec<f32>,
    // the actual model initialized
    model: Arc<Mutex<Whisper>>,
    tokenizer: Tokenizer,
    // some tokens that we need to access during the inference
    default_tokens: DefaultTokens
}

/// Whisper's default tokens pre-initialized and fetched - helps us later during inference
pub struct DefaultTokens {
    // start of transcript token
    sot: u32,
    // end of transcript
    eot: u32,
    // token representing language
    lang: u32,
    // token representing task - for us it's transcribe
    transcribe: u32,
    // token for no-timestamp - we don't need timestamp
    no_ts: u32,
    // token representing no speech
    no_speech: u32
}

This requires a quick explanation. The model field of the struct will hold an instance of the Whisper model and the tokenizer field is an instance of struct Tokenizer from the tokenizers crate, yet again by the awesome Huggingface team. Most of the fields in our struct WhisperWrap should be self explanatory, but mel_filters require a deep dive.

mel_filters: Mel-Spectrogram

Here’s a Fantastic Writeup on Mel-Spectrogram.

In text processing we convert the incoming text (words or related bytes etc.) into a bunch of ids or tokens - each would be a part of the vocabulary of the model, a set of tokens that the model knows from its training. Anything beyond that is unknown and more often in recent models would be represented by some form of UNK token. These tokens are your inputs to a text LLM.

Audio input is far more complicated. If I understand correctly, this is what is happening:

  1. A sliding window is applied to the audio waveform and then a Fourier transform is applied to each window. This converts the audio signal to time-frequency representation.

  2. These magnitudes are then converted to the Mel scale, which approximates human auditory perception, a pre-computed mel filterbank is applied to this data. This is what our mel_filter field of struct WhisperWrap maintains.

  3. A bunch of processing steps (log, normalization… depends I guess on models) later, specifically for Whisper the spectrograms are converted into 30s chunks, shorter sequences are padded. These are your input features for a whisper model.

Note
I don’t have expert level understanding of the steps above; I might have gotten it wrong or missed a few. Please do point out in case if find an issue.

Now, that we have some understanding of Mel Spectrogram and input to whisper, lets continue with our model initialization.

audio-instruct/src-tauri/src/whisper.rs
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
impl WhisperWrap {
    pub fn new(dir: &Path) -> Result<Self> {
        let device = device()?; // builds the device we are going to use. A helper function defined in `audio-instruct/src-tauri/src/utils.rs`

        let model_path = Self::model_path(dir)?;
        let (model, mel_filters, tokenizer, config) = Self::load_model(model_path.as_path(), &device)?;

        info!("Whisper ready!");
        Ok(
            Self { device, config, data: Arc::new(Mutex::new(Vec::new())), default_tokens: DefaultTokens::init(&tokenizer), mel_filters, model, tokenizer }
        )
    }

    // checks if the model file(s) exists or not and downloads it
    // Unlike a `gguf` model, we'll need 33 files for our model to work
    // the model weights will be in `model.safetensors`
    // the tokenizer.json file for the vocab
    // config.json for the model config
    fn model_path(base_dir: &Path) -> Result<PathBuf> {
        // Check if tokenizer.json exists, or create
        if !base_dir.join("tokenizer.json").is_file() {
            // a helper function defined in `audio-instruct/src-tauri/src/utils.rs`
            hf_download(base_dir, MODEL_REPO, "tokenizer.json", Some(LOCAL_MODEL_TOK))?;
        }

        // check if config.json exists else download
        if !base_dir.join("config.json").is_file() {
            // a helper function defined in `audio-instruct/src-tauri/src/utils.rs`
            hf_download(base_dir, MODEL_REPO, "config.json", Some(LOCAL_MODEL_CFG))?;
        }

        // finally, check if model.safetensors exist or download
        // this is a large file, and might take a while
        if !base_dir.join("model.safetensors").is_file() {
            // a helper function defined in `audio-instruct/src-tauri/src/utils.rs`
            hf_download(base_dir, MODEL_REPO, "model.safetensors", Some(LOCAL_MODEL_MODEL))?;
        }

        Ok(base_dir.to_path_buf())
    }

    fn load_model(model_dir: &Path, device: &Device) -> Result<(Arc<Mutex<Whisper>>, Vec<f32>, Tokenizer, Config)> {
        info!("Loading whisper");

        let tokenizer = match Tokenizer::from_file(model_dir.join(LOCAL_MODEL_TOK)) {
            Ok(t) => t,
            Err(e) => {
                error!("Error loading tokenizer: {e:?}");
                return Err(anyhow!("{e:?}"));
            }
        };

        let config: Config = serde_json::from_str(&std::fs::read_to_string(model_dir.join(LOCAL_MODEL_CFG))?)?;

        // loading the `mel_filters`, the binary filters can be found in `audio-instruct/src-tauri/melfilters<128>.bytes`
        // whisper-large-v3 models use `128 bin` inputs
        let mel = match &config.num_mel_bins {
            80 => include_bytes!(concat!(env!("CARGO_MANIFEST_DIR"), "/melfilters.bytes")).as_slice(),
            128 => include_bytes!(concat!(env!("CARGO_MANIFEST_DIR"), "/melfilters128.bytes")).as_slice(),
            nmel => anyhow::bail!("unexpected num_mel_bins {nmel}"),
        };
        let mut mel_filters = vec![0f32; mel.len() / 4];
        <byteorder::LittleEndian as byteorder::ByteOrder>::read_f32_into(mel, &mut mel_filters);

        // This is the construct in which we load `Candle` models, we define a `varbuilder` to load the `model.safetensors` file and then pass them on while model is loaded
        // Each `nn.Module` would access it's respective `weights` and `biases` from the `VarBuilder` object
        // This is unsafe :/
        let vb = unsafe { VarBuilder::from_mmaped_safetensors(&[model_dir.join(LOCAL_MODEL_MODEL)], DType::F32, device)? };
        let model = Arc::new(Mutex::new(Whisper::load(&vb, config.clone())?));

        Ok((
            model,
            mel_filters,
            tokenizer,
            config
        ))
    }
}

So that would be enough to initialize our model. The comments should be enough to detail out what we are doing, so I’ll not spend a lot of time on this now.

struct LlamaWrap and model LLaMA3 Quant GGUF

Loading the LLaMA3 GGUF model is a lot less involving, that’s because our gguf file is a self-contained entity with everything necessary to run the model. Also, the awesome team of Candle has already provided us with a simple interface to just load the model file.

audio-instruct/src-tauri/src/llama.rs
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
/// A struct to maintain a initialized Llama quantized `gguf` model and associated methods
pub struct LlamaWrap {
    device: Device,
    model: Arc<Mutex<ModelWeights>>,
    tokenizer: Tokenizer,
    sampler: Arc<Mutex<LogitsProcessor>>,
    stop_tokens: [u32; 2]
}

impl LlamaWrap {
    /// Initializer for new llama manager
    pub fn new(dir: &Path) -> Result<Self> {
        let device = device()?;

        let model_path = Self::model_path(dir)?;
        let (model, tokenizer) = Self::load_model(model_path.as_path(), &device)?;
        let stop_tokens = [tokenizer.token_to_id("<|eot_id|>").unwrap(), tokenizer.token_to_id("<|end_of_text|>").unwrap()];

        let sampler = Arc::new(
            Mutex::new(
                LogitsProcessor::from_sampling(42, Sampling::TopKThenTopP { k: TOP_K, p: TOP_P, temperature: TEMPERATURE })
            )
        );
        
        info!("Llama ready!");
        Ok(Self { device, model, tokenizer, sampler, stop_tokens })
    }

    fn model_path(base_dir: &Path) -> Result<PathBuf> {
        let model_path = base_dir;

        // The file doesn't exist, lets download it
        if !model_path.join(MODEL_FILE).is_file() {
            hf_download(base_dir, MODEL_REPO, MODEL_FILE, None)?;
        }

        // Download the tokenizer.json, and rename it, its a separate repo
        if !model_path.join(LOCAL_MODEL_TOK).is_file() {
            hf_download(base_dir, TOKENIZER_REPO, "tokenizer.json", Some(LOCAL_MODEL_TOK))?;
        }

        Ok(model_path.to_path_buf())
    }

    fn load_model(model_dir: &Path, device: &Device) -> Result<(Arc<Mutex<ModelWeights>>, Tokenizer)> {
        let model_file = model_dir.join(MODEL_FILE);
        let tok_file = model_dir.join(LOCAL_MODEL_TOK);

        info!("Loading gguf model @{:?}", model_file);

        let mut file = std::fs::File::open(model_file)?;
        // reading the params from file
        let model = gguf_file::Content::read(&mut file)?;

        let model = Arc::new(
            Mutex::new(
                ModelWeights::from_gguf(model, &mut file, device)?
            ));

        info!("Loading tokenizer @{:?}", tok_file);
        let tokenizer = Tokenizer::from_file(tok_file).unwrap();

        Ok((model, tokenizer))
    }
}

And that’s it. Now, let’s run our app and see if everything is in order.

Note

The LogitProcessor and ModelWeights structs require mutable borrow during the inference pass. Notice how we wrap these fields in Arc<Mutex> fields.

There are other ways of doing this, but I’ll leave if up to you to use your imagination. :)

RUST_LOG=info npm run tauri dev --release -- --features metal
Note

I’m passing the --release flag to get the best out of the load

The --features metal is to ensure that our fn device() helper creates a Metal device. If you are using cuda ensure that you have installed candle in cuda mode and pass --features cuda to the launch command

If everything has gone according to plan, the models would be downloaded, and the app should show the default tauri welcome window in a few seconds.

Now that our model loading is complete, let us set up our text inference flow.

Text inference

We’ve done this before in our previous series Desktop QA Assistant with LLaMA3, our current implementation should not be very different. It’s not a strait copy paste job because we have switched our ML framework from GGML wrapper llama3_cpp-rs to huggingface candle, but it should be very similar.

First, let’s add some methods to our struct LlamaWrap to accept an incoming text request, preprocess and generate some response.

audio-instruct/src-tauri/src/llama.rs
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
impl LlamaWrap {
    // .. initializer code ommitted ..

    /// Helper function to convert incoming `command` to templated prompt
    /// Prompt template: https://github.com/meta-llama/llama3/blob/main/llama/tokenizer.py#L202
    pub fn prompt(txt: &str) -> String {
        format!(
            "<|start_header_id|>system<|end_header_id|>\n\nYou are a knowledgeable, efficient, intelligent and direct AI assistant. Provide concise answers, focusing on the key information needed. Respond only with the answer to the instruction based on the given data. Do not add any additional text, introduction, context or explanation. If you are unsure about an answer, truthfully return \"Not Known\".<|eot_id|><|start_header_id|>user<|end_header_id|>\n\n{}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n",
            txt
        )
    }

    /// the main method to run inference
    /// 1. Convert's incoming text to fit into the `llama3` prompt template with `Self::prompt()` function
    /// 2. Encoded `input` tokens with the tokenizer
    /// 3. Runs generation loop and stores generated tokens into `all_tokens` vec - the output `logits` are passed through the sampler to get the next token
    /// 4. Finally `decodes` tokens generated to a string
    /// returns the (generated text, number of tokens generated, duration)
    pub fn infer(&self, instruct: &str) -> Result<(String, usize, std::time::Duration)> {
        let prompt = Self::prompt(instruct);
        let prompttokens = match self.tokenizer.encode(prompt, true) {
            Ok(t) => t,
            Err(e) => {
                error!("infer: error tokenizing prompt tokens: {e:?}");
                anyhow::bail!("error tokenizing prompt");
            }
        };

        // Acquiring a mutable instance of `self.model`
        let mut model = match self.model.lock() {
            Ok(m) => m,
            Err(e) => {
                error!("infer: acquiring model mutex lock: {e:?}");
                anyhow::bail!("error acquiring model");
            }
        };

        // Acquiring a mutable instance of `self.sample`
        let mut sampler = match self.sampler.lock() {
            Ok(s) => s,
            Err(e) => {
                error!("infer: acquiring model mutex lock: {e:?}");
                anyhow::bail!("error acquiring model");
            }
        };

        let mut all_tokens = vec![];
        let start_prompt_processing = std::time::Instant::now();

        let mut input = Tensor::new(prompttokens.get_ids(), &self.device)?.unsqueeze(0)?;
        let mut logits = model.forward(&input, 0)?;
        let mut next = sampler.sample(&logits.squeeze(0)?)?;

        all_tokens.push(next);

        for i in prompttokens.len() .. MAX_NEW_TOKENS {
            input = Tensor::new(&[next], &self.device)?.unsqueeze(0)?;

            logits = model.forward(&input, i)?;
            next = sampler.sample(&logits.squeeze(0)?)?;

            if self.stop_tokens.contains(&next) {
                break;
            }

            all_tokens.push(next);
        }

        
        let tk = match self.tokenizer.decode(&all_tokens[..], false) {
            Ok(t) => t,
            Err(e) => {
                error!("Error generating tokens: {e:?}");
                anyhow::bail!("Error generating tokens")
            }
        };

        Ok((
            tk,
            all_tokens.len(),
            std::time::Instant::now().duration_since(start_prompt_processing)
        ))
    }
}

Let’s write a quick testcase to check if our generation works.

audio-instruct/src-tauri/src/llama.rs
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
#[cfg(test)]
mod tests {
    use std::path::Path;

    use super::LlamaWrap;

    #[test]
    fn llama_infer() -> anyhow::Result<()> {
        pretty_env_logger::init();

        // Hardcode this based on your `app_data_dir` or use a ENV variable
        // Since we are not initializing the entire tauri app, we don't have access to it here
        // Or, at-least, I couldn't find one :/
        // For me, on my Mac M1, the `app_data_dir` points to "/Users/anubhab/Library/Application Support/audio-instruct.llm"
        let dir = Path::new(<REPLACE_WITH_YOUR_APP_DATA_DIR>);
        let mut llama = LlamaWrap::new(dir)?;

        let inf = llama.infer("Who is Steve Wozniak?")?;

        info!("{inf:?}");
        Ok(())
    }
}

Now, let’s run our test. Note the flags I’m passing to enable features metal.

cd src-tauri
RUST_LOG=info cargo test llama_infer --release --features metal -- --nocapture
cd ..

And WALLAH … our text inference works!

Let’s enable our text inference through the public API before moving on to the Audio inference. It’s simple, we have already stubbed out the pub fn text() method of our struct Instruct, we’ll just remove the todo!() with an actual call. Yet another todo!() is DONE.

audio-instruct/src-tauri/src/instruct.rs
19
20
21
22
23
24
25
26
27
28
29
30
impl Instruct {
    // ... code ommitted ...
    
    /// Public API to call text inference
    pub fn text(&self, instruct: &str) -> Result<Response> {
        let (txt, n_tokens, elapsed) = self.llama.infer(instruct)?;

        Ok(
            Response::new(&txt, n_tokens as u32, elapsed.as_secs())
        )
    }
}

Wrapping up & Next steps:

In this post we set up the required scaffolding for running our Desktop Application, loaded our LLaMA3 model and got it to generate some text for us. In the next post, Part II of this series, we’ll wrap this up with inference and integrated frontend. Till then, adios ..

Before we close today …

If you have found this post helpful consider spreading the word, it would act as a strong motivator for me to create more. If you found an issue or hit a snag, reach out to me @beingAnubhab.

Acknowledgements & reference

This project is built on the shoulder of stalwarts, a huge shout-out to all of them

  1. Rust maintainers for this awesome language
  2. The tauri app and its creators and maintainers
  3. Meta for creating LLaMA family of models and giving open-source AI a fair shot
  4. HuggingFace🤗 for everything they do, including Candle, distil-whisper and Tokenizer
  5. Georgi Gerganov for creating GGML/ GGUF movement
  6. Quant Factory Team for the LLaMA GGUF model files
  7. Svelte team and the creator Rich Harris for Svelte :)

And many, many more …