[2/4] ipa: rpi: controller: awb: Add Neural Network Awb
diff mbox series

Message ID 20251024144049.3311-3-david.plowman@raspberrypi.com
State Superseded
Headers show
Series
  • Raspberry Pi AWB using neural networks
Related show

Commit Message

David Plowman Oct. 24, 2025, 2:16 p.m. UTC
From: Peter Bailey <peter.bailey@raspberrypi.com>

Add an Awb algorithm which uses neural networks.

Signed-off-by: Peter Bailey <peter.bailey@raspberrypi.com>
---
 src/ipa/rpi/controller/meson.build    |   9 +
 src/ipa/rpi/controller/rpi/awb_nn.cpp | 442 ++++++++++++++++++++++++++
 2 files changed, 451 insertions(+)
 create mode 100644 src/ipa/rpi/controller/rpi/awb_nn.cpp

Comments

Naushir Patuck Dec. 10, 2025, 2:08 p.m. UTC | #1
Hi David and Peter,

Only minor styling comments below:

On Fri, 24 Oct 2025 at 15:41, David Plowman
<david.plowman@raspberrypi.com> wrote:
>
> From: Peter Bailey <peter.bailey@raspberrypi.com>
>
> Add an Awb algorithm which uses neural networks.
>
> Signed-off-by: Peter Bailey <peter.bailey@raspberrypi.com>
> ---
>  src/ipa/rpi/controller/meson.build    |   9 +
>  src/ipa/rpi/controller/rpi/awb_nn.cpp | 442 ++++++++++++++++++++++++++
>  2 files changed, 451 insertions(+)
>  create mode 100644 src/ipa/rpi/controller/rpi/awb_nn.cpp
>
> diff --git a/src/ipa/rpi/controller/meson.build b/src/ipa/rpi/controller/meson.build
> index 73c93dca..2541d073 100644
> --- a/src/ipa/rpi/controller/meson.build
> +++ b/src/ipa/rpi/controller/meson.build
> @@ -32,6 +32,15 @@ rpi_ipa_controller_deps = [
>      libcamera_private,
>  ]
>
> +tflite_dep = dependency('tensorflow-lite', required : false)
> +
> +if tflite_dep.found()
> +    rpi_ipa_controller_sources += files([
> +        'rpi/awb_nn.cpp',
> +    ])
> +    rpi_ipa_controller_deps += tflite_dep
> +endif
> +
>  rpi_ipa_controller_lib = static_library('rpi_ipa_controller', rpi_ipa_controller_sources,
>                                          include_directories : libipa_includes,
>                                          dependencies : rpi_ipa_controller_deps)
> diff --git a/src/ipa/rpi/controller/rpi/awb_nn.cpp b/src/ipa/rpi/controller/rpi/awb_nn.cpp
> new file mode 100644
> index 00000000..c309ca3f
> --- /dev/null
> +++ b/src/ipa/rpi/controller/rpi/awb_nn.cpp
> @@ -0,0 +1,442 @@
> +/* SPDX-License-Identifier: BSD-2-Clause */
> +/*
> + * Copyright (C) 2025, Raspberry Pi Ltd
> + *
> + * AWB control algorithm using neural network
> + */
> +
> +#include <chrono>
> +#include <condition_variable>
> +#include <thread>
> +
> +#include <libcamera/base/file.h>
> +#include <libcamera/base/log.h>
> +
> +#include <tensorflow/lite/interpreter.h>
> +#include <tensorflow/lite/kernels/register.h>
> +#include <tensorflow/lite/model.h>
> +
> +#include "../awb_algorithm.h"
> +#include "../awb_status.h"
> +#include "../lux_status.h"
> +#include "libipa/pwl.h"
> +
> +#include "alsc_status.h"
> +#include "awb.h"
> +
> +using namespace libcamera;
> +
> +LOG_DECLARE_CATEGORY(RPiAwb)
> +
> +constexpr double kDefaultCT = 4500.0;
> +
> +#define NAME "rpi.nn.awb"
> +
> +namespace RPiController {
> +
> +struct AwbNNConfig {
> +       AwbNNConfig() {}
> +       int read(const libcamera::YamlObject &params, AwbConfig &config);
> +
> +       /* An empty model will check default locations for model.tflite */
> +       std::string model;
> +       float minTemp;
> +       float maxTemp;
> +
> +       bool enableNn;
> +
> +       /* CCM matrix for 5000K temperature */
> +       double ccm[9];
> +};
> +
> +class AwbNN : public Awb
> +{
> +public:
> +       AwbNN(Controller *controller = NULL);
> +       ~AwbNN();
> +       char const *name() const override;
> +       void initialise() override;
> +       int read(const libcamera::YamlObject &params) override;
> +
> +protected:
> +       void doAwb() override;
> +       void prepareStats() override;
> +
> +private:
> +       bool isAutoEnabled() const;
> +       AwbNNConfig nnConfig_;
> +       void transverseSearch(double t, double &r, double &b);
> +       RGB processZone(RGB zone, float red_gain, float blue_gain);
> +       void awbNN();
> +       void loadModel();
> +
> +       libcamera::Size zoneSize_;
> +       std::unique_ptr<tflite::FlatBufferModel> model_;
> +       std::unique_ptr<tflite::Interpreter> interpreter_;
> +};
> +
> +int AwbNNConfig::read(const libcamera::YamlObject &params, AwbConfig &config)
> +{
> +       model = params["model"].get<std::string>("");
> +       minTemp = params["min_temp"].get<float>(2800.0);
> +       maxTemp = params["max_temp"].get<float>(7600.0);
> +
> +       for (int i = 0; i < 9; i++)
> +               ccm[i] = params["ccm"][i].get<double>(0.0);
> +
> +       enableNn = params["enable_nn"].get<int>(1);
> +
> +       if (enableNn) {
> +               if (!config.hasCtCurve()) {
> +                       LOG(RPiAwb, Error) << "CT curve not specified";
> +                       enableNn = false;
> +               }
> +
> +               if (!model.empty() && model.find(".tflite") == std::string::npos) {
> +                       LOG(RPiAwb, Error) << "Model must be a .tflite file";
> +                       enableNn = false;
> +               }
> +
> +               bool validCcm = true;
> +               for (int i = 0; i < 9; i++)
> +                       if (ccm[i] == 0.0)
> +                               validCcm = false;
> +
> +               if (!validCcm) {
> +                       LOG(RPiAwb, Error) << "CCM not specified or invalid";
> +                       enableNn = false;
> +               }
> +
> +               if (!enableNn) {
> +                       LOG(RPiAwb, Warning) << "Neural Network AWB mis-configured - switch to Grey method";
> +               }
> +       }
> +
> +       if (!enableNn) {
> +               config.sensitivityR = config.sensitivityB = 1.0;
> +               config.greyWorld = true;
> +       }
> +
> +       return 0;
> +}
> +
> +AwbNN::AwbNN(Controller *controller)
> +       : Awb(controller)
> +{
> +       zoneSize_ = getHardwareConfig().awbRegions;
> +}
> +
> +AwbNN::~AwbNN()
> +{
> +}
> +
> +char const *AwbNN::name() const
> +{
> +       return NAME;
> +}
> +
> +int AwbNN::read(const libcamera::YamlObject &params)
> +{
> +       int ret;
> +
> +       ret = config_.read(params);
> +       if (ret)
> +               return ret;
> +
> +       ret = nnConfig_.read(params, config_);
> +       if (ret)
> +               return ret;
> +
> +       return 0;
> +}
> +
> +static bool checkTensorShape(TfLiteTensor *tensor, const int *expectedDims, const int expectedDimsSize)
> +{
> +       if (tensor->dims->size != expectedDimsSize) {
> +               return false;
> +       }
> +
> +       for (int i = 0; i < tensor->dims->size; i++) {
> +               if (tensor->dims->data[i] != expectedDims[i]) {
> +                       return false;
> +               }

no need for brackets here and above.

> +       }
> +       return true;
> +}
> +
> +static std::string buildDimString(const int *dims, const int dimsSize)
> +{
> +       std::string s = "[";
> +       for (int i = 0; i < dimsSize; i++) {
> +               s += std::to_string(dims[i]);
> +               if (i < dimsSize - 1)
> +                       s += ",";
> +               else
> +                       s += "]";
> +       }
> +       return s;
> +}
> +
> +void AwbNN::loadModel()
> +{
> +       std::string modelPath;
> +       if (getTarget() == "bcm2835") {
> +               modelPath = "/ipa/rpi/vc4/awb_model.tflite";
> +       } else {
> +               modelPath = "/ipa/rpi/pisp/awb_model.tflite";
> +       }
> +
> +       if (nnConfig_.model.empty()) {
> +               std::string root = utils::libcameraSourcePath();
> +               if (!root.empty()) {
> +                       modelPath = root + modelPath;
> +               } else {
> +                       modelPath = LIBCAMERA_DATA_DIR + modelPath;
> +               }
> +
> +               if (!File::exists(modelPath)) {
> +                       LOG(RPiAwb, Error) << "No model file found in standard locations";
> +                       nnConfig_.enableNn = false;
> +                       return;
> +               }
> +       } else {
> +               modelPath = nnConfig_.model;
> +       }
> +
> +       LOG(RPiAwb, Debug) << "Attempting to load model from: " << modelPath;
> +
> +       model_ = tflite::FlatBufferModel::BuildFromFile(modelPath.c_str());
> +
> +       if (!model_) {
> +               LOG(RPiAwb, Error) << "Failed to load model from " << modelPath;
> +               nnConfig_.enableNn = false;
> +               return;
> +       }
> +
> +       tflite::MutableOpResolver resolver;
> +       tflite::ops::builtin::BuiltinOpResolver builtin_resolver;
> +       resolver.AddAll(builtin_resolver);
> +       tflite::InterpreterBuilder(*model_, resolver)(&interpreter_);
> +       if (!interpreter_) {
> +               LOG(RPiAwb, Error) << "Failed to build interpreter for model " << nnConfig_.model;
> +               nnConfig_.enableNn = false;
> +               return;
> +       }
> +
> +       interpreter_->AllocateTensors();
> +       TfLiteTensor *inputTensor = interpreter_->input_tensor(0);
> +       TfLiteTensor *inputLuxTensor = interpreter_->input_tensor(1);
> +       TfLiteTensor *outputTensor = interpreter_->output_tensor(0);
> +       if (!inputTensor || !inputLuxTensor || !outputTensor) {
> +               LOG(RPiAwb, Error) << "Model missing input or output tensor";
> +               nnConfig_.enableNn = false;
> +               return;
> +       }
> +
> +       const int expectedInputDims[] = { 1, (int)zoneSize_.height, (int)zoneSize_.width, 3 };
> +       const int expectedInputLuxDims[] = { 1 };
> +       const int expectedOutputDims[] = { 1 };
> +
> +       if (!checkTensorShape(inputTensor, expectedInputDims, 4)) {
> +               LOG(RPiAwb, Error) << "Model input tensor dimension mismatch. Expected: " << buildDimString(expectedInputDims, 4)
> +                                  << ", Got: " << buildDimString(inputTensor->dims->data, inputTensor->dims->size);
> +               nnConfig_.enableNn = false;
> +               return;
> +       }
> +
> +       if (!checkTensorShape(inputLuxTensor, expectedInputLuxDims, 1)) {
> +               LOG(RPiAwb, Error) << "Model input lux tensor dimension mismatch. Expected: " << buildDimString(expectedInputLuxDims, 1)
> +                                  << ", Got: " << buildDimString(inputLuxTensor->dims->data, inputLuxTensor->dims->size);
> +               nnConfig_.enableNn = false;
> +               return;
> +       }
> +
> +       if (!checkTensorShape(outputTensor, expectedOutputDims, 1)) {
> +               LOG(RPiAwb, Error) << "Model output tensor dimension mismatch. Expected: " << buildDimString(expectedOutputDims, 1)
> +                                  << ", Got: " << buildDimString(outputTensor->dims->data, outputTensor->dims->size);
> +               nnConfig_.enableNn = false;
> +               return;
> +       }
> +
> +       if (inputTensor->type != kTfLiteFloat32 || inputLuxTensor->type != kTfLiteFloat32 || outputTensor->type != kTfLiteFloat32) {
> +               LOG(RPiAwb, Error) << "Model input and output tensors must be float32";
> +               nnConfig_.enableNn = false;
> +               return;
> +       }
> +
> +       LOG(RPiAwb, Info) << "Model loaded successfully from " << modelPath;
> +       LOG(RPiAwb, Debug) << "Model validation successful - Input Image: "
> +                          << buildDimString(expectedInputDims, 4)
> +                          << ", Input Lux: " << buildDimString(expectedInputLuxDims, 1)
> +                          << ", Output: " << buildDimString(expectedOutputDims, 1) << " floats";
> +}
> +
> +void AwbNN::initialise()
> +{
> +       Awb::initialise();
> +
> +       if (nnConfig_.enableNn) {
> +               loadModel();
> +               if (!nnConfig_.enableNn) {
> +                       LOG(RPiAwb, Warning) << "Neural Network AWB failed to load - switch to Grey method";
> +                       config_.greyWorld = true;
> +                       config_.sensitivityR = config_.sensitivityB = 1.0;
> +               }
> +       }
> +}
> +
> +void AwbNN::prepareStats()
> +{
> +       zones_.clear();
> +       /*
> +        * LSC has already been applied to the stats in this pipeline, so stop
> +        * any LSC compensation.  We also ignore config_.fast in this version.
> +        */
> +       generateStats(zones_, statistics_, 0.0, 0.0, getGlobalMetadata(), 0.0, 0.0, 0.0);
> +       /*
> +        * apply sensitivities, so values appear to come from our "canonical"
> +        * sensor.
> +        */
> +       for (auto &zone : zones_) {
> +               zone.R *= config_.sensitivityR;
> +               zone.B *= config_.sensitivityB;
> +       }
> +}
> +
> +void AwbNN::transverseSearch(double t, double &r, double &b)
> +{
> +       int spanR = -1, spanB = -1;
> +       config_.ctR.eval(t, &spanR);
> +       config_.ctB.eval(t, &spanB);
> +
> +       const int diff = 10;
> +       double rDiff = config_.ctR.eval(t + diff, &spanR) -
> +                      config_.ctR.eval(t - diff, &spanR);
> +       double bDiff = config_.ctB.eval(t + diff, &spanB) -
> +                      config_.ctB.eval(t - diff, &spanB);
> +
> +       ipa::Pwl::Point transverse({ bDiff, -rDiff });
> +       if (transverse.length2() < 1e-6)
> +               return;
> +
> +       transverse = transverse / transverse.length();
> +       double transverseRange = config_.transverseNeg + config_.transversePos;
> +       const int maxNumDeltas = 12;
> +       int numDeltas = floor(transverseRange * 100 + 0.5) + 1;
> +       numDeltas = numDeltas < 3 ? 3 : (numDeltas > maxNumDeltas ? maxNumDeltas : numDeltas);
> +
> +       ipa::Pwl::Point points[maxNumDeltas];
> +       int bestPoint = 0;
> +
> +       for (int i = 0; i < numDeltas; i++) {
> +               points[i][0] = -config_.transverseNeg +
> +                              (transverseRange * i) / (numDeltas - 1);
> +               ipa::Pwl::Point rbTest = ipa::Pwl::Point({ r, b }) +
> +                                        transverse * points[i].x();
> +               double rTest = rbTest.x(), bTest = rbTest.y();
> +               double gainR = 1 / rTest, gainB = 1 / bTest;
> +               double delta2Sum = computeDelta2Sum(gainR, gainB, 0.0, 0.0);
> +               points[i][1] = delta2Sum;
> +               if (points[i].y() < points[bestPoint].y())
> +                       bestPoint = i;
> +       }
> +
> +       bestPoint = std::clamp(bestPoint, 1, numDeltas - 2);
> +       ipa::Pwl::Point rbBest = ipa::Pwl::Point({ r, b }) +
> +                                transverse * interpolateQuadatric(points[bestPoint - 1],
> +                                                                  points[bestPoint],
> +                                                                  points[bestPoint + 1]);
> +       double rBest = rbBest.x(), bBest = rbBest.y();
> +
> +       r = rBest, b = bBest;
> +}
> +
> +AwbNN::RGB AwbNN::processZone(AwbNN::RGB zone, float redGain, float blueGain)
> +{
> +       /*
> +        * Renders the pixel at 5000K temperature
> +        */
> +       RGB zoneGains = zone;
> +
> +       zoneGains.R *= redGain;
> +       zoneGains.G *= 1.0;
> +       zoneGains.B *= blueGain;
> +
> +       RGB zoneCcm;
> +
> +       zoneCcm.R = nnConfig_.ccm[0] * zoneGains.R + nnConfig_.ccm[1] * zoneGains.G + nnConfig_.ccm[2] * zoneGains.B;
> +       zoneCcm.G = nnConfig_.ccm[3] * zoneGains.R + nnConfig_.ccm[4] * zoneGains.G + nnConfig_.ccm[5] * zoneGains.B;
> +       zoneCcm.B = nnConfig_.ccm[6] * zoneGains.R + nnConfig_.ccm[7] * zoneGains.G + nnConfig_.ccm[8] * zoneGains.B;
> +
> +       return zoneCcm;
> +}
> +
> +void AwbNN::awbNN()
> +{
> +       float *inputData = interpreter_->typed_input_tensor<float>(0);
> +       float *inputLux = interpreter_->typed_input_tensor<float>(1);
> +
> +       float redGain = 1.0 / config_.ctR.eval(5000);
> +       float blueGain = 1.0 / config_.ctB.eval(5000);

Maybe the 5000k here and elsewhere needs to be a constant
kCanonicalTemp or something?

> +
> +       for (uint i = 0; i < zoneSize_.height; i++) {
> +               for (uint j = 0; j < zoneSize_.width; j++) {
> +                       uint zoneIdx = i * zoneSize_.width + j;
> +
> +                       RGB processedZone = processZone(zones_[zoneIdx] * (1.0 / 65535), redGain, blueGain);
> +                       uint baseIdx = zoneIdx * 3;
> +
> +                       inputData[baseIdx + 0] = static_cast<float>(processedZone.R);
> +                       inputData[baseIdx + 1] = static_cast<float>(processedZone.G);
> +                       inputData[baseIdx + 2] = static_cast<float>(processedZone.B);
> +               }
> +       }
> +
> +       inputLux[0] = static_cast<float>(lux_);
> +
> +       TfLiteStatus status = interpreter_->Invoke();
> +       if (status != kTfLiteOk) {
> +               LOG(RPiAwb, Error) << "Model inference failed with status: " << status;
> +               return;
> +       }
> +
> +       float *outputData = interpreter_->typed_output_tensor<float>(0);
> +
> +       double t = outputData[0];
> +
> +       LOG(RPiAwb, Debug) << "Model output temperature: " << t;
> +
> +       t = std::clamp(t, mode_->ctLo, mode_->ctHi);
> +
> +       double r = config_.ctR.eval(t);
> +       double b = config_.ctB.eval(t);
> +
> +       transverseSearch(t, r, b);
> +
> +       LOG(RPiAwb, Debug) << "After transverse search: Temperature: " << t << " Red gain: " << 1.0 / r << " Blue gain: " << 1.0 / b;
> +
> +       asyncResults_.temperatureK = t;
> +       asyncResults_.gainR = 1.0 / r * config_.sensitivityR;
> +       asyncResults_.gainG = 1.0;
> +       asyncResults_.gainB = 1.0 / b * config_.sensitivityB;
> +}
> +
> +void AwbNN::doAwb()
> +{
> +       prepareStats();
> +       if (zones_.size() == (zoneSize_.width * zoneSize_.height) && nnConfig_.enableNn) {
> +               awbNN();
> +       } else {
> +               awbGrey();
> +       }

No need for the brackets here.

Reviewed-by: Naushir Patuck <naush@raspberrypi.com>

> +       statistics_.reset();
> +}
> +
> +/* Register algorithm with the system. */
> +static Algorithm *create(Controller *controller)
> +{
> +       return (Algorithm *)new AwbNN(controller);
> +}
> +static RegisterAlgorithm reg(NAME, &create);
> +
> +} /* namespace RPiController */
> --
> 2.47.3
>
Laurent Pinchart Dec. 11, 2025, 8:21 a.m. UTC | #2
Hi David,

Thank you for the patch.

On Fri, Oct 24, 2025 at 03:16:01PM +0100, David Plowman wrote:
> From: Peter Bailey <peter.bailey@raspberrypi.com>
> 
> Add an Awb algorithm which uses neural networks.
> 
> Signed-off-by: Peter Bailey <peter.bailey@raspberrypi.com>
> ---
>  src/ipa/rpi/controller/meson.build    |   9 +
>  src/ipa/rpi/controller/rpi/awb_nn.cpp | 442 ++++++++++++++++++++++++++
>  2 files changed, 451 insertions(+)
>  create mode 100644 src/ipa/rpi/controller/rpi/awb_nn.cpp
> 
> diff --git a/src/ipa/rpi/controller/meson.build b/src/ipa/rpi/controller/meson.build
> index 73c93dca..2541d073 100644
> --- a/src/ipa/rpi/controller/meson.build
> +++ b/src/ipa/rpi/controller/meson.build
> @@ -32,6 +32,15 @@ rpi_ipa_controller_deps = [
>      libcamera_private,
>  ]
>  
> +tflite_dep = dependency('tensorflow-lite', required : false)

This needs a configuration option, otherwise libcamera may accidentally
depend on tensorflow-lite when not desired. Package maintainers for
distributions hate those accidental dependencies that make builds
non-reproducible.

> +
> +if tflite_dep.found()
> +    rpi_ipa_controller_sources += files([
> +        'rpi/awb_nn.cpp',
> +    ])
> +    rpi_ipa_controller_deps += tflite_dep
> +endif
> +
>  rpi_ipa_controller_lib = static_library('rpi_ipa_controller', rpi_ipa_controller_sources,
>                                          include_directories : libipa_includes,
>                                          dependencies : rpi_ipa_controller_deps)
> diff --git a/src/ipa/rpi/controller/rpi/awb_nn.cpp b/src/ipa/rpi/controller/rpi/awb_nn.cpp
> new file mode 100644
> index 00000000..c309ca3f
> --- /dev/null
> +++ b/src/ipa/rpi/controller/rpi/awb_nn.cpp
> @@ -0,0 +1,442 @@
> +/* SPDX-License-Identifier: BSD-2-Clause */
> +/*
> + * Copyright (C) 2025, Raspberry Pi Ltd
> + *
> + * AWB control algorithm using neural network
> + */
> +
> +#include <chrono>
> +#include <condition_variable>
> +#include <thread>
> +
> +#include <libcamera/base/file.h>
> +#include <libcamera/base/log.h>
> +
> +#include <tensorflow/lite/interpreter.h>
> +#include <tensorflow/lite/kernels/register.h>
> +#include <tensorflow/lite/model.h>
> +
> +#include "../awb_algorithm.h"
> +#include "../awb_status.h"
> +#include "../lux_status.h"
> +#include "libipa/pwl.h"
> +
> +#include "alsc_status.h"
> +#include "awb.h"
> +
> +using namespace libcamera;
> +
> +LOG_DECLARE_CATEGORY(RPiAwb)
> +
> +constexpr double kDefaultCT = 4500.0;
> +
> +#define NAME "rpi.nn.awb"
> +
> +namespace RPiController {
> +
> +struct AwbNNConfig {
> +	AwbNNConfig() {}
> +	int read(const libcamera::YamlObject &params, AwbConfig &config);
> +
> +	/* An empty model will check default locations for model.tflite */
> +	std::string model;
> +	float minTemp;
> +	float maxTemp;
> +
> +	bool enableNn;
> +
> +	/* CCM matrix for 5000K temperature */
> +	double ccm[9];
> +};
> +
> +class AwbNN : public Awb
> +{
> +public:
> +	AwbNN(Controller *controller = NULL);
> +	~AwbNN();
> +	char const *name() const override;
> +	void initialise() override;
> +	int read(const libcamera::YamlObject &params) override;
> +
> +protected:
> +	void doAwb() override;
> +	void prepareStats() override;
> +
> +private:
> +	bool isAutoEnabled() const;
> +	AwbNNConfig nnConfig_;
> +	void transverseSearch(double t, double &r, double &b);
> +	RGB processZone(RGB zone, float red_gain, float blue_gain);
> +	void awbNN();
> +	void loadModel();
> +
> +	libcamera::Size zoneSize_;
> +	std::unique_ptr<tflite::FlatBufferModel> model_;
> +	std::unique_ptr<tflite::Interpreter> interpreter_;
> +};
> +
> +int AwbNNConfig::read(const libcamera::YamlObject &params, AwbConfig &config)
> +{
> +	model = params["model"].get<std::string>("");
> +	minTemp = params["min_temp"].get<float>(2800.0);
> +	maxTemp = params["max_temp"].get<float>(7600.0);
> +
> +	for (int i = 0; i < 9; i++)
> +		ccm[i] = params["ccm"][i].get<double>(0.0);
> +
> +	enableNn = params["enable_nn"].get<int>(1);
> +
> +	if (enableNn) {
> +		if (!config.hasCtCurve()) {
> +			LOG(RPiAwb, Error) << "CT curve not specified";
> +			enableNn = false;
> +		}
> +
> +		if (!model.empty() && model.find(".tflite") == std::string::npos) {
> +			LOG(RPiAwb, Error) << "Model must be a .tflite file";
> +			enableNn = false;
> +		}
> +
> +		bool validCcm = true;
> +		for (int i = 0; i < 9; i++)
> +			if (ccm[i] == 0.0)
> +				validCcm = false;
> +
> +		if (!validCcm) {
> +			LOG(RPiAwb, Error) << "CCM not specified or invalid";
> +			enableNn = false;
> +		}
> +
> +		if (!enableNn) {
> +			LOG(RPiAwb, Warning) << "Neural Network AWB mis-configured - switch to Grey method";
> +		}
> +	}
> +
> +	if (!enableNn) {
> +		config.sensitivityR = config.sensitivityB = 1.0;
> +		config.greyWorld = true;
> +	}
> +
> +	return 0;
> +}
> +
> +AwbNN::AwbNN(Controller *controller)
> +	: Awb(controller)
> +{
> +	zoneSize_ = getHardwareConfig().awbRegions;
> +}
> +
> +AwbNN::~AwbNN()
> +{
> +}
> +
> +char const *AwbNN::name() const
> +{
> +	return NAME;
> +}
> +
> +int AwbNN::read(const libcamera::YamlObject &params)
> +{
> +	int ret;
> +
> +	ret = config_.read(params);
> +	if (ret)
> +		return ret;
> +
> +	ret = nnConfig_.read(params, config_);
> +	if (ret)
> +		return ret;
> +
> +	return 0;
> +}
> +
> +static bool checkTensorShape(TfLiteTensor *tensor, const int *expectedDims, const int expectedDimsSize)
> +{
> +	if (tensor->dims->size != expectedDimsSize) {
> +		return false;
> +	}
> +
> +	for (int i = 0; i < tensor->dims->size; i++) {
> +		if (tensor->dims->data[i] != expectedDims[i]) {
> +			return false;
> +		}
> +	}
> +	return true;
> +}
> +
> +static std::string buildDimString(const int *dims, const int dimsSize)
> +{
> +	std::string s = "[";
> +	for (int i = 0; i < dimsSize; i++) {
> +		s += std::to_string(dims[i]);
> +		if (i < dimsSize - 1)
> +			s += ",";
> +		else
> +			s += "]";
> +	}
> +	return s;
> +}
> +
> +void AwbNN::loadModel()
> +{
> +	std::string modelPath;
> +	if (getTarget() == "bcm2835") {
> +		modelPath = "/ipa/rpi/vc4/awb_model.tflite";
> +	} else {
> +		modelPath = "/ipa/rpi/pisp/awb_model.tflite";
> +	}
> +
> +	if (nnConfig_.model.empty()) {
> +		std::string root = utils::libcameraSourcePath();
> +		if (!root.empty()) {
> +			modelPath = root + modelPath;
> +		} else {
> +			modelPath = LIBCAMERA_DATA_DIR + modelPath;
> +		}
> +
> +		if (!File::exists(modelPath)) {
> +			LOG(RPiAwb, Error) << "No model file found in standard locations";
> +			nnConfig_.enableNn = false;
> +			return;
> +		}
> +	} else {
> +		modelPath = nnConfig_.model;
> +	}
> +
> +	LOG(RPiAwb, Debug) << "Attempting to load model from: " << modelPath;
> +
> +	model_ = tflite::FlatBufferModel::BuildFromFile(modelPath.c_str());
> +
> +	if (!model_) {
> +		LOG(RPiAwb, Error) << "Failed to load model from " << modelPath;
> +		nnConfig_.enableNn = false;
> +		return;
> +	}
> +
> +	tflite::MutableOpResolver resolver;
> +	tflite::ops::builtin::BuiltinOpResolver builtin_resolver;
> +	resolver.AddAll(builtin_resolver);
> +	tflite::InterpreterBuilder(*model_, resolver)(&interpreter_);
> +	if (!interpreter_) {
> +		LOG(RPiAwb, Error) << "Failed to build interpreter for model " << nnConfig_.model;
> +		nnConfig_.enableNn = false;
> +		return;
> +	}
> +
> +	interpreter_->AllocateTensors();
> +	TfLiteTensor *inputTensor = interpreter_->input_tensor(0);
> +	TfLiteTensor *inputLuxTensor = interpreter_->input_tensor(1);
> +	TfLiteTensor *outputTensor = interpreter_->output_tensor(0);
> +	if (!inputTensor || !inputLuxTensor || !outputTensor) {
> +		LOG(RPiAwb, Error) << "Model missing input or output tensor";
> +		nnConfig_.enableNn = false;
> +		return;
> +	}
> +
> +	const int expectedInputDims[] = { 1, (int)zoneSize_.height, (int)zoneSize_.width, 3 };
> +	const int expectedInputLuxDims[] = { 1 };
> +	const int expectedOutputDims[] = { 1 };
> +
> +	if (!checkTensorShape(inputTensor, expectedInputDims, 4)) {
> +		LOG(RPiAwb, Error) << "Model input tensor dimension mismatch. Expected: " << buildDimString(expectedInputDims, 4)
> +				   << ", Got: " << buildDimString(inputTensor->dims->data, inputTensor->dims->size);
> +		nnConfig_.enableNn = false;
> +		return;
> +	}
> +
> +	if (!checkTensorShape(inputLuxTensor, expectedInputLuxDims, 1)) {
> +		LOG(RPiAwb, Error) << "Model input lux tensor dimension mismatch. Expected: " << buildDimString(expectedInputLuxDims, 1)
> +				   << ", Got: " << buildDimString(inputLuxTensor->dims->data, inputLuxTensor->dims->size);
> +		nnConfig_.enableNn = false;
> +		return;
> +	}
> +
> +	if (!checkTensorShape(outputTensor, expectedOutputDims, 1)) {
> +		LOG(RPiAwb, Error) << "Model output tensor dimension mismatch. Expected: " << buildDimString(expectedOutputDims, 1)
> +				   << ", Got: " << buildDimString(outputTensor->dims->data, outputTensor->dims->size);
> +		nnConfig_.enableNn = false;
> +		return;
> +	}
> +
> +	if (inputTensor->type != kTfLiteFloat32 || inputLuxTensor->type != kTfLiteFloat32 || outputTensor->type != kTfLiteFloat32) {
> +		LOG(RPiAwb, Error) << "Model input and output tensors must be float32";
> +		nnConfig_.enableNn = false;
> +		return;
> +	}
> +
> +	LOG(RPiAwb, Info) << "Model loaded successfully from " << modelPath;
> +	LOG(RPiAwb, Debug) << "Model validation successful - Input Image: "
> +			   << buildDimString(expectedInputDims, 4)
> +			   << ", Input Lux: " << buildDimString(expectedInputLuxDims, 1)
> +			   << ", Output: " << buildDimString(expectedOutputDims, 1) << " floats";
> +}
> +
> +void AwbNN::initialise()
> +{
> +	Awb::initialise();
> +
> +	if (nnConfig_.enableNn) {
> +		loadModel();
> +		if (!nnConfig_.enableNn) {
> +			LOG(RPiAwb, Warning) << "Neural Network AWB failed to load - switch to Grey method";
> +			config_.greyWorld = true;
> +			config_.sensitivityR = config_.sensitivityB = 1.0;
> +		}
> +	}
> +}
> +
> +void AwbNN::prepareStats()
> +{
> +	zones_.clear();
> +	/*
> +	 * LSC has already been applied to the stats in this pipeline, so stop
> +	 * any LSC compensation.  We also ignore config_.fast in this version.
> +	 */
> +	generateStats(zones_, statistics_, 0.0, 0.0, getGlobalMetadata(), 0.0, 0.0, 0.0);
> +	/*
> +	 * apply sensitivities, so values appear to come from our "canonical"
> +	 * sensor.
> +	 */
> +	for (auto &zone : zones_) {
> +		zone.R *= config_.sensitivityR;
> +		zone.B *= config_.sensitivityB;
> +	}
> +}
> +
> +void AwbNN::transverseSearch(double t, double &r, double &b)
> +{
> +	int spanR = -1, spanB = -1;
> +	config_.ctR.eval(t, &spanR);
> +	config_.ctB.eval(t, &spanB);
> +
> +	const int diff = 10;
> +	double rDiff = config_.ctR.eval(t + diff, &spanR) -
> +		       config_.ctR.eval(t - diff, &spanR);
> +	double bDiff = config_.ctB.eval(t + diff, &spanB) -
> +		       config_.ctB.eval(t - diff, &spanB);
> +
> +	ipa::Pwl::Point transverse({ bDiff, -rDiff });
> +	if (transverse.length2() < 1e-6)
> +		return;
> +
> +	transverse = transverse / transverse.length();
> +	double transverseRange = config_.transverseNeg + config_.transversePos;
> +	const int maxNumDeltas = 12;
> +	int numDeltas = floor(transverseRange * 100 + 0.5) + 1;
> +	numDeltas = numDeltas < 3 ? 3 : (numDeltas > maxNumDeltas ? maxNumDeltas : numDeltas);
> +
> +	ipa::Pwl::Point points[maxNumDeltas];
> +	int bestPoint = 0;
> +
> +	for (int i = 0; i < numDeltas; i++) {
> +		points[i][0] = -config_.transverseNeg +
> +			       (transverseRange * i) / (numDeltas - 1);
> +		ipa::Pwl::Point rbTest = ipa::Pwl::Point({ r, b }) +
> +					 transverse * points[i].x();
> +		double rTest = rbTest.x(), bTest = rbTest.y();
> +		double gainR = 1 / rTest, gainB = 1 / bTest;
> +		double delta2Sum = computeDelta2Sum(gainR, gainB, 0.0, 0.0);
> +		points[i][1] = delta2Sum;
> +		if (points[i].y() < points[bestPoint].y())
> +			bestPoint = i;
> +	}
> +
> +	bestPoint = std::clamp(bestPoint, 1, numDeltas - 2);
> +	ipa::Pwl::Point rbBest = ipa::Pwl::Point({ r, b }) +
> +				 transverse * interpolateQuadatric(points[bestPoint - 1],
> +								   points[bestPoint],
> +								   points[bestPoint + 1]);
> +	double rBest = rbBest.x(), bBest = rbBest.y();
> +
> +	r = rBest, b = bBest;
> +}
> +
> +AwbNN::RGB AwbNN::processZone(AwbNN::RGB zone, float redGain, float blueGain)
> +{
> +	/*
> +	 * Renders the pixel at 5000K temperature
> +	 */
> +	RGB zoneGains = zone;
> +
> +	zoneGains.R *= redGain;
> +	zoneGains.G *= 1.0;
> +	zoneGains.B *= blueGain;
> +
> +	RGB zoneCcm;
> +
> +	zoneCcm.R = nnConfig_.ccm[0] * zoneGains.R + nnConfig_.ccm[1] * zoneGains.G + nnConfig_.ccm[2] * zoneGains.B;
> +	zoneCcm.G = nnConfig_.ccm[3] * zoneGains.R + nnConfig_.ccm[4] * zoneGains.G + nnConfig_.ccm[5] * zoneGains.B;
> +	zoneCcm.B = nnConfig_.ccm[6] * zoneGains.R + nnConfig_.ccm[7] * zoneGains.G + nnConfig_.ccm[8] * zoneGains.B;
> +
> +	return zoneCcm;
> +}
> +
> +void AwbNN::awbNN()
> +{
> +	float *inputData = interpreter_->typed_input_tensor<float>(0);
> +	float *inputLux = interpreter_->typed_input_tensor<float>(1);
> +
> +	float redGain = 1.0 / config_.ctR.eval(5000);
> +	float blueGain = 1.0 / config_.ctB.eval(5000);
> +
> +	for (uint i = 0; i < zoneSize_.height; i++) {
> +		for (uint j = 0; j < zoneSize_.width; j++) {
> +			uint zoneIdx = i * zoneSize_.width + j;
> +
> +			RGB processedZone = processZone(zones_[zoneIdx] * (1.0 / 65535), redGain, blueGain);
> +			uint baseIdx = zoneIdx * 3;
> +
> +			inputData[baseIdx + 0] = static_cast<float>(processedZone.R);
> +			inputData[baseIdx + 1] = static_cast<float>(processedZone.G);
> +			inputData[baseIdx + 2] = static_cast<float>(processedZone.B);
> +		}
> +	}
> +
> +	inputLux[0] = static_cast<float>(lux_);
> +
> +	TfLiteStatus status = interpreter_->Invoke();

How long does this typically take ?

> +	if (status != kTfLiteOk) {
> +		LOG(RPiAwb, Error) << "Model inference failed with status: " << status;
> +		return;
> +	}
> +
> +	float *outputData = interpreter_->typed_output_tensor<float>(0);
> +
> +	double t = outputData[0];
> +
> +	LOG(RPiAwb, Debug) << "Model output temperature: " << t;
> +
> +	t = std::clamp(t, mode_->ctLo, mode_->ctHi);
> +
> +	double r = config_.ctR.eval(t);
> +	double b = config_.ctB.eval(t);
> +
> +	transverseSearch(t, r, b);
> +
> +	LOG(RPiAwb, Debug) << "After transverse search: Temperature: " << t << " Red gain: " << 1.0 / r << " Blue gain: " << 1.0 / b;
> +
> +	asyncResults_.temperatureK = t;
> +	asyncResults_.gainR = 1.0 / r * config_.sensitivityR;
> +	asyncResults_.gainG = 1.0;
> +	asyncResults_.gainB = 1.0 / b * config_.sensitivityB;
> +}
> +
> +void AwbNN::doAwb()
> +{
> +	prepareStats();
> +	if (zones_.size() == (zoneSize_.width * zoneSize_.height) && nnConfig_.enableNn) {
> +		awbNN();
> +	} else {
> +		awbGrey();
> +	}
> +	statistics_.reset();
> +}
> +
> +/* Register algorithm with the system. */
> +static Algorithm *create(Controller *controller)
> +{
> +	return (Algorithm *)new AwbNN(controller);
> +}
> +static RegisterAlgorithm reg(NAME, &create);
> +
> +} /* namespace RPiController */
David Plowman Dec. 11, 2025, 3:56 p.m. UTC | #3
Hi Laruent

Thanks for the comments.

On Thu, 11 Dec 2025 at 08:21, Laurent Pinchart
<laurent.pinchart@ideasonboard.com> wrote:
>
> Hi David,
>
> Thank you for the patch.
>
> On Fri, Oct 24, 2025 at 03:16:01PM +0100, David Plowman wrote:
> > From: Peter Bailey <peter.bailey@raspberrypi.com>
> >
> > Add an Awb algorithm which uses neural networks.
> >
> > Signed-off-by: Peter Bailey <peter.bailey@raspberrypi.com>
> > ---
> >  src/ipa/rpi/controller/meson.build    |   9 +
> >  src/ipa/rpi/controller/rpi/awb_nn.cpp | 442 ++++++++++++++++++++++++++
> >  2 files changed, 451 insertions(+)
> >  create mode 100644 src/ipa/rpi/controller/rpi/awb_nn.cpp
> >
> > diff --git a/src/ipa/rpi/controller/meson.build b/src/ipa/rpi/controller/meson.build
> > index 73c93dca..2541d073 100644
> > --- a/src/ipa/rpi/controller/meson.build
> > +++ b/src/ipa/rpi/controller/meson.build
> > @@ -32,6 +32,15 @@ rpi_ipa_controller_deps = [
> >      libcamera_private,
> >  ]
> >
> > +tflite_dep = dependency('tensorflow-lite', required : false)
>
> This needs a configuration option, otherwise libcamera may accidentally
> depend on tensorflow-lite when not desired. Package maintainers for
> distributions hate those accidental dependencies that make builds
> non-reproducible.

Will do. Thought I'd done it for v2 but failed to save the file! Will be in v3.

>
> > +
> > +if tflite_dep.found()
> > +    rpi_ipa_controller_sources += files([
> > +        'rpi/awb_nn.cpp',
> > +    ])
> > +    rpi_ipa_controller_deps += tflite_dep
> > +endif
> > +
> >  rpi_ipa_controller_lib = static_library('rpi_ipa_controller', rpi_ipa_controller_sources,
> >                                          include_directories : libipa_includes,
> >                                          dependencies : rpi_ipa_controller_deps)
> > diff --git a/src/ipa/rpi/controller/rpi/awb_nn.cpp b/src/ipa/rpi/controller/rpi/awb_nn.cpp
> > new file mode 100644
> > index 00000000..c309ca3f
> > --- /dev/null
> > +++ b/src/ipa/rpi/controller/rpi/awb_nn.cpp
> > @@ -0,0 +1,442 @@
> > +/* SPDX-License-Identifier: BSD-2-Clause */
> > +/*
> > + * Copyright (C) 2025, Raspberry Pi Ltd
> > + *
> > + * AWB control algorithm using neural network
> > + */
> > +
> > +#include <chrono>
> > +#include <condition_variable>
> > +#include <thread>
> > +
> > +#include <libcamera/base/file.h>
> > +#include <libcamera/base/log.h>
> > +
> > +#include <tensorflow/lite/interpreter.h>
> > +#include <tensorflow/lite/kernels/register.h>
> > +#include <tensorflow/lite/model.h>
> > +
> > +#include "../awb_algorithm.h"
> > +#include "../awb_status.h"
> > +#include "../lux_status.h"
> > +#include "libipa/pwl.h"
> > +
> > +#include "alsc_status.h"
> > +#include "awb.h"
> > +
> > +using namespace libcamera;
> > +
> > +LOG_DECLARE_CATEGORY(RPiAwb)
> > +
> > +constexpr double kDefaultCT = 4500.0;
> > +
> > +#define NAME "rpi.nn.awb"
> > +
> > +namespace RPiController {
> > +
> > +struct AwbNNConfig {
> > +     AwbNNConfig() {}
> > +     int read(const libcamera::YamlObject &params, AwbConfig &config);
> > +
> > +     /* An empty model will check default locations for model.tflite */
> > +     std::string model;
> > +     float minTemp;
> > +     float maxTemp;
> > +
> > +     bool enableNn;
> > +
> > +     /* CCM matrix for 5000K temperature */
> > +     double ccm[9];
> > +};
> > +
> > +class AwbNN : public Awb
> > +{
> > +public:
> > +     AwbNN(Controller *controller = NULL);
> > +     ~AwbNN();
> > +     char const *name() const override;
> > +     void initialise() override;
> > +     int read(const libcamera::YamlObject &params) override;
> > +
> > +protected:
> > +     void doAwb() override;
> > +     void prepareStats() override;
> > +
> > +private:
> > +     bool isAutoEnabled() const;
> > +     AwbNNConfig nnConfig_;
> > +     void transverseSearch(double t, double &r, double &b);
> > +     RGB processZone(RGB zone, float red_gain, float blue_gain);
> > +     void awbNN();
> > +     void loadModel();
> > +
> > +     libcamera::Size zoneSize_;
> > +     std::unique_ptr<tflite::FlatBufferModel> model_;
> > +     std::unique_ptr<tflite::Interpreter> interpreter_;
> > +};
> > +
> > +int AwbNNConfig::read(const libcamera::YamlObject &params, AwbConfig &config)
> > +{
> > +     model = params["model"].get<std::string>("");
> > +     minTemp = params["min_temp"].get<float>(2800.0);
> > +     maxTemp = params["max_temp"].get<float>(7600.0);
> > +
> > +     for (int i = 0; i < 9; i++)
> > +             ccm[i] = params["ccm"][i].get<double>(0.0);
> > +
> > +     enableNn = params["enable_nn"].get<int>(1);
> > +
> > +     if (enableNn) {
> > +             if (!config.hasCtCurve()) {
> > +                     LOG(RPiAwb, Error) << "CT curve not specified";
> > +                     enableNn = false;
> > +             }
> > +
> > +             if (!model.empty() && model.find(".tflite") == std::string::npos) {
> > +                     LOG(RPiAwb, Error) << "Model must be a .tflite file";
> > +                     enableNn = false;
> > +             }
> > +
> > +             bool validCcm = true;
> > +             for (int i = 0; i < 9; i++)
> > +                     if (ccm[i] == 0.0)
> > +                             validCcm = false;
> > +
> > +             if (!validCcm) {
> > +                     LOG(RPiAwb, Error) << "CCM not specified or invalid";
> > +                     enableNn = false;
> > +             }
> > +
> > +             if (!enableNn) {
> > +                     LOG(RPiAwb, Warning) << "Neural Network AWB mis-configured - switch to Grey method";
> > +             }
> > +     }
> > +
> > +     if (!enableNn) {
> > +             config.sensitivityR = config.sensitivityB = 1.0;
> > +             config.greyWorld = true;
> > +     }
> > +
> > +     return 0;
> > +}
> > +
> > +AwbNN::AwbNN(Controller *controller)
> > +     : Awb(controller)
> > +{
> > +     zoneSize_ = getHardwareConfig().awbRegions;
> > +}
> > +
> > +AwbNN::~AwbNN()
> > +{
> > +}
> > +
> > +char const *AwbNN::name() const
> > +{
> > +     return NAME;
> > +}
> > +
> > +int AwbNN::read(const libcamera::YamlObject &params)
> > +{
> > +     int ret;
> > +
> > +     ret = config_.read(params);
> > +     if (ret)
> > +             return ret;
> > +
> > +     ret = nnConfig_.read(params, config_);
> > +     if (ret)
> > +             return ret;
> > +
> > +     return 0;
> > +}
> > +
> > +static bool checkTensorShape(TfLiteTensor *tensor, const int *expectedDims, const int expectedDimsSize)
> > +{
> > +     if (tensor->dims->size != expectedDimsSize) {
> > +             return false;
> > +     }
> > +
> > +     for (int i = 0; i < tensor->dims->size; i++) {
> > +             if (tensor->dims->data[i] != expectedDims[i]) {
> > +                     return false;
> > +             }
> > +     }
> > +     return true;
> > +}
> > +
> > +static std::string buildDimString(const int *dims, const int dimsSize)
> > +{
> > +     std::string s = "[";
> > +     for (int i = 0; i < dimsSize; i++) {
> > +             s += std::to_string(dims[i]);
> > +             if (i < dimsSize - 1)
> > +                     s += ",";
> > +             else
> > +                     s += "]";
> > +     }
> > +     return s;
> > +}
> > +
> > +void AwbNN::loadModel()
> > +{
> > +     std::string modelPath;
> > +     if (getTarget() == "bcm2835") {
> > +             modelPath = "/ipa/rpi/vc4/awb_model.tflite";
> > +     } else {
> > +             modelPath = "/ipa/rpi/pisp/awb_model.tflite";
> > +     }
> > +
> > +     if (nnConfig_.model.empty()) {
> > +             std::string root = utils::libcameraSourcePath();
> > +             if (!root.empty()) {
> > +                     modelPath = root + modelPath;
> > +             } else {
> > +                     modelPath = LIBCAMERA_DATA_DIR + modelPath;
> > +             }
> > +
> > +             if (!File::exists(modelPath)) {
> > +                     LOG(RPiAwb, Error) << "No model file found in standard locations";
> > +                     nnConfig_.enableNn = false;
> > +                     return;
> > +             }
> > +     } else {
> > +             modelPath = nnConfig_.model;
> > +     }
> > +
> > +     LOG(RPiAwb, Debug) << "Attempting to load model from: " << modelPath;
> > +
> > +     model_ = tflite::FlatBufferModel::BuildFromFile(modelPath.c_str());
> > +
> > +     if (!model_) {
> > +             LOG(RPiAwb, Error) << "Failed to load model from " << modelPath;
> > +             nnConfig_.enableNn = false;
> > +             return;
> > +     }
> > +
> > +     tflite::MutableOpResolver resolver;
> > +     tflite::ops::builtin::BuiltinOpResolver builtin_resolver;
> > +     resolver.AddAll(builtin_resolver);
> > +     tflite::InterpreterBuilder(*model_, resolver)(&interpreter_);
> > +     if (!interpreter_) {
> > +             LOG(RPiAwb, Error) << "Failed to build interpreter for model " << nnConfig_.model;
> > +             nnConfig_.enableNn = false;
> > +             return;
> > +     }
> > +
> > +     interpreter_->AllocateTensors();
> > +     TfLiteTensor *inputTensor = interpreter_->input_tensor(0);
> > +     TfLiteTensor *inputLuxTensor = interpreter_->input_tensor(1);
> > +     TfLiteTensor *outputTensor = interpreter_->output_tensor(0);
> > +     if (!inputTensor || !inputLuxTensor || !outputTensor) {
> > +             LOG(RPiAwb, Error) << "Model missing input or output tensor";
> > +             nnConfig_.enableNn = false;
> > +             return;
> > +     }
> > +
> > +     const int expectedInputDims[] = { 1, (int)zoneSize_.height, (int)zoneSize_.width, 3 };
> > +     const int expectedInputLuxDims[] = { 1 };
> > +     const int expectedOutputDims[] = { 1 };
> > +
> > +     if (!checkTensorShape(inputTensor, expectedInputDims, 4)) {
> > +             LOG(RPiAwb, Error) << "Model input tensor dimension mismatch. Expected: " << buildDimString(expectedInputDims, 4)
> > +                                << ", Got: " << buildDimString(inputTensor->dims->data, inputTensor->dims->size);
> > +             nnConfig_.enableNn = false;
> > +             return;
> > +     }
> > +
> > +     if (!checkTensorShape(inputLuxTensor, expectedInputLuxDims, 1)) {
> > +             LOG(RPiAwb, Error) << "Model input lux tensor dimension mismatch. Expected: " << buildDimString(expectedInputLuxDims, 1)
> > +                                << ", Got: " << buildDimString(inputLuxTensor->dims->data, inputLuxTensor->dims->size);
> > +             nnConfig_.enableNn = false;
> > +             return;
> > +     }
> > +
> > +     if (!checkTensorShape(outputTensor, expectedOutputDims, 1)) {
> > +             LOG(RPiAwb, Error) << "Model output tensor dimension mismatch. Expected: " << buildDimString(expectedOutputDims, 1)
> > +                                << ", Got: " << buildDimString(outputTensor->dims->data, outputTensor->dims->size);
> > +             nnConfig_.enableNn = false;
> > +             return;
> > +     }
> > +
> > +     if (inputTensor->type != kTfLiteFloat32 || inputLuxTensor->type != kTfLiteFloat32 || outputTensor->type != kTfLiteFloat32) {
> > +             LOG(RPiAwb, Error) << "Model input and output tensors must be float32";
> > +             nnConfig_.enableNn = false;
> > +             return;
> > +     }
> > +
> > +     LOG(RPiAwb, Info) << "Model loaded successfully from " << modelPath;
> > +     LOG(RPiAwb, Debug) << "Model validation successful - Input Image: "
> > +                        << buildDimString(expectedInputDims, 4)
> > +                        << ", Input Lux: " << buildDimString(expectedInputLuxDims, 1)
> > +                        << ", Output: " << buildDimString(expectedOutputDims, 1) << " floats";
> > +}
> > +
> > +void AwbNN::initialise()
> > +{
> > +     Awb::initialise();
> > +
> > +     if (nnConfig_.enableNn) {
> > +             loadModel();
> > +             if (!nnConfig_.enableNn) {
> > +                     LOG(RPiAwb, Warning) << "Neural Network AWB failed to load - switch to Grey method";
> > +                     config_.greyWorld = true;
> > +                     config_.sensitivityR = config_.sensitivityB = 1.0;
> > +             }
> > +     }
> > +}
> > +
> > +void AwbNN::prepareStats()
> > +{
> > +     zones_.clear();
> > +     /*
> > +      * LSC has already been applied to the stats in this pipeline, so stop
> > +      * any LSC compensation.  We also ignore config_.fast in this version.
> > +      */
> > +     generateStats(zones_, statistics_, 0.0, 0.0, getGlobalMetadata(), 0.0, 0.0, 0.0);
> > +     /*
> > +      * apply sensitivities, so values appear to come from our "canonical"
> > +      * sensor.
> > +      */
> > +     for (auto &zone : zones_) {
> > +             zone.R *= config_.sensitivityR;
> > +             zone.B *= config_.sensitivityB;
> > +     }
> > +}
> > +
> > +void AwbNN::transverseSearch(double t, double &r, double &b)
> > +{
> > +     int spanR = -1, spanB = -1;
> > +     config_.ctR.eval(t, &spanR);
> > +     config_.ctB.eval(t, &spanB);
> > +
> > +     const int diff = 10;
> > +     double rDiff = config_.ctR.eval(t + diff, &spanR) -
> > +                    config_.ctR.eval(t - diff, &spanR);
> > +     double bDiff = config_.ctB.eval(t + diff, &spanB) -
> > +                    config_.ctB.eval(t - diff, &spanB);
> > +
> > +     ipa::Pwl::Point transverse({ bDiff, -rDiff });
> > +     if (transverse.length2() < 1e-6)
> > +             return;
> > +
> > +     transverse = transverse / transverse.length();
> > +     double transverseRange = config_.transverseNeg + config_.transversePos;
> > +     const int maxNumDeltas = 12;
> > +     int numDeltas = floor(transverseRange * 100 + 0.5) + 1;
> > +     numDeltas = numDeltas < 3 ? 3 : (numDeltas > maxNumDeltas ? maxNumDeltas : numDeltas);
> > +
> > +     ipa::Pwl::Point points[maxNumDeltas];
> > +     int bestPoint = 0;
> > +
> > +     for (int i = 0; i < numDeltas; i++) {
> > +             points[i][0] = -config_.transverseNeg +
> > +                            (transverseRange * i) / (numDeltas - 1);
> > +             ipa::Pwl::Point rbTest = ipa::Pwl::Point({ r, b }) +
> > +                                      transverse * points[i].x();
> > +             double rTest = rbTest.x(), bTest = rbTest.y();
> > +             double gainR = 1 / rTest, gainB = 1 / bTest;
> > +             double delta2Sum = computeDelta2Sum(gainR, gainB, 0.0, 0.0);
> > +             points[i][1] = delta2Sum;
> > +             if (points[i].y() < points[bestPoint].y())
> > +                     bestPoint = i;
> > +     }
> > +
> > +     bestPoint = std::clamp(bestPoint, 1, numDeltas - 2);
> > +     ipa::Pwl::Point rbBest = ipa::Pwl::Point({ r, b }) +
> > +                              transverse * interpolateQuadatric(points[bestPoint - 1],
> > +                                                                points[bestPoint],
> > +                                                                points[bestPoint + 1]);
> > +     double rBest = rbBest.x(), bBest = rbBest.y();
> > +
> > +     r = rBest, b = bBest;
> > +}
> > +
> > +AwbNN::RGB AwbNN::processZone(AwbNN::RGB zone, float redGain, float blueGain)
> > +{
> > +     /*
> > +      * Renders the pixel at 5000K temperature
> > +      */
> > +     RGB zoneGains = zone;
> > +
> > +     zoneGains.R *= redGain;
> > +     zoneGains.G *= 1.0;
> > +     zoneGains.B *= blueGain;
> > +
> > +     RGB zoneCcm;
> > +
> > +     zoneCcm.R = nnConfig_.ccm[0] * zoneGains.R + nnConfig_.ccm[1] * zoneGains.G + nnConfig_.ccm[2] * zoneGains.B;
> > +     zoneCcm.G = nnConfig_.ccm[3] * zoneGains.R + nnConfig_.ccm[4] * zoneGains.G + nnConfig_.ccm[5] * zoneGains.B;
> > +     zoneCcm.B = nnConfig_.ccm[6] * zoneGains.R + nnConfig_.ccm[7] * zoneGains.G + nnConfig_.ccm[8] * zoneGains.B;
> > +
> > +     return zoneCcm;
> > +}
> > +
> > +void AwbNN::awbNN()
> > +{
> > +     float *inputData = interpreter_->typed_input_tensor<float>(0);
> > +     float *inputLux = interpreter_->typed_input_tensor<float>(1);
> > +
> > +     float redGain = 1.0 / config_.ctR.eval(5000);
> > +     float blueGain = 1.0 / config_.ctB.eval(5000);
> > +
> > +     for (uint i = 0; i < zoneSize_.height; i++) {
> > +             for (uint j = 0; j < zoneSize_.width; j++) {
> > +                     uint zoneIdx = i * zoneSize_.width + j;
> > +
> > +                     RGB processedZone = processZone(zones_[zoneIdx] * (1.0 / 65535), redGain, blueGain);
> > +                     uint baseIdx = zoneIdx * 3;
> > +
> > +                     inputData[baseIdx + 0] = static_cast<float>(processedZone.R);
> > +                     inputData[baseIdx + 1] = static_cast<float>(processedZone.G);
> > +                     inputData[baseIdx + 2] = static_cast<float>(processedZone.B);
> > +             }
> > +     }
> > +
> > +     inputLux[0] = static_cast<float>(lux_);
> > +
> > +     TfLiteStatus status = interpreter_->Invoke();
>
> How long does this typically take ?

Because they run on the image statistics they're actually tiny models
(compared to most image processing/analysis models). I forget exactly,
but it's down at the millisecond or two level on a Pi 5, so "a few
times slower" on other Pis.

Thanks
David

>
> > +     if (status != kTfLiteOk) {
> > +             LOG(RPiAwb, Error) << "Model inference failed with status: " << status;
> > +             return;
> > +     }
> > +
> > +     float *outputData = interpreter_->typed_output_tensor<float>(0);
> > +
> > +     double t = outputData[0];
> > +
> > +     LOG(RPiAwb, Debug) << "Model output temperature: " << t;
> > +
> > +     t = std::clamp(t, mode_->ctLo, mode_->ctHi);
> > +
> > +     double r = config_.ctR.eval(t);
> > +     double b = config_.ctB.eval(t);
> > +
> > +     transverseSearch(t, r, b);
> > +
> > +     LOG(RPiAwb, Debug) << "After transverse search: Temperature: " << t << " Red gain: " << 1.0 / r << " Blue gain: " << 1.0 / b;
> > +
> > +     asyncResults_.temperatureK = t;
> > +     asyncResults_.gainR = 1.0 / r * config_.sensitivityR;
> > +     asyncResults_.gainG = 1.0;
> > +     asyncResults_.gainB = 1.0 / b * config_.sensitivityB;
> > +}
> > +
> > +void AwbNN::doAwb()
> > +{
> > +     prepareStats();
> > +     if (zones_.size() == (zoneSize_.width * zoneSize_.height) && nnConfig_.enableNn) {
> > +             awbNN();
> > +     } else {
> > +             awbGrey();
> > +     }
> > +     statistics_.reset();
> > +}
> > +
> > +/* Register algorithm with the system. */
> > +static Algorithm *create(Controller *controller)
> > +{
> > +     return (Algorithm *)new AwbNN(controller);
> > +}
> > +static RegisterAlgorithm reg(NAME, &create);
> > +
> > +} /* namespace RPiController */
>
> --
> Regards,
>
> Laurent Pinchart
Naushir Patuck Dec. 11, 2025, 3:59 p.m. UTC | #4
On Thu, 11 Dec 2025 at 15:56, David Plowman
<david.plowman@raspberrypi.com> wrote:
>
> Hi Laruent
>
> Thanks for the comments.
>
> On Thu, 11 Dec 2025 at 08:21, Laurent Pinchart
> <laurent.pinchart@ideasonboard.com> wrote:
> >
> > Hi David,
> >
> > Thank you for the patch.
> >
> > On Fri, Oct 24, 2025 at 03:16:01PM +0100, David Plowman wrote:
> > > From: Peter Bailey <peter.bailey@raspberrypi.com>
> > >
> > > Add an Awb algorithm which uses neural networks.
> > >
> > > Signed-off-by: Peter Bailey <peter.bailey@raspberrypi.com>
> > > ---
> > >  src/ipa/rpi/controller/meson.build    |   9 +
> > >  src/ipa/rpi/controller/rpi/awb_nn.cpp | 442 ++++++++++++++++++++++++++
> > >  2 files changed, 451 insertions(+)
> > >  create mode 100644 src/ipa/rpi/controller/rpi/awb_nn.cpp
> > >
> > > diff --git a/src/ipa/rpi/controller/meson.build b/src/ipa/rpi/controller/meson.build
> > > index 73c93dca..2541d073 100644
> > > --- a/src/ipa/rpi/controller/meson.build
> > > +++ b/src/ipa/rpi/controller/meson.build
> > > @@ -32,6 +32,15 @@ rpi_ipa_controller_deps = [
> > >      libcamera_private,
> > >  ]
> > >
> > > +tflite_dep = dependency('tensorflow-lite', required : false)
> >
> > This needs a configuration option, otherwise libcamera may accidentally
> > depend on tensorflow-lite when not desired. Package maintainers for
> > distributions hate those accidental dependencies that make builds
> > non-reproducible.
>
> Will do. Thought I'd done it for v2 but failed to save the file! Will be in v3.
>
> >
> > > +
> > > +if tflite_dep.found()
> > > +    rpi_ipa_controller_sources += files([
> > > +        'rpi/awb_nn.cpp',
> > > +    ])
> > > +    rpi_ipa_controller_deps += tflite_dep
> > > +endif
> > > +
> > >  rpi_ipa_controller_lib = static_library('rpi_ipa_controller', rpi_ipa_controller_sources,
> > >                                          include_directories : libipa_includes,
> > >                                          dependencies : rpi_ipa_controller_deps)
> > > diff --git a/src/ipa/rpi/controller/rpi/awb_nn.cpp b/src/ipa/rpi/controller/rpi/awb_nn.cpp
> > > new file mode 100644
> > > index 00000000..c309ca3f
> > > --- /dev/null
> > > +++ b/src/ipa/rpi/controller/rpi/awb_nn.cpp
> > > @@ -0,0 +1,442 @@
> > > +/* SPDX-License-Identifier: BSD-2-Clause */
> > > +/*
> > > + * Copyright (C) 2025, Raspberry Pi Ltd
> > > + *
> > > + * AWB control algorithm using neural network
> > > + */
> > > +
> > > +#include <chrono>
> > > +#include <condition_variable>
> > > +#include <thread>
> > > +
> > > +#include <libcamera/base/file.h>
> > > +#include <libcamera/base/log.h>
> > > +
> > > +#include <tensorflow/lite/interpreter.h>
> > > +#include <tensorflow/lite/kernels/register.h>
> > > +#include <tensorflow/lite/model.h>
> > > +
> > > +#include "../awb_algorithm.h"
> > > +#include "../awb_status.h"
> > > +#include "../lux_status.h"
> > > +#include "libipa/pwl.h"
> > > +
> > > +#include "alsc_status.h"
> > > +#include "awb.h"
> > > +
> > > +using namespace libcamera;
> > > +
> > > +LOG_DECLARE_CATEGORY(RPiAwb)
> > > +
> > > +constexpr double kDefaultCT = 4500.0;
> > > +
> > > +#define NAME "rpi.nn.awb"
> > > +
> > > +namespace RPiController {
> > > +
> > > +struct AwbNNConfig {
> > > +     AwbNNConfig() {}
> > > +     int read(const libcamera::YamlObject &params, AwbConfig &config);
> > > +
> > > +     /* An empty model will check default locations for model.tflite */
> > > +     std::string model;
> > > +     float minTemp;
> > > +     float maxTemp;
> > > +
> > > +     bool enableNn;
> > > +
> > > +     /* CCM matrix for 5000K temperature */
> > > +     double ccm[9];
> > > +};
> > > +
> > > +class AwbNN : public Awb
> > > +{
> > > +public:
> > > +     AwbNN(Controller *controller = NULL);
> > > +     ~AwbNN();
> > > +     char const *name() const override;
> > > +     void initialise() override;
> > > +     int read(const libcamera::YamlObject &params) override;
> > > +
> > > +protected:
> > > +     void doAwb() override;
> > > +     void prepareStats() override;
> > > +
> > > +private:
> > > +     bool isAutoEnabled() const;
> > > +     AwbNNConfig nnConfig_;
> > > +     void transverseSearch(double t, double &r, double &b);
> > > +     RGB processZone(RGB zone, float red_gain, float blue_gain);
> > > +     void awbNN();
> > > +     void loadModel();
> > > +
> > > +     libcamera::Size zoneSize_;
> > > +     std::unique_ptr<tflite::FlatBufferModel> model_;
> > > +     std::unique_ptr<tflite::Interpreter> interpreter_;
> > > +};
> > > +
> > > +int AwbNNConfig::read(const libcamera::YamlObject &params, AwbConfig &config)
> > > +{
> > > +     model = params["model"].get<std::string>("");
> > > +     minTemp = params["min_temp"].get<float>(2800.0);
> > > +     maxTemp = params["max_temp"].get<float>(7600.0);
> > > +
> > > +     for (int i = 0; i < 9; i++)
> > > +             ccm[i] = params["ccm"][i].get<double>(0.0);
> > > +
> > > +     enableNn = params["enable_nn"].get<int>(1);
> > > +
> > > +     if (enableNn) {
> > > +             if (!config.hasCtCurve()) {
> > > +                     LOG(RPiAwb, Error) << "CT curve not specified";
> > > +                     enableNn = false;
> > > +             }
> > > +
> > > +             if (!model.empty() && model.find(".tflite") == std::string::npos) {
> > > +                     LOG(RPiAwb, Error) << "Model must be a .tflite file";
> > > +                     enableNn = false;
> > > +             }
> > > +
> > > +             bool validCcm = true;
> > > +             for (int i = 0; i < 9; i++)
> > > +                     if (ccm[i] == 0.0)
> > > +                             validCcm = false;
> > > +
> > > +             if (!validCcm) {
> > > +                     LOG(RPiAwb, Error) << "CCM not specified or invalid";
> > > +                     enableNn = false;
> > > +             }
> > > +
> > > +             if (!enableNn) {
> > > +                     LOG(RPiAwb, Warning) << "Neural Network AWB mis-configured - switch to Grey method";
> > > +             }
> > > +     }
> > > +
> > > +     if (!enableNn) {
> > > +             config.sensitivityR = config.sensitivityB = 1.0;
> > > +             config.greyWorld = true;
> > > +     }
> > > +
> > > +     return 0;
> > > +}
> > > +
> > > +AwbNN::AwbNN(Controller *controller)
> > > +     : Awb(controller)
> > > +{
> > > +     zoneSize_ = getHardwareConfig().awbRegions;
> > > +}
> > > +
> > > +AwbNN::~AwbNN()
> > > +{
> > > +}
> > > +
> > > +char const *AwbNN::name() const
> > > +{
> > > +     return NAME;
> > > +}
> > > +
> > > +int AwbNN::read(const libcamera::YamlObject &params)
> > > +{
> > > +     int ret;
> > > +
> > > +     ret = config_.read(params);
> > > +     if (ret)
> > > +             return ret;
> > > +
> > > +     ret = nnConfig_.read(params, config_);
> > > +     if (ret)
> > > +             return ret;
> > > +
> > > +     return 0;
> > > +}
> > > +
> > > +static bool checkTensorShape(TfLiteTensor *tensor, const int *expectedDims, const int expectedDimsSize)
> > > +{
> > > +     if (tensor->dims->size != expectedDimsSize) {
> > > +             return false;
> > > +     }
> > > +
> > > +     for (int i = 0; i < tensor->dims->size; i++) {
> > > +             if (tensor->dims->data[i] != expectedDims[i]) {
> > > +                     return false;
> > > +             }
> > > +     }
> > > +     return true;
> > > +}
> > > +
> > > +static std::string buildDimString(const int *dims, const int dimsSize)
> > > +{
> > > +     std::string s = "[";
> > > +     for (int i = 0; i < dimsSize; i++) {
> > > +             s += std::to_string(dims[i]);
> > > +             if (i < dimsSize - 1)
> > > +                     s += ",";
> > > +             else
> > > +                     s += "]";
> > > +     }
> > > +     return s;
> > > +}
> > > +
> > > +void AwbNN::loadModel()
> > > +{
> > > +     std::string modelPath;
> > > +     if (getTarget() == "bcm2835") {
> > > +             modelPath = "/ipa/rpi/vc4/awb_model.tflite";
> > > +     } else {
> > > +             modelPath = "/ipa/rpi/pisp/awb_model.tflite";
> > > +     }
> > > +
> > > +     if (nnConfig_.model.empty()) {
> > > +             std::string root = utils::libcameraSourcePath();
> > > +             if (!root.empty()) {
> > > +                     modelPath = root + modelPath;
> > > +             } else {
> > > +                     modelPath = LIBCAMERA_DATA_DIR + modelPath;
> > > +             }
> > > +
> > > +             if (!File::exists(modelPath)) {
> > > +                     LOG(RPiAwb, Error) << "No model file found in standard locations";
> > > +                     nnConfig_.enableNn = false;
> > > +                     return;
> > > +             }
> > > +     } else {
> > > +             modelPath = nnConfig_.model;
> > > +     }
> > > +
> > > +     LOG(RPiAwb, Debug) << "Attempting to load model from: " << modelPath;
> > > +
> > > +     model_ = tflite::FlatBufferModel::BuildFromFile(modelPath.c_str());
> > > +
> > > +     if (!model_) {
> > > +             LOG(RPiAwb, Error) << "Failed to load model from " << modelPath;
> > > +             nnConfig_.enableNn = false;
> > > +             return;
> > > +     }
> > > +
> > > +     tflite::MutableOpResolver resolver;
> > > +     tflite::ops::builtin::BuiltinOpResolver builtin_resolver;
> > > +     resolver.AddAll(builtin_resolver);
> > > +     tflite::InterpreterBuilder(*model_, resolver)(&interpreter_);
> > > +     if (!interpreter_) {
> > > +             LOG(RPiAwb, Error) << "Failed to build interpreter for model " << nnConfig_.model;
> > > +             nnConfig_.enableNn = false;
> > > +             return;
> > > +     }
> > > +
> > > +     interpreter_->AllocateTensors();
> > > +     TfLiteTensor *inputTensor = interpreter_->input_tensor(0);
> > > +     TfLiteTensor *inputLuxTensor = interpreter_->input_tensor(1);
> > > +     TfLiteTensor *outputTensor = interpreter_->output_tensor(0);
> > > +     if (!inputTensor || !inputLuxTensor || !outputTensor) {
> > > +             LOG(RPiAwb, Error) << "Model missing input or output tensor";
> > > +             nnConfig_.enableNn = false;
> > > +             return;
> > > +     }
> > > +
> > > +     const int expectedInputDims[] = { 1, (int)zoneSize_.height, (int)zoneSize_.width, 3 };
> > > +     const int expectedInputLuxDims[] = { 1 };
> > > +     const int expectedOutputDims[] = { 1 };
> > > +
> > > +     if (!checkTensorShape(inputTensor, expectedInputDims, 4)) {
> > > +             LOG(RPiAwb, Error) << "Model input tensor dimension mismatch. Expected: " << buildDimString(expectedInputDims, 4)
> > > +                                << ", Got: " << buildDimString(inputTensor->dims->data, inputTensor->dims->size);
> > > +             nnConfig_.enableNn = false;
> > > +             return;
> > > +     }
> > > +
> > > +     if (!checkTensorShape(inputLuxTensor, expectedInputLuxDims, 1)) {
> > > +             LOG(RPiAwb, Error) << "Model input lux tensor dimension mismatch. Expected: " << buildDimString(expectedInputLuxDims, 1)
> > > +                                << ", Got: " << buildDimString(inputLuxTensor->dims->data, inputLuxTensor->dims->size);
> > > +             nnConfig_.enableNn = false;
> > > +             return;
> > > +     }
> > > +
> > > +     if (!checkTensorShape(outputTensor, expectedOutputDims, 1)) {
> > > +             LOG(RPiAwb, Error) << "Model output tensor dimension mismatch. Expected: " << buildDimString(expectedOutputDims, 1)
> > > +                                << ", Got: " << buildDimString(outputTensor->dims->data, outputTensor->dims->size);
> > > +             nnConfig_.enableNn = false;
> > > +             return;
> > > +     }
> > > +
> > > +     if (inputTensor->type != kTfLiteFloat32 || inputLuxTensor->type != kTfLiteFloat32 || outputTensor->type != kTfLiteFloat32) {
> > > +             LOG(RPiAwb, Error) << "Model input and output tensors must be float32";
> > > +             nnConfig_.enableNn = false;
> > > +             return;
> > > +     }
> > > +
> > > +     LOG(RPiAwb, Info) << "Model loaded successfully from " << modelPath;
> > > +     LOG(RPiAwb, Debug) << "Model validation successful - Input Image: "
> > > +                        << buildDimString(expectedInputDims, 4)
> > > +                        << ", Input Lux: " << buildDimString(expectedInputLuxDims, 1)
> > > +                        << ", Output: " << buildDimString(expectedOutputDims, 1) << " floats";
> > > +}
> > > +
> > > +void AwbNN::initialise()
> > > +{
> > > +     Awb::initialise();
> > > +
> > > +     if (nnConfig_.enableNn) {
> > > +             loadModel();
> > > +             if (!nnConfig_.enableNn) {
> > > +                     LOG(RPiAwb, Warning) << "Neural Network AWB failed to load - switch to Grey method";
> > > +                     config_.greyWorld = true;
> > > +                     config_.sensitivityR = config_.sensitivityB = 1.0;
> > > +             }
> > > +     }
> > > +}
> > > +
> > > +void AwbNN::prepareStats()
> > > +{
> > > +     zones_.clear();
> > > +     /*
> > > +      * LSC has already been applied to the stats in this pipeline, so stop
> > > +      * any LSC compensation.  We also ignore config_.fast in this version.
> > > +      */
> > > +     generateStats(zones_, statistics_, 0.0, 0.0, getGlobalMetadata(), 0.0, 0.0, 0.0);
> > > +     /*
> > > +      * apply sensitivities, so values appear to come from our "canonical"
> > > +      * sensor.
> > > +      */
> > > +     for (auto &zone : zones_) {
> > > +             zone.R *= config_.sensitivityR;
> > > +             zone.B *= config_.sensitivityB;
> > > +     }
> > > +}
> > > +
> > > +void AwbNN::transverseSearch(double t, double &r, double &b)
> > > +{
> > > +     int spanR = -1, spanB = -1;
> > > +     config_.ctR.eval(t, &spanR);
> > > +     config_.ctB.eval(t, &spanB);
> > > +
> > > +     const int diff = 10;
> > > +     double rDiff = config_.ctR.eval(t + diff, &spanR) -
> > > +                    config_.ctR.eval(t - diff, &spanR);
> > > +     double bDiff = config_.ctB.eval(t + diff, &spanB) -
> > > +                    config_.ctB.eval(t - diff, &spanB);
> > > +
> > > +     ipa::Pwl::Point transverse({ bDiff, -rDiff });
> > > +     if (transverse.length2() < 1e-6)
> > > +             return;
> > > +
> > > +     transverse = transverse / transverse.length();
> > > +     double transverseRange = config_.transverseNeg + config_.transversePos;
> > > +     const int maxNumDeltas = 12;
> > > +     int numDeltas = floor(transverseRange * 100 + 0.5) + 1;
> > > +     numDeltas = numDeltas < 3 ? 3 : (numDeltas > maxNumDeltas ? maxNumDeltas : numDeltas);
> > > +
> > > +     ipa::Pwl::Point points[maxNumDeltas];
> > > +     int bestPoint = 0;
> > > +
> > > +     for (int i = 0; i < numDeltas; i++) {
> > > +             points[i][0] = -config_.transverseNeg +
> > > +                            (transverseRange * i) / (numDeltas - 1);
> > > +             ipa::Pwl::Point rbTest = ipa::Pwl::Point({ r, b }) +
> > > +                                      transverse * points[i].x();
> > > +             double rTest = rbTest.x(), bTest = rbTest.y();
> > > +             double gainR = 1 / rTest, gainB = 1 / bTest;
> > > +             double delta2Sum = computeDelta2Sum(gainR, gainB, 0.0, 0.0);
> > > +             points[i][1] = delta2Sum;
> > > +             if (points[i].y() < points[bestPoint].y())
> > > +                     bestPoint = i;
> > > +     }
> > > +
> > > +     bestPoint = std::clamp(bestPoint, 1, numDeltas - 2);
> > > +     ipa::Pwl::Point rbBest = ipa::Pwl::Point({ r, b }) +
> > > +                              transverse * interpolateQuadatric(points[bestPoint - 1],
> > > +                                                                points[bestPoint],
> > > +                                                                points[bestPoint + 1]);
> > > +     double rBest = rbBest.x(), bBest = rbBest.y();
> > > +
> > > +     r = rBest, b = bBest;
> > > +}
> > > +
> > > +AwbNN::RGB AwbNN::processZone(AwbNN::RGB zone, float redGain, float blueGain)
> > > +{
> > > +     /*
> > > +      * Renders the pixel at 5000K temperature
> > > +      */
> > > +     RGB zoneGains = zone;
> > > +
> > > +     zoneGains.R *= redGain;
> > > +     zoneGains.G *= 1.0;
> > > +     zoneGains.B *= blueGain;
> > > +
> > > +     RGB zoneCcm;
> > > +
> > > +     zoneCcm.R = nnConfig_.ccm[0] * zoneGains.R + nnConfig_.ccm[1] * zoneGains.G + nnConfig_.ccm[2] * zoneGains.B;
> > > +     zoneCcm.G = nnConfig_.ccm[3] * zoneGains.R + nnConfig_.ccm[4] * zoneGains.G + nnConfig_.ccm[5] * zoneGains.B;
> > > +     zoneCcm.B = nnConfig_.ccm[6] * zoneGains.R + nnConfig_.ccm[7] * zoneGains.G + nnConfig_.ccm[8] * zoneGains.B;
> > > +
> > > +     return zoneCcm;
> > > +}
> > > +
> > > +void AwbNN::awbNN()
> > > +{
> > > +     float *inputData = interpreter_->typed_input_tensor<float>(0);
> > > +     float *inputLux = interpreter_->typed_input_tensor<float>(1);
> > > +
> > > +     float redGain = 1.0 / config_.ctR.eval(5000);
> > > +     float blueGain = 1.0 / config_.ctB.eval(5000);
> > > +
> > > +     for (uint i = 0; i < zoneSize_.height; i++) {
> > > +             for (uint j = 0; j < zoneSize_.width; j++) {
> > > +                     uint zoneIdx = i * zoneSize_.width + j;
> > > +
> > > +                     RGB processedZone = processZone(zones_[zoneIdx] * (1.0 / 65535), redGain, blueGain);
> > > +                     uint baseIdx = zoneIdx * 3;
> > > +
> > > +                     inputData[baseIdx + 0] = static_cast<float>(processedZone.R);
> > > +                     inputData[baseIdx + 1] = static_cast<float>(processedZone.G);
> > > +                     inputData[baseIdx + 2] = static_cast<float>(processedZone.B);
> > > +             }
> > > +     }
> > > +
> > > +     inputLux[0] = static_cast<float>(lux_);
> > > +
> > > +     TfLiteStatus status = interpreter_->Invoke();
> >
> > How long does this typically take ?
>
> Because they run on the image statistics they're actually tiny models
> (compared to most image processing/analysis models). I forget exactly,
> but it's down at the millisecond or two level on a Pi 5, so "a few
> times slower" on other Pis.

If I recall, the runtime was around 100-200us per frame on a Pi 5.

>
> Thanks
> David
>
> >
> > > +     if (status != kTfLiteOk) {
> > > +             LOG(RPiAwb, Error) << "Model inference failed with status: " << status;
> > > +             return;
> > > +     }
> > > +
> > > +     float *outputData = interpreter_->typed_output_tensor<float>(0);
> > > +
> > > +     double t = outputData[0];
> > > +
> > > +     LOG(RPiAwb, Debug) << "Model output temperature: " << t;
> > > +
> > > +     t = std::clamp(t, mode_->ctLo, mode_->ctHi);
> > > +
> > > +     double r = config_.ctR.eval(t);
> > > +     double b = config_.ctB.eval(t);
> > > +
> > > +     transverseSearch(t, r, b);
> > > +
> > > +     LOG(RPiAwb, Debug) << "After transverse search: Temperature: " << t << " Red gain: " << 1.0 / r << " Blue gain: " << 1.0 / b;
> > > +
> > > +     asyncResults_.temperatureK = t;
> > > +     asyncResults_.gainR = 1.0 / r * config_.sensitivityR;
> > > +     asyncResults_.gainG = 1.0;
> > > +     asyncResults_.gainB = 1.0 / b * config_.sensitivityB;
> > > +}
> > > +
> > > +void AwbNN::doAwb()
> > > +{
> > > +     prepareStats();
> > > +     if (zones_.size() == (zoneSize_.width * zoneSize_.height) && nnConfig_.enableNn) {
> > > +             awbNN();
> > > +     } else {
> > > +             awbGrey();
> > > +     }
> > > +     statistics_.reset();
> > > +}
> > > +
> > > +/* Register algorithm with the system. */
> > > +static Algorithm *create(Controller *controller)
> > > +{
> > > +     return (Algorithm *)new AwbNN(controller);
> > > +}
> > > +static RegisterAlgorithm reg(NAME, &create);
> > > +
> > > +} /* namespace RPiController */
> >
> > --
> > Regards,
> >
> > Laurent Pinchart

Patch
diff mbox series

diff --git a/src/ipa/rpi/controller/meson.build b/src/ipa/rpi/controller/meson.build
index 73c93dca..2541d073 100644
--- a/src/ipa/rpi/controller/meson.build
+++ b/src/ipa/rpi/controller/meson.build
@@ -32,6 +32,15 @@  rpi_ipa_controller_deps = [
     libcamera_private,
 ]
 
+tflite_dep = dependency('tensorflow-lite', required : false)
+
+if tflite_dep.found()
+    rpi_ipa_controller_sources += files([
+        'rpi/awb_nn.cpp',
+    ])
+    rpi_ipa_controller_deps += tflite_dep
+endif
+
 rpi_ipa_controller_lib = static_library('rpi_ipa_controller', rpi_ipa_controller_sources,
                                         include_directories : libipa_includes,
                                         dependencies : rpi_ipa_controller_deps)
diff --git a/src/ipa/rpi/controller/rpi/awb_nn.cpp b/src/ipa/rpi/controller/rpi/awb_nn.cpp
new file mode 100644
index 00000000..c309ca3f
--- /dev/null
+++ b/src/ipa/rpi/controller/rpi/awb_nn.cpp
@@ -0,0 +1,442 @@ 
+/* SPDX-License-Identifier: BSD-2-Clause */
+/*
+ * Copyright (C) 2025, Raspberry Pi Ltd
+ *
+ * AWB control algorithm using neural network
+ */
+
+#include <chrono>
+#include <condition_variable>
+#include <thread>
+
+#include <libcamera/base/file.h>
+#include <libcamera/base/log.h>
+
+#include <tensorflow/lite/interpreter.h>
+#include <tensorflow/lite/kernels/register.h>
+#include <tensorflow/lite/model.h>
+
+#include "../awb_algorithm.h"
+#include "../awb_status.h"
+#include "../lux_status.h"
+#include "libipa/pwl.h"
+
+#include "alsc_status.h"
+#include "awb.h"
+
+using namespace libcamera;
+
+LOG_DECLARE_CATEGORY(RPiAwb)
+
+constexpr double kDefaultCT = 4500.0;
+
+#define NAME "rpi.nn.awb"
+
+namespace RPiController {
+
+struct AwbNNConfig {
+	AwbNNConfig() {}
+	int read(const libcamera::YamlObject &params, AwbConfig &config);
+
+	/* An empty model will check default locations for model.tflite */
+	std::string model;
+	float minTemp;
+	float maxTemp;
+
+	bool enableNn;
+
+	/* CCM matrix for 5000K temperature */
+	double ccm[9];
+};
+
+class AwbNN : public Awb
+{
+public:
+	AwbNN(Controller *controller = NULL);
+	~AwbNN();
+	char const *name() const override;
+	void initialise() override;
+	int read(const libcamera::YamlObject &params) override;
+
+protected:
+	void doAwb() override;
+	void prepareStats() override;
+
+private:
+	bool isAutoEnabled() const;
+	AwbNNConfig nnConfig_;
+	void transverseSearch(double t, double &r, double &b);
+	RGB processZone(RGB zone, float red_gain, float blue_gain);
+	void awbNN();
+	void loadModel();
+
+	libcamera::Size zoneSize_;
+	std::unique_ptr<tflite::FlatBufferModel> model_;
+	std::unique_ptr<tflite::Interpreter> interpreter_;
+};
+
+int AwbNNConfig::read(const libcamera::YamlObject &params, AwbConfig &config)
+{
+	model = params["model"].get<std::string>("");
+	minTemp = params["min_temp"].get<float>(2800.0);
+	maxTemp = params["max_temp"].get<float>(7600.0);
+
+	for (int i = 0; i < 9; i++)
+		ccm[i] = params["ccm"][i].get<double>(0.0);
+
+	enableNn = params["enable_nn"].get<int>(1);
+
+	if (enableNn) {
+		if (!config.hasCtCurve()) {
+			LOG(RPiAwb, Error) << "CT curve not specified";
+			enableNn = false;
+		}
+
+		if (!model.empty() && model.find(".tflite") == std::string::npos) {
+			LOG(RPiAwb, Error) << "Model must be a .tflite file";
+			enableNn = false;
+		}
+
+		bool validCcm = true;
+		for (int i = 0; i < 9; i++)
+			if (ccm[i] == 0.0)
+				validCcm = false;
+
+		if (!validCcm) {
+			LOG(RPiAwb, Error) << "CCM not specified or invalid";
+			enableNn = false;
+		}
+
+		if (!enableNn) {
+			LOG(RPiAwb, Warning) << "Neural Network AWB mis-configured - switch to Grey method";
+		}
+	}
+
+	if (!enableNn) {
+		config.sensitivityR = config.sensitivityB = 1.0;
+		config.greyWorld = true;
+	}
+
+	return 0;
+}
+
+AwbNN::AwbNN(Controller *controller)
+	: Awb(controller)
+{
+	zoneSize_ = getHardwareConfig().awbRegions;
+}
+
+AwbNN::~AwbNN()
+{
+}
+
+char const *AwbNN::name() const
+{
+	return NAME;
+}
+
+int AwbNN::read(const libcamera::YamlObject &params)
+{
+	int ret;
+
+	ret = config_.read(params);
+	if (ret)
+		return ret;
+
+	ret = nnConfig_.read(params, config_);
+	if (ret)
+		return ret;
+
+	return 0;
+}
+
+static bool checkTensorShape(TfLiteTensor *tensor, const int *expectedDims, const int expectedDimsSize)
+{
+	if (tensor->dims->size != expectedDimsSize) {
+		return false;
+	}
+
+	for (int i = 0; i < tensor->dims->size; i++) {
+		if (tensor->dims->data[i] != expectedDims[i]) {
+			return false;
+		}
+	}
+	return true;
+}
+
+static std::string buildDimString(const int *dims, const int dimsSize)
+{
+	std::string s = "[";
+	for (int i = 0; i < dimsSize; i++) {
+		s += std::to_string(dims[i]);
+		if (i < dimsSize - 1)
+			s += ",";
+		else
+			s += "]";
+	}
+	return s;
+}
+
+void AwbNN::loadModel()
+{
+	std::string modelPath;
+	if (getTarget() == "bcm2835") {
+		modelPath = "/ipa/rpi/vc4/awb_model.tflite";
+	} else {
+		modelPath = "/ipa/rpi/pisp/awb_model.tflite";
+	}
+
+	if (nnConfig_.model.empty()) {
+		std::string root = utils::libcameraSourcePath();
+		if (!root.empty()) {
+			modelPath = root + modelPath;
+		} else {
+			modelPath = LIBCAMERA_DATA_DIR + modelPath;
+		}
+
+		if (!File::exists(modelPath)) {
+			LOG(RPiAwb, Error) << "No model file found in standard locations";
+			nnConfig_.enableNn = false;
+			return;
+		}
+	} else {
+		modelPath = nnConfig_.model;
+	}
+
+	LOG(RPiAwb, Debug) << "Attempting to load model from: " << modelPath;
+
+	model_ = tflite::FlatBufferModel::BuildFromFile(modelPath.c_str());
+
+	if (!model_) {
+		LOG(RPiAwb, Error) << "Failed to load model from " << modelPath;
+		nnConfig_.enableNn = false;
+		return;
+	}
+
+	tflite::MutableOpResolver resolver;
+	tflite::ops::builtin::BuiltinOpResolver builtin_resolver;
+	resolver.AddAll(builtin_resolver);
+	tflite::InterpreterBuilder(*model_, resolver)(&interpreter_);
+	if (!interpreter_) {
+		LOG(RPiAwb, Error) << "Failed to build interpreter for model " << nnConfig_.model;
+		nnConfig_.enableNn = false;
+		return;
+	}
+
+	interpreter_->AllocateTensors();
+	TfLiteTensor *inputTensor = interpreter_->input_tensor(0);
+	TfLiteTensor *inputLuxTensor = interpreter_->input_tensor(1);
+	TfLiteTensor *outputTensor = interpreter_->output_tensor(0);
+	if (!inputTensor || !inputLuxTensor || !outputTensor) {
+		LOG(RPiAwb, Error) << "Model missing input or output tensor";
+		nnConfig_.enableNn = false;
+		return;
+	}
+
+	const int expectedInputDims[] = { 1, (int)zoneSize_.height, (int)zoneSize_.width, 3 };
+	const int expectedInputLuxDims[] = { 1 };
+	const int expectedOutputDims[] = { 1 };
+
+	if (!checkTensorShape(inputTensor, expectedInputDims, 4)) {
+		LOG(RPiAwb, Error) << "Model input tensor dimension mismatch. Expected: " << buildDimString(expectedInputDims, 4)
+				   << ", Got: " << buildDimString(inputTensor->dims->data, inputTensor->dims->size);
+		nnConfig_.enableNn = false;
+		return;
+	}
+
+	if (!checkTensorShape(inputLuxTensor, expectedInputLuxDims, 1)) {
+		LOG(RPiAwb, Error) << "Model input lux tensor dimension mismatch. Expected: " << buildDimString(expectedInputLuxDims, 1)
+				   << ", Got: " << buildDimString(inputLuxTensor->dims->data, inputLuxTensor->dims->size);
+		nnConfig_.enableNn = false;
+		return;
+	}
+
+	if (!checkTensorShape(outputTensor, expectedOutputDims, 1)) {
+		LOG(RPiAwb, Error) << "Model output tensor dimension mismatch. Expected: " << buildDimString(expectedOutputDims, 1)
+				   << ", Got: " << buildDimString(outputTensor->dims->data, outputTensor->dims->size);
+		nnConfig_.enableNn = false;
+		return;
+	}
+
+	if (inputTensor->type != kTfLiteFloat32 || inputLuxTensor->type != kTfLiteFloat32 || outputTensor->type != kTfLiteFloat32) {
+		LOG(RPiAwb, Error) << "Model input and output tensors must be float32";
+		nnConfig_.enableNn = false;
+		return;
+	}
+
+	LOG(RPiAwb, Info) << "Model loaded successfully from " << modelPath;
+	LOG(RPiAwb, Debug) << "Model validation successful - Input Image: "
+			   << buildDimString(expectedInputDims, 4)
+			   << ", Input Lux: " << buildDimString(expectedInputLuxDims, 1)
+			   << ", Output: " << buildDimString(expectedOutputDims, 1) << " floats";
+}
+
+void AwbNN::initialise()
+{
+	Awb::initialise();
+
+	if (nnConfig_.enableNn) {
+		loadModel();
+		if (!nnConfig_.enableNn) {
+			LOG(RPiAwb, Warning) << "Neural Network AWB failed to load - switch to Grey method";
+			config_.greyWorld = true;
+			config_.sensitivityR = config_.sensitivityB = 1.0;
+		}
+	}
+}
+
+void AwbNN::prepareStats()
+{
+	zones_.clear();
+	/*
+	 * LSC has already been applied to the stats in this pipeline, so stop
+	 * any LSC compensation.  We also ignore config_.fast in this version.
+	 */
+	generateStats(zones_, statistics_, 0.0, 0.0, getGlobalMetadata(), 0.0, 0.0, 0.0);
+	/*
+	 * apply sensitivities, so values appear to come from our "canonical"
+	 * sensor.
+	 */
+	for (auto &zone : zones_) {
+		zone.R *= config_.sensitivityR;
+		zone.B *= config_.sensitivityB;
+	}
+}
+
+void AwbNN::transverseSearch(double t, double &r, double &b)
+{
+	int spanR = -1, spanB = -1;
+	config_.ctR.eval(t, &spanR);
+	config_.ctB.eval(t, &spanB);
+
+	const int diff = 10;
+	double rDiff = config_.ctR.eval(t + diff, &spanR) -
+		       config_.ctR.eval(t - diff, &spanR);
+	double bDiff = config_.ctB.eval(t + diff, &spanB) -
+		       config_.ctB.eval(t - diff, &spanB);
+
+	ipa::Pwl::Point transverse({ bDiff, -rDiff });
+	if (transverse.length2() < 1e-6)
+		return;
+
+	transverse = transverse / transverse.length();
+	double transverseRange = config_.transverseNeg + config_.transversePos;
+	const int maxNumDeltas = 12;
+	int numDeltas = floor(transverseRange * 100 + 0.5) + 1;
+	numDeltas = numDeltas < 3 ? 3 : (numDeltas > maxNumDeltas ? maxNumDeltas : numDeltas);
+
+	ipa::Pwl::Point points[maxNumDeltas];
+	int bestPoint = 0;
+
+	for (int i = 0; i < numDeltas; i++) {
+		points[i][0] = -config_.transverseNeg +
+			       (transverseRange * i) / (numDeltas - 1);
+		ipa::Pwl::Point rbTest = ipa::Pwl::Point({ r, b }) +
+					 transverse * points[i].x();
+		double rTest = rbTest.x(), bTest = rbTest.y();
+		double gainR = 1 / rTest, gainB = 1 / bTest;
+		double delta2Sum = computeDelta2Sum(gainR, gainB, 0.0, 0.0);
+		points[i][1] = delta2Sum;
+		if (points[i].y() < points[bestPoint].y())
+			bestPoint = i;
+	}
+
+	bestPoint = std::clamp(bestPoint, 1, numDeltas - 2);
+	ipa::Pwl::Point rbBest = ipa::Pwl::Point({ r, b }) +
+				 transverse * interpolateQuadatric(points[bestPoint - 1],
+								   points[bestPoint],
+								   points[bestPoint + 1]);
+	double rBest = rbBest.x(), bBest = rbBest.y();
+
+	r = rBest, b = bBest;
+}
+
+AwbNN::RGB AwbNN::processZone(AwbNN::RGB zone, float redGain, float blueGain)
+{
+	/*
+	 * Renders the pixel at 5000K temperature
+	 */
+	RGB zoneGains = zone;
+
+	zoneGains.R *= redGain;
+	zoneGains.G *= 1.0;
+	zoneGains.B *= blueGain;
+
+	RGB zoneCcm;
+
+	zoneCcm.R = nnConfig_.ccm[0] * zoneGains.R + nnConfig_.ccm[1] * zoneGains.G + nnConfig_.ccm[2] * zoneGains.B;
+	zoneCcm.G = nnConfig_.ccm[3] * zoneGains.R + nnConfig_.ccm[4] * zoneGains.G + nnConfig_.ccm[5] * zoneGains.B;
+	zoneCcm.B = nnConfig_.ccm[6] * zoneGains.R + nnConfig_.ccm[7] * zoneGains.G + nnConfig_.ccm[8] * zoneGains.B;
+
+	return zoneCcm;
+}
+
+void AwbNN::awbNN()
+{
+	float *inputData = interpreter_->typed_input_tensor<float>(0);
+	float *inputLux = interpreter_->typed_input_tensor<float>(1);
+
+	float redGain = 1.0 / config_.ctR.eval(5000);
+	float blueGain = 1.0 / config_.ctB.eval(5000);
+
+	for (uint i = 0; i < zoneSize_.height; i++) {
+		for (uint j = 0; j < zoneSize_.width; j++) {
+			uint zoneIdx = i * zoneSize_.width + j;
+
+			RGB processedZone = processZone(zones_[zoneIdx] * (1.0 / 65535), redGain, blueGain);
+			uint baseIdx = zoneIdx * 3;
+
+			inputData[baseIdx + 0] = static_cast<float>(processedZone.R);
+			inputData[baseIdx + 1] = static_cast<float>(processedZone.G);
+			inputData[baseIdx + 2] = static_cast<float>(processedZone.B);
+		}
+	}
+
+	inputLux[0] = static_cast<float>(lux_);
+
+	TfLiteStatus status = interpreter_->Invoke();
+	if (status != kTfLiteOk) {
+		LOG(RPiAwb, Error) << "Model inference failed with status: " << status;
+		return;
+	}
+
+	float *outputData = interpreter_->typed_output_tensor<float>(0);
+
+	double t = outputData[0];
+
+	LOG(RPiAwb, Debug) << "Model output temperature: " << t;
+
+	t = std::clamp(t, mode_->ctLo, mode_->ctHi);
+
+	double r = config_.ctR.eval(t);
+	double b = config_.ctB.eval(t);
+
+	transverseSearch(t, r, b);
+
+	LOG(RPiAwb, Debug) << "After transverse search: Temperature: " << t << " Red gain: " << 1.0 / r << " Blue gain: " << 1.0 / b;
+
+	asyncResults_.temperatureK = t;
+	asyncResults_.gainR = 1.0 / r * config_.sensitivityR;
+	asyncResults_.gainG = 1.0;
+	asyncResults_.gainB = 1.0 / b * config_.sensitivityB;
+}
+
+void AwbNN::doAwb()
+{
+	prepareStats();
+	if (zones_.size() == (zoneSize_.width * zoneSize_.height) && nnConfig_.enableNn) {
+		awbNN();
+	} else {
+		awbGrey();
+	}
+	statistics_.reset();
+}
+
+/* Register algorithm with the system. */
+static Algorithm *create(Controller *controller)
+{
+	return (Algorithm *)new AwbNN(controller);
+}
+static RegisterAlgorithm reg(NAME, &create);
+
+} /* namespace RPiController */