[v2,2/4] ipa: rpi: controller: awb: Add Neural Network AWB
diff mbox series

Message ID 20251211142824.26635-3-david.plowman@raspberrypi.com
State Superseded
Headers show
Series
  • Raspberry Pi AWB using neural networks
Related show

Commit Message

David Plowman Dec. 11, 2025, 2:25 p.m. UTC
From: Peter Bailey <peter.bailey@raspberrypi.com>

Add an AWB algorithm which uses neural networks.

Signed-off-by: Peter Bailey <peter.bailey@raspberrypi.com>
---
 meson_options.txt                     |   5 +
 src/ipa/rpi/controller/meson.build    |   9 +
 src/ipa/rpi/controller/rpi/awb_nn.cpp | 446 ++++++++++++++++++++++++++
 3 files changed, 460 insertions(+)
 create mode 100644 src/ipa/rpi/controller/rpi/awb_nn.cpp

Comments

David Plowman Dec. 11, 2025, 3:36 p.m. UTC | #1
Hi

On Thu, 11 Dec 2025 at 14:28, David Plowman
<david.plowman@raspberrypi.com> wrote:
>
> From: Peter Bailey <peter.bailey@raspberrypi.com>
>
> Add an AWB algorithm which uses neural networks.
>
> Signed-off-by: Peter Bailey <peter.bailey@raspberrypi.com>
> ---
>  meson_options.txt                     |   5 +
>  src/ipa/rpi/controller/meson.build    |   9 +
>  src/ipa/rpi/controller/rpi/awb_nn.cpp | 446 ++++++++++++++++++++++++++
>  3 files changed, 460 insertions(+)
>  create mode 100644 src/ipa/rpi/controller/rpi/awb_nn.cpp
>
> diff --git a/meson_options.txt b/meson_options.txt
> index 5954e028..89eece52 100644
> --- a/meson_options.txt
> +++ b/meson_options.txt
> @@ -78,6 +78,11 @@ option('qcam',
>          value : 'disabled',
>          description : 'Compile the qcam test application')
>
> +option('rpi-awb-nn',
> +        type : 'feature',
> +        value : 'auto',
> +        description : 'Enable the Raspberry Pi Neural Network AWB algorithm')
> +
>  option('test',
>          type : 'boolean',
>          value : false,
> diff --git a/src/ipa/rpi/controller/meson.build b/src/ipa/rpi/controller/meson.build
> index 90d9e285..9c261d99 100644
> --- a/src/ipa/rpi/controller/meson.build
> +++ b/src/ipa/rpi/controller/meson.build
> @@ -33,6 +33,15 @@ rpi_ipa_controller_deps = [
>      libcamera_private,
>  ]
>
> +tflite_dep = dependency('tensorflow-lite', required : false)

Oops. I edited this to read

tflite_dep = dependency('tensorflow-lite', required : get_option('rpi-awb-nn'))

and then failed to hit the save button. Version 3 incoming shortly!

David

> +
> +if tflite_dep.found()
> +    rpi_ipa_controller_sources += files([
> +        'rpi/awb_nn.cpp',
> +    ])
> +    rpi_ipa_controller_deps += tflite_dep
> +endif
> +
>  rpi_ipa_controller_lib = static_library('rpi_ipa_controller', rpi_ipa_controller_sources,
>                                          include_directories : libipa_includes,
>                                          dependencies : rpi_ipa_controller_deps)
> diff --git a/src/ipa/rpi/controller/rpi/awb_nn.cpp b/src/ipa/rpi/controller/rpi/awb_nn.cpp
> new file mode 100644
> index 00000000..35d1270e
> --- /dev/null
> +++ b/src/ipa/rpi/controller/rpi/awb_nn.cpp
> @@ -0,0 +1,446 @@
> +/* SPDX-License-Identifier: BSD-2-Clause */
> +/*
> + * Copyright (C) 2025, Raspberry Pi Ltd
> + *
> + * AWB control algorithm using neural network
> + */
> +
> +#include <chrono>
> +#include <condition_variable>
> +#include <thread>
> +
> +#include <libcamera/base/file.h>
> +#include <libcamera/base/log.h>
> +
> +#include <tensorflow/lite/interpreter.h>
> +#include <tensorflow/lite/kernels/register.h>
> +#include <tensorflow/lite/model.h>
> +
> +#include "../awb_algorithm.h"
> +#include "../awb_status.h"
> +#include "../lux_status.h"
> +#include "libipa/pwl.h"
> +
> +#include "alsc_status.h"
> +#include "awb.h"
> +
> +using namespace libcamera;
> +
> +LOG_DECLARE_CATEGORY(RPiAwb)
> +
> +constexpr double kDefaultCT = 4500.0;
> +
> +/*
> + * The neural networks are trained to work on images rendered at a canonical
> + * colour temperature. That value is 5000K, which must be reproduced here.
> + */
> +constexpr double kNetworkCanonicalCT = 5000.0;
> +
> +#define NAME "rpi.nn.awb"
> +
> +namespace RPiController {
> +
> +struct AwbNNConfig {
> +       AwbNNConfig() {}
> +       int read(const libcamera::YamlObject &params, AwbConfig &config);
> +
> +       /* An empty model will check default locations for model.tflite */
> +       std::string model;
> +       float minTemp;
> +       float maxTemp;
> +
> +       bool enableNn;
> +
> +       /* CCM matrix for canonical network CT */
> +       double ccm[9];
> +};
> +
> +class AwbNN : public Awb
> +{
> +public:
> +       AwbNN(Controller *controller = NULL);
> +       ~AwbNN();
> +       char const *name() const override;
> +       void initialise() override;
> +       int read(const libcamera::YamlObject &params) override;
> +
> +protected:
> +       void doAwb() override;
> +       void prepareStats() override;
> +
> +private:
> +       bool isAutoEnabled() const;
> +       AwbNNConfig nnConfig_;
> +       void transverseSearch(double t, double &r, double &b);
> +       RGB processZone(RGB zone, float red_gain, float blue_gain);
> +       void awbNN();
> +       void loadModel();
> +
> +       libcamera::Size zoneSize_;
> +       std::unique_ptr<tflite::FlatBufferModel> model_;
> +       std::unique_ptr<tflite::Interpreter> interpreter_;
> +};
> +
> +int AwbNNConfig::read(const libcamera::YamlObject &params, AwbConfig &config)
> +{
> +       model = params["model"].get<std::string>("");
> +       minTemp = params["min_temp"].get<float>(2800.0);
> +       maxTemp = params["max_temp"].get<float>(7600.0);
> +
> +       for (int i = 0; i < 9; i++)
> +               ccm[i] = params["ccm"][i].get<double>(0.0);
> +
> +       enableNn = params["enable_nn"].get<int>(1);
> +
> +       if (enableNn) {
> +               if (!config.hasCtCurve()) {
> +                       LOG(RPiAwb, Error) << "CT curve not specified";
> +                       enableNn = false;
> +               }
> +
> +               if (!model.empty() && model.find(".tflite") == std::string::npos) {
> +                       LOG(RPiAwb, Error) << "Model must be a .tflite file";
> +                       enableNn = false;
> +               }
> +
> +               bool validCcm = true;
> +               for (int i = 0; i < 9; i++)
> +                       if (ccm[i] == 0.0)
> +                               validCcm = false;
> +
> +               if (!validCcm) {
> +                       LOG(RPiAwb, Error) << "CCM not specified or invalid";
> +                       enableNn = false;
> +               }
> +
> +               if (!enableNn) {
> +                       LOG(RPiAwb, Warning) << "Neural Network AWB mis-configured - switch to Grey method";
> +               }
> +       }
> +
> +       if (!enableNn) {
> +               config.sensitivityR = config.sensitivityB = 1.0;
> +               config.greyWorld = true;
> +       }
> +
> +       return 0;
> +}
> +
> +AwbNN::AwbNN(Controller *controller)
> +       : Awb(controller)
> +{
> +       zoneSize_ = getHardwareConfig().awbRegions;
> +}
> +
> +AwbNN::~AwbNN()
> +{
> +}
> +
> +char const *AwbNN::name() const
> +{
> +       return NAME;
> +}
> +
> +int AwbNN::read(const libcamera::YamlObject &params)
> +{
> +       int ret;
> +
> +       ret = config_.read(params);
> +       if (ret)
> +               return ret;
> +
> +       ret = nnConfig_.read(params, config_);
> +       if (ret)
> +               return ret;
> +
> +       return 0;
> +}
> +
> +static bool checkTensorShape(TfLiteTensor *tensor, const int *expectedDims, const int expectedDimsSize)
> +{
> +       if (tensor->dims->size != expectedDimsSize)
> +               return false;
> +
> +       for (int i = 0; i < tensor->dims->size; i++) {
> +               if (tensor->dims->data[i] != expectedDims[i]) {
> +                       return false;
> +               }
> +       }
> +       return true;
> +}
> +
> +static std::string buildDimString(const int *dims, const int dimsSize)
> +{
> +       std::string s = "[";
> +       for (int i = 0; i < dimsSize; i++) {
> +               s += std::to_string(dims[i]);
> +               if (i < dimsSize - 1)
> +                       s += ",";
> +               else
> +                       s += "]";
> +       }
> +       return s;
> +}
> +
> +void AwbNN::loadModel()
> +{
> +       std::string modelPath;
> +       if (getTarget() == "bcm2835") {
> +               modelPath = "/ipa/rpi/vc4/awb_model.tflite";
> +       } else {
> +               modelPath = "/ipa/rpi/pisp/awb_model.tflite";
> +       }
> +
> +       if (nnConfig_.model.empty()) {
> +               std::string root = utils::libcameraSourcePath();
> +               if (!root.empty()) {
> +                       modelPath = root + modelPath;
> +               } else {
> +                       modelPath = LIBCAMERA_DATA_DIR + modelPath;
> +               }
> +
> +               if (!File::exists(modelPath)) {
> +                       LOG(RPiAwb, Error) << "No model file found in standard locations";
> +                       nnConfig_.enableNn = false;
> +                       return;
> +               }
> +       } else {
> +               modelPath = nnConfig_.model;
> +       }
> +
> +       LOG(RPiAwb, Debug) << "Attempting to load model from: " << modelPath;
> +
> +       model_ = tflite::FlatBufferModel::BuildFromFile(modelPath.c_str());
> +
> +       if (!model_) {
> +               LOG(RPiAwb, Error) << "Failed to load model from " << modelPath;
> +               nnConfig_.enableNn = false;
> +               return;
> +       }
> +
> +       tflite::MutableOpResolver resolver;
> +       tflite::ops::builtin::BuiltinOpResolver builtin_resolver;
> +       resolver.AddAll(builtin_resolver);
> +       tflite::InterpreterBuilder(*model_, resolver)(&interpreter_);
> +       if (!interpreter_) {
> +               LOG(RPiAwb, Error) << "Failed to build interpreter for model " << nnConfig_.model;
> +               nnConfig_.enableNn = false;
> +               return;
> +       }
> +
> +       interpreter_->AllocateTensors();
> +       TfLiteTensor *inputTensor = interpreter_->input_tensor(0);
> +       TfLiteTensor *inputLuxTensor = interpreter_->input_tensor(1);
> +       TfLiteTensor *outputTensor = interpreter_->output_tensor(0);
> +       if (!inputTensor || !inputLuxTensor || !outputTensor) {
> +               LOG(RPiAwb, Error) << "Model missing input or output tensor";
> +               nnConfig_.enableNn = false;
> +               return;
> +       }
> +
> +       const int expectedInputDims[] = { 1, (int)zoneSize_.height, (int)zoneSize_.width, 3 };
> +       const int expectedInputLuxDims[] = { 1 };
> +       const int expectedOutputDims[] = { 1 };
> +
> +       if (!checkTensorShape(inputTensor, expectedInputDims, 4)) {
> +               LOG(RPiAwb, Error) << "Model input tensor dimension mismatch. Expected: " << buildDimString(expectedInputDims, 4)
> +                                  << ", Got: " << buildDimString(inputTensor->dims->data, inputTensor->dims->size);
> +               nnConfig_.enableNn = false;
> +               return;
> +       }
> +
> +       if (!checkTensorShape(inputLuxTensor, expectedInputLuxDims, 1)) {
> +               LOG(RPiAwb, Error) << "Model input lux tensor dimension mismatch. Expected: " << buildDimString(expectedInputLuxDims, 1)
> +                                  << ", Got: " << buildDimString(inputLuxTensor->dims->data, inputLuxTensor->dims->size);
> +               nnConfig_.enableNn = false;
> +               return;
> +       }
> +
> +       if (!checkTensorShape(outputTensor, expectedOutputDims, 1)) {
> +               LOG(RPiAwb, Error) << "Model output tensor dimension mismatch. Expected: " << buildDimString(expectedOutputDims, 1)
> +                                  << ", Got: " << buildDimString(outputTensor->dims->data, outputTensor->dims->size);
> +               nnConfig_.enableNn = false;
> +               return;
> +       }
> +
> +       if (inputTensor->type != kTfLiteFloat32 || inputLuxTensor->type != kTfLiteFloat32 || outputTensor->type != kTfLiteFloat32) {
> +               LOG(RPiAwb, Error) << "Model input and output tensors must be float32";
> +               nnConfig_.enableNn = false;
> +               return;
> +       }
> +
> +       LOG(RPiAwb, Info) << "Model loaded successfully from " << modelPath;
> +       LOG(RPiAwb, Debug) << "Model validation successful - Input Image: "
> +                          << buildDimString(expectedInputDims, 4)
> +                          << ", Input Lux: " << buildDimString(expectedInputLuxDims, 1)
> +                          << ", Output: " << buildDimString(expectedOutputDims, 1) << " floats";
> +}
> +
> +void AwbNN::initialise()
> +{
> +       Awb::initialise();
> +
> +       if (nnConfig_.enableNn) {
> +               loadModel();
> +               if (!nnConfig_.enableNn) {
> +                       LOG(RPiAwb, Warning) << "Neural Network AWB failed to load - switch to Grey method";
> +                       config_.greyWorld = true;
> +                       config_.sensitivityR = config_.sensitivityB = 1.0;
> +               }
> +       }
> +}
> +
> +void AwbNN::prepareStats()
> +{
> +       zones_.clear();
> +       /*
> +        * LSC has already been applied to the stats in this pipeline, so stop
> +        * any LSC compensation.  We also ignore config_.fast in this version.
> +        */
> +       generateStats(zones_, statistics_, 0.0, 0.0, getGlobalMetadata(), 0.0, 0.0, 0.0);
> +       /*
> +        * apply sensitivities, so values appear to come from our "canonical"
> +        * sensor.
> +        */
> +       for (auto &zone : zones_) {
> +               zone.R *= config_.sensitivityR;
> +               zone.B *= config_.sensitivityB;
> +       }
> +}
> +
> +void AwbNN::transverseSearch(double t, double &r, double &b)
> +{
> +       int spanR = -1, spanB = -1;
> +       config_.ctR.eval(t, &spanR);
> +       config_.ctB.eval(t, &spanB);
> +
> +       const int diff = 10;
> +       double rDiff = config_.ctR.eval(t + diff, &spanR) -
> +                      config_.ctR.eval(t - diff, &spanR);
> +       double bDiff = config_.ctB.eval(t + diff, &spanB) -
> +                      config_.ctB.eval(t - diff, &spanB);
> +
> +       ipa::Pwl::Point transverse({ bDiff, -rDiff });
> +       if (transverse.length2() < 1e-6)
> +               return;
> +
> +       transverse = transverse / transverse.length();
> +       double transverseRange = config_.transverseNeg + config_.transversePos;
> +       const int maxNumDeltas = 12;
> +       int numDeltas = floor(transverseRange * 100 + 0.5) + 1;
> +       numDeltas = numDeltas < 3 ? 3 : (numDeltas > maxNumDeltas ? maxNumDeltas : numDeltas);
> +
> +       ipa::Pwl::Point points[maxNumDeltas];
> +       int bestPoint = 0;
> +
> +       for (int i = 0; i < numDeltas; i++) {
> +               points[i][0] = -config_.transverseNeg +
> +                              (transverseRange * i) / (numDeltas - 1);
> +               ipa::Pwl::Point rbTest = ipa::Pwl::Point({ r, b }) +
> +                                        transverse * points[i].x();
> +               double rTest = rbTest.x(), bTest = rbTest.y();
> +               double gainR = 1 / rTest, gainB = 1 / bTest;
> +               double delta2Sum = computeDelta2Sum(gainR, gainB, 0.0, 0.0);
> +               points[i][1] = delta2Sum;
> +               if (points[i].y() < points[bestPoint].y())
> +                       bestPoint = i;
> +       }
> +
> +       bestPoint = std::clamp(bestPoint, 1, numDeltas - 2);
> +       ipa::Pwl::Point rbBest = ipa::Pwl::Point({ r, b }) +
> +                                transverse * interpolateQuadatric(points[bestPoint - 1],
> +                                                                  points[bestPoint],
> +                                                                  points[bestPoint + 1]);
> +       double rBest = rbBest.x(), bBest = rbBest.y();
> +
> +       r = rBest, b = bBest;
> +}
> +
> +AwbNN::RGB AwbNN::processZone(AwbNN::RGB zone, float redGain, float blueGain)
> +{
> +       /*
> +        * Renders the pixel at canonical network colour temperature
> +        */
> +       RGB zoneGains = zone;
> +
> +       zoneGains.R *= redGain;
> +       zoneGains.G *= 1.0;
> +       zoneGains.B *= blueGain;
> +
> +       RGB zoneCcm;
> +
> +       zoneCcm.R = nnConfig_.ccm[0] * zoneGains.R + nnConfig_.ccm[1] * zoneGains.G + nnConfig_.ccm[2] * zoneGains.B;
> +       zoneCcm.G = nnConfig_.ccm[3] * zoneGains.R + nnConfig_.ccm[4] * zoneGains.G + nnConfig_.ccm[5] * zoneGains.B;
> +       zoneCcm.B = nnConfig_.ccm[6] * zoneGains.R + nnConfig_.ccm[7] * zoneGains.G + nnConfig_.ccm[8] * zoneGains.B;
> +
> +       return zoneCcm;
> +}
> +
> +void AwbNN::awbNN()
> +{
> +       float *inputData = interpreter_->typed_input_tensor<float>(0);
> +       float *inputLux = interpreter_->typed_input_tensor<float>(1);
> +
> +       float redGain = 1.0 / config_.ctR.eval(kNetworkCanonicalCT);
> +       float blueGain = 1.0 / config_.ctB.eval(kNetworkCanonicalCT);
> +
> +       for (uint i = 0; i < zoneSize_.height; i++) {
> +               for (uint j = 0; j < zoneSize_.width; j++) {
> +                       uint zoneIdx = i * zoneSize_.width + j;
> +
> +                       RGB processedZone = processZone(zones_[zoneIdx] * (1.0 / 65535), redGain, blueGain);
> +                       uint baseIdx = zoneIdx * 3;
> +
> +                       inputData[baseIdx + 0] = static_cast<float>(processedZone.R);
> +                       inputData[baseIdx + 1] = static_cast<float>(processedZone.G);
> +                       inputData[baseIdx + 2] = static_cast<float>(processedZone.B);
> +               }
> +       }
> +
> +       inputLux[0] = static_cast<float>(lux_);
> +
> +       TfLiteStatus status = interpreter_->Invoke();
> +       if (status != kTfLiteOk) {
> +               LOG(RPiAwb, Error) << "Model inference failed with status: " << status;
> +               return;
> +       }
> +
> +       float *outputData = interpreter_->typed_output_tensor<float>(0);
> +
> +       double t = outputData[0];
> +
> +       LOG(RPiAwb, Debug) << "Model output temperature: " << t;
> +
> +       t = std::clamp(t, mode_->ctLo, mode_->ctHi);
> +
> +       double r = config_.ctR.eval(t);
> +       double b = config_.ctB.eval(t);
> +
> +       transverseSearch(t, r, b);
> +
> +       LOG(RPiAwb, Debug) << "After transverse search: Temperature: " << t << " Red gain: " << 1.0 / r << " Blue gain: " << 1.0 / b;
> +
> +       asyncResults_.temperatureK = t;
> +       asyncResults_.gainR = 1.0 / r * config_.sensitivityR;
> +       asyncResults_.gainG = 1.0;
> +       asyncResults_.gainB = 1.0 / b * config_.sensitivityB;
> +}
> +
> +void AwbNN::doAwb()
> +{
> +       prepareStats();
> +       if (zones_.size() == (zoneSize_.width * zoneSize_.height) && nnConfig_.enableNn)
> +               awbNN();
> +       else
> +               awbGrey();
> +       statistics_.reset();
> +}
> +
> +/* Register algorithm with the system. */
> +static Algorithm *create(Controller *controller)
> +{
> +       return (Algorithm *)new AwbNN(controller);
> +}
> +static RegisterAlgorithm reg(NAME, &create);
> +
> +} /* namespace RPiController */
> --
> 2.47.3
>

Patch
diff mbox series

diff --git a/meson_options.txt b/meson_options.txt
index 5954e028..89eece52 100644
--- a/meson_options.txt
+++ b/meson_options.txt
@@ -78,6 +78,11 @@  option('qcam',
         value : 'disabled',
         description : 'Compile the qcam test application')
 
+option('rpi-awb-nn',
+        type : 'feature',
+        value : 'auto',
+        description : 'Enable the Raspberry Pi Neural Network AWB algorithm')
+
 option('test',
         type : 'boolean',
         value : false,
diff --git a/src/ipa/rpi/controller/meson.build b/src/ipa/rpi/controller/meson.build
index 90d9e285..9c261d99 100644
--- a/src/ipa/rpi/controller/meson.build
+++ b/src/ipa/rpi/controller/meson.build
@@ -33,6 +33,15 @@  rpi_ipa_controller_deps = [
     libcamera_private,
 ]
 
+tflite_dep = dependency('tensorflow-lite', required : false)
+
+if tflite_dep.found()
+    rpi_ipa_controller_sources += files([
+        'rpi/awb_nn.cpp',
+    ])
+    rpi_ipa_controller_deps += tflite_dep
+endif
+
 rpi_ipa_controller_lib = static_library('rpi_ipa_controller', rpi_ipa_controller_sources,
                                         include_directories : libipa_includes,
                                         dependencies : rpi_ipa_controller_deps)
diff --git a/src/ipa/rpi/controller/rpi/awb_nn.cpp b/src/ipa/rpi/controller/rpi/awb_nn.cpp
new file mode 100644
index 00000000..35d1270e
--- /dev/null
+++ b/src/ipa/rpi/controller/rpi/awb_nn.cpp
@@ -0,0 +1,446 @@ 
+/* SPDX-License-Identifier: BSD-2-Clause */
+/*
+ * Copyright (C) 2025, Raspberry Pi Ltd
+ *
+ * AWB control algorithm using neural network
+ */
+
+#include <chrono>
+#include <condition_variable>
+#include <thread>
+
+#include <libcamera/base/file.h>
+#include <libcamera/base/log.h>
+
+#include <tensorflow/lite/interpreter.h>
+#include <tensorflow/lite/kernels/register.h>
+#include <tensorflow/lite/model.h>
+
+#include "../awb_algorithm.h"
+#include "../awb_status.h"
+#include "../lux_status.h"
+#include "libipa/pwl.h"
+
+#include "alsc_status.h"
+#include "awb.h"
+
+using namespace libcamera;
+
+LOG_DECLARE_CATEGORY(RPiAwb)
+
+constexpr double kDefaultCT = 4500.0;
+
+/*
+ * The neural networks are trained to work on images rendered at a canonical
+ * colour temperature. That value is 5000K, which must be reproduced here.
+ */
+constexpr double kNetworkCanonicalCT = 5000.0;
+
+#define NAME "rpi.nn.awb"
+
+namespace RPiController {
+
+struct AwbNNConfig {
+	AwbNNConfig() {}
+	int read(const libcamera::YamlObject &params, AwbConfig &config);
+
+	/* An empty model will check default locations for model.tflite */
+	std::string model;
+	float minTemp;
+	float maxTemp;
+
+	bool enableNn;
+
+	/* CCM matrix for canonical network CT */
+	double ccm[9];
+};
+
+class AwbNN : public Awb
+{
+public:
+	AwbNN(Controller *controller = NULL);
+	~AwbNN();
+	char const *name() const override;
+	void initialise() override;
+	int read(const libcamera::YamlObject &params) override;
+
+protected:
+	void doAwb() override;
+	void prepareStats() override;
+
+private:
+	bool isAutoEnabled() const;
+	AwbNNConfig nnConfig_;
+	void transverseSearch(double t, double &r, double &b);
+	RGB processZone(RGB zone, float red_gain, float blue_gain);
+	void awbNN();
+	void loadModel();
+
+	libcamera::Size zoneSize_;
+	std::unique_ptr<tflite::FlatBufferModel> model_;
+	std::unique_ptr<tflite::Interpreter> interpreter_;
+};
+
+int AwbNNConfig::read(const libcamera::YamlObject &params, AwbConfig &config)
+{
+	model = params["model"].get<std::string>("");
+	minTemp = params["min_temp"].get<float>(2800.0);
+	maxTemp = params["max_temp"].get<float>(7600.0);
+
+	for (int i = 0; i < 9; i++)
+		ccm[i] = params["ccm"][i].get<double>(0.0);
+
+	enableNn = params["enable_nn"].get<int>(1);
+
+	if (enableNn) {
+		if (!config.hasCtCurve()) {
+			LOG(RPiAwb, Error) << "CT curve not specified";
+			enableNn = false;
+		}
+
+		if (!model.empty() && model.find(".tflite") == std::string::npos) {
+			LOG(RPiAwb, Error) << "Model must be a .tflite file";
+			enableNn = false;
+		}
+
+		bool validCcm = true;
+		for (int i = 0; i < 9; i++)
+			if (ccm[i] == 0.0)
+				validCcm = false;
+
+		if (!validCcm) {
+			LOG(RPiAwb, Error) << "CCM not specified or invalid";
+			enableNn = false;
+		}
+
+		if (!enableNn) {
+			LOG(RPiAwb, Warning) << "Neural Network AWB mis-configured - switch to Grey method";
+		}
+	}
+
+	if (!enableNn) {
+		config.sensitivityR = config.sensitivityB = 1.0;
+		config.greyWorld = true;
+	}
+
+	return 0;
+}
+
+AwbNN::AwbNN(Controller *controller)
+	: Awb(controller)
+{
+	zoneSize_ = getHardwareConfig().awbRegions;
+}
+
+AwbNN::~AwbNN()
+{
+}
+
+char const *AwbNN::name() const
+{
+	return NAME;
+}
+
+int AwbNN::read(const libcamera::YamlObject &params)
+{
+	int ret;
+
+	ret = config_.read(params);
+	if (ret)
+		return ret;
+
+	ret = nnConfig_.read(params, config_);
+	if (ret)
+		return ret;
+
+	return 0;
+}
+
+static bool checkTensorShape(TfLiteTensor *tensor, const int *expectedDims, const int expectedDimsSize)
+{
+	if (tensor->dims->size != expectedDimsSize)
+		return false;
+
+	for (int i = 0; i < tensor->dims->size; i++) {
+		if (tensor->dims->data[i] != expectedDims[i]) {
+			return false;
+		}
+	}
+	return true;
+}
+
+static std::string buildDimString(const int *dims, const int dimsSize)
+{
+	std::string s = "[";
+	for (int i = 0; i < dimsSize; i++) {
+		s += std::to_string(dims[i]);
+		if (i < dimsSize - 1)
+			s += ",";
+		else
+			s += "]";
+	}
+	return s;
+}
+
+void AwbNN::loadModel()
+{
+	std::string modelPath;
+	if (getTarget() == "bcm2835") {
+		modelPath = "/ipa/rpi/vc4/awb_model.tflite";
+	} else {
+		modelPath = "/ipa/rpi/pisp/awb_model.tflite";
+	}
+
+	if (nnConfig_.model.empty()) {
+		std::string root = utils::libcameraSourcePath();
+		if (!root.empty()) {
+			modelPath = root + modelPath;
+		} else {
+			modelPath = LIBCAMERA_DATA_DIR + modelPath;
+		}
+
+		if (!File::exists(modelPath)) {
+			LOG(RPiAwb, Error) << "No model file found in standard locations";
+			nnConfig_.enableNn = false;
+			return;
+		}
+	} else {
+		modelPath = nnConfig_.model;
+	}
+
+	LOG(RPiAwb, Debug) << "Attempting to load model from: " << modelPath;
+
+	model_ = tflite::FlatBufferModel::BuildFromFile(modelPath.c_str());
+
+	if (!model_) {
+		LOG(RPiAwb, Error) << "Failed to load model from " << modelPath;
+		nnConfig_.enableNn = false;
+		return;
+	}
+
+	tflite::MutableOpResolver resolver;
+	tflite::ops::builtin::BuiltinOpResolver builtin_resolver;
+	resolver.AddAll(builtin_resolver);
+	tflite::InterpreterBuilder(*model_, resolver)(&interpreter_);
+	if (!interpreter_) {
+		LOG(RPiAwb, Error) << "Failed to build interpreter for model " << nnConfig_.model;
+		nnConfig_.enableNn = false;
+		return;
+	}
+
+	interpreter_->AllocateTensors();
+	TfLiteTensor *inputTensor = interpreter_->input_tensor(0);
+	TfLiteTensor *inputLuxTensor = interpreter_->input_tensor(1);
+	TfLiteTensor *outputTensor = interpreter_->output_tensor(0);
+	if (!inputTensor || !inputLuxTensor || !outputTensor) {
+		LOG(RPiAwb, Error) << "Model missing input or output tensor";
+		nnConfig_.enableNn = false;
+		return;
+	}
+
+	const int expectedInputDims[] = { 1, (int)zoneSize_.height, (int)zoneSize_.width, 3 };
+	const int expectedInputLuxDims[] = { 1 };
+	const int expectedOutputDims[] = { 1 };
+
+	if (!checkTensorShape(inputTensor, expectedInputDims, 4)) {
+		LOG(RPiAwb, Error) << "Model input tensor dimension mismatch. Expected: " << buildDimString(expectedInputDims, 4)
+				   << ", Got: " << buildDimString(inputTensor->dims->data, inputTensor->dims->size);
+		nnConfig_.enableNn = false;
+		return;
+	}
+
+	if (!checkTensorShape(inputLuxTensor, expectedInputLuxDims, 1)) {
+		LOG(RPiAwb, Error) << "Model input lux tensor dimension mismatch. Expected: " << buildDimString(expectedInputLuxDims, 1)
+				   << ", Got: " << buildDimString(inputLuxTensor->dims->data, inputLuxTensor->dims->size);
+		nnConfig_.enableNn = false;
+		return;
+	}
+
+	if (!checkTensorShape(outputTensor, expectedOutputDims, 1)) {
+		LOG(RPiAwb, Error) << "Model output tensor dimension mismatch. Expected: " << buildDimString(expectedOutputDims, 1)
+				   << ", Got: " << buildDimString(outputTensor->dims->data, outputTensor->dims->size);
+		nnConfig_.enableNn = false;
+		return;
+	}
+
+	if (inputTensor->type != kTfLiteFloat32 || inputLuxTensor->type != kTfLiteFloat32 || outputTensor->type != kTfLiteFloat32) {
+		LOG(RPiAwb, Error) << "Model input and output tensors must be float32";
+		nnConfig_.enableNn = false;
+		return;
+	}
+
+	LOG(RPiAwb, Info) << "Model loaded successfully from " << modelPath;
+	LOG(RPiAwb, Debug) << "Model validation successful - Input Image: "
+			   << buildDimString(expectedInputDims, 4)
+			   << ", Input Lux: " << buildDimString(expectedInputLuxDims, 1)
+			   << ", Output: " << buildDimString(expectedOutputDims, 1) << " floats";
+}
+
+void AwbNN::initialise()
+{
+	Awb::initialise();
+
+	if (nnConfig_.enableNn) {
+		loadModel();
+		if (!nnConfig_.enableNn) {
+			LOG(RPiAwb, Warning) << "Neural Network AWB failed to load - switch to Grey method";
+			config_.greyWorld = true;
+			config_.sensitivityR = config_.sensitivityB = 1.0;
+		}
+	}
+}
+
+void AwbNN::prepareStats()
+{
+	zones_.clear();
+	/*
+	 * LSC has already been applied to the stats in this pipeline, so stop
+	 * any LSC compensation.  We also ignore config_.fast in this version.
+	 */
+	generateStats(zones_, statistics_, 0.0, 0.0, getGlobalMetadata(), 0.0, 0.0, 0.0);
+	/*
+	 * apply sensitivities, so values appear to come from our "canonical"
+	 * sensor.
+	 */
+	for (auto &zone : zones_) {
+		zone.R *= config_.sensitivityR;
+		zone.B *= config_.sensitivityB;
+	}
+}
+
+void AwbNN::transverseSearch(double t, double &r, double &b)
+{
+	int spanR = -1, spanB = -1;
+	config_.ctR.eval(t, &spanR);
+	config_.ctB.eval(t, &spanB);
+
+	const int diff = 10;
+	double rDiff = config_.ctR.eval(t + diff, &spanR) -
+		       config_.ctR.eval(t - diff, &spanR);
+	double bDiff = config_.ctB.eval(t + diff, &spanB) -
+		       config_.ctB.eval(t - diff, &spanB);
+
+	ipa::Pwl::Point transverse({ bDiff, -rDiff });
+	if (transverse.length2() < 1e-6)
+		return;
+
+	transverse = transverse / transverse.length();
+	double transverseRange = config_.transverseNeg + config_.transversePos;
+	const int maxNumDeltas = 12;
+	int numDeltas = floor(transverseRange * 100 + 0.5) + 1;
+	numDeltas = numDeltas < 3 ? 3 : (numDeltas > maxNumDeltas ? maxNumDeltas : numDeltas);
+
+	ipa::Pwl::Point points[maxNumDeltas];
+	int bestPoint = 0;
+
+	for (int i = 0; i < numDeltas; i++) {
+		points[i][0] = -config_.transverseNeg +
+			       (transverseRange * i) / (numDeltas - 1);
+		ipa::Pwl::Point rbTest = ipa::Pwl::Point({ r, b }) +
+					 transverse * points[i].x();
+		double rTest = rbTest.x(), bTest = rbTest.y();
+		double gainR = 1 / rTest, gainB = 1 / bTest;
+		double delta2Sum = computeDelta2Sum(gainR, gainB, 0.0, 0.0);
+		points[i][1] = delta2Sum;
+		if (points[i].y() < points[bestPoint].y())
+			bestPoint = i;
+	}
+
+	bestPoint = std::clamp(bestPoint, 1, numDeltas - 2);
+	ipa::Pwl::Point rbBest = ipa::Pwl::Point({ r, b }) +
+				 transverse * interpolateQuadatric(points[bestPoint - 1],
+								   points[bestPoint],
+								   points[bestPoint + 1]);
+	double rBest = rbBest.x(), bBest = rbBest.y();
+
+	r = rBest, b = bBest;
+}
+
+AwbNN::RGB AwbNN::processZone(AwbNN::RGB zone, float redGain, float blueGain)
+{
+	/*
+	 * Renders the pixel at canonical network colour temperature
+	 */
+	RGB zoneGains = zone;
+
+	zoneGains.R *= redGain;
+	zoneGains.G *= 1.0;
+	zoneGains.B *= blueGain;
+
+	RGB zoneCcm;
+
+	zoneCcm.R = nnConfig_.ccm[0] * zoneGains.R + nnConfig_.ccm[1] * zoneGains.G + nnConfig_.ccm[2] * zoneGains.B;
+	zoneCcm.G = nnConfig_.ccm[3] * zoneGains.R + nnConfig_.ccm[4] * zoneGains.G + nnConfig_.ccm[5] * zoneGains.B;
+	zoneCcm.B = nnConfig_.ccm[6] * zoneGains.R + nnConfig_.ccm[7] * zoneGains.G + nnConfig_.ccm[8] * zoneGains.B;
+
+	return zoneCcm;
+}
+
+void AwbNN::awbNN()
+{
+	float *inputData = interpreter_->typed_input_tensor<float>(0);
+	float *inputLux = interpreter_->typed_input_tensor<float>(1);
+
+	float redGain = 1.0 / config_.ctR.eval(kNetworkCanonicalCT);
+	float blueGain = 1.0 / config_.ctB.eval(kNetworkCanonicalCT);
+
+	for (uint i = 0; i < zoneSize_.height; i++) {
+		for (uint j = 0; j < zoneSize_.width; j++) {
+			uint zoneIdx = i * zoneSize_.width + j;
+
+			RGB processedZone = processZone(zones_[zoneIdx] * (1.0 / 65535), redGain, blueGain);
+			uint baseIdx = zoneIdx * 3;
+
+			inputData[baseIdx + 0] = static_cast<float>(processedZone.R);
+			inputData[baseIdx + 1] = static_cast<float>(processedZone.G);
+			inputData[baseIdx + 2] = static_cast<float>(processedZone.B);
+		}
+	}
+
+	inputLux[0] = static_cast<float>(lux_);
+
+	TfLiteStatus status = interpreter_->Invoke();
+	if (status != kTfLiteOk) {
+		LOG(RPiAwb, Error) << "Model inference failed with status: " << status;
+		return;
+	}
+
+	float *outputData = interpreter_->typed_output_tensor<float>(0);
+
+	double t = outputData[0];
+
+	LOG(RPiAwb, Debug) << "Model output temperature: " << t;
+
+	t = std::clamp(t, mode_->ctLo, mode_->ctHi);
+
+	double r = config_.ctR.eval(t);
+	double b = config_.ctB.eval(t);
+
+	transverseSearch(t, r, b);
+
+	LOG(RPiAwb, Debug) << "After transverse search: Temperature: " << t << " Red gain: " << 1.0 / r << " Blue gain: " << 1.0 / b;
+
+	asyncResults_.temperatureK = t;
+	asyncResults_.gainR = 1.0 / r * config_.sensitivityR;
+	asyncResults_.gainG = 1.0;
+	asyncResults_.gainB = 1.0 / b * config_.sensitivityB;
+}
+
+void AwbNN::doAwb()
+{
+	prepareStats();
+	if (zones_.size() == (zoneSize_.width * zoneSize_.height) && nnConfig_.enableNn)
+		awbNN();
+	else
+		awbGrey();
+	statistics_.reset();
+}
+
+/* Register algorithm with the system. */
+static Algorithm *create(Controller *controller)
+{
+	return (Algorithm *)new AwbNN(controller);
+}
+static RegisterAlgorithm reg(NAME, &create);
+
+} /* namespace RPiController */