Skip to content

Commit 43a7429

Browse files
committed
Refactor Byte Reader and decoder
1 parent e3c46bb commit 43a7429

5 files changed

Lines changed: 149 additions & 165 deletions

File tree

decoder.go

Lines changed: 93 additions & 156 deletions
Original file line numberDiff line numberDiff line change
@@ -1,57 +1,71 @@
11
package bencode
22

33
import (
4-
"bytes"
5-
"errors"
64
"fmt"
7-
"io"
85
"reflect"
96
"strconv"
107
)
118

12-
type reader struct {
13-
*bytes.Reader
9+
type decoder struct {
10+
reader *reader
1411
}
1512

16-
func newReader(data []byte) *reader {
17-
return &reader{bytes.NewReader(data)}
13+
func (d *decoder) init(data []byte) {
14+
d.reader = newReader(data)
1815
}
1916

20-
func (r *reader) ReadUntil(c byte) ([]byte, error) {
21-
res := []byte("")
22-
for {
23-
b, err := r.ReadByte()
24-
if err != nil {
25-
if errors.Is(err, io.EOF) {
26-
break
27-
}
28-
29-
return []byte{}, err
30-
}
31-
if b == c {
32-
break
17+
func (d *decoder) setArray(rv reflect.Value, values []interface{}) error {
18+
indirect := reflect.Indirect(rv)
19+
for k, value := range values {
20+
switch decodedValue := value.(type) {
21+
case string:
22+
indirect.Set(reflect.Append(indirect, reflect.ValueOf(decodedValue)))
23+
case []interface{}:
24+
n := reflect.New(indirect.Type().Elem())
25+
indirect.Set(reflect.Append(indirect, reflect.Indirect(n)))
26+
d.setArray(indirect.Index(k).Addr(), decodedValue)
3327
}
34-
res = append(res, b)
3528
}
36-
return res, nil
29+
return nil
3730
}
3831

39-
func (r *reader) readNBytes(n uint64) ([]byte, error) {
40-
res := []byte("")
41-
var i uint64
42-
for i = 0; i < n; i++ {
43-
b, err := r.ReadByte()
44-
if err != nil {
45-
return []byte(""), err
32+
func (d *decoder) setStruct(rv reflect.Value, values map[string]interface{}) error {
33+
indirect := reflect.Indirect(rv)
34+
if rv.Kind() != reflect.Pointer || rv.IsNil() {
35+
return fmt.Errorf("passed object is not a pointer or is nil: %v", reflect.TypeOf(rv))
36+
}
37+
if indirect.Kind() != reflect.Struct {
38+
return fmt.Errorf("passed object is not a struct: %v", reflect.TypeOf(rv))
39+
}
40+
fields := getStructFields(indirect.Type())
41+
for key, value := range values {
42+
if field, ok := fields[key]; ok {
43+
switch decodedValue := value.(type) {
44+
case string:
45+
switch field.Type.Kind() {
46+
case reflect.Uint64:
47+
insertValue, err := stringToUint(decodedValue)
48+
if err != nil {
49+
return err
50+
}
51+
indirect.FieldByName(field.Name).SetUint(insertValue)
52+
case reflect.String:
53+
indirect.FieldByName(field.Name).SetString(decodedValue)
54+
}
55+
case map[string]interface{}:
56+
n := reflect.New(indirect.FieldByName(field.Name).Type())
57+
indirect.FieldByName(field.Name).Set(reflect.Indirect(n))
58+
d.setStruct(indirect.FieldByName(field.Name).Addr(), decodedValue)
59+
case []interface{}:
60+
d.setArray(indirect.FieldByName(field.Name).Addr(), decodedValue)
61+
}
4662
}
47-
res = append(res, b)
4863
}
49-
return res, nil
64+
return nil
5065
}
5166

52-
func (r * reader) readStringLength() (uint64, error) {
53-
// Read the length of string
54-
size, err := r.ReadUntil(':')
67+
func (d *decoder) readStringLength() (uint64, error) {
68+
size, err := d.reader.readUntil(':')
5569
if err != nil {
5670
return 0, err
5771
}
@@ -62,215 +76,138 @@ func (r * reader) readStringLength() (uint64, error) {
6276
return res, nil
6377
}
6478

65-
func (r *reader) ReadString() (string, error) {
66-
size, err := r.readStringLength()
79+
func (d *decoder) readString() (string, error) {
80+
size, err := d.readStringLength()
6781
if err != nil {
6882
return "", err
6983
}
70-
b, err := r.readNBytes(size)
84+
b, err := d.reader.readNBytes(size)
7185
return string(b), err
7286
}
7387

74-
func (r *reader) readInteger() (string, error) {
75-
r.ReadByte()
76-
value, err := r.ReadUntil('e')
88+
func (d *decoder) readInteger() (string, error) {
89+
d.reader.ReadByte()
90+
value, err := d.reader.readUntil('e')
7791
if err != nil {
7892
return "", err
7993
}
8094
return string(value), nil
8195
}
8296

83-
func (r *reader) ReadList() ([]interface{}, error) {
97+
func (d *decoder) getValueType() (reflect.Kind, error) {
98+
b, err := d.reader.ReadByte()
99+
defer d.reader.UnreadByte()
100+
if err != nil {
101+
return reflect.Invalid, err
102+
}
103+
switch b {
104+
case 'i':
105+
return reflect.Uint64, nil
106+
case 'l':
107+
return reflect.Array, nil
108+
case 'd':
109+
return reflect.Map, nil
110+
default:
111+
return reflect.String, nil
112+
}
113+
}
114+
115+
func (d *decoder) readList() ([]interface{}, error) {
84116
res := make([]interface{}, 0)
85-
r.ReadByte()
117+
d.reader.ReadByte()
86118
var v interface{}
87119
for {
88-
e, err := r.ReadByte()
120+
e, err := d.reader.ReadByte()
89121
if err != nil {
90122
return res, err
91123
}
92124
if e == 'e' {
93125
break
94126
}
95-
r.UnreadByte()
96-
vt, err := r.getValueType()
127+
d.reader.UnreadByte()
128+
vt, err := d.getValueType()
97129
if err != nil {
98130
return res, err
99131
}
100132
switch vt {
101133
case reflect.Uint64:
102-
v, err = r.readInteger()
134+
v, err = d.readInteger()
103135
if err != nil {
104136
return res, err
105137
}
106138
case reflect.String:
107-
v, err = r.ReadString()
139+
v, err = d.readString()
108140
if err != nil {
109141
return res, err
110142
}
111143
case reflect.Array:
112-
v, err = r.ReadList()
144+
v, err = d.readList()
113145
if err != nil {
114146
return res, err
115147
}
116148
case reflect.Map:
117-
v, err = r.ReadDictionary()
149+
v, err = d.readDictionary()
118150
}
119151
res = append(res, v)
120152
}
121153
return res, nil
122154
}
123155

124-
func (r *reader) ReadDictionary() (map[string]interface{}, error) {
156+
func (d *decoder) readDictionary() (map[string]interface{}, error) {
125157
res := make(map[string]interface{})
126-
r.ReadByte()
158+
d.reader.ReadByte()
127159
var v interface{}
128160
for {
129-
s, _ := r.ReadString()
130-
vt, err := r.getValueType()
161+
s, _ := d.readString()
162+
vt, err := d.getValueType()
131163
if err != nil {
132164
return res, err
133165
}
134166
switch vt {
135167
case reflect.Uint64:
136-
v, err = r.readInteger()
168+
v, err = d.readInteger()
137169
if err != nil {
138170
return res, err
139171
}
140172
case reflect.String:
141-
v, err = r.ReadString()
173+
v, err = d.readString()
142174
if err != nil {
143175
return res, err
144176
}
145177
case reflect.Array:
146-
v, err = r.ReadList()
178+
v, err = d.readList()
147179
if err != nil {
148180
return res, err
149181
}
150182
case reflect.Map:
151-
v, err = r.ReadDictionary()
183+
v, err = d.readDictionary()
152184
}
153185
res[s] = v
154186
if err != nil {
155187
return res, err
156188
}
157-
e, err := r.ReadByte()
189+
e, err := d.reader.ReadByte()
158190
if err != nil {
159191
return res, err
160192
}
161193
if e == 'e' {
162194
break
163195
}
164-
r.UnreadByte()
196+
d.reader.UnreadByte()
165197
}
166198
return res, nil
167199
}
168200

169-
func (r *reader) getValueType() (reflect.Kind, error) {
170-
b, err := r.ReadByte()
171-
defer r.UnreadByte()
172-
if err != nil {
173-
return reflect.Invalid, err
174-
}
175-
switch b {
176-
case 'i':
177-
return reflect.Uint64, nil
178-
case 'l':
179-
return reflect.Array, nil
180-
case 'd':
181-
return reflect.Map, nil
182-
default:
183-
return reflect.String, nil
184-
}
185-
}
186-
187-
func (r *reader) readValue(vt reflect.Kind) (string, error) {
188-
var res string
189-
switch vt {
190-
case reflect.Uint64:
191-
res, err := r.readInteger()
192-
if err != nil {
193-
return res, err
194-
}
195-
return res, nil
196-
case reflect.String:
197-
res, err := r.ReadString()
198-
if err != nil {
199-
return res, err
200-
}
201-
return res, nil
202-
}
203-
return res, nil
204-
}
205-
206-
type decoder struct {
207-
reader *reader
208-
}
209-
210-
func (d *decoder) init(data []byte) {
211-
d.reader = newReader(data)
212-
}
213-
214-
func (d *decoder) setArray(rv reflect.Value, values []interface{}) error {
215-
indirect := reflect.Indirect(rv)
216-
for k, value := range values {
217-
switch decodedValue := value.(type) {
218-
case string:
219-
indirect.Set(reflect.Append(indirect, reflect.ValueOf(decodedValue)))
220-
case []interface{}:
221-
n := reflect.New(indirect.Type().Elem())
222-
indirect.Set(reflect.Append(indirect, reflect.Indirect(n)))
223-
d.setArray(indirect.Index(k).Addr(), decodedValue)
224-
}
225-
}
226-
return nil
227-
}
228-
229-
func (d *decoder) setStruct(rv reflect.Value, values map[string]interface{}) error {
230-
indirect := reflect.Indirect(rv)
231-
if rv.Kind() != reflect.Pointer || rv.IsNil() {
232-
return fmt.Errorf("passed object is not a pointer or is nil: %v", reflect.TypeOf(rv))
233-
}
234-
if indirect.Kind() != reflect.Struct {
235-
return fmt.Errorf("passed object is not a struct: %v", reflect.TypeOf(rv))
236-
}
237-
fields := getStructFields(indirect.Type())
238-
for key, value := range values {
239-
if field, ok := fields[key]; ok {
240-
switch decodedValue := value.(type) {
241-
case string:
242-
switch field.Type.Kind() {
243-
case reflect.Uint64:
244-
insertValue, err := stringToUint(decodedValue)
245-
if err != nil {
246-
return err
247-
}
248-
indirect.FieldByName(field.Name).SetUint(insertValue)
249-
case reflect.String:
250-
indirect.FieldByName(field.Name).SetString(decodedValue)
251-
}
252-
case map[string]interface{}:
253-
n := reflect.New(indirect.FieldByName(field.Name).Type())
254-
indirect.FieldByName(field.Name).Set(reflect.Indirect(n))
255-
d.setStruct(indirect.FieldByName(field.Name).Addr(), decodedValue)
256-
case []interface{}:
257-
d.setArray(indirect.FieldByName(field.Name).Addr(), decodedValue)
258-
}
259-
}
260-
}
261-
return nil
262-
}
263-
264201
func (d *decoder) unmarshal(v any) error{
265-
values, err := d.reader.ReadDictionary()
202+
values, err := d.readDictionary()
266203
if err != nil {
267204
return err
268205
}
269206
rv := reflect.ValueOf(v)
270207
return d.setStruct(rv, values)
271208
}
272209

273-
func decode(input []byte, v any) {
210+
func Decode(input []byte, v any) {
274211
var d decoder
275212
d.init(input)
276213
d.unmarshal(v)

0 commit comments

Comments
 (0)