[1/4] ipa: rpi: controller: awb: Separate Bayesian Awb into AwbBayes
diff mbox series

Message ID 20251024144049.3311-2-david.plowman@raspberrypi.com
State New
Headers show
Series
  • Raspberry Pi AWB using neural networks
Related show

Commit Message

David Plowman Oct. 24, 2025, 2:16 p.m. UTC
From: Peter Bailey <peter.bailey@raspberrypi.com>

Move parts of the AWB algorithm specific to the Bayesian algorithm into a
new class. This will make it easier to add new Awb algorithms in the future.

Signed-off-by: Peter Bailey <peter.bailey@raspberrypi.com>
---
 src/ipa/rpi/controller/meson.build       |   1 +
 src/ipa/rpi/controller/rpi/awb.cpp       | 409 +++------------------
 src/ipa/rpi/controller/rpi/awb.h         |  99 ++---
 src/ipa/rpi/controller/rpi/awb_bayes.cpp | 444 +++++++++++++++++++++++
 4 files changed, 533 insertions(+), 420 deletions(-)
 create mode 100644 src/ipa/rpi/controller/rpi/awb_bayes.cpp

Patch
diff mbox series

diff --git a/src/ipa/rpi/controller/meson.build b/src/ipa/rpi/controller/meson.build
index dde4ac12..73c93dca 100644
--- a/src/ipa/rpi/controller/meson.build
+++ b/src/ipa/rpi/controller/meson.build
@@ -10,6 +10,7 @@  rpi_ipa_controller_sources = files([
     'rpi/agc_channel.cpp',
     'rpi/alsc.cpp',
     'rpi/awb.cpp',
+    'rpi/awb_bayes.cpp',
     'rpi/black_level.cpp',
     'rpi/cac.cpp',
     'rpi/ccm.cpp',
diff --git a/src/ipa/rpi/controller/rpi/awb.cpp b/src/ipa/rpi/controller/rpi/awb.cpp
index 365b595f..de5fa59b 100644
--- a/src/ipa/rpi/controller/rpi/awb.cpp
+++ b/src/ipa/rpi/controller/rpi/awb.cpp
@@ -1,20 +1,14 @@ 
 /* SPDX-License-Identifier: BSD-2-Clause */
 /*
- * Copyright (C) 2019, Raspberry Pi Ltd
+ * Copyright (C) 2025, Raspberry Pi Ltd
  *
  * AWB control algorithm
  */
-
-#include <assert.h>
-#include <cmath>
-#include <functional>
-
-#include <libcamera/base/log.h>
+#include "awb.h"
 
 #include "../lux_status.h"
 
 #include "alsc_status.h"
-#include "awb.h"
 
 using namespace RPiController;
 using namespace libcamera;
@@ -23,39 +17,6 @@  LOG_DEFINE_CATEGORY(RPiAwb)
 
 constexpr double kDefaultCT = 4500.0;
 
-#define NAME "rpi.awb"
-
-/*
- * todo - the locking in this algorithm needs some tidying up as has been done
- * elsewhere (ALSC and AGC).
- */
-
-int AwbMode::read(const libcamera::YamlObject &params)
-{
-	auto value = params["lo"].get<double>();
-	if (!value)
-		return -EINVAL;
-	ctLo = *value;
-
-	value = params["hi"].get<double>();
-	if (!value)
-		return -EINVAL;
-	ctHi = *value;
-
-	return 0;
-}
-
-int AwbPrior::read(const libcamera::YamlObject &params)
-{
-	auto value = params["lux"].get<double>();
-	if (!value)
-		return -EINVAL;
-	lux = *value;
-
-	prior = params["prior"].get<ipa::Pwl>(ipa::Pwl{});
-	return prior.empty() ? -EINVAL : 0;
-}
-
 static int readCtCurve(ipa::Pwl &ctR, ipa::Pwl &ctB, const libcamera::YamlObject &params)
 {
 	if (params.size() % 3) {
@@ -92,11 +53,25 @@  static int readCtCurve(ipa::Pwl &ctR, ipa::Pwl &ctB, const libcamera::YamlObject
 	return 0;
 }
 
+int AwbMode::read(const libcamera::YamlObject &params)
+{
+	auto value = params["lo"].get<double>();
+	if (!value)
+		return -EINVAL;
+	ctLo = *value;
+
+	value = params["hi"].get<double>();
+	if (!value)
+		return -EINVAL;
+	ctHi = *value;
+
+	return 0;
+}
+
 int AwbConfig::read(const libcamera::YamlObject &params)
 {
 	int ret;
 
-	bayes = params["bayes"].get<int>(1);
 	framePeriod = params["frame_period"].get<uint16_t>(10);
 	startupFrames = params["startup_frames"].get<uint16_t>(10);
 	convergenceFrames = params["convergence_frames"].get<unsigned int>(3);
@@ -111,23 +86,6 @@  int AwbConfig::read(const libcamera::YamlObject &params)
 		ctBInverse = ctB.inverse().first;
 	}
 
-	if (params.contains("priors")) {
-		for (const auto &p : params["priors"].asList()) {
-			AwbPrior prior;
-			ret = prior.read(p);
-			if (ret)
-				return ret;
-			if (!priors.empty() && prior.lux <= priors.back().lux) {
-				LOG(RPiAwb, Error) << "AwbConfig: Prior must be ordered in increasing lux value";
-				return -EINVAL;
-			}
-			priors.push_back(prior);
-		}
-		if (priors.empty()) {
-			LOG(RPiAwb, Error) << "AwbConfig: no AWB priors configured";
-			return -EINVAL;
-		}
-	}
 	if (params.contains("modes")) {
 		for (const auto &[key, value] : params["modes"].asDict()) {
 			ret = modes[key].read(value);
@@ -142,13 +100,10 @@  int AwbConfig::read(const libcamera::YamlObject &params)
 		}
 	}
 
-	minPixels = params["min_pixels"].get<double>(16.0);
-	minG = params["min_G"].get<uint16_t>(32);
-	minRegions = params["min_regions"].get<uint32_t>(10);
 	deltaLimit = params["delta_limit"].get<double>(0.2);
-	coarseStep = params["coarse_step"].get<double>(0.2);
 	transversePos = params["transverse_pos"].get<double>(0.01);
 	transverseNeg = params["transverse_neg"].get<double>(0.01);
+
 	if (transversePos <= 0 || transverseNeg <= 0) {
 		LOG(RPiAwb, Error) << "AwbConfig: transverse_pos/neg must be > 0";
 		return -EINVAL;
@@ -157,29 +112,21 @@  int AwbConfig::read(const libcamera::YamlObject &params)
 	sensitivityR = params["sensitivity_r"].get<double>(1.0);
 	sensitivityB = params["sensitivity_b"].get<double>(1.0);
 
-	if (bayes) {
-		if (ctR.empty() || ctB.empty() || priors.empty() ||
-		    defaultMode == nullptr) {
-			LOG(RPiAwb, Warning)
-				<< "Bayesian AWB mis-configured - switch to Grey method";
-			bayes = false;
-		}
-	}
-	whitepointR = params["whitepoint_r"].get<double>(0.0);
-	whitepointB = params["whitepoint_b"].get<double>(0.0);
-	if (bayes == false)
+	if (hasCtCurve() && defaultMode != nullptr) {
+		greyWorld = false;
+	} else {
+		greyWorld = true;
 		sensitivityR = sensitivityB = 1.0; /* nor do sensitivities make any sense */
-	/*
-	 * The biasProportion parameter adds a small proportion of the counted
-	 * pixles to a region biased to the biasCT colour temperature.
-	 *
-	 * A typical value for biasProportion would be between 0.05 to 0.1.
-	 */
-	biasProportion = params["bias_proportion"].get<double>(0.0);
-	biasCT = params["bias_ct"].get<double>(kDefaultCT);
+	}
+
 	return 0;
 }
 
+bool AwbConfig::hasCtCurve() const
+{
+	return !ctR.empty() && !ctB.empty();
+}
+
 Awb::Awb(Controller *controller)
 	: AwbAlgorithm(controller)
 {
@@ -199,16 +146,6 @@  Awb::~Awb()
 	asyncThread_.join();
 }
 
-char const *Awb::name() const
-{
-	return NAME;
-}
-
-int Awb::read(const libcamera::YamlObject &params)
-{
-	return config_.read(params);
-}
-
 void Awb::initialise()
 {
 	frameCount_ = framePhase_ = 0;
@@ -217,7 +154,7 @@  void Awb::initialise()
 	 * just in case the first few frames don't have anything meaningful in
 	 * them.
 	 */
-	if (!config_.ctR.empty() && !config_.ctB.empty()) {
+	if (!config_.greyWorld) {
 		syncResults_.temperatureK = config_.ctR.domain().clamp(4000);
 		syncResults_.gainR = 1.0 / config_.ctR.eval(syncResults_.temperatureK);
 		syncResults_.gainG = 1.0;
@@ -282,7 +219,7 @@  void Awb::setManualGains(double manualR, double manualB)
 		syncResults_.gainR = prevSyncResults_.gainR = manualR_;
 		syncResults_.gainG = prevSyncResults_.gainG = 1.0;
 		syncResults_.gainB = prevSyncResults_.gainB = manualB_;
-		if (config_.bayes) {
+		if (!config_.greyWorld) {
 			/* Also estimate the best corresponding colour temperature from the curves. */
 			double ctR = config_.ctRInverse.eval(config_.ctRInverse.domain().clamp(1 / manualR_));
 			double ctB = config_.ctBInverse.eval(config_.ctBInverse.domain().clamp(1 / manualB_));
@@ -294,7 +231,7 @@  void Awb::setManualGains(double manualR, double manualB)
 
 void Awb::setColourTemperature(double temperatureK)
 {
-	if (!config_.bayes) {
+	if (config_.greyWorld) {
 		LOG(RPiAwb, Warning) << "AWB uncalibrated - cannot set colour temperature";
 		return;
 	}
@@ -433,10 +370,10 @@  void Awb::asyncFunc()
 	}
 }
 
-static void generateStats(std::vector<Awb::RGB> &zones,
-			  StatisticsPtr &stats, double minPixels,
-			  double minG, Metadata &globalMetadata,
-			  double biasProportion, double biasCtR, double biasCtB)
+void Awb::generateStats(std::vector<Awb::RGB> &zones,
+			StatisticsPtr &stats, double minPixels,
+			double minG, Metadata &globalMetadata,
+			double biasProportion, double biasCtR, double biasCtB)
 {
 	std::scoped_lock<RPiController::Metadata> l(globalMetadata);
 
@@ -450,9 +387,9 @@  static void generateStats(std::vector<Awb::RGB> &zones,
 			zone.R = region.val.rSum / region.counted;
 			zone.B = region.val.bSum / region.counted;
 			/*
-			 * Add some bias samples to allow the search to tend to a
-			 * bias CT in failure cases.
-			 */
+			* Add some bias samples to allow the search to tend to a
+			* bias CT in failure cases.
+			*/
 			const unsigned int proportion = biasProportion * region.counted;
 			zone.R += proportion * biasCtR;
 			zone.B += proportion * biasCtB;
@@ -469,29 +406,7 @@  static void generateStats(std::vector<Awb::RGB> &zones,
 	}
 }
 
-void Awb::prepareStats()
-{
-	zones_.clear();
-	/*
-	 * LSC has already been applied to the stats in this pipeline, so stop
-	 * any LSC compensation.  We also ignore config_.fast in this version.
-	 */
-	const double biasCtR = config_.bayes ? config_.ctR.eval(config_.biasCT) : 0;
-	const double biasCtB = config_.bayes ? config_.ctB.eval(config_.biasCT) : 0;
-	generateStats(zones_, statistics_, config_.minPixels,
-		      config_.minG, getGlobalMetadata(),
-		      config_.biasProportion, biasCtR, biasCtB);
-	/*
-	 * apply sensitivities, so values appear to come from our "canonical"
-	 * sensor.
-	 */
-	for (auto &zone : zones_) {
-		zone.R *= config_.sensitivityR;
-		zone.B *= config_.sensitivityB;
-	}
-}
-
-double Awb::computeDelta2Sum(double gainR, double gainB)
+double Awb::computeDelta2Sum(double gainR, double gainB, double whitepointR, double whitepointB)
 {
 	/*
 	 * Compute the sum of the squared colour error (non-greyness) as it
@@ -499,8 +414,8 @@  double Awb::computeDelta2Sum(double gainR, double gainB)
 	 */
 	double delta2Sum = 0;
 	for (auto &z : zones_) {
-		double deltaR = gainR * z.R - 1 - config_.whitepointR;
-		double deltaB = gainB * z.B - 1 - config_.whitepointB;
+		double deltaR = gainR * z.R - 1 - whitepointR;
+		double deltaB = gainB * z.B - 1 - whitepointB;
 		double delta2 = deltaR * deltaR + deltaB * deltaB;
 		/* LOG(RPiAwb, Debug) << "deltaR " << deltaR << " deltaB " << deltaB << " delta2 " << delta2; */
 		delta2 = std::min(delta2, config_.deltaLimit);
@@ -509,39 +424,14 @@  double Awb::computeDelta2Sum(double gainR, double gainB)
 	return delta2Sum;
 }
 
-ipa::Pwl Awb::interpolatePrior()
+double Awb::interpolateQuadatric(libcamera::ipa::Pwl::Point const &a,
+				 libcamera::ipa::Pwl::Point const &b,
+				 libcamera::ipa::Pwl::Point const &c)
 {
 	/*
-	 * Interpolate the prior log likelihood function for our current lux
-	 * value.
-	 */
-	if (lux_ <= config_.priors.front().lux)
-		return config_.priors.front().prior;
-	else if (lux_ >= config_.priors.back().lux)
-		return config_.priors.back().prior;
-	else {
-		int idx = 0;
-		/* find which two we lie between */
-		while (config_.priors[idx + 1].lux < lux_)
-			idx++;
-		double lux0 = config_.priors[idx].lux,
-		       lux1 = config_.priors[idx + 1].lux;
-		return ipa::Pwl::combine(config_.priors[idx].prior,
-				    config_.priors[idx + 1].prior,
-				    [&](double /*x*/, double y0, double y1) {
-					    return y0 + (y1 - y0) *
-							(lux_ - lux0) / (lux1 - lux0);
-				    });
-	}
-}
-
-static double interpolateQuadatric(ipa::Pwl::Point const &a, ipa::Pwl::Point const &b,
-				   ipa::Pwl::Point const &c)
-{
-	/*
-	 * Given 3 points on a curve, find the extremum of the function in that
-	 * interval by fitting a quadratic.
-	 */
+	* Given 3 points on a curve, find the extremum of the function in that
+	* interval by fitting a quadratic.
+	*/
 	const double eps = 1e-3;
 	ipa::Pwl::Point ca = c - a, ba = b - a;
 	double denominator = 2 * (ba.y() * ca.x() - ca.y() * ba.x());
@@ -554,180 +444,6 @@  static double interpolateQuadatric(ipa::Pwl::Point const &a, ipa::Pwl::Point con
 	return a.y() < c.y() - eps ? a.x() : (c.y() < a.y() - eps ? c.x() : b.x());
 }
 
-double Awb::coarseSearch(ipa::Pwl const &prior)
-{
-	points_.clear(); /* assume doesn't deallocate memory */
-	size_t bestPoint = 0;
-	double t = mode_->ctLo;
-	int spanR = 0, spanB = 0;
-	/* Step down the CT curve evaluating log likelihood. */
-	while (true) {
-		double r = config_.ctR.eval(t, &spanR);
-		double b = config_.ctB.eval(t, &spanB);
-		double gainR = 1 / r, gainB = 1 / b;
-		double delta2Sum = computeDelta2Sum(gainR, gainB);
-		double priorLogLikelihood = prior.eval(prior.domain().clamp(t));
-		double finalLogLikelihood = delta2Sum - priorLogLikelihood;
-		LOG(RPiAwb, Debug)
-			<< "t: " << t << " gain R " << gainR << " gain B "
-			<< gainB << " delta2_sum " << delta2Sum
-			<< " prior " << priorLogLikelihood << " final "
-			<< finalLogLikelihood;
-		points_.push_back(ipa::Pwl::Point({ t, finalLogLikelihood }));
-		if (points_.back().y() < points_[bestPoint].y())
-			bestPoint = points_.size() - 1;
-		if (t == mode_->ctHi)
-			break;
-		/* for even steps along the r/b curve scale them by the current t */
-		t = std::min(t + t / 10 * config_.coarseStep, mode_->ctHi);
-	}
-	t = points_[bestPoint].x();
-	LOG(RPiAwb, Debug) << "Coarse search found CT " << t;
-	/*
-	 * We have the best point of the search, but refine it with a quadratic
-	 * interpolation around its neighbours.
-	 */
-	if (points_.size() > 2) {
-		unsigned long bp = std::min(bestPoint, points_.size() - 2);
-		bestPoint = std::max(1UL, bp);
-		t = interpolateQuadatric(points_[bestPoint - 1],
-					 points_[bestPoint],
-					 points_[bestPoint + 1]);
-		LOG(RPiAwb, Debug)
-			<< "After quadratic refinement, coarse search has CT "
-			<< t;
-	}
-	return t;
-}
-
-void Awb::fineSearch(double &t, double &r, double &b, ipa::Pwl const &prior)
-{
-	int spanR = -1, spanB = -1;
-	config_.ctR.eval(t, &spanR);
-	config_.ctB.eval(t, &spanB);
-	double step = t / 10 * config_.coarseStep * 0.1;
-	int nsteps = 5;
-	double rDiff = config_.ctR.eval(t + nsteps * step, &spanR) -
-		       config_.ctR.eval(t - nsteps * step, &spanR);
-	double bDiff = config_.ctB.eval(t + nsteps * step, &spanB) -
-		       config_.ctB.eval(t - nsteps * step, &spanB);
-	ipa::Pwl::Point transverse({ bDiff, -rDiff });
-	if (transverse.length2() < 1e-6)
-		return;
-	/*
-	 * unit vector orthogonal to the b vs. r function (pointing outwards
-	 * with r and b increasing)
-	 */
-	transverse = transverse / transverse.length();
-	double bestLogLikelihood = 0, bestT = 0, bestR = 0, bestB = 0;
-	double transverseRange = config_.transverseNeg + config_.transversePos;
-	const int maxNumDeltas = 12;
-	/* a transverse step approximately every 0.01 r/b units */
-	int numDeltas = floor(transverseRange * 100 + 0.5) + 1;
-	numDeltas = numDeltas < 3 ? 3 : (numDeltas > maxNumDeltas ? maxNumDeltas : numDeltas);
-	/*
-	 * Step down CT curve. March a bit further if the transverse range is
-	 * large.
-	 */
-	nsteps += numDeltas;
-	for (int i = -nsteps; i <= nsteps; i++) {
-		double tTest = t + i * step;
-		double priorLogLikelihood =
-			prior.eval(prior.domain().clamp(tTest));
-		double rCurve = config_.ctR.eval(tTest, &spanR);
-		double bCurve = config_.ctB.eval(tTest, &spanB);
-		/* x will be distance off the curve, y the log likelihood there */
-		ipa::Pwl::Point points[maxNumDeltas];
-		int bestPoint = 0;
-		/* Take some measurements transversely *off* the CT curve. */
-		for (int j = 0; j < numDeltas; j++) {
-			points[j][0] = -config_.transverseNeg +
-				       (transverseRange * j) / (numDeltas - 1);
-			ipa::Pwl::Point rbTest = ipa::Pwl::Point({ rCurve, bCurve }) +
-						 transverse * points[j].x();
-			double rTest = rbTest.x(), bTest = rbTest.y();
-			double gainR = 1 / rTest, gainB = 1 / bTest;
-			double delta2Sum = computeDelta2Sum(gainR, gainB);
-			points[j][1] = delta2Sum - priorLogLikelihood;
-			LOG(RPiAwb, Debug)
-				<< "At t " << tTest << " r " << rTest << " b "
-				<< bTest << ": " << points[j].y();
-			if (points[j].y() < points[bestPoint].y())
-				bestPoint = j;
-		}
-		/*
-		 * We have NUM_DELTAS points transversely across the CT curve,
-		 * now let's do a quadratic interpolation for the best result.
-		 */
-		bestPoint = std::max(1, std::min(bestPoint, numDeltas - 2));
-		ipa::Pwl::Point rbTest = ipa::Pwl::Point({ rCurve, bCurve }) +
-					 transverse * interpolateQuadatric(points[bestPoint - 1],
-									   points[bestPoint],
-									   points[bestPoint + 1]);
-		double rTest = rbTest.x(), bTest = rbTest.y();
-		double gainR = 1 / rTest, gainB = 1 / bTest;
-		double delta2Sum = computeDelta2Sum(gainR, gainB);
-		double finalLogLikelihood = delta2Sum - priorLogLikelihood;
-		LOG(RPiAwb, Debug)
-			<< "Finally "
-			<< tTest << " r " << rTest << " b " << bTest << ": "
-			<< finalLogLikelihood
-			<< (finalLogLikelihood < bestLogLikelihood ? " BEST" : "");
-		if (bestT == 0 || finalLogLikelihood < bestLogLikelihood)
-			bestLogLikelihood = finalLogLikelihood,
-			bestT = tTest, bestR = rTest, bestB = bTest;
-	}
-	t = bestT, r = bestR, b = bestB;
-	LOG(RPiAwb, Debug)
-		<< "Fine search found t " << t << " r " << r << " b " << b;
-}
-
-void Awb::awbBayes()
-{
-	/*
-	 * May as well divide out G to save computeDelta2Sum from doing it over
-	 * and over.
-	 */
-	for (auto &z : zones_)
-		z.R = z.R / (z.G + 1), z.B = z.B / (z.G + 1);
-	/*
-	 * Get the current prior, and scale according to how many zones are
-	 * valid... not entirely sure about this.
-	 */
-	ipa::Pwl prior = interpolatePrior();
-	prior *= zones_.size() / (double)(statistics_->awbRegions.numRegions());
-	prior.map([](double x, double y) {
-		LOG(RPiAwb, Debug) << "(" << x << "," << y << ")";
-	});
-	double t = coarseSearch(prior);
-	double r = config_.ctR.eval(t);
-	double b = config_.ctB.eval(t);
-	LOG(RPiAwb, Debug)
-		<< "After coarse search: r " << r << " b " << b << " (gains r "
-		<< 1 / r << " b " << 1 / b << ")";
-	/*
-	 * Not entirely sure how to handle the fine search yet. Mostly the
-	 * estimated CT is already good enough, but the fine search allows us to
-	 * wander transverely off the CT curve. Under some illuminants, where
-	 * there may be more or less green light, this may prove beneficial,
-	 * though I probably need more real datasets before deciding exactly how
-	 * this should be controlled and tuned.
-	 */
-	fineSearch(t, r, b, prior);
-	LOG(RPiAwb, Debug)
-		<< "After fine search: r " << r << " b " << b << " (gains r "
-		<< 1 / r << " b " << 1 / b << ")";
-	/*
-	 * Write results out for the main thread to pick up. Remember to adjust
-	 * the gains from the ones that the "canonical sensor" would require to
-	 * the ones needed by *this* sensor.
-	 */
-	asyncResults_.temperatureK = t;
-	asyncResults_.gainR = 1.0 / r * config_.sensitivityR;
-	asyncResults_.gainG = 1.0;
-	asyncResults_.gainB = 1.0 / b * config_.sensitivityB;
-}
-
 void Awb::awbGrey()
 {
 	LOG(RPiAwb, Debug) << "Grey world AWB";
@@ -765,32 +481,3 @@  void Awb::awbGrey()
 	asyncResults_.gainG = 1.0;
 	asyncResults_.gainB = gainB;
 }
-
-void Awb::doAwb()
-{
-	prepareStats();
-	LOG(RPiAwb, Debug) << "Valid zones: " << zones_.size();
-	if (zones_.size() > config_.minRegions) {
-		if (config_.bayes)
-			awbBayes();
-		else
-			awbGrey();
-		LOG(RPiAwb, Debug)
-			<< "CT found is "
-			<< asyncResults_.temperatureK
-			<< " with gains r " << asyncResults_.gainR
-			<< " and b " << asyncResults_.gainB;
-	}
-	/*
-	 * we're done with these; we may as well relinquish our hold on the
-	 * pointer.
-	 */
-	statistics_.reset();
-}
-
-/* Register algorithm with the system. */
-static Algorithm *create(Controller *controller)
-{
-	return (Algorithm *)new Awb(controller);
-}
-static RegisterAlgorithm reg(NAME, &create);
diff --git a/src/ipa/rpi/controller/rpi/awb.h b/src/ipa/rpi/controller/rpi/awb.h
index 2fb91254..8b2d8d1d 100644
--- a/src/ipa/rpi/controller/rpi/awb.h
+++ b/src/ipa/rpi/controller/rpi/awb.h
@@ -1,42 +1,33 @@ 
 /* SPDX-License-Identifier: BSD-2-Clause */
 /*
- * Copyright (C) 2019, Raspberry Pi Ltd
+ * Copyright (C) 2025, Raspberry Pi Ltd
  *
  * AWB control algorithm
  */
 #pragma once
 
-#include <mutex>
 #include <condition_variable>
+#include <mutex>
 #include <thread>
 
-#include <libcamera/geometry.h>
-
 #include "../awb_algorithm.h"
 #include "../awb_status.h"
-#include "../statistics.h"
-
 #include "libipa/pwl.h"
 
 namespace RPiController {
 
-/* Control algorithm to perform AWB calculations. */
-
 struct AwbMode {
 	int read(const libcamera::YamlObject &params);
 	double ctLo; /* low CT value for search */
 	double ctHi; /* high CT value for search */
 };
 
-struct AwbPrior {
-	int read(const libcamera::YamlObject &params);
-	double lux; /* lux level */
-	libcamera::ipa::Pwl prior; /* maps CT to prior log likelihood for this lux level */
-};
-
 struct AwbConfig {
-	AwbConfig() : defaultMode(nullptr) {}
+	AwbConfig()
+		: defaultMode(nullptr) {}
 	int read(const libcamera::YamlObject &params);
+	bool hasCtCurve() const;
+
 	/* Only repeat the AWB calculation every "this many" frames */
 	uint16_t framePeriod;
 	/* number of initial frames for which speed taken as 1.0 (maximum) */
@@ -47,27 +38,13 @@  struct AwbConfig {
 	libcamera::ipa::Pwl ctB; /* function maps CT to b (= B/G) */
 	libcamera::ipa::Pwl ctRInverse; /* inverse of ctR */
 	libcamera::ipa::Pwl ctBInverse; /* inverse of ctB */
-	/* table of illuminant priors at different lux levels */
-	std::vector<AwbPrior> priors;
+
 	/* AWB "modes" (determines the search range) */
 	std::map<std::string, AwbMode> modes;
 	AwbMode *defaultMode; /* mode used if no mode selected */
-	/*
-	 * minimum proportion of pixels counted within AWB region for it to be
-	 * "useful"
-	 */
-	double minPixels;
-	/* minimum G value of those pixels, to be regarded a "useful" */
-	uint16_t minG;
-	/*
-	 * number of AWB regions that must be "useful" in order to do the AWB
-	 * calculation
-	 */
-	uint32_t minRegions;
+
 	/* clamp on colour error term (so as not to penalise non-grey excessively) */
 	double deltaLimit;
-	/* step size control in coarse search */
-	double coarseStep;
 	/* how far to wander off CT curve towards "more purple" */
 	double transversePos;
 	/* how far to wander off CT curve towards "more green" */
@@ -82,14 +59,8 @@  struct AwbConfig {
 	 * sensor's B/G)
 	 */
 	double sensitivityB;
-	/* The whitepoint (which we normally "aim" for) can be moved. */
-	double whitepointR;
-	double whitepointB;
-	bool bayes; /* use Bayesian algorithm */
-	/* proportion of counted samples to add for the search bias */
-	double biasProportion;
-	/* CT target for the search bias */
-	double biasCT;
+
+	bool greyWorld; /* don't use the ct curve when in grey world mode */
 };
 
 class Awb : public AwbAlgorithm
@@ -97,9 +68,7 @@  class Awb : public AwbAlgorithm
 public:
 	Awb(Controller *controller = NULL);
 	~Awb();
-	char const *name() const override;
-	void initialise() override;
-	int read(const libcamera::YamlObject &params) override;
+	virtual void initialise() override;
 	unsigned int getConvergenceFrames() const override;
 	void initialValues(double &gainR, double &gainB) override;
 	void setMode(std::string const &name) override;
@@ -110,6 +79,11 @@  public:
 	void switchMode(CameraMode const &cameraMode, Metadata *metadata) override;
 	void prepare(Metadata *imageMetadata) override;
 	void process(StatisticsPtr &stats, Metadata *imageMetadata) override;
+
+	static double interpolateQuadatric(libcamera::ipa::Pwl::Point const &a,
+					   libcamera::ipa::Pwl::Point const &b,
+					   libcamera::ipa::Pwl::Point const &c);
+
 	struct RGB {
 		RGB(double r = 0, double g = 0, double b = 0)
 			: R(r), G(g), B(b)
@@ -123,10 +97,30 @@  public:
 		}
 	};
 
-private:
-	bool isAutoEnabled() const;
+protected:
 	/* configuration is read-only, and available to both threads */
 	AwbConfig config_;
+	/*
+	 * The following are for the asynchronous thread to use, though the main
+	 * thread can set/reset them if the async thread is known to be idle:
+	 */
+	std::vector<RGB> zones_;
+	StatisticsPtr statistics_;
+	double lux_;
+	AwbMode *mode_;
+	AwbStatus asyncResults_;
+
+	virtual void doAwb() = 0;
+	virtual void prepareStats() = 0;
+	double computeDelta2Sum(double gainR, double gainB, double whitepointR, double whitepointB);
+	void awbGrey();
+	static void generateStats(std::vector<Awb::RGB> &zones,
+				  StatisticsPtr &stats, double minPixels,
+				  double minG, Metadata &globalMetadata,
+				  double biasProportion, double biasCtR, double biasCtB);
+
+private:
+	bool isAutoEnabled() const;
 	std::thread asyncThread_;
 	void asyncFunc(); /* asynchronous thread function */
 	std::mutex mutex_;
@@ -152,6 +146,7 @@  private:
 	AwbStatus syncResults_;
 	AwbStatus prevSyncResults_;
 	std::string modeName_;
+
 	/*
 	 * The following are for the asynchronous thread to use, though the main
 	 * thread can set/reset them if the async thread is known to be idle:
@@ -159,20 +154,6 @@  private:
 	void restartAsync(StatisticsPtr &stats, double lux);
 	/* copy out the results from the async thread so that it can be restarted */
 	void fetchAsyncResults();
-	StatisticsPtr statistics_;
-	AwbMode *mode_;
-	double lux_;
-	AwbStatus asyncResults_;
-	void doAwb();
-	void awbBayes();
-	void awbGrey();
-	void prepareStats();
-	double computeDelta2Sum(double gainR, double gainB);
-	libcamera::ipa::Pwl interpolatePrior();
-	double coarseSearch(libcamera::ipa::Pwl const &prior);
-	void fineSearch(double &t, double &r, double &b, libcamera::ipa::Pwl const &prior);
-	std::vector<RGB> zones_;
-	std::vector<libcamera::ipa::Pwl::Point> points_;
 	/* manual r setting */
 	double manualR_;
 	/* manual b setting */
@@ -196,4 +177,4 @@  static inline Awb::RGB operator*(Awb::RGB const &rgb, double d)
 	return d * rgb;
 }
 
-} /* namespace RPiController */
+} // namespace RPiController
diff --git a/src/ipa/rpi/controller/rpi/awb_bayes.cpp b/src/ipa/rpi/controller/rpi/awb_bayes.cpp
new file mode 100644
index 00000000..09233cec
--- /dev/null
+++ b/src/ipa/rpi/controller/rpi/awb_bayes.cpp
@@ -0,0 +1,444 @@ 
+/* SPDX-License-Identifier: BSD-2-Clause */
+/*
+ * Copyright (C) 2019, Raspberry Pi Ltd
+ *
+ * AWB control algorithm
+ */
+
+#include <assert.h>
+#include <cmath>
+#include <condition_variable>
+#include <functional>
+#include <mutex>
+#include <thread>
+
+#include <libcamera/base/log.h>
+
+#include <libcamera/geometry.h>
+
+#include "../awb_algorithm.h"
+#include "../awb_status.h"
+#include "../lux_status.h"
+#include "../statistics.h"
+#include "libipa/pwl.h"
+
+#include "alsc_status.h"
+#include "awb.h"
+
+using namespace libcamera;
+
+LOG_DECLARE_CATEGORY(RPiAwb)
+
+constexpr double kDefaultCT = 4500.0;
+
+#define NAME "rpi.awb"
+
+/*
+ * todo - the locking in this algorithm needs some tidying up as has been done
+ * elsewhere (ALSC and AGC).
+ */
+
+namespace RPiController {
+
+struct AwbPrior {
+	int read(const libcamera::YamlObject &params);
+	double lux; /* lux level */
+	libcamera::ipa::Pwl prior; /* maps CT to prior log likelihood for this lux level */
+};
+
+struct AwbBayesConfig {
+	AwbBayesConfig() {}
+	int read(const libcamera::YamlObject &params, AwbConfig &config);
+	/* table of illuminant priors at different lux levels */
+	std::vector<AwbPrior> priors;
+	/*
+	 * minimum proportion of pixels counted within AWB region for it to be
+	 * "useful"
+	 */
+	double minPixels;
+	/* minimum G value of those pixels, to be regarded a "useful" */
+	uint16_t minG;
+	/*
+	 * number of AWB regions that must be "useful" in order to do the AWB
+	 * calculation
+	 */
+	uint32_t minRegions;
+	/* step size control in coarse search */
+	double coarseStep;
+	/* The whitepoint (which we normally "aim" for) can be moved. */
+	double whitepointR;
+	double whitepointB;
+	bool bayes; /* use Bayesian algorithm */
+	/* proportion of counted samples to add for the search bias */
+	double biasProportion;
+	/* CT target for the search bias */
+	double biasCT;
+};
+
+class AwbBayes : public Awb
+{
+public:
+	AwbBayes(Controller *controller = NULL);
+	~AwbBayes();
+	char const *name() const override;
+	int read(const libcamera::YamlObject &params) override;
+
+protected:
+	void prepareStats() override;
+	void doAwb() override;
+
+private:
+	AwbBayesConfig bayesConfig_;
+	void awbBayes();
+	libcamera::ipa::Pwl interpolatePrior();
+	double coarseSearch(libcamera::ipa::Pwl const &prior);
+	void fineSearch(double &t, double &r, double &b, libcamera::ipa::Pwl const &prior);
+	std::vector<libcamera::ipa::Pwl::Point> points_;
+};
+
+int AwbPrior::read(const libcamera::YamlObject &params)
+{
+	auto value = params["lux"].get<double>();
+	if (!value)
+		return -EINVAL;
+	lux = *value;
+
+	prior = params["prior"].get<ipa::Pwl>(ipa::Pwl{});
+	return prior.empty() ? -EINVAL : 0;
+}
+
+int AwbBayesConfig::read(const libcamera::YamlObject &params, AwbConfig &config)
+{
+	int ret;
+
+	bayes = params["bayes"].get<int>(1);
+
+	if (params.contains("priors")) {
+		for (const auto &p : params["priors"].asList()) {
+			AwbPrior prior;
+			ret = prior.read(p);
+			if (ret)
+				return ret;
+			if (!priors.empty() && prior.lux <= priors.back().lux) {
+				LOG(RPiAwb, Error) << "AwbConfig: Prior must be ordered in increasing lux value";
+				return -EINVAL;
+			}
+			priors.push_back(prior);
+		}
+		if (priors.empty()) {
+			LOG(RPiAwb, Error) << "AwbConfig: no AWB priors configured";
+			return -EINVAL;
+		}
+	}
+
+	minPixels = params["min_pixels"].get<double>(16.0);
+	minG = params["min_G"].get<uint16_t>(32);
+	minRegions = params["min_regions"].get<uint32_t>(10);
+	coarseStep = params["coarse_step"].get<double>(0.2);
+
+	if (bayes) {
+		if (!config.hasCtCurve() || priors.empty() ||
+		    config.defaultMode == nullptr) {
+			LOG(RPiAwb, Warning)
+				<< "Bayesian AWB mis-configured - switch to Grey method";
+			bayes = false;
+		}
+	}
+	whitepointR = params["whitepoint_r"].get<double>(0.0);
+	whitepointB = params["whitepoint_b"].get<double>(0.0);
+	if (bayes == false) {
+		config.sensitivityR = config.sensitivityB = 1.0; /* nor do sensitivities make any sense */
+		config.greyWorld = true; /* prevent the ct curve being used in manual mode */
+	}
+	/*
+	 * The biasProportion parameter adds a small proportion of the counted
+	 * pixles to a region biased to the biasCT colour temperature.
+	 *
+	 * A typical value for biasProportion would be between 0.05 to 0.1.
+	 */
+	biasProportion = params["bias_proportion"].get<double>(0.0);
+	biasCT = params["bias_ct"].get<double>(kDefaultCT);
+	return 0;
+}
+
+AwbBayes::AwbBayes(Controller *controller)
+	: Awb(controller)
+{
+}
+
+AwbBayes::~AwbBayes()
+{
+}
+
+char const *AwbBayes::name() const
+{
+	return NAME;
+}
+
+int AwbBayes::read(const libcamera::YamlObject &params)
+{
+	int ret;
+
+	ret = config_.read(params);
+	if (ret)
+		return ret;
+
+	ret = bayesConfig_.read(params, config_);
+	if (ret)
+		return ret;
+
+	return 0;
+}
+
+void AwbBayes::prepareStats()
+{
+	zones_.clear();
+	/*
+	 * LSC has already been applied to the stats in this pipeline, so stop
+	 * any LSC compensation.  We also ignore config_.fast in this version.
+	 */
+	const double biasCtR = bayesConfig_.bayes ? config_.ctR.eval(bayesConfig_.biasCT) : 0;
+	const double biasCtB = bayesConfig_.bayes ? config_.ctB.eval(bayesConfig_.biasCT) : 0;
+	generateStats(zones_, statistics_, bayesConfig_.minPixels,
+		      bayesConfig_.minG, getGlobalMetadata(),
+		      bayesConfig_.biasProportion, biasCtR, biasCtB);
+	/*
+	 * apply sensitivities, so values appear to come from our "canonical"
+	 * sensor.
+	 */
+	for (auto &zone : zones_) {
+		zone.R *= config_.sensitivityR;
+		zone.B *= config_.sensitivityB;
+	}
+}
+
+ipa::Pwl AwbBayes::interpolatePrior()
+{
+	/*
+	 * Interpolate the prior log likelihood function for our current lux
+	 * value.
+	 */
+	if (lux_ <= bayesConfig_.priors.front().lux)
+		return bayesConfig_.priors.front().prior;
+	else if (lux_ >= bayesConfig_.priors.back().lux)
+		return bayesConfig_.priors.back().prior;
+	else {
+		int idx = 0;
+		/* find which two we lie between */
+		while (bayesConfig_.priors[idx + 1].lux < lux_)
+			idx++;
+		double lux0 = bayesConfig_.priors[idx].lux,
+		       lux1 = bayesConfig_.priors[idx + 1].lux;
+		return ipa::Pwl::combine(bayesConfig_.priors[idx].prior,
+					 bayesConfig_.priors[idx + 1].prior,
+					 [&](double /*x*/, double y0, double y1) {
+						 return y0 + (y1 - y0) *
+								     (lux_ - lux0) / (lux1 - lux0);
+					 });
+	}
+}
+
+double AwbBayes::coarseSearch(ipa::Pwl const &prior)
+{
+	points_.clear(); /* assume doesn't deallocate memory */
+	size_t bestPoint = 0;
+	double t = mode_->ctLo;
+	int spanR = 0, spanB = 0;
+	/* Step down the CT curve evaluating log likelihood. */
+	while (true) {
+		double r = config_.ctR.eval(t, &spanR);
+		double b = config_.ctB.eval(t, &spanB);
+		double gainR = 1 / r, gainB = 1 / b;
+		double delta2Sum = computeDelta2Sum(gainR, gainB, bayesConfig_.whitepointR, bayesConfig_.whitepointB);
+		double priorLogLikelihood = prior.eval(prior.domain().clamp(t));
+		double finalLogLikelihood = delta2Sum - priorLogLikelihood;
+		LOG(RPiAwb, Debug)
+			<< "t: " << t << " gain R " << gainR << " gain B "
+			<< gainB << " delta2_sum " << delta2Sum
+			<< " prior " << priorLogLikelihood << " final "
+			<< finalLogLikelihood;
+		points_.push_back(ipa::Pwl::Point({ t, finalLogLikelihood }));
+		if (points_.back().y() < points_[bestPoint].y())
+			bestPoint = points_.size() - 1;
+		if (t == mode_->ctHi)
+			break;
+		/* for even steps along the r/b curve scale them by the current t */
+		t = std::min(t + t / 10 * bayesConfig_.coarseStep, mode_->ctHi);
+	}
+	t = points_[bestPoint].x();
+	LOG(RPiAwb, Debug) << "Coarse search found CT " << t;
+	/*
+	 * We have the best point of the search, but refine it with a quadratic
+	 * interpolation around its neighbours.
+	 */
+	if (points_.size() > 2) {
+		unsigned long bp = std::min(bestPoint, points_.size() - 2);
+		bestPoint = std::max(1UL, bp);
+		t = interpolateQuadatric(points_[bestPoint - 1],
+					 points_[bestPoint],
+					 points_[bestPoint + 1]);
+		LOG(RPiAwb, Debug)
+			<< "After quadratic refinement, coarse search has CT "
+			<< t;
+	}
+	return t;
+}
+
+void AwbBayes::fineSearch(double &t, double &r, double &b, ipa::Pwl const &prior)
+{
+	int spanR = -1, spanB = -1;
+	config_.ctR.eval(t, &spanR);
+	config_.ctB.eval(t, &spanB);
+	double step = t / 10 * bayesConfig_.coarseStep * 0.1;
+	int nsteps = 5;
+	double rDiff = config_.ctR.eval(t + nsteps * step, &spanR) -
+		       config_.ctR.eval(t - nsteps * step, &spanR);
+	double bDiff = config_.ctB.eval(t + nsteps * step, &spanB) -
+		       config_.ctB.eval(t - nsteps * step, &spanB);
+	ipa::Pwl::Point transverse({ bDiff, -rDiff });
+	if (transverse.length2() < 1e-6)
+		return;
+	/*
+	 * unit vector orthogonal to the b vs. r function (pointing outwards
+	 * with r and b increasing)
+	 */
+	transverse = transverse / transverse.length();
+	double bestLogLikelihood = 0, bestT = 0, bestR = 0, bestB = 0;
+	double transverseRange = config_.transverseNeg + config_.transversePos;
+	const int maxNumDeltas = 12;
+	/* a transverse step approximately every 0.01 r/b units */
+	int numDeltas = floor(transverseRange * 100 + 0.5) + 1;
+	numDeltas = numDeltas < 3 ? 3 : (numDeltas > maxNumDeltas ? maxNumDeltas : numDeltas);
+	/*
+	 * Step down CT curve. March a bit further if the transverse range is
+	 * large.
+	 */
+	nsteps += numDeltas;
+	for (int i = -nsteps; i <= nsteps; i++) {
+		double tTest = t + i * step;
+		double priorLogLikelihood =
+			prior.eval(prior.domain().clamp(tTest));
+		double rCurve = config_.ctR.eval(tTest, &spanR);
+		double bCurve = config_.ctB.eval(tTest, &spanB);
+		/* x will be distance off the curve, y the log likelihood there */
+		ipa::Pwl::Point points[maxNumDeltas];
+		int bestPoint = 0;
+		/* Take some measurements transversely *off* the CT curve. */
+		for (int j = 0; j < numDeltas; j++) {
+			points[j][0] = -config_.transverseNeg +
+				       (transverseRange * j) / (numDeltas - 1);
+			ipa::Pwl::Point rbTest = ipa::Pwl::Point({ rCurve, bCurve }) +
+						 transverse * points[j].x();
+			double rTest = rbTest.x(), bTest = rbTest.y();
+			double gainR = 1 / rTest, gainB = 1 / bTest;
+			double delta2Sum = computeDelta2Sum(gainR, gainB, bayesConfig_.whitepointR, bayesConfig_.whitepointB);
+			points[j][1] = delta2Sum - priorLogLikelihood;
+			LOG(RPiAwb, Debug)
+				<< "At t " << tTest << " r " << rTest << " b "
+				<< bTest << ": " << points[j].y();
+			if (points[j].y() < points[bestPoint].y())
+				bestPoint = j;
+		}
+		/*
+		 * We have NUM_DELTAS points transversely across the CT curve,
+		 * now let's do a quadratic interpolation for the best result.
+		 */
+		bestPoint = std::max(1, std::min(bestPoint, numDeltas - 2));
+		ipa::Pwl::Point rbTest = ipa::Pwl::Point({ rCurve, bCurve }) +
+					 transverse * interpolateQuadatric(points[bestPoint - 1],
+									   points[bestPoint],
+									   points[bestPoint + 1]);
+		double rTest = rbTest.x(), bTest = rbTest.y();
+		double gainR = 1 / rTest, gainB = 1 / bTest;
+		double delta2Sum = computeDelta2Sum(gainR, gainB, bayesConfig_.whitepointR, bayesConfig_.whitepointB);
+		double finalLogLikelihood = delta2Sum - priorLogLikelihood;
+		LOG(RPiAwb, Debug)
+			<< "Finally "
+			<< tTest << " r " << rTest << " b " << bTest << ": "
+			<< finalLogLikelihood
+			<< (finalLogLikelihood < bestLogLikelihood ? " BEST" : "");
+		if (bestT == 0 || finalLogLikelihood < bestLogLikelihood)
+			bestLogLikelihood = finalLogLikelihood,
+			bestT = tTest, bestR = rTest, bestB = bTest;
+	}
+	t = bestT, r = bestR, b = bestB;
+	LOG(RPiAwb, Debug)
+		<< "Fine search found t " << t << " r " << r << " b " << b;
+}
+
+void AwbBayes::awbBayes()
+{
+	/*
+	 * May as well divide out G to save computeDelta2Sum from doing it over
+	 * and over.
+	 */
+	for (auto &z : zones_)
+		z.R = z.R / (z.G + 1), z.B = z.B / (z.G + 1);
+	/*
+	 * Get the current prior, and scale according to how many zones are
+	 * valid... not entirely sure about this.
+	 */
+	ipa::Pwl prior = interpolatePrior();
+	prior *= zones_.size() / (double)(statistics_->awbRegions.numRegions());
+	prior.map([](double x, double y) {
+		LOG(RPiAwb, Debug) << "(" << x << "," << y << ")";
+	});
+	double t = coarseSearch(prior);
+	double r = config_.ctR.eval(t);
+	double b = config_.ctB.eval(t);
+	LOG(RPiAwb, Debug)
+		<< "After coarse search: r " << r << " b " << b << " (gains r "
+		<< 1 / r << " b " << 1 / b << ")";
+	/*
+	 * Not entirely sure how to handle the fine search yet. Mostly the
+	 * estimated CT is already good enough, but the fine search allows us to
+	 * wander transverely off the CT curve. Under some illuminants, where
+	 * there may be more or less green light, this may prove beneficial,
+	 * though I probably need more real datasets before deciding exactly how
+	 * this should be controlled and tuned.
+	 */
+	fineSearch(t, r, b, prior);
+	LOG(RPiAwb, Debug)
+		<< "After fine search: r " << r << " b " << b << " (gains r "
+		<< 1 / r << " b " << 1 / b << ")";
+	/*
+	 * Write results out for the main thread to pick up. Remember to adjust
+	 * the gains from the ones that the "canonical sensor" would require to
+	 * the ones needed by *this* sensor.
+	 */
+	asyncResults_.temperatureK = t;
+	asyncResults_.gainR = 1.0 / r * config_.sensitivityR;
+	asyncResults_.gainG = 1.0;
+	asyncResults_.gainB = 1.0 / b * config_.sensitivityB;
+}
+
+void AwbBayes::doAwb()
+{
+	prepareStats();
+	LOG(RPiAwb, Debug) << "Valid zones: " << zones_.size();
+	if (zones_.size() > bayesConfig_.minRegions) {
+		if (bayesConfig_.bayes)
+			awbBayes();
+		else
+			awbGrey();
+		LOG(RPiAwb, Debug)
+			<< "CT found is "
+			<< asyncResults_.temperatureK
+			<< " with gains r " << asyncResults_.gainR
+			<< " and b " << asyncResults_.gainB;
+	}
+	/*
+	 * we're done with these; we may as well relinquish our hold on the
+	 * pointer.
+	 */
+	statistics_.reset();
+}
+
+/* Register algorithm with the system. */
+static Algorithm *create(Controller *controller)
+{
+	return (Algorithm *)new AwbBayes(controller);
+}
+static RegisterAlgorithm reg(NAME, &create);
+
+} /* namespace RPiController */