| Message ID | 20251212103401.3776-3-david.plowman@raspberrypi.com |
|---|---|
| State | Superseded |
| Headers | show |
| Series |
|
| Related | show |
Hi David, Thank you for the patch. On Fri, Dec 12, 2025 at 10:23:51AM +0000, David Plowman wrote: > From: Peter Bailey <peter.bailey@raspberrypi.com> > > Add an AWB algorithm which uses neural networks. Staying at a high level (I'm still in Japan, I'll get back to reviews later this week), I would like to discuss the license of the neural network. I understand you don't plan to publish the full training data set, as it contains images that depict people. How about the network model ? > Signed-off-by: Peter Bailey <peter.bailey@raspberrypi.com> > Reviewed-by: David Plowman <david.plowman@raspberrypi.com> > Reviewed-by: Naushir Patuck <naush@raspberrypi.com> > --- > meson_options.txt | 5 + > src/ipa/rpi/controller/meson.build | 9 + > src/ipa/rpi/controller/rpi/awb_nn.cpp | 446 ++++++++++++++++++++++++++ > 3 files changed, 460 insertions(+) > create mode 100644 src/ipa/rpi/controller/rpi/awb_nn.cpp > > diff --git a/meson_options.txt b/meson_options.txt > index 5954e028..89eece52 100644 > --- a/meson_options.txt > +++ b/meson_options.txt > @@ -78,6 +78,11 @@ option('qcam', > value : 'disabled', > description : 'Compile the qcam test application') > > +option('rpi-awb-nn', > + type : 'feature', > + value : 'auto', > + description : 'Enable the Raspberry Pi Neural Network AWB algorithm') > + > option('test', > type : 'boolean', > value : false, > diff --git a/src/ipa/rpi/controller/meson.build b/src/ipa/rpi/controller/meson.build > index 90d9e285..eba6cb28 100644 > --- a/src/ipa/rpi/controller/meson.build > +++ b/src/ipa/rpi/controller/meson.build > @@ -33,6 +33,15 @@ rpi_ipa_controller_deps = [ > libcamera_private, > ] > > +tflite_dep = dependency('tensorflow-lite', required : get_option('rpi-awb-nn')) > + > +if tflite_dep.found() > + rpi_ipa_controller_sources += files([ > + 'rpi/awb_nn.cpp', > + ]) > + rpi_ipa_controller_deps += tflite_dep > +endif > + > rpi_ipa_controller_lib = static_library('rpi_ipa_controller', rpi_ipa_controller_sources, > include_directories : libipa_includes, > dependencies : rpi_ipa_controller_deps) > diff --git a/src/ipa/rpi/controller/rpi/awb_nn.cpp b/src/ipa/rpi/controller/rpi/awb_nn.cpp > new file mode 100644 > index 00000000..35d1270e > --- /dev/null > +++ b/src/ipa/rpi/controller/rpi/awb_nn.cpp > @@ -0,0 +1,446 @@ > +/* SPDX-License-Identifier: BSD-2-Clause */ > +/* > + * Copyright (C) 2025, Raspberry Pi Ltd > + * > + * AWB control algorithm using neural network > + */ > + > +#include <chrono> > +#include <condition_variable> > +#include <thread> > + > +#include <libcamera/base/file.h> > +#include <libcamera/base/log.h> > + > +#include <tensorflow/lite/interpreter.h> > +#include <tensorflow/lite/kernels/register.h> > +#include <tensorflow/lite/model.h> > + > +#include "../awb_algorithm.h" > +#include "../awb_status.h" > +#include "../lux_status.h" > +#include "libipa/pwl.h" > + > +#include "alsc_status.h" > +#include "awb.h" > + > +using namespace libcamera; > + > +LOG_DECLARE_CATEGORY(RPiAwb) > + > +constexpr double kDefaultCT = 4500.0; > + > +/* > + * The neural networks are trained to work on images rendered at a canonical > + * colour temperature. That value is 5000K, which must be reproduced here. > + */ > +constexpr double kNetworkCanonicalCT = 5000.0; > + > +#define NAME "rpi.nn.awb" > + > +namespace RPiController { > + > +struct AwbNNConfig { > + AwbNNConfig() {} > + int read(const libcamera::YamlObject ¶ms, AwbConfig &config); > + > + /* An empty model will check default locations for model.tflite */ > + std::string model; > + float minTemp; > + float maxTemp; > + > + bool enableNn; > + > + /* CCM matrix for canonical network CT */ > + double ccm[9]; > +}; > + > +class AwbNN : public Awb > +{ > +public: > + AwbNN(Controller *controller = NULL); > + ~AwbNN(); > + char const *name() const override; > + void initialise() override; > + int read(const libcamera::YamlObject ¶ms) override; > + > +protected: > + void doAwb() override; > + void prepareStats() override; > + > +private: > + bool isAutoEnabled() const; > + AwbNNConfig nnConfig_; > + void transverseSearch(double t, double &r, double &b); > + RGB processZone(RGB zone, float red_gain, float blue_gain); > + void awbNN(); > + void loadModel(); > + > + libcamera::Size zoneSize_; > + std::unique_ptr<tflite::FlatBufferModel> model_; > + std::unique_ptr<tflite::Interpreter> interpreter_; > +}; > + > +int AwbNNConfig::read(const libcamera::YamlObject ¶ms, AwbConfig &config) > +{ > + model = params["model"].get<std::string>(""); > + minTemp = params["min_temp"].get<float>(2800.0); > + maxTemp = params["max_temp"].get<float>(7600.0); > + > + for (int i = 0; i < 9; i++) > + ccm[i] = params["ccm"][i].get<double>(0.0); > + > + enableNn = params["enable_nn"].get<int>(1); > + > + if (enableNn) { > + if (!config.hasCtCurve()) { > + LOG(RPiAwb, Error) << "CT curve not specified"; > + enableNn = false; > + } > + > + if (!model.empty() && model.find(".tflite") == std::string::npos) { > + LOG(RPiAwb, Error) << "Model must be a .tflite file"; > + enableNn = false; > + } > + > + bool validCcm = true; > + for (int i = 0; i < 9; i++) > + if (ccm[i] == 0.0) > + validCcm = false; > + > + if (!validCcm) { > + LOG(RPiAwb, Error) << "CCM not specified or invalid"; > + enableNn = false; > + } > + > + if (!enableNn) { > + LOG(RPiAwb, Warning) << "Neural Network AWB mis-configured - switch to Grey method"; > + } > + } > + > + if (!enableNn) { > + config.sensitivityR = config.sensitivityB = 1.0; > + config.greyWorld = true; > + } > + > + return 0; > +} > + > +AwbNN::AwbNN(Controller *controller) > + : Awb(controller) > +{ > + zoneSize_ = getHardwareConfig().awbRegions; > +} > + > +AwbNN::~AwbNN() > +{ > +} > + > +char const *AwbNN::name() const > +{ > + return NAME; > +} > + > +int AwbNN::read(const libcamera::YamlObject ¶ms) > +{ > + int ret; > + > + ret = config_.read(params); > + if (ret) > + return ret; > + > + ret = nnConfig_.read(params, config_); > + if (ret) > + return ret; > + > + return 0; > +} > + > +static bool checkTensorShape(TfLiteTensor *tensor, const int *expectedDims, const int expectedDimsSize) > +{ > + if (tensor->dims->size != expectedDimsSize) > + return false; > + > + for (int i = 0; i < tensor->dims->size; i++) { > + if (tensor->dims->data[i] != expectedDims[i]) { > + return false; > + } > + } > + return true; > +} > + > +static std::string buildDimString(const int *dims, const int dimsSize) > +{ > + std::string s = "["; > + for (int i = 0; i < dimsSize; i++) { > + s += std::to_string(dims[i]); > + if (i < dimsSize - 1) > + s += ","; > + else > + s += "]"; > + } > + return s; > +} > + > +void AwbNN::loadModel() > +{ > + std::string modelPath; > + if (getTarget() == "bcm2835") { > + modelPath = "/ipa/rpi/vc4/awb_model.tflite"; > + } else { > + modelPath = "/ipa/rpi/pisp/awb_model.tflite"; > + } > + > + if (nnConfig_.model.empty()) { > + std::string root = utils::libcameraSourcePath(); > + if (!root.empty()) { > + modelPath = root + modelPath; > + } else { > + modelPath = LIBCAMERA_DATA_DIR + modelPath; > + } > + > + if (!File::exists(modelPath)) { > + LOG(RPiAwb, Error) << "No model file found in standard locations"; > + nnConfig_.enableNn = false; > + return; > + } > + } else { > + modelPath = nnConfig_.model; > + } > + > + LOG(RPiAwb, Debug) << "Attempting to load model from: " << modelPath; > + > + model_ = tflite::FlatBufferModel::BuildFromFile(modelPath.c_str()); > + > + if (!model_) { > + LOG(RPiAwb, Error) << "Failed to load model from " << modelPath; > + nnConfig_.enableNn = false; > + return; > + } > + > + tflite::MutableOpResolver resolver; > + tflite::ops::builtin::BuiltinOpResolver builtin_resolver; > + resolver.AddAll(builtin_resolver); > + tflite::InterpreterBuilder(*model_, resolver)(&interpreter_); > + if (!interpreter_) { > + LOG(RPiAwb, Error) << "Failed to build interpreter for model " << nnConfig_.model; > + nnConfig_.enableNn = false; > + return; > + } > + > + interpreter_->AllocateTensors(); > + TfLiteTensor *inputTensor = interpreter_->input_tensor(0); > + TfLiteTensor *inputLuxTensor = interpreter_->input_tensor(1); > + TfLiteTensor *outputTensor = interpreter_->output_tensor(0); > + if (!inputTensor || !inputLuxTensor || !outputTensor) { > + LOG(RPiAwb, Error) << "Model missing input or output tensor"; > + nnConfig_.enableNn = false; > + return; > + } > + > + const int expectedInputDims[] = { 1, (int)zoneSize_.height, (int)zoneSize_.width, 3 }; > + const int expectedInputLuxDims[] = { 1 }; > + const int expectedOutputDims[] = { 1 }; > + > + if (!checkTensorShape(inputTensor, expectedInputDims, 4)) { > + LOG(RPiAwb, Error) << "Model input tensor dimension mismatch. Expected: " << buildDimString(expectedInputDims, 4) > + << ", Got: " << buildDimString(inputTensor->dims->data, inputTensor->dims->size); > + nnConfig_.enableNn = false; > + return; > + } > + > + if (!checkTensorShape(inputLuxTensor, expectedInputLuxDims, 1)) { > + LOG(RPiAwb, Error) << "Model input lux tensor dimension mismatch. Expected: " << buildDimString(expectedInputLuxDims, 1) > + << ", Got: " << buildDimString(inputLuxTensor->dims->data, inputLuxTensor->dims->size); > + nnConfig_.enableNn = false; > + return; > + } > + > + if (!checkTensorShape(outputTensor, expectedOutputDims, 1)) { > + LOG(RPiAwb, Error) << "Model output tensor dimension mismatch. Expected: " << buildDimString(expectedOutputDims, 1) > + << ", Got: " << buildDimString(outputTensor->dims->data, outputTensor->dims->size); > + nnConfig_.enableNn = false; > + return; > + } > + > + if (inputTensor->type != kTfLiteFloat32 || inputLuxTensor->type != kTfLiteFloat32 || outputTensor->type != kTfLiteFloat32) { > + LOG(RPiAwb, Error) << "Model input and output tensors must be float32"; > + nnConfig_.enableNn = false; > + return; > + } > + > + LOG(RPiAwb, Info) << "Model loaded successfully from " << modelPath; > + LOG(RPiAwb, Debug) << "Model validation successful - Input Image: " > + << buildDimString(expectedInputDims, 4) > + << ", Input Lux: " << buildDimString(expectedInputLuxDims, 1) > + << ", Output: " << buildDimString(expectedOutputDims, 1) << " floats"; > +} > + > +void AwbNN::initialise() > +{ > + Awb::initialise(); > + > + if (nnConfig_.enableNn) { > + loadModel(); > + if (!nnConfig_.enableNn) { > + LOG(RPiAwb, Warning) << "Neural Network AWB failed to load - switch to Grey method"; > + config_.greyWorld = true; > + config_.sensitivityR = config_.sensitivityB = 1.0; > + } > + } > +} > + > +void AwbNN::prepareStats() > +{ > + zones_.clear(); > + /* > + * LSC has already been applied to the stats in this pipeline, so stop > + * any LSC compensation. We also ignore config_.fast in this version. > + */ > + generateStats(zones_, statistics_, 0.0, 0.0, getGlobalMetadata(), 0.0, 0.0, 0.0); > + /* > + * apply sensitivities, so values appear to come from our "canonical" > + * sensor. > + */ > + for (auto &zone : zones_) { > + zone.R *= config_.sensitivityR; > + zone.B *= config_.sensitivityB; > + } > +} > + > +void AwbNN::transverseSearch(double t, double &r, double &b) > +{ > + int spanR = -1, spanB = -1; > + config_.ctR.eval(t, &spanR); > + config_.ctB.eval(t, &spanB); > + > + const int diff = 10; > + double rDiff = config_.ctR.eval(t + diff, &spanR) - > + config_.ctR.eval(t - diff, &spanR); > + double bDiff = config_.ctB.eval(t + diff, &spanB) - > + config_.ctB.eval(t - diff, &spanB); > + > + ipa::Pwl::Point transverse({ bDiff, -rDiff }); > + if (transverse.length2() < 1e-6) > + return; > + > + transverse = transverse / transverse.length(); > + double transverseRange = config_.transverseNeg + config_.transversePos; > + const int maxNumDeltas = 12; > + int numDeltas = floor(transverseRange * 100 + 0.5) + 1; > + numDeltas = numDeltas < 3 ? 3 : (numDeltas > maxNumDeltas ? maxNumDeltas : numDeltas); > + > + ipa::Pwl::Point points[maxNumDeltas]; > + int bestPoint = 0; > + > + for (int i = 0; i < numDeltas; i++) { > + points[i][0] = -config_.transverseNeg + > + (transverseRange * i) / (numDeltas - 1); > + ipa::Pwl::Point rbTest = ipa::Pwl::Point({ r, b }) + > + transverse * points[i].x(); > + double rTest = rbTest.x(), bTest = rbTest.y(); > + double gainR = 1 / rTest, gainB = 1 / bTest; > + double delta2Sum = computeDelta2Sum(gainR, gainB, 0.0, 0.0); > + points[i][1] = delta2Sum; > + if (points[i].y() < points[bestPoint].y()) > + bestPoint = i; > + } > + > + bestPoint = std::clamp(bestPoint, 1, numDeltas - 2); > + ipa::Pwl::Point rbBest = ipa::Pwl::Point({ r, b }) + > + transverse * interpolateQuadatric(points[bestPoint - 1], > + points[bestPoint], > + points[bestPoint + 1]); > + double rBest = rbBest.x(), bBest = rbBest.y(); > + > + r = rBest, b = bBest; > +} > + > +AwbNN::RGB AwbNN::processZone(AwbNN::RGB zone, float redGain, float blueGain) > +{ > + /* > + * Renders the pixel at canonical network colour temperature > + */ > + RGB zoneGains = zone; > + > + zoneGains.R *= redGain; > + zoneGains.G *= 1.0; > + zoneGains.B *= blueGain; > + > + RGB zoneCcm; > + > + zoneCcm.R = nnConfig_.ccm[0] * zoneGains.R + nnConfig_.ccm[1] * zoneGains.G + nnConfig_.ccm[2] * zoneGains.B; > + zoneCcm.G = nnConfig_.ccm[3] * zoneGains.R + nnConfig_.ccm[4] * zoneGains.G + nnConfig_.ccm[5] * zoneGains.B; > + zoneCcm.B = nnConfig_.ccm[6] * zoneGains.R + nnConfig_.ccm[7] * zoneGains.G + nnConfig_.ccm[8] * zoneGains.B; > + > + return zoneCcm; > +} > + > +void AwbNN::awbNN() > +{ > + float *inputData = interpreter_->typed_input_tensor<float>(0); > + float *inputLux = interpreter_->typed_input_tensor<float>(1); > + > + float redGain = 1.0 / config_.ctR.eval(kNetworkCanonicalCT); > + float blueGain = 1.0 / config_.ctB.eval(kNetworkCanonicalCT); > + > + for (uint i = 0; i < zoneSize_.height; i++) { > + for (uint j = 0; j < zoneSize_.width; j++) { > + uint zoneIdx = i * zoneSize_.width + j; > + > + RGB processedZone = processZone(zones_[zoneIdx] * (1.0 / 65535), redGain, blueGain); > + uint baseIdx = zoneIdx * 3; > + > + inputData[baseIdx + 0] = static_cast<float>(processedZone.R); > + inputData[baseIdx + 1] = static_cast<float>(processedZone.G); > + inputData[baseIdx + 2] = static_cast<float>(processedZone.B); > + } > + } > + > + inputLux[0] = static_cast<float>(lux_); > + > + TfLiteStatus status = interpreter_->Invoke(); > + if (status != kTfLiteOk) { > + LOG(RPiAwb, Error) << "Model inference failed with status: " << status; > + return; > + } > + > + float *outputData = interpreter_->typed_output_tensor<float>(0); > + > + double t = outputData[0]; > + > + LOG(RPiAwb, Debug) << "Model output temperature: " << t; > + > + t = std::clamp(t, mode_->ctLo, mode_->ctHi); > + > + double r = config_.ctR.eval(t); > + double b = config_.ctB.eval(t); > + > + transverseSearch(t, r, b); > + > + LOG(RPiAwb, Debug) << "After transverse search: Temperature: " << t << " Red gain: " << 1.0 / r << " Blue gain: " << 1.0 / b; > + > + asyncResults_.temperatureK = t; > + asyncResults_.gainR = 1.0 / r * config_.sensitivityR; > + asyncResults_.gainG = 1.0; > + asyncResults_.gainB = 1.0 / b * config_.sensitivityB; > +} > + > +void AwbNN::doAwb() > +{ > + prepareStats(); > + if (zones_.size() == (zoneSize_.width * zoneSize_.height) && nnConfig_.enableNn) > + awbNN(); > + else > + awbGrey(); > + statistics_.reset(); > +} > + > +/* Register algorithm with the system. */ > +static Algorithm *create(Controller *controller) > +{ > + return (Algorithm *)new AwbNN(controller); > +} > +static RegisterAlgorithm reg(NAME, &create); > + > +} /* namespace RPiController */
Hi Laurent Yes, we're going to publish the model and the training data. In fact, you can find it here <https://github.com/raspberrypi/AWB_NN>, where there are full instructions for adding and re-training with your own images. In fact, we have published two versions of the datasets, one for Pi 5 and one for earlier Pis. In the Pi 5 case, the training images are all reduced to 32x32 resolution, so no privacy worries, and for earlier Pis it's 16x12. (These are obviously the resolution of the statistics on the two platforms). Users can, if they want, re-train with only their own images, or use a number (or all) of ours as well. Hope that clarifies things a bit! David On Tue, 16 Dec 2025 at 10:33, Laurent Pinchart < laurent.pinchart@ideasonboard.com> wrote: > Hi David, > > Thank you for the patch. > > On Fri, Dec 12, 2025 at 10:23:51AM +0000, David Plowman wrote: > > From: Peter Bailey <peter.bailey@raspberrypi.com> > > > > Add an AWB algorithm which uses neural networks. > > Staying at a high level (I'm still in Japan, I'll get back to reviews > later this week), I would like to discuss the license of the neural > network. > > I understand you don't plan to publish the full training data set, as it > contains images that depict people. How about the network model ? > > > Signed-off-by: Peter Bailey <peter.bailey@raspberrypi.com> > > Reviewed-by: David Plowman <david.plowman@raspberrypi.com> > > Reviewed-by: Naushir Patuck <naush@raspberrypi.com> > > --- > > meson_options.txt | 5 + > > src/ipa/rpi/controller/meson.build | 9 + > > src/ipa/rpi/controller/rpi/awb_nn.cpp | 446 ++++++++++++++++++++++++++ > > 3 files changed, 460 insertions(+) > > create mode 100644 src/ipa/rpi/controller/rpi/awb_nn.cpp > > > > diff --git a/meson_options.txt b/meson_options.txt > > index 5954e028..89eece52 100644 > > --- a/meson_options.txt > > +++ b/meson_options.txt > > @@ -78,6 +78,11 @@ option('qcam', > > value : 'disabled', > > description : 'Compile the qcam test application') > > > > +option('rpi-awb-nn', > > + type : 'feature', > > + value : 'auto', > > + description : 'Enable the Raspberry Pi Neural Network AWB > algorithm') > > + > > option('test', > > type : 'boolean', > > value : false, > > diff --git a/src/ipa/rpi/controller/meson.build > b/src/ipa/rpi/controller/meson.build > > index 90d9e285..eba6cb28 100644 > > --- a/src/ipa/rpi/controller/meson.build > > +++ b/src/ipa/rpi/controller/meson.build > > @@ -33,6 +33,15 @@ rpi_ipa_controller_deps = [ > > libcamera_private, > > ] > > > > +tflite_dep = dependency('tensorflow-lite', required : > get_option('rpi-awb-nn')) > > + > > +if tflite_dep.found() > > + rpi_ipa_controller_sources += files([ > > + 'rpi/awb_nn.cpp', > > + ]) > > + rpi_ipa_controller_deps += tflite_dep > > +endif > > + > > rpi_ipa_controller_lib = static_library('rpi_ipa_controller', > rpi_ipa_controller_sources, > > include_directories : > libipa_includes, > > dependencies : > rpi_ipa_controller_deps) > > diff --git a/src/ipa/rpi/controller/rpi/awb_nn.cpp > b/src/ipa/rpi/controller/rpi/awb_nn.cpp > > new file mode 100644 > > index 00000000..35d1270e > > --- /dev/null > > +++ b/src/ipa/rpi/controller/rpi/awb_nn.cpp > > @@ -0,0 +1,446 @@ > > +/* SPDX-License-Identifier: BSD-2-Clause */ > > +/* > > + * Copyright (C) 2025, Raspberry Pi Ltd > > + * > > + * AWB control algorithm using neural network > > + */ > > + > > +#include <chrono> > > +#include <condition_variable> > > +#include <thread> > > + > > +#include <libcamera/base/file.h> > > +#include <libcamera/base/log.h> > > + > > +#include <tensorflow/lite/interpreter.h> > > +#include <tensorflow/lite/kernels/register.h> > > +#include <tensorflow/lite/model.h> > > + > > +#include "../awb_algorithm.h" > > +#include "../awb_status.h" > > +#include "../lux_status.h" > > +#include "libipa/pwl.h" > > + > > +#include "alsc_status.h" > > +#include "awb.h" > > + > > +using namespace libcamera; > > + > > +LOG_DECLARE_CATEGORY(RPiAwb) > > + > > +constexpr double kDefaultCT = 4500.0; > > + > > +/* > > + * The neural networks are trained to work on images rendered at a > canonical > > + * colour temperature. That value is 5000K, which must be reproduced > here. > > + */ > > +constexpr double kNetworkCanonicalCT = 5000.0; > > + > > +#define NAME "rpi.nn.awb" > > + > > +namespace RPiController { > > + > > +struct AwbNNConfig { > > + AwbNNConfig() {} > > + int read(const libcamera::YamlObject ¶ms, AwbConfig &config); > > + > > + /* An empty model will check default locations for model.tflite */ > > + std::string model; > > + float minTemp; > > + float maxTemp; > > + > > + bool enableNn; > > + > > + /* CCM matrix for canonical network CT */ > > + double ccm[9]; > > +}; > > + > > +class AwbNN : public Awb > > +{ > > +public: > > + AwbNN(Controller *controller = NULL); > > + ~AwbNN(); > > + char const *name() const override; > > + void initialise() override; > > + int read(const libcamera::YamlObject ¶ms) override; > > + > > +protected: > > + void doAwb() override; > > + void prepareStats() override; > > + > > +private: > > + bool isAutoEnabled() const; > > + AwbNNConfig nnConfig_; > > + void transverseSearch(double t, double &r, double &b); > > + RGB processZone(RGB zone, float red_gain, float blue_gain); > > + void awbNN(); > > + void loadModel(); > > + > > + libcamera::Size zoneSize_; > > + std::unique_ptr<tflite::FlatBufferModel> model_; > > + std::unique_ptr<tflite::Interpreter> interpreter_; > > +}; > > + > > +int AwbNNConfig::read(const libcamera::YamlObject ¶ms, AwbConfig > &config) > > +{ > > + model = params["model"].get<std::string>(""); > > + minTemp = params["min_temp"].get<float>(2800.0); > > + maxTemp = params["max_temp"].get<float>(7600.0); > > + > > + for (int i = 0; i < 9; i++) > > + ccm[i] = params["ccm"][i].get<double>(0.0); > > + > > + enableNn = params["enable_nn"].get<int>(1); > > + > > + if (enableNn) { > > + if (!config.hasCtCurve()) { > > + LOG(RPiAwb, Error) << "CT curve not specified"; > > + enableNn = false; > > + } > > + > > + if (!model.empty() && model.find(".tflite") == > std::string::npos) { > > + LOG(RPiAwb, Error) << "Model must be a .tflite > file"; > > + enableNn = false; > > + } > > + > > + bool validCcm = true; > > + for (int i = 0; i < 9; i++) > > + if (ccm[i] == 0.0) > > + validCcm = false; > > + > > + if (!validCcm) { > > + LOG(RPiAwb, Error) << "CCM not specified or > invalid"; > > + enableNn = false; > > + } > > + > > + if (!enableNn) { > > + LOG(RPiAwb, Warning) << "Neural Network AWB > mis-configured - switch to Grey method"; > > + } > > + } > > + > > + if (!enableNn) { > > + config.sensitivityR = config.sensitivityB = 1.0; > > + config.greyWorld = true; > > + } > > + > > + return 0; > > +} > > + > > +AwbNN::AwbNN(Controller *controller) > > + : Awb(controller) > > +{ > > + zoneSize_ = getHardwareConfig().awbRegions; > > +} > > + > > +AwbNN::~AwbNN() > > +{ > > +} > > + > > +char const *AwbNN::name() const > > +{ > > + return NAME; > > +} > > + > > +int AwbNN::read(const libcamera::YamlObject ¶ms) > > +{ > > + int ret; > > + > > + ret = config_.read(params); > > + if (ret) > > + return ret; > > + > > + ret = nnConfig_.read(params, config_); > > + if (ret) > > + return ret; > > + > > + return 0; > > +} > > + > > +static bool checkTensorShape(TfLiteTensor *tensor, const int > *expectedDims, const int expectedDimsSize) > > +{ > > + if (tensor->dims->size != expectedDimsSize) > > + return false; > > + > > + for (int i = 0; i < tensor->dims->size; i++) { > > + if (tensor->dims->data[i] != expectedDims[i]) { > > + return false; > > + } > > + } > > + return true; > > +} > > + > > +static std::string buildDimString(const int *dims, const int dimsSize) > > +{ > > + std::string s = "["; > > + for (int i = 0; i < dimsSize; i++) { > > + s += std::to_string(dims[i]); > > + if (i < dimsSize - 1) > > + s += ","; > > + else > > + s += "]"; > > + } > > + return s; > > +} > > + > > +void AwbNN::loadModel() > > +{ > > + std::string modelPath; > > + if (getTarget() == "bcm2835") { > > + modelPath = "/ipa/rpi/vc4/awb_model.tflite"; > > + } else { > > + modelPath = "/ipa/rpi/pisp/awb_model.tflite"; > > + } > > + > > + if (nnConfig_.model.empty()) { > > + std::string root = utils::libcameraSourcePath(); > > + if (!root.empty()) { > > + modelPath = root + modelPath; > > + } else { > > + modelPath = LIBCAMERA_DATA_DIR + modelPath; > > + } > > + > > + if (!File::exists(modelPath)) { > > + LOG(RPiAwb, Error) << "No model file found in > standard locations"; > > + nnConfig_.enableNn = false; > > + return; > > + } > > + } else { > > + modelPath = nnConfig_.model; > > + } > > + > > + LOG(RPiAwb, Debug) << "Attempting to load model from: " << > modelPath; > > + > > + model_ = tflite::FlatBufferModel::BuildFromFile(modelPath.c_str()); > > + > > + if (!model_) { > > + LOG(RPiAwb, Error) << "Failed to load model from " << > modelPath; > > + nnConfig_.enableNn = false; > > + return; > > + } > > + > > + tflite::MutableOpResolver resolver; > > + tflite::ops::builtin::BuiltinOpResolver builtin_resolver; > > + resolver.AddAll(builtin_resolver); > > + tflite::InterpreterBuilder(*model_, resolver)(&interpreter_); > > + if (!interpreter_) { > > + LOG(RPiAwb, Error) << "Failed to build interpreter for > model " << nnConfig_.model; > > + nnConfig_.enableNn = false; > > + return; > > + } > > + > > + interpreter_->AllocateTensors(); > > + TfLiteTensor *inputTensor = interpreter_->input_tensor(0); > > + TfLiteTensor *inputLuxTensor = interpreter_->input_tensor(1); > > + TfLiteTensor *outputTensor = interpreter_->output_tensor(0); > > + if (!inputTensor || !inputLuxTensor || !outputTensor) { > > + LOG(RPiAwb, Error) << "Model missing input or output > tensor"; > > + nnConfig_.enableNn = false; > > + return; > > + } > > + > > + const int expectedInputDims[] = { 1, (int)zoneSize_.height, > (int)zoneSize_.width, 3 }; > > + const int expectedInputLuxDims[] = { 1 }; > > + const int expectedOutputDims[] = { 1 }; > > + > > + if (!checkTensorShape(inputTensor, expectedInputDims, 4)) { > > + LOG(RPiAwb, Error) << "Model input tensor dimension > mismatch. Expected: " << buildDimString(expectedInputDims, 4) > > + << ", Got: " << > buildDimString(inputTensor->dims->data, inputTensor->dims->size); > > + nnConfig_.enableNn = false; > > + return; > > + } > > + > > + if (!checkTensorShape(inputLuxTensor, expectedInputLuxDims, 1)) { > > + LOG(RPiAwb, Error) << "Model input lux tensor dimension > mismatch. Expected: " << buildDimString(expectedInputLuxDims, 1) > > + << ", Got: " << > buildDimString(inputLuxTensor->dims->data, inputLuxTensor->dims->size); > > + nnConfig_.enableNn = false; > > + return; > > + } > > + > > + if (!checkTensorShape(outputTensor, expectedOutputDims, 1)) { > > + LOG(RPiAwb, Error) << "Model output tensor dimension > mismatch. Expected: " << buildDimString(expectedOutputDims, 1) > > + << ", Got: " << > buildDimString(outputTensor->dims->data, outputTensor->dims->size); > > + nnConfig_.enableNn = false; > > + return; > > + } > > + > > + if (inputTensor->type != kTfLiteFloat32 || inputLuxTensor->type != > kTfLiteFloat32 || outputTensor->type != kTfLiteFloat32) { > > + LOG(RPiAwb, Error) << "Model input and output tensors must > be float32"; > > + nnConfig_.enableNn = false; > > + return; > > + } > > + > > + LOG(RPiAwb, Info) << "Model loaded successfully from " << > modelPath; > > + LOG(RPiAwb, Debug) << "Model validation successful - Input Image: " > > + << buildDimString(expectedInputDims, 4) > > + << ", Input Lux: " << > buildDimString(expectedInputLuxDims, 1) > > + << ", Output: " << > buildDimString(expectedOutputDims, 1) << " floats"; > > +} > > + > > +void AwbNN::initialise() > > +{ > > + Awb::initialise(); > > + > > + if (nnConfig_.enableNn) { > > + loadModel(); > > + if (!nnConfig_.enableNn) { > > + LOG(RPiAwb, Warning) << "Neural Network AWB failed > to load - switch to Grey method"; > > + config_.greyWorld = true; > > + config_.sensitivityR = config_.sensitivityB = 1.0; > > + } > > + } > > +} > > + > > +void AwbNN::prepareStats() > > +{ > > + zones_.clear(); > > + /* > > + * LSC has already been applied to the stats in this pipeline, so > stop > > + * any LSC compensation. We also ignore config_.fast in this > version. > > + */ > > + generateStats(zones_, statistics_, 0.0, 0.0, getGlobalMetadata(), > 0.0, 0.0, 0.0); > > + /* > > + * apply sensitivities, so values appear to come from our > "canonical" > > + * sensor. > > + */ > > + for (auto &zone : zones_) { > > + zone.R *= config_.sensitivityR; > > + zone.B *= config_.sensitivityB; > > + } > > +} > > + > > +void AwbNN::transverseSearch(double t, double &r, double &b) > > +{ > > + int spanR = -1, spanB = -1; > > + config_.ctR.eval(t, &spanR); > > + config_.ctB.eval(t, &spanB); > > + > > + const int diff = 10; > > + double rDiff = config_.ctR.eval(t + diff, &spanR) - > > + config_.ctR.eval(t - diff, &spanR); > > + double bDiff = config_.ctB.eval(t + diff, &spanB) - > > + config_.ctB.eval(t - diff, &spanB); > > + > > + ipa::Pwl::Point transverse({ bDiff, -rDiff }); > > + if (transverse.length2() < 1e-6) > > + return; > > + > > + transverse = transverse / transverse.length(); > > + double transverseRange = config_.transverseNeg + > config_.transversePos; > > + const int maxNumDeltas = 12; > > + int numDeltas = floor(transverseRange * 100 + 0.5) + 1; > > + numDeltas = numDeltas < 3 ? 3 : (numDeltas > maxNumDeltas ? > maxNumDeltas : numDeltas); > > + > > + ipa::Pwl::Point points[maxNumDeltas]; > > + int bestPoint = 0; > > + > > + for (int i = 0; i < numDeltas; i++) { > > + points[i][0] = -config_.transverseNeg + > > + (transverseRange * i) / (numDeltas - 1); > > + ipa::Pwl::Point rbTest = ipa::Pwl::Point({ r, b }) + > > + transverse * points[i].x(); > > + double rTest = rbTest.x(), bTest = rbTest.y(); > > + double gainR = 1 / rTest, gainB = 1 / bTest; > > + double delta2Sum = computeDelta2Sum(gainR, gainB, 0.0, > 0.0); > > + points[i][1] = delta2Sum; > > + if (points[i].y() < points[bestPoint].y()) > > + bestPoint = i; > > + } > > + > > + bestPoint = std::clamp(bestPoint, 1, numDeltas - 2); > > + ipa::Pwl::Point rbBest = ipa::Pwl::Point({ r, b }) + > > + transverse * > interpolateQuadatric(points[bestPoint - 1], > > + > points[bestPoint], > > + > points[bestPoint + 1]); > > + double rBest = rbBest.x(), bBest = rbBest.y(); > > + > > + r = rBest, b = bBest; > > +} > > + > > +AwbNN::RGB AwbNN::processZone(AwbNN::RGB zone, float redGain, float > blueGain) > > +{ > > + /* > > + * Renders the pixel at canonical network colour temperature > > + */ > > + RGB zoneGains = zone; > > + > > + zoneGains.R *= redGain; > > + zoneGains.G *= 1.0; > > + zoneGains.B *= blueGain; > > + > > + RGB zoneCcm; > > + > > + zoneCcm.R = nnConfig_.ccm[0] * zoneGains.R + nnConfig_.ccm[1] * > zoneGains.G + nnConfig_.ccm[2] * zoneGains.B; > > + zoneCcm.G = nnConfig_.ccm[3] * zoneGains.R + nnConfig_.ccm[4] * > zoneGains.G + nnConfig_.ccm[5] * zoneGains.B; > > + zoneCcm.B = nnConfig_.ccm[6] * zoneGains.R + nnConfig_.ccm[7] * > zoneGains.G + nnConfig_.ccm[8] * zoneGains.B; > > + > > + return zoneCcm; > > +} > > + > > +void AwbNN::awbNN() > > +{ > > + float *inputData = interpreter_->typed_input_tensor<float>(0); > > + float *inputLux = interpreter_->typed_input_tensor<float>(1); > > + > > + float redGain = 1.0 / config_.ctR.eval(kNetworkCanonicalCT); > > + float blueGain = 1.0 / config_.ctB.eval(kNetworkCanonicalCT); > > + > > + for (uint i = 0; i < zoneSize_.height; i++) { > > + for (uint j = 0; j < zoneSize_.width; j++) { > > + uint zoneIdx = i * zoneSize_.width + j; > > + > > + RGB processedZone = processZone(zones_[zoneIdx] * > (1.0 / 65535), redGain, blueGain); > > + uint baseIdx = zoneIdx * 3; > > + > > + inputData[baseIdx + 0] = > static_cast<float>(processedZone.R); > > + inputData[baseIdx + 1] = > static_cast<float>(processedZone.G); > > + inputData[baseIdx + 2] = > static_cast<float>(processedZone.B); > > + } > > + } > > + > > + inputLux[0] = static_cast<float>(lux_); > > + > > + TfLiteStatus status = interpreter_->Invoke(); > > + if (status != kTfLiteOk) { > > + LOG(RPiAwb, Error) << "Model inference failed with status: > " << status; > > + return; > > + } > > + > > + float *outputData = interpreter_->typed_output_tensor<float>(0); > > + > > + double t = outputData[0]; > > + > > + LOG(RPiAwb, Debug) << "Model output temperature: " << t; > > + > > + t = std::clamp(t, mode_->ctLo, mode_->ctHi); > > + > > + double r = config_.ctR.eval(t); > > + double b = config_.ctB.eval(t); > > + > > + transverseSearch(t, r, b); > > + > > + LOG(RPiAwb, Debug) << "After transverse search: Temperature: " << > t << " Red gain: " << 1.0 / r << " Blue gain: " << 1.0 / b; > > + > > + asyncResults_.temperatureK = t; > > + asyncResults_.gainR = 1.0 / r * config_.sensitivityR; > > + asyncResults_.gainG = 1.0; > > + asyncResults_.gainB = 1.0 / b * config_.sensitivityB; > > +} > > + > > +void AwbNN::doAwb() > > +{ > > + prepareStats(); > > + if (zones_.size() == (zoneSize_.width * zoneSize_.height) && > nnConfig_.enableNn) > > + awbNN(); > > + else > > + awbGrey(); > > + statistics_.reset(); > > +} > > + > > +/* Register algorithm with the system. */ > > +static Algorithm *create(Controller *controller) > > +{ > > + return (Algorithm *)new AwbNN(controller); > > +} > > +static RegisterAlgorithm reg(NAME, &create); > > + > > +} /* namespace RPiController */ > > -- > Regards, > > Laurent Pinchart >
diff --git a/meson_options.txt b/meson_options.txt index 5954e028..89eece52 100644 --- a/meson_options.txt +++ b/meson_options.txt @@ -78,6 +78,11 @@ option('qcam', value : 'disabled', description : 'Compile the qcam test application') +option('rpi-awb-nn', + type : 'feature', + value : 'auto', + description : 'Enable the Raspberry Pi Neural Network AWB algorithm') + option('test', type : 'boolean', value : false, diff --git a/src/ipa/rpi/controller/meson.build b/src/ipa/rpi/controller/meson.build index 90d9e285..eba6cb28 100644 --- a/src/ipa/rpi/controller/meson.build +++ b/src/ipa/rpi/controller/meson.build @@ -33,6 +33,15 @@ rpi_ipa_controller_deps = [ libcamera_private, ] +tflite_dep = dependency('tensorflow-lite', required : get_option('rpi-awb-nn')) + +if tflite_dep.found() + rpi_ipa_controller_sources += files([ + 'rpi/awb_nn.cpp', + ]) + rpi_ipa_controller_deps += tflite_dep +endif + rpi_ipa_controller_lib = static_library('rpi_ipa_controller', rpi_ipa_controller_sources, include_directories : libipa_includes, dependencies : rpi_ipa_controller_deps) diff --git a/src/ipa/rpi/controller/rpi/awb_nn.cpp b/src/ipa/rpi/controller/rpi/awb_nn.cpp new file mode 100644 index 00000000..35d1270e --- /dev/null +++ b/src/ipa/rpi/controller/rpi/awb_nn.cpp @@ -0,0 +1,446 @@ +/* SPDX-License-Identifier: BSD-2-Clause */ +/* + * Copyright (C) 2025, Raspberry Pi Ltd + * + * AWB control algorithm using neural network + */ + +#include <chrono> +#include <condition_variable> +#include <thread> + +#include <libcamera/base/file.h> +#include <libcamera/base/log.h> + +#include <tensorflow/lite/interpreter.h> +#include <tensorflow/lite/kernels/register.h> +#include <tensorflow/lite/model.h> + +#include "../awb_algorithm.h" +#include "../awb_status.h" +#include "../lux_status.h" +#include "libipa/pwl.h" + +#include "alsc_status.h" +#include "awb.h" + +using namespace libcamera; + +LOG_DECLARE_CATEGORY(RPiAwb) + +constexpr double kDefaultCT = 4500.0; + +/* + * The neural networks are trained to work on images rendered at a canonical + * colour temperature. That value is 5000K, which must be reproduced here. + */ +constexpr double kNetworkCanonicalCT = 5000.0; + +#define NAME "rpi.nn.awb" + +namespace RPiController { + +struct AwbNNConfig { + AwbNNConfig() {} + int read(const libcamera::YamlObject ¶ms, AwbConfig &config); + + /* An empty model will check default locations for model.tflite */ + std::string model; + float minTemp; + float maxTemp; + + bool enableNn; + + /* CCM matrix for canonical network CT */ + double ccm[9]; +}; + +class AwbNN : public Awb +{ +public: + AwbNN(Controller *controller = NULL); + ~AwbNN(); + char const *name() const override; + void initialise() override; + int read(const libcamera::YamlObject ¶ms) override; + +protected: + void doAwb() override; + void prepareStats() override; + +private: + bool isAutoEnabled() const; + AwbNNConfig nnConfig_; + void transverseSearch(double t, double &r, double &b); + RGB processZone(RGB zone, float red_gain, float blue_gain); + void awbNN(); + void loadModel(); + + libcamera::Size zoneSize_; + std::unique_ptr<tflite::FlatBufferModel> model_; + std::unique_ptr<tflite::Interpreter> interpreter_; +}; + +int AwbNNConfig::read(const libcamera::YamlObject ¶ms, AwbConfig &config) +{ + model = params["model"].get<std::string>(""); + minTemp = params["min_temp"].get<float>(2800.0); + maxTemp = params["max_temp"].get<float>(7600.0); + + for (int i = 0; i < 9; i++) + ccm[i] = params["ccm"][i].get<double>(0.0); + + enableNn = params["enable_nn"].get<int>(1); + + if (enableNn) { + if (!config.hasCtCurve()) { + LOG(RPiAwb, Error) << "CT curve not specified"; + enableNn = false; + } + + if (!model.empty() && model.find(".tflite") == std::string::npos) { + LOG(RPiAwb, Error) << "Model must be a .tflite file"; + enableNn = false; + } + + bool validCcm = true; + for (int i = 0; i < 9; i++) + if (ccm[i] == 0.0) + validCcm = false; + + if (!validCcm) { + LOG(RPiAwb, Error) << "CCM not specified or invalid"; + enableNn = false; + } + + if (!enableNn) { + LOG(RPiAwb, Warning) << "Neural Network AWB mis-configured - switch to Grey method"; + } + } + + if (!enableNn) { + config.sensitivityR = config.sensitivityB = 1.0; + config.greyWorld = true; + } + + return 0; +} + +AwbNN::AwbNN(Controller *controller) + : Awb(controller) +{ + zoneSize_ = getHardwareConfig().awbRegions; +} + +AwbNN::~AwbNN() +{ +} + +char const *AwbNN::name() const +{ + return NAME; +} + +int AwbNN::read(const libcamera::YamlObject ¶ms) +{ + int ret; + + ret = config_.read(params); + if (ret) + return ret; + + ret = nnConfig_.read(params, config_); + if (ret) + return ret; + + return 0; +} + +static bool checkTensorShape(TfLiteTensor *tensor, const int *expectedDims, const int expectedDimsSize) +{ + if (tensor->dims->size != expectedDimsSize) + return false; + + for (int i = 0; i < tensor->dims->size; i++) { + if (tensor->dims->data[i] != expectedDims[i]) { + return false; + } + } + return true; +} + +static std::string buildDimString(const int *dims, const int dimsSize) +{ + std::string s = "["; + for (int i = 0; i < dimsSize; i++) { + s += std::to_string(dims[i]); + if (i < dimsSize - 1) + s += ","; + else + s += "]"; + } + return s; +} + +void AwbNN::loadModel() +{ + std::string modelPath; + if (getTarget() == "bcm2835") { + modelPath = "/ipa/rpi/vc4/awb_model.tflite"; + } else { + modelPath = "/ipa/rpi/pisp/awb_model.tflite"; + } + + if (nnConfig_.model.empty()) { + std::string root = utils::libcameraSourcePath(); + if (!root.empty()) { + modelPath = root + modelPath; + } else { + modelPath = LIBCAMERA_DATA_DIR + modelPath; + } + + if (!File::exists(modelPath)) { + LOG(RPiAwb, Error) << "No model file found in standard locations"; + nnConfig_.enableNn = false; + return; + } + } else { + modelPath = nnConfig_.model; + } + + LOG(RPiAwb, Debug) << "Attempting to load model from: " << modelPath; + + model_ = tflite::FlatBufferModel::BuildFromFile(modelPath.c_str()); + + if (!model_) { + LOG(RPiAwb, Error) << "Failed to load model from " << modelPath; + nnConfig_.enableNn = false; + return; + } + + tflite::MutableOpResolver resolver; + tflite::ops::builtin::BuiltinOpResolver builtin_resolver; + resolver.AddAll(builtin_resolver); + tflite::InterpreterBuilder(*model_, resolver)(&interpreter_); + if (!interpreter_) { + LOG(RPiAwb, Error) << "Failed to build interpreter for model " << nnConfig_.model; + nnConfig_.enableNn = false; + return; + } + + interpreter_->AllocateTensors(); + TfLiteTensor *inputTensor = interpreter_->input_tensor(0); + TfLiteTensor *inputLuxTensor = interpreter_->input_tensor(1); + TfLiteTensor *outputTensor = interpreter_->output_tensor(0); + if (!inputTensor || !inputLuxTensor || !outputTensor) { + LOG(RPiAwb, Error) << "Model missing input or output tensor"; + nnConfig_.enableNn = false; + return; + } + + const int expectedInputDims[] = { 1, (int)zoneSize_.height, (int)zoneSize_.width, 3 }; + const int expectedInputLuxDims[] = { 1 }; + const int expectedOutputDims[] = { 1 }; + + if (!checkTensorShape(inputTensor, expectedInputDims, 4)) { + LOG(RPiAwb, Error) << "Model input tensor dimension mismatch. Expected: " << buildDimString(expectedInputDims, 4) + << ", Got: " << buildDimString(inputTensor->dims->data, inputTensor->dims->size); + nnConfig_.enableNn = false; + return; + } + + if (!checkTensorShape(inputLuxTensor, expectedInputLuxDims, 1)) { + LOG(RPiAwb, Error) << "Model input lux tensor dimension mismatch. Expected: " << buildDimString(expectedInputLuxDims, 1) + << ", Got: " << buildDimString(inputLuxTensor->dims->data, inputLuxTensor->dims->size); + nnConfig_.enableNn = false; + return; + } + + if (!checkTensorShape(outputTensor, expectedOutputDims, 1)) { + LOG(RPiAwb, Error) << "Model output tensor dimension mismatch. Expected: " << buildDimString(expectedOutputDims, 1) + << ", Got: " << buildDimString(outputTensor->dims->data, outputTensor->dims->size); + nnConfig_.enableNn = false; + return; + } + + if (inputTensor->type != kTfLiteFloat32 || inputLuxTensor->type != kTfLiteFloat32 || outputTensor->type != kTfLiteFloat32) { + LOG(RPiAwb, Error) << "Model input and output tensors must be float32"; + nnConfig_.enableNn = false; + return; + } + + LOG(RPiAwb, Info) << "Model loaded successfully from " << modelPath; + LOG(RPiAwb, Debug) << "Model validation successful - Input Image: " + << buildDimString(expectedInputDims, 4) + << ", Input Lux: " << buildDimString(expectedInputLuxDims, 1) + << ", Output: " << buildDimString(expectedOutputDims, 1) << " floats"; +} + +void AwbNN::initialise() +{ + Awb::initialise(); + + if (nnConfig_.enableNn) { + loadModel(); + if (!nnConfig_.enableNn) { + LOG(RPiAwb, Warning) << "Neural Network AWB failed to load - switch to Grey method"; + config_.greyWorld = true; + config_.sensitivityR = config_.sensitivityB = 1.0; + } + } +} + +void AwbNN::prepareStats() +{ + zones_.clear(); + /* + * LSC has already been applied to the stats in this pipeline, so stop + * any LSC compensation. We also ignore config_.fast in this version. + */ + generateStats(zones_, statistics_, 0.0, 0.0, getGlobalMetadata(), 0.0, 0.0, 0.0); + /* + * apply sensitivities, so values appear to come from our "canonical" + * sensor. + */ + for (auto &zone : zones_) { + zone.R *= config_.sensitivityR; + zone.B *= config_.sensitivityB; + } +} + +void AwbNN::transverseSearch(double t, double &r, double &b) +{ + int spanR = -1, spanB = -1; + config_.ctR.eval(t, &spanR); + config_.ctB.eval(t, &spanB); + + const int diff = 10; + double rDiff = config_.ctR.eval(t + diff, &spanR) - + config_.ctR.eval(t - diff, &spanR); + double bDiff = config_.ctB.eval(t + diff, &spanB) - + config_.ctB.eval(t - diff, &spanB); + + ipa::Pwl::Point transverse({ bDiff, -rDiff }); + if (transverse.length2() < 1e-6) + return; + + transverse = transverse / transverse.length(); + double transverseRange = config_.transverseNeg + config_.transversePos; + const int maxNumDeltas = 12; + int numDeltas = floor(transverseRange * 100 + 0.5) + 1; + numDeltas = numDeltas < 3 ? 3 : (numDeltas > maxNumDeltas ? maxNumDeltas : numDeltas); + + ipa::Pwl::Point points[maxNumDeltas]; + int bestPoint = 0; + + for (int i = 0; i < numDeltas; i++) { + points[i][0] = -config_.transverseNeg + + (transverseRange * i) / (numDeltas - 1); + ipa::Pwl::Point rbTest = ipa::Pwl::Point({ r, b }) + + transverse * points[i].x(); + double rTest = rbTest.x(), bTest = rbTest.y(); + double gainR = 1 / rTest, gainB = 1 / bTest; + double delta2Sum = computeDelta2Sum(gainR, gainB, 0.0, 0.0); + points[i][1] = delta2Sum; + if (points[i].y() < points[bestPoint].y()) + bestPoint = i; + } + + bestPoint = std::clamp(bestPoint, 1, numDeltas - 2); + ipa::Pwl::Point rbBest = ipa::Pwl::Point({ r, b }) + + transverse * interpolateQuadatric(points[bestPoint - 1], + points[bestPoint], + points[bestPoint + 1]); + double rBest = rbBest.x(), bBest = rbBest.y(); + + r = rBest, b = bBest; +} + +AwbNN::RGB AwbNN::processZone(AwbNN::RGB zone, float redGain, float blueGain) +{ + /* + * Renders the pixel at canonical network colour temperature + */ + RGB zoneGains = zone; + + zoneGains.R *= redGain; + zoneGains.G *= 1.0; + zoneGains.B *= blueGain; + + RGB zoneCcm; + + zoneCcm.R = nnConfig_.ccm[0] * zoneGains.R + nnConfig_.ccm[1] * zoneGains.G + nnConfig_.ccm[2] * zoneGains.B; + zoneCcm.G = nnConfig_.ccm[3] * zoneGains.R + nnConfig_.ccm[4] * zoneGains.G + nnConfig_.ccm[5] * zoneGains.B; + zoneCcm.B = nnConfig_.ccm[6] * zoneGains.R + nnConfig_.ccm[7] * zoneGains.G + nnConfig_.ccm[8] * zoneGains.B; + + return zoneCcm; +} + +void AwbNN::awbNN() +{ + float *inputData = interpreter_->typed_input_tensor<float>(0); + float *inputLux = interpreter_->typed_input_tensor<float>(1); + + float redGain = 1.0 / config_.ctR.eval(kNetworkCanonicalCT); + float blueGain = 1.0 / config_.ctB.eval(kNetworkCanonicalCT); + + for (uint i = 0; i < zoneSize_.height; i++) { + for (uint j = 0; j < zoneSize_.width; j++) { + uint zoneIdx = i * zoneSize_.width + j; + + RGB processedZone = processZone(zones_[zoneIdx] * (1.0 / 65535), redGain, blueGain); + uint baseIdx = zoneIdx * 3; + + inputData[baseIdx + 0] = static_cast<float>(processedZone.R); + inputData[baseIdx + 1] = static_cast<float>(processedZone.G); + inputData[baseIdx + 2] = static_cast<float>(processedZone.B); + } + } + + inputLux[0] = static_cast<float>(lux_); + + TfLiteStatus status = interpreter_->Invoke(); + if (status != kTfLiteOk) { + LOG(RPiAwb, Error) << "Model inference failed with status: " << status; + return; + } + + float *outputData = interpreter_->typed_output_tensor<float>(0); + + double t = outputData[0]; + + LOG(RPiAwb, Debug) << "Model output temperature: " << t; + + t = std::clamp(t, mode_->ctLo, mode_->ctHi); + + double r = config_.ctR.eval(t); + double b = config_.ctB.eval(t); + + transverseSearch(t, r, b); + + LOG(RPiAwb, Debug) << "After transverse search: Temperature: " << t << " Red gain: " << 1.0 / r << " Blue gain: " << 1.0 / b; + + asyncResults_.temperatureK = t; + asyncResults_.gainR = 1.0 / r * config_.sensitivityR; + asyncResults_.gainG = 1.0; + asyncResults_.gainB = 1.0 / b * config_.sensitivityB; +} + +void AwbNN::doAwb() +{ + prepareStats(); + if (zones_.size() == (zoneSize_.width * zoneSize_.height) && nnConfig_.enableNn) + awbNN(); + else + awbGrey(); + statistics_.reset(); +} + +/* Register algorithm with the system. */ +static Algorithm *create(Controller *controller) +{ + return (Algorithm *)new AwbNN(controller); +} +static RegisterAlgorithm reg(NAME, &create); + +} /* namespace RPiController */