Skip to content

Commit 94c5ab5

Browse files
committed
gpu puzzles - buffered terminal outputs, split multi-input vectors when printing to web terminal
1 parent 9bd5282 commit 94c5ab5

2 files changed

Lines changed: 90 additions & 48 deletions

File tree

experimental/fasthtml/gpu_puzzles/evaluator.h

Lines changed: 68 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -943,7 +943,8 @@ std::string getTemplate(int puzzleIndex) {
943943
result +=
944944
R"(fn main(
945945
@builtin(global_invocation_id) gid: vec3<u32>,
946-
@builtin(local_invocation_id) lid: vec3<u32>) {
946+
@builtin(workgroup_id) wid: vec3<u32>,
947+
@builtin(local_invocation_id) lid: vec3<u32>,) {
947948
let i: u32 = gid.x;
948949
out[i] = in0[i];
949950
}
@@ -983,8 +984,8 @@ bool evaluate(Context &ctx, const std::string &kernelCode, int puzzleIndex) {
983984
CompilationInfo compilationInfo;
984985

985986
bool allPassed = true;
986-
for (int i = 0; i < testCases.size(); ++i) {
987-
auto testCase = testCases[i];
987+
for (int caseIdx = 0; caseIdx < testCases.size(); ++caseIdx) {
988+
auto testCase = testCases[caseIdx];
988989

989990
auto start = std::chrono::high_resolution_clock::now();
990991

@@ -1041,36 +1042,78 @@ bool evaluate(Context &ctx, const std::string &kernelCode, int puzzleIndex) {
10411042
auto end = std::chrono::high_resolution_clock::now();
10421043
std::chrono::duration<double> elapsed = end - start;
10431044

1045+
const char *red = "\033[1;31m";
1046+
const char *green = "\033[1;32m";
1047+
const char *grey = "\033[1;30m";
1048+
const char *reset = "\033[0m";
1049+
10441050
// wprintf("Time taken: %f s\n", elapsed.count());
1051+
1052+
bool compilePassed = true;
1053+
int ptr = 0;
1054+
constexpr size_t kBufSize = 1024 * 10;
1055+
char buf[kBufSize];
1056+
10451057
if (compilationInfo.messages.size() > 0) {
1046-
for (size_t idx = 0; idx < compilationInfo.messages.size(); ++idx) {
1047-
wprintf("\033[1;31mError\033[0m line %d, column %d:\n",
1048-
static_cast<int>(compilationInfo.lineNums[idx]),
1049-
static_cast<int>(compilationInfo.linePos[idx]));
1050-
wprintf(" %s\n",compilationInfo.messages[idx].c_str());
1051-
1058+
if (caseIdx == 0) {
1059+
// Don't print compilation errors more than once
1060+
for (size_t idx = 0; idx < compilationInfo.messages.size(); ++idx) {
1061+
ptr +=
1062+
snprintf(buf, kBufSize, "%sError%s line %d, column %d:\n\r", red,
1063+
reset, static_cast<int>(compilationInfo.lineNums[idx]),
1064+
static_cast<int>(compilationInfo.linePos[idx]));
1065+
ptr += snprintf(buf + ptr, kBufSize - ptr, "%s\n\n\r",
1066+
compilationInfo.messages[idx].c_str());
1067+
}
1068+
ptr += snprintf(buf + ptr, kBufSize,
1069+
"* * * * * * * *\n\n\r");
10521070
}
10531071
allPassed = false;
1054-
break; // don't iterate to other tests
1072+
compilePassed = false;
1073+
}
1074+
1075+
// ptr = 0;
1076+
if (compilePassed && checkOutput(output, testCase.expectedOutput)) {
1077+
ptr += snprintf(buf + ptr, kBufSize, "Test case %d %sPASSED%s\n\n\r",
1078+
caseIdx + 1, green, reset);
10551079
} else {
1080+
ptr += snprintf(buf + ptr, kBufSize, "Test case %d %sFAILED%s\n\n\r",
1081+
caseIdx + 1, red, reset);
1082+
allPassed = false;
1083+
}
10561084

1057-
if (checkOutput(output, testCase.expectedOutput)) {
1058-
wprintf("Test case %d \033[1;32mPASSED\033[0m\n", i + 1);
1059-
} else {
1060-
wprintf("Test case %d \033[1;31mFAILED\033[0m\n", i + 1);
1061-
allPassed = false;
1085+
ptr += snprintf(buf + ptr, kBufSize,
1086+
"\033[1;30mWorkgroup Size ( %s )\n\r",
1087+
toString(testCase.workgroupSize).c_str());
1088+
ptr += snprintf(buf + ptr, kBufSize,
1089+
"Number of Workgroups ( %s )\n\033[0m\n\r",
1090+
toString(testCase.gridSize).c_str());
1091+
1092+
wprintf("%s", buf);
1093+
1094+
if (testCase.nInputs > 1) {
1095+
for (size_t inp = 0; inp < testCase.nInputs; ++inp) {
1096+
size_t sz = testCase.input.size() / testCase.nInputs;
1097+
size_t offset = inp * sz;
1098+
snprintf(buf, sizeof(buf), "%sInput %zu%s", grey, inp, reset);
1099+
printVec({begin(testCase.input) + offset,
1100+
begin(testCase.input) + offset + sz},
1101+
buf);
10621102
}
1063-
1064-
// print workgrou psize and num workgroups
1065-
wprintf("Workgroup Size ( %s )",
1066-
toString(testCase.workgroupSize).c_str());
1067-
wprintf("Number of Workgroups ( %s )",
1068-
toString(testCase.gridSize).c_str());
1069-
1070-
printVec(testCase.input, "\nInput ");
1071-
printVec(output, "Got ");
1072-
printVec(testCase.expectedOutput, "Expected");
1103+
} else {
1104+
snprintf(buf, sizeof(buf), "%sInput %s", grey, reset);
1105+
printVec(testCase.input, buf);
1106+
}
1107+
if (compilePassed) {
1108+
wprintf("");
1109+
snprintf(buf, sizeof(buf), "%sGot %s", grey, reset);
1110+
printVec(output, buf);
1111+
wprintf("");
10731112
}
1113+
1114+
snprintf(buf, sizeof(buf), "%sExpected%s", grey, reset);
1115+
printVec(testCase.expectedOutput, buf);
1116+
wprintf("");
10741117
}
10751118

10761119
return allPassed;

experimental/fasthtml/gpu_puzzles/run.py

Lines changed: 22 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -19,27 +19,28 @@
1919
height: 100vh
2020
"""
2121

22+
23+
def button(text, id):
24+
return Button(
25+
text,
26+
cls="bg-blue-300 hover:bg-blue-900 text-white font-bold py-2 px-4 rounded",
27+
id=id,
28+
)
29+
30+
2231
def controls():
2332
# left and right buttons
2433
return (
2534
Div(
2635
Div(
27-
Button(
28-
"<<",
29-
cls="bg-blue-300 hover:bg-blue-900 text-white font-bold py-2 px-4 rounded",
30-
id="prev",
31-
),
36+
button("<<", "prev"),
3237
# don't start a new row for div
3338
Div(
3439
"Puzzle 1: Map",
3540
id="puzzle_name",
36-
style="font-size: 1.5rem; width: 25vw; font-weight: bold;"
37-
),
38-
Button(
39-
">>",
40-
cls="bg-blue-300 hover:bg-blue-900 text-white font-bold py-2 px-4 rounded",
41-
id="next",
41+
style="font-size: 1.5rem; width: 25vw; font-weight: bold;",
4242
),
43+
button(">>", "next"),
4344
style="display: flex; align-items: center; justify-content: center;",
4445
),
4546
style="text-align: center; margin-top: 5vh; margin-left: 2rem; margin-right: 2rem;",
@@ -52,8 +53,6 @@ def controls():
5253
)
5354

5455

55-
56-
5756
def init_app() -> str:
5857
return f"""
5958
document.addEventListener('DOMContentLoaded', () => {{
@@ -111,19 +110,18 @@ async def serve_wasm():
111110

112111

113112
def output():
114-
correctHeight = 27;
115-
outputHeight = 100 - correctHeight + 1;
113+
correctHeight = 27
114+
outputHeight = 100 - correctHeight + 1
116115
return (
117116
Div(
118-
"(Result Check)",
117+
Div(
118+
"(Result Check)",
119+
id="correct",
120+
),
121+
button("Solution", "solution"),
119122
style=f"width: 49vw; height:{correctHeight / 3 * 2}vh; float: right; font-size: 2rem; text-align: center; align-items: center; justify-content: center; margin-top: {correctHeight / 3}vh;",
120-
id="correct",
121123
),
122-
Div(
123-
id="output",
124-
style=f"width: 49vw; height:{outputHeight}vh; float: right;"
125-
),
126-
124+
Div(id="output", style=f"width: 49vw; height:{outputHeight}vh; float: right;"),
127125
)
128126

129127

@@ -138,6 +136,7 @@ def CodeEditor():
138136
),
139137
)
140138

139+
141140
@app.get("/")
142141
def get():
143142
return (
@@ -151,7 +150,7 @@ def get():
151150
),
152151
output(),
153152
),
154-
style = body_style,
153+
style=body_style,
155154
),
156155
)
157156

0 commit comments

Comments
 (0)