from dataclasses import dataclass from math import sqrt, cbrt, atan2, cos, sin, pi, inf @dataclass class Lab: L: float a: float b: float @dataclass class RGB: r: float g: float b: float @dataclass class HSL: h: float s: float l: float @dataclass class LC: L: float C: float # Alternative representation of (L_cusp, C_cusp) # Encoded so S = C_cusp/L_cusp and T = C_cusp/(1-L_cusp) # The maximum value for C in the triangle is then found as fmin(S*L, T*(1-L)), for a given L @dataclass class ST: S: float T: float @dataclass class Cs: C_0: float C_mid: float C_max: float #-------------------------------------------------------------- def srgb_transfer_function(a: float) -> float: if 0.0031308 >= a: return 12.92 * a else: return 1.055 * a**0.4166666666666667 - 0.055 def srgb_transfer_function_inv(a: float) -> float: if 0.04045 < a: return ((a + 0.055) / 1.055)**2.4 else: return a / 12.92 def linear_srgb_to_oklab(c: RGB) -> Lab: l = 0.4122214708 * c.r + 0.5363325363 * c.g + 0.0514459929 * c.b m = 0.2119034982 * c.r + 0.6806995451 * c.g + 0.1073969566 * c.b s = 0.0883024619 * c.r + 0.2817188376 * c.g + 0.6299787005 * c.b l_ = cbrt(l) m_ = cbrt(m) s_ = cbrt(s) return Lab( 0.2104542553 * l_ + 0.7936177850 * m_ - 0.0040720468 * s_, 1.9779984951 * l_ - 2.4285922050 * m_ + 0.4505937099 * s_, 0.0259040371 * l_ + 0.7827717662 * m_ - 0.8086757660 * s_, ) def oklab_to_linear_srgb(c: Lab) -> RGB: l_ = c.L + 0.3963377774 * c.a + 0.2158037573 * c.b m_ = c.L - 0.1055613458 * c.a - 0.0638541728 * c.b s_ = c.L - 0.0894841775 * c.a - 1.2914855480 * c.b l = l_ * l_ * l_ m = m_ * m_ * m_ s = s_ * s_ * s_ return RGB( +4.0767416621 * l - 3.3077115913 * m + 0.2309699292 * s, -1.2684380046 * l + 2.6097574011 * m - 0.3413193965 * s, -0.0041960863 * l - 0.7034186147 * m + 1.7076147010 * s, ) def srgb_to_oklab(c: RGB) -> Lab: return linear_srgb_to_oklab(RGB( srgb_transfer_function_inv(c.r), srgb_transfer_function_inv(c.g), srgb_transfer_function_inv(c.b), )) def oklab_to_srgb(c: Lab) -> RGB: rgb = oklab_to_linear_srgb(c) return RGB( srgb_transfer_function(rgb.r), srgb_transfer_function(rgb.g), srgb_transfer_function(rgb.b), ) #-------------------------------------------------------------- # Finds the maximum saturation possible for a given hue that fits in sRGB # Saturation here is defined as S = C/L # a and b must be normalized so a^2 + b^2 == 1 def compute_max_saturation(a: float, b: float) -> float: # Max saturation will be when one of r, g or b goes below zero. # Select different coefficients depending on which component goes below zero first k0 = k1 = k2 = k3 = k4 = wl = wm = ws = 0.0 if (-1.88170328 * a - 0.80936493 * b) > 1.0: # Red component k0 = 1.19086277 k1 = 1.76576728 k2 = 0.59662641 k3 = 0.75515197 k4 = 0.56771245 wl = 4.0767416621 wm = -3.3077115913 ws = 0.2309699292 elif (1.81444104 * a - 1.19445276 * b) > 1.0: # Green component k0 = 0.73956515 k1 = -0.45954404 k2 = 0.08285427 k3 = 0.12541070 k4 = 0.14503204 wl = -1.2684380046 wm = 2.6097574011 ws = -0.3413193965 else: # Blue component k0 = 1.35733652 k1 = -0.00915799 k2 = -1.15130210 k3 = -0.50559606 k4 = 0.00692167 wl = -0.0041960863 wm = -0.7034186147 ws = 1.7076147010 # Approximate max saturation using a polynomial: S = k0 + k1 * a + k2 * b + k3 * a * a + k4 * a * b # Do one step Halley's method to get closer # this gives an error less than 10e6, except for some blue hues where the dS/dh is close to infinite # this should be sufficient for most applications, otherwise do two/three steps k_l = 0.3963377774 * a + 0.2158037573 * b k_m = -0.1055613458 * a - 0.0638541728 * b k_s = -0.0894841775 * a - 1.2914855480 * b l_ = 1.0 + S * k_l m_ = 1.0 + S * k_m s_ = 1.0 + S * k_s l = l_ * l_ * l_ m = m_ * m_ * m_ s = s_ * s_ * s_ l_dS = 3.0 * k_l * l_ * l_ m_dS = 3.0 * k_m * m_ * m_ s_dS = 3.0 * k_s * s_ * s_ l_dS2 = 6.0 * k_l * k_l * l_ m_dS2 = 6.0 * k_m * k_m * m_ s_dS2 = 6.0 * k_s * k_s * s_ f = wl * l + wm * m + ws * s f1 = wl * l_dS + wm * m_dS + ws * s_dS f2 = wl * l_dS2 + wm * m_dS2 + ws * s_dS2 S = S - f * f1 / (f1*f1 - 0.5 * f * f2) return S # finds L_cusp and C_cusp for a given hue # a and b must be normalized so a^2 + b^2 == 1 def find_cusp(a: float, b: float) -> LC: # First, find the maximum saturation (saturation S = C/L) S_cusp = compute_max_saturation(a, b) # Convert to linear sRGB to find the first point where at least one of r,g or b >= 1: rgb_at_max = oklab_to_linear_srgb(Lab(1.0, S_cusp * a, S_cusp * b)) L_cusp = cbrt(1.0 / max(rgb_at_max.r, rgb_at_max.g, rgb_at_max.b)) C_cusp = L_cusp * S_cusp return LC(L_cusp , C_cusp) # Finds intersection of the line defined by # L = L0 * (1 - t) + t * L1; # C = t * C1; # a and b must be normalized so a^2 + b^2 == 1 def find_gamut_intersection(a: float, b: float, L1: float, C1: float, L0: float, cusp: LC) -> float: # Find the intersection for upper and lower half seprately t = 0.0 if (((L1 - L0) * cusp.C - (cusp.L - L0) * C1)) <= 0.0: # Lower half t = cusp.C * L0 / (C1 * cusp.L + cusp.C * (L0 - L1)) else: # Upper half # First intersect with triangle t = cusp.C * (L0 - 1.0) / (C1 * (cusp.L - 1.0) + cusp.C * (L0 - L1)) # Then one step Halley's method dL = L1 - L0 dC = C1 k_l = +0.3963377774 * a + 0.2158037573 * b k_m = -0.1055613458 * a - 0.0638541728 * b k_s = -0.0894841775 * a - 1.2914855480 * b l_dt = dL + dC * k_l m_dt = dL + dC * k_m s_dt = dL + dC * k_s # If higher accuracy is required, 2 or 3 iterations of the following block can be used: L = L0 * (1.0 - t) + t * L1 C = t * C1 l_ = L + C * k_l m_ = L + C * k_m s_ = L + C * k_s l = l_ * l_ * l_ m = m_ * m_ * m_ s = s_ * s_ * s_ ldt = 3.0 * l_dt * l_ * l_ mdt = 3.0 * m_dt * m_ * m_ sdt = 3.0 * s_dt * s_ * s_ ldt2 = 6.0 * l_dt * l_dt * l_ mdt2 = 6.0 * m_dt * m_dt * m_ sdt2 = 6.0 * s_dt * s_dt * s_ r = 4.0767416621 * l - 3.3077115913 * m + 0.2309699292 * s - 1 r1 = 4.0767416621 * ldt - 3.3077115913 * mdt + 0.2309699292 * sdt r2 = 4.0767416621 * ldt2 - 3.3077115913 * mdt2 + 0.2309699292 * sdt2 u_r = r1 / (r1 * r1 - 0.5 * r * r2) t_r = -r * u_r g = -1.2684380046 * l + 2.6097574011 * m - 0.3413193965 * s - 1 g1 = -1.2684380046 * ldt + 2.6097574011 * mdt - 0.3413193965 * sdt g2 = -1.2684380046 * ldt2 + 2.6097574011 * mdt2 - 0.3413193965 * sdt2 u_g = g1 / (g1 * g1 - 0.5 * g * g2) t_g = -g * u_g b = -0.0041960863 * l - 0.7034186147 * m + 1.7076147010 * s - 1 b1 = -0.0041960863 * ldt - 0.7034186147 * mdt + 1.7076147010 * sdt b2 = -0.0041960863 * ldt2 - 0.7034186147 * mdt2 + 1.7076147010 * sdt2 u_b = b1 / (b1 * b1 - 0.5 * b * b2) t_b = -b * u_b # t_r = u_r >= 0.f ? t_r : FLT_MAX; # t_g = u_g >= 0.f ? t_g : FLT_MAX; # t_b = u_b >= 0.f ? t_b : FLT_MAX; t_r = t_r if u_r >= 0.0 else inf t_g = t_g if u_g >= 0.0 else inf t_b = t_b if u_b >= 0.0 else inf t += min(t_r, t_g, t_b) return t #-------------------------------------------------------------- # toe function for L_r def toe(x: float) -> float: k_1 = 0.206 k_2 = 0.03 k_3 = (1.0 + k_1) / (1.0 + k_2); return 0.5 * (k_3 * x - k_1 + sqrt((k_3 * x - k_1) * (k_3 * x - k_1) + 4 * k_2 * k_3 * x)) # inverse toe function for L_r def toe_inv(x: float) -> float: k_1 = 0.206 k_2 = 0.03 k_3 = (1.0 + k_1) / (1.0 + k_2) return (x * x + k_1 * x) / (k_3 * (x + k_2)) def to_ST(cusp: LC) -> ST: L = cusp.L C = cusp.C return ST(C / L, C / (1 - L)) # Returns a smooth approximation of the location of the cusp # This polynomial was created by an optimization process # It has been designed so that S_mid < S_max and T_mid < T_max def get_ST_mid(a_: float, b_: float) -> ST: S = 0.115169930 + 1.0 / ( +7.447789700 + 4.159012400 * b_ + a_ * (-2.195573470 + 1.751984010 * b_ + a_ * (-2.137049480 - 10.023010430 * b_ + a_ * (-4.248945610 + 5.387708190 * b_ + 4.698910130 * a_ ))) ) T = 0.112396420 + 1.0 / ( +1.613203200 - 0.681243790 * b_ + a_ * (+0.403706120 + 0.901481230 * b_ + a_ * (-0.270879430 + 0.612239900 * b_ + a_ * (+0.002992150 - 0.453995680 * b_ - 0.146618720 * a_ ))) ) return ST(S, T) def get_Cs(L: float, a_: float, b_: float) -> Cs: cusp = find_cusp(a_, b_) C_max = find_gamut_intersection(a_, b_, L, 1, L, cusp) ST_max = to_ST(cusp) # Scale factor to compensate for the curved part of gamut shape: k = C_max / min((L * ST_max.S), (1.0 - L) * ST_max.T) ST_mid = get_ST_mid(a_, b_) # Use a soft minimum function, instead of a sharp triangle shape to get a smooth value for chroma. C_a = L * ST_mid.S C_b = (1.0 - L) * ST_mid.T C_mid = 0.9 * k * sqrt(sqrt(1.0 / (1.0 / (C_a * C_a * C_a * C_a) + 1.0 / (C_b * C_b * C_b * C_b)))) # for C_0, the shape is independent of hue, so ST are constant. Values picked to roughly be the average values of ST. C_a = L * 0.4 C_b = (1.0 - L) * 0.8 # Use a soft minimum function, instead of a sharp triangle shape to get a smooth value for chroma. C_0 = sqrt(1.0 / (1.0 / (C_a * C_a) + 1.0 / (C_b * C_b))) return Cs(C_0, C_mid, C_max) def okhsl_to_srgb(hsl: HSL) -> RGB: h = hsl.h s = hsl.s l = hsl.l if l == 1.0: return RGB(1.0, 1.0, 1.0) elif l == 0.0: return RGB(0.0, 0.0, 0.0) a_ = cos(2.0 * pi * h) b_ = sin(2.0 * pi * h) L = toe_inv(l) cs = get_Cs(L, a_, b_) C_0 = cs.C_0 C_mid = cs.C_mid C_max = cs.C_max # Interpolate the three values for C so that: # At s=0: dC/ds = C_0, C=0 # At s=0.8: C=C_mid # At s=1.0: C=C_max mid = 0.8 mid_inv = 1.25 C = t = k_0 = k_1 = k_2 = 0.0 if s < mid: t = mid_inv * s k_1 = mid * C_0 k_2 = (1.0 - k_1 / C_mid) C = t * k_1 / (1.0 - k_2 * t) else: t = (s - mid)/ (1 - mid) k_0 = C_mid k_1 = (1.0 - mid) * C_mid * C_mid * mid_inv * mid_inv / C_0 k_2 = (1.0 - (k_1) / (C_max - C_mid)) C = k_0 + t * k_1 / (1.0 - k_2 * t) rgb = oklab_to_srgb(Lab(L, C * a_, C * b_)) return RGB( min(1.0, max(0.0, rgb.r)), min(1.0, max(0.0, rgb.g)), min(1.0, max(0.0, rgb.b)), ) def srgb_to_okhsl(rgb: RGB) -> HSL: lab = srgb_to_oklab(rgb) C = sqrt(lab.a * lab.a + lab.b * lab.b) a_ = lab.a / C b_ = lab.b / C L = lab.L h = 0.5 + 0.5 * atan2(-lab.b, -lab.a) / pi cs = get_Cs(L, a_, b_) C_0 = cs.C_0 C_mid = cs.C_mid C_max = cs.C_max # Inverse of the interpolation in okhsl_to_srgb: mid = 0.8 mid_inv = 1.25 s = 0.0; if C < C_mid: k_1 = mid * C_0 k_2 = (1.0 - k_1 / C_mid) t = C / (k_1 + k_2 * C) s = t * mid else: k_0 = C_mid k_1 = (1.0 - mid) * C_mid * C_mid * mid_inv * mid_inv / C_0 k_2 = (1.0 - (k_1) / (C_max - C_mid)) t = (C - k_0) / (k_1 + k_2 * (C - k_0)) s = mid + (1.0 - mid) * t l = toe(L); return HSL(h, s, l)