@@ -943,7 +943,8 @@ std::string getTemplate(int puzzleIndex) {
943943 result +=
944944 R"( fn main(
945945 @builtin(global_invocation_id) gid: vec3<u32>,
946- @builtin(local_invocation_id) lid: vec3<u32>) {
946+ @builtin(workgroup_id) wid: vec3<u32>,
947+ @builtin(local_invocation_id) lid: vec3<u32>,) {
947948 let i: u32 = gid.x;
948949 out[i] = in0[i];
949950}
@@ -983,8 +984,8 @@ bool evaluate(Context &ctx, const std::string &kernelCode, int puzzleIndex) {
983984 CompilationInfo compilationInfo;
984985
985986 bool allPassed = true ;
986- for (int i = 0 ; i < testCases.size (); ++i ) {
987- auto testCase = testCases[i ];
987+ for (int caseIdx = 0 ; caseIdx < testCases.size (); ++caseIdx ) {
988+ auto testCase = testCases[caseIdx ];
988989
989990 auto start = std::chrono::high_resolution_clock::now ();
990991
@@ -1041,36 +1042,78 @@ bool evaluate(Context &ctx, const std::string &kernelCode, int puzzleIndex) {
10411042 auto end = std::chrono::high_resolution_clock::now ();
10421043 std::chrono::duration<double > elapsed = end - start;
10431044
1045+ const char *red = " \033 [1;31m" ;
1046+ const char *green = " \033 [1;32m" ;
1047+ const char *grey = " \033 [1;30m" ;
1048+ const char *reset = " \033 [0m" ;
1049+
10441050 // wprintf("Time taken: %f s\n", elapsed.count());
1051+
1052+ bool compilePassed = true ;
1053+ int ptr = 0 ;
1054+ constexpr size_t kBufSize = 1024 * 10 ;
1055+ char buf[kBufSize ];
1056+
10451057 if (compilationInfo.messages .size () > 0 ) {
1046- for (size_t idx = 0 ; idx < compilationInfo.messages .size (); ++idx) {
1047- wprintf (" \033 [1;31mError\033 [0m line %d, column %d:\n " ,
1048- static_cast <int >(compilationInfo.lineNums [idx]),
1049- static_cast <int >(compilationInfo.linePos [idx]));
1050- wprintf (" %s\n " ,compilationInfo.messages [idx].c_str ());
1051-
1058+ if (caseIdx == 0 ) {
1059+ // Don't print compilation errors more than once
1060+ for (size_t idx = 0 ; idx < compilationInfo.messages .size (); ++idx) {
1061+ ptr +=
1062+ snprintf (buf, kBufSize , " %sError%s line %d, column %d:\n\r " , red,
1063+ reset, static_cast <int >(compilationInfo.lineNums [idx]),
1064+ static_cast <int >(compilationInfo.linePos [idx]));
1065+ ptr += snprintf (buf + ptr, kBufSize - ptr, " %s\n\n\r " ,
1066+ compilationInfo.messages [idx].c_str ());
1067+ }
1068+ ptr += snprintf (buf + ptr, kBufSize ,
1069+ " * * * * * * * *\n\n\r " );
10521070 }
10531071 allPassed = false ;
1054- break ; // don't iterate to other tests
1072+ compilePassed = false ;
1073+ }
1074+
1075+ // ptr = 0;
1076+ if (compilePassed && checkOutput (output, testCase.expectedOutput )) {
1077+ ptr += snprintf (buf + ptr, kBufSize , " Test case %d %sPASSED%s\n\n\r " ,
1078+ caseIdx + 1 , green, reset);
10551079 } else {
1080+ ptr += snprintf (buf + ptr, kBufSize , " Test case %d %sFAILED%s\n\n\r " ,
1081+ caseIdx + 1 , red, reset);
1082+ allPassed = false ;
1083+ }
10561084
1057- if (checkOutput (output, testCase.expectedOutput )) {
1058- wprintf (" Test case %d \033 [1;32mPASSED\033 [0m\n " , i + 1 );
1059- } else {
1060- wprintf (" Test case %d \033 [1;31mFAILED\033 [0m\n " , i + 1 );
1061- allPassed = false ;
1085+ ptr += snprintf (buf + ptr, kBufSize ,
1086+ " \033 [1;30mWorkgroup Size ( %s )\n\r " ,
1087+ toString (testCase.workgroupSize ).c_str ());
1088+ ptr += snprintf (buf + ptr, kBufSize ,
1089+ " Number of Workgroups ( %s )\n\033 [0m\n\r " ,
1090+ toString (testCase.gridSize ).c_str ());
1091+
1092+ wprintf (" %s" , buf);
1093+
1094+ if (testCase.nInputs > 1 ) {
1095+ for (size_t inp = 0 ; inp < testCase.nInputs ; ++inp) {
1096+ size_t sz = testCase.input .size () / testCase.nInputs ;
1097+ size_t offset = inp * sz;
1098+ snprintf (buf, sizeof (buf), " %sInput %zu%s" , grey, inp, reset);
1099+ printVec ({begin (testCase.input ) + offset,
1100+ begin (testCase.input ) + offset + sz},
1101+ buf);
10621102 }
1063-
1064- // print workgrou psize and num workgroups
1065- wprintf (" Workgroup Size ( %s )" ,
1066- toString (testCase.workgroupSize ).c_str ());
1067- wprintf (" Number of Workgroups ( %s )" ,
1068- toString (testCase.gridSize ).c_str ());
1069-
1070- printVec (testCase.input , " \n Input " );
1071- printVec (output, " Got " );
1072- printVec (testCase.expectedOutput , " Expected" );
1103+ } else {
1104+ snprintf (buf, sizeof (buf), " %sInput %s" , grey, reset);
1105+ printVec (testCase.input , buf);
1106+ }
1107+ if (compilePassed) {
1108+ wprintf (" " );
1109+ snprintf (buf, sizeof (buf), " %sGot %s" , grey, reset);
1110+ printVec (output, buf);
1111+ wprintf (" " );
10731112 }
1113+
1114+ snprintf (buf, sizeof (buf), " %sExpected%s" , grey, reset);
1115+ printVec (testCase.expectedOutput , buf);
1116+ wprintf (" " );
10741117 }
10751118
10761119 return allPassed;
0 commit comments