| Message ID | 20251211142824.26635-3-david.plowman@raspberrypi.com |
|---|---|
| State | Superseded |
| Headers | show |
| Series |
|
| Related | show |
Hi On Thu, 11 Dec 2025 at 14:28, David Plowman <david.plowman@raspberrypi.com> wrote: > > From: Peter Bailey <peter.bailey@raspberrypi.com> > > Add an AWB algorithm which uses neural networks. > > Signed-off-by: Peter Bailey <peter.bailey@raspberrypi.com> > --- > meson_options.txt | 5 + > src/ipa/rpi/controller/meson.build | 9 + > src/ipa/rpi/controller/rpi/awb_nn.cpp | 446 ++++++++++++++++++++++++++ > 3 files changed, 460 insertions(+) > create mode 100644 src/ipa/rpi/controller/rpi/awb_nn.cpp > > diff --git a/meson_options.txt b/meson_options.txt > index 5954e028..89eece52 100644 > --- a/meson_options.txt > +++ b/meson_options.txt > @@ -78,6 +78,11 @@ option('qcam', > value : 'disabled', > description : 'Compile the qcam test application') > > +option('rpi-awb-nn', > + type : 'feature', > + value : 'auto', > + description : 'Enable the Raspberry Pi Neural Network AWB algorithm') > + > option('test', > type : 'boolean', > value : false, > diff --git a/src/ipa/rpi/controller/meson.build b/src/ipa/rpi/controller/meson.build > index 90d9e285..9c261d99 100644 > --- a/src/ipa/rpi/controller/meson.build > +++ b/src/ipa/rpi/controller/meson.build > @@ -33,6 +33,15 @@ rpi_ipa_controller_deps = [ > libcamera_private, > ] > > +tflite_dep = dependency('tensorflow-lite', required : false) Oops. I edited this to read tflite_dep = dependency('tensorflow-lite', required : get_option('rpi-awb-nn')) and then failed to hit the save button. Version 3 incoming shortly! David > + > +if tflite_dep.found() > + rpi_ipa_controller_sources += files([ > + 'rpi/awb_nn.cpp', > + ]) > + rpi_ipa_controller_deps += tflite_dep > +endif > + > rpi_ipa_controller_lib = static_library('rpi_ipa_controller', rpi_ipa_controller_sources, > include_directories : libipa_includes, > dependencies : rpi_ipa_controller_deps) > diff --git a/src/ipa/rpi/controller/rpi/awb_nn.cpp b/src/ipa/rpi/controller/rpi/awb_nn.cpp > new file mode 100644 > index 00000000..35d1270e > --- /dev/null > +++ b/src/ipa/rpi/controller/rpi/awb_nn.cpp > @@ -0,0 +1,446 @@ > +/* SPDX-License-Identifier: BSD-2-Clause */ > +/* > + * Copyright (C) 2025, Raspberry Pi Ltd > + * > + * AWB control algorithm using neural network > + */ > + > +#include <chrono> > +#include <condition_variable> > +#include <thread> > + > +#include <libcamera/base/file.h> > +#include <libcamera/base/log.h> > + > +#include <tensorflow/lite/interpreter.h> > +#include <tensorflow/lite/kernels/register.h> > +#include <tensorflow/lite/model.h> > + > +#include "../awb_algorithm.h" > +#include "../awb_status.h" > +#include "../lux_status.h" > +#include "libipa/pwl.h" > + > +#include "alsc_status.h" > +#include "awb.h" > + > +using namespace libcamera; > + > +LOG_DECLARE_CATEGORY(RPiAwb) > + > +constexpr double kDefaultCT = 4500.0; > + > +/* > + * The neural networks are trained to work on images rendered at a canonical > + * colour temperature. That value is 5000K, which must be reproduced here. > + */ > +constexpr double kNetworkCanonicalCT = 5000.0; > + > +#define NAME "rpi.nn.awb" > + > +namespace RPiController { > + > +struct AwbNNConfig { > + AwbNNConfig() {} > + int read(const libcamera::YamlObject ¶ms, AwbConfig &config); > + > + /* An empty model will check default locations for model.tflite */ > + std::string model; > + float minTemp; > + float maxTemp; > + > + bool enableNn; > + > + /* CCM matrix for canonical network CT */ > + double ccm[9]; > +}; > + > +class AwbNN : public Awb > +{ > +public: > + AwbNN(Controller *controller = NULL); > + ~AwbNN(); > + char const *name() const override; > + void initialise() override; > + int read(const libcamera::YamlObject ¶ms) override; > + > +protected: > + void doAwb() override; > + void prepareStats() override; > + > +private: > + bool isAutoEnabled() const; > + AwbNNConfig nnConfig_; > + void transverseSearch(double t, double &r, double &b); > + RGB processZone(RGB zone, float red_gain, float blue_gain); > + void awbNN(); > + void loadModel(); > + > + libcamera::Size zoneSize_; > + std::unique_ptr<tflite::FlatBufferModel> model_; > + std::unique_ptr<tflite::Interpreter> interpreter_; > +}; > + > +int AwbNNConfig::read(const libcamera::YamlObject ¶ms, AwbConfig &config) > +{ > + model = params["model"].get<std::string>(""); > + minTemp = params["min_temp"].get<float>(2800.0); > + maxTemp = params["max_temp"].get<float>(7600.0); > + > + for (int i = 0; i < 9; i++) > + ccm[i] = params["ccm"][i].get<double>(0.0); > + > + enableNn = params["enable_nn"].get<int>(1); > + > + if (enableNn) { > + if (!config.hasCtCurve()) { > + LOG(RPiAwb, Error) << "CT curve not specified"; > + enableNn = false; > + } > + > + if (!model.empty() && model.find(".tflite") == std::string::npos) { > + LOG(RPiAwb, Error) << "Model must be a .tflite file"; > + enableNn = false; > + } > + > + bool validCcm = true; > + for (int i = 0; i < 9; i++) > + if (ccm[i] == 0.0) > + validCcm = false; > + > + if (!validCcm) { > + LOG(RPiAwb, Error) << "CCM not specified or invalid"; > + enableNn = false; > + } > + > + if (!enableNn) { > + LOG(RPiAwb, Warning) << "Neural Network AWB mis-configured - switch to Grey method"; > + } > + } > + > + if (!enableNn) { > + config.sensitivityR = config.sensitivityB = 1.0; > + config.greyWorld = true; > + } > + > + return 0; > +} > + > +AwbNN::AwbNN(Controller *controller) > + : Awb(controller) > +{ > + zoneSize_ = getHardwareConfig().awbRegions; > +} > + > +AwbNN::~AwbNN() > +{ > +} > + > +char const *AwbNN::name() const > +{ > + return NAME; > +} > + > +int AwbNN::read(const libcamera::YamlObject ¶ms) > +{ > + int ret; > + > + ret = config_.read(params); > + if (ret) > + return ret; > + > + ret = nnConfig_.read(params, config_); > + if (ret) > + return ret; > + > + return 0; > +} > + > +static bool checkTensorShape(TfLiteTensor *tensor, const int *expectedDims, const int expectedDimsSize) > +{ > + if (tensor->dims->size != expectedDimsSize) > + return false; > + > + for (int i = 0; i < tensor->dims->size; i++) { > + if (tensor->dims->data[i] != expectedDims[i]) { > + return false; > + } > + } > + return true; > +} > + > +static std::string buildDimString(const int *dims, const int dimsSize) > +{ > + std::string s = "["; > + for (int i = 0; i < dimsSize; i++) { > + s += std::to_string(dims[i]); > + if (i < dimsSize - 1) > + s += ","; > + else > + s += "]"; > + } > + return s; > +} > + > +void AwbNN::loadModel() > +{ > + std::string modelPath; > + if (getTarget() == "bcm2835") { > + modelPath = "/ipa/rpi/vc4/awb_model.tflite"; > + } else { > + modelPath = "/ipa/rpi/pisp/awb_model.tflite"; > + } > + > + if (nnConfig_.model.empty()) { > + std::string root = utils::libcameraSourcePath(); > + if (!root.empty()) { > + modelPath = root + modelPath; > + } else { > + modelPath = LIBCAMERA_DATA_DIR + modelPath; > + } > + > + if (!File::exists(modelPath)) { > + LOG(RPiAwb, Error) << "No model file found in standard locations"; > + nnConfig_.enableNn = false; > + return; > + } > + } else { > + modelPath = nnConfig_.model; > + } > + > + LOG(RPiAwb, Debug) << "Attempting to load model from: " << modelPath; > + > + model_ = tflite::FlatBufferModel::BuildFromFile(modelPath.c_str()); > + > + if (!model_) { > + LOG(RPiAwb, Error) << "Failed to load model from " << modelPath; > + nnConfig_.enableNn = false; > + return; > + } > + > + tflite::MutableOpResolver resolver; > + tflite::ops::builtin::BuiltinOpResolver builtin_resolver; > + resolver.AddAll(builtin_resolver); > + tflite::InterpreterBuilder(*model_, resolver)(&interpreter_); > + if (!interpreter_) { > + LOG(RPiAwb, Error) << "Failed to build interpreter for model " << nnConfig_.model; > + nnConfig_.enableNn = false; > + return; > + } > + > + interpreter_->AllocateTensors(); > + TfLiteTensor *inputTensor = interpreter_->input_tensor(0); > + TfLiteTensor *inputLuxTensor = interpreter_->input_tensor(1); > + TfLiteTensor *outputTensor = interpreter_->output_tensor(0); > + if (!inputTensor || !inputLuxTensor || !outputTensor) { > + LOG(RPiAwb, Error) << "Model missing input or output tensor"; > + nnConfig_.enableNn = false; > + return; > + } > + > + const int expectedInputDims[] = { 1, (int)zoneSize_.height, (int)zoneSize_.width, 3 }; > + const int expectedInputLuxDims[] = { 1 }; > + const int expectedOutputDims[] = { 1 }; > + > + if (!checkTensorShape(inputTensor, expectedInputDims, 4)) { > + LOG(RPiAwb, Error) << "Model input tensor dimension mismatch. Expected: " << buildDimString(expectedInputDims, 4) > + << ", Got: " << buildDimString(inputTensor->dims->data, inputTensor->dims->size); > + nnConfig_.enableNn = false; > + return; > + } > + > + if (!checkTensorShape(inputLuxTensor, expectedInputLuxDims, 1)) { > + LOG(RPiAwb, Error) << "Model input lux tensor dimension mismatch. Expected: " << buildDimString(expectedInputLuxDims, 1) > + << ", Got: " << buildDimString(inputLuxTensor->dims->data, inputLuxTensor->dims->size); > + nnConfig_.enableNn = false; > + return; > + } > + > + if (!checkTensorShape(outputTensor, expectedOutputDims, 1)) { > + LOG(RPiAwb, Error) << "Model output tensor dimension mismatch. Expected: " << buildDimString(expectedOutputDims, 1) > + << ", Got: " << buildDimString(outputTensor->dims->data, outputTensor->dims->size); > + nnConfig_.enableNn = false; > + return; > + } > + > + if (inputTensor->type != kTfLiteFloat32 || inputLuxTensor->type != kTfLiteFloat32 || outputTensor->type != kTfLiteFloat32) { > + LOG(RPiAwb, Error) << "Model input and output tensors must be float32"; > + nnConfig_.enableNn = false; > + return; > + } > + > + LOG(RPiAwb, Info) << "Model loaded successfully from " << modelPath; > + LOG(RPiAwb, Debug) << "Model validation successful - Input Image: " > + << buildDimString(expectedInputDims, 4) > + << ", Input Lux: " << buildDimString(expectedInputLuxDims, 1) > + << ", Output: " << buildDimString(expectedOutputDims, 1) << " floats"; > +} > + > +void AwbNN::initialise() > +{ > + Awb::initialise(); > + > + if (nnConfig_.enableNn) { > + loadModel(); > + if (!nnConfig_.enableNn) { > + LOG(RPiAwb, Warning) << "Neural Network AWB failed to load - switch to Grey method"; > + config_.greyWorld = true; > + config_.sensitivityR = config_.sensitivityB = 1.0; > + } > + } > +} > + > +void AwbNN::prepareStats() > +{ > + zones_.clear(); > + /* > + * LSC has already been applied to the stats in this pipeline, so stop > + * any LSC compensation. We also ignore config_.fast in this version. > + */ > + generateStats(zones_, statistics_, 0.0, 0.0, getGlobalMetadata(), 0.0, 0.0, 0.0); > + /* > + * apply sensitivities, so values appear to come from our "canonical" > + * sensor. > + */ > + for (auto &zone : zones_) { > + zone.R *= config_.sensitivityR; > + zone.B *= config_.sensitivityB; > + } > +} > + > +void AwbNN::transverseSearch(double t, double &r, double &b) > +{ > + int spanR = -1, spanB = -1; > + config_.ctR.eval(t, &spanR); > + config_.ctB.eval(t, &spanB); > + > + const int diff = 10; > + double rDiff = config_.ctR.eval(t + diff, &spanR) - > + config_.ctR.eval(t - diff, &spanR); > + double bDiff = config_.ctB.eval(t + diff, &spanB) - > + config_.ctB.eval(t - diff, &spanB); > + > + ipa::Pwl::Point transverse({ bDiff, -rDiff }); > + if (transverse.length2() < 1e-6) > + return; > + > + transverse = transverse / transverse.length(); > + double transverseRange = config_.transverseNeg + config_.transversePos; > + const int maxNumDeltas = 12; > + int numDeltas = floor(transverseRange * 100 + 0.5) + 1; > + numDeltas = numDeltas < 3 ? 3 : (numDeltas > maxNumDeltas ? maxNumDeltas : numDeltas); > + > + ipa::Pwl::Point points[maxNumDeltas]; > + int bestPoint = 0; > + > + for (int i = 0; i < numDeltas; i++) { > + points[i][0] = -config_.transverseNeg + > + (transverseRange * i) / (numDeltas - 1); > + ipa::Pwl::Point rbTest = ipa::Pwl::Point({ r, b }) + > + transverse * points[i].x(); > + double rTest = rbTest.x(), bTest = rbTest.y(); > + double gainR = 1 / rTest, gainB = 1 / bTest; > + double delta2Sum = computeDelta2Sum(gainR, gainB, 0.0, 0.0); > + points[i][1] = delta2Sum; > + if (points[i].y() < points[bestPoint].y()) > + bestPoint = i; > + } > + > + bestPoint = std::clamp(bestPoint, 1, numDeltas - 2); > + ipa::Pwl::Point rbBest = ipa::Pwl::Point({ r, b }) + > + transverse * interpolateQuadatric(points[bestPoint - 1], > + points[bestPoint], > + points[bestPoint + 1]); > + double rBest = rbBest.x(), bBest = rbBest.y(); > + > + r = rBest, b = bBest; > +} > + > +AwbNN::RGB AwbNN::processZone(AwbNN::RGB zone, float redGain, float blueGain) > +{ > + /* > + * Renders the pixel at canonical network colour temperature > + */ > + RGB zoneGains = zone; > + > + zoneGains.R *= redGain; > + zoneGains.G *= 1.0; > + zoneGains.B *= blueGain; > + > + RGB zoneCcm; > + > + zoneCcm.R = nnConfig_.ccm[0] * zoneGains.R + nnConfig_.ccm[1] * zoneGains.G + nnConfig_.ccm[2] * zoneGains.B; > + zoneCcm.G = nnConfig_.ccm[3] * zoneGains.R + nnConfig_.ccm[4] * zoneGains.G + nnConfig_.ccm[5] * zoneGains.B; > + zoneCcm.B = nnConfig_.ccm[6] * zoneGains.R + nnConfig_.ccm[7] * zoneGains.G + nnConfig_.ccm[8] * zoneGains.B; > + > + return zoneCcm; > +} > + > +void AwbNN::awbNN() > +{ > + float *inputData = interpreter_->typed_input_tensor<float>(0); > + float *inputLux = interpreter_->typed_input_tensor<float>(1); > + > + float redGain = 1.0 / config_.ctR.eval(kNetworkCanonicalCT); > + float blueGain = 1.0 / config_.ctB.eval(kNetworkCanonicalCT); > + > + for (uint i = 0; i < zoneSize_.height; i++) { > + for (uint j = 0; j < zoneSize_.width; j++) { > + uint zoneIdx = i * zoneSize_.width + j; > + > + RGB processedZone = processZone(zones_[zoneIdx] * (1.0 / 65535), redGain, blueGain); > + uint baseIdx = zoneIdx * 3; > + > + inputData[baseIdx + 0] = static_cast<float>(processedZone.R); > + inputData[baseIdx + 1] = static_cast<float>(processedZone.G); > + inputData[baseIdx + 2] = static_cast<float>(processedZone.B); > + } > + } > + > + inputLux[0] = static_cast<float>(lux_); > + > + TfLiteStatus status = interpreter_->Invoke(); > + if (status != kTfLiteOk) { > + LOG(RPiAwb, Error) << "Model inference failed with status: " << status; > + return; > + } > + > + float *outputData = interpreter_->typed_output_tensor<float>(0); > + > + double t = outputData[0]; > + > + LOG(RPiAwb, Debug) << "Model output temperature: " << t; > + > + t = std::clamp(t, mode_->ctLo, mode_->ctHi); > + > + double r = config_.ctR.eval(t); > + double b = config_.ctB.eval(t); > + > + transverseSearch(t, r, b); > + > + LOG(RPiAwb, Debug) << "After transverse search: Temperature: " << t << " Red gain: " << 1.0 / r << " Blue gain: " << 1.0 / b; > + > + asyncResults_.temperatureK = t; > + asyncResults_.gainR = 1.0 / r * config_.sensitivityR; > + asyncResults_.gainG = 1.0; > + asyncResults_.gainB = 1.0 / b * config_.sensitivityB; > +} > + > +void AwbNN::doAwb() > +{ > + prepareStats(); > + if (zones_.size() == (zoneSize_.width * zoneSize_.height) && nnConfig_.enableNn) > + awbNN(); > + else > + awbGrey(); > + statistics_.reset(); > +} > + > +/* Register algorithm with the system. */ > +static Algorithm *create(Controller *controller) > +{ > + return (Algorithm *)new AwbNN(controller); > +} > +static RegisterAlgorithm reg(NAME, &create); > + > +} /* namespace RPiController */ > -- > 2.47.3 >
diff --git a/meson_options.txt b/meson_options.txt index 5954e028..89eece52 100644 --- a/meson_options.txt +++ b/meson_options.txt @@ -78,6 +78,11 @@ option('qcam', value : 'disabled', description : 'Compile the qcam test application') +option('rpi-awb-nn', + type : 'feature', + value : 'auto', + description : 'Enable the Raspberry Pi Neural Network AWB algorithm') + option('test', type : 'boolean', value : false, diff --git a/src/ipa/rpi/controller/meson.build b/src/ipa/rpi/controller/meson.build index 90d9e285..9c261d99 100644 --- a/src/ipa/rpi/controller/meson.build +++ b/src/ipa/rpi/controller/meson.build @@ -33,6 +33,15 @@ rpi_ipa_controller_deps = [ libcamera_private, ] +tflite_dep = dependency('tensorflow-lite', required : false) + +if tflite_dep.found() + rpi_ipa_controller_sources += files([ + 'rpi/awb_nn.cpp', + ]) + rpi_ipa_controller_deps += tflite_dep +endif + rpi_ipa_controller_lib = static_library('rpi_ipa_controller', rpi_ipa_controller_sources, include_directories : libipa_includes, dependencies : rpi_ipa_controller_deps) diff --git a/src/ipa/rpi/controller/rpi/awb_nn.cpp b/src/ipa/rpi/controller/rpi/awb_nn.cpp new file mode 100644 index 00000000..35d1270e --- /dev/null +++ b/src/ipa/rpi/controller/rpi/awb_nn.cpp @@ -0,0 +1,446 @@ +/* SPDX-License-Identifier: BSD-2-Clause */ +/* + * Copyright (C) 2025, Raspberry Pi Ltd + * + * AWB control algorithm using neural network + */ + +#include <chrono> +#include <condition_variable> +#include <thread> + +#include <libcamera/base/file.h> +#include <libcamera/base/log.h> + +#include <tensorflow/lite/interpreter.h> +#include <tensorflow/lite/kernels/register.h> +#include <tensorflow/lite/model.h> + +#include "../awb_algorithm.h" +#include "../awb_status.h" +#include "../lux_status.h" +#include "libipa/pwl.h" + +#include "alsc_status.h" +#include "awb.h" + +using namespace libcamera; + +LOG_DECLARE_CATEGORY(RPiAwb) + +constexpr double kDefaultCT = 4500.0; + +/* + * The neural networks are trained to work on images rendered at a canonical + * colour temperature. That value is 5000K, which must be reproduced here. + */ +constexpr double kNetworkCanonicalCT = 5000.0; + +#define NAME "rpi.nn.awb" + +namespace RPiController { + +struct AwbNNConfig { + AwbNNConfig() {} + int read(const libcamera::YamlObject ¶ms, AwbConfig &config); + + /* An empty model will check default locations for model.tflite */ + std::string model; + float minTemp; + float maxTemp; + + bool enableNn; + + /* CCM matrix for canonical network CT */ + double ccm[9]; +}; + +class AwbNN : public Awb +{ +public: + AwbNN(Controller *controller = NULL); + ~AwbNN(); + char const *name() const override; + void initialise() override; + int read(const libcamera::YamlObject ¶ms) override; + +protected: + void doAwb() override; + void prepareStats() override; + +private: + bool isAutoEnabled() const; + AwbNNConfig nnConfig_; + void transverseSearch(double t, double &r, double &b); + RGB processZone(RGB zone, float red_gain, float blue_gain); + void awbNN(); + void loadModel(); + + libcamera::Size zoneSize_; + std::unique_ptr<tflite::FlatBufferModel> model_; + std::unique_ptr<tflite::Interpreter> interpreter_; +}; + +int AwbNNConfig::read(const libcamera::YamlObject ¶ms, AwbConfig &config) +{ + model = params["model"].get<std::string>(""); + minTemp = params["min_temp"].get<float>(2800.0); + maxTemp = params["max_temp"].get<float>(7600.0); + + for (int i = 0; i < 9; i++) + ccm[i] = params["ccm"][i].get<double>(0.0); + + enableNn = params["enable_nn"].get<int>(1); + + if (enableNn) { + if (!config.hasCtCurve()) { + LOG(RPiAwb, Error) << "CT curve not specified"; + enableNn = false; + } + + if (!model.empty() && model.find(".tflite") == std::string::npos) { + LOG(RPiAwb, Error) << "Model must be a .tflite file"; + enableNn = false; + } + + bool validCcm = true; + for (int i = 0; i < 9; i++) + if (ccm[i] == 0.0) + validCcm = false; + + if (!validCcm) { + LOG(RPiAwb, Error) << "CCM not specified or invalid"; + enableNn = false; + } + + if (!enableNn) { + LOG(RPiAwb, Warning) << "Neural Network AWB mis-configured - switch to Grey method"; + } + } + + if (!enableNn) { + config.sensitivityR = config.sensitivityB = 1.0; + config.greyWorld = true; + } + + return 0; +} + +AwbNN::AwbNN(Controller *controller) + : Awb(controller) +{ + zoneSize_ = getHardwareConfig().awbRegions; +} + +AwbNN::~AwbNN() +{ +} + +char const *AwbNN::name() const +{ + return NAME; +} + +int AwbNN::read(const libcamera::YamlObject ¶ms) +{ + int ret; + + ret = config_.read(params); + if (ret) + return ret; + + ret = nnConfig_.read(params, config_); + if (ret) + return ret; + + return 0; +} + +static bool checkTensorShape(TfLiteTensor *tensor, const int *expectedDims, const int expectedDimsSize) +{ + if (tensor->dims->size != expectedDimsSize) + return false; + + for (int i = 0; i < tensor->dims->size; i++) { + if (tensor->dims->data[i] != expectedDims[i]) { + return false; + } + } + return true; +} + +static std::string buildDimString(const int *dims, const int dimsSize) +{ + std::string s = "["; + for (int i = 0; i < dimsSize; i++) { + s += std::to_string(dims[i]); + if (i < dimsSize - 1) + s += ","; + else + s += "]"; + } + return s; +} + +void AwbNN::loadModel() +{ + std::string modelPath; + if (getTarget() == "bcm2835") { + modelPath = "/ipa/rpi/vc4/awb_model.tflite"; + } else { + modelPath = "/ipa/rpi/pisp/awb_model.tflite"; + } + + if (nnConfig_.model.empty()) { + std::string root = utils::libcameraSourcePath(); + if (!root.empty()) { + modelPath = root + modelPath; + } else { + modelPath = LIBCAMERA_DATA_DIR + modelPath; + } + + if (!File::exists(modelPath)) { + LOG(RPiAwb, Error) << "No model file found in standard locations"; + nnConfig_.enableNn = false; + return; + } + } else { + modelPath = nnConfig_.model; + } + + LOG(RPiAwb, Debug) << "Attempting to load model from: " << modelPath; + + model_ = tflite::FlatBufferModel::BuildFromFile(modelPath.c_str()); + + if (!model_) { + LOG(RPiAwb, Error) << "Failed to load model from " << modelPath; + nnConfig_.enableNn = false; + return; + } + + tflite::MutableOpResolver resolver; + tflite::ops::builtin::BuiltinOpResolver builtin_resolver; + resolver.AddAll(builtin_resolver); + tflite::InterpreterBuilder(*model_, resolver)(&interpreter_); + if (!interpreter_) { + LOG(RPiAwb, Error) << "Failed to build interpreter for model " << nnConfig_.model; + nnConfig_.enableNn = false; + return; + } + + interpreter_->AllocateTensors(); + TfLiteTensor *inputTensor = interpreter_->input_tensor(0); + TfLiteTensor *inputLuxTensor = interpreter_->input_tensor(1); + TfLiteTensor *outputTensor = interpreter_->output_tensor(0); + if (!inputTensor || !inputLuxTensor || !outputTensor) { + LOG(RPiAwb, Error) << "Model missing input or output tensor"; + nnConfig_.enableNn = false; + return; + } + + const int expectedInputDims[] = { 1, (int)zoneSize_.height, (int)zoneSize_.width, 3 }; + const int expectedInputLuxDims[] = { 1 }; + const int expectedOutputDims[] = { 1 }; + + if (!checkTensorShape(inputTensor, expectedInputDims, 4)) { + LOG(RPiAwb, Error) << "Model input tensor dimension mismatch. Expected: " << buildDimString(expectedInputDims, 4) + << ", Got: " << buildDimString(inputTensor->dims->data, inputTensor->dims->size); + nnConfig_.enableNn = false; + return; + } + + if (!checkTensorShape(inputLuxTensor, expectedInputLuxDims, 1)) { + LOG(RPiAwb, Error) << "Model input lux tensor dimension mismatch. Expected: " << buildDimString(expectedInputLuxDims, 1) + << ", Got: " << buildDimString(inputLuxTensor->dims->data, inputLuxTensor->dims->size); + nnConfig_.enableNn = false; + return; + } + + if (!checkTensorShape(outputTensor, expectedOutputDims, 1)) { + LOG(RPiAwb, Error) << "Model output tensor dimension mismatch. Expected: " << buildDimString(expectedOutputDims, 1) + << ", Got: " << buildDimString(outputTensor->dims->data, outputTensor->dims->size); + nnConfig_.enableNn = false; + return; + } + + if (inputTensor->type != kTfLiteFloat32 || inputLuxTensor->type != kTfLiteFloat32 || outputTensor->type != kTfLiteFloat32) { + LOG(RPiAwb, Error) << "Model input and output tensors must be float32"; + nnConfig_.enableNn = false; + return; + } + + LOG(RPiAwb, Info) << "Model loaded successfully from " << modelPath; + LOG(RPiAwb, Debug) << "Model validation successful - Input Image: " + << buildDimString(expectedInputDims, 4) + << ", Input Lux: " << buildDimString(expectedInputLuxDims, 1) + << ", Output: " << buildDimString(expectedOutputDims, 1) << " floats"; +} + +void AwbNN::initialise() +{ + Awb::initialise(); + + if (nnConfig_.enableNn) { + loadModel(); + if (!nnConfig_.enableNn) { + LOG(RPiAwb, Warning) << "Neural Network AWB failed to load - switch to Grey method"; + config_.greyWorld = true; + config_.sensitivityR = config_.sensitivityB = 1.0; + } + } +} + +void AwbNN::prepareStats() +{ + zones_.clear(); + /* + * LSC has already been applied to the stats in this pipeline, so stop + * any LSC compensation. We also ignore config_.fast in this version. + */ + generateStats(zones_, statistics_, 0.0, 0.0, getGlobalMetadata(), 0.0, 0.0, 0.0); + /* + * apply sensitivities, so values appear to come from our "canonical" + * sensor. + */ + for (auto &zone : zones_) { + zone.R *= config_.sensitivityR; + zone.B *= config_.sensitivityB; + } +} + +void AwbNN::transverseSearch(double t, double &r, double &b) +{ + int spanR = -1, spanB = -1; + config_.ctR.eval(t, &spanR); + config_.ctB.eval(t, &spanB); + + const int diff = 10; + double rDiff = config_.ctR.eval(t + diff, &spanR) - + config_.ctR.eval(t - diff, &spanR); + double bDiff = config_.ctB.eval(t + diff, &spanB) - + config_.ctB.eval(t - diff, &spanB); + + ipa::Pwl::Point transverse({ bDiff, -rDiff }); + if (transverse.length2() < 1e-6) + return; + + transverse = transverse / transverse.length(); + double transverseRange = config_.transverseNeg + config_.transversePos; + const int maxNumDeltas = 12; + int numDeltas = floor(transverseRange * 100 + 0.5) + 1; + numDeltas = numDeltas < 3 ? 3 : (numDeltas > maxNumDeltas ? maxNumDeltas : numDeltas); + + ipa::Pwl::Point points[maxNumDeltas]; + int bestPoint = 0; + + for (int i = 0; i < numDeltas; i++) { + points[i][0] = -config_.transverseNeg + + (transverseRange * i) / (numDeltas - 1); + ipa::Pwl::Point rbTest = ipa::Pwl::Point({ r, b }) + + transverse * points[i].x(); + double rTest = rbTest.x(), bTest = rbTest.y(); + double gainR = 1 / rTest, gainB = 1 / bTest; + double delta2Sum = computeDelta2Sum(gainR, gainB, 0.0, 0.0); + points[i][1] = delta2Sum; + if (points[i].y() < points[bestPoint].y()) + bestPoint = i; + } + + bestPoint = std::clamp(bestPoint, 1, numDeltas - 2); + ipa::Pwl::Point rbBest = ipa::Pwl::Point({ r, b }) + + transverse * interpolateQuadatric(points[bestPoint - 1], + points[bestPoint], + points[bestPoint + 1]); + double rBest = rbBest.x(), bBest = rbBest.y(); + + r = rBest, b = bBest; +} + +AwbNN::RGB AwbNN::processZone(AwbNN::RGB zone, float redGain, float blueGain) +{ + /* + * Renders the pixel at canonical network colour temperature + */ + RGB zoneGains = zone; + + zoneGains.R *= redGain; + zoneGains.G *= 1.0; + zoneGains.B *= blueGain; + + RGB zoneCcm; + + zoneCcm.R = nnConfig_.ccm[0] * zoneGains.R + nnConfig_.ccm[1] * zoneGains.G + nnConfig_.ccm[2] * zoneGains.B; + zoneCcm.G = nnConfig_.ccm[3] * zoneGains.R + nnConfig_.ccm[4] * zoneGains.G + nnConfig_.ccm[5] * zoneGains.B; + zoneCcm.B = nnConfig_.ccm[6] * zoneGains.R + nnConfig_.ccm[7] * zoneGains.G + nnConfig_.ccm[8] * zoneGains.B; + + return zoneCcm; +} + +void AwbNN::awbNN() +{ + float *inputData = interpreter_->typed_input_tensor<float>(0); + float *inputLux = interpreter_->typed_input_tensor<float>(1); + + float redGain = 1.0 / config_.ctR.eval(kNetworkCanonicalCT); + float blueGain = 1.0 / config_.ctB.eval(kNetworkCanonicalCT); + + for (uint i = 0; i < zoneSize_.height; i++) { + for (uint j = 0; j < zoneSize_.width; j++) { + uint zoneIdx = i * zoneSize_.width + j; + + RGB processedZone = processZone(zones_[zoneIdx] * (1.0 / 65535), redGain, blueGain); + uint baseIdx = zoneIdx * 3; + + inputData[baseIdx + 0] = static_cast<float>(processedZone.R); + inputData[baseIdx + 1] = static_cast<float>(processedZone.G); + inputData[baseIdx + 2] = static_cast<float>(processedZone.B); + } + } + + inputLux[0] = static_cast<float>(lux_); + + TfLiteStatus status = interpreter_->Invoke(); + if (status != kTfLiteOk) { + LOG(RPiAwb, Error) << "Model inference failed with status: " << status; + return; + } + + float *outputData = interpreter_->typed_output_tensor<float>(0); + + double t = outputData[0]; + + LOG(RPiAwb, Debug) << "Model output temperature: " << t; + + t = std::clamp(t, mode_->ctLo, mode_->ctHi); + + double r = config_.ctR.eval(t); + double b = config_.ctB.eval(t); + + transverseSearch(t, r, b); + + LOG(RPiAwb, Debug) << "After transverse search: Temperature: " << t << " Red gain: " << 1.0 / r << " Blue gain: " << 1.0 / b; + + asyncResults_.temperatureK = t; + asyncResults_.gainR = 1.0 / r * config_.sensitivityR; + asyncResults_.gainG = 1.0; + asyncResults_.gainB = 1.0 / b * config_.sensitivityB; +} + +void AwbNN::doAwb() +{ + prepareStats(); + if (zones_.size() == (zoneSize_.width * zoneSize_.height) && nnConfig_.enableNn) + awbNN(); + else + awbGrey(); + statistics_.reset(); +} + +/* Register algorithm with the system. */ +static Algorithm *create(Controller *controller) +{ + return (Algorithm *)new AwbNN(controller); +} +static RegisterAlgorithm reg(NAME, &create); + +} /* namespace RPiController */