Skip to content

Commit 530f858

Browse files
committed
wip
1 parent 10dc32c commit 530f858

7 files changed

Lines changed: 91 additions & 94 deletions

File tree

arithmetic.go

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ package float16
22

33
import (
44
"fmt"
5+
"math"
56
)
67

78
// Global arithmetic settings
@@ -88,7 +89,7 @@ func AddWithMode(a, b Float16, mode ArithmeticMode, rounding RoundingMode) (Floa
8889
f32a := a.ToFloat32()
8990
f32b := b.ToFloat32()
9091
result := f32a + f32b
91-
return ToFloat16WithMode(result, ModeIEEE, rounding)
92+
return NewConverter(ModeIEEE, rounding).ToFloat16(result), nil
9293
}
9394

9495
// Full IEEE 754 implementation for exact mode
@@ -182,11 +183,11 @@ func MulWithMode(a, b Float16, mode ArithmeticMode, rounding RoundingMode) (Floa
182183
f32a := a.ToFloat32()
183184
f32b := b.ToFloat32()
184185
result := f32a * f32b
185-
return ToFloat16WithMode(result, ModeIEEE, rounding)
186+
return NewConverter(ModeIEEE, rounding).ToFloat16(result), nil
186187
}
187188

188189
// Full IEEE 754 implementation
189-
return mulIEEE754(a, b, rounding)
190+
return addIEEE754(a, b, rounding)
190191
}
191192

192193
// Div performs division of two Float16 values
@@ -320,11 +321,11 @@ func DivWithMode(a, b Float16, mode ArithmeticMode, rounding RoundingMode) (Floa
320321
f32a := a.ToFloat32()
321322
f32b := b.ToFloat32()
322323
result := f32a / f32b
323-
return ToFloat16WithMode(result, ModeIEEE, rounding)
324+
return NewConverter(ModeIEEE, rounding).ToFloat16(result), nil
324325
}
325326

326327
// Full IEEE 754 implementation
327-
return divIEEE754(a, b, rounding)
328+
return addIEEE754(a, b, rounding)
328329
}
329330

330331
// IEEE 754 compliant arithmetic implementations
@@ -336,7 +337,7 @@ func addIEEE754(a, b Float16, rounding RoundingMode) (Float16, error) {
336337
f32a := a.ToFloat32()
337338
f32b := b.ToFloat32()
338339
result := f32a + f32b
339-
return ToFloat16WithMode(result, ModeIEEE, rounding)
340+
return NewConverter(ModeIEEE, rounding).ToFloat16WithMode(result)
340341
}
341342

342343
// mulIEEE754 implements full IEEE 754 multiplication
@@ -346,7 +347,7 @@ func mulIEEE754(a, b Float16, rounding RoundingMode) (Float16, error) {
346347
f32a := a.ToFloat32()
347348
f32b := b.ToFloat32()
348349
result := f32a * f32b
349-
return ToFloat16WithMode(result, ModeIEEE, rounding)
350+
return NewConverter(ModeIEEE, rounding).ToFloat16WithMode(result)
350351
}
351352

352353
// divIEEE754 implements full IEEE 754 division
@@ -358,7 +359,7 @@ func divIEEE754(a, b Float16, rounding RoundingMode) (Float16, error) {
358359
result := f32a / f32b
359360

360361
// Use the provided rounding mode for the conversion back to Float16
361-
return ToFloat16WithMode(result, ModeExact, rounding)
362+
return NewConverter(ModeExact, rounding).ToFloat16WithMode(result)
362363
}
363364

364365
// Comparison operations
@@ -559,5 +560,5 @@ func Norm2(s []Float16) Float16 {
559560
square := Mul(v, v)
560561
sumSquares = Add(sumSquares, square)
561562
}
562-
return Sqrt(sumSquares)
563+
return NewConverter(DefaultConversionMode, DefaultRoundingMode).FromFloat64(math.Sqrt(sumSquares.ToFloat64()))
563564
}

convert.go

Lines changed: 25 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,10 +25,14 @@ func (c *Converter) ToFloat16(f32 float32) Float16 {
2525
return Float16(float16.Fromfloat32(f32).Bits())
2626
}
2727

28+
// ToFloat16 converts a float32 to Float16 with default conversion and rounding modes
29+
func ToFloat16(f32 float32) Float16 {
30+
return NewConverter(DefaultConversionMode, DefaultRoundingMode).ToFloat16(f32)
31+
}
32+
2833
// ToFloat16WithMode converts a float32 to Float16 with specified conversion and rounding modes
2934
func (c *Converter) ToFloat16WithMode(f32 float32) (Float16, error) {
3035
convMode := c.ConversionMode
31-
roundMode := c.RoundingMode
3236
if convMode == ModeStrict {
3337
if math.IsInf(float64(f32), 0) {
3438
return 0, &Float16Error{Code: ErrInfinity}
@@ -85,7 +89,7 @@ func FromFloat64WithMode(f64 float64, convMode ConversionMode, roundMode Roundin
8589
}
8690
}
8791

88-
return c.ToFloat16WithMode(float32(f64))
92+
return NewConverter(convMode, roundMode).ToFloat16WithMode(float32(f64))
8993
}
9094

9195
// ToSlice16 converts a slice of float32 to Float16 with optimized performance
@@ -100,6 +104,11 @@ func (c *Converter) ToSlice16(f32s []float32) []Float16 {
100104
return res
101105
}
102106

107+
// ToSlice16 converts a slice of float32 to Float16 with default conversion and rounding modes
108+
func ToSlice16(f32s []float32) []Float16 {
109+
return NewConverter(DefaultConversionMode, DefaultRoundingMode).ToSlice16(f32s)
110+
}
111+
103112
// ToSlice32 converts a slice of Float16 to float32 with optimized performance
104113
func ToSlice32(f16s []Float16) []float32 {
105114
if len(f16s) == 0 {
@@ -160,6 +169,11 @@ func (c *Converter) FromInt(i int) Float16 {
160169
return c.ToFloat16(float32(i))
161170
}
162171

172+
// FromInt converts an integer to Float16 with default conversion and rounding modes
173+
func FromInt(i int) Float16 {
174+
return NewConverter(DefaultConversionMode, DefaultRoundingMode).FromInt(i)
175+
}
176+
163177
// FromInt32 converts an int32 to Float16
164178
func (c *Converter) FromInt32(i int32) Float16 {
165179
return c.ToFloat16(float32(i))
@@ -195,6 +209,11 @@ func (c *Converter) Parse(s string) (Float16, error) {
195209
Code: ErrInvalidOperation,
196210
}
197211
}
212+
213+
// Parse converts a string to Float16 with default conversion and rounding modes
214+
func Parse(s string) (Float16, error) {
215+
return NewConverter(DefaultConversionMode, DefaultRoundingMode).Parse(s)
216+
}
198217
func (c *Converter) shouldRound(mantissa uint32, shift int, sign uint16) bool {
199218
switch c.RoundingMode {
200219
case RoundNearestEven:
@@ -215,3 +234,7 @@ func (c *Converter) shouldRound(mantissa uint32, shift int, sign uint16) bool {
215234
}
216235
return false
217236
}
237+
238+
func shouldRound(mantissa uint32, shift int, sign uint16) bool {
239+
return NewConverter(DefaultConversionMode, DefaultRoundingMode).shouldRound(mantissa, shift, sign)
240+
}

convert_test.go

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -125,8 +125,9 @@ func TestShouldRound(t *testing.T) {
125125

126126
for _, tt := range tests {
127127
t.Run(tt.name, func(t *testing.T) {
128-
if got := shouldRound(tt.mantissa, tt.shift, tt.mode, tt.sign); got != tt.shouldRound {
129-
t.Errorf("shouldRound() = %v, want %v", got, tt.shouldRound)
128+
got := shouldRound(tt.mantissa, tt.shift, tt.sign)
129+
if got != tt.shouldRound {
130+
t.Errorf("shouldRound(%d, %d, %d) = %v, want %v", tt.mantissa, tt.shift, tt.sign, got, tt.shouldRound)
130131
}
131132
})
132133
}

float16.go

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,11 @@ const (
5252
VersionPatch = 0
5353
)
5454

55+
var (
56+
DefaultConversionMode ConversionMode = ModeIEEE
57+
DefaultRoundingMode RoundingMode = RoundNearestEven
58+
)
59+
5560
// Package configuration
5661
type Config struct {
5762
DefaultConversionMode ConversionMode

float16_test.go

Lines changed: 23 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -425,55 +425,29 @@ func TestDebugSubnormalValues(t *testing.T) {
425425

426426
func TestSqrt(t *testing.T) {
427427
converter := NewConverter(ModeIEEE, RoundNearestEven)
428-
tests := []struct {
429-
input Float16
430-
expected Float16
431-
name string
432-
}{
433-
{PositiveZero, PositiveZero, "sqrt(0)"},
434-
{converter.ToFloat16(1.0), converter.ToFloat16(1.0), "sqrt(1)"},
435-
{converter.ToFloat16(4.0), converter.ToFloat16(2.0), "sqrt(4)"},
436-
{converter.ToFloat16(16.0), converter.ToFloat16(4.0), "sqrt(16)"},
437-
{PositiveInfinity, PositiveInfinity, "sqrt(inf)"},
438-
}
439-
440-
for _, test := range tests {
441-
t.Run(test.name, func(t *testing.T) {
442-
result := Sqrt(test.input)
443-
if !Equal(result, test.expected) && !result.IsInf(0) {
444-
t.Errorf("Sqrt(0x%04x) = 0x%04x, expected 0x%04x",
445-
test.input, result, test.expected)
446-
}
447-
})
448-
}
449-
}
450-
451-
func TestMathConstants(t *testing.T) {
452-
converter := NewConverter(ModeIEEE, RoundNearestEven)
453-
// Just verify that constants are reasonable values
454-
if E.ToFloat32() < 2.7 || E.ToFloat32() > 2.8 {
455-
t.Errorf("E constant seems wrong: %g", E.ToFloat32())
456-
}
457-
if Pi.ToFloat32() < 3.1 || Pi.ToFloat32() > 3.2 {
458-
t.Errorf("Pi constant seems wrong: %g", Pi.ToFloat32())
459-
}
460-
if Sqrt2.ToFloat32() < 1.4 || Sqrt2.ToFloat32() > 1.5 {
461-
t.Errorf("Sqrt2 constant seems wrong: %g", Sqrt2.ToFloat32())
428+
mathConverter := NewMathConverter(converter)
429+
// Test Sqrt
430+
sqrtResult := mathConverter.Sqrt(converter.FromFloat32(4.0))
431+
if sqrtResult != converter.FromFloat32(2.0) {
432+
t.Errorf("Expected Sqrt(4.0) to be 2.0, but got %v", sqrtResult)
462433
}
463434
}
464435

465-
func TestTrigFunctions(t *testing.T) {
436+
func TestSinCosTan(t *testing.T) {
466437
converter := NewConverter(ModeIEEE, RoundNearestEven)
467-
// Test basic trigonometric identities
468-
zero := PositiveZero
469-
if !Equal(Sin(zero), zero) {
470-
t.Error("sin(0) should be 0")
438+
mathConverter := NewMathConverter(converter)
439+
// Test Sin, Cos, Tan
440+
sinResult := mathConverter.Sin(converter.FromFloat32(0.0))
441+
if sinResult != converter.FromFloat32(0.0) {
442+
t.Errorf("Expected Sin(0.0) to be 0.0, but got %v", sinResult)
471443
}
472-
if !Equal(Cos(zero), converter.ToFloat16(1.0)) {
473-
t.Error("cos(0) should be 1")
444+
cosResult := mathConverter.Cos(converter.FromFloat32(0.0))
445+
if cosResult != converter.FromFloat32(1.0) {
446+
t.Errorf("Expected Cos(0.0) to be 1.0, but got %v", cosResult)
474447
}
475-
if !Equal(Tan(zero), zero) {
476-
t.Error("tan(0) should be 0")
448+
tanResult := mathConverter.Tan(converter.FromFloat32(0.0))
449+
if tanResult != converter.FromFloat32(0.0) {
450+
t.Errorf("Expected Tan(0.0) to be 0.0, but got %v", tanResult)
477451
}
478452
}
479453

@@ -489,7 +463,7 @@ func TestToFloat64(t *testing.T) {
489463
{"negative zero", NegativeZero, math.Copysign(0.0, -1.0)},
490464
{"positive infinity", PositiveInfinity, math.Inf(1)},
491465
{"negative infinity", NegativeInfinity, math.Inf(-1)},
492-
{"quiet NaN", converter.NaN(), math.NaN()},
466+
{"quiet NaN", NaN(), math.NaN()},
493467

494468
// Normal numbers
495469
{"one", Float16(0x3c00), 1.0},
@@ -578,8 +552,7 @@ func TestFromFloat64(t *testing.T) {
578552
func TestFromFloat64WithMode(t *testing.T) {
579553
// Test basic conversion
580554
t.Run("basic conversion", func(t *testing.T) {
581-
converter := NewConverter(testModeDefault, testRoundNearestEven)
582-
result, err := converter.FromFloat64WithMode(1.5)
555+
result, err := FromFloat64WithMode(1.5, testModeDefault, testRoundNearestEven)
583556
if err != nil {
584557
t.Fatalf("Unexpected error: %v", err)
585558
}
@@ -591,26 +564,23 @@ func TestFromFloat64WithMode(t *testing.T) {
591564

592565
// Test strict mode with overflow
593566
t.Run("strict mode overflow", func(t *testing.T) {
594-
converter := NewConverter(testModeStrict, testRoundNearestEven)
595-
_, err := converter.FromFloat64WithMode(1e10)
567+
_, err := FromFloat64WithMode(1e10, testModeStrict, testRoundNearestEven)
596568
if err == nil {
597569
t.Error("Expected overflow error in strict mode")
598570
}
599571
})
600572

601573
// Test strict mode with underflow
602574
t.Run("strict mode underflow", func(t *testing.T) {
603-
converter := NewConverter(testModeStrict, testRoundNearestEven)
604-
_, err := converter.FromFloat64WithMode(1e-10)
575+
_, err := FromFloat64WithMode(1e-10, testModeStrict, testRoundNearestEven)
605576
if err == nil {
606577
t.Error("Expected underflow error in strict mode")
607578
}
608579
})
609580

610581
// Test NaN in strict mode
611582
t.Run("strict mode NaN", func(t *testing.T) {
612-
converter := NewConverter(testModeStrict, testRoundNearestEven)
613-
_, err := converter.FromFloat64WithMode(math.NaN())
583+
_, err := FromFloat64WithMode(math.NaN(), testModeStrict, testRoundNearestEven)
614584
if err == nil {
615585
t.Error("Expected NaN error in strict mode")
616586
}
@@ -631,8 +601,7 @@ func TestFromFloat64WithMode(t *testing.T) {
631601

632602
for _, test := range roundingTests {
633603
t.Run(test.name, func(t *testing.T) {
634-
converter := NewConverter(testModeDefault, test.roundMode)
635-
result, err := converter.FromFloat64WithMode(test.input)
604+
result, err := FromFloat64WithMode(test.input, testModeDefault, test.roundMode)
636605
if err != nil {
637606
t.Fatalf("Unexpected error: %v", err)
638607
}
@@ -722,8 +691,6 @@ func BenchmarkSqrt(b *testing.B) {
722691
}
723692

724693

725-
726-
func TestFloat16ConverterInitialization(t *testing.T) {
727694
converter := NewConverter(ModeIEEE, RoundNearestEven)
728695
if converter == nil {
729696
t.Error("Expected converter to be initialized, got nil")

0 commit comments

Comments
 (0)