Skip to content

Commit 29cf558

Browse files
authored
Merge pull request #259 from Joshua-Dias-Barreto/oneHotEncoding
Fixed #222. Added DataSeries>> #encodeOneHot
2 parents 4aaaf37 + bc6cd84 commit 29cf558

2 files changed

Lines changed: 81 additions & 0 deletions

File tree

src/DataFrame-Tests/DataSeriesTest.class.st

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1273,6 +1273,68 @@ DataSeriesTest >> testEighth [
12731273
self assert: series eighth equals: 10
12741274
]
12751275

1276+
{ #category : #'tests - converting' }
1277+
DataSeriesTest >> testEncodeOneHot [
1278+
1279+
| actual expected |
1280+
actual := #( 1 2 3 4 ) asDataSeries encodeOneHot.
1281+
expected := #(
1282+
#( 1 0 0 0 )
1283+
#( 0 1 0 0 )
1284+
#( 0 0 1 0 )
1285+
#( 0 0 0 1 )
1286+
) asDataSeries.
1287+
1288+
self assert: actual equals: expected
1289+
]
1290+
1291+
{ #category : #'tests - converting' }
1292+
DataSeriesTest >> testEncodeOneHotRomanNumbers [
1293+
1294+
| actual expected |
1295+
actual := (#( I XIV VII XXXII ) collect: [ :each | each romanNumber ])
1296+
asDataSeries encodeOneHot.
1297+
expected := #(
1298+
#( 1 0 0 0 )
1299+
#( 0 0 1 0 )
1300+
#( 0 1 0 0 )
1301+
#( 0 0 0 1 )
1302+
) asDataSeries.
1303+
1304+
self assert: actual equals: expected
1305+
]
1306+
1307+
{ #category : #'tests - converting' }
1308+
DataSeriesTest >> testEncodeOneHotStrings [
1309+
1310+
| actual expected |
1311+
actual := #( apple avocado orange banana ) asDataSeries encodeOneHot.
1312+
expected := #(
1313+
#( 1 0 0 0 )
1314+
#( 0 1 0 0 )
1315+
#( 0 0 0 1 )
1316+
#( 0 0 1 0 )
1317+
) asDataSeries.
1318+
1319+
self assert: actual equals: expected
1320+
]
1321+
1322+
{ #category : #'tests - converting' }
1323+
DataSeriesTest >> testEncodeOneHotWithDecimalAndLargeValues [
1324+
1325+
| actual expected |
1326+
actual := #( 0.5 1 2.2 3 4000 ) asDataSeries encodeOneHot.
1327+
expected := #(
1328+
#( 1 0 0 0 0 )
1329+
#( 0 1 0 0 0 )
1330+
#( 0 0 1 0 0 )
1331+
#( 0 0 0 1 0 )
1332+
#( 0 0 0 0 1 )
1333+
) asDataSeries.
1334+
1335+
self assert: actual equals: expected
1336+
]
1337+
12761338
{ #category : #'tests - comparing' }
12771339
DataSeriesTest >> testEquality [
12781340

src/DataFrame/DataSeries.class.st

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -319,6 +319,25 @@ DataSeries >> eighth [
319319
^ self atIndex: 8
320320
]
321321

322+
{ #category : #converting }
323+
DataSeries >> encodeOneHot [
324+
"Encode the values of the DataSeries into one-hot vectors."
325+
326+
| uniqueValues encodingDataSeries oneHotValues |
327+
uniqueValues := self removeDuplicates sortIfPossible.
328+
encodingDataSeries := self class new.
329+
uniqueValues withIndexDo: [ :value :index |
330+
encodingDataSeries at: value put: index ].
331+
oneHotValues := self values collect: [ :value |
332+
| oneHot |
333+
oneHot := encodingDataSeries keys collect: [ :key |
334+
value = key
335+
ifTrue: [ 1 ]
336+
ifFalse: [ 0 ] ].
337+
oneHot ].
338+
^ DataSeries withKeys: self keys values: oneHotValues name: self name
339+
]
340+
322341
{ #category : #private }
323342
DataSeries >> errorKeyNotFound: aKey [
324343

0 commit comments

Comments
 (0)