1 // Written in the D programming language. 2 3 /** 4 This module is a port of a growing fragment of the $(D_PARAM numeric) 5 header in Alexander Stepanov's $(LINK2 https://en.wikipedia.org/wiki/Standard_Template_Library, 6 Standard Template Library), with a few additions. 7 8 Macros: 9 Copyright: Copyright Andrei Alexandrescu 2008 - 2009. 10 License: $(HTTP www.boost.org/LICENSE_1_0.txt, Boost License 1.0). 11 Authors: $(HTTP erdani.org, Andrei Alexandrescu), 12 Don Clugston, Robert Jacques, Ilya Yaroshenko 13 Source: $(PHOBOSSRC std/numeric.d) 14 */ 15 /* 16 Copyright Andrei Alexandrescu 2008 - 2009. 17 Distributed under the Boost Software License, Version 1.0. 18 (See accompanying file LICENSE_1_0.txt or copy at 19 http://www.boost.org/LICENSE_1_0.txt) 20 */ 21 module std.numeric; 22 23 import std.complex; 24 import std.math; 25 import core.math : fabs, ldexp, sin, sqrt; 26 import std.range.primitives; 27 import std.traits; 28 import std.typecons; 29 30 /// Format flags for CustomFloat. 31 public enum CustomFloatFlags 32 { 33 /// Adds a sign bit to allow for signed numbers. 34 signed = 1, 35 36 /** 37 * Store values in normalized form by default. The actual precision of the 38 * significand is extended by 1 bit by assuming an implicit leading bit of 1 39 * instead of 0. i.e. `1.nnnn` instead of `0.nnnn`. 40 * True for all $(LINK2 https://en.wikipedia.org/wiki/IEEE_floating_point, IEE754) types 41 */ 42 storeNormalized = 2, 43 44 /** 45 * Stores the significand in $(LINK2 https://en.wikipedia.org/wiki/IEEE_754-1985#Denormalized_numbers, 46 * IEEE754 denormalized) form when the exponent is 0. Required to express the value 0. 47 */ 48 allowDenorm = 4, 49 50 /** 51 * Allows the storage of $(LINK2 https://en.wikipedia.org/wiki/IEEE_754-1985#Positive_and_negative_infinity, 52 * IEEE754 _infinity) values. 53 */ 54 infinity = 8, 55 56 /// Allows the storage of $(LINK2 https://en.wikipedia.org/wiki/NaN, IEEE754 Not a Number) values. 57 nan = 16, 58 59 /** 60 * If set, select an exponent bias such that max_exp = 1. 61 * i.e. so that the maximum value is >= 1.0 and < 2.0. 62 * Ignored if the exponent bias is manually specified. 63 */ 64 probability = 32, 65 66 /// If set, unsigned custom floats are assumed to be negative. 67 negativeUnsigned = 64, 68 69 /**If set, 0 is the only allowed $(LINK2 https://en.wikipedia.org/wiki/IEEE_754-1985#Denormalized_numbers, 70 * IEEE754 denormalized) number. 71 * Requires allowDenorm and storeNormalized. 72 */ 73 allowDenormZeroOnly = 128 | allowDenorm | storeNormalized, 74 75 /// Include _all of the $(LINK2 https://en.wikipedia.org/wiki/IEEE_floating_point, IEEE754) options. 76 ieee = signed | storeNormalized | allowDenorm | infinity | nan , 77 78 /// Include none of the above options. 79 none = 0 80 } 81 82 private enum isIEEEQuadruple = floatTraits!real.realFormat == RealFormat.ieeeQuadruple; 83 84 private template CustomFloatParams(uint bits) 85 { 86 enum CustomFloatFlags flags = CustomFloatFlags.ieee 87 ^ ((bits == 80 && !isIEEEQuadruple) ? CustomFloatFlags.storeNormalized : CustomFloatFlags.none); 88 static if (bits == 8) alias CustomFloatParams = CustomFloatParams!( 4, 3, flags); 89 static if (bits == 16) alias CustomFloatParams = CustomFloatParams!(10, 5, flags); 90 static if (bits == 32) alias CustomFloatParams = CustomFloatParams!(23, 8, flags); 91 static if (bits == 64) alias CustomFloatParams = CustomFloatParams!(52, 11, flags); 92 static if (bits == 80) alias CustomFloatParams = CustomFloatParams!(64, 15, flags); 93 } 94 95 private template CustomFloatParams(uint precision, uint exponentWidth, CustomFloatFlags flags) 96 { 97 import std.meta : AliasSeq; 98 alias CustomFloatParams = 99 AliasSeq!( 100 precision, 101 exponentWidth, 102 flags, 103 (1 << (exponentWidth - ((flags & flags.probability) == 0))) 104 - ((flags & (flags.nan | flags.infinity)) != 0) - ((flags & flags.probability) != 0) 105 ); // ((flags & CustomFloatFlags.probability) == 0) 106 } 107 108 /** 109 * Allows user code to define custom floating-point formats. These formats are 110 * for storage only; all operations on them are performed by first implicitly 111 * extracting them to `real` first. After the operation is completed the 112 * result can be stored in a custom floating-point value via assignment. 113 */ 114 template CustomFloat(uint bits) 115 if (bits == 8 || bits == 16 || bits == 32 || bits == 64 || bits == 80) 116 { 117 alias CustomFloat = CustomFloat!(CustomFloatParams!(bits)); 118 } 119 120 /// ditto 121 template CustomFloat(uint precision, uint exponentWidth, CustomFloatFlags flags = CustomFloatFlags.ieee) 122 if (((flags & flags.signed) + precision + exponentWidth) % 8 == 0 && precision + exponentWidth > 0) 123 { 124 alias CustomFloat = CustomFloat!(CustomFloatParams!(precision, exponentWidth, flags)); 125 } 126 127 /// 128 @safe unittest 129 { 130 import std.math.trigonometry : sin, cos; 131 132 // Define a 16-bit floating point values 133 CustomFloat!16 x; // Using the number of bits 134 CustomFloat!(10, 5) y; // Using the precision and exponent width 135 CustomFloat!(10, 5,CustomFloatFlags.ieee) z; // Using the precision, exponent width and format flags 136 CustomFloat!(10, 5,CustomFloatFlags.ieee, 15) w; // Using the precision, exponent width, format flags and exponent offset bias 137 138 // Use the 16-bit floats mostly like normal numbers 139 w = x*y - 1; 140 141 // Functions calls require conversion 142 z = sin(+x) + cos(+y); // Use unary plus to concisely convert to a real 143 z = sin(x.get!float) + cos(y.get!float); // Or use get!T 144 z = sin(cast(float) x) + cos(cast(float) y); // Or use cast(T) to explicitly convert 145 146 // Define a 8-bit custom float for storing probabilities 147 alias Probability = CustomFloat!(4, 4, CustomFloatFlags.ieee^CustomFloatFlags.probability^CustomFloatFlags.signed ); 148 auto p = Probability(0.5); 149 } 150 151 // Facilitate converting numeric types to custom float 152 private union ToBinary(F) 153 if (is(typeof(CustomFloatParams!(F.sizeof*8))) || is(F == real)) 154 { 155 F set; 156 157 // If on Linux or Mac, where 80-bit reals are padded, ignore the 158 // padding. 159 import std.algorithm.comparison : min; 160 CustomFloat!(CustomFloatParams!(min(F.sizeof*8, 80))) get; 161 162 // Convert F to the correct binary type. 163 static typeof(get) opCall(F value) 164 { 165 ToBinary r; 166 r.set = value; 167 return r.get; 168 } 169 alias get this; 170 } 171 172 /// ditto 173 struct CustomFloat(uint precision, // fraction bits (23 for float) 174 uint exponentWidth, // exponent bits (8 for float) Exponent width 175 CustomFloatFlags flags, 176 uint bias) 177 if (isCorrectCustomFloat(precision, exponentWidth, flags)) 178 { 179 import std.bitmanip : bitfields; 180 import std.meta : staticIndexOf; 181 private: 182 // get the correct unsigned bitfield type to support > 32 bits 183 template uType(uint bits) 184 { 185 static if (bits <= size_t.sizeof*8) alias uType = size_t; 186 else alias uType = ulong ; 187 } 188 189 // get the correct signed bitfield type to support > 32 bits 190 template sType(uint bits) 191 { 192 static if (bits <= ptrdiff_t.sizeof*8-1) alias sType = ptrdiff_t; 193 else alias sType = long; 194 } 195 196 alias T_sig = uType!precision; 197 alias T_exp = uType!exponentWidth; 198 alias T_signed_exp = sType!exponentWidth; 199 200 alias Flags = CustomFloatFlags; 201 202 // Perform IEEE rounding with round to nearest detection 203 void roundedShift(T,U)(ref T sig, U shift) 204 { 205 if (shift >= T.sizeof*8) 206 { 207 // avoid illegal shift 208 sig = 0; 209 } 210 else if (sig << (T.sizeof*8 - shift) == cast(T) 1uL << (T.sizeof*8 - 1)) 211 { 212 // round to even 213 sig >>= shift; 214 sig += sig & 1; 215 } 216 else 217 { 218 sig >>= shift - 1; 219 sig += sig & 1; 220 // Perform standard rounding 221 sig >>= 1; 222 } 223 } 224 225 // Convert the current value to signed exponent, normalized form 226 void toNormalized(T,U)(ref T sig, ref U exp) const 227 { 228 sig = significand; 229 auto shift = (T.sizeof*8) - precision; 230 exp = exponent; 231 static if (flags&(Flags.infinity|Flags.nan)) 232 { 233 // Handle inf or nan 234 if (exp == exponent_max) 235 { 236 exp = exp.max; 237 sig <<= shift; 238 static if (flags&Flags.storeNormalized) 239 { 240 // Save inf/nan in denormalized format 241 sig >>= 1; 242 sig += cast(T) 1uL << (T.sizeof*8 - 1); 243 } 244 return; 245 } 246 } 247 if ((~flags&Flags.storeNormalized) || 248 // Convert denormalized form to normalized form 249 ((flags&Flags.allowDenorm) && exp == 0)) 250 { 251 if (sig > 0) 252 { 253 import core.bitop : bsr; 254 auto shift2 = precision - bsr(sig); 255 exp -= shift2-1; 256 shift += shift2; 257 } 258 else // value = 0.0 259 { 260 exp = exp.min; 261 return; 262 } 263 } 264 sig <<= shift; 265 exp -= bias; 266 } 267 268 // Set the current value from signed exponent, normalized form 269 void fromNormalized(T,U)(ref T sig, ref U exp) 270 { 271 auto shift = (T.sizeof*8) - precision; 272 if (exp == exp.max) 273 { 274 // infinity or nan 275 exp = exponent_max; 276 static if (flags & Flags.storeNormalized) 277 sig <<= 1; 278 279 // convert back to normalized form 280 static if (~flags & Flags.infinity) 281 // No infinity support? 282 assert(sig != 0, "Infinity floating point value assigned to a " 283 ~ typeof(this).stringof ~ " (no infinity support)."); 284 285 static if (~flags & Flags.nan) // No NaN support? 286 assert(sig == 0, "NaN floating point value assigned to a " ~ 287 typeof(this).stringof ~ " (no nan support)."); 288 sig >>= shift; 289 return; 290 } 291 if (exp == exp.min) // 0.0 292 { 293 exp = 0; 294 sig = 0; 295 return; 296 } 297 298 exp += bias; 299 if (exp <= 0) 300 { 301 static if ((flags&Flags.allowDenorm) || 302 // Convert from normalized form to denormalized 303 (~flags&Flags.storeNormalized)) 304 { 305 shift += -exp; 306 roundedShift(sig,1); 307 sig += cast(T) 1uL << (T.sizeof*8 - 1); 308 // Add the leading 1 309 exp = 0; 310 } 311 else 312 assert((flags&Flags.storeNormalized) && exp == 0, 313 "Underflow occured assigning to a " ~ 314 typeof(this).stringof ~ " (no denormal support)."); 315 } 316 else 317 { 318 static if (~flags&Flags.storeNormalized) 319 { 320 // Convert from normalized form to denormalized 321 roundedShift(sig,1); 322 sig += cast(T) 1uL << (T.sizeof*8 - 1); 323 // Add the leading 1 324 } 325 } 326 327 if (shift > 0) 328 roundedShift(sig,shift); 329 if (sig > significand_max) 330 { 331 // handle significand overflow (should only be 1 bit) 332 static if (~flags&Flags.storeNormalized) 333 { 334 sig >>= 1; 335 } 336 else 337 sig &= significand_max; 338 exp++; 339 } 340 static if ((flags&Flags.allowDenormZeroOnly)==Flags.allowDenormZeroOnly) 341 { 342 // disallow non-zero denormals 343 if (exp == 0) 344 { 345 sig <<= 1; 346 if (sig > significand_max && (sig&significand_max) > 0) 347 // Check and round to even 348 exp++; 349 sig = 0; 350 } 351 } 352 353 if (exp >= exponent_max) 354 { 355 static if (flags&(Flags.infinity|Flags.nan)) 356 { 357 sig = 0; 358 exp = exponent_max; 359 static if (~flags&(Flags.infinity)) 360 assert(0, "Overflow occured assigning to a " ~ 361 typeof(this).stringof ~ " (no infinity support)."); 362 } 363 else 364 assert(exp == exponent_max, "Overflow occured assigning to a " 365 ~ typeof(this).stringof ~ " (no infinity support)."); 366 } 367 } 368 369 public: 370 static if (precision == 64) // CustomFloat!80 support hack 371 { 372 static if (isIEEEQuadruple) 373 { 374 // Only use highest 64 significand bits from 112 explicitly stored 375 align (1): 376 enum ulong significand_max = ulong.max; 377 version (LittleEndian) 378 { 379 private ubyte[6] _padding; // 48-bit of padding 380 ulong significand; 381 mixin(bitfields!( 382 T_exp , "exponent", exponentWidth, 383 bool , "sign" , flags & flags.signed )); 384 } 385 else 386 { 387 mixin(bitfields!( 388 T_exp , "exponent", exponentWidth, 389 bool , "sign" , flags & flags.signed )); 390 ulong significand; 391 private ubyte[6] _padding; // 48-bit of padding 392 } 393 } 394 else 395 { 396 ulong significand; 397 enum ulong significand_max = ulong.max; 398 mixin(bitfields!( 399 T_exp , "exponent", exponentWidth, 400 bool , "sign" , flags & flags.signed )); 401 } 402 } 403 else 404 { 405 mixin(bitfields!( 406 T_sig, "significand", precision, 407 T_exp, "exponent" , exponentWidth, 408 bool , "sign" , flags & flags.signed )); 409 } 410 411 /// Returns: infinity value 412 static if (flags & Flags.infinity) 413 static @property CustomFloat infinity() 414 { 415 CustomFloat value; 416 static if (flags & Flags.signed) 417 value.sign = 0; 418 value.significand = 0; 419 value.exponent = exponent_max; 420 return value; 421 } 422 423 /// Returns: NaN value 424 static if (flags & Flags.nan) 425 static @property CustomFloat nan() 426 { 427 CustomFloat value; 428 static if (flags & Flags.signed) 429 value.sign = 0; 430 value.significand = cast(typeof(significand_max)) 1L << (precision-1); 431 value.exponent = exponent_max; 432 return value; 433 } 434 435 /// Returns: number of decimal digits of precision 436 static @property size_t dig() 437 { 438 auto shiftcnt = precision - ((flags&Flags.storeNormalized) == 0); 439 return shiftcnt == 64 ? 19 : cast(size_t) log10(real(1uL << shiftcnt)); 440 } 441 442 /// Returns: smallest increment to the value 1 443 static @property CustomFloat epsilon() 444 { 445 CustomFloat one = CustomFloat(1); 446 CustomFloat onePlusEpsilon = one; 447 onePlusEpsilon.significand = onePlusEpsilon.significand | 1; // |= does not work here 448 449 return CustomFloat(onePlusEpsilon - one); 450 } 451 452 /// the number of bits in mantissa 453 enum mant_dig = precision + ((flags&Flags.storeNormalized) != 0); 454 455 /// Returns: maximum int value such that 10<sup>max_10_exp</sup> is representable 456 static @property int max_10_exp(){ return cast(int) log10( +max ); } 457 458 /// maximum int value such that 2<sup>max_exp-1</sup> is representable 459 enum max_exp = exponent_max - bias - ((flags & (Flags.infinity | Flags.nan)) != 0) + 1; 460 461 /// Returns: minimum int value such that 10<sup>min_10_exp</sup> is representable 462 static @property int min_10_exp(){ return cast(int) log10( +min_normal ); } 463 464 /// minimum int value such that 2<sup>min_exp-1</sup> is representable as a normalized value 465 enum min_exp = cast(T_signed_exp) -(cast(long) bias) + 1 + ((flags & Flags.allowDenorm) != 0); 466 467 /// Returns: largest representable value that's not infinity 468 static @property CustomFloat max() 469 { 470 CustomFloat value; 471 static if (flags & Flags.signed) 472 value.sign = 0; 473 value.exponent = exponent_max - ((flags&(flags.infinity|flags.nan)) != 0); 474 value.significand = significand_max; 475 return value; 476 } 477 478 /// Returns: smallest representable normalized value that's not 0 479 static @property CustomFloat min_normal() 480 { 481 CustomFloat value; 482 static if (flags & Flags.signed) 483 value.sign = 0; 484 value.exponent = (flags & Flags.allowDenorm) != 0; 485 static if (flags & Flags.storeNormalized) 486 value.significand = 0; 487 else 488 value.significand = cast(T_sig) 1uL << (precision - 1); 489 return value; 490 } 491 492 /// Returns: real part 493 @property CustomFloat re() const { return this; } 494 495 /// Returns: imaginary part 496 static @property CustomFloat im() { return CustomFloat(0.0f); } 497 498 /// Initialize from any `real` compatible type. 499 this(F)(F input) 500 if (__traits(compiles, cast(real) input )) 501 { 502 this = input; 503 } 504 505 /// Self assignment 506 void opAssign(F:CustomFloat)(F input) 507 { 508 static if (flags & Flags.signed) 509 sign = input.sign; 510 exponent = input.exponent; 511 significand = input.significand; 512 } 513 514 /// Assigns from any `real` compatible type. 515 void opAssign(F)(F input) 516 if (__traits(compiles, cast(real) input)) 517 { 518 import std.conv : text; 519 520 static if (staticIndexOf!(immutable F, immutable float, immutable double, immutable real) >= 0) 521 auto value = ToBinary!(Unqual!F)(input); 522 else 523 auto value = ToBinary!(real )(input); 524 525 // Assign the sign bit 526 static if (~flags & Flags.signed) 527 assert((!value.sign) ^ ((flags&flags.negativeUnsigned) > 0), 528 "Incorrectly signed floating point value assigned to a " ~ 529 typeof(this).stringof ~ " (no sign support)."); 530 else 531 sign = value.sign; 532 533 CommonType!(T_signed_exp ,value.T_signed_exp) exp = value.exponent; 534 CommonType!(T_sig, value.T_sig ) sig = value.significand; 535 536 value.toNormalized(sig,exp); 537 fromNormalized(sig,exp); 538 539 assert(exp <= exponent_max, text(typeof(this).stringof ~ 540 " exponent too large: " ,exp," > ",exponent_max, "\t",input,"\t",sig)); 541 assert(sig <= significand_max, text(typeof(this).stringof ~ 542 " significand too large: ",sig," > ",significand_max, 543 "\t",input,"\t",exp," ",exponent_max)); 544 exponent = cast(T_exp) exp; 545 significand = cast(T_sig) sig; 546 } 547 548 /// Fetches the stored value either as a `float`, `double` or `real`. 549 @property F get(F)() const 550 if (staticIndexOf!(immutable F, immutable float, immutable double, immutable real) >= 0) 551 { 552 import std.conv : text; 553 554 ToBinary!F result; 555 556 static if (flags&Flags.signed) 557 result.sign = sign; 558 else 559 result.sign = (flags&flags.negativeUnsigned) > 0; 560 561 CommonType!(T_signed_exp ,result.get.T_signed_exp ) exp = exponent; // Assign the exponent and fraction 562 CommonType!(T_sig, result.get.T_sig ) sig = significand; 563 564 toNormalized(sig,exp); 565 result.fromNormalized(sig,exp); 566 assert(exp <= result.exponent_max, text("get exponent too large: " ,exp," > ",result.exponent_max) ); 567 assert(sig <= result.significand_max, text("get significand too large: ",sig," > ",result.significand_max) ); 568 result.exponent = cast(result.get.T_exp) exp; 569 result.significand = cast(result.get.T_sig) sig; 570 return result.set; 571 } 572 573 ///ditto 574 alias opCast = get; 575 576 /// Convert the CustomFloat to a real and perform the relevant operator on the result 577 real opUnary(string op)() 578 if (__traits(compiles, mixin(op~`(get!real)`)) || op=="++" || op=="--") 579 { 580 static if (op=="++" || op=="--") 581 { 582 auto result = get!real; 583 this = mixin(op~`result`); 584 return result; 585 } 586 else 587 return mixin(op~`get!real`); 588 } 589 590 /// ditto 591 // Define an opBinary `CustomFloat op CustomFloat` so that those below 592 // do not match equally, which is disallowed by the spec: 593 // https://dlang.org/spec/operatoroverloading.html#binary 594 real opBinary(string op,T)(T b) const 595 if (__traits(compiles, mixin(`get!real`~op~`b.get!real`))) 596 { 597 return mixin(`get!real`~op~`b.get!real`); 598 } 599 600 /// ditto 601 real opBinary(string op,T)(T b) const 602 if ( __traits(compiles, mixin(`get!real`~op~`b`)) && 603 !__traits(compiles, mixin(`get!real`~op~`b.get!real`))) 604 { 605 return mixin(`get!real`~op~`b`); 606 } 607 608 /// ditto 609 real opBinaryRight(string op,T)(T a) const 610 if ( __traits(compiles, mixin(`a`~op~`get!real`)) && 611 !__traits(compiles, mixin(`get!real`~op~`b`)) && 612 !__traits(compiles, mixin(`get!real`~op~`b.get!real`))) 613 { 614 return mixin(`a`~op~`get!real`); 615 } 616 617 /// ditto 618 int opCmp(T)(auto ref T b) const 619 if (__traits(compiles, cast(real) b)) 620 { 621 auto x = get!real; 622 auto y = cast(real) b; 623 return (x >= y)-(x <= y); 624 } 625 626 /// ditto 627 void opOpAssign(string op, T)(auto ref T b) 628 if (__traits(compiles, mixin(`get!real`~op~`cast(real) b`))) 629 { 630 return mixin(`this = this `~op~` cast(real) b`); 631 } 632 633 /// ditto 634 template toString() 635 { 636 import std.format.spec : FormatSpec; 637 import std.format.write : formatValue; 638 // Needs to be a template because of https://issues.dlang.org/show_bug.cgi?id=13737. 639 void toString()(scope void delegate(const(char)[]) sink, scope const ref FormatSpec!char fmt) 640 { 641 sink.formatValue(get!real, fmt); 642 } 643 } 644 } 645 646 @safe unittest 647 { 648 import std.meta; 649 alias FPTypes = 650 AliasSeq!( 651 CustomFloat!(5, 10), 652 CustomFloat!(5, 11, CustomFloatFlags.ieee ^ CustomFloatFlags.signed), 653 CustomFloat!(1, 7, CustomFloatFlags.ieee ^ CustomFloatFlags.signed), 654 CustomFloat!(4, 3, CustomFloatFlags.ieee | CustomFloatFlags.probability ^ CustomFloatFlags.signed) 655 ); 656 657 foreach (F; FPTypes) 658 { 659 auto x = F(0.125); 660 assert(x.get!float == 0.125F); 661 assert(x.get!double == 0.125); 662 assert(x.get!real == 0.125L); 663 664 x -= 0.0625; 665 assert(x.get!float == 0.0625F); 666 assert(x.get!double == 0.0625); 667 assert(x.get!real == 0.0625L); 668 669 x *= 2; 670 assert(x.get!float == 0.125F); 671 assert(x.get!double == 0.125); 672 assert(x.get!real == 0.125L); 673 674 x /= 4; 675 assert(x.get!float == 0.03125); 676 assert(x.get!double == 0.03125); 677 assert(x.get!real == 0.03125L); 678 679 x = 0.5; 680 x ^^= 4; 681 assert(x.get!float == 1 / 16.0F); 682 assert(x.get!double == 1 / 16.0); 683 assert(x.get!real == 1 / 16.0L); 684 } 685 } 686 687 @system unittest 688 { 689 // @system due to to!string(CustomFloat) 690 import std.conv; 691 CustomFloat!(5, 10) y = CustomFloat!(5, 10)(0.125); 692 assert(y.to!string == "0.125"); 693 } 694 695 @safe unittest 696 { 697 alias cf = CustomFloat!(5, 2); 698 699 auto a = cf.infinity; 700 assert(a.sign == 0); 701 assert(a.exponent == 3); 702 assert(a.significand == 0); 703 704 auto b = cf.nan; 705 assert(b.exponent == 3); 706 assert(b.significand != 0); 707 708 assert(cf.dig == 1); 709 710 auto c = cf.epsilon; 711 assert(c.sign == 0); 712 assert(c.exponent == 0); 713 assert(c.significand == 1); 714 715 assert(cf.mant_dig == 6); 716 717 assert(cf.max_10_exp == 0); 718 assert(cf.max_exp == 2); 719 assert(cf.min_10_exp == 0); 720 assert(cf.min_exp == 1); 721 722 auto d = cf.max; 723 assert(d.sign == 0); 724 assert(d.exponent == 2); 725 assert(d.significand == 31); 726 727 auto e = cf.min_normal; 728 assert(e.sign == 0); 729 assert(e.exponent == 1); 730 assert(e.significand == 0); 731 732 assert(e.re == e); 733 assert(e.im == cf(0.0)); 734 } 735 736 // check whether CustomFloats identical to float/double behave like float/double 737 @safe unittest 738 { 739 import std.conv : to; 740 741 alias myFloat = CustomFloat!(23, 8); 742 743 static assert(myFloat.dig == float.dig); 744 static assert(myFloat.mant_dig == float.mant_dig); 745 assert(myFloat.max_10_exp == float.max_10_exp); 746 static assert(myFloat.max_exp == float.max_exp); 747 assert(myFloat.min_10_exp == float.min_10_exp); 748 static assert(myFloat.min_exp == float.min_exp); 749 assert(to!float(myFloat.epsilon) == float.epsilon); 750 assert(to!float(myFloat.max) == float.max); 751 assert(to!float(myFloat.min_normal) == float.min_normal); 752 753 alias myDouble = CustomFloat!(52, 11); 754 755 static assert(myDouble.dig == double.dig); 756 static assert(myDouble.mant_dig == double.mant_dig); 757 assert(myDouble.max_10_exp == double.max_10_exp); 758 static assert(myDouble.max_exp == double.max_exp); 759 assert(myDouble.min_10_exp == double.min_10_exp); 760 static assert(myDouble.min_exp == double.min_exp); 761 assert(to!double(myDouble.epsilon) == double.epsilon); 762 assert(to!double(myDouble.max) == double.max); 763 assert(to!double(myDouble.min_normal) == double.min_normal); 764 } 765 766 // testing .dig 767 @safe unittest 768 { 769 static assert(CustomFloat!(1, 6).dig == 0); 770 static assert(CustomFloat!(9, 6).dig == 2); 771 static assert(CustomFloat!(10, 5).dig == 3); 772 static assert(CustomFloat!(10, 6, CustomFloatFlags.none).dig == 2); 773 static assert(CustomFloat!(11, 5, CustomFloatFlags.none).dig == 3); 774 static assert(CustomFloat!(64, 7).dig == 19); 775 } 776 777 // testing .mant_dig 778 @safe unittest 779 { 780 static assert(CustomFloat!(10, 5).mant_dig == 11); 781 static assert(CustomFloat!(10, 6, CustomFloatFlags.none).mant_dig == 10); 782 } 783 784 // testing .max_exp 785 @safe unittest 786 { 787 static assert(CustomFloat!(1, 6).max_exp == 2^^5); 788 static assert(CustomFloat!(2, 6, CustomFloatFlags.none).max_exp == 2^^5); 789 static assert(CustomFloat!(5, 10).max_exp == 2^^9); 790 static assert(CustomFloat!(6, 10, CustomFloatFlags.none).max_exp == 2^^9); 791 static assert(CustomFloat!(2, 6, CustomFloatFlags.nan).max_exp == 2^^5); 792 static assert(CustomFloat!(6, 10, CustomFloatFlags.nan).max_exp == 2^^9); 793 } 794 795 // testing .min_exp 796 @safe unittest 797 { 798 static assert(CustomFloat!(1, 6).min_exp == -2^^5+3); 799 static assert(CustomFloat!(5, 10).min_exp == -2^^9+3); 800 static assert(CustomFloat!(2, 6, CustomFloatFlags.none).min_exp == -2^^5+1); 801 static assert(CustomFloat!(6, 10, CustomFloatFlags.none).min_exp == -2^^9+1); 802 static assert(CustomFloat!(2, 6, CustomFloatFlags.nan).min_exp == -2^^5+2); 803 static assert(CustomFloat!(6, 10, CustomFloatFlags.nan).min_exp == -2^^9+2); 804 static assert(CustomFloat!(2, 6, CustomFloatFlags.allowDenorm).min_exp == -2^^5+2); 805 static assert(CustomFloat!(6, 10, CustomFloatFlags.allowDenorm).min_exp == -2^^9+2); 806 } 807 808 // testing .max_10_exp 809 @safe unittest 810 { 811 assert(CustomFloat!(1, 6).max_10_exp == 9); 812 assert(CustomFloat!(5, 10).max_10_exp == 154); 813 assert(CustomFloat!(2, 6, CustomFloatFlags.none).max_10_exp == 9); 814 assert(CustomFloat!(6, 10, CustomFloatFlags.none).max_10_exp == 154); 815 assert(CustomFloat!(2, 6, CustomFloatFlags.nan).max_10_exp == 9); 816 assert(CustomFloat!(6, 10, CustomFloatFlags.nan).max_10_exp == 154); 817 } 818 819 // testing .min_10_exp 820 @safe unittest 821 { 822 assert(CustomFloat!(1, 6).min_10_exp == -9); 823 assert(CustomFloat!(5, 10).min_10_exp == -153); 824 assert(CustomFloat!(2, 6, CustomFloatFlags.none).min_10_exp == -9); 825 assert(CustomFloat!(6, 10, CustomFloatFlags.none).min_10_exp == -154); 826 assert(CustomFloat!(2, 6, CustomFloatFlags.nan).min_10_exp == -9); 827 assert(CustomFloat!(6, 10, CustomFloatFlags.nan).min_10_exp == -153); 828 assert(CustomFloat!(2, 6, CustomFloatFlags.allowDenorm).min_10_exp == -9); 829 assert(CustomFloat!(6, 10, CustomFloatFlags.allowDenorm).min_10_exp == -153); 830 } 831 832 // testing .epsilon 833 @safe unittest 834 { 835 assert(CustomFloat!(1,6).epsilon.sign == 0); 836 assert(CustomFloat!(1,6).epsilon.exponent == 30); 837 assert(CustomFloat!(1,6).epsilon.significand == 0); 838 assert(CustomFloat!(2,5).epsilon.sign == 0); 839 assert(CustomFloat!(2,5).epsilon.exponent == 13); 840 assert(CustomFloat!(2,5).epsilon.significand == 0); 841 assert(CustomFloat!(3,4).epsilon.sign == 0); 842 assert(CustomFloat!(3,4).epsilon.exponent == 4); 843 assert(CustomFloat!(3,4).epsilon.significand == 0); 844 // the following epsilons are only available, when denormalized numbers are allowed: 845 assert(CustomFloat!(4,3).epsilon.sign == 0); 846 assert(CustomFloat!(4,3).epsilon.exponent == 0); 847 assert(CustomFloat!(4,3).epsilon.significand == 4); 848 assert(CustomFloat!(5,2).epsilon.sign == 0); 849 assert(CustomFloat!(5,2).epsilon.exponent == 0); 850 assert(CustomFloat!(5,2).epsilon.significand == 1); 851 } 852 853 // testing .max 854 @safe unittest 855 { 856 static assert(CustomFloat!(5,2).max.sign == 0); 857 static assert(CustomFloat!(5,2).max.exponent == 2); 858 static assert(CustomFloat!(5,2).max.significand == 31); 859 static assert(CustomFloat!(4,3).max.sign == 0); 860 static assert(CustomFloat!(4,3).max.exponent == 6); 861 static assert(CustomFloat!(4,3).max.significand == 15); 862 static assert(CustomFloat!(3,4).max.sign == 0); 863 static assert(CustomFloat!(3,4).max.exponent == 14); 864 static assert(CustomFloat!(3,4).max.significand == 7); 865 static assert(CustomFloat!(2,5).max.sign == 0); 866 static assert(CustomFloat!(2,5).max.exponent == 30); 867 static assert(CustomFloat!(2,5).max.significand == 3); 868 static assert(CustomFloat!(1,6).max.sign == 0); 869 static assert(CustomFloat!(1,6).max.exponent == 62); 870 static assert(CustomFloat!(1,6).max.significand == 1); 871 static assert(CustomFloat!(3,5, CustomFloatFlags.none).max.exponent == 31); 872 static assert(CustomFloat!(3,5, CustomFloatFlags.none).max.significand == 7); 873 } 874 875 // testing .min_normal 876 @safe unittest 877 { 878 static assert(CustomFloat!(5,2).min_normal.sign == 0); 879 static assert(CustomFloat!(5,2).min_normal.exponent == 1); 880 static assert(CustomFloat!(5,2).min_normal.significand == 0); 881 static assert(CustomFloat!(4,3).min_normal.sign == 0); 882 static assert(CustomFloat!(4,3).min_normal.exponent == 1); 883 static assert(CustomFloat!(4,3).min_normal.significand == 0); 884 static assert(CustomFloat!(3,4).min_normal.sign == 0); 885 static assert(CustomFloat!(3,4).min_normal.exponent == 1); 886 static assert(CustomFloat!(3,4).min_normal.significand == 0); 887 static assert(CustomFloat!(2,5).min_normal.sign == 0); 888 static assert(CustomFloat!(2,5).min_normal.exponent == 1); 889 static assert(CustomFloat!(2,5).min_normal.significand == 0); 890 static assert(CustomFloat!(1,6).min_normal.sign == 0); 891 static assert(CustomFloat!(1,6).min_normal.exponent == 1); 892 static assert(CustomFloat!(1,6).min_normal.significand == 0); 893 static assert(CustomFloat!(3,5, CustomFloatFlags.none).min_normal.exponent == 0); 894 static assert(CustomFloat!(3,5, CustomFloatFlags.none).min_normal.significand == 4); 895 } 896 897 @safe unittest 898 { 899 import std.math.traits : isNaN; 900 901 alias cf = CustomFloat!(5, 2); 902 903 auto f = cf.nan.get!float(); 904 assert(isNaN(f)); 905 906 cf a; 907 a = real.max; 908 assert(a == cf.infinity); 909 910 a = 0.015625; 911 assert(a.exponent == 0); 912 assert(a.significand == 0); 913 914 a = 0.984375; 915 assert(a.exponent == 1); 916 assert(a.significand == 0); 917 } 918 919 @system unittest 920 { 921 import std.exception : assertThrown; 922 import core.exception : AssertError; 923 924 alias cf = CustomFloat!(3, 5, CustomFloatFlags.none); 925 926 cf a; 927 assertThrown!AssertError(a = real.max); 928 } 929 930 @system unittest 931 { 932 import std.exception : assertThrown; 933 import core.exception : AssertError; 934 935 alias cf = CustomFloat!(3, 5, CustomFloatFlags.nan); 936 937 cf a; 938 assertThrown!AssertError(a = real.max); 939 } 940 941 @system unittest 942 { 943 import std.exception : assertThrown; 944 import core.exception : AssertError; 945 946 alias cf = CustomFloat!(24, 8, CustomFloatFlags.none); 947 948 cf a; 949 assertThrown!AssertError(a = float.infinity); 950 } 951 952 @safe unittest 953 { 954 const CustomFloat!16 x = CustomFloat!16(3); 955 assert(x.get!float == 3); 956 assert(x.re.get!float == 3); 957 assert(x + x == 6); 958 assert(x + 1 == 4); 959 assert(2 + x == 5); 960 assert(x < 4); 961 } 962 963 private bool isCorrectCustomFloat(uint precision, uint exponentWidth, CustomFloatFlags flags) @safe pure nothrow @nogc 964 { 965 // Restrictions from bitfield 966 // due to CustomFloat!80 support hack precision with 64 bits is handled specially 967 auto length = (flags & flags.signed) + exponentWidth + ((precision == 64) ? 0 : precision); 968 if (length != 8 && length != 16 && length != 32 && length != 64) return false; 969 970 // mantissa needs to fit into real mantissa 971 if (precision > real.mant_dig - 1 && precision != 64) return false; 972 973 // exponent needs to fit into real exponent 974 if (1L << exponentWidth - 1 > real.max_exp) return false; 975 976 // mantissa should have at least one bit 977 if (precision == 0) return false; 978 979 // exponent should have at least one bit, in some cases two 980 if (exponentWidth <= ((flags & (flags.allowDenorm | flags.infinity | flags.nan)) != 0)) return false; 981 982 return true; 983 } 984 985 @safe pure nothrow @nogc unittest 986 { 987 assert(isCorrectCustomFloat(3,4,CustomFloatFlags.ieee)); 988 assert(isCorrectCustomFloat(3,5,CustomFloatFlags.none)); 989 assert(!isCorrectCustomFloat(3,3,CustomFloatFlags.ieee)); 990 assert(isCorrectCustomFloat(64,7,CustomFloatFlags.ieee)); 991 assert(!isCorrectCustomFloat(64,4,CustomFloatFlags.ieee)); 992 assert(!isCorrectCustomFloat(508,3,CustomFloatFlags.ieee)); 993 assert(!isCorrectCustomFloat(3,100,CustomFloatFlags.ieee)); 994 assert(!isCorrectCustomFloat(0,7,CustomFloatFlags.ieee)); 995 assert(!isCorrectCustomFloat(6,1,CustomFloatFlags.ieee)); 996 assert(isCorrectCustomFloat(7,1,CustomFloatFlags.none)); 997 assert(!isCorrectCustomFloat(8,0,CustomFloatFlags.none)); 998 } 999 1000 /** 1001 Defines the fastest type to use when storing temporaries of a 1002 calculation intended to ultimately yield a result of type `F` 1003 (where `F` must be one of `float`, `double`, or $(D 1004 real)). When doing a multi-step computation, you may want to store 1005 intermediate results as `FPTemporary!F`. 1006 1007 The necessity of `FPTemporary` stems from the optimized 1008 floating-point operations and registers present in virtually all 1009 processors. When adding numbers in the example above, the addition may 1010 in fact be done in `real` precision internally. In that case, 1011 storing the intermediate `result` in $(D double format) is not only 1012 less precise, it is also (surprisingly) slower, because a conversion 1013 from `real` to `double` is performed every pass through the 1014 loop. This being a lose-lose situation, `FPTemporary!F` has been 1015 defined as the $(I fastest) type to use for calculations at precision 1016 `F`. There is no need to define a type for the $(I most accurate) 1017 calculations, as that is always `real`. 1018 1019 Finally, there is no guarantee that using `FPTemporary!F` will 1020 always be fastest, as the speed of floating-point calculations depends 1021 on very many factors. 1022 */ 1023 template FPTemporary(F) 1024 if (isFloatingPoint!F) 1025 { 1026 version (X86) 1027 alias FPTemporary = real; 1028 else 1029 alias FPTemporary = Unqual!F; 1030 } 1031 1032 /// 1033 @safe unittest 1034 { 1035 import std.math.operations : isClose; 1036 1037 // Average numbers in an array 1038 double avg(in double[] a) 1039 { 1040 if (a.length == 0) return 0; 1041 FPTemporary!double result = 0; 1042 foreach (e; a) result += e; 1043 return result / a.length; 1044 } 1045 1046 auto a = [1.0, 2.0, 3.0]; 1047 assert(isClose(avg(a), 2)); 1048 } 1049 1050 /** 1051 Implements the $(HTTP tinyurl.com/2zb9yr, secant method) for finding a 1052 root of the function `fun` starting from points $(D [xn_1, x_n]) 1053 (ideally close to the root). `Num` may be `float`, `double`, 1054 or `real`. 1055 */ 1056 template secantMethod(alias fun) 1057 { 1058 import std.functional : unaryFun; 1059 Num secantMethod(Num)(Num xn_1, Num xn) 1060 { 1061 auto fxn = unaryFun!(fun)(xn_1), d = xn_1 - xn; 1062 typeof(fxn) fxn_1; 1063 1064 xn = xn_1; 1065 while (!isClose(d, 0, 0.0, 1e-5) && isFinite(d)) 1066 { 1067 xn_1 = xn; 1068 xn -= d; 1069 fxn_1 = fxn; 1070 fxn = unaryFun!(fun)(xn); 1071 d *= -fxn / (fxn - fxn_1); 1072 } 1073 return xn; 1074 } 1075 } 1076 1077 /// 1078 @safe unittest 1079 { 1080 import std.math.operations : isClose; 1081 import std.math.trigonometry : cos; 1082 1083 float f(float x) 1084 { 1085 return cos(x) - x*x*x; 1086 } 1087 auto x = secantMethod!(f)(0f, 1f); 1088 assert(isClose(x, 0.865474)); 1089 } 1090 1091 @system unittest 1092 { 1093 // @system because of __gshared stderr 1094 import std.stdio; 1095 scope(failure) stderr.writeln("Failure testing secantMethod"); 1096 float f(float x) 1097 { 1098 return cos(x) - x*x*x; 1099 } 1100 immutable x = secantMethod!(f)(0f, 1f); 1101 assert(isClose(x, 0.865474)); 1102 auto d = &f; 1103 immutable y = secantMethod!(d)(0f, 1f); 1104 assert(isClose(y, 0.865474)); 1105 } 1106 1107 1108 /** 1109 * Return true if a and b have opposite sign. 1110 */ 1111 private bool oppositeSigns(T1, T2)(T1 a, T2 b) 1112 { 1113 return signbit(a) != signbit(b); 1114 } 1115 1116 public: 1117 1118 /** Find a real root of a real function f(x) via bracketing. 1119 * 1120 * Given a function `f` and a range `[a .. b]` such that `f(a)` 1121 * and `f(b)` have opposite signs or at least one of them equals ±0, 1122 * returns the value of `x` in 1123 * the range which is closest to a root of `f(x)`. If `f(x)` 1124 * has more than one root in the range, one will be chosen 1125 * arbitrarily. If `f(x)` returns NaN, NaN will be returned; 1126 * otherwise, this algorithm is guaranteed to succeed. 1127 * 1128 * Uses an algorithm based on TOMS748, which uses inverse cubic 1129 * interpolation whenever possible, otherwise reverting to parabolic 1130 * or secant interpolation. Compared to TOMS748, this implementation 1131 * improves worst-case performance by a factor of more than 100, and 1132 * typical performance by a factor of 2. For 80-bit reals, most 1133 * problems require 8 to 15 calls to `f(x)` to achieve full machine 1134 * precision. The worst-case performance (pathological cases) is 1135 * approximately twice the number of bits. 1136 * 1137 * References: "On Enclosing Simple Roots of Nonlinear Equations", 1138 * G. Alefeld, F.A. Potra, Yixun Shi, Mathematics of Computation 61, 1139 * pp733-744 (1993). Fortran code available from 1140 * $(HTTP www.netlib.org,www.netlib.org) as algorithm TOMS478. 1141 * 1142 */ 1143 T findRoot(T, DF, DT)(scope DF f, const T a, const T b, 1144 scope DT tolerance) //= (T a, T b) => false) 1145 if ( 1146 isFloatingPoint!T && 1147 is(typeof(tolerance(T.init, T.init)) : bool) && 1148 is(typeof(f(T.init)) == R, R) && isFloatingPoint!R 1149 ) 1150 { 1151 immutable fa = f(a); 1152 if (fa == 0) 1153 return a; 1154 immutable fb = f(b); 1155 if (fb == 0) 1156 return b; 1157 immutable r = findRoot(f, a, b, fa, fb, tolerance); 1158 // Return the first value if it is smaller or NaN 1159 return !(fabs(r[2]) > fabs(r[3])) ? r[0] : r[1]; 1160 } 1161 1162 ///ditto 1163 T findRoot(T, DF)(scope DF f, const T a, const T b) 1164 { 1165 return findRoot(f, a, b, (T a, T b) => false); 1166 } 1167 1168 /** Find root of a real function f(x) by bracketing, allowing the 1169 * termination condition to be specified. 1170 * 1171 * Params: 1172 * 1173 * f = Function to be analyzed 1174 * 1175 * ax = Left bound of initial range of `f` known to contain the 1176 * root. 1177 * 1178 * bx = Right bound of initial range of `f` known to contain the 1179 * root. 1180 * 1181 * fax = Value of `f(ax)`. 1182 * 1183 * fbx = Value of `f(bx)`. `fax` and `fbx` should have opposite signs. 1184 * (`f(ax)` and `f(bx)` are commonly known in advance.) 1185 * 1186 * 1187 * tolerance = Defines an early termination condition. Receives the 1188 * current upper and lower bounds on the root. The 1189 * delegate must return `true` when these bounds are 1190 * acceptable. If this function always returns `false`, 1191 * full machine precision will be achieved. 1192 * 1193 * Returns: 1194 * 1195 * A tuple consisting of two ranges. The first two elements are the 1196 * range (in `x`) of the root, while the second pair of elements 1197 * are the corresponding function values at those points. If an exact 1198 * root was found, both of the first two elements will contain the 1199 * root, and the second pair of elements will be 0. 1200 */ 1201 Tuple!(T, T, R, R) findRoot(T, R, DF, DT)(scope DF f, 1202 const T ax, const T bx, const R fax, const R fbx, 1203 scope DT tolerance) // = (T a, T b) => false) 1204 if ( 1205 isFloatingPoint!T && 1206 is(typeof(tolerance(T.init, T.init)) : bool) && 1207 is(typeof(f(T.init)) == R) && isFloatingPoint!R 1208 ) 1209 in 1210 { 1211 assert(!ax.isNaN() && !bx.isNaN(), "Limits must not be NaN"); 1212 assert(signbit(fax) != signbit(fbx), "Parameters must bracket the root."); 1213 } 1214 do 1215 { 1216 // Author: Don Clugston. This code is (heavily) modified from TOMS748 1217 // (www.netlib.org). The changes to improve the worst-cast performance are 1218 // entirely original. 1219 1220 T a, b, d; // [a .. b] is our current bracket. d is the third best guess. 1221 R fa, fb, fd; // Values of f at a, b, d. 1222 bool done = false; // Has a root been found? 1223 1224 // Allow ax and bx to be provided in reverse order 1225 if (ax <= bx) 1226 { 1227 a = ax; fa = fax; 1228 b = bx; fb = fbx; 1229 } 1230 else 1231 { 1232 a = bx; fa = fbx; 1233 b = ax; fb = fax; 1234 } 1235 1236 // Test the function at point c; update brackets accordingly 1237 void bracket(T c) 1238 { 1239 R fc = f(c); 1240 if (fc == 0 || fc.isNaN()) // Exact solution, or NaN 1241 { 1242 a = c; 1243 fa = fc; 1244 d = c; 1245 fd = fc; 1246 done = true; 1247 return; 1248 } 1249 1250 // Determine new enclosing interval 1251 if (signbit(fa) != signbit(fc)) 1252 { 1253 d = b; 1254 fd = fb; 1255 b = c; 1256 fb = fc; 1257 } 1258 else 1259 { 1260 d = a; 1261 fd = fa; 1262 a = c; 1263 fa = fc; 1264 } 1265 } 1266 1267 /* Perform a secant interpolation. If the result would lie on a or b, or if 1268 a and b differ so wildly in magnitude that the result would be meaningless, 1269 perform a bisection instead. 1270 */ 1271 static T secant_interpolate(T a, T b, R fa, R fb) 1272 { 1273 if (( ((a - b) == a) && b != 0) || (a != 0 && ((b - a) == b))) 1274 { 1275 // Catastrophic cancellation 1276 if (a == 0) 1277 a = copysign(T(0), b); 1278 else if (b == 0) 1279 b = copysign(T(0), a); 1280 else if (signbit(a) != signbit(b)) 1281 return 0; 1282 T c = ieeeMean(a, b); 1283 return c; 1284 } 1285 // avoid overflow 1286 if (b - a > T.max) 1287 return b / 2 + a / 2; 1288 if (fb - fa > R.max) 1289 return a - (b - a) / 2; 1290 T c = a - (fa / (fb - fa)) * (b - a); 1291 if (c == a || c == b) 1292 return (a + b) / 2; 1293 return c; 1294 } 1295 1296 /* Uses 'numsteps' newton steps to approximate the zero in [a .. b] of the 1297 quadratic polynomial interpolating f(x) at a, b, and d. 1298 Returns: 1299 The approximate zero in [a .. b] of the quadratic polynomial. 1300 */ 1301 T newtonQuadratic(int numsteps) 1302 { 1303 // Find the coefficients of the quadratic polynomial. 1304 immutable T a0 = fa; 1305 immutable T a1 = (fb - fa)/(b - a); 1306 immutable T a2 = ((fd - fb)/(d - b) - a1)/(d - a); 1307 1308 // Determine the starting point of newton steps. 1309 T c = oppositeSigns(a2, fa) ? a : b; 1310 1311 // start the safeguarded newton steps. 1312 foreach (int i; 0 .. numsteps) 1313 { 1314 immutable T pc = a0 + (a1 + a2 * (c - b))*(c - a); 1315 immutable T pdc = a1 + a2*((2 * c) - (a + b)); 1316 if (pdc == 0) 1317 return a - a0 / a1; 1318 else 1319 c = c - pc / pdc; 1320 } 1321 return c; 1322 } 1323 1324 // On the first iteration we take a secant step: 1325 if (fa == 0 || fa.isNaN()) 1326 { 1327 done = true; 1328 b = a; 1329 fb = fa; 1330 } 1331 else if (fb == 0 || fb.isNaN()) 1332 { 1333 done = true; 1334 a = b; 1335 fa = fb; 1336 } 1337 else 1338 { 1339 bracket(secant_interpolate(a, b, fa, fb)); 1340 } 1341 1342 // Starting with the second iteration, higher-order interpolation can 1343 // be used. 1344 int itnum = 1; // Iteration number 1345 int baditer = 1; // Num bisections to take if an iteration is bad. 1346 T c, e; // e is our fourth best guess 1347 R fe; 1348 1349 whileloop: 1350 while (!done && (b != nextUp(a)) && !tolerance(a, b)) 1351 { 1352 T a0 = a, b0 = b; // record the brackets 1353 1354 // Do two higher-order (cubic or parabolic) interpolation steps. 1355 foreach (int QQ; 0 .. 2) 1356 { 1357 // Cubic inverse interpolation requires that 1358 // all four function values fa, fb, fd, and fe are distinct; 1359 // otherwise use quadratic interpolation. 1360 bool distinct = (fa != fb) && (fa != fd) && (fa != fe) 1361 && (fb != fd) && (fb != fe) && (fd != fe); 1362 // The first time, cubic interpolation is impossible. 1363 if (itnum<2) distinct = false; 1364 bool ok = distinct; 1365 if (distinct) 1366 { 1367 // Cubic inverse interpolation of f(x) at a, b, d, and e 1368 immutable q11 = (d - e) * fd / (fe - fd); 1369 immutable q21 = (b - d) * fb / (fd - fb); 1370 immutable q31 = (a - b) * fa / (fb - fa); 1371 immutable d21 = (b - d) * fd / (fd - fb); 1372 immutable d31 = (a - b) * fb / (fb - fa); 1373 1374 immutable q22 = (d21 - q11) * fb / (fe - fb); 1375 immutable q32 = (d31 - q21) * fa / (fd - fa); 1376 immutable d32 = (d31 - q21) * fd / (fd - fa); 1377 immutable q33 = (d32 - q22) * fa / (fe - fa); 1378 c = a + (q31 + q32 + q33); 1379 if (c.isNaN() || (c <= a) || (c >= b)) 1380 { 1381 // DAC: If the interpolation predicts a or b, it's 1382 // probable that it's the actual root. Only allow this if 1383 // we're already close to the root. 1384 if (c == a && a - b != a) 1385 { 1386 c = nextUp(a); 1387 } 1388 else if (c == b && a - b != -b) 1389 { 1390 c = nextDown(b); 1391 } 1392 else 1393 { 1394 ok = false; 1395 } 1396 } 1397 } 1398 if (!ok) 1399 { 1400 // DAC: Alefeld doesn't explain why the number of newton steps 1401 // should vary. 1402 c = newtonQuadratic(distinct ? 3 : 2); 1403 if (c.isNaN() || (c <= a) || (c >= b)) 1404 { 1405 // Failure, try a secant step: 1406 c = secant_interpolate(a, b, fa, fb); 1407 } 1408 } 1409 ++itnum; 1410 e = d; 1411 fe = fd; 1412 bracket(c); 1413 if (done || ( b == nextUp(a)) || tolerance(a, b)) 1414 break whileloop; 1415 if (itnum == 2) 1416 continue whileloop; 1417 } 1418 1419 // Now we take a double-length secant step: 1420 T u; 1421 R fu; 1422 if (fabs(fa) < fabs(fb)) 1423 { 1424 u = a; 1425 fu = fa; 1426 } 1427 else 1428 { 1429 u = b; 1430 fu = fb; 1431 } 1432 c = u - 2 * (fu / (fb - fa)) * (b - a); 1433 1434 // DAC: If the secant predicts a value equal to an endpoint, it's 1435 // probably false. 1436 if (c == a || c == b || c.isNaN() || fabs(c - u) > (b - a) / 2) 1437 { 1438 if ((a-b) == a || (b-a) == b) 1439 { 1440 if ((a>0 && b<0) || (a<0 && b>0)) 1441 c = 0; 1442 else 1443 { 1444 if (a == 0) 1445 c = ieeeMean(copysign(T(0), b), b); 1446 else if (b == 0) 1447 c = ieeeMean(copysign(T(0), a), a); 1448 else 1449 c = ieeeMean(a, b); 1450 } 1451 } 1452 else 1453 { 1454 c = a + (b - a) / 2; 1455 } 1456 } 1457 e = d; 1458 fe = fd; 1459 bracket(c); 1460 if (done || (b == nextUp(a)) || tolerance(a, b)) 1461 break; 1462 1463 // IMPROVE THE WORST-CASE PERFORMANCE 1464 // We must ensure that the bounds reduce by a factor of 2 1465 // in binary space! every iteration. If we haven't achieved this 1466 // yet, or if we don't yet know what the exponent is, 1467 // perform a binary chop. 1468 1469 if ((a == 0 || b == 0 || 1470 (fabs(a) >= T(0.5) * fabs(b) && fabs(b) >= T(0.5) * fabs(a))) 1471 && (b - a) < T(0.25) * (b0 - a0)) 1472 { 1473 baditer = 1; 1474 continue; 1475 } 1476 1477 // DAC: If this happens on consecutive iterations, we probably have a 1478 // pathological function. Perform a number of bisections equal to the 1479 // total number of consecutive bad iterations. 1480 1481 if ((b - a) < T(0.25) * (b0 - a0)) 1482 baditer = 1; 1483 foreach (int QQ; 0 .. baditer) 1484 { 1485 e = d; 1486 fe = fd; 1487 1488 T w; 1489 if ((a>0 && b<0) || (a<0 && b>0)) 1490 w = 0; 1491 else 1492 { 1493 T usea = a; 1494 T useb = b; 1495 if (a == 0) 1496 usea = copysign(T(0), b); 1497 else if (b == 0) 1498 useb = copysign(T(0), a); 1499 w = ieeeMean(usea, useb); 1500 } 1501 bracket(w); 1502 } 1503 ++baditer; 1504 } 1505 return Tuple!(T, T, R, R)(a, b, fa, fb); 1506 } 1507 1508 ///ditto 1509 Tuple!(T, T, R, R) findRoot(T, R, DF)(scope DF f, 1510 const T ax, const T bx, const R fax, const R fbx) 1511 { 1512 return findRoot(f, ax, bx, fax, fbx, (T a, T b) => false); 1513 } 1514 1515 ///ditto 1516 T findRoot(T, R)(scope R delegate(T) f, const T a, const T b, 1517 scope bool delegate(T lo, T hi) tolerance = (T a, T b) => false) 1518 { 1519 return findRoot!(T, R delegate(T), bool delegate(T lo, T hi))(f, a, b, tolerance); 1520 } 1521 1522 @safe nothrow unittest 1523 { 1524 int numProblems = 0; 1525 int numCalls; 1526 1527 void testFindRoot(real delegate(real) @nogc @safe nothrow pure f , real x1, real x2) @nogc @safe nothrow pure 1528 { 1529 //numCalls=0; 1530 //++numProblems; 1531 assert(!x1.isNaN() && !x2.isNaN()); 1532 assert(signbit(f(x1)) != signbit(f(x2))); 1533 auto result = findRoot(f, x1, x2, f(x1), f(x2), 1534 (real lo, real hi) { return false; }); 1535 1536 auto flo = f(result[0]); 1537 auto fhi = f(result[1]); 1538 if (flo != 0) 1539 { 1540 assert(oppositeSigns(flo, fhi)); 1541 } 1542 } 1543 1544 // Test functions 1545 real cubicfn(real x) @nogc @safe nothrow pure 1546 { 1547 //++numCalls; 1548 if (x>float.max) 1549 x = float.max; 1550 if (x<-float.max) 1551 x = -float.max; 1552 // This has a single real root at -59.286543284815 1553 return 0.386*x*x*x + 23*x*x + 15.7*x + 525.2; 1554 } 1555 // Test a function with more than one root. 1556 real multisine(real x) { ++numCalls; return sin(x); } 1557 testFindRoot( &multisine, 6, 90); 1558 testFindRoot(&cubicfn, -100, 100); 1559 testFindRoot( &cubicfn, -double.max, real.max); 1560 1561 1562 /* Tests from the paper: 1563 * "On Enclosing Simple Roots of Nonlinear Equations", G. Alefeld, F.A. Potra, 1564 * Yixun Shi, Mathematics of Computation 61, pp733-744 (1993). 1565 */ 1566 // Parameters common to many alefeld tests. 1567 int n; 1568 real ale_a, ale_b; 1569 1570 int powercalls = 0; 1571 1572 real power(real x) 1573 { 1574 ++powercalls; 1575 ++numCalls; 1576 return pow(x, n) + double.min_normal; 1577 } 1578 int [] power_nvals = [3, 5, 7, 9, 19, 25]; 1579 // Alefeld paper states that pow(x,n) is a very poor case, where bisection 1580 // outperforms his method, and gives total numcalls = 1581 // 921 for bisection (2.4 calls per bit), 1830 for Alefeld (4.76/bit), 1582 // 2624 for brent (6.8/bit) 1583 // ... but that is for double, not real80. 1584 // This poor performance seems mainly due to catastrophic cancellation, 1585 // which is avoided here by the use of ieeeMean(). 1586 // I get: 231 (0.48/bit). 1587 // IE this is 10X faster in Alefeld's worst case 1588 numProblems=0; 1589 foreach (k; power_nvals) 1590 { 1591 n = k; 1592 testFindRoot(&power, -1, 10); 1593 } 1594 1595 int powerProblems = numProblems; 1596 1597 // Tests from Alefeld paper 1598 1599 int [9] alefeldSums; 1600 real alefeld0(real x) 1601 { 1602 ++alefeldSums[0]; 1603 ++numCalls; 1604 real q = sin(x) - x/2; 1605 for (int i=1; i<20; ++i) 1606 q+=(2*i-5.0)*(2*i-5.0)/((x-i*i)*(x-i*i)*(x-i*i)); 1607 return q; 1608 } 1609 real alefeld1(real x) 1610 { 1611 ++numCalls; 1612 ++alefeldSums[1]; 1613 return ale_a*x + exp(ale_b * x); 1614 } 1615 real alefeld2(real x) 1616 { 1617 ++numCalls; 1618 ++alefeldSums[2]; 1619 return pow(x, n) - ale_a; 1620 } 1621 real alefeld3(real x) 1622 { 1623 ++numCalls; 1624 ++alefeldSums[3]; 1625 return (1.0 +pow(1.0L-n, 2))*x - pow(1.0L-n*x, 2); 1626 } 1627 real alefeld4(real x) 1628 { 1629 ++numCalls; 1630 ++alefeldSums[4]; 1631 return x*x - pow(1-x, n); 1632 } 1633 real alefeld5(real x) 1634 { 1635 ++numCalls; 1636 ++alefeldSums[5]; 1637 return (1+pow(1.0L-n, 4))*x - pow(1.0L-n*x, 4); 1638 } 1639 real alefeld6(real x) 1640 { 1641 ++numCalls; 1642 ++alefeldSums[6]; 1643 return exp(-n*x)*(x-1.01L) + pow(x, n); 1644 } 1645 real alefeld7(real x) 1646 { 1647 ++numCalls; 1648 ++alefeldSums[7]; 1649 return (n*x-1)/((n-1)*x); 1650 } 1651 1652 numProblems=0; 1653 testFindRoot(&alefeld0, PI_2, PI); 1654 for (n=1; n <= 10; ++n) 1655 { 1656 testFindRoot(&alefeld0, n*n+1e-9L, (n+1)*(n+1)-1e-9L); 1657 } 1658 ale_a = -40; ale_b = -1; 1659 testFindRoot(&alefeld1, -9, 31); 1660 ale_a = -100; ale_b = -2; 1661 testFindRoot(&alefeld1, -9, 31); 1662 ale_a = -200; ale_b = -3; 1663 testFindRoot(&alefeld1, -9, 31); 1664 int [] nvals_3 = [1, 2, 5, 10, 15, 20]; 1665 int [] nvals_5 = [1, 2, 4, 5, 8, 15, 20]; 1666 int [] nvals_6 = [1, 5, 10, 15, 20]; 1667 int [] nvals_7 = [2, 5, 15, 20]; 1668 1669 for (int i=4; i<12; i+=2) 1670 { 1671 n = i; 1672 ale_a = 0.2; 1673 testFindRoot(&alefeld2, 0, 5); 1674 ale_a=1; 1675 testFindRoot(&alefeld2, 0.95, 4.05); 1676 testFindRoot(&alefeld2, 0, 1.5); 1677 } 1678 foreach (i; nvals_3) 1679 { 1680 n=i; 1681 testFindRoot(&alefeld3, 0, 1); 1682 } 1683 foreach (i; nvals_3) 1684 { 1685 n=i; 1686 testFindRoot(&alefeld4, 0, 1); 1687 } 1688 foreach (i; nvals_5) 1689 { 1690 n=i; 1691 testFindRoot(&alefeld5, 0, 1); 1692 } 1693 foreach (i; nvals_6) 1694 { 1695 n=i; 1696 testFindRoot(&alefeld6, 0, 1); 1697 } 1698 foreach (i; nvals_7) 1699 { 1700 n=i; 1701 testFindRoot(&alefeld7, 0.01L, 1); 1702 } 1703 real worstcase(real x) 1704 { 1705 ++numCalls; 1706 return x<0.3*real.max? -0.999e-3 : 1.0; 1707 } 1708 testFindRoot(&worstcase, -real.max, real.max); 1709 1710 // just check that the double + float cases compile 1711 findRoot((double x){ return 0.0; }, -double.max, double.max); 1712 findRoot((float x){ return 0.0f; }, -float.max, float.max); 1713 1714 /* 1715 int grandtotal=0; 1716 foreach (calls; alefeldSums) 1717 { 1718 grandtotal+=calls; 1719 } 1720 grandtotal-=2*numProblems; 1721 printf("\nALEFELD TOTAL = %d avg = %f (alefeld avg=19.3 for double)\n", 1722 grandtotal, (1.0*grandtotal)/numProblems); 1723 powercalls -= 2*powerProblems; 1724 printf("POWER TOTAL = %d avg = %f ", powercalls, 1725 (1.0*powercalls)/powerProblems); 1726 */ 1727 // https://issues.dlang.org/show_bug.cgi?id=14231 1728 auto xp = findRoot((float x) => x, 0f, 1f); 1729 auto xn = findRoot((float x) => x, -1f, -0f); 1730 } 1731 1732 //regression control 1733 @system unittest 1734 { 1735 // @system due to the case in the 2nd line 1736 static assert(__traits(compiles, findRoot((float x)=>cast(real) x, float.init, float.init))); 1737 static assert(__traits(compiles, findRoot!real((x)=>cast(double) x, real.init, real.init))); 1738 static assert(__traits(compiles, findRoot((real x)=>cast(double) x, real.init, real.init))); 1739 } 1740 1741 /++ 1742 Find a real minimum of a real function `f(x)` via bracketing. 1743 Given a function `f` and a range `(ax .. bx)`, 1744 returns the value of `x` in the range which is closest to a minimum of `f(x)`. 1745 `f` is never evaluted at the endpoints of `ax` and `bx`. 1746 If `f(x)` has more than one minimum in the range, one will be chosen arbitrarily. 1747 If `f(x)` returns NaN or -Infinity, `(x, f(x), NaN)` will be returned; 1748 otherwise, this algorithm is guaranteed to succeed. 1749 1750 Params: 1751 f = Function to be analyzed 1752 ax = Left bound of initial range of f known to contain the minimum. 1753 bx = Right bound of initial range of f known to contain the minimum. 1754 relTolerance = Relative tolerance. 1755 absTolerance = Absolute tolerance. 1756 1757 Preconditions: 1758 `ax` and `bx` shall be finite reals. $(BR) 1759 `relTolerance` shall be normal positive real. $(BR) 1760 `absTolerance` shall be normal positive real no less then `T.epsilon*2`. 1761 1762 Returns: 1763 A tuple consisting of `x`, `y = f(x)` and `error = 3 * (absTolerance * fabs(x) + relTolerance)`. 1764 1765 The method used is a combination of golden section search and 1766 successive parabolic interpolation. Convergence is never much slower 1767 than that for a Fibonacci search. 1768 1769 References: 1770 "Algorithms for Minimization without Derivatives", Richard Brent, Prentice-Hall, Inc. (1973) 1771 1772 See_Also: $(LREF findRoot), $(REF isNormal, std,math) 1773 +/ 1774 Tuple!(T, "x", Unqual!(ReturnType!DF), "y", T, "error") 1775 findLocalMin(T, DF)( 1776 scope DF f, 1777 const T ax, 1778 const T bx, 1779 const T relTolerance = sqrt(T.epsilon), 1780 const T absTolerance = sqrt(T.epsilon), 1781 ) 1782 if (isFloatingPoint!T 1783 && __traits(compiles, {T _ = DF.init(T.init);})) 1784 in 1785 { 1786 assert(isFinite(ax), "ax is not finite"); 1787 assert(isFinite(bx), "bx is not finite"); 1788 assert(isNormal(relTolerance), "relTolerance is not normal floating point number"); 1789 assert(isNormal(absTolerance), "absTolerance is not normal floating point number"); 1790 assert(relTolerance >= 0, "absTolerance is not positive"); 1791 assert(absTolerance >= T.epsilon*2, "absTolerance is not greater then `2*T.epsilon`"); 1792 } 1793 out (result) 1794 { 1795 assert(isFinite(result.x)); 1796 } 1797 do 1798 { 1799 alias R = Unqual!(CommonType!(ReturnType!DF, T)); 1800 // c is the squared inverse of the golden ratio 1801 // (3 - sqrt(5))/2 1802 // Value obtained from Wolfram Alpha. 1803 enum T c = 0x0.61c8864680b583ea0c633f9fa31237p+0L; 1804 enum T cm1 = 0x0.9e3779b97f4a7c15f39cc0605cedc8p+0L; 1805 R tolerance; 1806 T a = ax > bx ? bx : ax; 1807 T b = ax > bx ? ax : bx; 1808 // sequence of declarations suitable for SIMD instructions 1809 T v = a * cm1 + b * c; 1810 assert(isFinite(v)); 1811 R fv = f(v); 1812 if (isNaN(fv) || fv == -T.infinity) 1813 { 1814 return typeof(return)(v, fv, T.init); 1815 } 1816 T w = v; 1817 R fw = fv; 1818 T x = v; 1819 R fx = fv; 1820 size_t i; 1821 for (R d = 0, e = 0;;) 1822 { 1823 i++; 1824 T m = (a + b) / 2; 1825 // This fix is not part of the original algorithm 1826 if (!isFinite(m)) // fix infinity loop. Issue can be reproduced in R. 1827 { 1828 m = a / 2 + b / 2; 1829 if (!isFinite(m)) // fast-math compiler switch is enabled 1830 { 1831 //SIMD instructions can be used by compiler, do not reduce declarations 1832 int a_exp = void; 1833 int b_exp = void; 1834 immutable an = frexp(a, a_exp); 1835 immutable bn = frexp(b, b_exp); 1836 immutable am = ldexp(an, a_exp-1); 1837 immutable bm = ldexp(bn, b_exp-1); 1838 m = am + bm; 1839 if (!isFinite(m)) // wrong input: constraints are disabled in release mode 1840 { 1841 return typeof(return).init; 1842 } 1843 } 1844 } 1845 tolerance = absTolerance * fabs(x) + relTolerance; 1846 immutable t2 = tolerance * 2; 1847 // check stopping criterion 1848 if (!(fabs(x - m) > t2 - (b - a) / 2)) 1849 { 1850 break; 1851 } 1852 R p = 0; 1853 R q = 0; 1854 R r = 0; 1855 // fit parabola 1856 if (fabs(e) > tolerance) 1857 { 1858 immutable xw = x - w; 1859 immutable fxw = fx - fw; 1860 immutable xv = x - v; 1861 immutable fxv = fx - fv; 1862 immutable xwfxv = xw * fxv; 1863 immutable xvfxw = xv * fxw; 1864 p = xv * xvfxw - xw * xwfxv; 1865 q = (xvfxw - xwfxv) * 2; 1866 if (q > 0) 1867 p = -p; 1868 else 1869 q = -q; 1870 r = e; 1871 e = d; 1872 } 1873 T u; 1874 // a parabolic-interpolation step 1875 if (fabs(p) < fabs(q * r / 2) && p > q * (a - x) && p < q * (b - x)) 1876 { 1877 d = p / q; 1878 u = x + d; 1879 // f must not be evaluated too close to a or b 1880 if (u - a < t2 || b - u < t2) 1881 d = x < m ? tolerance : -tolerance; 1882 } 1883 // a golden-section step 1884 else 1885 { 1886 e = (x < m ? b : a) - x; 1887 d = c * e; 1888 } 1889 // f must not be evaluated too close to x 1890 u = x + (fabs(d) >= tolerance ? d : d > 0 ? tolerance : -tolerance); 1891 immutable fu = f(u); 1892 if (isNaN(fu) || fu == -T.infinity) 1893 { 1894 return typeof(return)(u, fu, T.init); 1895 } 1896 // update a, b, v, w, and x 1897 if (fu <= fx) 1898 { 1899 (u < x ? b : a) = x; 1900 v = w; fv = fw; 1901 w = x; fw = fx; 1902 x = u; fx = fu; 1903 } 1904 else 1905 { 1906 (u < x ? a : b) = u; 1907 if (fu <= fw || w == x) 1908 { 1909 v = w; fv = fw; 1910 w = u; fw = fu; 1911 } 1912 else if (fu <= fv || v == x || v == w) 1913 { // do not remove this braces 1914 v = u; fv = fu; 1915 } 1916 } 1917 } 1918 return typeof(return)(x, fx, tolerance * 3); 1919 } 1920 1921 /// 1922 @safe unittest 1923 { 1924 import std.math.operations : isClose; 1925 1926 auto ret = findLocalMin((double x) => (x-4)^^2, -1e7, 1e7); 1927 assert(ret.x.isClose(4.0)); 1928 assert(ret.y.isClose(0.0, 0.0, 1e-10)); 1929 } 1930 1931 @safe unittest 1932 { 1933 import std.meta : AliasSeq; 1934 static foreach (T; AliasSeq!(double, float, real)) 1935 { 1936 { 1937 auto ret = findLocalMin!T((T x) => (x-4)^^2, T.min_normal, 1e7); 1938 assert(ret.x.isClose(T(4))); 1939 assert(ret.y.isClose(T(0), 0.0, T.epsilon)); 1940 } 1941 { 1942 auto ret = findLocalMin!T((T x) => fabs(x-1), -T.max/4, T.max/4, T.min_normal, 2*T.epsilon); 1943 assert(isClose(ret.x, T(1))); 1944 assert(isClose(ret.y, T(0), 0.0, T.epsilon)); 1945 assert(ret.error <= 10 * T.epsilon); 1946 } 1947 { 1948 auto ret = findLocalMin!T((T x) => T.init, 0, 1, T.min_normal, 2*T.epsilon); 1949 assert(!ret.x.isNaN); 1950 assert(ret.y.isNaN); 1951 assert(ret.error.isNaN); 1952 } 1953 { 1954 auto ret = findLocalMin!T((T x) => log(x), 0, 1, T.min_normal, 2*T.epsilon); 1955 assert(ret.error < 3.00001 * ((2*T.epsilon)*fabs(ret.x)+ T.min_normal)); 1956 assert(ret.x >= 0 && ret.x <= ret.error); 1957 } 1958 { 1959 auto ret = findLocalMin!T((T x) => log(x), 0, T.max, T.min_normal, 2*T.epsilon); 1960 assert(ret.y < -18); 1961 assert(ret.error < 5e-08); 1962 assert(ret.x >= 0 && ret.x <= ret.error); 1963 } 1964 { 1965 auto ret = findLocalMin!T((T x) => -fabs(x), -1, 1, T.min_normal, 2*T.epsilon); 1966 assert(ret.x.fabs.isClose(T(1))); 1967 assert(ret.y.fabs.isClose(T(1))); 1968 assert(ret.error.isClose(T(0), 0.0, 100*T.epsilon)); 1969 } 1970 } 1971 } 1972 1973 /** 1974 Computes $(LINK2 https://en.wikipedia.org/wiki/Euclidean_distance, 1975 Euclidean distance) between input ranges `a` and 1976 `b`. The two ranges must have the same length. The three-parameter 1977 version stops computation as soon as the distance is greater than or 1978 equal to `limit` (this is useful to save computation if a small 1979 distance is sought). 1980 */ 1981 CommonType!(ElementType!(Range1), ElementType!(Range2)) 1982 euclideanDistance(Range1, Range2)(Range1 a, Range2 b) 1983 if (isInputRange!(Range1) && isInputRange!(Range2)) 1984 { 1985 enum bool haveLen = hasLength!(Range1) && hasLength!(Range2); 1986 static if (haveLen) assert(a.length == b.length); 1987 Unqual!(typeof(return)) result = 0; 1988 for (; !a.empty; a.popFront(), b.popFront()) 1989 { 1990 immutable t = a.front - b.front; 1991 result += t * t; 1992 } 1993 static if (!haveLen) assert(b.empty); 1994 return sqrt(result); 1995 } 1996 1997 /// Ditto 1998 CommonType!(ElementType!(Range1), ElementType!(Range2)) 1999 euclideanDistance(Range1, Range2, F)(Range1 a, Range2 b, F limit) 2000 if (isInputRange!(Range1) && isInputRange!(Range2)) 2001 { 2002 limit *= limit; 2003 enum bool haveLen = hasLength!(Range1) && hasLength!(Range2); 2004 static if (haveLen) assert(a.length == b.length); 2005 Unqual!(typeof(return)) result = 0; 2006 for (; ; a.popFront(), b.popFront()) 2007 { 2008 if (a.empty) 2009 { 2010 static if (!haveLen) assert(b.empty); 2011 break; 2012 } 2013 immutable t = a.front - b.front; 2014 result += t * t; 2015 if (result >= limit) break; 2016 } 2017 return sqrt(result); 2018 } 2019 2020 @safe unittest 2021 { 2022 import std.meta : AliasSeq; 2023 static foreach (T; AliasSeq!(double, const double, immutable double)) 2024 {{ 2025 T[] a = [ 1.0, 2.0, ]; 2026 T[] b = [ 4.0, 6.0, ]; 2027 assert(euclideanDistance(a, b) == 5); 2028 assert(euclideanDistance(a, b, 6) == 5); 2029 assert(euclideanDistance(a, b, 5) == 5); 2030 assert(euclideanDistance(a, b, 4) == 5); 2031 assert(euclideanDistance(a, b, 2) == 3); 2032 }} 2033 } 2034 2035 /** 2036 Computes the $(LINK2 https://en.wikipedia.org/wiki/Dot_product, 2037 dot product) of input ranges `a` and $(D 2038 b). The two ranges must have the same length. If both ranges define 2039 length, the check is done once; otherwise, it is done at each 2040 iteration. 2041 */ 2042 CommonType!(ElementType!(Range1), ElementType!(Range2)) 2043 dotProduct(Range1, Range2)(Range1 a, Range2 b) 2044 if (isInputRange!(Range1) && isInputRange!(Range2) && 2045 !(isArray!(Range1) && isArray!(Range2))) 2046 { 2047 enum bool haveLen = hasLength!(Range1) && hasLength!(Range2); 2048 static if (haveLen) assert(a.length == b.length); 2049 Unqual!(typeof(return)) result = 0; 2050 for (; !a.empty; a.popFront(), b.popFront()) 2051 { 2052 result += a.front * b.front; 2053 } 2054 static if (!haveLen) assert(b.empty); 2055 return result; 2056 } 2057 2058 /// Ditto 2059 CommonType!(F1, F2) 2060 dotProduct(F1, F2)(in F1[] avector, in F2[] bvector) 2061 { 2062 immutable n = avector.length; 2063 assert(n == bvector.length); 2064 auto avec = avector.ptr, bvec = bvector.ptr; 2065 Unqual!(typeof(return)) sum0 = 0, sum1 = 0; 2066 2067 const all_endp = avec + n; 2068 const smallblock_endp = avec + (n & ~3); 2069 const bigblock_endp = avec + (n & ~15); 2070 2071 for (; avec != bigblock_endp; avec += 16, bvec += 16) 2072 { 2073 sum0 += avec[0] * bvec[0]; 2074 sum1 += avec[1] * bvec[1]; 2075 sum0 += avec[2] * bvec[2]; 2076 sum1 += avec[3] * bvec[3]; 2077 sum0 += avec[4] * bvec[4]; 2078 sum1 += avec[5] * bvec[5]; 2079 sum0 += avec[6] * bvec[6]; 2080 sum1 += avec[7] * bvec[7]; 2081 sum0 += avec[8] * bvec[8]; 2082 sum1 += avec[9] * bvec[9]; 2083 sum0 += avec[10] * bvec[10]; 2084 sum1 += avec[11] * bvec[11]; 2085 sum0 += avec[12] * bvec[12]; 2086 sum1 += avec[13] * bvec[13]; 2087 sum0 += avec[14] * bvec[14]; 2088 sum1 += avec[15] * bvec[15]; 2089 } 2090 2091 for (; avec != smallblock_endp; avec += 4, bvec += 4) 2092 { 2093 sum0 += avec[0] * bvec[0]; 2094 sum1 += avec[1] * bvec[1]; 2095 sum0 += avec[2] * bvec[2]; 2096 sum1 += avec[3] * bvec[3]; 2097 } 2098 2099 sum0 += sum1; 2100 2101 /* Do trailing portion in naive loop. */ 2102 while (avec != all_endp) 2103 { 2104 sum0 += *avec * *bvec; 2105 ++avec; 2106 ++bvec; 2107 } 2108 2109 return sum0; 2110 } 2111 2112 /// ditto 2113 F dotProduct(F, uint N)(const ref scope F[N] a, const ref scope F[N] b) 2114 if (N <= 16) 2115 { 2116 F sum0 = 0; 2117 F sum1 = 0; 2118 static foreach (i; 0 .. N / 2) 2119 { 2120 sum0 += a[i*2] * b[i*2]; 2121 sum1 += a[i*2+1] * b[i*2+1]; 2122 } 2123 static if (N % 2 == 1) 2124 { 2125 sum0 += a[N-1] * b[N-1]; 2126 } 2127 return sum0 + sum1; 2128 } 2129 2130 @system unittest 2131 { 2132 // @system due to dotProduct and assertCTFEable 2133 import std.exception : assertCTFEable; 2134 import std.meta : AliasSeq; 2135 static foreach (T; AliasSeq!(double, const double, immutable double)) 2136 {{ 2137 T[] a = [ 1.0, 2.0, ]; 2138 T[] b = [ 4.0, 6.0, ]; 2139 assert(dotProduct(a, b) == 16); 2140 assert(dotProduct([1, 3, -5], [4, -2, -1]) == 3); 2141 // Test with fixed-length arrays. 2142 T[2] c = [ 1.0, 2.0, ]; 2143 T[2] d = [ 4.0, 6.0, ]; 2144 assert(dotProduct(c, d) == 16); 2145 T[3] e = [1, 3, -5]; 2146 T[3] f = [4, -2, -1]; 2147 assert(dotProduct(e, f) == 3); 2148 }} 2149 2150 // Make sure the unrolled loop codepath gets tested. 2151 static const x = 2152 [1.0, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22]; 2153 static const y = 2154 [2.0, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23]; 2155 assertCTFEable!({ assert(dotProduct(x, y) == 4048); }); 2156 } 2157 2158 /** 2159 Computes the $(LINK2 https://en.wikipedia.org/wiki/Cosine_similarity, 2160 cosine similarity) of input ranges `a` and $(D 2161 b). The two ranges must have the same length. If both ranges define 2162 length, the check is done once; otherwise, it is done at each 2163 iteration. If either range has all-zero elements, return 0. 2164 */ 2165 CommonType!(ElementType!(Range1), ElementType!(Range2)) 2166 cosineSimilarity(Range1, Range2)(Range1 a, Range2 b) 2167 if (isInputRange!(Range1) && isInputRange!(Range2)) 2168 { 2169 enum bool haveLen = hasLength!(Range1) && hasLength!(Range2); 2170 static if (haveLen) assert(a.length == b.length); 2171 Unqual!(typeof(return)) norma = 0, normb = 0, dotprod = 0; 2172 for (; !a.empty; a.popFront(), b.popFront()) 2173 { 2174 immutable t1 = a.front, t2 = b.front; 2175 norma += t1 * t1; 2176 normb += t2 * t2; 2177 dotprod += t1 * t2; 2178 } 2179 static if (!haveLen) assert(b.empty); 2180 if (norma == 0 || normb == 0) return 0; 2181 return dotprod / sqrt(norma * normb); 2182 } 2183 2184 @safe unittest 2185 { 2186 import std.meta : AliasSeq; 2187 static foreach (T; AliasSeq!(double, const double, immutable double)) 2188 {{ 2189 T[] a = [ 1.0, 2.0, ]; 2190 T[] b = [ 4.0, 3.0, ]; 2191 assert(isClose( 2192 cosineSimilarity(a, b), 10.0 / sqrt(5.0 * 25), 2193 0.01)); 2194 }} 2195 } 2196 2197 /** 2198 Normalizes values in `range` by multiplying each element with a 2199 number chosen such that values sum up to `sum`. If elements in $(D 2200 range) sum to zero, assigns $(D sum / range.length) to 2201 all. Normalization makes sense only if all elements in `range` are 2202 positive. `normalize` assumes that is the case without checking it. 2203 2204 Returns: `true` if normalization completed normally, `false` if 2205 all elements in `range` were zero or if `range` is empty. 2206 */ 2207 bool normalize(R)(R range, ElementType!(R) sum = 1) 2208 if (isForwardRange!(R)) 2209 { 2210 ElementType!(R) s = 0; 2211 // Step 1: Compute sum and length of the range 2212 static if (hasLength!(R)) 2213 { 2214 const length = range.length; 2215 foreach (e; range) 2216 { 2217 s += e; 2218 } 2219 } 2220 else 2221 { 2222 uint length = 0; 2223 foreach (e; range) 2224 { 2225 s += e; 2226 ++length; 2227 } 2228 } 2229 // Step 2: perform normalization 2230 if (s == 0) 2231 { 2232 if (length) 2233 { 2234 immutable f = sum / range.length; 2235 foreach (ref e; range) e = f; 2236 } 2237 return false; 2238 } 2239 // The path most traveled 2240 assert(s >= 0); 2241 immutable f = sum / s; 2242 foreach (ref e; range) 2243 e *= f; 2244 return true; 2245 } 2246 2247 /// 2248 @safe unittest 2249 { 2250 double[] a = []; 2251 assert(!normalize(a)); 2252 a = [ 1.0, 3.0 ]; 2253 assert(normalize(a)); 2254 assert(a == [ 0.25, 0.75 ]); 2255 assert(normalize!(typeof(a))(a, 50)); // a = [12.5, 37.5] 2256 a = [ 0.0, 0.0 ]; 2257 assert(!normalize(a)); 2258 assert(a == [ 0.5, 0.5 ]); 2259 } 2260 2261 /** 2262 Compute the sum of binary logarithms of the input range `r`. 2263 The error of this method is much smaller than with a naive sum of log2. 2264 */ 2265 ElementType!Range sumOfLog2s(Range)(Range r) 2266 if (isInputRange!Range && isFloatingPoint!(ElementType!Range)) 2267 { 2268 long exp = 0; 2269 Unqual!(typeof(return)) x = 1; 2270 foreach (e; r) 2271 { 2272 if (e < 0) 2273 return typeof(return).nan; 2274 int lexp = void; 2275 x *= frexp(e, lexp); 2276 exp += lexp; 2277 if (x < 0.5) 2278 { 2279 x *= 2; 2280 exp--; 2281 } 2282 } 2283 return exp + log2(x); 2284 } 2285 2286 /// 2287 @safe unittest 2288 { 2289 import std.math.traits : isNaN; 2290 2291 assert(sumOfLog2s(new double[0]) == 0); 2292 assert(sumOfLog2s([0.0L]) == -real.infinity); 2293 assert(sumOfLog2s([-0.0L]) == -real.infinity); 2294 assert(sumOfLog2s([2.0L]) == 1); 2295 assert(sumOfLog2s([-2.0L]).isNaN()); 2296 assert(sumOfLog2s([real.nan]).isNaN()); 2297 assert(sumOfLog2s([-real.nan]).isNaN()); 2298 assert(sumOfLog2s([real.infinity]) == real.infinity); 2299 assert(sumOfLog2s([-real.infinity]).isNaN()); 2300 assert(sumOfLog2s([ 0.25, 0.25, 0.25, 0.125 ]) == -9); 2301 } 2302 2303 /** 2304 Computes $(LINK2 https://en.wikipedia.org/wiki/Entropy_(information_theory), 2305 _entropy) of input range `r` in bits. This 2306 function assumes (without checking) that the values in `r` are all 2307 in $(D [0, 1]). For the entropy to be meaningful, often `r` should 2308 be normalized too (i.e., its values should sum to 1). The 2309 two-parameter version stops evaluating as soon as the intermediate 2310 result is greater than or equal to `max`. 2311 */ 2312 ElementType!Range entropy(Range)(Range r) 2313 if (isInputRange!Range) 2314 { 2315 Unqual!(typeof(return)) result = 0.0; 2316 for (;!r.empty; r.popFront) 2317 { 2318 if (!r.front) continue; 2319 result -= r.front * log2(r.front); 2320 } 2321 return result; 2322 } 2323 2324 /// Ditto 2325 ElementType!Range entropy(Range, F)(Range r, F max) 2326 if (isInputRange!Range && 2327 !is(CommonType!(ElementType!Range, F) == void)) 2328 { 2329 Unqual!(typeof(return)) result = 0.0; 2330 for (;!r.empty; r.popFront) 2331 { 2332 if (!r.front) continue; 2333 result -= r.front * log2(r.front); 2334 if (result >= max) break; 2335 } 2336 return result; 2337 } 2338 2339 @safe unittest 2340 { 2341 import std.meta : AliasSeq; 2342 static foreach (T; AliasSeq!(double, const double, immutable double)) 2343 {{ 2344 T[] p = [ 0.0, 0, 0, 1 ]; 2345 assert(entropy(p) == 0); 2346 p = [ 0.25, 0.25, 0.25, 0.25 ]; 2347 assert(entropy(p) == 2); 2348 assert(entropy(p, 1) == 1); 2349 }} 2350 } 2351 2352 /** 2353 Computes the $(LINK2 https://en.wikipedia.org/wiki/Kullback%E2%80%93Leibler_divergence, 2354 Kullback-Leibler divergence) between input ranges 2355 `a` and `b`, which is the sum $(D ai * log(ai / bi)). The base 2356 of logarithm is 2. The ranges are assumed to contain elements in $(D 2357 [0, 1]). Usually the ranges are normalized probability distributions, 2358 but this is not required or checked by $(D 2359 kullbackLeiblerDivergence). If any element `bi` is zero and the 2360 corresponding element `ai` nonzero, returns infinity. (Otherwise, 2361 if $(D ai == 0 && bi == 0), the term $(D ai * log(ai / bi)) is 2362 considered zero.) If the inputs are normalized, the result is 2363 positive. 2364 */ 2365 CommonType!(ElementType!Range1, ElementType!Range2) 2366 kullbackLeiblerDivergence(Range1, Range2)(Range1 a, Range2 b) 2367 if (isInputRange!(Range1) && isInputRange!(Range2)) 2368 { 2369 enum bool haveLen = hasLength!(Range1) && hasLength!(Range2); 2370 static if (haveLen) assert(a.length == b.length); 2371 Unqual!(typeof(return)) result = 0; 2372 for (; !a.empty; a.popFront(), b.popFront()) 2373 { 2374 immutable t1 = a.front; 2375 if (t1 == 0) continue; 2376 immutable t2 = b.front; 2377 if (t2 == 0) return result.infinity; 2378 assert(t1 > 0 && t2 > 0); 2379 result += t1 * log2(t1 / t2); 2380 } 2381 static if (!haveLen) assert(b.empty); 2382 return result; 2383 } 2384 2385 /// 2386 @safe unittest 2387 { 2388 import std.math.operations : isClose; 2389 2390 double[] p = [ 0.0, 0, 0, 1 ]; 2391 assert(kullbackLeiblerDivergence(p, p) == 0); 2392 double[] p1 = [ 0.25, 0.25, 0.25, 0.25 ]; 2393 assert(kullbackLeiblerDivergence(p1, p1) == 0); 2394 assert(kullbackLeiblerDivergence(p, p1) == 2); 2395 assert(kullbackLeiblerDivergence(p1, p) == double.infinity); 2396 double[] p2 = [ 0.2, 0.2, 0.2, 0.4 ]; 2397 assert(isClose(kullbackLeiblerDivergence(p1, p2), 0.0719281, 1e-5)); 2398 assert(isClose(kullbackLeiblerDivergence(p2, p1), 0.0780719, 1e-5)); 2399 } 2400 2401 /** 2402 Computes the $(LINK2 https://en.wikipedia.org/wiki/Jensen%E2%80%93Shannon_divergence, 2403 Jensen-Shannon divergence) between `a` and $(D 2404 b), which is the sum $(D (ai * log(2 * ai / (ai + bi)) + bi * log(2 * 2405 bi / (ai + bi))) / 2). The base of logarithm is 2. The ranges are 2406 assumed to contain elements in $(D [0, 1]). Usually the ranges are 2407 normalized probability distributions, but this is not required or 2408 checked by `jensenShannonDivergence`. If the inputs are normalized, 2409 the result is bounded within $(D [0, 1]). The three-parameter version 2410 stops evaluations as soon as the intermediate result is greater than 2411 or equal to `limit`. 2412 */ 2413 CommonType!(ElementType!Range1, ElementType!Range2) 2414 jensenShannonDivergence(Range1, Range2)(Range1 a, Range2 b) 2415 if (isInputRange!Range1 && isInputRange!Range2 && 2416 is(CommonType!(ElementType!Range1, ElementType!Range2))) 2417 { 2418 enum bool haveLen = hasLength!(Range1) && hasLength!(Range2); 2419 static if (haveLen) assert(a.length == b.length); 2420 Unqual!(typeof(return)) result = 0; 2421 for (; !a.empty; a.popFront(), b.popFront()) 2422 { 2423 immutable t1 = a.front; 2424 immutable t2 = b.front; 2425 immutable avg = (t1 + t2) / 2; 2426 if (t1 != 0) 2427 { 2428 result += t1 * log2(t1 / avg); 2429 } 2430 if (t2 != 0) 2431 { 2432 result += t2 * log2(t2 / avg); 2433 } 2434 } 2435 static if (!haveLen) assert(b.empty); 2436 return result / 2; 2437 } 2438 2439 /// Ditto 2440 CommonType!(ElementType!Range1, ElementType!Range2) 2441 jensenShannonDivergence(Range1, Range2, F)(Range1 a, Range2 b, F limit) 2442 if (isInputRange!Range1 && isInputRange!Range2 && 2443 is(typeof(CommonType!(ElementType!Range1, ElementType!Range2).init 2444 >= F.init) : bool)) 2445 { 2446 enum bool haveLen = hasLength!(Range1) && hasLength!(Range2); 2447 static if (haveLen) assert(a.length == b.length); 2448 Unqual!(typeof(return)) result = 0; 2449 limit *= 2; 2450 for (; !a.empty; a.popFront(), b.popFront()) 2451 { 2452 immutable t1 = a.front; 2453 immutable t2 = b.front; 2454 immutable avg = (t1 + t2) / 2; 2455 if (t1 != 0) 2456 { 2457 result += t1 * log2(t1 / avg); 2458 } 2459 if (t2 != 0) 2460 { 2461 result += t2 * log2(t2 / avg); 2462 } 2463 if (result >= limit) break; 2464 } 2465 static if (!haveLen) assert(b.empty); 2466 return result / 2; 2467 } 2468 2469 /// 2470 @safe unittest 2471 { 2472 import std.math.operations : isClose; 2473 2474 double[] p = [ 0.0, 0, 0, 1 ]; 2475 assert(jensenShannonDivergence(p, p) == 0); 2476 double[] p1 = [ 0.25, 0.25, 0.25, 0.25 ]; 2477 assert(jensenShannonDivergence(p1, p1) == 0); 2478 assert(isClose(jensenShannonDivergence(p1, p), 0.548795, 1e-5)); 2479 double[] p2 = [ 0.2, 0.2, 0.2, 0.4 ]; 2480 assert(isClose(jensenShannonDivergence(p1, p2), 0.0186218, 1e-5)); 2481 assert(isClose(jensenShannonDivergence(p2, p1), 0.0186218, 1e-5)); 2482 assert(isClose(jensenShannonDivergence(p2, p1, 0.005), 0.00602366, 1e-5)); 2483 } 2484 2485 /** 2486 The so-called "all-lengths gap-weighted string kernel" computes a 2487 similarity measure between `s` and `t` based on all of their 2488 common subsequences of all lengths. Gapped subsequences are also 2489 included. 2490 2491 To understand what $(D gapWeightedSimilarity(s, t, lambda)) computes, 2492 consider first the case $(D lambda = 1) and the strings $(D s = 2493 ["Hello", "brave", "new", "world"]) and $(D t = ["Hello", "new", 2494 "world"]). In that case, `gapWeightedSimilarity` counts the 2495 following matches: 2496 2497 $(OL $(LI three matches of length 1, namely `"Hello"`, `"new"`, 2498 and `"world"`;) $(LI three matches of length 2, namely ($(D 2499 "Hello", "new")), ($(D "Hello", "world")), and ($(D "new", "world"));) 2500 $(LI one match of length 3, namely ($(D "Hello", "new", "world")).)) 2501 2502 The call $(D gapWeightedSimilarity(s, t, 1)) simply counts all of 2503 these matches and adds them up, returning 7. 2504 2505 ---- 2506 string[] s = ["Hello", "brave", "new", "world"]; 2507 string[] t = ["Hello", "new", "world"]; 2508 assert(gapWeightedSimilarity(s, t, 1) == 7); 2509 ---- 2510 2511 Note how the gaps in matching are simply ignored, for example ($(D 2512 "Hello", "new")) is deemed as good a match as ($(D "new", 2513 "world")). This may be too permissive for some applications. To 2514 eliminate gapped matches entirely, use $(D lambda = 0): 2515 2516 ---- 2517 string[] s = ["Hello", "brave", "new", "world"]; 2518 string[] t = ["Hello", "new", "world"]; 2519 assert(gapWeightedSimilarity(s, t, 0) == 4); 2520 ---- 2521 2522 The call above eliminated the gapped matches ($(D "Hello", "new")), 2523 ($(D "Hello", "world")), and ($(D "Hello", "new", "world")) from the 2524 tally. That leaves only 4 matches. 2525 2526 The most interesting case is when gapped matches still participate in 2527 the result, but not as strongly as ungapped matches. The result will 2528 be a smooth, fine-grained similarity measure between the input 2529 strings. This is where values of `lambda` between 0 and 1 enter 2530 into play: gapped matches are $(I exponentially penalized with the 2531 number of gaps) with base `lambda`. This means that an ungapped 2532 match adds 1 to the return value; a match with one gap in either 2533 string adds `lambda` to the return value; ...; a match with a total 2534 of `n` gaps in both strings adds $(D pow(lambda, n)) to the return 2535 value. In the example above, we have 4 matches without gaps, 2 matches 2536 with one gap, and 1 match with three gaps. The latter match is ($(D 2537 "Hello", "world")), which has two gaps in the first string and one gap 2538 in the second string, totaling to three gaps. Summing these up we get 2539 $(D 4 + 2 * lambda + pow(lambda, 3)). 2540 2541 ---- 2542 string[] s = ["Hello", "brave", "new", "world"]; 2543 string[] t = ["Hello", "new", "world"]; 2544 assert(gapWeightedSimilarity(s, t, 0.5) == 4 + 0.5 * 2 + 0.125); 2545 ---- 2546 2547 `gapWeightedSimilarity` is useful wherever a smooth similarity 2548 measure between sequences allowing for approximate matches is 2549 needed. The examples above are given with words, but any sequences 2550 with elements comparable for equality are allowed, e.g. characters or 2551 numbers. `gapWeightedSimilarity` uses a highly optimized dynamic 2552 programming implementation that needs $(D 16 * min(s.length, 2553 t.length)) extra bytes of memory and $(BIGOH s.length * t.length) time 2554 to complete. 2555 */ 2556 F gapWeightedSimilarity(alias comp = "a == b", R1, R2, F)(R1 s, R2 t, F lambda) 2557 if (isRandomAccessRange!(R1) && hasLength!(R1) && 2558 isRandomAccessRange!(R2) && hasLength!(R2)) 2559 { 2560 import core.exception : onOutOfMemoryError; 2561 import core.stdc.stdlib : malloc, free; 2562 import std.algorithm.mutation : swap; 2563 import std.functional : binaryFun; 2564 2565 if (s.length < t.length) return gapWeightedSimilarity(t, s, lambda); 2566 if (!t.length) return 0; 2567 2568 auto dpvi = cast(F*) malloc(F.sizeof * 2 * t.length); 2569 if (!dpvi) 2570 onOutOfMemoryError(); 2571 2572 auto dpvi1 = dpvi + t.length; 2573 scope(exit) free(dpvi < dpvi1 ? dpvi : dpvi1); 2574 dpvi[0 .. t.length] = 0; 2575 dpvi1[0] = 0; 2576 immutable lambda2 = lambda * lambda; 2577 2578 F result = 0; 2579 foreach (i; 0 .. s.length) 2580 { 2581 const si = s[i]; 2582 for (size_t j = 0;;) 2583 { 2584 F dpsij = void; 2585 if (binaryFun!(comp)(si, t[j])) 2586 { 2587 dpsij = 1 + dpvi[j]; 2588 result += dpsij; 2589 } 2590 else 2591 { 2592 dpsij = 0; 2593 } 2594 immutable j1 = j + 1; 2595 if (j1 == t.length) break; 2596 dpvi1[j1] = dpsij + lambda * (dpvi1[j] + dpvi[j1]) - 2597 lambda2 * dpvi[j]; 2598 j = j1; 2599 } 2600 swap(dpvi, dpvi1); 2601 } 2602 return result; 2603 } 2604 2605 @system unittest 2606 { 2607 string[] s = ["Hello", "brave", "new", "world"]; 2608 string[] t = ["Hello", "new", "world"]; 2609 assert(gapWeightedSimilarity(s, t, 1) == 7); 2610 assert(gapWeightedSimilarity(s, t, 0) == 4); 2611 assert(gapWeightedSimilarity(s, t, 0.5) == 4 + 2 * 0.5 + 0.125); 2612 } 2613 2614 /** 2615 The similarity per `gapWeightedSimilarity` has an issue in that it 2616 grows with the lengths of the two strings, even though the strings are 2617 not actually very similar. For example, the range $(D ["Hello", 2618 "world"]) is increasingly similar with the range $(D ["Hello", 2619 "world", "world", "world",...]) as more instances of `"world"` are 2620 appended. To prevent that, `gapWeightedSimilarityNormalized` 2621 computes a normalized version of the similarity that is computed as 2622 $(D gapWeightedSimilarity(s, t, lambda) / 2623 sqrt(gapWeightedSimilarity(s, t, lambda) * gapWeightedSimilarity(s, t, 2624 lambda))). The function `gapWeightedSimilarityNormalized` (a 2625 so-called normalized kernel) is bounded in $(D [0, 1]), reaches `0` 2626 only for ranges that don't match in any position, and `1` only for 2627 identical ranges. 2628 2629 The optional parameters `sSelfSim` and `tSelfSim` are meant for 2630 avoiding duplicate computation. Many applications may have already 2631 computed $(D gapWeightedSimilarity(s, s, lambda)) and/or $(D 2632 gapWeightedSimilarity(t, t, lambda)). In that case, they can be passed 2633 as `sSelfSim` and `tSelfSim`, respectively. 2634 */ 2635 Select!(isFloatingPoint!(F), F, double) 2636 gapWeightedSimilarityNormalized(alias comp = "a == b", R1, R2, F) 2637 (R1 s, R2 t, F lambda, F sSelfSim = F.init, F tSelfSim = F.init) 2638 if (isRandomAccessRange!(R1) && hasLength!(R1) && 2639 isRandomAccessRange!(R2) && hasLength!(R2)) 2640 { 2641 static bool uncomputed(F n) 2642 { 2643 static if (isFloatingPoint!(F)) 2644 return isNaN(n); 2645 else 2646 return n == n.init; 2647 } 2648 if (uncomputed(sSelfSim)) 2649 sSelfSim = gapWeightedSimilarity!(comp)(s, s, lambda); 2650 if (sSelfSim == 0) return 0; 2651 if (uncomputed(tSelfSim)) 2652 tSelfSim = gapWeightedSimilarity!(comp)(t, t, lambda); 2653 if (tSelfSim == 0) return 0; 2654 2655 return gapWeightedSimilarity!(comp)(s, t, lambda) / 2656 sqrt(cast(typeof(return)) sSelfSim * tSelfSim); 2657 } 2658 2659 /// 2660 @system unittest 2661 { 2662 import std.math.operations : isClose; 2663 import std.math.algebraic : sqrt; 2664 2665 string[] s = ["Hello", "brave", "new", "world"]; 2666 string[] t = ["Hello", "new", "world"]; 2667 assert(gapWeightedSimilarity(s, s, 1) == 15); 2668 assert(gapWeightedSimilarity(t, t, 1) == 7); 2669 assert(gapWeightedSimilarity(s, t, 1) == 7); 2670 assert(isClose(gapWeightedSimilarityNormalized(s, t, 1), 2671 7.0 / sqrt(15.0 * 7), 0.01)); 2672 } 2673 2674 /** 2675 Similar to `gapWeightedSimilarity`, just works in an incremental 2676 manner by first revealing the matches of length 1, then gapped matches 2677 of length 2, and so on. The memory requirement is $(BIGOH s.length * 2678 t.length). The time complexity is $(BIGOH s.length * t.length) time 2679 for computing each step. Continuing on the previous example: 2680 2681 The implementation is based on the pseudocode in Fig. 4 of the paper 2682 $(HTTP jmlr.csail.mit.edu/papers/volume6/rousu05a/rousu05a.pdf, 2683 "Efficient Computation of Gapped Substring Kernels on Large Alphabets") 2684 by Rousu et al., with additional algorithmic and systems-level 2685 optimizations. 2686 */ 2687 struct GapWeightedSimilarityIncremental(Range, F = double) 2688 if (isRandomAccessRange!(Range) && hasLength!(Range)) 2689 { 2690 import core.stdc.stdlib : malloc, realloc, alloca, free; 2691 2692 private: 2693 Range s, t; 2694 F currentValue = 0; 2695 F* kl; 2696 size_t gram = void; 2697 F lambda = void, lambda2 = void; 2698 2699 public: 2700 /** 2701 Constructs an object given two ranges `s` and `t` and a penalty 2702 `lambda`. Constructor completes in $(BIGOH s.length * t.length) 2703 time and computes all matches of length 1. 2704 */ 2705 this(Range s, Range t, F lambda) 2706 { 2707 import core.exception : onOutOfMemoryError; 2708 2709 assert(lambda > 0); 2710 this.gram = 0; 2711 this.lambda = lambda; 2712 this.lambda2 = lambda * lambda; // for efficiency only 2713 2714 size_t iMin = size_t.max, jMin = size_t.max, 2715 iMax = 0, jMax = 0; 2716 /* initialize */ 2717 Tuple!(size_t, size_t) * k0; 2718 size_t k0len; 2719 scope(exit) free(k0); 2720 currentValue = 0; 2721 foreach (i, si; s) 2722 { 2723 foreach (j; 0 .. t.length) 2724 { 2725 if (si != t[j]) continue; 2726 k0 = cast(typeof(k0)) realloc(k0, ++k0len * (*k0).sizeof); 2727 with (k0[k0len - 1]) 2728 { 2729 field[0] = i; 2730 field[1] = j; 2731 } 2732 // Maintain the minimum and maximum i and j 2733 if (iMin > i) iMin = i; 2734 if (iMax < i) iMax = i; 2735 if (jMin > j) jMin = j; 2736 if (jMax < j) jMax = j; 2737 } 2738 } 2739 2740 if (iMin > iMax) return; 2741 assert(k0len); 2742 2743 currentValue = k0len; 2744 // Chop strings down to the useful sizes 2745 s = s[iMin .. iMax + 1]; 2746 t = t[jMin .. jMax + 1]; 2747 this.s = s; 2748 this.t = t; 2749 2750 kl = cast(F*) malloc(s.length * t.length * F.sizeof); 2751 if (!kl) 2752 onOutOfMemoryError(); 2753 2754 kl[0 .. s.length * t.length] = 0; 2755 foreach (pos; 0 .. k0len) 2756 { 2757 with (k0[pos]) 2758 { 2759 kl[(field[0] - iMin) * t.length + field[1] -jMin] = lambda2; 2760 } 2761 } 2762 } 2763 2764 /** 2765 Returns: `this`. 2766 */ 2767 ref GapWeightedSimilarityIncremental opSlice() 2768 { 2769 return this; 2770 } 2771 2772 /** 2773 Computes the match of the popFront length. Completes in $(BIGOH s.length * 2774 t.length) time. 2775 */ 2776 void popFront() 2777 { 2778 import std.algorithm.mutation : swap; 2779 2780 // This is a large source of optimization: if similarity at 2781 // the gram-1 level was 0, then we can safely assume 2782 // similarity at the gram level is 0 as well. 2783 if (empty) return; 2784 2785 // Now attempt to match gapped substrings of length `gram' 2786 ++gram; 2787 currentValue = 0; 2788 2789 auto Si = cast(F*) alloca(t.length * F.sizeof); 2790 Si[0 .. t.length] = 0; 2791 foreach (i; 0 .. s.length) 2792 { 2793 const si = s[i]; 2794 F Sij_1 = 0; 2795 F Si_1j_1 = 0; 2796 auto kli = kl + i * t.length; 2797 for (size_t j = 0;;) 2798 { 2799 const klij = kli[j]; 2800 const Si_1j = Si[j]; 2801 const tmp = klij + lambda * (Si_1j + Sij_1) - lambda2 * Si_1j_1; 2802 // now update kl and currentValue 2803 if (si == t[j]) 2804 currentValue += kli[j] = lambda2 * Si_1j_1; 2805 else 2806 kli[j] = 0; 2807 // commit to Si 2808 Si[j] = tmp; 2809 if (++j == t.length) break; 2810 // get ready for the popFront step; virtually increment j, 2811 // so essentially stuffj_1 <-- stuffj 2812 Si_1j_1 = Si_1j; 2813 Sij_1 = tmp; 2814 } 2815 } 2816 currentValue /= pow(lambda, 2 * (gram + 1)); 2817 2818 version (none) 2819 { 2820 Si_1[0 .. t.length] = 0; 2821 kl[0 .. min(t.length, maxPerimeter + 1)] = 0; 2822 foreach (i; 1 .. min(s.length, maxPerimeter + 1)) 2823 { 2824 auto kli = kl + i * t.length; 2825 assert(s.length > i); 2826 const si = s[i]; 2827 auto kl_1i_1 = kl_1 + (i - 1) * t.length; 2828 kli[0] = 0; 2829 F lastS = 0; 2830 foreach (j; 1 .. min(maxPerimeter - i + 1, t.length)) 2831 { 2832 immutable j_1 = j - 1; 2833 immutable tmp = kl_1i_1[j_1] 2834 + lambda * (Si_1[j] + lastS) 2835 - lambda2 * Si_1[j_1]; 2836 kl_1i_1[j_1] = float.nan; 2837 Si_1[j_1] = lastS; 2838 lastS = tmp; 2839 if (si == t[j]) 2840 { 2841 currentValue += kli[j] = lambda2 * lastS; 2842 } 2843 else 2844 { 2845 kli[j] = 0; 2846 } 2847 } 2848 Si_1[t.length - 1] = lastS; 2849 } 2850 currentValue /= pow(lambda, 2 * (gram + 1)); 2851 // get ready for the popFront computation 2852 swap(kl, kl_1); 2853 } 2854 } 2855 2856 /** 2857 Returns: The gapped similarity at the current match length (initially 2858 1, grows with each call to `popFront`). 2859 */ 2860 @property F front() { return currentValue; } 2861 2862 /** 2863 Returns: Whether there are more matches. 2864 */ 2865 @property bool empty() 2866 { 2867 if (currentValue) return false; 2868 if (kl) 2869 { 2870 free(kl); 2871 kl = null; 2872 } 2873 return true; 2874 } 2875 } 2876 2877 /** 2878 Ditto 2879 */ 2880 GapWeightedSimilarityIncremental!(R, F) gapWeightedSimilarityIncremental(R, F) 2881 (R r1, R r2, F penalty) 2882 { 2883 return typeof(return)(r1, r2, penalty); 2884 } 2885 2886 /// 2887 @system unittest 2888 { 2889 string[] s = ["Hello", "brave", "new", "world"]; 2890 string[] t = ["Hello", "new", "world"]; 2891 auto simIter = gapWeightedSimilarityIncremental(s, t, 1.0); 2892 assert(simIter.front == 3); // three 1-length matches 2893 simIter.popFront(); 2894 assert(simIter.front == 3); // three 2-length matches 2895 simIter.popFront(); 2896 assert(simIter.front == 1); // one 3-length match 2897 simIter.popFront(); 2898 assert(simIter.empty); // no more match 2899 } 2900 2901 @system unittest 2902 { 2903 import std.conv : text; 2904 string[] s = ["Hello", "brave", "new", "world"]; 2905 string[] t = ["Hello", "new", "world"]; 2906 auto simIter = gapWeightedSimilarityIncremental(s, t, 1.0); 2907 //foreach (e; simIter) writeln(e); 2908 assert(simIter.front == 3); // three 1-length matches 2909 simIter.popFront(); 2910 assert(simIter.front == 3, text(simIter.front)); // three 2-length matches 2911 simIter.popFront(); 2912 assert(simIter.front == 1); // one 3-length matches 2913 simIter.popFront(); 2914 assert(simIter.empty); // no more match 2915 2916 s = ["Hello"]; 2917 t = ["bye"]; 2918 simIter = gapWeightedSimilarityIncremental(s, t, 0.5); 2919 assert(simIter.empty); 2920 2921 s = ["Hello"]; 2922 t = ["Hello"]; 2923 simIter = gapWeightedSimilarityIncremental(s, t, 0.5); 2924 assert(simIter.front == 1); // one match 2925 simIter.popFront(); 2926 assert(simIter.empty); 2927 2928 s = ["Hello", "world"]; 2929 t = ["Hello"]; 2930 simIter = gapWeightedSimilarityIncremental(s, t, 0.5); 2931 assert(simIter.front == 1); // one match 2932 simIter.popFront(); 2933 assert(simIter.empty); 2934 2935 s = ["Hello", "world"]; 2936 t = ["Hello", "yah", "world"]; 2937 simIter = gapWeightedSimilarityIncremental(s, t, 0.5); 2938 assert(simIter.front == 2); // two 1-gram matches 2939 simIter.popFront(); 2940 assert(simIter.front == 0.5, text(simIter.front)); // one 2-gram match, 1 gap 2941 } 2942 2943 @system unittest 2944 { 2945 GapWeightedSimilarityIncremental!(string[]) sim = 2946 GapWeightedSimilarityIncremental!(string[])( 2947 ["nyuk", "I", "have", "no", "chocolate", "giba"], 2948 ["wyda", "I", "have", "I", "have", "have", "I", "have", "hehe"], 2949 0.5); 2950 double[] witness = [ 7.0, 4.03125, 0, 0 ]; 2951 foreach (e; sim) 2952 { 2953 //writeln(e); 2954 assert(e == witness.front); 2955 witness.popFront(); 2956 } 2957 witness = [ 3.0, 1.3125, 0.25 ]; 2958 sim = GapWeightedSimilarityIncremental!(string[])( 2959 ["I", "have", "no", "chocolate"], 2960 ["I", "have", "some", "chocolate"], 2961 0.5); 2962 foreach (e; sim) 2963 { 2964 //writeln(e); 2965 assert(e == witness.front); 2966 witness.popFront(); 2967 } 2968 assert(witness.empty); 2969 } 2970 2971 /** 2972 Computes the greatest common divisor of `a` and `b` by using 2973 an efficient algorithm such as $(HTTPS en.wikipedia.org/wiki/Euclidean_algorithm, Euclid's) 2974 or $(HTTPS en.wikipedia.org/wiki/Binary_GCD_algorithm, Stein's) algorithm. 2975 2976 Params: 2977 a = Integer value of any numerical type that supports the modulo operator `%`. 2978 If bit-shifting `<<` and `>>` are also supported, Stein's algorithm will 2979 be used; otherwise, Euclid's algorithm is used as _a fallback. 2980 b = Integer value of any equivalent numerical type. 2981 2982 Returns: 2983 The greatest common divisor of the given arguments. 2984 */ 2985 typeof(Unqual!(T).init % Unqual!(U).init) gcd(T, U)(T a, U b) 2986 if (isIntegral!T && isIntegral!U) 2987 { 2988 // Operate on a common type between the two arguments. 2989 alias UCT = Unsigned!(CommonType!(Unqual!T, Unqual!U)); 2990 2991 // `std.math.abs` doesn't support unsigned integers, and `T.min` is undefined. 2992 static if (is(T : immutable short) || is(T : immutable byte)) 2993 UCT ax = (isUnsigned!T || a >= 0) ? a : cast(UCT) -int(a); 2994 else 2995 UCT ax = (isUnsigned!T || a >= 0) ? a : -UCT(a); 2996 2997 static if (is(U : immutable short) || is(U : immutable byte)) 2998 UCT bx = (isUnsigned!U || b >= 0) ? b : cast(UCT) -int(b); 2999 else 3000 UCT bx = (isUnsigned!U || b >= 0) ? b : -UCT(b); 3001 3002 // Special cases. 3003 if (ax == 0) 3004 return bx; 3005 if (bx == 0) 3006 return ax; 3007 3008 return gcdImpl(ax, bx); 3009 } 3010 3011 private typeof(T.init % T.init) gcdImpl(T)(T a, T b) 3012 if (isIntegral!T) 3013 { 3014 pragma(inline, true); 3015 import core.bitop : bsf; 3016 import std.algorithm.mutation : swap; 3017 3018 immutable uint shift = bsf(a | b); 3019 a >>= a.bsf; 3020 do 3021 { 3022 b >>= b.bsf; 3023 if (a > b) 3024 swap(a, b); 3025 b -= a; 3026 } while (b); 3027 3028 return a << shift; 3029 } 3030 3031 /// 3032 @safe unittest 3033 { 3034 assert(gcd(2 * 5 * 7 * 7, 5 * 7 * 11) == 5 * 7); 3035 const int a = 5 * 13 * 23 * 23, b = 13 * 59; 3036 assert(gcd(a, b) == 13); 3037 } 3038 3039 @safe unittest 3040 { 3041 import std.meta : AliasSeq; 3042 static foreach (T; AliasSeq!(byte, ubyte, short, ushort, int, uint, long, ulong, 3043 const byte, const short, const int, const long, 3044 immutable ubyte, immutable ushort, immutable uint, immutable ulong)) 3045 { 3046 static foreach (U; AliasSeq!(byte, ubyte, short, ushort, int, uint, long, ulong, 3047 const ubyte, const ushort, const uint, const ulong, 3048 immutable byte, immutable short, immutable int, immutable long)) 3049 { 3050 // Signed and unsigned tests. 3051 static if (T.max > byte.max && U.max > byte.max) 3052 assert(gcd(T(200), U(200)) == 200); 3053 static if (T.max > ubyte.max) 3054 { 3055 assert(gcd(T(2000), U(20)) == 20); 3056 assert(gcd(T(2011), U(17)) == 1); 3057 } 3058 static if (T.max > ubyte.max && U.max > ubyte.max) 3059 assert(gcd(T(1071), U(462)) == 21); 3060 3061 assert(gcd(T(0), U(13)) == 13); 3062 assert(gcd(T(29), U(0)) == 29); 3063 assert(gcd(T(0), U(0)) == 0); 3064 assert(gcd(T(1), U(2)) == 1); 3065 assert(gcd(T(9), U(6)) == 3); 3066 assert(gcd(T(3), U(4)) == 1); 3067 assert(gcd(T(32), U(24)) == 8); 3068 assert(gcd(T(5), U(6)) == 1); 3069 assert(gcd(T(54), U(36)) == 18); 3070 3071 // Int and Long tests. 3072 static if (T.max > short.max && U.max > short.max) 3073 assert(gcd(T(46391), U(62527)) == 2017); 3074 static if (T.max > ushort.max && U.max > ushort.max) 3075 assert(gcd(T(63245986), U(39088169)) == 1); 3076 static if (T.max > uint.max && U.max > uint.max) 3077 { 3078 assert(gcd(T(77160074263), U(47687519812)) == 1); 3079 assert(gcd(T(77160074264), U(47687519812)) == 4); 3080 } 3081 3082 // Negative tests. 3083 static if (T.min < 0) 3084 { 3085 assert(gcd(T(-21), U(28)) == 7); 3086 assert(gcd(T(-3), U(4)) == 1); 3087 } 3088 static if (U.min < 0) 3089 { 3090 assert(gcd(T(1), U(-2)) == 1); 3091 assert(gcd(T(33), U(-44)) == 11); 3092 } 3093 static if (T.min < 0 && U.min < 0) 3094 { 3095 assert(gcd(T(-5), U(-6)) == 1); 3096 assert(gcd(T(-50), U(-60)) == 10); 3097 } 3098 } 3099 } 3100 } 3101 3102 // https://issues.dlang.org/show_bug.cgi?id=21834 3103 @safe unittest 3104 { 3105 assert(gcd(-120, 10U) == 10); 3106 assert(gcd(120U, -10) == 10); 3107 assert(gcd(int.min, 0L) == 1L + int.max); 3108 assert(gcd(0L, int.min) == 1L + int.max); 3109 assert(gcd(int.min, 0L + int.min) == 1L + int.max); 3110 assert(gcd(int.min, 1L + int.max) == 1L + int.max); 3111 assert(gcd(short.min, 1U + short.max) == 1U + short.max); 3112 } 3113 3114 // This overload is for non-builtin numerical types like BigInt or 3115 // user-defined types. 3116 /// ditto 3117 auto gcd(T)(T a, T b) 3118 if (!isIntegral!T && 3119 is(typeof(T.init % T.init)) && 3120 is(typeof(T.init == 0 || T.init > 0))) 3121 { 3122 static if (!is(T == Unqual!T)) 3123 { 3124 return gcd!(Unqual!T)(a, b); 3125 } 3126 else 3127 { 3128 // Ensure arguments are unsigned. 3129 a = a >= 0 ? a : -a; 3130 b = b >= 0 ? b : -b; 3131 3132 // Special cases. 3133 if (a == 0) 3134 return b; 3135 if (b == 0) 3136 return a; 3137 3138 return gcdImpl(a, b); 3139 } 3140 } 3141 3142 private auto gcdImpl(T)(T a, T b) 3143 if (!isIntegral!T) 3144 { 3145 pragma(inline, true); 3146 import std.algorithm.mutation : swap; 3147 enum canUseBinaryGcd = is(typeof(() { 3148 T t, u; 3149 t <<= 1; 3150 t >>= 1; 3151 t -= u; 3152 bool b = (t & 1) == 0; 3153 swap(t, u); 3154 })); 3155 3156 static if (canUseBinaryGcd) 3157 { 3158 uint shift = 0; 3159 while ((a & 1) == 0 && (b & 1) == 0) 3160 { 3161 a >>= 1; 3162 b >>= 1; 3163 shift++; 3164 } 3165 3166 if ((a & 1) == 0) swap(a, b); 3167 3168 do 3169 { 3170 assert((a & 1) != 0); 3171 while ((b & 1) == 0) 3172 b >>= 1; 3173 if (a > b) 3174 swap(a, b); 3175 b -= a; 3176 } while (b); 3177 3178 return a << shift; 3179 } 3180 else 3181 { 3182 // The only thing we have is %; fallback to Euclidean algorithm. 3183 while (b != 0) 3184 { 3185 auto t = b; 3186 b = a % b; 3187 a = t; 3188 } 3189 return a; 3190 } 3191 } 3192 3193 // https://issues.dlang.org/show_bug.cgi?id=7102 3194 @system pure unittest 3195 { 3196 import std.bigint : BigInt; 3197 assert(gcd(BigInt("71_000_000_000_000_000_000"), 3198 BigInt("31_000_000_000_000_000_000")) == 3199 BigInt("1_000_000_000_000_000_000")); 3200 3201 assert(gcd(BigInt(0), BigInt(1234567)) == BigInt(1234567)); 3202 assert(gcd(BigInt(1234567), BigInt(0)) == BigInt(1234567)); 3203 } 3204 3205 @safe pure nothrow unittest 3206 { 3207 // A numerical type that only supports % and - (to force gcd implementation 3208 // to use Euclidean algorithm). 3209 struct CrippledInt 3210 { 3211 int impl; 3212 CrippledInt opBinary(string op : "%")(CrippledInt i) 3213 { 3214 return CrippledInt(impl % i.impl); 3215 } 3216 CrippledInt opUnary(string op : "-")() 3217 { 3218 return CrippledInt(-impl); 3219 } 3220 int opEquals(CrippledInt i) { return impl == i.impl; } 3221 int opEquals(int i) { return impl == i; } 3222 int opCmp(int i) { return (impl < i) ? -1 : (impl > i) ? 1 : 0; } 3223 } 3224 assert(gcd(CrippledInt(2310), CrippledInt(1309)) == CrippledInt(77)); 3225 assert(gcd(CrippledInt(-120), CrippledInt(10U)) == CrippledInt(10)); 3226 assert(gcd(CrippledInt(120U), CrippledInt(-10)) == CrippledInt(10)); 3227 } 3228 3229 // https://issues.dlang.org/show_bug.cgi?id=19514 3230 @system pure unittest 3231 { 3232 import std.bigint : BigInt; 3233 assert(gcd(BigInt(2), BigInt(1)) == BigInt(1)); 3234 } 3235 3236 // Issue 20924 3237 @safe unittest 3238 { 3239 import std.bigint : BigInt; 3240 const a = BigInt("123143238472389492934020"); 3241 const b = BigInt("902380489324729338420924"); 3242 assert(__traits(compiles, gcd(a, b))); 3243 } 3244 3245 // https://issues.dlang.org/show_bug.cgi?id=21834 3246 @safe unittest 3247 { 3248 import std.bigint : BigInt; 3249 assert(gcd(BigInt(-120), BigInt(10U)) == BigInt(10)); 3250 assert(gcd(BigInt(120U), BigInt(-10)) == BigInt(10)); 3251 assert(gcd(BigInt(int.min), BigInt(0L)) == BigInt(1L + int.max)); 3252 assert(gcd(BigInt(0L), BigInt(int.min)) == BigInt(1L + int.max)); 3253 assert(gcd(BigInt(int.min), BigInt(0L + int.min)) == BigInt(1L + int.max)); 3254 assert(gcd(BigInt(int.min), BigInt(1L + int.max)) == BigInt(1L + int.max)); 3255 assert(gcd(BigInt(short.min), BigInt(1U + short.max)) == BigInt(1U + short.max)); 3256 } 3257 3258 3259 /** 3260 Computes the least common multiple of `a` and `b`. 3261 Arguments are the same as $(MYREF gcd). 3262 3263 Returns: 3264 The least common multiple of the given arguments. 3265 */ 3266 typeof(Unqual!(T).init % Unqual!(U).init) lcm(T, U)(T a, U b) 3267 if (isIntegral!T && isIntegral!U) 3268 { 3269 // Operate on a common type between the two arguments. 3270 alias UCT = Unsigned!(CommonType!(Unqual!T, Unqual!U)); 3271 3272 // `std.math.abs` doesn't support unsigned integers, and `T.min` is undefined. 3273 static if (is(T : immutable short) || is(T : immutable byte)) 3274 UCT ax = (isUnsigned!T || a >= 0) ? a : cast(UCT) -int(a); 3275 else 3276 UCT ax = (isUnsigned!T || a >= 0) ? a : -UCT(a); 3277 3278 static if (is(U : immutable short) || is(U : immutable byte)) 3279 UCT bx = (isUnsigned!U || b >= 0) ? b : cast(UCT) -int(b); 3280 else 3281 UCT bx = (isUnsigned!U || b >= 0) ? b : -UCT(b); 3282 3283 // Special cases. 3284 if (ax == 0) 3285 return ax; 3286 if (bx == 0) 3287 return bx; 3288 3289 return (ax / gcdImpl(ax, bx)) * bx; 3290 } 3291 3292 /// 3293 @safe unittest 3294 { 3295 assert(lcm(1, 2) == 2); 3296 assert(lcm(3, 4) == 12); 3297 assert(lcm(5, 6) == 30); 3298 } 3299 3300 @safe unittest 3301 { 3302 import std.meta : AliasSeq; 3303 static foreach (T; AliasSeq!(byte, ubyte, short, ushort, int, uint, long, ulong, 3304 const byte, const short, const int, const long, 3305 immutable ubyte, immutable ushort, immutable uint, immutable ulong)) 3306 { 3307 static foreach (U; AliasSeq!(byte, ubyte, short, ushort, int, uint, long, ulong, 3308 const ubyte, const ushort, const uint, const ulong, 3309 immutable byte, immutable short, immutable int, immutable long)) 3310 { 3311 assert(lcm(T(21), U(6)) == 42); 3312 assert(lcm(T(41), U(0)) == 0); 3313 assert(lcm(T(0), U(7)) == 0); 3314 assert(lcm(T(0), U(0)) == 0); 3315 assert(lcm(T(1U), U(2)) == 2); 3316 assert(lcm(T(3), U(4U)) == 12); 3317 assert(lcm(T(5U), U(6U)) == 30); 3318 static if (T.min < 0) 3319 assert(lcm(T(-42), U(21U)) == 42); 3320 } 3321 } 3322 } 3323 3324 /// ditto 3325 auto lcm(T)(T a, T b) 3326 if (!isIntegral!T && 3327 is(typeof(T.init % T.init)) && 3328 is(typeof(T.init == 0 || T.init > 0))) 3329 { 3330 // Ensure arguments are unsigned. 3331 a = a >= 0 ? a : -a; 3332 b = b >= 0 ? b : -b; 3333 3334 // Special cases. 3335 if (a == 0) 3336 return a; 3337 if (b == 0) 3338 return b; 3339 3340 return (a / gcdImpl(a, b)) * b; 3341 } 3342 3343 @safe unittest 3344 { 3345 import std.bigint : BigInt; 3346 assert(lcm(BigInt(21), BigInt(6)) == BigInt(42)); 3347 assert(lcm(BigInt(41), BigInt(0)) == BigInt(0)); 3348 assert(lcm(BigInt(0), BigInt(7)) == BigInt(0)); 3349 assert(lcm(BigInt(0), BigInt(0)) == BigInt(0)); 3350 assert(lcm(BigInt(1U), BigInt(2)) == BigInt(2)); 3351 assert(lcm(BigInt(3), BigInt(4U)) == BigInt(12)); 3352 assert(lcm(BigInt(5U), BigInt(6U)) == BigInt(30)); 3353 assert(lcm(BigInt(-42), BigInt(21U)) == BigInt(42)); 3354 } 3355 3356 // This is to make tweaking the speed/size vs. accuracy tradeoff easy, 3357 // though floats seem accurate enough for all practical purposes, since 3358 // they pass the "isClose(inverseFft(fft(arr)), arr)" test even for 3359 // size 2 ^^ 22. 3360 private alias lookup_t = float; 3361 3362 /**A class for performing fast Fourier transforms of power of two sizes. 3363 * This class encapsulates a large amount of state that is reusable when 3364 * performing multiple FFTs of sizes smaller than or equal to that specified 3365 * in the constructor. This results in substantial speedups when performing 3366 * multiple FFTs with a known maximum size. However, 3367 * a free function API is provided for convenience if you need to perform a 3368 * one-off FFT. 3369 * 3370 * References: 3371 * $(HTTP en.wikipedia.org/wiki/Cooley%E2%80%93Tukey_FFT_algorithm) 3372 */ 3373 final class Fft 3374 { 3375 import core.bitop : bsf; 3376 import std.algorithm.iteration : map; 3377 import std.array : uninitializedArray; 3378 3379 private: 3380 immutable lookup_t[][] negSinLookup; 3381 3382 void enforceSize(R)(R range) const 3383 { 3384 import std.conv : text; 3385 assert(range.length <= size, text( 3386 "FFT size mismatch. Expected ", size, ", got ", range.length)); 3387 } 3388 3389 void fftImpl(Ret, R)(Stride!R range, Ret buf) const 3390 in 3391 { 3392 assert(range.length >= 4); 3393 assert(isPowerOf2(range.length)); 3394 } 3395 do 3396 { 3397 auto recurseRange = range; 3398 recurseRange.doubleSteps(); 3399 3400 if (buf.length > 4) 3401 { 3402 fftImpl(recurseRange, buf[0..$ / 2]); 3403 recurseRange.popHalf(); 3404 fftImpl(recurseRange, buf[$ / 2..$]); 3405 } 3406 else 3407 { 3408 // Do this here instead of in another recursion to save on 3409 // recursion overhead. 3410 slowFourier2(recurseRange, buf[0..$ / 2]); 3411 recurseRange.popHalf(); 3412 slowFourier2(recurseRange, buf[$ / 2..$]); 3413 } 3414 3415 butterfly(buf); 3416 } 3417 3418 // This algorithm works by performing the even and odd parts of our FFT 3419 // using the "two for the price of one" method mentioned at 3420 // https://web.archive.org/web/20180312110051/http://www.engineeringproductivitytools.com/stuff/T0001/PT10.HTM#Head521 3421 // by making the odd terms into the imaginary components of our new FFT, 3422 // and then using symmetry to recombine them. 3423 void fftImplPureReal(Ret, R)(R range, Ret buf) const 3424 in 3425 { 3426 assert(range.length >= 4); 3427 assert(isPowerOf2(range.length)); 3428 } 3429 do 3430 { 3431 alias E = ElementType!R; 3432 3433 // Converts odd indices of range to the imaginary components of 3434 // a range half the size. The even indices become the real components. 3435 static if (isArray!R && isFloatingPoint!E) 3436 { 3437 // Then the memory layout of complex numbers provides a dirt 3438 // cheap way to convert. This is a common case, so take advantage. 3439 auto oddsImag = cast(Complex!E[]) range; 3440 } 3441 else 3442 { 3443 // General case: Use a higher order range. We can assume 3444 // source.length is even because it has to be a power of 2. 3445 static struct OddToImaginary 3446 { 3447 R source; 3448 alias C = Complex!(CommonType!(E, typeof(buf[0].re))); 3449 3450 @property 3451 { 3452 C front() 3453 { 3454 return C(source[0], source[1]); 3455 } 3456 3457 C back() 3458 { 3459 immutable n = source.length; 3460 return C(source[n - 2], source[n - 1]); 3461 } 3462 3463 typeof(this) save() 3464 { 3465 return typeof(this)(source.save); 3466 } 3467 3468 bool empty() 3469 { 3470 return source.empty; 3471 } 3472 3473 size_t length() 3474 { 3475 return source.length / 2; 3476 } 3477 } 3478 3479 void popFront() 3480 { 3481 source.popFront(); 3482 source.popFront(); 3483 } 3484 3485 void popBack() 3486 { 3487 source.popBack(); 3488 source.popBack(); 3489 } 3490 3491 C opIndex(size_t index) 3492 { 3493 return C(source[index * 2], source[index * 2 + 1]); 3494 } 3495 3496 typeof(this) opSlice(size_t lower, size_t upper) 3497 { 3498 return typeof(this)(source[lower * 2 .. upper * 2]); 3499 } 3500 } 3501 3502 auto oddsImag = OddToImaginary(range); 3503 } 3504 3505 fft(oddsImag, buf[0..$ / 2]); 3506 auto evenFft = buf[0..$ / 2]; 3507 auto oddFft = buf[$ / 2..$]; 3508 immutable halfN = evenFft.length; 3509 oddFft[0].re = buf[0].im; 3510 oddFft[0].im = 0; 3511 evenFft[0].im = 0; 3512 // evenFft[0].re is already right b/c it's aliased with buf[0].re. 3513 3514 foreach (k; 1 .. halfN / 2 + 1) 3515 { 3516 immutable bufk = buf[k]; 3517 immutable bufnk = buf[buf.length / 2 - k]; 3518 evenFft[k].re = 0.5 * (bufk.re + bufnk.re); 3519 evenFft[halfN - k].re = evenFft[k].re; 3520 evenFft[k].im = 0.5 * (bufk.im - bufnk.im); 3521 evenFft[halfN - k].im = -evenFft[k].im; 3522 3523 oddFft[k].re = 0.5 * (bufk.im + bufnk.im); 3524 oddFft[halfN - k].re = oddFft[k].re; 3525 oddFft[k].im = 0.5 * (bufnk.re - bufk.re); 3526 oddFft[halfN - k].im = -oddFft[k].im; 3527 } 3528 3529 butterfly(buf); 3530 } 3531 3532 void butterfly(R)(R buf) const 3533 in 3534 { 3535 assert(isPowerOf2(buf.length)); 3536 } 3537 do 3538 { 3539 immutable n = buf.length; 3540 immutable localLookup = negSinLookup[bsf(n)]; 3541 assert(localLookup.length == n); 3542 3543 immutable cosMask = n - 1; 3544 immutable cosAdd = n / 4 * 3; 3545 3546 lookup_t negSinFromLookup(size_t index) pure nothrow 3547 { 3548 return localLookup[index]; 3549 } 3550 3551 lookup_t cosFromLookup(size_t index) pure nothrow 3552 { 3553 // cos is just -sin shifted by PI * 3 / 2. 3554 return localLookup[(index + cosAdd) & cosMask]; 3555 } 3556 3557 immutable halfLen = n / 2; 3558 3559 // This loop is unrolled and the two iterations are interleaved 3560 // relative to the textbook FFT to increase ILP. This gives roughly 5% 3561 // speedups on DMD. 3562 for (size_t k = 0; k < halfLen; k += 2) 3563 { 3564 immutable cosTwiddle1 = cosFromLookup(k); 3565 immutable sinTwiddle1 = negSinFromLookup(k); 3566 immutable cosTwiddle2 = cosFromLookup(k + 1); 3567 immutable sinTwiddle2 = negSinFromLookup(k + 1); 3568 3569 immutable realLower1 = buf[k].re; 3570 immutable imagLower1 = buf[k].im; 3571 immutable realLower2 = buf[k + 1].re; 3572 immutable imagLower2 = buf[k + 1].im; 3573 3574 immutable upperIndex1 = k + halfLen; 3575 immutable upperIndex2 = upperIndex1 + 1; 3576 immutable realUpper1 = buf[upperIndex1].re; 3577 immutable imagUpper1 = buf[upperIndex1].im; 3578 immutable realUpper2 = buf[upperIndex2].re; 3579 immutable imagUpper2 = buf[upperIndex2].im; 3580 3581 immutable realAdd1 = cosTwiddle1 * realUpper1 3582 - sinTwiddle1 * imagUpper1; 3583 immutable imagAdd1 = sinTwiddle1 * realUpper1 3584 + cosTwiddle1 * imagUpper1; 3585 immutable realAdd2 = cosTwiddle2 * realUpper2 3586 - sinTwiddle2 * imagUpper2; 3587 immutable imagAdd2 = sinTwiddle2 * realUpper2 3588 + cosTwiddle2 * imagUpper2; 3589 3590 buf[k].re += realAdd1; 3591 buf[k].im += imagAdd1; 3592 buf[k + 1].re += realAdd2; 3593 buf[k + 1].im += imagAdd2; 3594 3595 buf[upperIndex1].re = realLower1 - realAdd1; 3596 buf[upperIndex1].im = imagLower1 - imagAdd1; 3597 buf[upperIndex2].re = realLower2 - realAdd2; 3598 buf[upperIndex2].im = imagLower2 - imagAdd2; 3599 } 3600 } 3601 3602 // This constructor is used within this module for allocating the 3603 // buffer space elsewhere besides the GC heap. It's definitely **NOT** 3604 // part of the public API and definitely **IS** subject to change. 3605 // 3606 // Also, this is unsafe because the memSpace buffer will be cast 3607 // to immutable. 3608 // 3609 // Public b/c of https://issues.dlang.org/show_bug.cgi?id=4636. 3610 public this(lookup_t[] memSpace) 3611 { 3612 immutable size = memSpace.length / 2; 3613 3614 /* Create a lookup table of all negative sine values at a resolution of 3615 * size and all smaller power of two resolutions. This may seem 3616 * inefficient, but having all the lookups be next to each other in 3617 * memory at every level of iteration is a huge win performance-wise. 3618 */ 3619 if (size == 0) 3620 { 3621 return; 3622 } 3623 3624 assert(isPowerOf2(size), 3625 "Can only do FFTs on ranges with a size that is a power of two."); 3626 3627 auto table = new lookup_t[][bsf(size) + 1]; 3628 3629 table[$ - 1] = memSpace[$ - size..$]; 3630 memSpace = memSpace[0 .. size]; 3631 3632 auto lastRow = table[$ - 1]; 3633 lastRow[0] = 0; // -sin(0) == 0. 3634 foreach (ptrdiff_t i; 1 .. size) 3635 { 3636 // The hard coded cases are for improved accuracy and to prevent 3637 // annoying non-zeroness when stuff should be zero. 3638 3639 if (i == size / 4) 3640 lastRow[i] = -1; // -sin(pi / 2) == -1. 3641 else if (i == size / 2) 3642 lastRow[i] = 0; // -sin(pi) == 0. 3643 else if (i == size * 3 / 4) 3644 lastRow[i] = 1; // -sin(pi * 3 / 2) == 1 3645 else 3646 lastRow[i] = -sin(i * 2.0L * PI / size); 3647 } 3648 3649 // Fill in all the other rows with strided versions. 3650 foreach (i; 1 .. table.length - 1) 3651 { 3652 immutable strideLength = size / (2 ^^ i); 3653 auto strided = Stride!(lookup_t[])(lastRow, strideLength); 3654 table[i] = memSpace[$ - strided.length..$]; 3655 memSpace = memSpace[0..$ - strided.length]; 3656 3657 size_t copyIndex; 3658 foreach (elem; strided) 3659 { 3660 table[i][copyIndex++] = elem; 3661 } 3662 } 3663 3664 negSinLookup = cast(immutable) table; 3665 } 3666 3667 public: 3668 /**Create an `Fft` object for computing fast Fourier transforms of 3669 * power of two sizes of `size` or smaller. `size` must be a 3670 * power of two. 3671 */ 3672 this(size_t size) 3673 { 3674 // Allocate all twiddle factor buffers in one contiguous block so that, 3675 // when one is done being used, the next one is next in cache. 3676 auto memSpace = uninitializedArray!(lookup_t[])(2 * size); 3677 this(memSpace); 3678 } 3679 3680 @property size_t size() const 3681 { 3682 return (negSinLookup is null) ? 0 : negSinLookup[$ - 1].length; 3683 } 3684 3685 /**Compute the Fourier transform of range using the $(BIGOH N log N) 3686 * Cooley-Tukey Algorithm. `range` must be a random-access range with 3687 * slicing and a length equal to `size` as provided at the construction of 3688 * this object. The contents of range can be either numeric types, 3689 * which will be interpreted as pure real values, or complex types with 3690 * properties or members `.re` and `.im` that can be read. 3691 * 3692 * Note: Pure real FFTs are automatically detected and the relevant 3693 * optimizations are performed. 3694 * 3695 * Returns: An array of complex numbers representing the transformed data in 3696 * the frequency domain. 3697 * 3698 * Conventions: The exponent is negative and the factor is one, 3699 * i.e., output[j] := sum[ exp(-2 PI i j k / N) input[k] ]. 3700 */ 3701 Complex!F[] fft(F = double, R)(R range) const 3702 if (isFloatingPoint!F && isRandomAccessRange!R) 3703 { 3704 enforceSize(range); 3705 Complex!F[] ret; 3706 if (range.length == 0) 3707 { 3708 return ret; 3709 } 3710 3711 // Don't waste time initializing the memory for ret. 3712 ret = uninitializedArray!(Complex!F[])(range.length); 3713 3714 fft(range, ret); 3715 return ret; 3716 } 3717 3718 /**Same as the overload, but allows for the results to be stored in a user- 3719 * provided buffer. The buffer must be of the same length as range, must be 3720 * a random-access range, must have slicing, and must contain elements that are 3721 * complex-like. This means that they must have a .re and a .im member or 3722 * property that can be both read and written and are floating point numbers. 3723 */ 3724 void fft(Ret, R)(R range, Ret buf) const 3725 if (isRandomAccessRange!Ret && isComplexLike!(ElementType!Ret) && hasSlicing!Ret) 3726 { 3727 assert(buf.length == range.length); 3728 enforceSize(range); 3729 3730 if (range.length == 0) 3731 { 3732 return; 3733 } 3734 else if (range.length == 1) 3735 { 3736 buf[0] = range[0]; 3737 return; 3738 } 3739 else if (range.length == 2) 3740 { 3741 slowFourier2(range, buf); 3742 return; 3743 } 3744 else 3745 { 3746 alias E = ElementType!R; 3747 static if (is(E : real)) 3748 { 3749 return fftImplPureReal(range, buf); 3750 } 3751 else 3752 { 3753 static if (is(R : Stride!R)) 3754 return fftImpl(range, buf); 3755 else 3756 return fftImpl(Stride!R(range, 1), buf); 3757 } 3758 } 3759 } 3760 3761 /** 3762 * Computes the inverse Fourier transform of a range. The range must be a 3763 * random access range with slicing, have a length equal to the size 3764 * provided at construction of this object, and contain elements that are 3765 * either of type std.complex.Complex or have essentially 3766 * the same compile-time interface. 3767 * 3768 * Returns: The time-domain signal. 3769 * 3770 * Conventions: The exponent is positive and the factor is 1/N, i.e., 3771 * output[j] := (1 / N) sum[ exp(+2 PI i j k / N) input[k] ]. 3772 */ 3773 Complex!F[] inverseFft(F = double, R)(R range) const 3774 if (isRandomAccessRange!R && isComplexLike!(ElementType!R) && isFloatingPoint!F) 3775 { 3776 enforceSize(range); 3777 Complex!F[] ret; 3778 if (range.length == 0) 3779 { 3780 return ret; 3781 } 3782 3783 // Don't waste time initializing the memory for ret. 3784 ret = uninitializedArray!(Complex!F[])(range.length); 3785 3786 inverseFft(range, ret); 3787 return ret; 3788 } 3789 3790 /** 3791 * Inverse FFT that allows a user-supplied buffer to be provided. The buffer 3792 * must be a random access range with slicing, and its elements 3793 * must be some complex-like type. 3794 */ 3795 void inverseFft(Ret, R)(R range, Ret buf) const 3796 if (isRandomAccessRange!Ret && isComplexLike!(ElementType!Ret) && hasSlicing!Ret) 3797 { 3798 enforceSize(range); 3799 3800 auto swapped = map!swapRealImag(range); 3801 fft(swapped, buf); 3802 3803 immutable lenNeg1 = 1.0 / buf.length; 3804 foreach (ref elem; buf) 3805 { 3806 immutable temp = elem.re * lenNeg1; 3807 elem.re = elem.im * lenNeg1; 3808 elem.im = temp; 3809 } 3810 } 3811 } 3812 3813 // This mixin creates an Fft object in the scope it's mixed into such that all 3814 // memory owned by the object is deterministically destroyed at the end of that 3815 // scope. 3816 private enum string MakeLocalFft = q{ 3817 import core.stdc.stdlib; 3818 import core.exception : onOutOfMemoryError; 3819 3820 auto lookupBuf = (cast(lookup_t*) malloc(range.length * 2 * lookup_t.sizeof)) 3821 [0 .. 2 * range.length]; 3822 if (!lookupBuf.ptr) 3823 onOutOfMemoryError(); 3824 3825 scope(exit) free(cast(void*) lookupBuf.ptr); 3826 auto fftObj = scoped!Fft(lookupBuf); 3827 }; 3828 3829 /**Convenience functions that create an `Fft` object, run the FFT or inverse 3830 * FFT and return the result. Useful for one-off FFTs. 3831 * 3832 * Note: In addition to convenience, these functions are slightly more 3833 * efficient than manually creating an Fft object for a single use, 3834 * as the Fft object is deterministically destroyed before these 3835 * functions return. 3836 */ 3837 Complex!F[] fft(F = double, R)(R range) 3838 { 3839 mixin(MakeLocalFft); 3840 return fftObj.fft!(F, R)(range); 3841 } 3842 3843 /// ditto 3844 void fft(Ret, R)(R range, Ret buf) 3845 { 3846 mixin(MakeLocalFft); 3847 return fftObj.fft!(Ret, R)(range, buf); 3848 } 3849 3850 /// ditto 3851 Complex!F[] inverseFft(F = double, R)(R range) 3852 { 3853 mixin(MakeLocalFft); 3854 return fftObj.inverseFft!(F, R)(range); 3855 } 3856 3857 /// ditto 3858 void inverseFft(Ret, R)(R range, Ret buf) 3859 { 3860 mixin(MakeLocalFft); 3861 return fftObj.inverseFft!(Ret, R)(range, buf); 3862 } 3863 3864 @system unittest 3865 { 3866 import std.algorithm; 3867 import std.conv; 3868 import std.range; 3869 // Test values from R and Octave. 3870 auto arr = [1,2,3,4,5,6,7,8]; 3871 auto fft1 = fft(arr); 3872 assert(isClose(map!"a.re"(fft1), 3873 [36.0, -4, -4, -4, -4, -4, -4, -4], 1e-4)); 3874 assert(isClose(map!"a.im"(fft1), 3875 [0, 9.6568, 4, 1.6568, 0, -1.6568, -4, -9.6568], 1e-4)); 3876 3877 auto fft1Retro = fft(retro(arr)); 3878 assert(isClose(map!"a.re"(fft1Retro), 3879 [36.0, 4, 4, 4, 4, 4, 4, 4], 1e-4)); 3880 assert(isClose(map!"a.im"(fft1Retro), 3881 [0, -9.6568, -4, -1.6568, 0, 1.6568, 4, 9.6568], 1e-4)); 3882 3883 auto fft1Float = fft(to!(float[])(arr)); 3884 assert(isClose(map!"a.re"(fft1), map!"a.re"(fft1Float))); 3885 assert(isClose(map!"a.im"(fft1), map!"a.im"(fft1Float))); 3886 3887 alias C = Complex!float; 3888 auto arr2 = [C(1,2), C(3,4), C(5,6), C(7,8), C(9,10), 3889 C(11,12), C(13,14), C(15,16)]; 3890 auto fft2 = fft(arr2); 3891 assert(isClose(map!"a.re"(fft2), 3892 [64.0, -27.3137, -16, -11.3137, -8, -4.6862, 0, 11.3137], 1e-4)); 3893 assert(isClose(map!"a.im"(fft2), 3894 [72, 11.3137, 0, -4.686, -8, -11.3137, -16, -27.3137], 1e-4)); 3895 3896 auto inv1 = inverseFft(fft1); 3897 assert(isClose(map!"a.re"(inv1), arr, 1e-6)); 3898 assert(reduce!max(map!"a.im"(inv1)) < 1e-10); 3899 3900 auto inv2 = inverseFft(fft2); 3901 assert(isClose(map!"a.re"(inv2), map!"a.re"(arr2))); 3902 assert(isClose(map!"a.im"(inv2), map!"a.im"(arr2))); 3903 3904 // FFTs of size 0, 1 and 2 are handled as special cases. Test them here. 3905 ushort[] empty; 3906 assert(fft(empty) == null); 3907 assert(inverseFft(fft(empty)) == null); 3908 3909 real[] oneElem = [4.5L]; 3910 auto oneFft = fft(oneElem); 3911 assert(oneFft.length == 1); 3912 assert(oneFft[0].re == 4.5L); 3913 assert(oneFft[0].im == 0); 3914 3915 auto oneInv = inverseFft(oneFft); 3916 assert(oneInv.length == 1); 3917 assert(isClose(oneInv[0].re, 4.5)); 3918 assert(isClose(oneInv[0].im, 0, 0.0, 1e-10)); 3919 3920 long[2] twoElems = [8, 4]; 3921 auto twoFft = fft(twoElems[]); 3922 assert(twoFft.length == 2); 3923 assert(isClose(twoFft[0].re, 12)); 3924 assert(isClose(twoFft[0].im, 0, 0.0, 1e-10)); 3925 assert(isClose(twoFft[1].re, 4)); 3926 assert(isClose(twoFft[1].im, 0, 0.0, 1e-10)); 3927 auto twoInv = inverseFft(twoFft); 3928 assert(isClose(twoInv[0].re, 8)); 3929 assert(isClose(twoInv[0].im, 0, 0.0, 1e-10)); 3930 assert(isClose(twoInv[1].re, 4)); 3931 assert(isClose(twoInv[1].im, 0, 0.0, 1e-10)); 3932 } 3933 3934 // Swaps the real and imaginary parts of a complex number. This is useful 3935 // for inverse FFTs. 3936 C swapRealImag(C)(C input) 3937 { 3938 return C(input.im, input.re); 3939 } 3940 3941 /** This function transforms `decimal` value into a value in the factorial number 3942 system stored in `fac`. 3943 3944 A factorial number is constructed as: 3945 $(D fac[0] * 0! + fac[1] * 1! + ... fac[20] * 20!) 3946 3947 Params: 3948 decimal = The decimal value to convert into the factorial number system. 3949 fac = The array to store the factorial number. The array is of size 21 as 3950 `ulong.max` requires 21 digits in the factorial number system. 3951 Returns: 3952 A variable storing the number of digits of the factorial number stored in 3953 `fac`. 3954 */ 3955 size_t decimalToFactorial(ulong decimal, ref ubyte[21] fac) 3956 @safe pure nothrow @nogc 3957 { 3958 import std.algorithm.mutation : reverse; 3959 size_t idx; 3960 3961 for (ulong i = 1; decimal != 0; ++i) 3962 { 3963 auto temp = decimal % i; 3964 decimal /= i; 3965 fac[idx++] = cast(ubyte)(temp); 3966 } 3967 3968 if (idx == 0) 3969 { 3970 fac[idx++] = cast(ubyte) 0; 3971 } 3972 3973 reverse(fac[0 .. idx]); 3974 3975 // first digit of the number in factorial will always be zero 3976 assert(fac[idx - 1] == 0); 3977 3978 return idx; 3979 } 3980 3981 /// 3982 @safe pure @nogc unittest 3983 { 3984 ubyte[21] fac; 3985 size_t idx = decimalToFactorial(2982, fac); 3986 3987 assert(fac[0] == 4); 3988 assert(fac[1] == 0); 3989 assert(fac[2] == 4); 3990 assert(fac[3] == 1); 3991 assert(fac[4] == 0); 3992 assert(fac[5] == 0); 3993 assert(fac[6] == 0); 3994 } 3995 3996 @safe pure unittest 3997 { 3998 ubyte[21] fac; 3999 size_t idx = decimalToFactorial(0UL, fac); 4000 assert(idx == 1); 4001 assert(fac[0] == 0); 4002 4003 fac[] = 0; 4004 idx = 0; 4005 idx = decimalToFactorial(ulong.max, fac); 4006 assert(idx == 21); 4007 auto t = [7, 11, 12, 4, 3, 15, 3, 5, 3, 5, 0, 8, 3, 5, 0, 0, 0, 2, 1, 1, 0]; 4008 foreach (i, it; fac[0 .. 21]) 4009 { 4010 assert(it == t[i]); 4011 } 4012 4013 fac[] = 0; 4014 idx = decimalToFactorial(2982, fac); 4015 4016 assert(idx == 7); 4017 t = [4, 0, 4, 1, 0, 0, 0]; 4018 foreach (i, it; fac[0 .. idx]) 4019 { 4020 assert(it == t[i]); 4021 } 4022 } 4023 4024 private: 4025 // The reasons I couldn't use std.algorithm were b/c its stride length isn't 4026 // modifiable on the fly and because range has grown some performance hacks 4027 // for powers of 2. 4028 struct Stride(R) 4029 { 4030 import core.bitop : bsf; 4031 Unqual!R range; 4032 size_t _nSteps; 4033 size_t _length; 4034 alias E = ElementType!(R); 4035 4036 this(R range, size_t nStepsIn) 4037 { 4038 this.range = range; 4039 _nSteps = nStepsIn; 4040 _length = (range.length + _nSteps - 1) / nSteps; 4041 } 4042 4043 size_t length() const @property 4044 { 4045 return _length; 4046 } 4047 4048 typeof(this) save() @property 4049 { 4050 auto ret = this; 4051 ret.range = ret.range.save; 4052 return ret; 4053 } 4054 4055 E opIndex(size_t index) 4056 { 4057 return range[index * _nSteps]; 4058 } 4059 4060 E front() @property 4061 { 4062 return range[0]; 4063 } 4064 4065 void popFront() 4066 { 4067 if (range.length >= _nSteps) 4068 { 4069 range = range[_nSteps .. range.length]; 4070 _length--; 4071 } 4072 else 4073 { 4074 range = range[0 .. 0]; 4075 _length = 0; 4076 } 4077 } 4078 4079 // Pops half the range's stride. 4080 void popHalf() 4081 { 4082 range = range[_nSteps / 2 .. range.length]; 4083 } 4084 4085 bool empty() const @property 4086 { 4087 return length == 0; 4088 } 4089 4090 size_t nSteps() const @property 4091 { 4092 return _nSteps; 4093 } 4094 4095 void doubleSteps() 4096 { 4097 _nSteps *= 2; 4098 _length /= 2; 4099 } 4100 4101 size_t nSteps(size_t newVal) @property 4102 { 4103 _nSteps = newVal; 4104 4105 // Using >> bsf(nSteps) is a few cycles faster than / nSteps. 4106 _length = (range.length + _nSteps - 1) >> bsf(nSteps); 4107 return newVal; 4108 } 4109 } 4110 4111 // Hard-coded base case for FFT of size 2. This is actually a TON faster than 4112 // using a generic slow DFT. This seems to be the best base case. (Size 1 4113 // can be coded inline as buf[0] = range[0]). 4114 void slowFourier2(Ret, R)(R range, Ret buf) 4115 { 4116 assert(range.length == 2); 4117 assert(buf.length == 2); 4118 buf[0] = range[0] + range[1]; 4119 buf[1] = range[0] - range[1]; 4120 } 4121 4122 // Hard-coded base case for FFT of size 4. Doesn't work as well as the size 4123 // 2 case. 4124 void slowFourier4(Ret, R)(R range, Ret buf) 4125 { 4126 alias C = ElementType!Ret; 4127 4128 assert(range.length == 4); 4129 assert(buf.length == 4); 4130 buf[0] = range[0] + range[1] + range[2] + range[3]; 4131 buf[1] = range[0] - range[1] * C(0, 1) - range[2] + range[3] * C(0, 1); 4132 buf[2] = range[0] - range[1] + range[2] - range[3]; 4133 buf[3] = range[0] + range[1] * C(0, 1) - range[2] - range[3] * C(0, 1); 4134 } 4135 4136 N roundDownToPowerOf2(N)(N num) 4137 if (isScalarType!N && !isFloatingPoint!N) 4138 { 4139 import core.bitop : bsr; 4140 return num & (cast(N) 1 << bsr(num)); 4141 } 4142 4143 @safe unittest 4144 { 4145 assert(roundDownToPowerOf2(7) == 4); 4146 assert(roundDownToPowerOf2(4) == 4); 4147 } 4148 4149 template isComplexLike(T) 4150 { 4151 enum bool isComplexLike = is(typeof(T.init.re)) && 4152 is(typeof(T.init.im)); 4153 } 4154 4155 @safe unittest 4156 { 4157 static assert(isComplexLike!(Complex!double)); 4158 static assert(!isComplexLike!(uint)); 4159 }