Skip to content

Commit cdc4fc5

Browse files
committed
feat: implement float16 arithmetic with IEEE 754 conversion and tests
1 parent eff0d28 commit cdc4fc5

16 files changed

Lines changed: 4684 additions & 0 deletions

arithmetic.go

Lines changed: 663 additions & 0 deletions
Large diffs are not rendered by default.

arithmetic_mode_test.go

Lines changed: 337 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,337 @@
1+
package float16
2+
3+
import "testing"
4+
5+
func TestSubWithMode(t *testing.T) {
6+
tests := []struct {
7+
name string
8+
a, b Float16
9+
mode ArithmeticMode
10+
rounding RoundingMode
11+
want Float16
12+
wantErr bool
13+
}{
14+
{
15+
name: "exact mode with valid inputs",
16+
a: FromBits(0x4000), // 2.0
17+
b: FromBits(0x3C00), // 1.0
18+
mode: ModeExactArithmetic,
19+
rounding: RoundNearestEven,
20+
want: FromBits(0x3C00), // 1.0
21+
wantErr: false,
22+
},
23+
{
24+
name: "exact mode with NaN",
25+
a: NaN(),
26+
b: FromBits(0x3C00), // 1.0
27+
mode: ModeExactArithmetic,
28+
rounding: RoundNearestEven,
29+
want: 0,
30+
wantErr: true,
31+
},
32+
{
33+
name: "fast mode with valid inputs",
34+
a: FromBits(0x4200), // 3.0
35+
b: FromBits(0x3C00), // 1.0
36+
mode: ModeFastArithmetic,
37+
rounding: RoundNearestEven,
38+
want: FromBits(0x4000), // 2.0
39+
wantErr: false,
40+
},
41+
{
42+
name: "IEEE mode with valid inputs",
43+
a: FromBits(0x4400), // 4.0
44+
b: FromBits(0x3C00), // 1.0
45+
mode: ModeIEEEArithmetic,
46+
rounding: RoundNearestEven,
47+
want: FromBits(0x4200), // 3.0
48+
wantErr: false,
49+
},
50+
{
51+
name: "infinity minus infinity",
52+
a: PositiveInfinity,
53+
b: PositiveInfinity,
54+
mode: ModeExactArithmetic,
55+
rounding: RoundNearestEven,
56+
want: 0,
57+
wantErr: true,
58+
},
59+
{
60+
name: "IEEE mode with valid inputs",
61+
a: FromBits(0x4400), // 4.0
62+
b: FromBits(0x3C00), // 1.0
63+
mode: ModeIEEEArithmetic,
64+
rounding: RoundNearestEven,
65+
want: FromBits(0x4200), // 3.0
66+
wantErr: false,
67+
},
68+
{
69+
name: "SubWithMode - Negative infinity minus negative infinity",
70+
a: NegativeInfinity,
71+
b: NegativeInfinity,
72+
mode: ModeExactArithmetic, // This operation should only error in exact mode
73+
rounding: RoundNearestEven,
74+
want: QuietNaN,
75+
wantErr: true,
76+
},
77+
{
78+
name: "IEEE mode with negative infinity minus negative infinity",
79+
a: NegativeInfinity,
80+
b: NegativeInfinity,
81+
mode: ModeIEEEArithmetic,
82+
rounding: RoundNearestEven,
83+
want: QuietNaN, // Should return NaN without error in IEEE mode
84+
wantErr: false,
85+
},
86+
}
87+
88+
for _, tt := range tests {
89+
t.Run(tt.name, func(t *testing.T) {
90+
got, err := SubWithMode(tt.a, tt.b, tt.mode, tt.rounding)
91+
if (err != nil) != tt.wantErr {
92+
t.Errorf("SubWithMode() error = %v, wantErr %v", err, tt.wantErr)
93+
return
94+
}
95+
if !tt.wantErr && got != tt.want {
96+
t.Errorf("SubWithMode() = %v, want %v", got, tt.want)
97+
}
98+
})
99+
}
100+
}
101+
102+
func TestMulWithMode(t *testing.T) {
103+
tests := []struct {
104+
name string
105+
a, b Float16
106+
mode ArithmeticMode
107+
rounding RoundingMode
108+
want Float16
109+
wantErr bool
110+
}{
111+
{
112+
name: "exact mode with valid inputs",
113+
a: FromBits(0x4000), // 2.0
114+
b: FromBits(0x4200), // 3.0
115+
mode: ModeExactArithmetic,
116+
rounding: RoundNearestEven,
117+
want: FromBits(0x4600), // 6.0 (0x4600 is 6.0 in float16)
118+
wantErr: false,
119+
},
120+
{
121+
name: "exact mode with NaN",
122+
a: NaN(),
123+
b: FromBits(0x3C00), // 1.0
124+
mode: ModeExactArithmetic,
125+
rounding: RoundNearestEven,
126+
want: 0,
127+
wantErr: true,
128+
},
129+
{
130+
name: "infinity times zero in exact mode",
131+
a: Infinity(1),
132+
b: FromBits(0x0000), // 0.0
133+
mode: ModeExactArithmetic,
134+
rounding: RoundNearestEven,
135+
want: 0,
136+
wantErr: true,
137+
},
138+
{
139+
name: "infinity times zero in IEEE mode",
140+
a: Infinity(1),
141+
b: FromBits(0x0000), // 0.0
142+
mode: ModeIEEEArithmetic,
143+
rounding: RoundNearestEven,
144+
want: QuietNaN,
145+
wantErr: false,
146+
},
147+
{
148+
name: "Zero times infinity in exact mode",
149+
a: PositiveZero,
150+
b: PositiveInfinity,
151+
mode: ModeExactArithmetic,
152+
rounding: RoundNearestEven,
153+
want: 0,
154+
wantErr: true,
155+
},
156+
{
157+
name: "Zero times infinity in IEEE mode",
158+
a: PositiveZero,
159+
b: PositiveInfinity,
160+
mode: ModeIEEEArithmetic,
161+
rounding: RoundNearestEven,
162+
want: QuietNaN,
163+
wantErr: false,
164+
},
165+
}
166+
167+
for _, tt := range tests {
168+
t.Run(tt.name, func(t *testing.T) {
169+
got, err := MulWithMode(tt.a, tt.b, tt.mode, tt.rounding)
170+
if (err != nil) != tt.wantErr {
171+
t.Errorf("MulWithMode() error = %v, wantErr %v", err, tt.wantErr)
172+
return
173+
}
174+
if !tt.wantErr && got != tt.want {
175+
t.Errorf("MulWithMode() = %v, want %v", got, tt.want)
176+
}
177+
})
178+
}
179+
}
180+
181+
func TestDivWithMode(t *testing.T) {
182+
tests := []struct {
183+
name string
184+
a, b Float16
185+
mode ArithmeticMode
186+
rounding RoundingMode
187+
want Float16
188+
wantErr bool
189+
}{
190+
{
191+
name: "exact mode with valid inputs",
192+
a: FromBits(0x4600), // 6.0 (0x4600 = 6.0, 0x4800 = 8.0)
193+
b: FromBits(0x4200), // 3.0
194+
mode: ModeExactArithmetic,
195+
rounding: RoundNearestEven,
196+
want: FromBits(0x4000), // 2.0 (0x4000 = 2.0, 0x4400 = 4.0)
197+
wantErr: false,
198+
},
199+
{
200+
name: "exact mode with NaN",
201+
a: NaN(),
202+
b: FromBits(0x3C00), // 1.0
203+
mode: ModeExactArithmetic,
204+
rounding: RoundNearestEven,
205+
want: 0,
206+
wantErr: true,
207+
},
208+
{
209+
name: "division by zero",
210+
a: FromBits(0x3C00), // 1.0
211+
b: FromBits(0x0000), // 0.0
212+
mode: ModeExactArithmetic,
213+
rounding: RoundNearestEven,
214+
want: 0,
215+
wantErr: true,
216+
},
217+
{
218+
name: "infinity divided by infinity",
219+
a: Infinity(1),
220+
b: Infinity(1),
221+
mode: ModeExactArithmetic,
222+
rounding: RoundNearestEven,
223+
want: 0,
224+
wantErr: true,
225+
},
226+
{
227+
name: "DivWithMode - Division by zero",
228+
a: FromBits(0x3C00),
229+
b: PositiveZero,
230+
mode: ModeIEEEArithmetic,
231+
rounding: RoundNearestEven,
232+
want: PositiveInfinity,
233+
wantErr: false,
234+
},
235+
}
236+
237+
for _, tt := range tests {
238+
t.Run(tt.name, func(t *testing.T) {
239+
got, err := DivWithMode(tt.a, tt.b, tt.mode, tt.rounding)
240+
if (err != nil) != tt.wantErr {
241+
t.Errorf("DivWithMode() error = %v, wantErr %v", err, tt.wantErr)
242+
return
243+
}
244+
if !tt.wantErr && got != tt.want {
245+
t.Errorf("DivWithMode() = %v, want %v", got, tt.want)
246+
}
247+
})
248+
}
249+
}
250+
251+
// Infinity returns positive or negative infinity based on the sign parameter
252+
func Infinity(sign int) Float16 {
253+
if sign >= 0 {
254+
return PositiveInfinity
255+
}
256+
return NegativeInfinity
257+
}
258+
259+
func TestSliceOperationsWithMode(t *testing.T) {
260+
t.Run("SubSlice", func(t *testing.T) {
261+
a := []Float16{FromBits(0x4400), FromBits(0x4500), FromBits(0x4600)} // [4.0, 5.0, 6.0]
262+
b := []Float16{FromBits(0x3C00), FromBits(0x4000), FromBits(0x4200)} // [1.0, 2.0, 3.0]
263+
want := []Float16{FromBits(0x4200), FromBits(0x4200), FromBits(0x4200)} // [3.0, 3.0, 3.0]
264+
got := SubSlice(a, b)
265+
if len(got) != len(want) {
266+
t.Fatalf("SubSlice() length = %d, want %d", len(got), len(want))
267+
}
268+
for i := range got {
269+
if got[i] != want[i] {
270+
t.Errorf("SubSlice()[%d] = %v, want %v", i, got[i], want[i])
271+
}
272+
}
273+
})
274+
275+
t.Run("MulSlice", func(t *testing.T) {
276+
a := []Float16{FromBits(0x3C00), FromBits(0x4000), FromBits(0x4400)} // [1.0, 2.0, 4.0]
277+
b := []Float16{FromBits(0x4400), FromBits(0x4400), FromBits(0x4400)} // [4.0, 4.0, 4.0]
278+
want := []Float16{FromBits(0x4400), FromBits(0x4800), FromBits(0x4C00)} // [4.0, 8.0, 16.0]
279+
got := MulSlice(a, b)
280+
if len(got) != len(want) {
281+
t.Fatalf("MulSlice() length = %d, want %d", len(got), len(want))
282+
}
283+
for i := range got {
284+
if got[i] != want[i] {
285+
t.Errorf("MulSlice()[%d] = %v, want %v", i, got[i], want[i])
286+
}
287+
}
288+
})
289+
290+
t.Run("DivSlice", func(t *testing.T) {
291+
a := []Float16{FromBits(0x4400), FromBits(0x4800), FromBits(0x4C00)} // [4.0, 8.0, 16.0]
292+
b := []Float16{FromBits(0x3C00), FromBits(0x4000), FromBits(0x4400)} // [1.0, 2.0, 4.0]
293+
want := []Float16{FromBits(0x4400), FromBits(0x4400), FromBits(0x4400)} // [4.0, 4.0, 4.0]
294+
got := DivSlice(a, b)
295+
if len(got) != len(want) {
296+
t.Fatalf("DivSlice() length = %d, want %d", len(got), len(want))
297+
}
298+
for i := range got {
299+
if got[i] != want[i] {
300+
t.Errorf("DivSlice()[%d] = %v, want %v", i, got[i], want[i])
301+
}
302+
}
303+
})
304+
305+
t.Run("ScaleSlice", func(t *testing.T) {
306+
s := []Float16{FromBits(0x3C00), FromBits(0x4000), FromBits(0x4200)}
307+
scalar := FromBits(0x4000)
308+
want := []Float16{FromBits(0x4000), FromBits(0x4400), FromBits(0x4600)}
309+
got := ScaleSlice(s, scalar)
310+
if len(got) != len(want) {
311+
t.Fatalf("ScaleSlice() length = %d, want %d", len(got), len(want))
312+
}
313+
for i := range got {
314+
if got[i] != want[i] {
315+
t.Errorf("ScaleSlice()[%d] = %v, want %v", i, got[i], want[i])
316+
}
317+
}
318+
})
319+
320+
t.Run("SumSlice", func(t *testing.T) {
321+
s := []Float16{FromBits(0x3C00), FromBits(0x4000), FromBits(0x4200)}
322+
want := FromBits(0x4600) // 6.0
323+
got := SumSlice(s)
324+
if got != want {
325+
t.Errorf("SumSlice() = %v, want %v", got, want)
326+
}
327+
})
328+
329+
t.Run("Norm2", func(t *testing.T) {
330+
s := []Float16{FromBits(0x4200), FromBits(0x4400)} // 3-4-5 right triangle
331+
want := FromBits(0x4500) // 5.0
332+
got := Norm2(s)
333+
if got != want {
334+
t.Errorf("Norm2() = %v, want %v", got, want)
335+
}
336+
})
337+
}

bitpattern_test.go

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
package float16
2+
3+
import (
4+
"testing"
5+
)
6+
7+
func TestBitPatterns(t *testing.T) {
8+
tests := []struct {
9+
name string
10+
bits uint16
11+
}{
12+
{"1.0", 0x3C00},
13+
{"2.0", 0x4000},
14+
{"4.0", 0x4400},
15+
{"8.0", 0x4800},
16+
{"16.0", 0x4C00},
17+
{"32.0", 0x5000},
18+
{"0x3136", 0x3136},
19+
{"0x3332", 0x3332},
20+
}
21+
22+
for _, tt := range tests {
23+
t.Run(tt.name, func(t *testing.T) {
24+
f := FromBits(tt.bits)
25+
_ = f // Use the value to prevent unused variable warning
26+
})
27+
}
28+
}

0 commit comments

Comments
 (0)