| Message ID | 20260127120604.6560-3-david.plowman@raspberrypi.com |
|---|---|
| State | New |
| Headers | show |
| Series |
|
| Related | show |
Hi Just a couple quick comments. 2026. 01. 27. 12:59 keltezéssel, David Plowman írta: > From: Peter Bailey <peter.bailey@raspberrypi.com> > > Add an AWB algorithm which uses neural networks. > > Signed-off-by: Peter Bailey <peter.bailey@raspberrypi.com> > Reviewed-by: David Plowman <david.plowman@raspberrypi.com> > Reviewed-by: Naushir Patuck <naush@raspberrypi.com> > --- > meson_options.txt | 5 + > src/ipa/rpi/controller/meson.build | 9 + > src/ipa/rpi/controller/rpi/awb_nn.cpp | 456 ++++++++++++++++++++++++++ > 3 files changed, 470 insertions(+) > create mode 100644 src/ipa/rpi/controller/rpi/awb_nn.cpp > > diff --git a/meson_options.txt b/meson_options.txt > index c052e85a..07847294 100644 > --- a/meson_options.txt > +++ b/meson_options.txt > @@ -76,6 +76,11 @@ option('qcam', > value : 'auto', > description : 'Compile the qcam test application') > > +option('rpi-awb-nn', If dots work, then I think `rpi.awb-nn` is better name. > + type : 'feature', > + value : 'auto', > + description : 'Enable the Raspberry Pi Neural Network AWB algorithm') > + > option('test', > type : 'boolean', > value : false, > diff --git a/src/ipa/rpi/controller/meson.build b/src/ipa/rpi/controller/meson.build > index c8637906..03ee7c20 100644 > --- a/src/ipa/rpi/controller/meson.build > +++ b/src/ipa/rpi/controller/meson.build > @@ -32,6 +32,15 @@ rpi_ipa_controller_deps = [ > libcamera_private, > ] > > +tflite_dep = dependency('tensorflow-lite', required : get_option('rpi-awb-nn')) > + > +if tflite_dep.found() > + rpi_ipa_controller_sources += files([ > + 'rpi/awb_nn.cpp', > + ]) > + rpi_ipa_controller_deps += tflite_dep > +endif > + > rpi_ipa_controller_lib = static_library('rpi_ipa_controller', rpi_ipa_controller_sources, > include_directories : libipa_includes, > dependencies : rpi_ipa_controller_deps) > diff --git a/src/ipa/rpi/controller/rpi/awb_nn.cpp b/src/ipa/rpi/controller/rpi/awb_nn.cpp > new file mode 100644 > index 00000000..395add85 > --- /dev/null > +++ b/src/ipa/rpi/controller/rpi/awb_nn.cpp > @@ -0,0 +1,456 @@ > +/* SPDX-License-Identifier: BSD-2-Clause */ > +/* > + * Copyright (C) 2025, Raspberry Pi Ltd > + * > + * AWB control algorithm using neural network > + * > + * The AWB Neural Network algorithm can be run entirely with the code here > + * and the suppllied TFLite models. Those interested in the full model > + * definitions, or who may want to re-train the models should visit > + * > + * https://github.com/raspberrypi/awb_nn > + * > + * where you will find full source code for the models, the full datasets > + * used for training our supplied models, and full instructions for capturing > + * your own images and re-training the models for your own use cases. > + */ > + > +#include <chrono> > +#include <condition_variable> > +#include <thread> The ones above don't seem to be used. > + > +#include <libcamera/base/file.h> > +#include <libcamera/base/log.h> > + > +#include <tensorflow/lite/interpreter.h> > +#include <tensorflow/lite/kernels/register.h> > +#include <tensorflow/lite/model.h> > + > +#include "../awb_algorithm.h" > +#include "../awb_status.h" > +#include "../lux_status.h" > +#include "libipa/pwl.h" > + > +#include "alsc_status.h" This also does not look used. > +#include "awb.h" > + > +using namespace libcamera; > + > +LOG_DECLARE_CATEGORY(RPiAwb) > + > +constexpr double kDefaultCT = 4500.0; > + > +/* > + * The neural networks are trained to work on images rendered at a canonical > + * colour temperature. That value is 5000K, which must be reproduced here. > + */ > +constexpr double kNetworkCanonicalCT = 5000.0; > + > +#define NAME "rpi.nn.awb" > + > +namespace RPiController { > + > +struct AwbNNConfig { > + AwbNNConfig() {} Is this empty constructor needed? > + int read(const libcamera::YamlObject ¶ms, AwbConfig &config); > + > + /* An empty model will check default locations for model.tflite */ > + std::string model; > + float minTemp; > + float maxTemp; > + > + bool enableNn; > + > + /* CCM matrix for canonical network CT */ > + double ccm[9]; > +}; > + > +class AwbNN : public Awb > +{ > +public: > + AwbNN(Controller *controller = NULL); nullptr > + ~AwbNN(); > + char const *name() const override; > + void initialise() override; > + int read(const libcamera::YamlObject ¶ms) override; > + > +protected: > + void doAwb() override; > + void prepareStats() override; > + > +private: > + bool isAutoEnabled() const; > + AwbNNConfig nnConfig_; > + void transverseSearch(double t, double &r, double &b); > + RGB processZone(RGB zone, float red_gain, float blue_gain); > + void awbNN(); > + void loadModel(); > + > + libcamera::Size zoneSize_; > + std::unique_ptr<tflite::FlatBufferModel> model_; > + std::unique_ptr<tflite::Interpreter> interpreter_; > +}; > + > +int AwbNNConfig::read(const libcamera::YamlObject ¶ms, AwbConfig &config) > +{ > + model = params["model"].get<std::string>(""); > + minTemp = params["min_temp"].get<float>(2800.0); > + maxTemp = params["max_temp"].get<float>(7600.0); > + > + for (int i = 0; i < 9; i++) > + ccm[i] = params["ccm"][i].get<double>(0.0); > + > + enableNn = params["enable_nn"].get<int>(1); > + > + if (enableNn) { > + if (!config.hasCtCurve()) { > + LOG(RPiAwb, Error) << "CT curve not specified"; > + enableNn = false; > + } > + > + if (!model.empty() && model.find(".tflite") == std::string::npos) { > + LOG(RPiAwb, Error) << "Model must be a .tflite file"; > + enableNn = false; Is it useful to force the extension? > + } > + > + bool validCcm = true; > + for (int i = 0; i < 9; i++) > + if (ccm[i] == 0.0) > + validCcm = false; > + > + if (!validCcm) { > + LOG(RPiAwb, Error) << "CCM not specified or invalid"; > + enableNn = false; > + } > + > + if (!enableNn) { > + LOG(RPiAwb, Warning) << "Neural Network AWB mis-configured - switch to Grey method"; "misconfigured" ? > + } > + } > + > + if (!enableNn) { > + config.sensitivityR = config.sensitivityB = 1.0; > + config.greyWorld = true; > + } > + > + return 0; > +} > + > +AwbNN::AwbNN(Controller *controller) > + : Awb(controller) > +{ > + zoneSize_ = getHardwareConfig().awbRegions; > +} > + > +AwbNN::~AwbNN() > +{ > +} > + > +char const *AwbNN::name() const > +{ > + return NAME; > +} > + > +int AwbNN::read(const libcamera::YamlObject ¶ms) > +{ > + int ret; > + > + ret = config_.read(params); > + if (ret) > + return ret; > + > + ret = nnConfig_.read(params, config_); > + if (ret) > + return ret; > + > + return 0; > +} > + > +static bool checkTensorShape(TfLiteTensor *tensor, const int *expectedDims, const int expectedDimsSize) > +{ > + if (tensor->dims->size != expectedDimsSize) > + return false; > + > + for (int i = 0; i < tensor->dims->size; i++) { > + if (tensor->dims->data[i] != expectedDims[i]) { > + return false; > + } > + } > + return true; from <algorithm> return std::equal(expectedDims, expectedDims + expectedDimsSize, tensor->dims->data, tensor->dims->data + tensor->dims->size); ? > +} > + > +static std::string buildDimString(const int *dims, const int dimsSize) > +{ > + std::string s = "["; > + for (int i = 0; i < dimsSize; i++) { > + s += std::to_string(dims[i]); > + if (i < dimsSize - 1) > + s += ","; > + else > + s += "]"; > + } > + return s; return '[' + utils::join(Span{ dims, dimsSize }, ",") + ']'; ? > +} > + > +void AwbNN::loadModel() > +{ > + std::string modelPath; > + if (getTarget() == "bcm2835") { > + modelPath = "/ipa/rpi/vc4/awb_model.tflite"; > + } else { > + modelPath = "/ipa/rpi/pisp/awb_model.tflite"; > + } > + > + if (nnConfig_.model.empty()) { > + std::string root = utils::libcameraSourcePath(); > + if (!root.empty()) { > + modelPath = root + modelPath; > + } else { > + modelPath = LIBCAMERA_DATA_DIR + modelPath; > + } > + > + if (!File::exists(modelPath)) { > + LOG(RPiAwb, Error) << "No model file found in standard locations"; > + nnConfig_.enableNn = false; > + return; > + } > + } else { > + modelPath = nnConfig_.model; > + } > + > + LOG(RPiAwb, Debug) << "Attempting to load model from: " << modelPath; > + > + model_ = tflite::FlatBufferModel::BuildFromFile(modelPath.c_str()); As far as I can see, `BuildFromFile` takes an `ErrorReporter` parameter. Would it make sense to create a static instance of one and use it to route messages into libcamera log? If not specified, does it report anything to stderr or similar? (And the errors from tflite are logged, then I would probably also remove the `File::exists()` check as well.) > + > + if (!model_) { > + LOG(RPiAwb, Error) << "Failed to load model from " << modelPath; > + nnConfig_.enableNn = false; > + return; > + } > + > + tflite::MutableOpResolver resolver; > + tflite::ops::builtin::BuiltinOpResolver builtin_resolver; > + resolver.AddAll(builtin_resolver); > + tflite::InterpreterBuilder(*model_, resolver)(&interpreter_); > + if (!interpreter_) { > + LOG(RPiAwb, Error) << "Failed to build interpreter for model " << nnConfig_.model; > + nnConfig_.enableNn = false; > + return; > + } > + > + interpreter_->AllocateTensors(); > + TfLiteTensor *inputTensor = interpreter_->input_tensor(0); > + TfLiteTensor *inputLuxTensor = interpreter_->input_tensor(1); > + TfLiteTensor *outputTensor = interpreter_->output_tensor(0); > + if (!inputTensor || !inputLuxTensor || !outputTensor) { > + LOG(RPiAwb, Error) << "Model missing input or output tensor"; > + nnConfig_.enableNn = false; > + return; > + } > + > + const int expectedInputDims[] = { 1, (int)zoneSize_.height, (int)zoneSize_.width, 3 }; > + const int expectedInputLuxDims[] = { 1 }; > + const int expectedOutputDims[] = { 1 }; > + > + if (!checkTensorShape(inputTensor, expectedInputDims, 4)) { > + LOG(RPiAwb, Error) << "Model input tensor dimension mismatch. Expected: " << buildDimString(expectedInputDims, 4) > + << ", Got: " << buildDimString(inputTensor->dims->data, inputTensor->dims->size); > + nnConfig_.enableNn = false; > + return; > + } > + > + if (!checkTensorShape(inputLuxTensor, expectedInputLuxDims, 1)) { > + LOG(RPiAwb, Error) << "Model input lux tensor dimension mismatch. Expected: " << buildDimString(expectedInputLuxDims, 1) > + << ", Got: " << buildDimString(inputLuxTensor->dims->data, inputLuxTensor->dims->size); > + nnConfig_.enableNn = false; > + return; > + } > + > + if (!checkTensorShape(outputTensor, expectedOutputDims, 1)) { > + LOG(RPiAwb, Error) << "Model output tensor dimension mismatch. Expected: " << buildDimString(expectedOutputDims, 1) > + << ", Got: " << buildDimString(outputTensor->dims->data, outputTensor->dims->size); > + nnConfig_.enableNn = false; > + return; > + } > + > + if (inputTensor->type != kTfLiteFloat32 || inputLuxTensor->type != kTfLiteFloat32 || outputTensor->type != kTfLiteFloat32) { > + LOG(RPiAwb, Error) << "Model input and output tensors must be float32"; > + nnConfig_.enableNn = false; > + return; > + } > + > + LOG(RPiAwb, Info) << "Model loaded successfully from " << modelPath; > + LOG(RPiAwb, Debug) << "Model validation successful - Input Image: " > + << buildDimString(expectedInputDims, 4) > + << ", Input Lux: " << buildDimString(expectedInputLuxDims, 1) > + << ", Output: " << buildDimString(expectedOutputDims, 1) << " floats"; > +} > + > +void AwbNN::initialise() > +{ > + Awb::initialise(); > + > + if (nnConfig_.enableNn) { > + loadModel(); > + if (!nnConfig_.enableNn) { > + LOG(RPiAwb, Warning) << "Neural Network AWB failed to load - switch to Grey method"; > + config_.greyWorld = true; > + config_.sensitivityR = config_.sensitivityB = 1.0; > + } > + } > +} > + > +void AwbNN::prepareStats() > +{ > + zones_.clear(); > + /* > + * LSC has already been applied to the stats in this pipeline, so stop > + * any LSC compensation. We also ignore config_.fast in this version. > + */ > + generateStats(zones_, statistics_, 0.0, 0.0, getGlobalMetadata(), 0.0, 0.0, 0.0); > + /* > + * apply sensitivities, so values appear to come from our "canonical" > + * sensor. > + */ > + for (auto &zone : zones_) { > + zone.R *= config_.sensitivityR; > + zone.B *= config_.sensitivityB; > + } > +} > + > +void AwbNN::transverseSearch(double t, double &r, double &b) > +{ > + int spanR = -1, spanB = -1; > + config_.ctR.eval(t, &spanR); > + config_.ctB.eval(t, &spanB); > + > + const int diff = 10; > + double rDiff = config_.ctR.eval(t + diff, &spanR) - > + config_.ctR.eval(t - diff, &spanR); > + double bDiff = config_.ctB.eval(t + diff, &spanB) - > + config_.ctB.eval(t - diff, &spanB); > + > + ipa::Pwl::Point transverse({ bDiff, -rDiff }); > + if (transverse.length2() < 1e-6) > + return; > + > + transverse = transverse / transverse.length(); > + double transverseRange = config_.transverseNeg + config_.transversePos; > + const int maxNumDeltas = 12; > + int numDeltas = floor(transverseRange * 100 + 0.5) + 1; > + numDeltas = numDeltas < 3 ? 3 : (numDeltas > maxNumDeltas ? maxNumDeltas : numDeltas); int numDeltas = std::clamp<int>(floor(transverseRange * 100 + 0.5) + 1, 3, maxNumDeltas); ? > + > + ipa::Pwl::Point points[maxNumDeltas]; > + int bestPoint = 0; > + > + for (int i = 0; i < numDeltas; i++) { > + points[i][0] = -config_.transverseNeg + > + (transverseRange * i) / (numDeltas - 1); > + ipa::Pwl::Point rbTest = ipa::Pwl::Point({ r, b }) + > + transverse * points[i].x(); > + double rTest = rbTest.x(), bTest = rbTest.y(); > + double gainR = 1 / rTest, gainB = 1 / bTest; > + double delta2Sum = computeDelta2Sum(gainR, gainB, 0.0, 0.0); > + points[i][1] = delta2Sum; > + if (points[i].y() < points[bestPoint].y()) > + bestPoint = i; > + } > + > + bestPoint = std::clamp(bestPoint, 1, numDeltas - 2); > + ipa::Pwl::Point rbBest = ipa::Pwl::Point({ r, b }) + > + transverse * interpolateQuadatric(points[bestPoint - 1], > + points[bestPoint], > + points[bestPoint + 1]); > + double rBest = rbBest.x(), bBest = rbBest.y(); > + > + r = rBest, b = bBest; r = rbBest.x(); b = rbBest.y(); ? > +} > + > +AwbNN::RGB AwbNN::processZone(AwbNN::RGB zone, float redGain, float blueGain) > +{ > + /* > + * Renders the pixel at canonical network colour temperature > + */ > + RGB zoneGains = zone; > + > + zoneGains.R *= redGain; > + zoneGains.G *= 1.0; > + zoneGains.B *= blueGain; > + > + RGB zoneCcm; > + > + zoneCcm.R = nnConfig_.ccm[0] * zoneGains.R + nnConfig_.ccm[1] * zoneGains.G + nnConfig_.ccm[2] * zoneGains.B; > + zoneCcm.G = nnConfig_.ccm[3] * zoneGains.R + nnConfig_.ccm[4] * zoneGains.G + nnConfig_.ccm[5] * zoneGains.B; > + zoneCcm.B = nnConfig_.ccm[6] * zoneGains.R + nnConfig_.ccm[7] * zoneGains.G + nnConfig_.ccm[8] * zoneGains.B; > + > + return zoneCcm; > +} > + > +void AwbNN::awbNN() > +{ > + float *inputData = interpreter_->typed_input_tensor<float>(0); > + float *inputLux = interpreter_->typed_input_tensor<float>(1); > + > + float redGain = 1.0 / config_.ctR.eval(kNetworkCanonicalCT); > + float blueGain = 1.0 / config_.ctB.eval(kNetworkCanonicalCT); > + > + for (uint i = 0; i < zoneSize_.height; i++) { > + for (uint j = 0; j < zoneSize_.width; j++) { > + uint zoneIdx = i * zoneSize_.width + j; > + > + RGB processedZone = processZone(zones_[zoneIdx] * (1.0 / 65535), redGain, blueGain); > + uint baseIdx = zoneIdx * 3; Where is this `uint` type coming from? tflite? Given that `zoneSize_` is `libcamera::Size`, which uses `unsigned int`, is this `uint` type necessary here? > + > + inputData[baseIdx + 0] = static_cast<float>(processedZone.R); > + inputData[baseIdx + 1] = static_cast<float>(processedZone.G); > + inputData[baseIdx + 2] = static_cast<float>(processedZone.B); > + } > + } > + > + inputLux[0] = static_cast<float>(lux_); > + > + TfLiteStatus status = interpreter_->Invoke(); > + if (status != kTfLiteOk) { > + LOG(RPiAwb, Error) << "Model inference failed with status: " << status; > + return; > + } > + > + float *outputData = interpreter_->typed_output_tensor<float>(0); > + > + double t = outputData[0]; > + > + LOG(RPiAwb, Debug) << "Model output temperature: " << t; > + > + t = std::clamp(t, mode_->ctLo, mode_->ctHi); > + > + double r = config_.ctR.eval(t); > + double b = config_.ctB.eval(t); > + > + transverseSearch(t, r, b); > + > + LOG(RPiAwb, Debug) << "After transverse search: Temperature: " << t << " Red gain: " << 1.0 / r << " Blue gain: " << 1.0 / b; > + > + asyncResults_.temperatureK = t; > + asyncResults_.gainR = 1.0 / r * config_.sensitivityR; > + asyncResults_.gainG = 1.0; > + asyncResults_.gainB = 1.0 / b * config_.sensitivityB; > +} > + > +void AwbNN::doAwb() > +{ > + prepareStats(); > + if (zones_.size() == (zoneSize_.width * zoneSize_.height) && nnConfig_.enableNn) > + awbNN(); > + else > + awbGrey(); > + statistics_.reset(); > +} > + > +/* Register algorithm with the system. */ > +static Algorithm *create(Controller *controller) > +{ > + return (Algorithm *)new AwbNN(controller); Please omit this cast. Regards, Barnabás Pőcze > +} > +static RegisterAlgorithm reg(NAME, &create); > + > +} /* namespace RPiController */
Hi Barnabas Thank you for the comments. On Tue, 27 Jan 2026 at 14:28, Barnabás Pőcze <barnabas.pocze@ideasonboard.com> wrote: > > Hi > > Just a couple quick comments. > > > 2026. 01. 27. 12:59 keltezéssel, David Plowman írta: > > From: Peter Bailey <peter.bailey@raspberrypi.com> > > > > Add an AWB algorithm which uses neural networks. > > > > Signed-off-by: Peter Bailey <peter.bailey@raspberrypi.com> > > Reviewed-by: David Plowman <david.plowman@raspberrypi.com> > > Reviewed-by: Naushir Patuck <naush@raspberrypi.com> > > --- > > meson_options.txt | 5 + > > src/ipa/rpi/controller/meson.build | 9 + > > src/ipa/rpi/controller/rpi/awb_nn.cpp | 456 ++++++++++++++++++++++++++ > > 3 files changed, 470 insertions(+) > > create mode 100644 src/ipa/rpi/controller/rpi/awb_nn.cpp > > > > diff --git a/meson_options.txt b/meson_options.txt > > index c052e85a..07847294 100644 > > --- a/meson_options.txt > > +++ b/meson_options.txt > > @@ -76,6 +76,11 @@ option('qcam', > > value : 'auto', > > description : 'Compile the qcam test application') > > > > +option('rpi-awb-nn', > > If dots work, then I think `rpi.awb-nn` is better name. I don't believe this works, unfortunately. Someone please correct me if I'm wrong! > > > > + type : 'feature', > > + value : 'auto', > > + description : 'Enable the Raspberry Pi Neural Network AWB algorithm') > > + > > option('test', > > type : 'boolean', > > value : false, > > diff --git a/src/ipa/rpi/controller/meson.build b/src/ipa/rpi/controller/meson.build > > index c8637906..03ee7c20 100644 > > --- a/src/ipa/rpi/controller/meson.build > > +++ b/src/ipa/rpi/controller/meson.build > > @@ -32,6 +32,15 @@ rpi_ipa_controller_deps = [ > > libcamera_private, > > ] > > > > +tflite_dep = dependency('tensorflow-lite', required : get_option('rpi-awb-nn')) > > + > > +if tflite_dep.found() > > + rpi_ipa_controller_sources += files([ > > + 'rpi/awb_nn.cpp', > > + ]) > > + rpi_ipa_controller_deps += tflite_dep > > +endif > > + > > rpi_ipa_controller_lib = static_library('rpi_ipa_controller', rpi_ipa_controller_sources, > > include_directories : libipa_includes, > > dependencies : rpi_ipa_controller_deps) > > diff --git a/src/ipa/rpi/controller/rpi/awb_nn.cpp b/src/ipa/rpi/controller/rpi/awb_nn.cpp > > new file mode 100644 > > index 00000000..395add85 > > --- /dev/null > > +++ b/src/ipa/rpi/controller/rpi/awb_nn.cpp > > @@ -0,0 +1,456 @@ > > +/* SPDX-License-Identifier: BSD-2-Clause */ > > +/* > > + * Copyright (C) 2025, Raspberry Pi Ltd > > + * > > + * AWB control algorithm using neural network > > + * > > + * The AWB Neural Network algorithm can be run entirely with the code here > > + * and the suppllied TFLite models. Those interested in the full model > > + * definitions, or who may want to re-train the models should visit > > + * > > + * https://github.com/raspberrypi/awb_nn > > + * > > + * where you will find full source code for the models, the full datasets > > + * used for training our supplied models, and full instructions for capturing > > + * your own images and re-training the models for your own use cases. > > + */ > > + > > +#include <chrono> > > +#include <condition_variable> > > +#include <thread> > > The ones above don't seem to be used. Yes, will remove. > > > > + > > +#include <libcamera/base/file.h> > > +#include <libcamera/base/log.h> > > + > > +#include <tensorflow/lite/interpreter.h> > > +#include <tensorflow/lite/kernels/register.h> > > +#include <tensorflow/lite/model.h> > > + > > +#include "../awb_algorithm.h" > > +#include "../awb_status.h" > > +#include "../lux_status.h" > > +#include "libipa/pwl.h" > > + > > +#include "alsc_status.h" > > This also does not look used. Yes, this too. > > > > +#include "awb.h" > > + > > +using namespace libcamera; > > + > > +LOG_DECLARE_CATEGORY(RPiAwb) > > + > > +constexpr double kDefaultCT = 4500.0; > > + > > +/* > > + * The neural networks are trained to work on images rendered at a canonical > > + * colour temperature. That value is 5000K, which must be reproduced here. > > + */ > > +constexpr double kNetworkCanonicalCT = 5000.0; > > + > > +#define NAME "rpi.nn.awb" > > + > > +namespace RPiController { > > + > > +struct AwbNNConfig { > > + AwbNNConfig() {} > > Is this empty constructor needed? True, it probably isn't. Though I quite like making it explicit if it's actually being used somewhere. But "AwbNNConfig() = default;" would certainly be better, so maybe I'll go with that? > > > > + int read(const libcamera::YamlObject ¶ms, AwbConfig &config); > > + > > + /* An empty model will check default locations for model.tflite */ > > + std::string model; > > + float minTemp; > > + float maxTemp; > > + > > + bool enableNn; > > + > > + /* CCM matrix for canonical network CT */ > > + double ccm[9]; > > +}; > > + > > +class AwbNN : public Awb > > +{ > > +public: > > + AwbNN(Controller *controller = NULL); > > nullptr I'll change that. Actually there's one in awb.h too so I'll also change that as well. > > > > + ~AwbNN(); > > + char const *name() const override; > > + void initialise() override; > > + int read(const libcamera::YamlObject ¶ms) override; > > + > > +protected: > > + void doAwb() override; > > + void prepareStats() override; > > + > > +private: > > + bool isAutoEnabled() const; > > + AwbNNConfig nnConfig_; > > + void transverseSearch(double t, double &r, double &b); > > + RGB processZone(RGB zone, float red_gain, float blue_gain); > > + void awbNN(); > > + void loadModel(); > > + > > + libcamera::Size zoneSize_; > > + std::unique_ptr<tflite::FlatBufferModel> model_; > > + std::unique_ptr<tflite::Interpreter> interpreter_; > > +}; > > + > > +int AwbNNConfig::read(const libcamera::YamlObject ¶ms, AwbConfig &config) > > +{ > > + model = params["model"].get<std::string>(""); > > + minTemp = params["min_temp"].get<float>(2800.0); > > + maxTemp = params["max_temp"].get<float>(7600.0); > > + > > + for (int i = 0; i < 9; i++) > > + ccm[i] = params["ccm"][i].get<double>(0.0); > > + > > + enableNn = params["enable_nn"].get<int>(1); > > + > > + if (enableNn) { > > + if (!config.hasCtCurve()) { > > + LOG(RPiAwb, Error) << "CT curve not specified"; > > + enableNn = false; > > + } > > + > > + if (!model.empty() && model.find(".tflite") == std::string::npos) { > > + LOG(RPiAwb, Error) << "Model must be a .tflite file"; > > + enableNn = false; > > Is it useful to force the extension? I quite like forcing the extension, it makes it very clear that we want the .tflite file from our repo, not the .keras or anything else. > > > > + } > > + > > + bool validCcm = true; > > + for (int i = 0; i < 9; i++) > > + if (ccm[i] == 0.0) > > + validCcm = false; > > + > > + if (!validCcm) { > > + LOG(RPiAwb, Error) << "CCM not specified or invalid"; > > + enableNn = false; > > + } > > + > > + if (!enableNn) { > > + LOG(RPiAwb, Warning) << "Neural Network AWB mis-configured - switch to Grey method"; > > "misconfigured" ? Yes, I'm inclined to agree. > > > + } > > + } > > + > > + if (!enableNn) { > > + config.sensitivityR = config.sensitivityB = 1.0; > > + config.greyWorld = true; > > + } > > + > > + return 0; > > +} > > + > > +AwbNN::AwbNN(Controller *controller) > > + : Awb(controller) > > +{ > > + zoneSize_ = getHardwareConfig().awbRegions; > > +} > > + > > +AwbNN::~AwbNN() > > +{ > > +} > > + > > +char const *AwbNN::name() const > > +{ > > + return NAME; > > +} > > + > > +int AwbNN::read(const libcamera::YamlObject ¶ms) > > +{ > > + int ret; > > + > > + ret = config_.read(params); > > + if (ret) > > + return ret; > > + > > + ret = nnConfig_.read(params, config_); > > + if (ret) > > + return ret; > > + > > + return 0; > > +} > > + > > +static bool checkTensorShape(TfLiteTensor *tensor, const int *expectedDims, const int expectedDimsSize) > > +{ > > + if (tensor->dims->size != expectedDimsSize) > > + return false; > > + > > + for (int i = 0; i < tensor->dims->size; i++) { > > + if (tensor->dims->data[i] != expectedDims[i]) { > > + return false; > > + } > > + } > > + return true; > > from <algorithm> > > return std::equal(expectedDims, expectedDims + expectedDimsSize, > tensor->dims->data, tensor->dims->data + tensor->dims->size); > > ? Agree. > > > > +} > > + > > +static std::string buildDimString(const int *dims, const int dimsSize) > > +{ > > + std::string s = "["; > > + for (int i = 0; i < dimsSize; i++) { > > + s += std::to_string(dims[i]); > > + if (i < dimsSize - 1) > > + s += ","; > > + else > > + s += "]"; > > + } > > + return s; > > return '[' + utils::join(Span{ dims, dimsSize }, ",") + ']'; > > ? Yes, that's nicer. > > > > +} > > + > > +void AwbNN::loadModel() > > +{ > > + std::string modelPath; > > + if (getTarget() == "bcm2835") { > > + modelPath = "/ipa/rpi/vc4/awb_model.tflite"; > > + } else { > > + modelPath = "/ipa/rpi/pisp/awb_model.tflite"; > > + } > > + > > + if (nnConfig_.model.empty()) { > > + std::string root = utils::libcameraSourcePath(); > > + if (!root.empty()) { > > + modelPath = root + modelPath; > > + } else { > > + modelPath = LIBCAMERA_DATA_DIR + modelPath; > > + } > > + > > + if (!File::exists(modelPath)) { > > + LOG(RPiAwb, Error) << "No model file found in standard locations"; > > + nnConfig_.enableNn = false; > > + return; > > + } > > + } else { > > + modelPath = nnConfig_.model; > > + } > > + > > + LOG(RPiAwb, Debug) << "Attempting to load model from: " << modelPath; > > + > > + model_ = tflite::FlatBufferModel::BuildFromFile(modelPath.c_str()); > > As far as I can see, `BuildFromFile` takes an `ErrorReporter` parameter. Would it > make sense to create a static instance of one and use it to route messages into > libcamera log? If not specified, does it report anything to stderr or similar? > (And the errors from tflite are logged, then I would probably also remove the > `File::exists()` check as well.) So it's my belief that if you don't say anything, they go to stderr. I'm happy to remove the File::exists() check, the error message doesn't even tell you where it looked, whereas the TFLite message will. Again, I'm slightly inclined not to bother with an ErrorReporter. It works like printf which is slightly irritating, and then I'll just start to worry about parsing the messages for errors, warnings, info etc. and I'm not really seeing a benefit. If anyone disagrees, perhaps we can leave that as a subsequent patch? > > > > + > > + if (!model_) { > > + LOG(RPiAwb, Error) << "Failed to load model from " << modelPath; > > + nnConfig_.enableNn = false; > > + return; > > + } > > + > > + tflite::MutableOpResolver resolver; > > + tflite::ops::builtin::BuiltinOpResolver builtin_resolver; > > + resolver.AddAll(builtin_resolver); > > + tflite::InterpreterBuilder(*model_, resolver)(&interpreter_); > > + if (!interpreter_) { > > + LOG(RPiAwb, Error) << "Failed to build interpreter for model " << nnConfig_.model; > > + nnConfig_.enableNn = false; > > + return; > > + } > > + > > + interpreter_->AllocateTensors(); > > + TfLiteTensor *inputTensor = interpreter_->input_tensor(0); > > + TfLiteTensor *inputLuxTensor = interpreter_->input_tensor(1); > > + TfLiteTensor *outputTensor = interpreter_->output_tensor(0); > > + if (!inputTensor || !inputLuxTensor || !outputTensor) { > > + LOG(RPiAwb, Error) << "Model missing input or output tensor"; > > + nnConfig_.enableNn = false; > > + return; > > + } > > + > > + const int expectedInputDims[] = { 1, (int)zoneSize_.height, (int)zoneSize_.width, 3 }; > > + const int expectedInputLuxDims[] = { 1 }; > > + const int expectedOutputDims[] = { 1 }; > > + > > + if (!checkTensorShape(inputTensor, expectedInputDims, 4)) { > > + LOG(RPiAwb, Error) << "Model input tensor dimension mismatch. Expected: " << buildDimString(expectedInputDims, 4) > > + << ", Got: " << buildDimString(inputTensor->dims->data, inputTensor->dims->size); > > + nnConfig_.enableNn = false; > > + return; > > + } > > + > > + if (!checkTensorShape(inputLuxTensor, expectedInputLuxDims, 1)) { > > + LOG(RPiAwb, Error) << "Model input lux tensor dimension mismatch. Expected: " << buildDimString(expectedInputLuxDims, 1) > > + << ", Got: " << buildDimString(inputLuxTensor->dims->data, inputLuxTensor->dims->size); > > + nnConfig_.enableNn = false; > > + return; > > + } > > + > > + if (!checkTensorShape(outputTensor, expectedOutputDims, 1)) { > > + LOG(RPiAwb, Error) << "Model output tensor dimension mismatch. Expected: " << buildDimString(expectedOutputDims, 1) > > + << ", Got: " << buildDimString(outputTensor->dims->data, outputTensor->dims->size); > > + nnConfig_.enableNn = false; > > + return; > > + } > > + > > + if (inputTensor->type != kTfLiteFloat32 || inputLuxTensor->type != kTfLiteFloat32 || outputTensor->type != kTfLiteFloat32) { > > + LOG(RPiAwb, Error) << "Model input and output tensors must be float32"; > > + nnConfig_.enableNn = false; > > + return; > > + } > > + > > + LOG(RPiAwb, Info) << "Model loaded successfully from " << modelPath; > > + LOG(RPiAwb, Debug) << "Model validation successful - Input Image: " > > + << buildDimString(expectedInputDims, 4) > > + << ", Input Lux: " << buildDimString(expectedInputLuxDims, 1) > > + << ", Output: " << buildDimString(expectedOutputDims, 1) << " floats"; > > +} > > + > > +void AwbNN::initialise() > > +{ > > + Awb::initialise(); > > + > > + if (nnConfig_.enableNn) { > > + loadModel(); > > + if (!nnConfig_.enableNn) { > > + LOG(RPiAwb, Warning) << "Neural Network AWB failed to load - switch to Grey method"; > > + config_.greyWorld = true; > > + config_.sensitivityR = config_.sensitivityB = 1.0; > > + } > > + } > > +} > > + > > +void AwbNN::prepareStats() > > +{ > > + zones_.clear(); > > + /* > > + * LSC has already been applied to the stats in this pipeline, so stop > > + * any LSC compensation. We also ignore config_.fast in this version. > > + */ > > + generateStats(zones_, statistics_, 0.0, 0.0, getGlobalMetadata(), 0.0, 0.0, 0.0); > > + /* > > + * apply sensitivities, so values appear to come from our "canonical" > > + * sensor. > > + */ > > + for (auto &zone : zones_) { > > + zone.R *= config_.sensitivityR; > > + zone.B *= config_.sensitivityB; > > + } > > +} > > + > > +void AwbNN::transverseSearch(double t, double &r, double &b) > > +{ > > + int spanR = -1, spanB = -1; > > + config_.ctR.eval(t, &spanR); > > + config_.ctB.eval(t, &spanB); > > + > > + const int diff = 10; > > + double rDiff = config_.ctR.eval(t + diff, &spanR) - > > + config_.ctR.eval(t - diff, &spanR); > > + double bDiff = config_.ctB.eval(t + diff, &spanB) - > > + config_.ctB.eval(t - diff, &spanB); > > + > > + ipa::Pwl::Point transverse({ bDiff, -rDiff }); > > + if (transverse.length2() < 1e-6) > > + return; > > + > > + transverse = transverse / transverse.length(); > > + double transverseRange = config_.transverseNeg + config_.transversePos; > > + const int maxNumDeltas = 12; > > + int numDeltas = floor(transverseRange * 100 + 0.5) + 1; > > + numDeltas = numDeltas < 3 ? 3 : (numDeltas > maxNumDeltas ? maxNumDeltas : numDeltas); > > int numDeltas = std::clamp<int>(floor(transverseRange * 100 + 0.5) + 1, 3, maxNumDeltas); > > ? Yes, will tidy that. > > > > + > > + ipa::Pwl::Point points[maxNumDeltas]; > > + int bestPoint = 0; > > + > > + for (int i = 0; i < numDeltas; i++) { > > + points[i][0] = -config_.transverseNeg + > > + (transverseRange * i) / (numDeltas - 1); > > + ipa::Pwl::Point rbTest = ipa::Pwl::Point({ r, b }) + > > + transverse * points[i].x(); > > + double rTest = rbTest.x(), bTest = rbTest.y(); > > + double gainR = 1 / rTest, gainB = 1 / bTest; > > + double delta2Sum = computeDelta2Sum(gainR, gainB, 0.0, 0.0); > > + points[i][1] = delta2Sum; > > + if (points[i].y() < points[bestPoint].y()) > > + bestPoint = i; > > + } > > + > > + bestPoint = std::clamp(bestPoint, 1, numDeltas - 2); > > + ipa::Pwl::Point rbBest = ipa::Pwl::Point({ r, b }) + > > + transverse * interpolateQuadatric(points[bestPoint - 1], > > + points[bestPoint], > > + points[bestPoint + 1]); > > + double rBest = rbBest.x(), bBest = rbBest.y(); > > + > > + r = rBest, b = bBest; > > r = rbBest.x(); > b = rbBest.y(); > > ? Yes, this too. > > > > +} > > + > > +AwbNN::RGB AwbNN::processZone(AwbNN::RGB zone, float redGain, float blueGain) > > +{ > > + /* > > + * Renders the pixel at canonical network colour temperature > > + */ > > + RGB zoneGains = zone; > > + > > + zoneGains.R *= redGain; > > + zoneGains.G *= 1.0; > > + zoneGains.B *= blueGain; > > + > > + RGB zoneCcm; > > + > > + zoneCcm.R = nnConfig_.ccm[0] * zoneGains.R + nnConfig_.ccm[1] * zoneGains.G + nnConfig_.ccm[2] * zoneGains.B; > > + zoneCcm.G = nnConfig_.ccm[3] * zoneGains.R + nnConfig_.ccm[4] * zoneGains.G + nnConfig_.ccm[5] * zoneGains.B; > > + zoneCcm.B = nnConfig_.ccm[6] * zoneGains.R + nnConfig_.ccm[7] * zoneGains.G + nnConfig_.ccm[8] * zoneGains.B; > > + > > + return zoneCcm; > > +} > > + > > +void AwbNN::awbNN() > > +{ > > + float *inputData = interpreter_->typed_input_tensor<float>(0); > > + float *inputLux = interpreter_->typed_input_tensor<float>(1); > > + > > + float redGain = 1.0 / config_.ctR.eval(kNetworkCanonicalCT); > > + float blueGain = 1.0 / config_.ctB.eval(kNetworkCanonicalCT); > > + > > + for (uint i = 0; i < zoneSize_.height; i++) { > > + for (uint j = 0; j < zoneSize_.width; j++) { > > + uint zoneIdx = i * zoneSize_.width + j; > > + > > + RGB processedZone = processZone(zones_[zoneIdx] * (1.0 / 65535), redGain, blueGain); > > + uint baseIdx = zoneIdx * 3; > > Where is this `uint` type coming from? tflite? Given that `zoneSize_` is `libcamera::Size`, > which uses `unsigned int`, is this `uint` type necessary here? Indeed, uint isn't a standard thing, though it often gets defined. unsigned int is clearly more portable and better. > > > > + > > + inputData[baseIdx + 0] = static_cast<float>(processedZone.R); > > + inputData[baseIdx + 1] = static_cast<float>(processedZone.G); > > + inputData[baseIdx + 2] = static_cast<float>(processedZone.B); > > + } > > + } > > + > > + inputLux[0] = static_cast<float>(lux_); > > + > > + TfLiteStatus status = interpreter_->Invoke(); > > + if (status != kTfLiteOk) { > > + LOG(RPiAwb, Error) << "Model inference failed with status: " << status; > > + return; > > + } > > + > > + float *outputData = interpreter_->typed_output_tensor<float>(0); > > + > > + double t = outputData[0]; > > + > > + LOG(RPiAwb, Debug) << "Model output temperature: " << t; > > + > > + t = std::clamp(t, mode_->ctLo, mode_->ctHi); > > + > > + double r = config_.ctR.eval(t); > > + double b = config_.ctB.eval(t); > > + > > + transverseSearch(t, r, b); > > + > > + LOG(RPiAwb, Debug) << "After transverse search: Temperature: " << t << " Red gain: " << 1.0 / r << " Blue gain: " << 1.0 / b; > > + > > + asyncResults_.temperatureK = t; > > + asyncResults_.gainR = 1.0 / r * config_.sensitivityR; > > + asyncResults_.gainG = 1.0; > > + asyncResults_.gainB = 1.0 / b * config_.sensitivityB; > > +} > > + > > +void AwbNN::doAwb() > > +{ > > + prepareStats(); > > + if (zones_.size() == (zoneSize_.width * zoneSize_.height) && nnConfig_.enableNn) > > + awbNN(); > > + else > > + awbGrey(); > > + statistics_.reset(); > > +} > > + > > +/* Register algorithm with the system. */ > > +static Algorithm *create(Controller *controller) > > +{ > > + return (Algorithm *)new AwbNN(controller); > > Please omit this cast. Actually I think there are quite a few of these. I'll fix it in the files touched by this patch set, but the others will have to wait! Will post a v6 shortly. Best regards David > > > Regards, > Barnabás Pőcze > > > +} > > +static RegisterAlgorithm reg(NAME, &create); > > + > > +} /* namespace RPiController */ >
diff --git a/meson_options.txt b/meson_options.txt index c052e85a..07847294 100644 --- a/meson_options.txt +++ b/meson_options.txt @@ -76,6 +76,11 @@ option('qcam', value : 'auto', description : 'Compile the qcam test application') +option('rpi-awb-nn', + type : 'feature', + value : 'auto', + description : 'Enable the Raspberry Pi Neural Network AWB algorithm') + option('test', type : 'boolean', value : false, diff --git a/src/ipa/rpi/controller/meson.build b/src/ipa/rpi/controller/meson.build index c8637906..03ee7c20 100644 --- a/src/ipa/rpi/controller/meson.build +++ b/src/ipa/rpi/controller/meson.build @@ -32,6 +32,15 @@ rpi_ipa_controller_deps = [ libcamera_private, ] +tflite_dep = dependency('tensorflow-lite', required : get_option('rpi-awb-nn')) + +if tflite_dep.found() + rpi_ipa_controller_sources += files([ + 'rpi/awb_nn.cpp', + ]) + rpi_ipa_controller_deps += tflite_dep +endif + rpi_ipa_controller_lib = static_library('rpi_ipa_controller', rpi_ipa_controller_sources, include_directories : libipa_includes, dependencies : rpi_ipa_controller_deps) diff --git a/src/ipa/rpi/controller/rpi/awb_nn.cpp b/src/ipa/rpi/controller/rpi/awb_nn.cpp new file mode 100644 index 00000000..395add85 --- /dev/null +++ b/src/ipa/rpi/controller/rpi/awb_nn.cpp @@ -0,0 +1,456 @@ +/* SPDX-License-Identifier: BSD-2-Clause */ +/* + * Copyright (C) 2025, Raspberry Pi Ltd + * + * AWB control algorithm using neural network + * + * The AWB Neural Network algorithm can be run entirely with the code here + * and the suppllied TFLite models. Those interested in the full model + * definitions, or who may want to re-train the models should visit + * + * https://github.com/raspberrypi/awb_nn + * + * where you will find full source code for the models, the full datasets + * used for training our supplied models, and full instructions for capturing + * your own images and re-training the models for your own use cases. + */ + +#include <chrono> +#include <condition_variable> +#include <thread> + +#include <libcamera/base/file.h> +#include <libcamera/base/log.h> + +#include <tensorflow/lite/interpreter.h> +#include <tensorflow/lite/kernels/register.h> +#include <tensorflow/lite/model.h> + +#include "../awb_algorithm.h" +#include "../awb_status.h" +#include "../lux_status.h" +#include "libipa/pwl.h" + +#include "alsc_status.h" +#include "awb.h" + +using namespace libcamera; + +LOG_DECLARE_CATEGORY(RPiAwb) + +constexpr double kDefaultCT = 4500.0; + +/* + * The neural networks are trained to work on images rendered at a canonical + * colour temperature. That value is 5000K, which must be reproduced here. + */ +constexpr double kNetworkCanonicalCT = 5000.0; + +#define NAME "rpi.nn.awb" + +namespace RPiController { + +struct AwbNNConfig { + AwbNNConfig() {} + int read(const libcamera::YamlObject ¶ms, AwbConfig &config); + + /* An empty model will check default locations for model.tflite */ + std::string model; + float minTemp; + float maxTemp; + + bool enableNn; + + /* CCM matrix for canonical network CT */ + double ccm[9]; +}; + +class AwbNN : public Awb +{ +public: + AwbNN(Controller *controller = NULL); + ~AwbNN(); + char const *name() const override; + void initialise() override; + int read(const libcamera::YamlObject ¶ms) override; + +protected: + void doAwb() override; + void prepareStats() override; + +private: + bool isAutoEnabled() const; + AwbNNConfig nnConfig_; + void transverseSearch(double t, double &r, double &b); + RGB processZone(RGB zone, float red_gain, float blue_gain); + void awbNN(); + void loadModel(); + + libcamera::Size zoneSize_; + std::unique_ptr<tflite::FlatBufferModel> model_; + std::unique_ptr<tflite::Interpreter> interpreter_; +}; + +int AwbNNConfig::read(const libcamera::YamlObject ¶ms, AwbConfig &config) +{ + model = params["model"].get<std::string>(""); + minTemp = params["min_temp"].get<float>(2800.0); + maxTemp = params["max_temp"].get<float>(7600.0); + + for (int i = 0; i < 9; i++) + ccm[i] = params["ccm"][i].get<double>(0.0); + + enableNn = params["enable_nn"].get<int>(1); + + if (enableNn) { + if (!config.hasCtCurve()) { + LOG(RPiAwb, Error) << "CT curve not specified"; + enableNn = false; + } + + if (!model.empty() && model.find(".tflite") == std::string::npos) { + LOG(RPiAwb, Error) << "Model must be a .tflite file"; + enableNn = false; + } + + bool validCcm = true; + for (int i = 0; i < 9; i++) + if (ccm[i] == 0.0) + validCcm = false; + + if (!validCcm) { + LOG(RPiAwb, Error) << "CCM not specified or invalid"; + enableNn = false; + } + + if (!enableNn) { + LOG(RPiAwb, Warning) << "Neural Network AWB mis-configured - switch to Grey method"; + } + } + + if (!enableNn) { + config.sensitivityR = config.sensitivityB = 1.0; + config.greyWorld = true; + } + + return 0; +} + +AwbNN::AwbNN(Controller *controller) + : Awb(controller) +{ + zoneSize_ = getHardwareConfig().awbRegions; +} + +AwbNN::~AwbNN() +{ +} + +char const *AwbNN::name() const +{ + return NAME; +} + +int AwbNN::read(const libcamera::YamlObject ¶ms) +{ + int ret; + + ret = config_.read(params); + if (ret) + return ret; + + ret = nnConfig_.read(params, config_); + if (ret) + return ret; + + return 0; +} + +static bool checkTensorShape(TfLiteTensor *tensor, const int *expectedDims, const int expectedDimsSize) +{ + if (tensor->dims->size != expectedDimsSize) + return false; + + for (int i = 0; i < tensor->dims->size; i++) { + if (tensor->dims->data[i] != expectedDims[i]) { + return false; + } + } + return true; +} + +static std::string buildDimString(const int *dims, const int dimsSize) +{ + std::string s = "["; + for (int i = 0; i < dimsSize; i++) { + s += std::to_string(dims[i]); + if (i < dimsSize - 1) + s += ","; + else + s += "]"; + } + return s; +} + +void AwbNN::loadModel() +{ + std::string modelPath; + if (getTarget() == "bcm2835") { + modelPath = "/ipa/rpi/vc4/awb_model.tflite"; + } else { + modelPath = "/ipa/rpi/pisp/awb_model.tflite"; + } + + if (nnConfig_.model.empty()) { + std::string root = utils::libcameraSourcePath(); + if (!root.empty()) { + modelPath = root + modelPath; + } else { + modelPath = LIBCAMERA_DATA_DIR + modelPath; + } + + if (!File::exists(modelPath)) { + LOG(RPiAwb, Error) << "No model file found in standard locations"; + nnConfig_.enableNn = false; + return; + } + } else { + modelPath = nnConfig_.model; + } + + LOG(RPiAwb, Debug) << "Attempting to load model from: " << modelPath; + + model_ = tflite::FlatBufferModel::BuildFromFile(modelPath.c_str()); + + if (!model_) { + LOG(RPiAwb, Error) << "Failed to load model from " << modelPath; + nnConfig_.enableNn = false; + return; + } + + tflite::MutableOpResolver resolver; + tflite::ops::builtin::BuiltinOpResolver builtin_resolver; + resolver.AddAll(builtin_resolver); + tflite::InterpreterBuilder(*model_, resolver)(&interpreter_); + if (!interpreter_) { + LOG(RPiAwb, Error) << "Failed to build interpreter for model " << nnConfig_.model; + nnConfig_.enableNn = false; + return; + } + + interpreter_->AllocateTensors(); + TfLiteTensor *inputTensor = interpreter_->input_tensor(0); + TfLiteTensor *inputLuxTensor = interpreter_->input_tensor(1); + TfLiteTensor *outputTensor = interpreter_->output_tensor(0); + if (!inputTensor || !inputLuxTensor || !outputTensor) { + LOG(RPiAwb, Error) << "Model missing input or output tensor"; + nnConfig_.enableNn = false; + return; + } + + const int expectedInputDims[] = { 1, (int)zoneSize_.height, (int)zoneSize_.width, 3 }; + const int expectedInputLuxDims[] = { 1 }; + const int expectedOutputDims[] = { 1 }; + + if (!checkTensorShape(inputTensor, expectedInputDims, 4)) { + LOG(RPiAwb, Error) << "Model input tensor dimension mismatch. Expected: " << buildDimString(expectedInputDims, 4) + << ", Got: " << buildDimString(inputTensor->dims->data, inputTensor->dims->size); + nnConfig_.enableNn = false; + return; + } + + if (!checkTensorShape(inputLuxTensor, expectedInputLuxDims, 1)) { + LOG(RPiAwb, Error) << "Model input lux tensor dimension mismatch. Expected: " << buildDimString(expectedInputLuxDims, 1) + << ", Got: " << buildDimString(inputLuxTensor->dims->data, inputLuxTensor->dims->size); + nnConfig_.enableNn = false; + return; + } + + if (!checkTensorShape(outputTensor, expectedOutputDims, 1)) { + LOG(RPiAwb, Error) << "Model output tensor dimension mismatch. Expected: " << buildDimString(expectedOutputDims, 1) + << ", Got: " << buildDimString(outputTensor->dims->data, outputTensor->dims->size); + nnConfig_.enableNn = false; + return; + } + + if (inputTensor->type != kTfLiteFloat32 || inputLuxTensor->type != kTfLiteFloat32 || outputTensor->type != kTfLiteFloat32) { + LOG(RPiAwb, Error) << "Model input and output tensors must be float32"; + nnConfig_.enableNn = false; + return; + } + + LOG(RPiAwb, Info) << "Model loaded successfully from " << modelPath; + LOG(RPiAwb, Debug) << "Model validation successful - Input Image: " + << buildDimString(expectedInputDims, 4) + << ", Input Lux: " << buildDimString(expectedInputLuxDims, 1) + << ", Output: " << buildDimString(expectedOutputDims, 1) << " floats"; +} + +void AwbNN::initialise() +{ + Awb::initialise(); + + if (nnConfig_.enableNn) { + loadModel(); + if (!nnConfig_.enableNn) { + LOG(RPiAwb, Warning) << "Neural Network AWB failed to load - switch to Grey method"; + config_.greyWorld = true; + config_.sensitivityR = config_.sensitivityB = 1.0; + } + } +} + +void AwbNN::prepareStats() +{ + zones_.clear(); + /* + * LSC has already been applied to the stats in this pipeline, so stop + * any LSC compensation. We also ignore config_.fast in this version. + */ + generateStats(zones_, statistics_, 0.0, 0.0, getGlobalMetadata(), 0.0, 0.0, 0.0); + /* + * apply sensitivities, so values appear to come from our "canonical" + * sensor. + */ + for (auto &zone : zones_) { + zone.R *= config_.sensitivityR; + zone.B *= config_.sensitivityB; + } +} + +void AwbNN::transverseSearch(double t, double &r, double &b) +{ + int spanR = -1, spanB = -1; + config_.ctR.eval(t, &spanR); + config_.ctB.eval(t, &spanB); + + const int diff = 10; + double rDiff = config_.ctR.eval(t + diff, &spanR) - + config_.ctR.eval(t - diff, &spanR); + double bDiff = config_.ctB.eval(t + diff, &spanB) - + config_.ctB.eval(t - diff, &spanB); + + ipa::Pwl::Point transverse({ bDiff, -rDiff }); + if (transverse.length2() < 1e-6) + return; + + transverse = transverse / transverse.length(); + double transverseRange = config_.transverseNeg + config_.transversePos; + const int maxNumDeltas = 12; + int numDeltas = floor(transverseRange * 100 + 0.5) + 1; + numDeltas = numDeltas < 3 ? 3 : (numDeltas > maxNumDeltas ? maxNumDeltas : numDeltas); + + ipa::Pwl::Point points[maxNumDeltas]; + int bestPoint = 0; + + for (int i = 0; i < numDeltas; i++) { + points[i][0] = -config_.transverseNeg + + (transverseRange * i) / (numDeltas - 1); + ipa::Pwl::Point rbTest = ipa::Pwl::Point({ r, b }) + + transverse * points[i].x(); + double rTest = rbTest.x(), bTest = rbTest.y(); + double gainR = 1 / rTest, gainB = 1 / bTest; + double delta2Sum = computeDelta2Sum(gainR, gainB, 0.0, 0.0); + points[i][1] = delta2Sum; + if (points[i].y() < points[bestPoint].y()) + bestPoint = i; + } + + bestPoint = std::clamp(bestPoint, 1, numDeltas - 2); + ipa::Pwl::Point rbBest = ipa::Pwl::Point({ r, b }) + + transverse * interpolateQuadatric(points[bestPoint - 1], + points[bestPoint], + points[bestPoint + 1]); + double rBest = rbBest.x(), bBest = rbBest.y(); + + r = rBest, b = bBest; +} + +AwbNN::RGB AwbNN::processZone(AwbNN::RGB zone, float redGain, float blueGain) +{ + /* + * Renders the pixel at canonical network colour temperature + */ + RGB zoneGains = zone; + + zoneGains.R *= redGain; + zoneGains.G *= 1.0; + zoneGains.B *= blueGain; + + RGB zoneCcm; + + zoneCcm.R = nnConfig_.ccm[0] * zoneGains.R + nnConfig_.ccm[1] * zoneGains.G + nnConfig_.ccm[2] * zoneGains.B; + zoneCcm.G = nnConfig_.ccm[3] * zoneGains.R + nnConfig_.ccm[4] * zoneGains.G + nnConfig_.ccm[5] * zoneGains.B; + zoneCcm.B = nnConfig_.ccm[6] * zoneGains.R + nnConfig_.ccm[7] * zoneGains.G + nnConfig_.ccm[8] * zoneGains.B; + + return zoneCcm; +} + +void AwbNN::awbNN() +{ + float *inputData = interpreter_->typed_input_tensor<float>(0); + float *inputLux = interpreter_->typed_input_tensor<float>(1); + + float redGain = 1.0 / config_.ctR.eval(kNetworkCanonicalCT); + float blueGain = 1.0 / config_.ctB.eval(kNetworkCanonicalCT); + + for (uint i = 0; i < zoneSize_.height; i++) { + for (uint j = 0; j < zoneSize_.width; j++) { + uint zoneIdx = i * zoneSize_.width + j; + + RGB processedZone = processZone(zones_[zoneIdx] * (1.0 / 65535), redGain, blueGain); + uint baseIdx = zoneIdx * 3; + + inputData[baseIdx + 0] = static_cast<float>(processedZone.R); + inputData[baseIdx + 1] = static_cast<float>(processedZone.G); + inputData[baseIdx + 2] = static_cast<float>(processedZone.B); + } + } + + inputLux[0] = static_cast<float>(lux_); + + TfLiteStatus status = interpreter_->Invoke(); + if (status != kTfLiteOk) { + LOG(RPiAwb, Error) << "Model inference failed with status: " << status; + return; + } + + float *outputData = interpreter_->typed_output_tensor<float>(0); + + double t = outputData[0]; + + LOG(RPiAwb, Debug) << "Model output temperature: " << t; + + t = std::clamp(t, mode_->ctLo, mode_->ctHi); + + double r = config_.ctR.eval(t); + double b = config_.ctB.eval(t); + + transverseSearch(t, r, b); + + LOG(RPiAwb, Debug) << "After transverse search: Temperature: " << t << " Red gain: " << 1.0 / r << " Blue gain: " << 1.0 / b; + + asyncResults_.temperatureK = t; + asyncResults_.gainR = 1.0 / r * config_.sensitivityR; + asyncResults_.gainG = 1.0; + asyncResults_.gainB = 1.0 / b * config_.sensitivityB; +} + +void AwbNN::doAwb() +{ + prepareStats(); + if (zones_.size() == (zoneSize_.width * zoneSize_.height) && nnConfig_.enableNn) + awbNN(); + else + awbGrey(); + statistics_.reset(); +} + +/* Register algorithm with the system. */ +static Algorithm *create(Controller *controller) +{ + return (Algorithm *)new AwbNN(controller); +} +static RegisterAlgorithm reg(NAME, &create); + +} /* namespace RPiController */