Skip to content

Commit ca376aa

Browse files
nsicchaclaude
andcommitted
Add MWE: reduce_sum with tuple(param, data) and STAN_THREADS
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent 85b5472 commit ca376aa

8 files changed

Lines changed: 640 additions & 0 deletions

File tree

mwe/.env-main/Manifest.toml

Lines changed: 97 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,97 @@
1+
# This file is machine-generated - editing it directly is not advised
2+
3+
julia_version = "1.10.11"
4+
manifest_format = "2.0"
5+
project_hash = "5578dd2aac58cda5d4cf32fe876f733701deb830"
6+
7+
[[deps.ArgTools]]
8+
uuid = "0dad84c5-d112-42e6-8d28-ef12dabb789f"
9+
version = "1.1.1"
10+
11+
[[deps.Artifacts]]
12+
uuid = "56f22d72-fd6d-98f1-02f0-08ddc0907c33"
13+
14+
[[deps.BridgeStan]]
15+
deps = ["Downloads", "Inflate", "TOML", "Tar"]
16+
git-tree-sha1 = "f8689ac4ce3245df7a436889988762a6c4a9da12"
17+
uuid = "c88b6f0a-829e-4b0b-94b7-f06ab5908f5a"
18+
version = "2.7.0"
19+
20+
[[deps.Dates]]
21+
deps = ["Printf"]
22+
uuid = "ade2ca70-3891-5945-98fb-dc099432e06a"
23+
24+
[[deps.Downloads]]
25+
deps = ["ArgTools", "FileWatching", "LibCURL", "NetworkOptions"]
26+
uuid = "f43a241f-c20a-4ad4-852c-f6b1247861c6"
27+
version = "1.6.0"
28+
29+
[[deps.FileWatching]]
30+
uuid = "7b1f6079-737a-58dc-b8bc-7a2ca5c1b5ee"
31+
32+
[[deps.Inflate]]
33+
git-tree-sha1 = "d1b1b796e47d94588b3757fe84fbf65a5ec4a80d"
34+
uuid = "d25df0c9-e2be-5dd7-82c8-3ad0b3e990b9"
35+
version = "0.1.5"
36+
37+
[[deps.LibCURL]]
38+
deps = ["LibCURL_jll", "MozillaCACerts_jll"]
39+
uuid = "b27032c2-a3e7-50c8-80cd-2d36dbcbfd21"
40+
version = "0.6.4"
41+
42+
[[deps.LibCURL_jll]]
43+
deps = ["Artifacts", "LibSSH2_jll", "Libdl", "MbedTLS_jll", "Zlib_jll", "nghttp2_jll"]
44+
uuid = "deac9b47-8bc7-5906-a0fe-35ac56dc84c0"
45+
version = "8.4.0+0"
46+
47+
[[deps.LibSSH2_jll]]
48+
deps = ["Artifacts", "Libdl", "MbedTLS_jll"]
49+
uuid = "29816b5a-b9ab-546f-933c-edad1886dfa8"
50+
version = "1.11.0+1"
51+
52+
[[deps.Libdl]]
53+
uuid = "8f399da3-3557-5675-b5ff-fb832c97cbdb"
54+
55+
[[deps.MbedTLS_jll]]
56+
deps = ["Artifacts", "Libdl"]
57+
uuid = "c8ffd9c3-330d-5841-b78e-0817d7145fa1"
58+
version = "2.28.1010+0"
59+
60+
[[deps.MozillaCACerts_jll]]
61+
uuid = "14a3606d-f60d-562e-9121-12d972cd8159"
62+
version = "2025.12.2"
63+
64+
[[deps.NetworkOptions]]
65+
uuid = "ca575930-c2e3-43a9-ace4-1e988b2c1908"
66+
version = "1.2.0"
67+
68+
[[deps.Printf]]
69+
deps = ["Unicode"]
70+
uuid = "de0858da-6303-5e67-8744-51eddeeeb8d7"
71+
72+
[[deps.SHA]]
73+
uuid = "ea8e919c-243c-51af-8825-aaa63cd721ce"
74+
version = "0.7.0"
75+
76+
[[deps.TOML]]
77+
deps = ["Dates"]
78+
uuid = "fa267f1f-6049-4f14-aa54-33bafae1ed76"
79+
version = "1.0.3"
80+
81+
[[deps.Tar]]
82+
deps = ["ArgTools", "SHA"]
83+
uuid = "a4e569a6-e804-4fa4-b0f3-eef7a1d5b13e"
84+
version = "1.10.0"
85+
86+
[[deps.Unicode]]
87+
uuid = "4ec0a83e-493e-50e2-b9ac-8f72acf5a8f5"
88+
89+
[[deps.Zlib_jll]]
90+
deps = ["Libdl"]
91+
uuid = "83775a58-1f1d-513f-b197-d71354ab007a"
92+
version = "1.2.13+1"
93+
94+
[[deps.nghttp2_jll]]
95+
deps = ["Artifacts", "Libdl"]
96+
uuid = "8e850ede-7688-5339-a07c-302acd2aaf8d"
97+
version = "1.52.0+1"

mwe/.env-main/Project.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
[deps]
2+
BridgeStan = "c88b6f0a-829e-4b0b-94b7-f06ab5908f5a"

mwe/.env-worktree/Manifest.toml

Lines changed: 97 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,97 @@
1+
# This file is machine-generated - editing it directly is not advised
2+
3+
julia_version = "1.10.11"
4+
manifest_format = "2.0"
5+
project_hash = "5578dd2aac58cda5d4cf32fe876f733701deb830"
6+
7+
[[deps.ArgTools]]
8+
uuid = "0dad84c5-d112-42e6-8d28-ef12dabb789f"
9+
version = "1.1.1"
10+
11+
[[deps.Artifacts]]
12+
uuid = "56f22d72-fd6d-98f1-02f0-08ddc0907c33"
13+
14+
[[deps.BridgeStan]]
15+
deps = ["Downloads", "Inflate", "TOML", "Tar"]
16+
git-tree-sha1 = "f8689ac4ce3245df7a436889988762a6c4a9da12"
17+
uuid = "c88b6f0a-829e-4b0b-94b7-f06ab5908f5a"
18+
version = "2.7.0"
19+
20+
[[deps.Dates]]
21+
deps = ["Printf"]
22+
uuid = "ade2ca70-3891-5945-98fb-dc099432e06a"
23+
24+
[[deps.Downloads]]
25+
deps = ["ArgTools", "FileWatching", "LibCURL", "NetworkOptions"]
26+
uuid = "f43a241f-c20a-4ad4-852c-f6b1247861c6"
27+
version = "1.6.0"
28+
29+
[[deps.FileWatching]]
30+
uuid = "7b1f6079-737a-58dc-b8bc-7a2ca5c1b5ee"
31+
32+
[[deps.Inflate]]
33+
git-tree-sha1 = "d1b1b796e47d94588b3757fe84fbf65a5ec4a80d"
34+
uuid = "d25df0c9-e2be-5dd7-82c8-3ad0b3e990b9"
35+
version = "0.1.5"
36+
37+
[[deps.LibCURL]]
38+
deps = ["LibCURL_jll", "MozillaCACerts_jll"]
39+
uuid = "b27032c2-a3e7-50c8-80cd-2d36dbcbfd21"
40+
version = "0.6.4"
41+
42+
[[deps.LibCURL_jll]]
43+
deps = ["Artifacts", "LibSSH2_jll", "Libdl", "MbedTLS_jll", "Zlib_jll", "nghttp2_jll"]
44+
uuid = "deac9b47-8bc7-5906-a0fe-35ac56dc84c0"
45+
version = "8.4.0+0"
46+
47+
[[deps.LibSSH2_jll]]
48+
deps = ["Artifacts", "Libdl", "MbedTLS_jll"]
49+
uuid = "29816b5a-b9ab-546f-933c-edad1886dfa8"
50+
version = "1.11.0+1"
51+
52+
[[deps.Libdl]]
53+
uuid = "8f399da3-3557-5675-b5ff-fb832c97cbdb"
54+
55+
[[deps.MbedTLS_jll]]
56+
deps = ["Artifacts", "Libdl"]
57+
uuid = "c8ffd9c3-330d-5841-b78e-0817d7145fa1"
58+
version = "2.28.1010+0"
59+
60+
[[deps.MozillaCACerts_jll]]
61+
uuid = "14a3606d-f60d-562e-9121-12d972cd8159"
62+
version = "2025.12.2"
63+
64+
[[deps.NetworkOptions]]
65+
uuid = "ca575930-c2e3-43a9-ace4-1e988b2c1908"
66+
version = "1.2.0"
67+
68+
[[deps.Printf]]
69+
deps = ["Unicode"]
70+
uuid = "de0858da-6303-5e67-8744-51eddeeeb8d7"
71+
72+
[[deps.SHA]]
73+
uuid = "ea8e919c-243c-51af-8825-aaa63cd721ce"
74+
version = "0.7.0"
75+
76+
[[deps.TOML]]
77+
deps = ["Dates"]
78+
uuid = "fa267f1f-6049-4f14-aa54-33bafae1ed76"
79+
version = "1.0.3"
80+
81+
[[deps.Tar]]
82+
deps = ["ArgTools", "SHA"]
83+
uuid = "a4e569a6-e804-4fa4-b0f3-eef7a1d5b13e"
84+
version = "1.10.0"
85+
86+
[[deps.Unicode]]
87+
uuid = "4ec0a83e-493e-50e2-b9ac-8f72acf5a8f5"
88+
89+
[[deps.Zlib_jll]]
90+
deps = ["Libdl"]
91+
uuid = "83775a58-1f1d-513f-b197-d71354ab007a"
92+
version = "1.2.13+1"
93+
94+
[[deps.nghttp2_jll]]
95+
deps = ["Artifacts", "Libdl"]
96+
uuid = "8e850ede-7688-5339-a07c-302acd2aaf8d"
97+
version = "1.52.0+1"

mwe/.env-worktree/Project.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
[deps]
2+
BridgeStan = "c88b6f0a-829e-4b0b-94b7-f06ab5908f5a"

mwe/Project.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
[deps]
2+
BridgeStan = "c88b6f0a-829e-4b0b-94b7-f06ab5908f5a"

mwe/mwe.jl

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
using BridgeStan
2+
3+
stan_code = """
4+
functions {
5+
real partial_sum(array[] real slice, int start, int end,
6+
tuple(real, int) params) {
7+
real mu = params.1;
8+
int K = params.2;
9+
real lp = 0;
10+
for (i in 1:size(slice)) {
11+
lp += normal_lpdf(slice[i] | mu, K);
12+
}
13+
return lp;
14+
}
15+
}
16+
data {
17+
int N;
18+
int K;
19+
array[N] real y;
20+
}
21+
parameters {
22+
real mu;
23+
}
24+
model {
25+
mu ~ normal(0, 10);
26+
target += reduce_sum(partial_sum, y, 1, (mu, K));
27+
}
28+
"""
29+
30+
stan_math = ENV["MWE_RUN_DIR"]
31+
label = ENV["MWE_LABEL"]
32+
33+
workdir = mktempdir()
34+
stan_file = joinpath(workdir, "mwe.stan")
35+
write(stan_file, stan_code)
36+
37+
lib = compile_model(stan_file; make_args=["STAN_MATH=$stan_math", "STAN_THREADS=true"])
38+
39+
data = """{"N": 5, "K": 2, "y": [1.0, 2.0, 3.0, 4.0, 5.0]}"""
40+
sm = StanModel(lib, data)
41+
42+
params = [3.0] # mu (unconstrained)
43+
lp = log_density(sm, params)
44+
lp_grad, grad = log_density_gradient(sm, params)
45+
46+
println("[$label] log_density = $lp")
47+
println("[$label] gradient = $grad")
48+
49+
@assert isfinite(lp) "log_density should be finite"
50+
@assert all(isfinite, grad) "gradient should be finite"
51+
@assert lp lp_grad "log_density values should match"
52+
@assert length(grad) == 1 "gradient should have 1 element"
53+
54+
println("[$label] All checks passed!")

0 commit comments

Comments
 (0)