In our previous posts of this series, we set up embedding generation and storage using the Stella_en_1.5B_v5 model and created a mini Vector Store inspired by Spotify’s ANNOY. Now, we’ll look at extracting text from PDFs and using LLaMA3 for text generation.

Series Snapshot
  • Part 1: we implement Embedding generation from text data. We used Stella_en_1.5B_v5 and it’s Candle Transformers’ implementation as the embedding model and used the crate text-splitter to split our text into meaningful chunks.
  • Part 2: we build our own mini Vector Store inspired by Spotify’s ANNOY.
  • Part 3 (this): we code up a pipeline to analyze and extract text from .pdf files and also set the foundation for text generation with a LLaMA Model.
  • Part 4: we work on the retrieve-and-answer flow from our corpus.
  • Part 5: we implement and evaluate some techniques for a better RAG.

TL; DR

Github

Output

Note: This video has been sped up

Text Extraction from Files

Extracting text from files poses another set of unique challenges, especially if .pdf is involved! The horrors of .pdf processing has left deep scars on all of us who have attempted to work with PDF files. Here, we’ll make a shallow attempt at layout detection and text extraction from .pdf files using pre-trained models by the folks at Unstructured-IO.

Text extraction from .txt files are straightforward, just read it. For .pdf, the problem can largely be broken down into the following steps:

  1. Layout Analysis: Understanding the structure of each page of a .pdf. This will involve converting pages of a pdf file to images, running a layout detection model on it to deduce the regions of interest.

  2. Data extraction: Extracting text or other objects like tables, images etc. from the regions of interest. For now, we’ll skip tables and images. That’s another beast!

  3. OCR (if required): Image based PDF files would require us to use a ocr tool to extract text. I’ve used tesseract or ghostscript + tesseract before. To keep stuff simple-ish - let’s work only with text based PDF files.

Layout Detection

We’ll use a Detectron2 based model (onnx port) for our Layout Detection.

The model can be downloaded from HuggingFace Hub, detectron2_mask_rcnn_X_101_32x8d_FPN_3x/onnx.model and we’ll place it in our models directory renamed to layout.onnx.

Candle ONNX Error!

Attempting to run the Detectron2 based ONNX model with candle-onnx hit some snag, not all onnx ops are supported by candle-onnx yet!

While I’ve used the tract onnx runtime before, for this project I’ll lean on using ort onnx runtime.

Let’s add the crates ort and image to our dependencies. While at it, we’ll also add the crate pdfium-render crate which is a thin wrapper around google/pdfium, a library for rendering and working with .pdf files.

Note on `pdfium-render`
Setting up pdfium-render is a bit more involved than just adding it to our Cargo.toml. We’ll tackle this soon!

Let’s code up layout detection in a new file src-tauri/src/layout.rs. Let’s create a struct RegionOfInterest to hold areas of the document that we’ll work on and an enum DetectedElem to map the model’s predicted classes.

src-tauri/src/layout.rs
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
// Imports omitted

// Copied from: https://github.com/styrowolf/layoutparser-ort/blob/master/src/utils.rs
/// Utility function to convert bbox to a array
fn vec_to_bbox<T: Copy>(v: Vec<T>) -> [T; 4] {
    [v[0], v[1], v[2], v[3]]
}

// An emum to represent the classes of regions of interest
// detected by the `layout detection` model
#[derive(Debug, PartialEq, Eq, Copy, Clone)]
pub enum DetectedElem {
    Text,
    Title,
    List,
    Table,
    Figure,
}

impl Display for DetectedElem {
    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
        write!(
            f,
            "{}",
            match self {
                Self::Text => "Text",
                Self::Title => "Title",
                Self::List => "List",
                Self::Table => "Table",
                Self::Figure => "Figure",
            }
        )
    }
}

/// This struct represents a Region of interest
#[derive(Debug)]
pub struct RegionOfInterest {
    kind: DetectedElem,
    // the bounding box - x1, y1, x2, y2 - top, left, bottom, right
    bbox: [f32; 4],
    // confidence: f32,
}

impl RegionOfInterest {
    pub fn kind(&self) -> DetectedElem {
        self.kind
    }

    pub fn bbox(&self) -> [f32; 4] {
        self.bbox
    }
}

Now, we’ll need a struct for the Detectron2 based Layout Analysis model.

src-tauri/src/layout.rs
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
// Code omitted ..

/// A [`Detectron2`](https://github.com/facebookresearch/detectron2)-based model.
pub struct Detectron2Model {
    model: Session,
    label_map: [DetectedElem; 5],
}

// Copied from: https://github.com/styrowolf/layoutparser-ort/blob/master/src/utils.rs
/// Utility function to convert bbox to a array
fn vec_to_bbox<T: Copy>(v: Vec<T>) -> [T; 4] {
    [v[0], v[1], v[2], v[3]]
}

impl Detectron2Model {
    /// Required input image width.
    pub const REQUIRED_WIDTH: usize = 800;
    /// Required input image height.
    pub const REQUIRED_HEIGHT: usize = 1035;
    /// Default confidence threshold for detections.
    pub const DEFAULT_CONFIDENCE_THRESHOLD: f32 = 0.85;

    pub fn new() -> Result<Self> {
        // Loading and initializing the model from `onnx` file
        let model = Session::builder()?
            .with_optimization_level(GraphOptimizationLevel::Level3)?
            // We could make this a little more generic with `numcpus` crate
            .with_intra_threads(8)?
            .commit_from_file("../models/layout.onnx")?;

        // You could print the model outputs to figure out which prediction datapoints are useful
        // println!("{:?}", model.outputs);

        Ok(Self {
            model,
            label_map: [
                DetectedElem::Text,
                DetectedElem::Title,
                DetectedElem::List,
                DetectedElem::Table,
                DetectedElem::Figure,
            ],
        })
    }

    pub fn predict(&self, page: &image::DynamicImage) -> Result<Vec<RegionOfInterest>> {
        let (img_width, img_height, input) = self.preprocess(page)?;
        // let hm = HashMap::from([("x.1".to_string(), input)]);
        let res = self.model.run(ort::inputs!["x.1" => input]?)?;

        self.postprocess(res, img_width, img_height)
    }

    // 1. Resizes an image to the required format!
    // 2. Creates a tensor from the image
    // 3. Reshapes the tensor to channel first format
    // 4. Creates input ndarray for `ort` to consume
    fn preprocess(&self, img: &image::DynamicImage) -> Result<(u32, u32, ort::Value)> {
        // TODO: re-visit this and resize smarter
        let (img_width, img_height) = (img.width(), img.height());
        let img = img.resize_exact(
            Self::REQUIRED_WIDTH as u32,
            Self::REQUIRED_HEIGHT as u32,
            imageops::FilterType::Triangle,
        );

        let img = img.to_rgb8().into_raw();

        // Read the image as a tensor
        let t = Tensor::from_vec(
            img,
            (Self::REQUIRED_HEIGHT, Self::REQUIRED_WIDTH, 3),
            &Device::Cpu,
        )?
        .to_dtype(DType::F32)?
        .permute((2, 0, 1))? // shape: [3, height, width]
        .to_vec3::<f32>()?
        .concat()
        .concat();

        // Create a `ndarray` input for `ort` runtime to consume
        let input = ort::Value::from_array(
            ([3, Self::REQUIRED_HEIGHT, Self::REQUIRED_WIDTH], &t[..])
        )?;
        
        Ok((img_width, img_height, input.into()))
    }

    // Reads the predictions and converts them to regions of interest
    fn postprocess(
        &self,
        outputs: SessionOutputs<'_, '_>,
        width: u32,
        height: u32,
    ) -> Result<Vec<RegionOfInterest>> {
        // Extract predictions for bounding boxes,
        // labels and confidence scores
        // Shape: [num pred, 4]
        let bboxes = &outputs[0].try_extract_tensor::<f32>()?;
        // Shape: [num pred]
        let labels = &outputs[1].try_extract_tensor::<i64>()?;
         // 3 for MASK_RCNN_X_101_32X8D_FPN_3x | 2 for FASTER_RCNN_R_50_FPN_3X
         // Shape: [num pred]
        let confidence = &outputs[3].try_extract_tensor::<f32>()?;

        // We had originally `resized` the image to fit
        // the required input dimensions,
        // we are just going to adjust the predictions to factor in the resize
        let width_factor = width as f32 / Self::REQUIRED_WIDTH as f32;
        let height_factor = height as f32 / Self::REQUIRED_HEIGHT as f32;

        // Iterate over (region bounding boxes, predicted classes/ labels, and confidence scores)
        let mut elements = bboxes
            .rows()
            .into_iter()
            .zip(labels.iter().zip(confidence.iter()))
            .filter_map(|(bbox, (&label, &confidence))| {
                // Skip everything below some confidence score we want to work with
                if confidence < Self::DEFAULT_CONFIDENCE_THRESHOLD {
                    return None;
                }

                // Getting the predicted label from the predicted index
                let label = self.label_map.get(label as usize)?;
                // We don't have any way of interpreting Figure and Table as text
                // So, we'll skip that
                if label == &DetectedElem::Figure || label == &DetectedElem::Table {
                    return None;
                }
                let [x1, y1, x2, y2] = vec_to_bbox(bbox.iter().copied().collect::<Vec<_>>());
                // Adjusting the predicted bounding box to our original image size
                Some(RegionOfInterest {
                    kind: *label,
                    bbox: [
                        x1 * width_factor,
                        y1 * height_factor,
                        x2 * width_factor,
                        y2 * height_factor,
                    ],
                    confidence
                })
            })
            .collect::<Vec<_>>();

        // Now we sort the predictions to (kind of) visual hierarchy
        // from top left
        elements.par_sort_unstable_by(|a, b| {
            (a.bbox()[1].max(a.bbox()[3])).total_cmp(&(b.bbox()[1].max(b.bbox()[3])))
        });

        Ok(elements)
    }
}

The comments in the code should suffice as an overview of what’s happening here!

Pro Tip: Figuring out Model Outputs

Most onnx models would be well documented and we’ll know the outputs we work with. In case its NOT, print the model.output field.

For our model that gives us:

[Output { name: "onnx::Concat_2670", output_type: Tensor { ty: Float32, dimensions: [-1, 4] } }, Output { name: "value.3", output_type: Tensor { ty: Int64, dimensions: [-1] } }, Output { name: "value.7", output_type: Tensor { ty: Float32, dimensions: [-1, -1, -1, -1] } }, Output { name: "value", output_type: Tensor { ty: Float32, dimensions: [-1] } }, Output { name: "onnx::Split_570", output_type: Tensor { ty: Int64, dimensions: [2] } }]

Let’s elaborate:

  1. 0th index is of shape [-1, 4] - meaning [variable size, 4] - we can assume this to be the bounding boxes.
  2. The 1st index of the output is of type int and shape [-1] (variable) - we can be reasonably certain that this is the index of the predicted class.
  3. The 3rd index is of type Float32 and variable size, this would be your predicted confidence
  4. Finally, the 4th index says it’s [2], which would generally indicate that we are looking at a binary classification of some kind. I’m pretty sure that refers to landscape vs portrait classification of the given document image, though, we have not used the 4th output.

Let’s test out the predictions!

src-tauri/src/layout.rs
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
#[cfg(test)]
mod tests {
    use anyhow::Result;

    use super::Detectron2Model;

    #[test]
    fn single_page_layout() -> Result<()> {
        let d2model = Detectron2Model::new()?;
        let img = image::open("../test-data/paper-image.jpg")?;

        let pred = d2model.predict(&img)?;
        println!("{pred:?}");

        Ok(())
    }
}

This is the test image we are using!

test image

cd src-tauri
cargo test single_page_layout --release -- --nocapture
cd ..

And we should get …

[RegionOfInterest { kind: Text, bbox: [94.5113, 664.0678, 1128.6223, 739.99677], confidence: 0.9726107 }, RegionOfInterest { kind: Text, bbox: [96.92838, 754.4805, 584.53, 954.15466], confidence: 0.9837298 }, RegionOfInterest { kind: Title, bbox: [100.300995, 982.7206, 438.7208, 1009.6998], confidence: 0.98317474 }, RegionOfInterest { kind: Text, bbox: [630.19, 761.4926, 1122.4296, 1029.26], confidence: 0.9966737 }, RegionOfInterest { kind: Text, bbox: [634.7173, 1072.1256, 1129.8917, 1148.6306], confidence: 0.8633146 }, RegionOfInterest { kind: Text, bbox: [85.23773, 1013.9407, 585.50305, 1315.0413], confidence: 0.99716276 }, RegionOfInterest { kind: Text, bbox: [94.63046, 1319.7373, 593.8499, 1419.4666], confidence: 0.9766832 }, RegionOfInterest { kind: Text, bbox: [646.0491, 1415.9106, 1127.1869, 1478.3527], confidence: 0.99850094 }, RegionOfInterest { kind: Text, bbox: [102.88627, 1444.6996, 594.5993, 1484.9055], confidence: 0.9771153 }]

We don’t know how accurate the predictions are yet! But, as far as the flow is concerned, that concludes our Layout Detection.

Text Extraction

To text extraction based on the detected layout we’ll need to:

  1. Convert each page of a .pdf to an image - pdfium-render and image crates we added earlier has a role to play here
  2. Predict the layout of the page
  3. From the predicted regions of interest, we’ll retrieve the text
  4. Generate embeddings for the text
`pdfium` Quickstart

The crate pdfium-render provides high-level wrappers on top of the original pdfium C++ library, but it doesn’t ship with the required library files. The author(s) of pdfium-render provides multiple ways of binding the C++ libraries, we are going to take the dynamic approach.

Steps to get this up and running:

  1. Create a .cargo directory inside src-tauri and a src-tauri/.cargo/config.toml file with the following content:
[env]
PDFIUM_DYNAMIC_LIB_PATH = { value = "../binaries/pdfium", relative = true }
  1. Download the .tgz archive for your OS and platform from pdfium-binaries repo for version v6666 and put it inside the directory binaries/pdfium in our project root.

  2. Untar it inside the binaries/ppdfium directory:

tar -xvzf <downloaded_file>.tgz
  1. I had to move the binaries/pdfium/lib/libpdfium.dylib in my case to the directory binaries/pdfium/libpdfium.dylib - Mac won’t allow the execution because it can’t verify the developer. Moving it changes the metadata of the file, that’s probably why it works!🧐🤯🫨😵‍💫

With pdfium-render ready, let’s wrap up the .pdf -> layout analysis -> text flow. We create a strutct PdfProc that will hold everything we need to analyze and extract data from .pdf files.

src-tauri/src/doc.rs
28
29
30
31
32
33
34
35
// Imports and code omitted

pub struct PdfProc {
    pdfs: Vec<PathBuf>,
    layout: Detectron2Model,
    pdfium: Pdfium,
    pdfium_cfg: PdfRenderConfig,
}

Now we expose a bunch of methods for this flow to work, for brevity I’ll focus on the key methods.

src-tauri/src/doc.rs
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
// code omitted ..
impl PdfProc {
    // initializer ..
    pub fn new(model_path: &Path, pdfs: Vec<PathBuf>) -> Result<Self> {
        // Straitforward implementation for initializing the different fields of the struct
        Ok(Self { ... })
    }

    /// Returns the total number of pages to be analyzed
    pub fn estimate(&self) -> usize {
        // this can be used to show some progress in real world
    }

    /// Extract text from `pdfs`
    pub fn extract(&self, send: Sender<ExtractorEvt>) -> Result<Vec<Vec<(String, FileKind)>>> {
        // for each `.pdf` file we are going to convert the pages to images
        let file_encoded = self
            .pdfs
            .iter()
            .filter_map(|file| {
                let pdf = self.pdfium.load_pdf_from_file(&file, None).ok()?;

                self.process_pages(file, pdf, send.clone())
            })
            .collect::<Vec<_>>();

        Ok(file_encoded)
    }

    /// Processes each pages
    /// - reneders page with rendering config
    /// - runs layout detection
    /// - reads text from bounding boxes detected by the model
    pub fn process_pages(
        &self,
        file: &PathBuf,
        doc: PdfDocument<'_>,
        send: Sender<ExtractorEvt>,
    ) -> Option<Vec<(String, FileKind)>> {
        Some(
            doc.pages()
                .iter()
                .enumerate()
                .filter_map(|(idx, page)| {
                    // Convert the page to an image
                    let img = page.render_with_config(&self.pdfium_cfg).ok()?.as_image(); // Renders this page to an image::DynamicImage...

                    // Keep track of the factors by which the page and images of pages were resized to
                    // This is required to get accurate output from the predicted regions of interest
                    let w_f = page.width().value / img.width() as f32;
                    let h_f = page.height().value / img.height() as f32;
                    let pg_num = idx + 1;

                    // send the image for prediction
                    // and for each predicted `region of interest`
                    // fetch the text inside the bounding box
                    let text = self
                        .layout
                        .predict(&img)
                        .ok()?
                        .iter()
                        .filter_map(|e| {
                            // The bounding box for the region of interest
                            let bbox = e.bbox(); // x1, y1, x2, y2

                            // The bounding boxes for the predicted regions follow a `left-top` co-ordinate system
                            // But `pdfium` uses a bottom-left coordinate system, let's convert it
                            // We'll also factor in the original page size here
                            let top = page.height().value - bbox[1] * h_f + PADDING;
                            let bottom = page.height().value - bbox[3] * h_f - PADDING;
                            let left = bbox[0] * w_f - PADDING;
                            let right = bbox[2] * w_f + PADDING;

                            // Now, we have the `pdfium` compatible bounding boxes
                            // Let's fetch the text
                            let text = page
                                .text()
                                .ok()?
                                .inside_rect(PdfRect::new_from_values(bottom, left, top, right))
                                .replace("\t", " ")
                                .replace("\r\n", "\n");

                            Some(match e.kind() {
                                // We are using `MarkDownSplitter` for our text splitting task
                                // Here we are adding `##` to mark the generated text as title
                                DetectedElem::Title => {
                                    format!("## {}\n", text.replace("\n", "; "))
                                }
                                // Rest of the text remains as is
                                DetectedElem::Text | DetectedElem::List => text,
                                _ => unimplemented!(),
                            })
                        })
                        .collect::<Vec<_>>()
                        .join("\n");
                    
                    if let Err(e) = send.send(ExtractorEvt::Page) {
                        eprintln!("Warn: error sending page event: {e:?}");
                    }
                    
                    Some((text, FileKind::Pdf((file.to_owned(), pg_num))))
                })
                .collect::<Vec<_>>(),
        )
    }
}

Notice that some methods accept a std::sync::mpsc::Sender, we’ll elaborate on this later but the idea here is to emit some execution status to show progress on the client side.

For reference, this is the first page of the .pdf we are using for this test:

first page archaeology pdf

cd src-tauri
cargo test extract_from_pdf --release -- --nocapture
cd ..

The results are very accurate!

first layout pred results

Pro Tip

In real world data the predictions Layout Detection will not always be clean!

You’ll need to play around the model hyperparameters, different models, padding for text extraction etc. to narrow down on a reasonable output that works Most of the Time.

Text Generation with LLaMA

So far we have tackled Embeddings generation with Stella_en_1.5B_v5 model, Document Layout Analysis with a Detectron2 based Mask RCNN model and we have extracted texts from .pdf files and learnt how to split text into semantic chunks.

To conclude this Part 3 of this series, lay the foundation for text generation using LLaMA3. I’m using LLaMA 3.1 series model, but you can choose anything you prefer based on your hardware.

Quickstart with LLaMA3.1
  • LLaMA3.x series requires you to accept the Meta LLaMA License - follow this blog post to known more about LLaMA 3.1 and how to acquire the LLaMA models
  • Once you have access to the models in HuggingFace - download the weights (.safetensors) files, tokenizer.json and config.json to our models directory
Pro tip!
The proper way of fetching model weights and related files would be to accept the token string provided by Meta and use it to dynamically download the files from HuggingFace Hub on application init. I’ll leave that implementation up to you!

The struct Generator

With the model weights, tokenizer and config files tucked away in our models directory, let’s create a struct Generator to manage the text generation.

src-tauri/src/gen.rs
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
// Imports omitted ..

// Sampling constants
const TEMPERATURE: f64 = 0.8;
const TOP_P: f64 = 0.95;
const TOP_K: usize = 32;

/// A struct to maintain a initialized Llama quantized `gguf` model and associated methods
pub struct Generator {
    cfg: Config,
    device: Device,
    model: Llama,
    tokenizer: Tokenizer,
    sampler: LogitsProcessor,
    stop_tokens: [u32; 2],
}

impl Generator {
    // Download model `safetensor` files into your project dir `models` folder
    // I'm using LLaMA3.1 8B instruct, you can use whatever you want
    const MODEL_FILES: [&'static str; 2] = [
        "model-00001-of-00002.safetensors",
        "model-00002-of-00002.safetensors",
    ];
    const TOKENIZER_FILE: &'static str = "llama_tokenizer.json";
    const MODEL_CONFIG_FILE: &'static str = "llama_config.json";

    /// Initializer for new llama manager
    pub fn new(dir: &Path, device: &Device) -> Result<Self> {
        let mut device = device.to_owned();
        if let candle_core::Device::Metal(mut m) = device {
            m.set_use_mlx_mm(false);
            device = Device::Metal(m);
        }

        let (model, mut cfg, tokenizer) = Self::load_model(dir, &device)?;
        let stop_tokens = [
            tokenizer.token_to_id("<|eot_id|>").unwrap(),
            tokenizer.token_to_id("<|end_of_text|>").unwrap(),
        ];
        cfg.max_position_embeddings = 4096;

        // Initializing the sampler
        let sampler = LogitsProcessor::from_sampling(
            42,
            Sampling::TopKThenTopP {
                k: TOP_K,
                p: TOP_P,
                temperature: TEMPERATURE,
            },
        );
        
        println!("Llama ready!");
        Ok(Self {
            cfg,
            device: device.clone(),
            model,
            tokenizer,
            sampler,
            // sampler2,
            stop_tokens,
        })
    }

    // A utility function to load the model and tokenizer
    fn load_model(model_dir: &Path, device: &Device) -> Result<(Llama, Config, Tokenizer)> {
        // let model_file = model_dir.join(Self::MODEL_FILE);
        let tok_file = model_dir.join(Self::TOKENIZER_FILE);
        let cfg_file = model_dir.join(Self::MODEL_CONFIG_FILE);
        let model_files = Self::MODEL_FILES
            .iter()
            .map(|mf| model_dir.join(mf))
            .collect::<Vec<_>>();

        println!("Loading LLaMA ..");
        let start = Instant::now();
        let cfg =
            serde_json::from_slice::<LlamaConfig>(&std::fs::read(&cfg_file)?)?.into_config(false);
        let vb = unsafe { VarBuilder::from_mmaped_safetensors(&model_files, DType::BF16, device)? };
        let llama = Llama::load(vb, &cfg)?;
        println!("LLaMA loaded in {}s", (Instant::now() - start).as_secs());

        let tokenizer = Tokenizer::from_file(tok_file).unwrap();

        Ok((llama, cfg, tokenizer))
    }

    // Utility function to run the generation loop
    fn generate(&mut self, prompt: &str) -> Result<String> {
        // Tokenize the input
        let input = self
            .tokenizer
            .encode(prompt, true)
            .map_err(|e| anyhow!(e))?;

        if input.len() >= self.cfg.max_position_embeddings {
            return Err(anyhow!("large input tokens!"));
        }

        let mut cache = Cache::new(true, DType::BF16, &self.cfg, &self.device)?;
        // Creating a tensor of input tokens
        let mut ip = Tensor::new(input.get_ids(), &self.device)?.unsqueeze(0)?;

        let mut start = std::time::Instant::now();

        // The forward pass to the first token
        let mut logits = self.model.forward(&ip, 0, &mut cache)?;
        // Sampling the first token
        let mut next = self.sampler.sample(&logits.squeeze(0)?)?;
        
        println!(
            "{} prompt tokens processed @ {}t/s",
            input.len(),
            input.len() as f32 / (std::time::Instant::now() - start).as_secs() as f32
        );
        // A container for all tokens generated
        let mut all_tokens = vec![next];

        start = std::time::Instant::now();

        // Forward pass - decoder loop
        for i in input.len()..self.cfg.max_position_embeddings {
            ip = Tensor::new(&[next], &self.device)?.unsqueeze(0)?;

            logits = self.model.forward(&ip, i, &mut cache)?;
            next = self.sampler.sample(&logits.squeeze(0)?).unwrap();
            if self.stop_tokens.contains(&next) {
                break;
            }

            all_tokens.push(next);
        }
        println!(
            "{} tokens generated @ {}t/s",
            all_tokens.len() - 1,
            (all_tokens.len() - 1) as f32 / (std::time::Instant::now() - start).as_secs_f32()
        );

        // Decode tokens and return result
        Ok(match self.tokenizer.decode(&all_tokens[..], false) {
            Ok(t) => t,
            Err(e) => {
                eprintln!("Error generating tokens: {e:?}");
                anyhow::bail!("Error generating tokens")
            }
        })
    }
}

With that the foundation for the G of our RAG is ready.

Write a test!

Write a test for the Generator. The new( .. ) function accepts the path to your models directory and a candle Device while the method generate( .. ) accepts a prompt string.

If you are new at text generation read up about prompt templates and try to create the prompt string based on the LLaMA3 prompt template.

Next Steps

Now we have all of the different components ready for our RAG to work, we have our Retrieval ready with Embeddings and a Vector Store, and Generation with LLaMA. In the next post, Part 4, we’ll tie these independent blocks together.