1 // Written in the D programming language. 2 3 /** 4 This module is a port of a growing fragment of the $(D_PARAM numeric) 5 header in Alexander Stepanov's $(LINK2 https://en.wikipedia.org/wiki/Standard_Template_Library, 6 Standard Template Library), with a few additions. 7 8 Macros: 9 Copyright: Copyright Andrei Alexandrescu 2008 - 2009. 10 License: $(HTTP www.boost.org/LICENSE_1_0.txt, Boost License 1.0). 11 Authors: $(HTTP erdani.org, Andrei Alexandrescu), 12 Don Clugston, Robert Jacques, Ilya Yaroshenko 13 Source: $(PHOBOSSRC std/numeric.d) 14 */ 15 /* 16 Copyright Andrei Alexandrescu 2008 - 2009. 17 Distributed under the Boost Software License, Version 1.0. 18 (See accompanying file LICENSE_1_0.txt or copy at 19 http://www.boost.org/LICENSE_1_0.txt) 20 */ 21 module std.numeric; 22 23 import std.complex; 24 import std.math; 25 import core.math : fabs, ldexp, sin, sqrt; 26 import std.range.primitives; 27 import std.traits; 28 import std.typecons; 29 30 /// Format flags for CustomFloat. 31 public enum CustomFloatFlags 32 { 33 /// Adds a sign bit to allow for signed numbers. 34 signed = 1, 35 36 /** 37 * Store values in normalized form by default. The actual precision of the 38 * significand is extended by 1 bit by assuming an implicit leading bit of 1 39 * instead of 0. i.e. `1.nnnn` instead of `0.nnnn`. 40 * True for all $(LINK2 https://en.wikipedia.org/wiki/IEEE_floating_point, IEE754) types 41 */ 42 storeNormalized = 2, 43 44 /** 45 * Stores the significand in $(LINK2 https://en.wikipedia.org/wiki/IEEE_754-1985#Denormalized_numbers, 46 * IEEE754 denormalized) form when the exponent is 0. Required to express the value 0. 47 */ 48 allowDenorm = 4, 49 50 /** 51 * Allows the storage of $(LINK2 https://en.wikipedia.org/wiki/IEEE_754-1985#Positive_and_negative_infinity, 52 * IEEE754 _infinity) values. 53 */ 54 infinity = 8, 55 56 /// Allows the storage of $(LINK2 https://en.wikipedia.org/wiki/NaN, IEEE754 Not a Number) values. 57 nan = 16, 58 59 /** 60 * If set, select an exponent bias such that max_exp = 1. 61 * i.e. so that the maximum value is >= 1.0 and < 2.0. 62 * Ignored if the exponent bias is manually specified. 63 */ 64 probability = 32, 65 66 /// If set, unsigned custom floats are assumed to be negative. 67 negativeUnsigned = 64, 68 69 /**If set, 0 is the only allowed $(LINK2 https://en.wikipedia.org/wiki/IEEE_754-1985#Denormalized_numbers, 70 * IEEE754 denormalized) number. 71 * Requires allowDenorm and storeNormalized. 72 */ 73 allowDenormZeroOnly = 128 | allowDenorm | storeNormalized, 74 75 /// Include _all of the $(LINK2 https://en.wikipedia.org/wiki/IEEE_floating_point, IEEE754) options. 76 ieee = signed | storeNormalized | allowDenorm | infinity | nan , 77 78 /// Include none of the above options. 79 none = 0 80 } 81 82 private enum isIEEEQuadruple = floatTraits!real.realFormat == RealFormat.ieeeQuadruple; 83 84 private template CustomFloatParams(uint bits) 85 { 86 enum CustomFloatFlags flags = CustomFloatFlags.ieee 87 ^ ((bits == 80 && !isIEEEQuadruple) ? CustomFloatFlags.storeNormalized : CustomFloatFlags.none); 88 static if (bits == 8) alias CustomFloatParams = CustomFloatParams!( 4, 3, flags); 89 static if (bits == 16) alias CustomFloatParams = CustomFloatParams!(10, 5, flags); 90 static if (bits == 32) alias CustomFloatParams = CustomFloatParams!(23, 8, flags); 91 static if (bits == 64) alias CustomFloatParams = CustomFloatParams!(52, 11, flags); 92 static if (bits == 80) alias CustomFloatParams = CustomFloatParams!(64, 15, flags); 93 } 94 95 private template CustomFloatParams(uint precision, uint exponentWidth, CustomFloatFlags flags) 96 { 97 import std.meta : AliasSeq; 98 alias CustomFloatParams = 99 AliasSeq!( 100 precision, 101 exponentWidth, 102 flags, 103 (1 << (exponentWidth - ((flags & flags.probability) == 0))) 104 - ((flags & (flags.nan | flags.infinity)) != 0) - ((flags & flags.probability) != 0) 105 ); // ((flags & CustomFloatFlags.probability) == 0) 106 } 107 108 /** 109 * Allows user code to define custom floating-point formats. These formats are 110 * for storage only; all operations on them are performed by first implicitly 111 * extracting them to `real` first. After the operation is completed the 112 * result can be stored in a custom floating-point value via assignment. 113 */ 114 template CustomFloat(uint bits) 115 if (bits == 8 || bits == 16 || bits == 32 || bits == 64 || bits == 80) 116 { 117 alias CustomFloat = CustomFloat!(CustomFloatParams!(bits)); 118 } 119 120 /// ditto 121 template CustomFloat(uint precision, uint exponentWidth, CustomFloatFlags flags = CustomFloatFlags.ieee) 122 if (((flags & flags.signed) + precision + exponentWidth) % 8 == 0 && precision + exponentWidth > 0) 123 { 124 alias CustomFloat = CustomFloat!(CustomFloatParams!(precision, exponentWidth, flags)); 125 } 126 127 /// 128 @safe unittest 129 { 130 import std.math.trigonometry : sin, cos; 131 132 // Define a 16-bit floating point values 133 CustomFloat!16 x; // Using the number of bits 134 CustomFloat!(10, 5) y; // Using the precision and exponent width 135 CustomFloat!(10, 5,CustomFloatFlags.ieee) z; // Using the precision, exponent width and format flags 136 CustomFloat!(10, 5,CustomFloatFlags.ieee, 15) w; // Using the precision, exponent width, format flags and exponent offset bias 137 138 // Use the 16-bit floats mostly like normal numbers 139 w = x*y - 1; 140 141 // Functions calls require conversion 142 z = sin(+x) + cos(+y); // Use unary plus to concisely convert to a real 143 z = sin(x.get!float) + cos(y.get!float); // Or use get!T 144 z = sin(cast(float) x) + cos(cast(float) y); // Or use cast(T) to explicitly convert 145 146 // Define a 8-bit custom float for storing probabilities 147 alias Probability = CustomFloat!(4, 4, CustomFloatFlags.ieee^CustomFloatFlags.probability^CustomFloatFlags.signed ); 148 auto p = Probability(0.5); 149 } 150 151 // Facilitate converting numeric types to custom float 152 private union ToBinary(F) 153 if (is(typeof(CustomFloatParams!(F.sizeof*8))) || is(F == real)) 154 { 155 F set; 156 157 // If on Linux or Mac, where 80-bit reals are padded, ignore the 158 // padding. 159 import std.algorithm.comparison : min; 160 CustomFloat!(CustomFloatParams!(min(F.sizeof*8, 80))) get; 161 162 // Convert F to the correct binary type. 163 static typeof(get) opCall(F value) 164 { 165 ToBinary r; 166 r.set = value; 167 return r.get; 168 } 169 alias get this; 170 } 171 172 /// ditto 173 struct CustomFloat(uint precision, // fraction bits (23 for float) 174 uint exponentWidth, // exponent bits (8 for float) Exponent width 175 CustomFloatFlags flags, 176 uint bias) 177 if (isCorrectCustomFloat(precision, exponentWidth, flags)) 178 { 179 import std.bitmanip : bitfields; 180 import std.meta : staticIndexOf; 181 private: 182 // get the correct unsigned bitfield type to support > 32 bits 183 template uType(uint bits) 184 { 185 static if (bits <= size_t.sizeof*8) alias uType = size_t; 186 else alias uType = ulong ; 187 } 188 189 // get the correct signed bitfield type to support > 32 bits 190 template sType(uint bits) 191 { 192 static if (bits <= ptrdiff_t.sizeof*8-1) alias sType = ptrdiff_t; 193 else alias sType = long; 194 } 195 196 alias T_sig = uType!precision; 197 alias T_exp = uType!exponentWidth; 198 alias T_signed_exp = sType!exponentWidth; 199 200 alias Flags = CustomFloatFlags; 201 202 // Perform IEEE rounding with round to nearest detection 203 void roundedShift(T,U)(ref T sig, U shift) 204 { 205 if (shift >= T.sizeof*8) 206 { 207 // avoid illegal shift 208 sig = 0; 209 } 210 else if (sig << (T.sizeof*8 - shift) == cast(T) 1uL << (T.sizeof*8 - 1)) 211 { 212 // round to even 213 sig >>= shift; 214 sig += sig & 1; 215 } 216 else 217 { 218 sig >>= shift - 1; 219 sig += sig & 1; 220 // Perform standard rounding 221 sig >>= 1; 222 } 223 } 224 225 // Convert the current value to signed exponent, normalized form 226 void toNormalized(T,U)(ref T sig, ref U exp) 227 { 228 sig = significand; 229 auto shift = (T.sizeof*8) - precision; 230 exp = exponent; 231 static if (flags&(Flags.infinity|Flags.nan)) 232 { 233 // Handle inf or nan 234 if (exp == exponent_max) 235 { 236 exp = exp.max; 237 sig <<= shift; 238 static if (flags&Flags.storeNormalized) 239 { 240 // Save inf/nan in denormalized format 241 sig >>= 1; 242 sig += cast(T) 1uL << (T.sizeof*8 - 1); 243 } 244 return; 245 } 246 } 247 if ((~flags&Flags.storeNormalized) || 248 // Convert denormalized form to normalized form 249 ((flags&Flags.allowDenorm) && exp == 0)) 250 { 251 if (sig > 0) 252 { 253 import core.bitop : bsr; 254 auto shift2 = precision - bsr(sig); 255 exp -= shift2-1; 256 shift += shift2; 257 } 258 else // value = 0.0 259 { 260 exp = exp.min; 261 return; 262 } 263 } 264 sig <<= shift; 265 exp -= bias; 266 } 267 268 // Set the current value from signed exponent, normalized form 269 void fromNormalized(T,U)(ref T sig, ref U exp) 270 { 271 auto shift = (T.sizeof*8) - precision; 272 if (exp == exp.max) 273 { 274 // infinity or nan 275 exp = exponent_max; 276 static if (flags & Flags.storeNormalized) 277 sig <<= 1; 278 279 // convert back to normalized form 280 static if (~flags & Flags.infinity) 281 // No infinity support? 282 assert(sig != 0, "Infinity floating point value assigned to a " 283 ~ typeof(this).stringof ~ " (no infinity support)."); 284 285 static if (~flags & Flags.nan) // No NaN support? 286 assert(sig == 0, "NaN floating point value assigned to a " ~ 287 typeof(this).stringof ~ " (no nan support)."); 288 sig >>= shift; 289 return; 290 } 291 if (exp == exp.min) // 0.0 292 { 293 exp = 0; 294 sig = 0; 295 return; 296 } 297 298 exp += bias; 299 if (exp <= 0) 300 { 301 static if ((flags&Flags.allowDenorm) || 302 // Convert from normalized form to denormalized 303 (~flags&Flags.storeNormalized)) 304 { 305 shift += -exp; 306 roundedShift(sig,1); 307 sig += cast(T) 1uL << (T.sizeof*8 - 1); 308 // Add the leading 1 309 exp = 0; 310 } 311 else 312 assert((flags&Flags.storeNormalized) && exp == 0, 313 "Underflow occured assigning to a " ~ 314 typeof(this).stringof ~ " (no denormal support)."); 315 } 316 else 317 { 318 static if (~flags&Flags.storeNormalized) 319 { 320 // Convert from normalized form to denormalized 321 roundedShift(sig,1); 322 sig += cast(T) 1uL << (T.sizeof*8 - 1); 323 // Add the leading 1 324 } 325 } 326 327 if (shift > 0) 328 roundedShift(sig,shift); 329 if (sig > significand_max) 330 { 331 // handle significand overflow (should only be 1 bit) 332 static if (~flags&Flags.storeNormalized) 333 { 334 sig >>= 1; 335 } 336 else 337 sig &= significand_max; 338 exp++; 339 } 340 static if ((flags&Flags.allowDenormZeroOnly)==Flags.allowDenormZeroOnly) 341 { 342 // disallow non-zero denormals 343 if (exp == 0) 344 { 345 sig <<= 1; 346 if (sig > significand_max && (sig&significand_max) > 0) 347 // Check and round to even 348 exp++; 349 sig = 0; 350 } 351 } 352 353 if (exp >= exponent_max) 354 { 355 static if (flags&(Flags.infinity|Flags.nan)) 356 { 357 sig = 0; 358 exp = exponent_max; 359 static if (~flags&(Flags.infinity)) 360 assert(0, "Overflow occured assigning to a " ~ 361 typeof(this).stringof ~ " (no infinity support)."); 362 } 363 else 364 assert(exp == exponent_max, "Overflow occured assigning to a " 365 ~ typeof(this).stringof ~ " (no infinity support)."); 366 } 367 } 368 369 public: 370 static if (precision == 64) // CustomFloat!80 support hack 371 { 372 static if (isIEEEQuadruple) 373 { 374 // Only use highest 64 significand bits from 112 explicitly stored 375 align (1): 376 enum ulong significand_max = ulong.max; 377 version (LittleEndian) 378 { 379 private ubyte[6] _padding; // 48-bit of padding 380 ulong significand; 381 mixin(bitfields!( 382 T_exp , "exponent", exponentWidth, 383 bool , "sign" , flags & flags.signed )); 384 } 385 else 386 { 387 mixin(bitfields!( 388 T_exp , "exponent", exponentWidth, 389 bool , "sign" , flags & flags.signed )); 390 ulong significand; 391 private ubyte[6] _padding; // 48-bit of padding 392 } 393 } 394 else 395 { 396 ulong significand; 397 enum ulong significand_max = ulong.max; 398 mixin(bitfields!( 399 T_exp , "exponent", exponentWidth, 400 bool , "sign" , flags & flags.signed )); 401 } 402 } 403 else 404 { 405 mixin(bitfields!( 406 T_sig, "significand", precision, 407 T_exp, "exponent" , exponentWidth, 408 bool , "sign" , flags & flags.signed )); 409 } 410 411 /// Returns: infinity value 412 static if (flags & Flags.infinity) 413 static @property CustomFloat infinity() 414 { 415 CustomFloat value; 416 static if (flags & Flags.signed) 417 value.sign = 0; 418 value.significand = 0; 419 value.exponent = exponent_max; 420 return value; 421 } 422 423 /// Returns: NaN value 424 static if (flags & Flags.nan) 425 static @property CustomFloat nan() 426 { 427 CustomFloat value; 428 static if (flags & Flags.signed) 429 value.sign = 0; 430 value.significand = cast(typeof(significand_max)) 1L << (precision-1); 431 value.exponent = exponent_max; 432 return value; 433 } 434 435 /// Returns: number of decimal digits of precision 436 static @property size_t dig() 437 { 438 auto shiftcnt = precision - ((flags&Flags.storeNormalized) == 0); 439 return shiftcnt == 64 ? 19 : cast(size_t) log10(real(1uL << shiftcnt)); 440 } 441 442 /// Returns: smallest increment to the value 1 443 static @property CustomFloat epsilon() 444 { 445 CustomFloat one = CustomFloat(1); 446 CustomFloat onePlusEpsilon = one; 447 onePlusEpsilon.significand = onePlusEpsilon.significand | 1; // |= does not work here 448 449 return CustomFloat(onePlusEpsilon - one); 450 } 451 452 /// the number of bits in mantissa 453 enum mant_dig = precision + ((flags&Flags.storeNormalized) != 0); 454 455 /// Returns: maximum int value such that 10<sup>max_10_exp</sup> is representable 456 static @property int max_10_exp(){ return cast(int) log10( +max ); } 457 458 /// maximum int value such that 2<sup>max_exp-1</sup> is representable 459 enum max_exp = exponent_max - bias - ((flags & (Flags.infinity | Flags.nan)) != 0) + 1; 460 461 /// Returns: minimum int value such that 10<sup>min_10_exp</sup> is representable 462 static @property int min_10_exp(){ return cast(int) log10( +min_normal ); } 463 464 /// minimum int value such that 2<sup>min_exp-1</sup> is representable as a normalized value 465 enum min_exp = cast(T_signed_exp) -(cast(long) bias) + 1 + ((flags & Flags.allowDenorm) != 0); 466 467 /// Returns: largest representable value that's not infinity 468 static @property CustomFloat max() 469 { 470 CustomFloat value; 471 static if (flags & Flags.signed) 472 value.sign = 0; 473 value.exponent = exponent_max - ((flags&(flags.infinity|flags.nan)) != 0); 474 value.significand = significand_max; 475 return value; 476 } 477 478 /// Returns: smallest representable normalized value that's not 0 479 static @property CustomFloat min_normal() 480 { 481 CustomFloat value; 482 static if (flags & Flags.signed) 483 value.sign = 0; 484 value.exponent = (flags & Flags.allowDenorm) != 0; 485 static if (flags & Flags.storeNormalized) 486 value.significand = 0; 487 else 488 value.significand = cast(T_sig) 1uL << (precision - 1); 489 return value; 490 } 491 492 /// Returns: real part 493 @property CustomFloat re() { return this; } 494 495 /// Returns: imaginary part 496 static @property CustomFloat im() { return CustomFloat(0.0f); } 497 498 /// Initialize from any `real` compatible type. 499 this(F)(F input) if (__traits(compiles, cast(real) input )) 500 { 501 this = input; 502 } 503 504 /// Self assignment 505 void opAssign(F:CustomFloat)(F input) 506 { 507 static if (flags & Flags.signed) 508 sign = input.sign; 509 exponent = input.exponent; 510 significand = input.significand; 511 } 512 513 /// Assigns from any `real` compatible type. 514 void opAssign(F)(F input) 515 if (__traits(compiles, cast(real) input)) 516 { 517 import std.conv : text; 518 519 static if (staticIndexOf!(immutable F, immutable float, immutable double, immutable real) >= 0) 520 auto value = ToBinary!(Unqual!F)(input); 521 else 522 auto value = ToBinary!(real )(input); 523 524 // Assign the sign bit 525 static if (~flags & Flags.signed) 526 assert((!value.sign) ^ ((flags&flags.negativeUnsigned) > 0), 527 "Incorrectly signed floating point value assigned to a " ~ 528 typeof(this).stringof ~ " (no sign support)."); 529 else 530 sign = value.sign; 531 532 CommonType!(T_signed_exp ,value.T_signed_exp) exp = value.exponent; 533 CommonType!(T_sig, value.T_sig ) sig = value.significand; 534 535 value.toNormalized(sig,exp); 536 fromNormalized(sig,exp); 537 538 assert(exp <= exponent_max, text(typeof(this).stringof ~ 539 " exponent too large: " ,exp," > ",exponent_max, "\t",input,"\t",sig)); 540 assert(sig <= significand_max, text(typeof(this).stringof ~ 541 " significand too large: ",sig," > ",significand_max, 542 "\t",input,"\t",exp," ",exponent_max)); 543 exponent = cast(T_exp) exp; 544 significand = cast(T_sig) sig; 545 } 546 547 /// Fetches the stored value either as a `float`, `double` or `real`. 548 @property F get(F)() 549 if (staticIndexOf!(immutable F, immutable float, immutable double, immutable real) >= 0) 550 { 551 import std.conv : text; 552 553 ToBinary!F result; 554 555 static if (flags&Flags.signed) 556 result.sign = sign; 557 else 558 result.sign = (flags&flags.negativeUnsigned) > 0; 559 560 CommonType!(T_signed_exp ,result.get.T_signed_exp ) exp = exponent; // Assign the exponent and fraction 561 CommonType!(T_sig, result.get.T_sig ) sig = significand; 562 563 toNormalized(sig,exp); 564 result.fromNormalized(sig,exp); 565 assert(exp <= result.exponent_max, text("get exponent too large: " ,exp," > ",result.exponent_max) ); 566 assert(sig <= result.significand_max, text("get significand too large: ",sig," > ",result.significand_max) ); 567 result.exponent = cast(result.get.T_exp) exp; 568 result.significand = cast(result.get.T_sig) sig; 569 return result.set; 570 } 571 572 ///ditto 573 alias opCast = get; 574 575 /// Convert the CustomFloat to a real and perform the relevant operator on the result 576 real opUnary(string op)() 577 if (__traits(compiles, mixin(op~`(get!real)`)) || op=="++" || op=="--") 578 { 579 static if (op=="++" || op=="--") 580 { 581 auto result = get!real; 582 this = mixin(op~`result`); 583 return result; 584 } 585 else 586 return mixin(op~`get!real`); 587 } 588 589 /// ditto 590 // Define an opBinary `CustomFloat op CustomFloat` so that those below 591 // do not match equally, which is disallowed by the spec: 592 // https://dlang.org/spec/operatoroverloading.html#binary 593 real opBinary(string op,T)(T b) 594 if (__traits(compiles, mixin(`get!real`~op~`b.get!real`))) 595 { 596 return mixin(`get!real`~op~`b.get!real`); 597 } 598 599 /// ditto 600 real opBinary(string op,T)(T b) 601 if ( __traits(compiles, mixin(`get!real`~op~`b`)) && 602 !__traits(compiles, mixin(`get!real`~op~`b.get!real`))) 603 { 604 return mixin(`get!real`~op~`b`); 605 } 606 607 /// ditto 608 real opBinaryRight(string op,T)(T a) 609 if ( __traits(compiles, mixin(`a`~op~`get!real`)) && 610 !__traits(compiles, mixin(`get!real`~op~`b`)) && 611 !__traits(compiles, mixin(`get!real`~op~`b.get!real`))) 612 { 613 return mixin(`a`~op~`get!real`); 614 } 615 616 /// ditto 617 int opCmp(T)(auto ref T b) 618 if (__traits(compiles, cast(real) b)) 619 { 620 auto x = get!real; 621 auto y = cast(real) b; 622 return (x >= y)-(x <= y); 623 } 624 625 /// ditto 626 void opOpAssign(string op, T)(auto ref T b) 627 if (__traits(compiles, mixin(`get!real`~op~`cast(real) b`))) 628 { 629 return mixin(`this = this `~op~` cast(real) b`); 630 } 631 632 /// ditto 633 template toString() 634 { 635 import std.format.spec : FormatSpec; 636 import std.format.write : formatValue; 637 // Needs to be a template because of https://issues.dlang.org/show_bug.cgi?id=13737. 638 void toString()(scope void delegate(const(char)[]) sink, scope const ref FormatSpec!char fmt) 639 { 640 sink.formatValue(get!real, fmt); 641 } 642 } 643 } 644 645 @safe unittest 646 { 647 import std.meta; 648 alias FPTypes = 649 AliasSeq!( 650 CustomFloat!(5, 10), 651 CustomFloat!(5, 11, CustomFloatFlags.ieee ^ CustomFloatFlags.signed), 652 CustomFloat!(1, 7, CustomFloatFlags.ieee ^ CustomFloatFlags.signed), 653 CustomFloat!(4, 3, CustomFloatFlags.ieee | CustomFloatFlags.probability ^ CustomFloatFlags.signed) 654 ); 655 656 foreach (F; FPTypes) 657 { 658 auto x = F(0.125); 659 assert(x.get!float == 0.125F); 660 assert(x.get!double == 0.125); 661 assert(x.get!real == 0.125L); 662 663 x -= 0.0625; 664 assert(x.get!float == 0.0625F); 665 assert(x.get!double == 0.0625); 666 assert(x.get!real == 0.0625L); 667 668 x *= 2; 669 assert(x.get!float == 0.125F); 670 assert(x.get!double == 0.125); 671 assert(x.get!real == 0.125L); 672 673 x /= 4; 674 assert(x.get!float == 0.03125); 675 assert(x.get!double == 0.03125); 676 assert(x.get!real == 0.03125L); 677 678 x = 0.5; 679 x ^^= 4; 680 assert(x.get!float == 1 / 16.0F); 681 assert(x.get!double == 1 / 16.0); 682 assert(x.get!real == 1 / 16.0L); 683 } 684 } 685 686 @system unittest 687 { 688 // @system due to to!string(CustomFloat) 689 import std.conv; 690 CustomFloat!(5, 10) y = CustomFloat!(5, 10)(0.125); 691 assert(y.to!string == "0.125"); 692 } 693 694 @safe unittest 695 { 696 alias cf = CustomFloat!(5, 2); 697 698 auto a = cf.infinity; 699 assert(a.sign == 0); 700 assert(a.exponent == 3); 701 assert(a.significand == 0); 702 703 auto b = cf.nan; 704 assert(b.exponent == 3); 705 assert(b.significand != 0); 706 707 assert(cf.dig == 1); 708 709 auto c = cf.epsilon; 710 assert(c.sign == 0); 711 assert(c.exponent == 0); 712 assert(c.significand == 1); 713 714 assert(cf.mant_dig == 6); 715 716 assert(cf.max_10_exp == 0); 717 assert(cf.max_exp == 2); 718 assert(cf.min_10_exp == 0); 719 assert(cf.min_exp == 1); 720 721 auto d = cf.max; 722 assert(d.sign == 0); 723 assert(d.exponent == 2); 724 assert(d.significand == 31); 725 726 auto e = cf.min_normal; 727 assert(e.sign == 0); 728 assert(e.exponent == 1); 729 assert(e.significand == 0); 730 731 assert(e.re == e); 732 assert(e.im == cf(0.0)); 733 } 734 735 // check whether CustomFloats identical to float/double behave like float/double 736 @safe unittest 737 { 738 import std.conv : to; 739 740 alias myFloat = CustomFloat!(23, 8); 741 742 static assert(myFloat.dig == float.dig); 743 static assert(myFloat.mant_dig == float.mant_dig); 744 assert(myFloat.max_10_exp == float.max_10_exp); 745 static assert(myFloat.max_exp == float.max_exp); 746 assert(myFloat.min_10_exp == float.min_10_exp); 747 static assert(myFloat.min_exp == float.min_exp); 748 assert(to!float(myFloat.epsilon) == float.epsilon); 749 assert(to!float(myFloat.max) == float.max); 750 assert(to!float(myFloat.min_normal) == float.min_normal); 751 752 alias myDouble = CustomFloat!(52, 11); 753 754 static assert(myDouble.dig == double.dig); 755 static assert(myDouble.mant_dig == double.mant_dig); 756 assert(myDouble.max_10_exp == double.max_10_exp); 757 static assert(myDouble.max_exp == double.max_exp); 758 assert(myDouble.min_10_exp == double.min_10_exp); 759 static assert(myDouble.min_exp == double.min_exp); 760 assert(to!double(myDouble.epsilon) == double.epsilon); 761 assert(to!double(myDouble.max) == double.max); 762 assert(to!double(myDouble.min_normal) == double.min_normal); 763 } 764 765 // testing .dig 766 @safe unittest 767 { 768 static assert(CustomFloat!(1, 6).dig == 0); 769 static assert(CustomFloat!(9, 6).dig == 2); 770 static assert(CustomFloat!(10, 5).dig == 3); 771 static assert(CustomFloat!(10, 6, CustomFloatFlags.none).dig == 2); 772 static assert(CustomFloat!(11, 5, CustomFloatFlags.none).dig == 3); 773 static assert(CustomFloat!(64, 7).dig == 19); 774 } 775 776 // testing .mant_dig 777 @safe unittest 778 { 779 static assert(CustomFloat!(10, 5).mant_dig == 11); 780 static assert(CustomFloat!(10, 6, CustomFloatFlags.none).mant_dig == 10); 781 } 782 783 // testing .max_exp 784 @safe unittest 785 { 786 static assert(CustomFloat!(1, 6).max_exp == 2^^5); 787 static assert(CustomFloat!(2, 6, CustomFloatFlags.none).max_exp == 2^^5); 788 static assert(CustomFloat!(5, 10).max_exp == 2^^9); 789 static assert(CustomFloat!(6, 10, CustomFloatFlags.none).max_exp == 2^^9); 790 static assert(CustomFloat!(2, 6, CustomFloatFlags.nan).max_exp == 2^^5); 791 static assert(CustomFloat!(6, 10, CustomFloatFlags.nan).max_exp == 2^^9); 792 } 793 794 // testing .min_exp 795 @safe unittest 796 { 797 static assert(CustomFloat!(1, 6).min_exp == -2^^5+3); 798 static assert(CustomFloat!(5, 10).min_exp == -2^^9+3); 799 static assert(CustomFloat!(2, 6, CustomFloatFlags.none).min_exp == -2^^5+1); 800 static assert(CustomFloat!(6, 10, CustomFloatFlags.none).min_exp == -2^^9+1); 801 static assert(CustomFloat!(2, 6, CustomFloatFlags.nan).min_exp == -2^^5+2); 802 static assert(CustomFloat!(6, 10, CustomFloatFlags.nan).min_exp == -2^^9+2); 803 static assert(CustomFloat!(2, 6, CustomFloatFlags.allowDenorm).min_exp == -2^^5+2); 804 static assert(CustomFloat!(6, 10, CustomFloatFlags.allowDenorm).min_exp == -2^^9+2); 805 } 806 807 // testing .max_10_exp 808 @safe unittest 809 { 810 assert(CustomFloat!(1, 6).max_10_exp == 9); 811 assert(CustomFloat!(5, 10).max_10_exp == 154); 812 assert(CustomFloat!(2, 6, CustomFloatFlags.none).max_10_exp == 9); 813 assert(CustomFloat!(6, 10, CustomFloatFlags.none).max_10_exp == 154); 814 assert(CustomFloat!(2, 6, CustomFloatFlags.nan).max_10_exp == 9); 815 assert(CustomFloat!(6, 10, CustomFloatFlags.nan).max_10_exp == 154); 816 } 817 818 // testing .min_10_exp 819 @safe unittest 820 { 821 assert(CustomFloat!(1, 6).min_10_exp == -9); 822 assert(CustomFloat!(5, 10).min_10_exp == -153); 823 assert(CustomFloat!(2, 6, CustomFloatFlags.none).min_10_exp == -9); 824 assert(CustomFloat!(6, 10, CustomFloatFlags.none).min_10_exp == -154); 825 assert(CustomFloat!(2, 6, CustomFloatFlags.nan).min_10_exp == -9); 826 assert(CustomFloat!(6, 10, CustomFloatFlags.nan).min_10_exp == -153); 827 assert(CustomFloat!(2, 6, CustomFloatFlags.allowDenorm).min_10_exp == -9); 828 assert(CustomFloat!(6, 10, CustomFloatFlags.allowDenorm).min_10_exp == -153); 829 } 830 831 // testing .epsilon 832 @safe unittest 833 { 834 assert(CustomFloat!(1,6).epsilon.sign == 0); 835 assert(CustomFloat!(1,6).epsilon.exponent == 30); 836 assert(CustomFloat!(1,6).epsilon.significand == 0); 837 assert(CustomFloat!(2,5).epsilon.sign == 0); 838 assert(CustomFloat!(2,5).epsilon.exponent == 13); 839 assert(CustomFloat!(2,5).epsilon.significand == 0); 840 assert(CustomFloat!(3,4).epsilon.sign == 0); 841 assert(CustomFloat!(3,4).epsilon.exponent == 4); 842 assert(CustomFloat!(3,4).epsilon.significand == 0); 843 // the following epsilons are only available, when denormalized numbers are allowed: 844 assert(CustomFloat!(4,3).epsilon.sign == 0); 845 assert(CustomFloat!(4,3).epsilon.exponent == 0); 846 assert(CustomFloat!(4,3).epsilon.significand == 4); 847 assert(CustomFloat!(5,2).epsilon.sign == 0); 848 assert(CustomFloat!(5,2).epsilon.exponent == 0); 849 assert(CustomFloat!(5,2).epsilon.significand == 1); 850 } 851 852 // testing .max 853 @safe unittest 854 { 855 static assert(CustomFloat!(5,2).max.sign == 0); 856 static assert(CustomFloat!(5,2).max.exponent == 2); 857 static assert(CustomFloat!(5,2).max.significand == 31); 858 static assert(CustomFloat!(4,3).max.sign == 0); 859 static assert(CustomFloat!(4,3).max.exponent == 6); 860 static assert(CustomFloat!(4,3).max.significand == 15); 861 static assert(CustomFloat!(3,4).max.sign == 0); 862 static assert(CustomFloat!(3,4).max.exponent == 14); 863 static assert(CustomFloat!(3,4).max.significand == 7); 864 static assert(CustomFloat!(2,5).max.sign == 0); 865 static assert(CustomFloat!(2,5).max.exponent == 30); 866 static assert(CustomFloat!(2,5).max.significand == 3); 867 static assert(CustomFloat!(1,6).max.sign == 0); 868 static assert(CustomFloat!(1,6).max.exponent == 62); 869 static assert(CustomFloat!(1,6).max.significand == 1); 870 static assert(CustomFloat!(3,5, CustomFloatFlags.none).max.exponent == 31); 871 static assert(CustomFloat!(3,5, CustomFloatFlags.none).max.significand == 7); 872 } 873 874 // testing .min_normal 875 @safe unittest 876 { 877 static assert(CustomFloat!(5,2).min_normal.sign == 0); 878 static assert(CustomFloat!(5,2).min_normal.exponent == 1); 879 static assert(CustomFloat!(5,2).min_normal.significand == 0); 880 static assert(CustomFloat!(4,3).min_normal.sign == 0); 881 static assert(CustomFloat!(4,3).min_normal.exponent == 1); 882 static assert(CustomFloat!(4,3).min_normal.significand == 0); 883 static assert(CustomFloat!(3,4).min_normal.sign == 0); 884 static assert(CustomFloat!(3,4).min_normal.exponent == 1); 885 static assert(CustomFloat!(3,4).min_normal.significand == 0); 886 static assert(CustomFloat!(2,5).min_normal.sign == 0); 887 static assert(CustomFloat!(2,5).min_normal.exponent == 1); 888 static assert(CustomFloat!(2,5).min_normal.significand == 0); 889 static assert(CustomFloat!(1,6).min_normal.sign == 0); 890 static assert(CustomFloat!(1,6).min_normal.exponent == 1); 891 static assert(CustomFloat!(1,6).min_normal.significand == 0); 892 static assert(CustomFloat!(3,5, CustomFloatFlags.none).min_normal.exponent == 0); 893 static assert(CustomFloat!(3,5, CustomFloatFlags.none).min_normal.significand == 4); 894 } 895 896 @safe unittest 897 { 898 import std.math.traits : isNaN; 899 900 alias cf = CustomFloat!(5, 2); 901 902 auto f = cf.nan.get!float(); 903 assert(isNaN(f)); 904 905 cf a; 906 a = real.max; 907 assert(a == cf.infinity); 908 909 a = 0.015625; 910 assert(a.exponent == 0); 911 assert(a.significand == 0); 912 913 a = 0.984375; 914 assert(a.exponent == 1); 915 assert(a.significand == 0); 916 } 917 918 @system unittest 919 { 920 import std.exception : assertThrown; 921 import core.exception : AssertError; 922 923 alias cf = CustomFloat!(3, 5, CustomFloatFlags.none); 924 925 cf a; 926 assertThrown!AssertError(a = real.max); 927 } 928 929 @system unittest 930 { 931 import std.exception : assertThrown; 932 import core.exception : AssertError; 933 934 alias cf = CustomFloat!(3, 5, CustomFloatFlags.nan); 935 936 cf a; 937 assertThrown!AssertError(a = real.max); 938 } 939 940 @system unittest 941 { 942 import std.exception : assertThrown; 943 import core.exception : AssertError; 944 945 alias cf = CustomFloat!(24, 8, CustomFloatFlags.none); 946 947 cf a; 948 assertThrown!AssertError(a = float.infinity); 949 } 950 951 private bool isCorrectCustomFloat(uint precision, uint exponentWidth, CustomFloatFlags flags) @safe pure nothrow @nogc 952 { 953 // Restrictions from bitfield 954 // due to CustomFloat!80 support hack precision with 64 bits is handled specially 955 auto length = (flags & flags.signed) + exponentWidth + ((precision == 64) ? 0 : precision); 956 if (length != 8 && length != 16 && length != 32 && length != 64) return false; 957 958 // mantissa needs to fit into real mantissa 959 if (precision > real.mant_dig - 1 && precision != 64) return false; 960 961 // exponent needs to fit into real exponent 962 if (1L << exponentWidth - 1 > real.max_exp) return false; 963 964 // mantissa should have at least one bit 965 if (precision == 0) return false; 966 967 // exponent should have at least one bit, in some cases two 968 if (exponentWidth <= ((flags & (flags.allowDenorm | flags.infinity | flags.nan)) != 0)) return false; 969 970 return true; 971 } 972 973 @safe pure nothrow @nogc unittest 974 { 975 assert(isCorrectCustomFloat(3,4,CustomFloatFlags.ieee)); 976 assert(isCorrectCustomFloat(3,5,CustomFloatFlags.none)); 977 assert(!isCorrectCustomFloat(3,3,CustomFloatFlags.ieee)); 978 assert(isCorrectCustomFloat(64,7,CustomFloatFlags.ieee)); 979 assert(!isCorrectCustomFloat(64,4,CustomFloatFlags.ieee)); 980 assert(!isCorrectCustomFloat(508,3,CustomFloatFlags.ieee)); 981 assert(!isCorrectCustomFloat(3,100,CustomFloatFlags.ieee)); 982 assert(!isCorrectCustomFloat(0,7,CustomFloatFlags.ieee)); 983 assert(!isCorrectCustomFloat(6,1,CustomFloatFlags.ieee)); 984 assert(isCorrectCustomFloat(7,1,CustomFloatFlags.none)); 985 assert(!isCorrectCustomFloat(8,0,CustomFloatFlags.none)); 986 } 987 988 /** 989 Defines the fastest type to use when storing temporaries of a 990 calculation intended to ultimately yield a result of type `F` 991 (where `F` must be one of `float`, `double`, or $(D 992 real)). When doing a multi-step computation, you may want to store 993 intermediate results as `FPTemporary!F`. 994 995 The necessity of `FPTemporary` stems from the optimized 996 floating-point operations and registers present in virtually all 997 processors. When adding numbers in the example above, the addition may 998 in fact be done in `real` precision internally. In that case, 999 storing the intermediate `result` in $(D double format) is not only 1000 less precise, it is also (surprisingly) slower, because a conversion 1001 from `real` to `double` is performed every pass through the 1002 loop. This being a lose-lose situation, `FPTemporary!F` has been 1003 defined as the $(I fastest) type to use for calculations at precision 1004 `F`. There is no need to define a type for the $(I most accurate) 1005 calculations, as that is always `real`. 1006 1007 Finally, there is no guarantee that using `FPTemporary!F` will 1008 always be fastest, as the speed of floating-point calculations depends 1009 on very many factors. 1010 */ 1011 template FPTemporary(F) 1012 if (isFloatingPoint!F) 1013 { 1014 version (X86) 1015 alias FPTemporary = real; 1016 else 1017 alias FPTemporary = Unqual!F; 1018 } 1019 1020 /// 1021 @safe unittest 1022 { 1023 import std.math.operations : isClose; 1024 1025 // Average numbers in an array 1026 double avg(in double[] a) 1027 { 1028 if (a.length == 0) return 0; 1029 FPTemporary!double result = 0; 1030 foreach (e; a) result += e; 1031 return result / a.length; 1032 } 1033 1034 auto a = [1.0, 2.0, 3.0]; 1035 assert(isClose(avg(a), 2)); 1036 } 1037 1038 /** 1039 Implements the $(HTTP tinyurl.com/2zb9yr, secant method) for finding a 1040 root of the function `fun` starting from points $(D [xn_1, x_n]) 1041 (ideally close to the root). `Num` may be `float`, `double`, 1042 or `real`. 1043 */ 1044 template secantMethod(alias fun) 1045 { 1046 import std.functional : unaryFun; 1047 Num secantMethod(Num)(Num xn_1, Num xn) 1048 { 1049 auto fxn = unaryFun!(fun)(xn_1), d = xn_1 - xn; 1050 typeof(fxn) fxn_1; 1051 1052 xn = xn_1; 1053 while (!isClose(d, 0, 0.0, 1e-5) && isFinite(d)) 1054 { 1055 xn_1 = xn; 1056 xn -= d; 1057 fxn_1 = fxn; 1058 fxn = unaryFun!(fun)(xn); 1059 d *= -fxn / (fxn - fxn_1); 1060 } 1061 return xn; 1062 } 1063 } 1064 1065 /// 1066 @safe unittest 1067 { 1068 import std.math.operations : isClose; 1069 import std.math.trigonometry : cos; 1070 1071 float f(float x) 1072 { 1073 return cos(x) - x*x*x; 1074 } 1075 auto x = secantMethod!(f)(0f, 1f); 1076 assert(isClose(x, 0.865474)); 1077 } 1078 1079 @system unittest 1080 { 1081 // @system because of __gshared stderr 1082 import std.stdio; 1083 scope(failure) stderr.writeln("Failure testing secantMethod"); 1084 float f(float x) 1085 { 1086 return cos(x) - x*x*x; 1087 } 1088 immutable x = secantMethod!(f)(0f, 1f); 1089 assert(isClose(x, 0.865474)); 1090 auto d = &f; 1091 immutable y = secantMethod!(d)(0f, 1f); 1092 assert(isClose(y, 0.865474)); 1093 } 1094 1095 1096 /** 1097 * Return true if a and b have opposite sign. 1098 */ 1099 private bool oppositeSigns(T1, T2)(T1 a, T2 b) 1100 { 1101 return signbit(a) != signbit(b); 1102 } 1103 1104 public: 1105 1106 /** Find a real root of a real function f(x) via bracketing. 1107 * 1108 * Given a function `f` and a range `[a .. b]` such that `f(a)` 1109 * and `f(b)` have opposite signs or at least one of them equals ±0, 1110 * returns the value of `x` in 1111 * the range which is closest to a root of `f(x)`. If `f(x)` 1112 * has more than one root in the range, one will be chosen 1113 * arbitrarily. If `f(x)` returns NaN, NaN will be returned; 1114 * otherwise, this algorithm is guaranteed to succeed. 1115 * 1116 * Uses an algorithm based on TOMS748, which uses inverse cubic 1117 * interpolation whenever possible, otherwise reverting to parabolic 1118 * or secant interpolation. Compared to TOMS748, this implementation 1119 * improves worst-case performance by a factor of more than 100, and 1120 * typical performance by a factor of 2. For 80-bit reals, most 1121 * problems require 8 to 15 calls to `f(x)` to achieve full machine 1122 * precision. The worst-case performance (pathological cases) is 1123 * approximately twice the number of bits. 1124 * 1125 * References: "On Enclosing Simple Roots of Nonlinear Equations", 1126 * G. Alefeld, F.A. Potra, Yixun Shi, Mathematics of Computation 61, 1127 * pp733-744 (1993). Fortran code available from 1128 * $(HTTP www.netlib.org,www.netlib.org) as algorithm TOMS478. 1129 * 1130 */ 1131 T findRoot(T, DF, DT)(scope DF f, const T a, const T b, 1132 scope DT tolerance) //= (T a, T b) => false) 1133 if ( 1134 isFloatingPoint!T && 1135 is(typeof(tolerance(T.init, T.init)) : bool) && 1136 is(typeof(f(T.init)) == R, R) && isFloatingPoint!R 1137 ) 1138 { 1139 immutable fa = f(a); 1140 if (fa == 0) 1141 return a; 1142 immutable fb = f(b); 1143 if (fb == 0) 1144 return b; 1145 immutable r = findRoot(f, a, b, fa, fb, tolerance); 1146 // Return the first value if it is smaller or NaN 1147 return !(fabs(r[2]) > fabs(r[3])) ? r[0] : r[1]; 1148 } 1149 1150 ///ditto 1151 T findRoot(T, DF)(scope DF f, const T a, const T b) 1152 { 1153 return findRoot(f, a, b, (T a, T b) => false); 1154 } 1155 1156 /** Find root of a real function f(x) by bracketing, allowing the 1157 * termination condition to be specified. 1158 * 1159 * Params: 1160 * 1161 * f = Function to be analyzed 1162 * 1163 * ax = Left bound of initial range of `f` known to contain the 1164 * root. 1165 * 1166 * bx = Right bound of initial range of `f` known to contain the 1167 * root. 1168 * 1169 * fax = Value of `f(ax)`. 1170 * 1171 * fbx = Value of `f(bx)`. `fax` and `fbx` should have opposite signs. 1172 * (`f(ax)` and `f(bx)` are commonly known in advance.) 1173 * 1174 * 1175 * tolerance = Defines an early termination condition. Receives the 1176 * current upper and lower bounds on the root. The 1177 * delegate must return `true` when these bounds are 1178 * acceptable. If this function always returns `false`, 1179 * full machine precision will be achieved. 1180 * 1181 * Returns: 1182 * 1183 * A tuple consisting of two ranges. The first two elements are the 1184 * range (in `x`) of the root, while the second pair of elements 1185 * are the corresponding function values at those points. If an exact 1186 * root was found, both of the first two elements will contain the 1187 * root, and the second pair of elements will be 0. 1188 */ 1189 Tuple!(T, T, R, R) findRoot(T, R, DF, DT)(scope DF f, 1190 const T ax, const T bx, const R fax, const R fbx, 1191 scope DT tolerance) // = (T a, T b) => false) 1192 if ( 1193 isFloatingPoint!T && 1194 is(typeof(tolerance(T.init, T.init)) : bool) && 1195 is(typeof(f(T.init)) == R) && isFloatingPoint!R 1196 ) 1197 in 1198 { 1199 assert(!ax.isNaN() && !bx.isNaN(), "Limits must not be NaN"); 1200 assert(signbit(fax) != signbit(fbx), "Parameters must bracket the root."); 1201 } 1202 do 1203 { 1204 // Author: Don Clugston. This code is (heavily) modified from TOMS748 1205 // (www.netlib.org). The changes to improve the worst-cast performance are 1206 // entirely original. 1207 1208 T a, b, d; // [a .. b] is our current bracket. d is the third best guess. 1209 R fa, fb, fd; // Values of f at a, b, d. 1210 bool done = false; // Has a root been found? 1211 1212 // Allow ax and bx to be provided in reverse order 1213 if (ax <= bx) 1214 { 1215 a = ax; fa = fax; 1216 b = bx; fb = fbx; 1217 } 1218 else 1219 { 1220 a = bx; fa = fbx; 1221 b = ax; fb = fax; 1222 } 1223 1224 // Test the function at point c; update brackets accordingly 1225 void bracket(T c) 1226 { 1227 R fc = f(c); 1228 if (fc == 0 || fc.isNaN()) // Exact solution, or NaN 1229 { 1230 a = c; 1231 fa = fc; 1232 d = c; 1233 fd = fc; 1234 done = true; 1235 return; 1236 } 1237 1238 // Determine new enclosing interval 1239 if (signbit(fa) != signbit(fc)) 1240 { 1241 d = b; 1242 fd = fb; 1243 b = c; 1244 fb = fc; 1245 } 1246 else 1247 { 1248 d = a; 1249 fd = fa; 1250 a = c; 1251 fa = fc; 1252 } 1253 } 1254 1255 /* Perform a secant interpolation. If the result would lie on a or b, or if 1256 a and b differ so wildly in magnitude that the result would be meaningless, 1257 perform a bisection instead. 1258 */ 1259 static T secant_interpolate(T a, T b, R fa, R fb) 1260 { 1261 if (( ((a - b) == a) && b != 0) || (a != 0 && ((b - a) == b))) 1262 { 1263 // Catastrophic cancellation 1264 if (a == 0) 1265 a = copysign(T(0), b); 1266 else if (b == 0) 1267 b = copysign(T(0), a); 1268 else if (signbit(a) != signbit(b)) 1269 return 0; 1270 T c = ieeeMean(a, b); 1271 return c; 1272 } 1273 // avoid overflow 1274 if (b - a > T.max) 1275 return b / 2 + a / 2; 1276 if (fb - fa > R.max) 1277 return a - (b - a) / 2; 1278 T c = a - (fa / (fb - fa)) * (b - a); 1279 if (c == a || c == b) 1280 return (a + b) / 2; 1281 return c; 1282 } 1283 1284 /* Uses 'numsteps' newton steps to approximate the zero in [a .. b] of the 1285 quadratic polynomial interpolating f(x) at a, b, and d. 1286 Returns: 1287 The approximate zero in [a .. b] of the quadratic polynomial. 1288 */ 1289 T newtonQuadratic(int numsteps) 1290 { 1291 // Find the coefficients of the quadratic polynomial. 1292 immutable T a0 = fa; 1293 immutable T a1 = (fb - fa)/(b - a); 1294 immutable T a2 = ((fd - fb)/(d - b) - a1)/(d - a); 1295 1296 // Determine the starting point of newton steps. 1297 T c = oppositeSigns(a2, fa) ? a : b; 1298 1299 // start the safeguarded newton steps. 1300 foreach (int i; 0 .. numsteps) 1301 { 1302 immutable T pc = a0 + (a1 + a2 * (c - b))*(c - a); 1303 immutable T pdc = a1 + a2*((2 * c) - (a + b)); 1304 if (pdc == 0) 1305 return a - a0 / a1; 1306 else 1307 c = c - pc / pdc; 1308 } 1309 return c; 1310 } 1311 1312 // On the first iteration we take a secant step: 1313 if (fa == 0 || fa.isNaN()) 1314 { 1315 done = true; 1316 b = a; 1317 fb = fa; 1318 } 1319 else if (fb == 0 || fb.isNaN()) 1320 { 1321 done = true; 1322 a = b; 1323 fa = fb; 1324 } 1325 else 1326 { 1327 bracket(secant_interpolate(a, b, fa, fb)); 1328 } 1329 1330 // Starting with the second iteration, higher-order interpolation can 1331 // be used. 1332 int itnum = 1; // Iteration number 1333 int baditer = 1; // Num bisections to take if an iteration is bad. 1334 T c, e; // e is our fourth best guess 1335 R fe; 1336 1337 whileloop: 1338 while (!done && (b != nextUp(a)) && !tolerance(a, b)) 1339 { 1340 T a0 = a, b0 = b; // record the brackets 1341 1342 // Do two higher-order (cubic or parabolic) interpolation steps. 1343 foreach (int QQ; 0 .. 2) 1344 { 1345 // Cubic inverse interpolation requires that 1346 // all four function values fa, fb, fd, and fe are distinct; 1347 // otherwise use quadratic interpolation. 1348 bool distinct = (fa != fb) && (fa != fd) && (fa != fe) 1349 && (fb != fd) && (fb != fe) && (fd != fe); 1350 // The first time, cubic interpolation is impossible. 1351 if (itnum<2) distinct = false; 1352 bool ok = distinct; 1353 if (distinct) 1354 { 1355 // Cubic inverse interpolation of f(x) at a, b, d, and e 1356 immutable q11 = (d - e) * fd / (fe - fd); 1357 immutable q21 = (b - d) * fb / (fd - fb); 1358 immutable q31 = (a - b) * fa / (fb - fa); 1359 immutable d21 = (b - d) * fd / (fd - fb); 1360 immutable d31 = (a - b) * fb / (fb - fa); 1361 1362 immutable q22 = (d21 - q11) * fb / (fe - fb); 1363 immutable q32 = (d31 - q21) * fa / (fd - fa); 1364 immutable d32 = (d31 - q21) * fd / (fd - fa); 1365 immutable q33 = (d32 - q22) * fa / (fe - fa); 1366 c = a + (q31 + q32 + q33); 1367 if (c.isNaN() || (c <= a) || (c >= b)) 1368 { 1369 // DAC: If the interpolation predicts a or b, it's 1370 // probable that it's the actual root. Only allow this if 1371 // we're already close to the root. 1372 if (c == a && a - b != a) 1373 { 1374 c = nextUp(a); 1375 } 1376 else if (c == b && a - b != -b) 1377 { 1378 c = nextDown(b); 1379 } 1380 else 1381 { 1382 ok = false; 1383 } 1384 } 1385 } 1386 if (!ok) 1387 { 1388 // DAC: Alefeld doesn't explain why the number of newton steps 1389 // should vary. 1390 c = newtonQuadratic(distinct ? 3 : 2); 1391 if (c.isNaN() || (c <= a) || (c >= b)) 1392 { 1393 // Failure, try a secant step: 1394 c = secant_interpolate(a, b, fa, fb); 1395 } 1396 } 1397 ++itnum; 1398 e = d; 1399 fe = fd; 1400 bracket(c); 1401 if (done || ( b == nextUp(a)) || tolerance(a, b)) 1402 break whileloop; 1403 if (itnum == 2) 1404 continue whileloop; 1405 } 1406 1407 // Now we take a double-length secant step: 1408 T u; 1409 R fu; 1410 if (fabs(fa) < fabs(fb)) 1411 { 1412 u = a; 1413 fu = fa; 1414 } 1415 else 1416 { 1417 u = b; 1418 fu = fb; 1419 } 1420 c = u - 2 * (fu / (fb - fa)) * (b - a); 1421 1422 // DAC: If the secant predicts a value equal to an endpoint, it's 1423 // probably false. 1424 if (c == a || c == b || c.isNaN() || fabs(c - u) > (b - a) / 2) 1425 { 1426 if ((a-b) == a || (b-a) == b) 1427 { 1428 if ((a>0 && b<0) || (a<0 && b>0)) 1429 c = 0; 1430 else 1431 { 1432 if (a == 0) 1433 c = ieeeMean(copysign(T(0), b), b); 1434 else if (b == 0) 1435 c = ieeeMean(copysign(T(0), a), a); 1436 else 1437 c = ieeeMean(a, b); 1438 } 1439 } 1440 else 1441 { 1442 c = a + (b - a) / 2; 1443 } 1444 } 1445 e = d; 1446 fe = fd; 1447 bracket(c); 1448 if (done || (b == nextUp(a)) || tolerance(a, b)) 1449 break; 1450 1451 // IMPROVE THE WORST-CASE PERFORMANCE 1452 // We must ensure that the bounds reduce by a factor of 2 1453 // in binary space! every iteration. If we haven't achieved this 1454 // yet, or if we don't yet know what the exponent is, 1455 // perform a binary chop. 1456 1457 if ((a == 0 || b == 0 || 1458 (fabs(a) >= T(0.5) * fabs(b) && fabs(b) >= T(0.5) * fabs(a))) 1459 && (b - a) < T(0.25) * (b0 - a0)) 1460 { 1461 baditer = 1; 1462 continue; 1463 } 1464 1465 // DAC: If this happens on consecutive iterations, we probably have a 1466 // pathological function. Perform a number of bisections equal to the 1467 // total number of consecutive bad iterations. 1468 1469 if ((b - a) < T(0.25) * (b0 - a0)) 1470 baditer = 1; 1471 foreach (int QQ; 0 .. baditer) 1472 { 1473 e = d; 1474 fe = fd; 1475 1476 T w; 1477 if ((a>0 && b<0) || (a<0 && b>0)) 1478 w = 0; 1479 else 1480 { 1481 T usea = a; 1482 T useb = b; 1483 if (a == 0) 1484 usea = copysign(T(0), b); 1485 else if (b == 0) 1486 useb = copysign(T(0), a); 1487 w = ieeeMean(usea, useb); 1488 } 1489 bracket(w); 1490 } 1491 ++baditer; 1492 } 1493 return Tuple!(T, T, R, R)(a, b, fa, fb); 1494 } 1495 1496 ///ditto 1497 Tuple!(T, T, R, R) findRoot(T, R, DF)(scope DF f, 1498 const T ax, const T bx, const R fax, const R fbx) 1499 { 1500 return findRoot(f, ax, bx, fax, fbx, (T a, T b) => false); 1501 } 1502 1503 ///ditto 1504 T findRoot(T, R)(scope R delegate(T) f, const T a, const T b, 1505 scope bool delegate(T lo, T hi) tolerance = (T a, T b) => false) 1506 { 1507 return findRoot!(T, R delegate(T), bool delegate(T lo, T hi))(f, a, b, tolerance); 1508 } 1509 1510 @safe nothrow unittest 1511 { 1512 int numProblems = 0; 1513 int numCalls; 1514 1515 void testFindRoot(real delegate(real) @nogc @safe nothrow pure f , real x1, real x2) @nogc @safe nothrow pure 1516 { 1517 //numCalls=0; 1518 //++numProblems; 1519 assert(!x1.isNaN() && !x2.isNaN()); 1520 assert(signbit(f(x1)) != signbit(f(x2))); 1521 auto result = findRoot(f, x1, x2, f(x1), f(x2), 1522 (real lo, real hi) { return false; }); 1523 1524 auto flo = f(result[0]); 1525 auto fhi = f(result[1]); 1526 if (flo != 0) 1527 { 1528 assert(oppositeSigns(flo, fhi)); 1529 } 1530 } 1531 1532 // Test functions 1533 real cubicfn(real x) @nogc @safe nothrow pure 1534 { 1535 //++numCalls; 1536 if (x>float.max) 1537 x = float.max; 1538 if (x<-float.max) 1539 x = -float.max; 1540 // This has a single real root at -59.286543284815 1541 return 0.386*x*x*x + 23*x*x + 15.7*x + 525.2; 1542 } 1543 // Test a function with more than one root. 1544 real multisine(real x) { ++numCalls; return sin(x); } 1545 testFindRoot( &multisine, 6, 90); 1546 testFindRoot(&cubicfn, -100, 100); 1547 testFindRoot( &cubicfn, -double.max, real.max); 1548 1549 1550 /* Tests from the paper: 1551 * "On Enclosing Simple Roots of Nonlinear Equations", G. Alefeld, F.A. Potra, 1552 * Yixun Shi, Mathematics of Computation 61, pp733-744 (1993). 1553 */ 1554 // Parameters common to many alefeld tests. 1555 int n; 1556 real ale_a, ale_b; 1557 1558 int powercalls = 0; 1559 1560 real power(real x) 1561 { 1562 ++powercalls; 1563 ++numCalls; 1564 return pow(x, n) + double.min_normal; 1565 } 1566 int [] power_nvals = [3, 5, 7, 9, 19, 25]; 1567 // Alefeld paper states that pow(x,n) is a very poor case, where bisection 1568 // outperforms his method, and gives total numcalls = 1569 // 921 for bisection (2.4 calls per bit), 1830 for Alefeld (4.76/bit), 1570 // 2624 for brent (6.8/bit) 1571 // ... but that is for double, not real80. 1572 // This poor performance seems mainly due to catastrophic cancellation, 1573 // which is avoided here by the use of ieeeMean(). 1574 // I get: 231 (0.48/bit). 1575 // IE this is 10X faster in Alefeld's worst case 1576 numProblems=0; 1577 foreach (k; power_nvals) 1578 { 1579 n = k; 1580 testFindRoot(&power, -1, 10); 1581 } 1582 1583 int powerProblems = numProblems; 1584 1585 // Tests from Alefeld paper 1586 1587 int [9] alefeldSums; 1588 real alefeld0(real x) 1589 { 1590 ++alefeldSums[0]; 1591 ++numCalls; 1592 real q = sin(x) - x/2; 1593 for (int i=1; i<20; ++i) 1594 q+=(2*i-5.0)*(2*i-5.0)/((x-i*i)*(x-i*i)*(x-i*i)); 1595 return q; 1596 } 1597 real alefeld1(real x) 1598 { 1599 ++numCalls; 1600 ++alefeldSums[1]; 1601 return ale_a*x + exp(ale_b * x); 1602 } 1603 real alefeld2(real x) 1604 { 1605 ++numCalls; 1606 ++alefeldSums[2]; 1607 return pow(x, n) - ale_a; 1608 } 1609 real alefeld3(real x) 1610 { 1611 ++numCalls; 1612 ++alefeldSums[3]; 1613 return (1.0 +pow(1.0L-n, 2))*x - pow(1.0L-n*x, 2); 1614 } 1615 real alefeld4(real x) 1616 { 1617 ++numCalls; 1618 ++alefeldSums[4]; 1619 return x*x - pow(1-x, n); 1620 } 1621 real alefeld5(real x) 1622 { 1623 ++numCalls; 1624 ++alefeldSums[5]; 1625 return (1+pow(1.0L-n, 4))*x - pow(1.0L-n*x, 4); 1626 } 1627 real alefeld6(real x) 1628 { 1629 ++numCalls; 1630 ++alefeldSums[6]; 1631 return exp(-n*x)*(x-1.01L) + pow(x, n); 1632 } 1633 real alefeld7(real x) 1634 { 1635 ++numCalls; 1636 ++alefeldSums[7]; 1637 return (n*x-1)/((n-1)*x); 1638 } 1639 1640 numProblems=0; 1641 testFindRoot(&alefeld0, PI_2, PI); 1642 for (n=1; n <= 10; ++n) 1643 { 1644 testFindRoot(&alefeld0, n*n+1e-9L, (n+1)*(n+1)-1e-9L); 1645 } 1646 ale_a = -40; ale_b = -1; 1647 testFindRoot(&alefeld1, -9, 31); 1648 ale_a = -100; ale_b = -2; 1649 testFindRoot(&alefeld1, -9, 31); 1650 ale_a = -200; ale_b = -3; 1651 testFindRoot(&alefeld1, -9, 31); 1652 int [] nvals_3 = [1, 2, 5, 10, 15, 20]; 1653 int [] nvals_5 = [1, 2, 4, 5, 8, 15, 20]; 1654 int [] nvals_6 = [1, 5, 10, 15, 20]; 1655 int [] nvals_7 = [2, 5, 15, 20]; 1656 1657 for (int i=4; i<12; i+=2) 1658 { 1659 n = i; 1660 ale_a = 0.2; 1661 testFindRoot(&alefeld2, 0, 5); 1662 ale_a=1; 1663 testFindRoot(&alefeld2, 0.95, 4.05); 1664 testFindRoot(&alefeld2, 0, 1.5); 1665 } 1666 foreach (i; nvals_3) 1667 { 1668 n=i; 1669 testFindRoot(&alefeld3, 0, 1); 1670 } 1671 foreach (i; nvals_3) 1672 { 1673 n=i; 1674 testFindRoot(&alefeld4, 0, 1); 1675 } 1676 foreach (i; nvals_5) 1677 { 1678 n=i; 1679 testFindRoot(&alefeld5, 0, 1); 1680 } 1681 foreach (i; nvals_6) 1682 { 1683 n=i; 1684 testFindRoot(&alefeld6, 0, 1); 1685 } 1686 foreach (i; nvals_7) 1687 { 1688 n=i; 1689 testFindRoot(&alefeld7, 0.01L, 1); 1690 } 1691 real worstcase(real x) 1692 { 1693 ++numCalls; 1694 return x<0.3*real.max? -0.999e-3 : 1.0; 1695 } 1696 testFindRoot(&worstcase, -real.max, real.max); 1697 1698 // just check that the double + float cases compile 1699 findRoot((double x){ return 0.0; }, -double.max, double.max); 1700 findRoot((float x){ return 0.0f; }, -float.max, float.max); 1701 1702 /* 1703 int grandtotal=0; 1704 foreach (calls; alefeldSums) 1705 { 1706 grandtotal+=calls; 1707 } 1708 grandtotal-=2*numProblems; 1709 printf("\nALEFELD TOTAL = %d avg = %f (alefeld avg=19.3 for double)\n", 1710 grandtotal, (1.0*grandtotal)/numProblems); 1711 powercalls -= 2*powerProblems; 1712 printf("POWER TOTAL = %d avg = %f ", powercalls, 1713 (1.0*powercalls)/powerProblems); 1714 */ 1715 // https://issues.dlang.org/show_bug.cgi?id=14231 1716 auto xp = findRoot((float x) => x, 0f, 1f); 1717 auto xn = findRoot((float x) => x, -1f, -0f); 1718 } 1719 1720 //regression control 1721 @system unittest 1722 { 1723 // @system due to the case in the 2nd line 1724 static assert(__traits(compiles, findRoot((float x)=>cast(real) x, float.init, float.init))); 1725 static assert(__traits(compiles, findRoot!real((x)=>cast(double) x, real.init, real.init))); 1726 static assert(__traits(compiles, findRoot((real x)=>cast(double) x, real.init, real.init))); 1727 } 1728 1729 /++ 1730 Find a real minimum of a real function `f(x)` via bracketing. 1731 Given a function `f` and a range `(ax .. bx)`, 1732 returns the value of `x` in the range which is closest to a minimum of `f(x)`. 1733 `f` is never evaluted at the endpoints of `ax` and `bx`. 1734 If `f(x)` has more than one minimum in the range, one will be chosen arbitrarily. 1735 If `f(x)` returns NaN or -Infinity, `(x, f(x), NaN)` will be returned; 1736 otherwise, this algorithm is guaranteed to succeed. 1737 1738 Params: 1739 f = Function to be analyzed 1740 ax = Left bound of initial range of f known to contain the minimum. 1741 bx = Right bound of initial range of f known to contain the minimum. 1742 relTolerance = Relative tolerance. 1743 absTolerance = Absolute tolerance. 1744 1745 Preconditions: 1746 `ax` and `bx` shall be finite reals. $(BR) 1747 `relTolerance` shall be normal positive real. $(BR) 1748 `absTolerance` shall be normal positive real no less then `T.epsilon*2`. 1749 1750 Returns: 1751 A tuple consisting of `x`, `y = f(x)` and `error = 3 * (absTolerance * fabs(x) + relTolerance)`. 1752 1753 The method used is a combination of golden section search and 1754 successive parabolic interpolation. Convergence is never much slower 1755 than that for a Fibonacci search. 1756 1757 References: 1758 "Algorithms for Minimization without Derivatives", Richard Brent, Prentice-Hall, Inc. (1973) 1759 1760 See_Also: $(LREF findRoot), $(REF isNormal, std,math) 1761 +/ 1762 Tuple!(T, "x", Unqual!(ReturnType!DF), "y", T, "error") 1763 findLocalMin(T, DF)( 1764 scope DF f, 1765 const T ax, 1766 const T bx, 1767 const T relTolerance = sqrt(T.epsilon), 1768 const T absTolerance = sqrt(T.epsilon), 1769 ) 1770 if (isFloatingPoint!T 1771 && __traits(compiles, {T _ = DF.init(T.init);})) 1772 in 1773 { 1774 assert(isFinite(ax), "ax is not finite"); 1775 assert(isFinite(bx), "bx is not finite"); 1776 assert(isNormal(relTolerance), "relTolerance is not normal floating point number"); 1777 assert(isNormal(absTolerance), "absTolerance is not normal floating point number"); 1778 assert(relTolerance >= 0, "absTolerance is not positive"); 1779 assert(absTolerance >= T.epsilon*2, "absTolerance is not greater then `2*T.epsilon`"); 1780 } 1781 out (result) 1782 { 1783 assert(isFinite(result.x)); 1784 } 1785 do 1786 { 1787 alias R = Unqual!(CommonType!(ReturnType!DF, T)); 1788 // c is the squared inverse of the golden ratio 1789 // (3 - sqrt(5))/2 1790 // Value obtained from Wolfram Alpha. 1791 enum T c = 0x0.61c8864680b583ea0c633f9fa31237p+0L; 1792 enum T cm1 = 0x0.9e3779b97f4a7c15f39cc0605cedc8p+0L; 1793 R tolerance; 1794 T a = ax > bx ? bx : ax; 1795 T b = ax > bx ? ax : bx; 1796 // sequence of declarations suitable for SIMD instructions 1797 T v = a * cm1 + b * c; 1798 assert(isFinite(v)); 1799 R fv = f(v); 1800 if (isNaN(fv) || fv == -T.infinity) 1801 { 1802 return typeof(return)(v, fv, T.init); 1803 } 1804 T w = v; 1805 R fw = fv; 1806 T x = v; 1807 R fx = fv; 1808 size_t i; 1809 for (R d = 0, e = 0;;) 1810 { 1811 i++; 1812 T m = (a + b) / 2; 1813 // This fix is not part of the original algorithm 1814 if (!isFinite(m)) // fix infinity loop. Issue can be reproduced in R. 1815 { 1816 m = a / 2 + b / 2; 1817 if (!isFinite(m)) // fast-math compiler switch is enabled 1818 { 1819 //SIMD instructions can be used by compiler, do not reduce declarations 1820 int a_exp = void; 1821 int b_exp = void; 1822 immutable an = frexp(a, a_exp); 1823 immutable bn = frexp(b, b_exp); 1824 immutable am = ldexp(an, a_exp-1); 1825 immutable bm = ldexp(bn, b_exp-1); 1826 m = am + bm; 1827 if (!isFinite(m)) // wrong input: constraints are disabled in release mode 1828 { 1829 return typeof(return).init; 1830 } 1831 } 1832 } 1833 tolerance = absTolerance * fabs(x) + relTolerance; 1834 immutable t2 = tolerance * 2; 1835 // check stopping criterion 1836 if (!(fabs(x - m) > t2 - (b - a) / 2)) 1837 { 1838 break; 1839 } 1840 R p = 0; 1841 R q = 0; 1842 R r = 0; 1843 // fit parabola 1844 if (fabs(e) > tolerance) 1845 { 1846 immutable xw = x - w; 1847 immutable fxw = fx - fw; 1848 immutable xv = x - v; 1849 immutable fxv = fx - fv; 1850 immutable xwfxv = xw * fxv; 1851 immutable xvfxw = xv * fxw; 1852 p = xv * xvfxw - xw * xwfxv; 1853 q = (xvfxw - xwfxv) * 2; 1854 if (q > 0) 1855 p = -p; 1856 else 1857 q = -q; 1858 r = e; 1859 e = d; 1860 } 1861 T u; 1862 // a parabolic-interpolation step 1863 if (fabs(p) < fabs(q * r / 2) && p > q * (a - x) && p < q * (b - x)) 1864 { 1865 d = p / q; 1866 u = x + d; 1867 // f must not be evaluated too close to a or b 1868 if (u - a < t2 || b - u < t2) 1869 d = x < m ? tolerance : -tolerance; 1870 } 1871 // a golden-section step 1872 else 1873 { 1874 e = (x < m ? b : a) - x; 1875 d = c * e; 1876 } 1877 // f must not be evaluated too close to x 1878 u = x + (fabs(d) >= tolerance ? d : d > 0 ? tolerance : -tolerance); 1879 immutable fu = f(u); 1880 if (isNaN(fu) || fu == -T.infinity) 1881 { 1882 return typeof(return)(u, fu, T.init); 1883 } 1884 // update a, b, v, w, and x 1885 if (fu <= fx) 1886 { 1887 (u < x ? b : a) = x; 1888 v = w; fv = fw; 1889 w = x; fw = fx; 1890 x = u; fx = fu; 1891 } 1892 else 1893 { 1894 (u < x ? a : b) = u; 1895 if (fu <= fw || w == x) 1896 { 1897 v = w; fv = fw; 1898 w = u; fw = fu; 1899 } 1900 else if (fu <= fv || v == x || v == w) 1901 { // do not remove this braces 1902 v = u; fv = fu; 1903 } 1904 } 1905 } 1906 return typeof(return)(x, fx, tolerance * 3); 1907 } 1908 1909 /// 1910 @safe unittest 1911 { 1912 import std.math.operations : isClose; 1913 1914 auto ret = findLocalMin((double x) => (x-4)^^2, -1e7, 1e7); 1915 assert(ret.x.isClose(4.0)); 1916 assert(ret.y.isClose(0.0, 0.0, 1e-10)); 1917 } 1918 1919 @safe unittest 1920 { 1921 import std.meta : AliasSeq; 1922 static foreach (T; AliasSeq!(double, float, real)) 1923 { 1924 { 1925 auto ret = findLocalMin!T((T x) => (x-4)^^2, T.min_normal, 1e7); 1926 assert(ret.x.isClose(T(4))); 1927 assert(ret.y.isClose(T(0), 0.0, T.epsilon)); 1928 } 1929 { 1930 auto ret = findLocalMin!T((T x) => fabs(x-1), -T.max/4, T.max/4, T.min_normal, 2*T.epsilon); 1931 assert(isClose(ret.x, T(1))); 1932 assert(isClose(ret.y, T(0), 0.0, T.epsilon)); 1933 assert(ret.error <= 10 * T.epsilon); 1934 } 1935 { 1936 auto ret = findLocalMin!T((T x) => T.init, 0, 1, T.min_normal, 2*T.epsilon); 1937 assert(!ret.x.isNaN); 1938 assert(ret.y.isNaN); 1939 assert(ret.error.isNaN); 1940 } 1941 { 1942 auto ret = findLocalMin!T((T x) => log(x), 0, 1, T.min_normal, 2*T.epsilon); 1943 assert(ret.error < 3.00001 * ((2*T.epsilon)*fabs(ret.x)+ T.min_normal)); 1944 assert(ret.x >= 0 && ret.x <= ret.error); 1945 } 1946 { 1947 auto ret = findLocalMin!T((T x) => log(x), 0, T.max, T.min_normal, 2*T.epsilon); 1948 assert(ret.y < -18); 1949 assert(ret.error < 5e-08); 1950 assert(ret.x >= 0 && ret.x <= ret.error); 1951 } 1952 { 1953 auto ret = findLocalMin!T((T x) => -fabs(x), -1, 1, T.min_normal, 2*T.epsilon); 1954 assert(ret.x.fabs.isClose(T(1))); 1955 assert(ret.y.fabs.isClose(T(1))); 1956 assert(ret.error.isClose(T(0), 0.0, 100*T.epsilon)); 1957 } 1958 } 1959 } 1960 1961 /** 1962 Computes $(LINK2 https://en.wikipedia.org/wiki/Euclidean_distance, 1963 Euclidean distance) between input ranges `a` and 1964 `b`. The two ranges must have the same length. The three-parameter 1965 version stops computation as soon as the distance is greater than or 1966 equal to `limit` (this is useful to save computation if a small 1967 distance is sought). 1968 */ 1969 CommonType!(ElementType!(Range1), ElementType!(Range2)) 1970 euclideanDistance(Range1, Range2)(Range1 a, Range2 b) 1971 if (isInputRange!(Range1) && isInputRange!(Range2)) 1972 { 1973 enum bool haveLen = hasLength!(Range1) && hasLength!(Range2); 1974 static if (haveLen) assert(a.length == b.length); 1975 Unqual!(typeof(return)) result = 0; 1976 for (; !a.empty; a.popFront(), b.popFront()) 1977 { 1978 immutable t = a.front - b.front; 1979 result += t * t; 1980 } 1981 static if (!haveLen) assert(b.empty); 1982 return sqrt(result); 1983 } 1984 1985 /// Ditto 1986 CommonType!(ElementType!(Range1), ElementType!(Range2)) 1987 euclideanDistance(Range1, Range2, F)(Range1 a, Range2 b, F limit) 1988 if (isInputRange!(Range1) && isInputRange!(Range2)) 1989 { 1990 limit *= limit; 1991 enum bool haveLen = hasLength!(Range1) && hasLength!(Range2); 1992 static if (haveLen) assert(a.length == b.length); 1993 Unqual!(typeof(return)) result = 0; 1994 for (; ; a.popFront(), b.popFront()) 1995 { 1996 if (a.empty) 1997 { 1998 static if (!haveLen) assert(b.empty); 1999 break; 2000 } 2001 immutable t = a.front - b.front; 2002 result += t * t; 2003 if (result >= limit) break; 2004 } 2005 return sqrt(result); 2006 } 2007 2008 @safe unittest 2009 { 2010 import std.meta : AliasSeq; 2011 static foreach (T; AliasSeq!(double, const double, immutable double)) 2012 {{ 2013 T[] a = [ 1.0, 2.0, ]; 2014 T[] b = [ 4.0, 6.0, ]; 2015 assert(euclideanDistance(a, b) == 5); 2016 assert(euclideanDistance(a, b, 6) == 5); 2017 assert(euclideanDistance(a, b, 5) == 5); 2018 assert(euclideanDistance(a, b, 4) == 5); 2019 assert(euclideanDistance(a, b, 2) == 3); 2020 }} 2021 } 2022 2023 /** 2024 Computes the $(LINK2 https://en.wikipedia.org/wiki/Dot_product, 2025 dot product) of input ranges `a` and $(D 2026 b). The two ranges must have the same length. If both ranges define 2027 length, the check is done once; otherwise, it is done at each 2028 iteration. 2029 */ 2030 CommonType!(ElementType!(Range1), ElementType!(Range2)) 2031 dotProduct(Range1, Range2)(Range1 a, Range2 b) 2032 if (isInputRange!(Range1) && isInputRange!(Range2) && 2033 !(isArray!(Range1) && isArray!(Range2))) 2034 { 2035 enum bool haveLen = hasLength!(Range1) && hasLength!(Range2); 2036 static if (haveLen) assert(a.length == b.length); 2037 Unqual!(typeof(return)) result = 0; 2038 for (; !a.empty; a.popFront(), b.popFront()) 2039 { 2040 result += a.front * b.front; 2041 } 2042 static if (!haveLen) assert(b.empty); 2043 return result; 2044 } 2045 2046 /// Ditto 2047 CommonType!(F1, F2) 2048 dotProduct(F1, F2)(in F1[] avector, in F2[] bvector) 2049 { 2050 immutable n = avector.length; 2051 assert(n == bvector.length); 2052 auto avec = avector.ptr, bvec = bvector.ptr; 2053 Unqual!(typeof(return)) sum0 = 0, sum1 = 0; 2054 2055 const all_endp = avec + n; 2056 const smallblock_endp = avec + (n & ~3); 2057 const bigblock_endp = avec + (n & ~15); 2058 2059 for (; avec != bigblock_endp; avec += 16, bvec += 16) 2060 { 2061 sum0 += avec[0] * bvec[0]; 2062 sum1 += avec[1] * bvec[1]; 2063 sum0 += avec[2] * bvec[2]; 2064 sum1 += avec[3] * bvec[3]; 2065 sum0 += avec[4] * bvec[4]; 2066 sum1 += avec[5] * bvec[5]; 2067 sum0 += avec[6] * bvec[6]; 2068 sum1 += avec[7] * bvec[7]; 2069 sum0 += avec[8] * bvec[8]; 2070 sum1 += avec[9] * bvec[9]; 2071 sum0 += avec[10] * bvec[10]; 2072 sum1 += avec[11] * bvec[11]; 2073 sum0 += avec[12] * bvec[12]; 2074 sum1 += avec[13] * bvec[13]; 2075 sum0 += avec[14] * bvec[14]; 2076 sum1 += avec[15] * bvec[15]; 2077 } 2078 2079 for (; avec != smallblock_endp; avec += 4, bvec += 4) 2080 { 2081 sum0 += avec[0] * bvec[0]; 2082 sum1 += avec[1] * bvec[1]; 2083 sum0 += avec[2] * bvec[2]; 2084 sum1 += avec[3] * bvec[3]; 2085 } 2086 2087 sum0 += sum1; 2088 2089 /* Do trailing portion in naive loop. */ 2090 while (avec != all_endp) 2091 { 2092 sum0 += *avec * *bvec; 2093 ++avec; 2094 ++bvec; 2095 } 2096 2097 return sum0; 2098 } 2099 2100 /// ditto 2101 F dotProduct(F, uint N)(const ref scope F[N] a, const ref scope F[N] b) 2102 if (N <= 16) 2103 { 2104 F sum0 = 0; 2105 F sum1 = 0; 2106 static foreach (i; 0 .. N / 2) 2107 { 2108 sum0 += a[i*2] * b[i*2]; 2109 sum1 += a[i*2+1] * b[i*2+1]; 2110 } 2111 static if (N % 2 == 1) 2112 { 2113 sum0 += a[N-1] * b[N-1]; 2114 } 2115 return sum0 + sum1; 2116 } 2117 2118 @system unittest 2119 { 2120 // @system due to dotProduct and assertCTFEable 2121 import std.exception : assertCTFEable; 2122 import std.meta : AliasSeq; 2123 static foreach (T; AliasSeq!(double, const double, immutable double)) 2124 {{ 2125 T[] a = [ 1.0, 2.0, ]; 2126 T[] b = [ 4.0, 6.0, ]; 2127 assert(dotProduct(a, b) == 16); 2128 assert(dotProduct([1, 3, -5], [4, -2, -1]) == 3); 2129 // Test with fixed-length arrays. 2130 T[2] c = [ 1.0, 2.0, ]; 2131 T[2] d = [ 4.0, 6.0, ]; 2132 assert(dotProduct(c, d) == 16); 2133 T[3] e = [1, 3, -5]; 2134 T[3] f = [4, -2, -1]; 2135 assert(dotProduct(e, f) == 3); 2136 }} 2137 2138 // Make sure the unrolled loop codepath gets tested. 2139 static const x = 2140 [1.0, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22]; 2141 static const y = 2142 [2.0, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23]; 2143 assertCTFEable!({ assert(dotProduct(x, y) == 4048); }); 2144 } 2145 2146 /** 2147 Computes the $(LINK2 https://en.wikipedia.org/wiki/Cosine_similarity, 2148 cosine similarity) of input ranges `a` and $(D 2149 b). The two ranges must have the same length. If both ranges define 2150 length, the check is done once; otherwise, it is done at each 2151 iteration. If either range has all-zero elements, return 0. 2152 */ 2153 CommonType!(ElementType!(Range1), ElementType!(Range2)) 2154 cosineSimilarity(Range1, Range2)(Range1 a, Range2 b) 2155 if (isInputRange!(Range1) && isInputRange!(Range2)) 2156 { 2157 enum bool haveLen = hasLength!(Range1) && hasLength!(Range2); 2158 static if (haveLen) assert(a.length == b.length); 2159 Unqual!(typeof(return)) norma = 0, normb = 0, dotprod = 0; 2160 for (; !a.empty; a.popFront(), b.popFront()) 2161 { 2162 immutable t1 = a.front, t2 = b.front; 2163 norma += t1 * t1; 2164 normb += t2 * t2; 2165 dotprod += t1 * t2; 2166 } 2167 static if (!haveLen) assert(b.empty); 2168 if (norma == 0 || normb == 0) return 0; 2169 return dotprod / sqrt(norma * normb); 2170 } 2171 2172 @safe unittest 2173 { 2174 import std.meta : AliasSeq; 2175 static foreach (T; AliasSeq!(double, const double, immutable double)) 2176 {{ 2177 T[] a = [ 1.0, 2.0, ]; 2178 T[] b = [ 4.0, 3.0, ]; 2179 assert(isClose( 2180 cosineSimilarity(a, b), 10.0 / sqrt(5.0 * 25), 2181 0.01)); 2182 }} 2183 } 2184 2185 /** 2186 Normalizes values in `range` by multiplying each element with a 2187 number chosen such that values sum up to `sum`. If elements in $(D 2188 range) sum to zero, assigns $(D sum / range.length) to 2189 all. Normalization makes sense only if all elements in `range` are 2190 positive. `normalize` assumes that is the case without checking it. 2191 2192 Returns: `true` if normalization completed normally, `false` if 2193 all elements in `range` were zero or if `range` is empty. 2194 */ 2195 bool normalize(R)(R range, ElementType!(R) sum = 1) 2196 if (isForwardRange!(R)) 2197 { 2198 ElementType!(R) s = 0; 2199 // Step 1: Compute sum and length of the range 2200 static if (hasLength!(R)) 2201 { 2202 const length = range.length; 2203 foreach (e; range) 2204 { 2205 s += e; 2206 } 2207 } 2208 else 2209 { 2210 uint length = 0; 2211 foreach (e; range) 2212 { 2213 s += e; 2214 ++length; 2215 } 2216 } 2217 // Step 2: perform normalization 2218 if (s == 0) 2219 { 2220 if (length) 2221 { 2222 immutable f = sum / range.length; 2223 foreach (ref e; range) e = f; 2224 } 2225 return false; 2226 } 2227 // The path most traveled 2228 assert(s >= 0); 2229 immutable f = sum / s; 2230 foreach (ref e; range) 2231 e *= f; 2232 return true; 2233 } 2234 2235 /// 2236 @safe unittest 2237 { 2238 double[] a = []; 2239 assert(!normalize(a)); 2240 a = [ 1.0, 3.0 ]; 2241 assert(normalize(a)); 2242 assert(a == [ 0.25, 0.75 ]); 2243 assert(normalize!(typeof(a))(a, 50)); // a = [12.5, 37.5] 2244 a = [ 0.0, 0.0 ]; 2245 assert(!normalize(a)); 2246 assert(a == [ 0.5, 0.5 ]); 2247 } 2248 2249 /** 2250 Compute the sum of binary logarithms of the input range `r`. 2251 The error of this method is much smaller than with a naive sum of log2. 2252 */ 2253 ElementType!Range sumOfLog2s(Range)(Range r) 2254 if (isInputRange!Range && isFloatingPoint!(ElementType!Range)) 2255 { 2256 long exp = 0; 2257 Unqual!(typeof(return)) x = 1; 2258 foreach (e; r) 2259 { 2260 if (e < 0) 2261 return typeof(return).nan; 2262 int lexp = void; 2263 x *= frexp(e, lexp); 2264 exp += lexp; 2265 if (x < 0.5) 2266 { 2267 x *= 2; 2268 exp--; 2269 } 2270 } 2271 return exp + log2(x); 2272 } 2273 2274 /// 2275 @safe unittest 2276 { 2277 import std.math.traits : isNaN; 2278 2279 assert(sumOfLog2s(new double[0]) == 0); 2280 assert(sumOfLog2s([0.0L]) == -real.infinity); 2281 assert(sumOfLog2s([-0.0L]) == -real.infinity); 2282 assert(sumOfLog2s([2.0L]) == 1); 2283 assert(sumOfLog2s([-2.0L]).isNaN()); 2284 assert(sumOfLog2s([real.nan]).isNaN()); 2285 assert(sumOfLog2s([-real.nan]).isNaN()); 2286 assert(sumOfLog2s([real.infinity]) == real.infinity); 2287 assert(sumOfLog2s([-real.infinity]).isNaN()); 2288 assert(sumOfLog2s([ 0.25, 0.25, 0.25, 0.125 ]) == -9); 2289 } 2290 2291 /** 2292 Computes $(LINK2 https://en.wikipedia.org/wiki/Entropy_(information_theory), 2293 _entropy) of input range `r` in bits. This 2294 function assumes (without checking) that the values in `r` are all 2295 in $(D [0, 1]). For the entropy to be meaningful, often `r` should 2296 be normalized too (i.e., its values should sum to 1). The 2297 two-parameter version stops evaluating as soon as the intermediate 2298 result is greater than or equal to `max`. 2299 */ 2300 ElementType!Range entropy(Range)(Range r) 2301 if (isInputRange!Range) 2302 { 2303 Unqual!(typeof(return)) result = 0.0; 2304 for (;!r.empty; r.popFront) 2305 { 2306 if (!r.front) continue; 2307 result -= r.front * log2(r.front); 2308 } 2309 return result; 2310 } 2311 2312 /// Ditto 2313 ElementType!Range entropy(Range, F)(Range r, F max) 2314 if (isInputRange!Range && 2315 !is(CommonType!(ElementType!Range, F) == void)) 2316 { 2317 Unqual!(typeof(return)) result = 0.0; 2318 for (;!r.empty; r.popFront) 2319 { 2320 if (!r.front) continue; 2321 result -= r.front * log2(r.front); 2322 if (result >= max) break; 2323 } 2324 return result; 2325 } 2326 2327 @safe unittest 2328 { 2329 import std.meta : AliasSeq; 2330 static foreach (T; AliasSeq!(double, const double, immutable double)) 2331 {{ 2332 T[] p = [ 0.0, 0, 0, 1 ]; 2333 assert(entropy(p) == 0); 2334 p = [ 0.25, 0.25, 0.25, 0.25 ]; 2335 assert(entropy(p) == 2); 2336 assert(entropy(p, 1) == 1); 2337 }} 2338 } 2339 2340 /** 2341 Computes the $(LINK2 https://en.wikipedia.org/wiki/Kullback%E2%80%93Leibler_divergence, 2342 Kullback-Leibler divergence) between input ranges 2343 `a` and `b`, which is the sum $(D ai * log(ai / bi)). The base 2344 of logarithm is 2. The ranges are assumed to contain elements in $(D 2345 [0, 1]). Usually the ranges are normalized probability distributions, 2346 but this is not required or checked by $(D 2347 kullbackLeiblerDivergence). If any element `bi` is zero and the 2348 corresponding element `ai` nonzero, returns infinity. (Otherwise, 2349 if $(D ai == 0 && bi == 0), the term $(D ai * log(ai / bi)) is 2350 considered zero.) If the inputs are normalized, the result is 2351 positive. 2352 */ 2353 CommonType!(ElementType!Range1, ElementType!Range2) 2354 kullbackLeiblerDivergence(Range1, Range2)(Range1 a, Range2 b) 2355 if (isInputRange!(Range1) && isInputRange!(Range2)) 2356 { 2357 enum bool haveLen = hasLength!(Range1) && hasLength!(Range2); 2358 static if (haveLen) assert(a.length == b.length); 2359 Unqual!(typeof(return)) result = 0; 2360 for (; !a.empty; a.popFront(), b.popFront()) 2361 { 2362 immutable t1 = a.front; 2363 if (t1 == 0) continue; 2364 immutable t2 = b.front; 2365 if (t2 == 0) return result.infinity; 2366 assert(t1 > 0 && t2 > 0); 2367 result += t1 * log2(t1 / t2); 2368 } 2369 static if (!haveLen) assert(b.empty); 2370 return result; 2371 } 2372 2373 /// 2374 @safe unittest 2375 { 2376 import std.math.operations : isClose; 2377 2378 double[] p = [ 0.0, 0, 0, 1 ]; 2379 assert(kullbackLeiblerDivergence(p, p) == 0); 2380 double[] p1 = [ 0.25, 0.25, 0.25, 0.25 ]; 2381 assert(kullbackLeiblerDivergence(p1, p1) == 0); 2382 assert(kullbackLeiblerDivergence(p, p1) == 2); 2383 assert(kullbackLeiblerDivergence(p1, p) == double.infinity); 2384 double[] p2 = [ 0.2, 0.2, 0.2, 0.4 ]; 2385 assert(isClose(kullbackLeiblerDivergence(p1, p2), 0.0719281, 1e-5)); 2386 assert(isClose(kullbackLeiblerDivergence(p2, p1), 0.0780719, 1e-5)); 2387 } 2388 2389 /** 2390 Computes the $(LINK2 https://en.wikipedia.org/wiki/Jensen%E2%80%93Shannon_divergence, 2391 Jensen-Shannon divergence) between `a` and $(D 2392 b), which is the sum $(D (ai * log(2 * ai / (ai + bi)) + bi * log(2 * 2393 bi / (ai + bi))) / 2). The base of logarithm is 2. The ranges are 2394 assumed to contain elements in $(D [0, 1]). Usually the ranges are 2395 normalized probability distributions, but this is not required or 2396 checked by `jensenShannonDivergence`. If the inputs are normalized, 2397 the result is bounded within $(D [0, 1]). The three-parameter version 2398 stops evaluations as soon as the intermediate result is greater than 2399 or equal to `limit`. 2400 */ 2401 CommonType!(ElementType!Range1, ElementType!Range2) 2402 jensenShannonDivergence(Range1, Range2)(Range1 a, Range2 b) 2403 if (isInputRange!Range1 && isInputRange!Range2 && 2404 is(CommonType!(ElementType!Range1, ElementType!Range2))) 2405 { 2406 enum bool haveLen = hasLength!(Range1) && hasLength!(Range2); 2407 static if (haveLen) assert(a.length == b.length); 2408 Unqual!(typeof(return)) result = 0; 2409 for (; !a.empty; a.popFront(), b.popFront()) 2410 { 2411 immutable t1 = a.front; 2412 immutable t2 = b.front; 2413 immutable avg = (t1 + t2) / 2; 2414 if (t1 != 0) 2415 { 2416 result += t1 * log2(t1 / avg); 2417 } 2418 if (t2 != 0) 2419 { 2420 result += t2 * log2(t2 / avg); 2421 } 2422 } 2423 static if (!haveLen) assert(b.empty); 2424 return result / 2; 2425 } 2426 2427 /// Ditto 2428 CommonType!(ElementType!Range1, ElementType!Range2) 2429 jensenShannonDivergence(Range1, Range2, F)(Range1 a, Range2 b, F limit) 2430 if (isInputRange!Range1 && isInputRange!Range2 && 2431 is(typeof(CommonType!(ElementType!Range1, ElementType!Range2).init 2432 >= F.init) : bool)) 2433 { 2434 enum bool haveLen = hasLength!(Range1) && hasLength!(Range2); 2435 static if (haveLen) assert(a.length == b.length); 2436 Unqual!(typeof(return)) result = 0; 2437 limit *= 2; 2438 for (; !a.empty; a.popFront(), b.popFront()) 2439 { 2440 immutable t1 = a.front; 2441 immutable t2 = b.front; 2442 immutable avg = (t1 + t2) / 2; 2443 if (t1 != 0) 2444 { 2445 result += t1 * log2(t1 / avg); 2446 } 2447 if (t2 != 0) 2448 { 2449 result += t2 * log2(t2 / avg); 2450 } 2451 if (result >= limit) break; 2452 } 2453 static if (!haveLen) assert(b.empty); 2454 return result / 2; 2455 } 2456 2457 /// 2458 @safe unittest 2459 { 2460 import std.math.operations : isClose; 2461 2462 double[] p = [ 0.0, 0, 0, 1 ]; 2463 assert(jensenShannonDivergence(p, p) == 0); 2464 double[] p1 = [ 0.25, 0.25, 0.25, 0.25 ]; 2465 assert(jensenShannonDivergence(p1, p1) == 0); 2466 assert(isClose(jensenShannonDivergence(p1, p), 0.548795, 1e-5)); 2467 double[] p2 = [ 0.2, 0.2, 0.2, 0.4 ]; 2468 assert(isClose(jensenShannonDivergence(p1, p2), 0.0186218, 1e-5)); 2469 assert(isClose(jensenShannonDivergence(p2, p1), 0.0186218, 1e-5)); 2470 assert(isClose(jensenShannonDivergence(p2, p1, 0.005), 0.00602366, 1e-5)); 2471 } 2472 2473 /** 2474 The so-called "all-lengths gap-weighted string kernel" computes a 2475 similarity measure between `s` and `t` based on all of their 2476 common subsequences of all lengths. Gapped subsequences are also 2477 included. 2478 2479 To understand what $(D gapWeightedSimilarity(s, t, lambda)) computes, 2480 consider first the case $(D lambda = 1) and the strings $(D s = 2481 ["Hello", "brave", "new", "world"]) and $(D t = ["Hello", "new", 2482 "world"]). In that case, `gapWeightedSimilarity` counts the 2483 following matches: 2484 2485 $(OL $(LI three matches of length 1, namely `"Hello"`, `"new"`, 2486 and `"world"`;) $(LI three matches of length 2, namely ($(D 2487 "Hello", "new")), ($(D "Hello", "world")), and ($(D "new", "world"));) 2488 $(LI one match of length 3, namely ($(D "Hello", "new", "world")).)) 2489 2490 The call $(D gapWeightedSimilarity(s, t, 1)) simply counts all of 2491 these matches and adds them up, returning 7. 2492 2493 ---- 2494 string[] s = ["Hello", "brave", "new", "world"]; 2495 string[] t = ["Hello", "new", "world"]; 2496 assert(gapWeightedSimilarity(s, t, 1) == 7); 2497 ---- 2498 2499 Note how the gaps in matching are simply ignored, for example ($(D 2500 "Hello", "new")) is deemed as good a match as ($(D "new", 2501 "world")). This may be too permissive for some applications. To 2502 eliminate gapped matches entirely, use $(D lambda = 0): 2503 2504 ---- 2505 string[] s = ["Hello", "brave", "new", "world"]; 2506 string[] t = ["Hello", "new", "world"]; 2507 assert(gapWeightedSimilarity(s, t, 0) == 4); 2508 ---- 2509 2510 The call above eliminated the gapped matches ($(D "Hello", "new")), 2511 ($(D "Hello", "world")), and ($(D "Hello", "new", "world")) from the 2512 tally. That leaves only 4 matches. 2513 2514 The most interesting case is when gapped matches still participate in 2515 the result, but not as strongly as ungapped matches. The result will 2516 be a smooth, fine-grained similarity measure between the input 2517 strings. This is where values of `lambda` between 0 and 1 enter 2518 into play: gapped matches are $(I exponentially penalized with the 2519 number of gaps) with base `lambda`. This means that an ungapped 2520 match adds 1 to the return value; a match with one gap in either 2521 string adds `lambda` to the return value; ...; a match with a total 2522 of `n` gaps in both strings adds $(D pow(lambda, n)) to the return 2523 value. In the example above, we have 4 matches without gaps, 2 matches 2524 with one gap, and 1 match with three gaps. The latter match is ($(D 2525 "Hello", "world")), which has two gaps in the first string and one gap 2526 in the second string, totaling to three gaps. Summing these up we get 2527 $(D 4 + 2 * lambda + pow(lambda, 3)). 2528 2529 ---- 2530 string[] s = ["Hello", "brave", "new", "world"]; 2531 string[] t = ["Hello", "new", "world"]; 2532 assert(gapWeightedSimilarity(s, t, 0.5) == 4 + 0.5 * 2 + 0.125); 2533 ---- 2534 2535 `gapWeightedSimilarity` is useful wherever a smooth similarity 2536 measure between sequences allowing for approximate matches is 2537 needed. The examples above are given with words, but any sequences 2538 with elements comparable for equality are allowed, e.g. characters or 2539 numbers. `gapWeightedSimilarity` uses a highly optimized dynamic 2540 programming implementation that needs $(D 16 * min(s.length, 2541 t.length)) extra bytes of memory and $(BIGOH s.length * t.length) time 2542 to complete. 2543 */ 2544 F gapWeightedSimilarity(alias comp = "a == b", R1, R2, F)(R1 s, R2 t, F lambda) 2545 if (isRandomAccessRange!(R1) && hasLength!(R1) && 2546 isRandomAccessRange!(R2) && hasLength!(R2)) 2547 { 2548 import core.exception : onOutOfMemoryError; 2549 import core.stdc.stdlib : malloc, free; 2550 import std.algorithm.mutation : swap; 2551 import std.functional : binaryFun; 2552 2553 if (s.length < t.length) return gapWeightedSimilarity(t, s, lambda); 2554 if (!t.length) return 0; 2555 2556 auto dpvi = cast(F*) malloc(F.sizeof * 2 * t.length); 2557 if (!dpvi) 2558 onOutOfMemoryError(); 2559 2560 auto dpvi1 = dpvi + t.length; 2561 scope(exit) free(dpvi < dpvi1 ? dpvi : dpvi1); 2562 dpvi[0 .. t.length] = 0; 2563 dpvi1[0] = 0; 2564 immutable lambda2 = lambda * lambda; 2565 2566 F result = 0; 2567 foreach (i; 0 .. s.length) 2568 { 2569 const si = s[i]; 2570 for (size_t j = 0;;) 2571 { 2572 F dpsij = void; 2573 if (binaryFun!(comp)(si, t[j])) 2574 { 2575 dpsij = 1 + dpvi[j]; 2576 result += dpsij; 2577 } 2578 else 2579 { 2580 dpsij = 0; 2581 } 2582 immutable j1 = j + 1; 2583 if (j1 == t.length) break; 2584 dpvi1[j1] = dpsij + lambda * (dpvi1[j] + dpvi[j1]) - 2585 lambda2 * dpvi[j]; 2586 j = j1; 2587 } 2588 swap(dpvi, dpvi1); 2589 } 2590 return result; 2591 } 2592 2593 @system unittest 2594 { 2595 string[] s = ["Hello", "brave", "new", "world"]; 2596 string[] t = ["Hello", "new", "world"]; 2597 assert(gapWeightedSimilarity(s, t, 1) == 7); 2598 assert(gapWeightedSimilarity(s, t, 0) == 4); 2599 assert(gapWeightedSimilarity(s, t, 0.5) == 4 + 2 * 0.5 + 0.125); 2600 } 2601 2602 /** 2603 The similarity per `gapWeightedSimilarity` has an issue in that it 2604 grows with the lengths of the two strings, even though the strings are 2605 not actually very similar. For example, the range $(D ["Hello", 2606 "world"]) is increasingly similar with the range $(D ["Hello", 2607 "world", "world", "world",...]) as more instances of `"world"` are 2608 appended. To prevent that, `gapWeightedSimilarityNormalized` 2609 computes a normalized version of the similarity that is computed as 2610 $(D gapWeightedSimilarity(s, t, lambda) / 2611 sqrt(gapWeightedSimilarity(s, t, lambda) * gapWeightedSimilarity(s, t, 2612 lambda))). The function `gapWeightedSimilarityNormalized` (a 2613 so-called normalized kernel) is bounded in $(D [0, 1]), reaches `0` 2614 only for ranges that don't match in any position, and `1` only for 2615 identical ranges. 2616 2617 The optional parameters `sSelfSim` and `tSelfSim` are meant for 2618 avoiding duplicate computation. Many applications may have already 2619 computed $(D gapWeightedSimilarity(s, s, lambda)) and/or $(D 2620 gapWeightedSimilarity(t, t, lambda)). In that case, they can be passed 2621 as `sSelfSim` and `tSelfSim`, respectively. 2622 */ 2623 Select!(isFloatingPoint!(F), F, double) 2624 gapWeightedSimilarityNormalized(alias comp = "a == b", R1, R2, F) 2625 (R1 s, R2 t, F lambda, F sSelfSim = F.init, F tSelfSim = F.init) 2626 if (isRandomAccessRange!(R1) && hasLength!(R1) && 2627 isRandomAccessRange!(R2) && hasLength!(R2)) 2628 { 2629 static bool uncomputed(F n) 2630 { 2631 static if (isFloatingPoint!(F)) 2632 return isNaN(n); 2633 else 2634 return n == n.init; 2635 } 2636 if (uncomputed(sSelfSim)) 2637 sSelfSim = gapWeightedSimilarity!(comp)(s, s, lambda); 2638 if (sSelfSim == 0) return 0; 2639 if (uncomputed(tSelfSim)) 2640 tSelfSim = gapWeightedSimilarity!(comp)(t, t, lambda); 2641 if (tSelfSim == 0) return 0; 2642 2643 return gapWeightedSimilarity!(comp)(s, t, lambda) / 2644 sqrt(cast(typeof(return)) sSelfSim * tSelfSim); 2645 } 2646 2647 /// 2648 @system unittest 2649 { 2650 import std.math.operations : isClose; 2651 import std.math.algebraic : sqrt; 2652 2653 string[] s = ["Hello", "brave", "new", "world"]; 2654 string[] t = ["Hello", "new", "world"]; 2655 assert(gapWeightedSimilarity(s, s, 1) == 15); 2656 assert(gapWeightedSimilarity(t, t, 1) == 7); 2657 assert(gapWeightedSimilarity(s, t, 1) == 7); 2658 assert(isClose(gapWeightedSimilarityNormalized(s, t, 1), 2659 7.0 / sqrt(15.0 * 7), 0.01)); 2660 } 2661 2662 /** 2663 Similar to `gapWeightedSimilarity`, just works in an incremental 2664 manner by first revealing the matches of length 1, then gapped matches 2665 of length 2, and so on. The memory requirement is $(BIGOH s.length * 2666 t.length). The time complexity is $(BIGOH s.length * t.length) time 2667 for computing each step. Continuing on the previous example: 2668 2669 The implementation is based on the pseudocode in Fig. 4 of the paper 2670 $(HTTP jmlr.csail.mit.edu/papers/volume6/rousu05a/rousu05a.pdf, 2671 "Efficient Computation of Gapped Substring Kernels on Large Alphabets") 2672 by Rousu et al., with additional algorithmic and systems-level 2673 optimizations. 2674 */ 2675 struct GapWeightedSimilarityIncremental(Range, F = double) 2676 if (isRandomAccessRange!(Range) && hasLength!(Range)) 2677 { 2678 import core.stdc.stdlib : malloc, realloc, alloca, free; 2679 2680 private: 2681 Range s, t; 2682 F currentValue = 0; 2683 F* kl; 2684 size_t gram = void; 2685 F lambda = void, lambda2 = void; 2686 2687 public: 2688 /** 2689 Constructs an object given two ranges `s` and `t` and a penalty 2690 `lambda`. Constructor completes in $(BIGOH s.length * t.length) 2691 time and computes all matches of length 1. 2692 */ 2693 this(Range s, Range t, F lambda) 2694 { 2695 import core.exception : onOutOfMemoryError; 2696 2697 assert(lambda > 0); 2698 this.gram = 0; 2699 this.lambda = lambda; 2700 this.lambda2 = lambda * lambda; // for efficiency only 2701 2702 size_t iMin = size_t.max, jMin = size_t.max, 2703 iMax = 0, jMax = 0; 2704 /* initialize */ 2705 Tuple!(size_t, size_t) * k0; 2706 size_t k0len; 2707 scope(exit) free(k0); 2708 currentValue = 0; 2709 foreach (i, si; s) 2710 { 2711 foreach (j; 0 .. t.length) 2712 { 2713 if (si != t[j]) continue; 2714 k0 = cast(typeof(k0)) realloc(k0, ++k0len * (*k0).sizeof); 2715 with (k0[k0len - 1]) 2716 { 2717 field[0] = i; 2718 field[1] = j; 2719 } 2720 // Maintain the minimum and maximum i and j 2721 if (iMin > i) iMin = i; 2722 if (iMax < i) iMax = i; 2723 if (jMin > j) jMin = j; 2724 if (jMax < j) jMax = j; 2725 } 2726 } 2727 2728 if (iMin > iMax) return; 2729 assert(k0len); 2730 2731 currentValue = k0len; 2732 // Chop strings down to the useful sizes 2733 s = s[iMin .. iMax + 1]; 2734 t = t[jMin .. jMax + 1]; 2735 this.s = s; 2736 this.t = t; 2737 2738 kl = cast(F*) malloc(s.length * t.length * F.sizeof); 2739 if (!kl) 2740 onOutOfMemoryError(); 2741 2742 kl[0 .. s.length * t.length] = 0; 2743 foreach (pos; 0 .. k0len) 2744 { 2745 with (k0[pos]) 2746 { 2747 kl[(field[0] - iMin) * t.length + field[1] -jMin] = lambda2; 2748 } 2749 } 2750 } 2751 2752 /** 2753 Returns: `this`. 2754 */ 2755 ref GapWeightedSimilarityIncremental opSlice() 2756 { 2757 return this; 2758 } 2759 2760 /** 2761 Computes the match of the popFront length. Completes in $(BIGOH s.length * 2762 t.length) time. 2763 */ 2764 void popFront() 2765 { 2766 import std.algorithm.mutation : swap; 2767 2768 // This is a large source of optimization: if similarity at 2769 // the gram-1 level was 0, then we can safely assume 2770 // similarity at the gram level is 0 as well. 2771 if (empty) return; 2772 2773 // Now attempt to match gapped substrings of length `gram' 2774 ++gram; 2775 currentValue = 0; 2776 2777 auto Si = cast(F*) alloca(t.length * F.sizeof); 2778 Si[0 .. t.length] = 0; 2779 foreach (i; 0 .. s.length) 2780 { 2781 const si = s[i]; 2782 F Sij_1 = 0; 2783 F Si_1j_1 = 0; 2784 auto kli = kl + i * t.length; 2785 for (size_t j = 0;;) 2786 { 2787 const klij = kli[j]; 2788 const Si_1j = Si[j]; 2789 const tmp = klij + lambda * (Si_1j + Sij_1) - lambda2 * Si_1j_1; 2790 // now update kl and currentValue 2791 if (si == t[j]) 2792 currentValue += kli[j] = lambda2 * Si_1j_1; 2793 else 2794 kli[j] = 0; 2795 // commit to Si 2796 Si[j] = tmp; 2797 if (++j == t.length) break; 2798 // get ready for the popFront step; virtually increment j, 2799 // so essentially stuffj_1 <-- stuffj 2800 Si_1j_1 = Si_1j; 2801 Sij_1 = tmp; 2802 } 2803 } 2804 currentValue /= pow(lambda, 2 * (gram + 1)); 2805 2806 version (none) 2807 { 2808 Si_1[0 .. t.length] = 0; 2809 kl[0 .. min(t.length, maxPerimeter + 1)] = 0; 2810 foreach (i; 1 .. min(s.length, maxPerimeter + 1)) 2811 { 2812 auto kli = kl + i * t.length; 2813 assert(s.length > i); 2814 const si = s[i]; 2815 auto kl_1i_1 = kl_1 + (i - 1) * t.length; 2816 kli[0] = 0; 2817 F lastS = 0; 2818 foreach (j; 1 .. min(maxPerimeter - i + 1, t.length)) 2819 { 2820 immutable j_1 = j - 1; 2821 immutable tmp = kl_1i_1[j_1] 2822 + lambda * (Si_1[j] + lastS) 2823 - lambda2 * Si_1[j_1]; 2824 kl_1i_1[j_1] = float.nan; 2825 Si_1[j_1] = lastS; 2826 lastS = tmp; 2827 if (si == t[j]) 2828 { 2829 currentValue += kli[j] = lambda2 * lastS; 2830 } 2831 else 2832 { 2833 kli[j] = 0; 2834 } 2835 } 2836 Si_1[t.length - 1] = lastS; 2837 } 2838 currentValue /= pow(lambda, 2 * (gram + 1)); 2839 // get ready for the popFront computation 2840 swap(kl, kl_1); 2841 } 2842 } 2843 2844 /** 2845 Returns: The gapped similarity at the current match length (initially 2846 1, grows with each call to `popFront`). 2847 */ 2848 @property F front() { return currentValue; } 2849 2850 /** 2851 Returns: Whether there are more matches. 2852 */ 2853 @property bool empty() 2854 { 2855 if (currentValue) return false; 2856 if (kl) 2857 { 2858 free(kl); 2859 kl = null; 2860 } 2861 return true; 2862 } 2863 } 2864 2865 /** 2866 Ditto 2867 */ 2868 GapWeightedSimilarityIncremental!(R, F) gapWeightedSimilarityIncremental(R, F) 2869 (R r1, R r2, F penalty) 2870 { 2871 return typeof(return)(r1, r2, penalty); 2872 } 2873 2874 /// 2875 @system unittest 2876 { 2877 string[] s = ["Hello", "brave", "new", "world"]; 2878 string[] t = ["Hello", "new", "world"]; 2879 auto simIter = gapWeightedSimilarityIncremental(s, t, 1.0); 2880 assert(simIter.front == 3); // three 1-length matches 2881 simIter.popFront(); 2882 assert(simIter.front == 3); // three 2-length matches 2883 simIter.popFront(); 2884 assert(simIter.front == 1); // one 3-length match 2885 simIter.popFront(); 2886 assert(simIter.empty); // no more match 2887 } 2888 2889 @system unittest 2890 { 2891 import std.conv : text; 2892 string[] s = ["Hello", "brave", "new", "world"]; 2893 string[] t = ["Hello", "new", "world"]; 2894 auto simIter = gapWeightedSimilarityIncremental(s, t, 1.0); 2895 //foreach (e; simIter) writeln(e); 2896 assert(simIter.front == 3); // three 1-length matches 2897 simIter.popFront(); 2898 assert(simIter.front == 3, text(simIter.front)); // three 2-length matches 2899 simIter.popFront(); 2900 assert(simIter.front == 1); // one 3-length matches 2901 simIter.popFront(); 2902 assert(simIter.empty); // no more match 2903 2904 s = ["Hello"]; 2905 t = ["bye"]; 2906 simIter = gapWeightedSimilarityIncremental(s, t, 0.5); 2907 assert(simIter.empty); 2908 2909 s = ["Hello"]; 2910 t = ["Hello"]; 2911 simIter = gapWeightedSimilarityIncremental(s, t, 0.5); 2912 assert(simIter.front == 1); // one match 2913 simIter.popFront(); 2914 assert(simIter.empty); 2915 2916 s = ["Hello", "world"]; 2917 t = ["Hello"]; 2918 simIter = gapWeightedSimilarityIncremental(s, t, 0.5); 2919 assert(simIter.front == 1); // one match 2920 simIter.popFront(); 2921 assert(simIter.empty); 2922 2923 s = ["Hello", "world"]; 2924 t = ["Hello", "yah", "world"]; 2925 simIter = gapWeightedSimilarityIncremental(s, t, 0.5); 2926 assert(simIter.front == 2); // two 1-gram matches 2927 simIter.popFront(); 2928 assert(simIter.front == 0.5, text(simIter.front)); // one 2-gram match, 1 gap 2929 } 2930 2931 @system unittest 2932 { 2933 GapWeightedSimilarityIncremental!(string[]) sim = 2934 GapWeightedSimilarityIncremental!(string[])( 2935 ["nyuk", "I", "have", "no", "chocolate", "giba"], 2936 ["wyda", "I", "have", "I", "have", "have", "I", "have", "hehe"], 2937 0.5); 2938 double[] witness = [ 7.0, 4.03125, 0, 0 ]; 2939 foreach (e; sim) 2940 { 2941 //writeln(e); 2942 assert(e == witness.front); 2943 witness.popFront(); 2944 } 2945 witness = [ 3.0, 1.3125, 0.25 ]; 2946 sim = GapWeightedSimilarityIncremental!(string[])( 2947 ["I", "have", "no", "chocolate"], 2948 ["I", "have", "some", "chocolate"], 2949 0.5); 2950 foreach (e; sim) 2951 { 2952 //writeln(e); 2953 assert(e == witness.front); 2954 witness.popFront(); 2955 } 2956 assert(witness.empty); 2957 } 2958 2959 /** 2960 Computes the greatest common divisor of `a` and `b` by using 2961 an efficient algorithm such as $(HTTPS en.wikipedia.org/wiki/Euclidean_algorithm, Euclid's) 2962 or $(HTTPS en.wikipedia.org/wiki/Binary_GCD_algorithm, Stein's) algorithm. 2963 2964 Params: 2965 a = Integer value of any numerical type that supports the modulo operator `%`. 2966 If bit-shifting `<<` and `>>` are also supported, Stein's algorithm will 2967 be used; otherwise, Euclid's algorithm is used as _a fallback. 2968 b = Integer value of any equivalent numerical type. 2969 2970 Returns: 2971 The greatest common divisor of the given arguments. 2972 */ 2973 typeof(Unqual!(T).init % Unqual!(U).init) gcd(T, U)(T a, U b) 2974 if (isIntegral!T && isIntegral!U) 2975 { 2976 // Operate on a common type between the two arguments. 2977 alias UCT = Unsigned!(CommonType!(Unqual!T, Unqual!U)); 2978 2979 // `std.math.abs` doesn't support unsigned integers, and `T.min` is undefined. 2980 static if (is(T : immutable short) || is(T : immutable byte)) 2981 UCT ax = (isUnsigned!T || a >= 0) ? a : cast(UCT) -int(a); 2982 else 2983 UCT ax = (isUnsigned!T || a >= 0) ? a : -UCT(a); 2984 2985 static if (is(U : immutable short) || is(U : immutable byte)) 2986 UCT bx = (isUnsigned!U || b >= 0) ? b : cast(UCT) -int(b); 2987 else 2988 UCT bx = (isUnsigned!U || b >= 0) ? b : -UCT(b); 2989 2990 // Special cases. 2991 if (ax == 0) 2992 return bx; 2993 if (bx == 0) 2994 return ax; 2995 2996 return gcdImpl(ax, bx); 2997 } 2998 2999 private typeof(T.init % T.init) gcdImpl(T)(T a, T b) 3000 if (isIntegral!T) 3001 { 3002 pragma(inline, true); 3003 import core.bitop : bsf; 3004 import std.algorithm.mutation : swap; 3005 3006 immutable uint shift = bsf(a | b); 3007 a >>= a.bsf; 3008 do 3009 { 3010 b >>= b.bsf; 3011 if (a > b) 3012 swap(a, b); 3013 b -= a; 3014 } while (b); 3015 3016 return a << shift; 3017 } 3018 3019 /// 3020 @safe unittest 3021 { 3022 assert(gcd(2 * 5 * 7 * 7, 5 * 7 * 11) == 5 * 7); 3023 const int a = 5 * 13 * 23 * 23, b = 13 * 59; 3024 assert(gcd(a, b) == 13); 3025 } 3026 3027 @safe unittest 3028 { 3029 import std.meta : AliasSeq; 3030 static foreach (T; AliasSeq!(byte, ubyte, short, ushort, int, uint, long, ulong, 3031 const byte, const short, const int, const long, 3032 immutable ubyte, immutable ushort, immutable uint, immutable ulong)) 3033 { 3034 static foreach (U; AliasSeq!(byte, ubyte, short, ushort, int, uint, long, ulong, 3035 const ubyte, const ushort, const uint, const ulong, 3036 immutable byte, immutable short, immutable int, immutable long)) 3037 { 3038 // Signed and unsigned tests. 3039 static if (T.max > byte.max && U.max > byte.max) 3040 assert(gcd(T(200), U(200)) == 200); 3041 static if (T.max > ubyte.max) 3042 { 3043 assert(gcd(T(2000), U(20)) == 20); 3044 assert(gcd(T(2011), U(17)) == 1); 3045 } 3046 static if (T.max > ubyte.max && U.max > ubyte.max) 3047 assert(gcd(T(1071), U(462)) == 21); 3048 3049 assert(gcd(T(0), U(13)) == 13); 3050 assert(gcd(T(29), U(0)) == 29); 3051 assert(gcd(T(0), U(0)) == 0); 3052 assert(gcd(T(1), U(2)) == 1); 3053 assert(gcd(T(9), U(6)) == 3); 3054 assert(gcd(T(3), U(4)) == 1); 3055 assert(gcd(T(32), U(24)) == 8); 3056 assert(gcd(T(5), U(6)) == 1); 3057 assert(gcd(T(54), U(36)) == 18); 3058 3059 // Int and Long tests. 3060 static if (T.max > short.max && U.max > short.max) 3061 assert(gcd(T(46391), U(62527)) == 2017); 3062 static if (T.max > ushort.max && U.max > ushort.max) 3063 assert(gcd(T(63245986), U(39088169)) == 1); 3064 static if (T.max > uint.max && U.max > uint.max) 3065 { 3066 assert(gcd(T(77160074263), U(47687519812)) == 1); 3067 assert(gcd(T(77160074264), U(47687519812)) == 4); 3068 } 3069 3070 // Negative tests. 3071 static if (T.min < 0) 3072 { 3073 assert(gcd(T(-21), U(28)) == 7); 3074 assert(gcd(T(-3), U(4)) == 1); 3075 } 3076 static if (U.min < 0) 3077 { 3078 assert(gcd(T(1), U(-2)) == 1); 3079 assert(gcd(T(33), U(-44)) == 11); 3080 } 3081 static if (T.min < 0 && U.min < 0) 3082 { 3083 assert(gcd(T(-5), U(-6)) == 1); 3084 assert(gcd(T(-50), U(-60)) == 10); 3085 } 3086 } 3087 } 3088 } 3089 3090 // https://issues.dlang.org/show_bug.cgi?id=21834 3091 @safe unittest 3092 { 3093 assert(gcd(-120, 10U) == 10); 3094 assert(gcd(120U, -10) == 10); 3095 assert(gcd(int.min, 0L) == 1L + int.max); 3096 assert(gcd(0L, int.min) == 1L + int.max); 3097 assert(gcd(int.min, 0L + int.min) == 1L + int.max); 3098 assert(gcd(int.min, 1L + int.max) == 1L + int.max); 3099 assert(gcd(short.min, 1U + short.max) == 1U + short.max); 3100 } 3101 3102 // This overload is for non-builtin numerical types like BigInt or 3103 // user-defined types. 3104 /// ditto 3105 auto gcd(T)(T a, T b) 3106 if (!isIntegral!T && 3107 is(typeof(T.init % T.init)) && 3108 is(typeof(T.init == 0 || T.init > 0))) 3109 { 3110 static if (!is(T == Unqual!T)) 3111 { 3112 return gcd!(Unqual!T)(a, b); 3113 } 3114 else 3115 { 3116 // Ensure arguments are unsigned. 3117 a = a >= 0 ? a : -a; 3118 b = b >= 0 ? b : -b; 3119 3120 // Special cases. 3121 if (a == 0) 3122 return b; 3123 if (b == 0) 3124 return a; 3125 3126 return gcdImpl(a, b); 3127 } 3128 } 3129 3130 private auto gcdImpl(T)(T a, T b) 3131 if (!isIntegral!T) 3132 { 3133 pragma(inline, true); 3134 import std.algorithm.mutation : swap; 3135 enum canUseBinaryGcd = is(typeof(() { 3136 T t, u; 3137 t <<= 1; 3138 t >>= 1; 3139 t -= u; 3140 bool b = (t & 1) == 0; 3141 swap(t, u); 3142 })); 3143 3144 static if (canUseBinaryGcd) 3145 { 3146 uint shift = 0; 3147 while ((a & 1) == 0 && (b & 1) == 0) 3148 { 3149 a >>= 1; 3150 b >>= 1; 3151 shift++; 3152 } 3153 3154 if ((a & 1) == 0) swap(a, b); 3155 3156 do 3157 { 3158 assert((a & 1) != 0); 3159 while ((b & 1) == 0) 3160 b >>= 1; 3161 if (a > b) 3162 swap(a, b); 3163 b -= a; 3164 } while (b); 3165 3166 return a << shift; 3167 } 3168 else 3169 { 3170 // The only thing we have is %; fallback to Euclidean algorithm. 3171 while (b != 0) 3172 { 3173 auto t = b; 3174 b = a % b; 3175 a = t; 3176 } 3177 return a; 3178 } 3179 } 3180 3181 // https://issues.dlang.org/show_bug.cgi?id=7102 3182 @system pure unittest 3183 { 3184 import std.bigint : BigInt; 3185 assert(gcd(BigInt("71_000_000_000_000_000_000"), 3186 BigInt("31_000_000_000_000_000_000")) == 3187 BigInt("1_000_000_000_000_000_000")); 3188 3189 assert(gcd(BigInt(0), BigInt(1234567)) == BigInt(1234567)); 3190 assert(gcd(BigInt(1234567), BigInt(0)) == BigInt(1234567)); 3191 } 3192 3193 @safe pure nothrow unittest 3194 { 3195 // A numerical type that only supports % and - (to force gcd implementation 3196 // to use Euclidean algorithm). 3197 struct CrippledInt 3198 { 3199 int impl; 3200 CrippledInt opBinary(string op : "%")(CrippledInt i) 3201 { 3202 return CrippledInt(impl % i.impl); 3203 } 3204 CrippledInt opUnary(string op : "-")() 3205 { 3206 return CrippledInt(-impl); 3207 } 3208 int opEquals(CrippledInt i) { return impl == i.impl; } 3209 int opEquals(int i) { return impl == i; } 3210 int opCmp(int i) { return (impl < i) ? -1 : (impl > i) ? 1 : 0; } 3211 } 3212 assert(gcd(CrippledInt(2310), CrippledInt(1309)) == CrippledInt(77)); 3213 assert(gcd(CrippledInt(-120), CrippledInt(10U)) == CrippledInt(10)); 3214 assert(gcd(CrippledInt(120U), CrippledInt(-10)) == CrippledInt(10)); 3215 } 3216 3217 // https://issues.dlang.org/show_bug.cgi?id=19514 3218 @system pure unittest 3219 { 3220 import std.bigint : BigInt; 3221 assert(gcd(BigInt(2), BigInt(1)) == BigInt(1)); 3222 } 3223 3224 // Issue 20924 3225 @safe unittest 3226 { 3227 import std.bigint : BigInt; 3228 const a = BigInt("123143238472389492934020"); 3229 const b = BigInt("902380489324729338420924"); 3230 assert(__traits(compiles, gcd(a, b))); 3231 } 3232 3233 // https://issues.dlang.org/show_bug.cgi?id=21834 3234 @safe unittest 3235 { 3236 import std.bigint : BigInt; 3237 assert(gcd(BigInt(-120), BigInt(10U)) == BigInt(10)); 3238 assert(gcd(BigInt(120U), BigInt(-10)) == BigInt(10)); 3239 assert(gcd(BigInt(int.min), BigInt(0L)) == BigInt(1L + int.max)); 3240 assert(gcd(BigInt(0L), BigInt(int.min)) == BigInt(1L + int.max)); 3241 assert(gcd(BigInt(int.min), BigInt(0L + int.min)) == BigInt(1L + int.max)); 3242 assert(gcd(BigInt(int.min), BigInt(1L + int.max)) == BigInt(1L + int.max)); 3243 assert(gcd(BigInt(short.min), BigInt(1U + short.max)) == BigInt(1U + short.max)); 3244 } 3245 3246 3247 /** 3248 Computes the least common multiple of `a` and `b`. 3249 Arguments are the same as $(MYREF gcd). 3250 3251 Returns: 3252 The least common multiple of the given arguments. 3253 */ 3254 typeof(Unqual!(T).init % Unqual!(U).init) lcm(T, U)(T a, U b) 3255 if (isIntegral!T && isIntegral!U) 3256 { 3257 // Operate on a common type between the two arguments. 3258 alias UCT = Unsigned!(CommonType!(Unqual!T, Unqual!U)); 3259 3260 // `std.math.abs` doesn't support unsigned integers, and `T.min` is undefined. 3261 static if (is(T : immutable short) || is(T : immutable byte)) 3262 UCT ax = (isUnsigned!T || a >= 0) ? a : cast(UCT) -int(a); 3263 else 3264 UCT ax = (isUnsigned!T || a >= 0) ? a : -UCT(a); 3265 3266 static if (is(U : immutable short) || is(U : immutable byte)) 3267 UCT bx = (isUnsigned!U || b >= 0) ? b : cast(UCT) -int(b); 3268 else 3269 UCT bx = (isUnsigned!U || b >= 0) ? b : -UCT(b); 3270 3271 // Special cases. 3272 if (ax == 0) 3273 return ax; 3274 if (bx == 0) 3275 return bx; 3276 3277 return (ax / gcdImpl(ax, bx)) * bx; 3278 } 3279 3280 /// 3281 @safe unittest 3282 { 3283 assert(lcm(1, 2) == 2); 3284 assert(lcm(3, 4) == 12); 3285 assert(lcm(5, 6) == 30); 3286 } 3287 3288 @safe unittest 3289 { 3290 import std.meta : AliasSeq; 3291 static foreach (T; AliasSeq!(byte, ubyte, short, ushort, int, uint, long, ulong, 3292 const byte, const short, const int, const long, 3293 immutable ubyte, immutable ushort, immutable uint, immutable ulong)) 3294 { 3295 static foreach (U; AliasSeq!(byte, ubyte, short, ushort, int, uint, long, ulong, 3296 const ubyte, const ushort, const uint, const ulong, 3297 immutable byte, immutable short, immutable int, immutable long)) 3298 { 3299 assert(lcm(T(21), U(6)) == 42); 3300 assert(lcm(T(41), U(0)) == 0); 3301 assert(lcm(T(0), U(7)) == 0); 3302 assert(lcm(T(0), U(0)) == 0); 3303 assert(lcm(T(1U), U(2)) == 2); 3304 assert(lcm(T(3), U(4U)) == 12); 3305 assert(lcm(T(5U), U(6U)) == 30); 3306 static if (T.min < 0) 3307 assert(lcm(T(-42), U(21U)) == 42); 3308 } 3309 } 3310 } 3311 3312 /// ditto 3313 auto lcm(T)(T a, T b) 3314 if (!isIntegral!T && 3315 is(typeof(T.init % T.init)) && 3316 is(typeof(T.init == 0 || T.init > 0))) 3317 { 3318 // Ensure arguments are unsigned. 3319 a = a >= 0 ? a : -a; 3320 b = b >= 0 ? b : -b; 3321 3322 // Special cases. 3323 if (a == 0) 3324 return a; 3325 if (b == 0) 3326 return b; 3327 3328 return (a / gcdImpl(a, b)) * b; 3329 } 3330 3331 @safe unittest 3332 { 3333 import std.bigint : BigInt; 3334 assert(lcm(BigInt(21), BigInt(6)) == BigInt(42)); 3335 assert(lcm(BigInt(41), BigInt(0)) == BigInt(0)); 3336 assert(lcm(BigInt(0), BigInt(7)) == BigInt(0)); 3337 assert(lcm(BigInt(0), BigInt(0)) == BigInt(0)); 3338 assert(lcm(BigInt(1U), BigInt(2)) == BigInt(2)); 3339 assert(lcm(BigInt(3), BigInt(4U)) == BigInt(12)); 3340 assert(lcm(BigInt(5U), BigInt(6U)) == BigInt(30)); 3341 assert(lcm(BigInt(-42), BigInt(21U)) == BigInt(42)); 3342 } 3343 3344 // This is to make tweaking the speed/size vs. accuracy tradeoff easy, 3345 // though floats seem accurate enough for all practical purposes, since 3346 // they pass the "isClose(inverseFft(fft(arr)), arr)" test even for 3347 // size 2 ^^ 22. 3348 private alias lookup_t = float; 3349 3350 /**A class for performing fast Fourier transforms of power of two sizes. 3351 * This class encapsulates a large amount of state that is reusable when 3352 * performing multiple FFTs of sizes smaller than or equal to that specified 3353 * in the constructor. This results in substantial speedups when performing 3354 * multiple FFTs with a known maximum size. However, 3355 * a free function API is provided for convenience if you need to perform a 3356 * one-off FFT. 3357 * 3358 * References: 3359 * $(HTTP en.wikipedia.org/wiki/Cooley%E2%80%93Tukey_FFT_algorithm) 3360 */ 3361 final class Fft 3362 { 3363 import core.bitop : bsf; 3364 import std.algorithm.iteration : map; 3365 import std.array : uninitializedArray; 3366 3367 private: 3368 immutable lookup_t[][] negSinLookup; 3369 3370 void enforceSize(R)(R range) const 3371 { 3372 import std.conv : text; 3373 assert(range.length <= size, text( 3374 "FFT size mismatch. Expected ", size, ", got ", range.length)); 3375 } 3376 3377 void fftImpl(Ret, R)(Stride!R range, Ret buf) const 3378 in 3379 { 3380 assert(range.length >= 4); 3381 assert(isPowerOf2(range.length)); 3382 } 3383 do 3384 { 3385 auto recurseRange = range; 3386 recurseRange.doubleSteps(); 3387 3388 if (buf.length > 4) 3389 { 3390 fftImpl(recurseRange, buf[0..$ / 2]); 3391 recurseRange.popHalf(); 3392 fftImpl(recurseRange, buf[$ / 2..$]); 3393 } 3394 else 3395 { 3396 // Do this here instead of in another recursion to save on 3397 // recursion overhead. 3398 slowFourier2(recurseRange, buf[0..$ / 2]); 3399 recurseRange.popHalf(); 3400 slowFourier2(recurseRange, buf[$ / 2..$]); 3401 } 3402 3403 butterfly(buf); 3404 } 3405 3406 // This algorithm works by performing the even and odd parts of our FFT 3407 // using the "two for the price of one" method mentioned at 3408 // https://web.archive.org/web/20180312110051/http://www.engineeringproductivitytools.com/stuff/T0001/PT10.HTM#Head521 3409 // by making the odd terms into the imaginary components of our new FFT, 3410 // and then using symmetry to recombine them. 3411 void fftImplPureReal(Ret, R)(R range, Ret buf) const 3412 in 3413 { 3414 assert(range.length >= 4); 3415 assert(isPowerOf2(range.length)); 3416 } 3417 do 3418 { 3419 alias E = ElementType!R; 3420 3421 // Converts odd indices of range to the imaginary components of 3422 // a range half the size. The even indices become the real components. 3423 static if (isArray!R && isFloatingPoint!E) 3424 { 3425 // Then the memory layout of complex numbers provides a dirt 3426 // cheap way to convert. This is a common case, so take advantage. 3427 auto oddsImag = cast(Complex!E[]) range; 3428 } 3429 else 3430 { 3431 // General case: Use a higher order range. We can assume 3432 // source.length is even because it has to be a power of 2. 3433 static struct OddToImaginary 3434 { 3435 R source; 3436 alias C = Complex!(CommonType!(E, typeof(buf[0].re))); 3437 3438 @property 3439 { 3440 C front() 3441 { 3442 return C(source[0], source[1]); 3443 } 3444 3445 C back() 3446 { 3447 immutable n = source.length; 3448 return C(source[n - 2], source[n - 1]); 3449 } 3450 3451 typeof(this) save() 3452 { 3453 return typeof(this)(source.save); 3454 } 3455 3456 bool empty() 3457 { 3458 return source.empty; 3459 } 3460 3461 size_t length() 3462 { 3463 return source.length / 2; 3464 } 3465 } 3466 3467 void popFront() 3468 { 3469 source.popFront(); 3470 source.popFront(); 3471 } 3472 3473 void popBack() 3474 { 3475 source.popBack(); 3476 source.popBack(); 3477 } 3478 3479 C opIndex(size_t index) 3480 { 3481 return C(source[index * 2], source[index * 2 + 1]); 3482 } 3483 3484 typeof(this) opSlice(size_t lower, size_t upper) 3485 { 3486 return typeof(this)(source[lower * 2 .. upper * 2]); 3487 } 3488 } 3489 3490 auto oddsImag = OddToImaginary(range); 3491 } 3492 3493 fft(oddsImag, buf[0..$ / 2]); 3494 auto evenFft = buf[0..$ / 2]; 3495 auto oddFft = buf[$ / 2..$]; 3496 immutable halfN = evenFft.length; 3497 oddFft[0].re = buf[0].im; 3498 oddFft[0].im = 0; 3499 evenFft[0].im = 0; 3500 // evenFft[0].re is already right b/c it's aliased with buf[0].re. 3501 3502 foreach (k; 1 .. halfN / 2 + 1) 3503 { 3504 immutable bufk = buf[k]; 3505 immutable bufnk = buf[buf.length / 2 - k]; 3506 evenFft[k].re = 0.5 * (bufk.re + bufnk.re); 3507 evenFft[halfN - k].re = evenFft[k].re; 3508 evenFft[k].im = 0.5 * (bufk.im - bufnk.im); 3509 evenFft[halfN - k].im = -evenFft[k].im; 3510 3511 oddFft[k].re = 0.5 * (bufk.im + bufnk.im); 3512 oddFft[halfN - k].re = oddFft[k].re; 3513 oddFft[k].im = 0.5 * (bufnk.re - bufk.re); 3514 oddFft[halfN - k].im = -oddFft[k].im; 3515 } 3516 3517 butterfly(buf); 3518 } 3519 3520 void butterfly(R)(R buf) const 3521 in 3522 { 3523 assert(isPowerOf2(buf.length)); 3524 } 3525 do 3526 { 3527 immutable n = buf.length; 3528 immutable localLookup = negSinLookup[bsf(n)]; 3529 assert(localLookup.length == n); 3530 3531 immutable cosMask = n - 1; 3532 immutable cosAdd = n / 4 * 3; 3533 3534 lookup_t negSinFromLookup(size_t index) pure nothrow 3535 { 3536 return localLookup[index]; 3537 } 3538 3539 lookup_t cosFromLookup(size_t index) pure nothrow 3540 { 3541 // cos is just -sin shifted by PI * 3 / 2. 3542 return localLookup[(index + cosAdd) & cosMask]; 3543 } 3544 3545 immutable halfLen = n / 2; 3546 3547 // This loop is unrolled and the two iterations are interleaved 3548 // relative to the textbook FFT to increase ILP. This gives roughly 5% 3549 // speedups on DMD. 3550 for (size_t k = 0; k < halfLen; k += 2) 3551 { 3552 immutable cosTwiddle1 = cosFromLookup(k); 3553 immutable sinTwiddle1 = negSinFromLookup(k); 3554 immutable cosTwiddle2 = cosFromLookup(k + 1); 3555 immutable sinTwiddle2 = negSinFromLookup(k + 1); 3556 3557 immutable realLower1 = buf[k].re; 3558 immutable imagLower1 = buf[k].im; 3559 immutable realLower2 = buf[k + 1].re; 3560 immutable imagLower2 = buf[k + 1].im; 3561 3562 immutable upperIndex1 = k + halfLen; 3563 immutable upperIndex2 = upperIndex1 + 1; 3564 immutable realUpper1 = buf[upperIndex1].re; 3565 immutable imagUpper1 = buf[upperIndex1].im; 3566 immutable realUpper2 = buf[upperIndex2].re; 3567 immutable imagUpper2 = buf[upperIndex2].im; 3568 3569 immutable realAdd1 = cosTwiddle1 * realUpper1 3570 - sinTwiddle1 * imagUpper1; 3571 immutable imagAdd1 = sinTwiddle1 * realUpper1 3572 + cosTwiddle1 * imagUpper1; 3573 immutable realAdd2 = cosTwiddle2 * realUpper2 3574 - sinTwiddle2 * imagUpper2; 3575 immutable imagAdd2 = sinTwiddle2 * realUpper2 3576 + cosTwiddle2 * imagUpper2; 3577 3578 buf[k].re += realAdd1; 3579 buf[k].im += imagAdd1; 3580 buf[k + 1].re += realAdd2; 3581 buf[k + 1].im += imagAdd2; 3582 3583 buf[upperIndex1].re = realLower1 - realAdd1; 3584 buf[upperIndex1].im = imagLower1 - imagAdd1; 3585 buf[upperIndex2].re = realLower2 - realAdd2; 3586 buf[upperIndex2].im = imagLower2 - imagAdd2; 3587 } 3588 } 3589 3590 // This constructor is used within this module for allocating the 3591 // buffer space elsewhere besides the GC heap. It's definitely **NOT** 3592 // part of the public API and definitely **IS** subject to change. 3593 // 3594 // Also, this is unsafe because the memSpace buffer will be cast 3595 // to immutable. 3596 // 3597 // Public b/c of https://issues.dlang.org/show_bug.cgi?id=4636. 3598 public this(lookup_t[] memSpace) 3599 { 3600 immutable size = memSpace.length / 2; 3601 3602 /* Create a lookup table of all negative sine values at a resolution of 3603 * size and all smaller power of two resolutions. This may seem 3604 * inefficient, but having all the lookups be next to each other in 3605 * memory at every level of iteration is a huge win performance-wise. 3606 */ 3607 if (size == 0) 3608 { 3609 return; 3610 } 3611 3612 assert(isPowerOf2(size), 3613 "Can only do FFTs on ranges with a size that is a power of two."); 3614 3615 auto table = new lookup_t[][bsf(size) + 1]; 3616 3617 table[$ - 1] = memSpace[$ - size..$]; 3618 memSpace = memSpace[0 .. size]; 3619 3620 auto lastRow = table[$ - 1]; 3621 lastRow[0] = 0; // -sin(0) == 0. 3622 foreach (ptrdiff_t i; 1 .. size) 3623 { 3624 // The hard coded cases are for improved accuracy and to prevent 3625 // annoying non-zeroness when stuff should be zero. 3626 3627 if (i == size / 4) 3628 lastRow[i] = -1; // -sin(pi / 2) == -1. 3629 else if (i == size / 2) 3630 lastRow[i] = 0; // -sin(pi) == 0. 3631 else if (i == size * 3 / 4) 3632 lastRow[i] = 1; // -sin(pi * 3 / 2) == 1 3633 else 3634 lastRow[i] = -sin(i * 2.0L * PI / size); 3635 } 3636 3637 // Fill in all the other rows with strided versions. 3638 foreach (i; 1 .. table.length - 1) 3639 { 3640 immutable strideLength = size / (2 ^^ i); 3641 auto strided = Stride!(lookup_t[])(lastRow, strideLength); 3642 table[i] = memSpace[$ - strided.length..$]; 3643 memSpace = memSpace[0..$ - strided.length]; 3644 3645 size_t copyIndex; 3646 foreach (elem; strided) 3647 { 3648 table[i][copyIndex++] = elem; 3649 } 3650 } 3651 3652 negSinLookup = cast(immutable) table; 3653 } 3654 3655 public: 3656 /**Create an `Fft` object for computing fast Fourier transforms of 3657 * power of two sizes of `size` or smaller. `size` must be a 3658 * power of two. 3659 */ 3660 this(size_t size) 3661 { 3662 // Allocate all twiddle factor buffers in one contiguous block so that, 3663 // when one is done being used, the next one is next in cache. 3664 auto memSpace = uninitializedArray!(lookup_t[])(2 * size); 3665 this(memSpace); 3666 } 3667 3668 @property size_t size() const 3669 { 3670 return (negSinLookup is null) ? 0 : negSinLookup[$ - 1].length; 3671 } 3672 3673 /**Compute the Fourier transform of range using the $(BIGOH N log N) 3674 * Cooley-Tukey Algorithm. `range` must be a random-access range with 3675 * slicing and a length equal to `size` as provided at the construction of 3676 * this object. The contents of range can be either numeric types, 3677 * which will be interpreted as pure real values, or complex types with 3678 * properties or members `.re` and `.im` that can be read. 3679 * 3680 * Note: Pure real FFTs are automatically detected and the relevant 3681 * optimizations are performed. 3682 * 3683 * Returns: An array of complex numbers representing the transformed data in 3684 * the frequency domain. 3685 * 3686 * Conventions: The exponent is negative and the factor is one, 3687 * i.e., output[j] := sum[ exp(-2 PI i j k / N) input[k] ]. 3688 */ 3689 Complex!F[] fft(F = double, R)(R range) const 3690 if (isFloatingPoint!F && isRandomAccessRange!R) 3691 { 3692 enforceSize(range); 3693 Complex!F[] ret; 3694 if (range.length == 0) 3695 { 3696 return ret; 3697 } 3698 3699 // Don't waste time initializing the memory for ret. 3700 ret = uninitializedArray!(Complex!F[])(range.length); 3701 3702 fft(range, ret); 3703 return ret; 3704 } 3705 3706 /**Same as the overload, but allows for the results to be stored in a user- 3707 * provided buffer. The buffer must be of the same length as range, must be 3708 * a random-access range, must have slicing, and must contain elements that are 3709 * complex-like. This means that they must have a .re and a .im member or 3710 * property that can be both read and written and are floating point numbers. 3711 */ 3712 void fft(Ret, R)(R range, Ret buf) const 3713 if (isRandomAccessRange!Ret && isComplexLike!(ElementType!Ret) && hasSlicing!Ret) 3714 { 3715 assert(buf.length == range.length); 3716 enforceSize(range); 3717 3718 if (range.length == 0) 3719 { 3720 return; 3721 } 3722 else if (range.length == 1) 3723 { 3724 buf[0] = range[0]; 3725 return; 3726 } 3727 else if (range.length == 2) 3728 { 3729 slowFourier2(range, buf); 3730 return; 3731 } 3732 else 3733 { 3734 alias E = ElementType!R; 3735 static if (is(E : real)) 3736 { 3737 return fftImplPureReal(range, buf); 3738 } 3739 else 3740 { 3741 static if (is(R : Stride!R)) 3742 return fftImpl(range, buf); 3743 else 3744 return fftImpl(Stride!R(range, 1), buf); 3745 } 3746 } 3747 } 3748 3749 /** 3750 * Computes the inverse Fourier transform of a range. The range must be a 3751 * random access range with slicing, have a length equal to the size 3752 * provided at construction of this object, and contain elements that are 3753 * either of type std.complex.Complex or have essentially 3754 * the same compile-time interface. 3755 * 3756 * Returns: The time-domain signal. 3757 * 3758 * Conventions: The exponent is positive and the factor is 1/N, i.e., 3759 * output[j] := (1 / N) sum[ exp(+2 PI i j k / N) input[k] ]. 3760 */ 3761 Complex!F[] inverseFft(F = double, R)(R range) const 3762 if (isRandomAccessRange!R && isComplexLike!(ElementType!R) && isFloatingPoint!F) 3763 { 3764 enforceSize(range); 3765 Complex!F[] ret; 3766 if (range.length == 0) 3767 { 3768 return ret; 3769 } 3770 3771 // Don't waste time initializing the memory for ret. 3772 ret = uninitializedArray!(Complex!F[])(range.length); 3773 3774 inverseFft(range, ret); 3775 return ret; 3776 } 3777 3778 /** 3779 * Inverse FFT that allows a user-supplied buffer to be provided. The buffer 3780 * must be a random access range with slicing, and its elements 3781 * must be some complex-like type. 3782 */ 3783 void inverseFft(Ret, R)(R range, Ret buf) const 3784 if (isRandomAccessRange!Ret && isComplexLike!(ElementType!Ret) && hasSlicing!Ret) 3785 { 3786 enforceSize(range); 3787 3788 auto swapped = map!swapRealImag(range); 3789 fft(swapped, buf); 3790 3791 immutable lenNeg1 = 1.0 / buf.length; 3792 foreach (ref elem; buf) 3793 { 3794 immutable temp = elem.re * lenNeg1; 3795 elem.re = elem.im * lenNeg1; 3796 elem.im = temp; 3797 } 3798 } 3799 } 3800 3801 // This mixin creates an Fft object in the scope it's mixed into such that all 3802 // memory owned by the object is deterministically destroyed at the end of that 3803 // scope. 3804 private enum string MakeLocalFft = q{ 3805 import core.stdc.stdlib; 3806 import core.exception : onOutOfMemoryError; 3807 3808 auto lookupBuf = (cast(lookup_t*) malloc(range.length * 2 * lookup_t.sizeof)) 3809 [0 .. 2 * range.length]; 3810 if (!lookupBuf.ptr) 3811 onOutOfMemoryError(); 3812 3813 scope(exit) free(cast(void*) lookupBuf.ptr); 3814 auto fftObj = scoped!Fft(lookupBuf); 3815 }; 3816 3817 /**Convenience functions that create an `Fft` object, run the FFT or inverse 3818 * FFT and return the result. Useful for one-off FFTs. 3819 * 3820 * Note: In addition to convenience, these functions are slightly more 3821 * efficient than manually creating an Fft object for a single use, 3822 * as the Fft object is deterministically destroyed before these 3823 * functions return. 3824 */ 3825 Complex!F[] fft(F = double, R)(R range) 3826 { 3827 mixin(MakeLocalFft); 3828 return fftObj.fft!(F, R)(range); 3829 } 3830 3831 /// ditto 3832 void fft(Ret, R)(R range, Ret buf) 3833 { 3834 mixin(MakeLocalFft); 3835 return fftObj.fft!(Ret, R)(range, buf); 3836 } 3837 3838 /// ditto 3839 Complex!F[] inverseFft(F = double, R)(R range) 3840 { 3841 mixin(MakeLocalFft); 3842 return fftObj.inverseFft!(F, R)(range); 3843 } 3844 3845 /// ditto 3846 void inverseFft(Ret, R)(R range, Ret buf) 3847 { 3848 mixin(MakeLocalFft); 3849 return fftObj.inverseFft!(Ret, R)(range, buf); 3850 } 3851 3852 @system unittest 3853 { 3854 import std.algorithm; 3855 import std.conv; 3856 import std.range; 3857 // Test values from R and Octave. 3858 auto arr = [1,2,3,4,5,6,7,8]; 3859 auto fft1 = fft(arr); 3860 assert(isClose(map!"a.re"(fft1), 3861 [36.0, -4, -4, -4, -4, -4, -4, -4], 1e-4)); 3862 assert(isClose(map!"a.im"(fft1), 3863 [0, 9.6568, 4, 1.6568, 0, -1.6568, -4, -9.6568], 1e-4)); 3864 3865 auto fft1Retro = fft(retro(arr)); 3866 assert(isClose(map!"a.re"(fft1Retro), 3867 [36.0, 4, 4, 4, 4, 4, 4, 4], 1e-4)); 3868 assert(isClose(map!"a.im"(fft1Retro), 3869 [0, -9.6568, -4, -1.6568, 0, 1.6568, 4, 9.6568], 1e-4)); 3870 3871 auto fft1Float = fft(to!(float[])(arr)); 3872 assert(isClose(map!"a.re"(fft1), map!"a.re"(fft1Float))); 3873 assert(isClose(map!"a.im"(fft1), map!"a.im"(fft1Float))); 3874 3875 alias C = Complex!float; 3876 auto arr2 = [C(1,2), C(3,4), C(5,6), C(7,8), C(9,10), 3877 C(11,12), C(13,14), C(15,16)]; 3878 auto fft2 = fft(arr2); 3879 assert(isClose(map!"a.re"(fft2), 3880 [64.0, -27.3137, -16, -11.3137, -8, -4.6862, 0, 11.3137], 1e-4)); 3881 assert(isClose(map!"a.im"(fft2), 3882 [72, 11.3137, 0, -4.686, -8, -11.3137, -16, -27.3137], 1e-4)); 3883 3884 auto inv1 = inverseFft(fft1); 3885 assert(isClose(map!"a.re"(inv1), arr, 1e-6)); 3886 assert(reduce!max(map!"a.im"(inv1)) < 1e-10); 3887 3888 auto inv2 = inverseFft(fft2); 3889 assert(isClose(map!"a.re"(inv2), map!"a.re"(arr2))); 3890 assert(isClose(map!"a.im"(inv2), map!"a.im"(arr2))); 3891 3892 // FFTs of size 0, 1 and 2 are handled as special cases. Test them here. 3893 ushort[] empty; 3894 assert(fft(empty) == null); 3895 assert(inverseFft(fft(empty)) == null); 3896 3897 real[] oneElem = [4.5L]; 3898 auto oneFft = fft(oneElem); 3899 assert(oneFft.length == 1); 3900 assert(oneFft[0].re == 4.5L); 3901 assert(oneFft[0].im == 0); 3902 3903 auto oneInv = inverseFft(oneFft); 3904 assert(oneInv.length == 1); 3905 assert(isClose(oneInv[0].re, 4.5)); 3906 assert(isClose(oneInv[0].im, 0, 0.0, 1e-10)); 3907 3908 long[2] twoElems = [8, 4]; 3909 auto twoFft = fft(twoElems[]); 3910 assert(twoFft.length == 2); 3911 assert(isClose(twoFft[0].re, 12)); 3912 assert(isClose(twoFft[0].im, 0, 0.0, 1e-10)); 3913 assert(isClose(twoFft[1].re, 4)); 3914 assert(isClose(twoFft[1].im, 0, 0.0, 1e-10)); 3915 auto twoInv = inverseFft(twoFft); 3916 assert(isClose(twoInv[0].re, 8)); 3917 assert(isClose(twoInv[0].im, 0, 0.0, 1e-10)); 3918 assert(isClose(twoInv[1].re, 4)); 3919 assert(isClose(twoInv[1].im, 0, 0.0, 1e-10)); 3920 } 3921 3922 // Swaps the real and imaginary parts of a complex number. This is useful 3923 // for inverse FFTs. 3924 C swapRealImag(C)(C input) 3925 { 3926 return C(input.im, input.re); 3927 } 3928 3929 /** This function transforms `decimal` value into a value in the factorial number 3930 system stored in `fac`. 3931 3932 A factorial number is constructed as: 3933 $(D fac[0] * 0! + fac[1] * 1! + ... fac[20] * 20!) 3934 3935 Params: 3936 decimal = The decimal value to convert into the factorial number system. 3937 fac = The array to store the factorial number. The array is of size 21 as 3938 `ulong.max` requires 21 digits in the factorial number system. 3939 Returns: 3940 A variable storing the number of digits of the factorial number stored in 3941 `fac`. 3942 */ 3943 size_t decimalToFactorial(ulong decimal, ref ubyte[21] fac) 3944 @safe pure nothrow @nogc 3945 { 3946 import std.algorithm.mutation : reverse; 3947 size_t idx; 3948 3949 for (ulong i = 1; decimal != 0; ++i) 3950 { 3951 auto temp = decimal % i; 3952 decimal /= i; 3953 fac[idx++] = cast(ubyte)(temp); 3954 } 3955 3956 if (idx == 0) 3957 { 3958 fac[idx++] = cast(ubyte) 0; 3959 } 3960 3961 reverse(fac[0 .. idx]); 3962 3963 // first digit of the number in factorial will always be zero 3964 assert(fac[idx - 1] == 0); 3965 3966 return idx; 3967 } 3968 3969 /// 3970 @safe pure @nogc unittest 3971 { 3972 ubyte[21] fac; 3973 size_t idx = decimalToFactorial(2982, fac); 3974 3975 assert(fac[0] == 4); 3976 assert(fac[1] == 0); 3977 assert(fac[2] == 4); 3978 assert(fac[3] == 1); 3979 assert(fac[4] == 0); 3980 assert(fac[5] == 0); 3981 assert(fac[6] == 0); 3982 } 3983 3984 @safe pure unittest 3985 { 3986 ubyte[21] fac; 3987 size_t idx = decimalToFactorial(0UL, fac); 3988 assert(idx == 1); 3989 assert(fac[0] == 0); 3990 3991 fac[] = 0; 3992 idx = 0; 3993 idx = decimalToFactorial(ulong.max, fac); 3994 assert(idx == 21); 3995 auto t = [7, 11, 12, 4, 3, 15, 3, 5, 3, 5, 0, 8, 3, 5, 0, 0, 0, 2, 1, 1, 0]; 3996 foreach (i, it; fac[0 .. 21]) 3997 { 3998 assert(it == t[i]); 3999 } 4000 4001 fac[] = 0; 4002 idx = decimalToFactorial(2982, fac); 4003 4004 assert(idx == 7); 4005 t = [4, 0, 4, 1, 0, 0, 0]; 4006 foreach (i, it; fac[0 .. idx]) 4007 { 4008 assert(it == t[i]); 4009 } 4010 } 4011 4012 private: 4013 // The reasons I couldn't use std.algorithm were b/c its stride length isn't 4014 // modifiable on the fly and because range has grown some performance hacks 4015 // for powers of 2. 4016 struct Stride(R) 4017 { 4018 import core.bitop : bsf; 4019 Unqual!R range; 4020 size_t _nSteps; 4021 size_t _length; 4022 alias E = ElementType!(R); 4023 4024 this(R range, size_t nStepsIn) 4025 { 4026 this.range = range; 4027 _nSteps = nStepsIn; 4028 _length = (range.length + _nSteps - 1) / nSteps; 4029 } 4030 4031 size_t length() const @property 4032 { 4033 return _length; 4034 } 4035 4036 typeof(this) save() @property 4037 { 4038 auto ret = this; 4039 ret.range = ret.range.save; 4040 return ret; 4041 } 4042 4043 E opIndex(size_t index) 4044 { 4045 return range[index * _nSteps]; 4046 } 4047 4048 E front() @property 4049 { 4050 return range[0]; 4051 } 4052 4053 void popFront() 4054 { 4055 if (range.length >= _nSteps) 4056 { 4057 range = range[_nSteps .. range.length]; 4058 _length--; 4059 } 4060 else 4061 { 4062 range = range[0 .. 0]; 4063 _length = 0; 4064 } 4065 } 4066 4067 // Pops half the range's stride. 4068 void popHalf() 4069 { 4070 range = range[_nSteps / 2 .. range.length]; 4071 } 4072 4073 bool empty() const @property 4074 { 4075 return length == 0; 4076 } 4077 4078 size_t nSteps() const @property 4079 { 4080 return _nSteps; 4081 } 4082 4083 void doubleSteps() 4084 { 4085 _nSteps *= 2; 4086 _length /= 2; 4087 } 4088 4089 size_t nSteps(size_t newVal) @property 4090 { 4091 _nSteps = newVal; 4092 4093 // Using >> bsf(nSteps) is a few cycles faster than / nSteps. 4094 _length = (range.length + _nSteps - 1) >> bsf(nSteps); 4095 return newVal; 4096 } 4097 } 4098 4099 // Hard-coded base case for FFT of size 2. This is actually a TON faster than 4100 // using a generic slow DFT. This seems to be the best base case. (Size 1 4101 // can be coded inline as buf[0] = range[0]). 4102 void slowFourier2(Ret, R)(R range, Ret buf) 4103 { 4104 assert(range.length == 2); 4105 assert(buf.length == 2); 4106 buf[0] = range[0] + range[1]; 4107 buf[1] = range[0] - range[1]; 4108 } 4109 4110 // Hard-coded base case for FFT of size 4. Doesn't work as well as the size 4111 // 2 case. 4112 void slowFourier4(Ret, R)(R range, Ret buf) 4113 { 4114 alias C = ElementType!Ret; 4115 4116 assert(range.length == 4); 4117 assert(buf.length == 4); 4118 buf[0] = range[0] + range[1] + range[2] + range[3]; 4119 buf[1] = range[0] - range[1] * C(0, 1) - range[2] + range[3] * C(0, 1); 4120 buf[2] = range[0] - range[1] + range[2] - range[3]; 4121 buf[3] = range[0] + range[1] * C(0, 1) - range[2] - range[3] * C(0, 1); 4122 } 4123 4124 N roundDownToPowerOf2(N)(N num) 4125 if (isScalarType!N && !isFloatingPoint!N) 4126 { 4127 import core.bitop : bsr; 4128 return num & (cast(N) 1 << bsr(num)); 4129 } 4130 4131 @safe unittest 4132 { 4133 assert(roundDownToPowerOf2(7) == 4); 4134 assert(roundDownToPowerOf2(4) == 4); 4135 } 4136 4137 template isComplexLike(T) 4138 { 4139 enum bool isComplexLike = is(typeof(T.init.re)) && 4140 is(typeof(T.init.im)); 4141 } 4142 4143 @safe unittest 4144 { 4145 static assert(isComplexLike!(Complex!double)); 4146 static assert(!isComplexLike!(uint)); 4147 }