Skip to content

Commit 1d91357

Browse files
authored
Merge pull request #826 from docker/fix-mmproj-for-hf
feat: add support for multimodal projector in GGUF model building
2 parents b8c70b6 + 76fa68e commit 1d91357

2 files changed

Lines changed: 113 additions & 5 deletions

File tree

pkg/distribution/huggingface/model.go

Lines changed: 30 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,9 @@ func BuildModel(ctx context.Context, client *Client, repo, revision, tag string,
9090
_ = progress.WriteProgress(progressWriter, "Building model artifact...", 0, 0, 0, "", "pull")
9191
}
9292

93-
model, err := buildModelFromFiles(result.LocalPaths, weightFiles, configFiles, tempDir, createdTime)
93+
model, err := buildModelFromFiles(
94+
result.LocalPaths, weightFiles, configFiles, mmprojFile, tempDir, createdTime,
95+
)
9496
if err != nil {
9597
return nil, fmt.Errorf("build model: %w", err)
9698
}
@@ -103,14 +105,20 @@ func BuildModel(ctx context.Context, client *Client, repo, revision, tag string,
103105
// which preserves directory structure and adds each file as an individual layer with
104106
// filepath annotations. For GGUF models, it uses the V0.1 packaging (FromPaths)
105107
// for backward compatibility.
106-
func buildModelFromFiles(localPaths map[string]string, weightFiles, configFiles []RepoFile, tempDir string, createdTime *time.Time) (types.ModelArtifact, error) {
108+
func buildModelFromFiles(
109+
localPaths map[string]string,
110+
weightFiles, configFiles []RepoFile,
111+
mmprojFile *RepoFile,
112+
tempDir string,
113+
createdTime *time.Time,
114+
) (types.ModelArtifact, error) {
107115
// Check if this is a safetensors model - use V0.2 packaging
108116
if isSafetensorsModel(weightFiles) {
109117
return buildSafetensorsModelV02(tempDir, createdTime)
110118
}
111119

112120
// For GGUF models, use V0.1 packaging (backward compatible)
113-
return buildGGUFModelV01(localPaths, weightFiles, configFiles, createdTime)
121+
return buildGGUFModelV01(localPaths, weightFiles, configFiles, mmprojFile, createdTime)
114122
}
115123

116124
// buildSafetensorsModelV02 builds a safetensors model using V0.2 layer-per-file packaging.
@@ -133,7 +141,12 @@ func buildSafetensorsModelV02(tempDir string, createdTime *time.Time) (types.Mod
133141
}
134142

135143
// buildGGUFModelV01 builds a GGUF model using V0.1 packaging (backward compatible).
136-
func buildGGUFModelV01(localPaths map[string]string, weightFiles, configFiles []RepoFile, createdTime *time.Time) (types.ModelArtifact, error) {
144+
func buildGGUFModelV01(
145+
localPaths map[string]string,
146+
weightFiles, configFiles []RepoFile,
147+
mmprojFile *RepoFile,
148+
createdTime *time.Time,
149+
) (types.ModelArtifact, error) {
137150
// Collect weight file paths (sorted for reproducibility)
138151
var weightPaths []string
139152
for _, f := range weightFiles {
@@ -157,7 +170,19 @@ func buildGGUFModelV01(localPaths map[string]string, weightFiles, configFiles []
157170
return nil, fmt.Errorf("create builder: %w", err)
158171
}
159172

160-
// Check for chat template and add it
173+
// Add multimodal projector if present (F16 preferred, selected upstream).
174+
if mmprojFile != nil {
175+
localPath, ok := localPaths[mmprojFile.Path]
176+
if !ok {
177+
return nil, fmt.Errorf("missing local path for mmproj %s", mmprojFile.Path)
178+
}
179+
b, err = b.WithMultimodalProjector(localPath)
180+
if err != nil {
181+
return nil, fmt.Errorf("add mmproj: %w", err)
182+
}
183+
}
184+
185+
// Check for chat template and add it.
161186
for _, f := range configFiles {
162187
if isChatTemplate(f.Path) {
163188
localPath := localPaths[f.Path]
Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,83 @@
1+
package huggingface
2+
3+
import (
4+
"path/filepath"
5+
"testing"
6+
"time"
7+
8+
"github.com/docker/model-runner/pkg/distribution/types"
9+
)
10+
11+
// TestBuildGGUFModelV01WithMMProj verifies that buildGGUFModelV01 includes
12+
// the multimodal projector as a MediaTypeMultimodalProjector layer when an
13+
// mmprojFile is provided.
14+
func TestBuildGGUFModelV01WithMMProj(t *testing.T) {
15+
assetsDir := filepath.Join("..", "assets")
16+
ggufPath := filepath.Join(assetsDir, "dummy.gguf")
17+
mmprojPath := filepath.Join(assetsDir, "dummy.mmproj")
18+
19+
weightFiles := []RepoFile{
20+
{Type: "file", Path: "dummy.gguf"},
21+
}
22+
mmprojFile := &RepoFile{Type: "file", Path: "mmproj-model-f16.gguf"}
23+
localPaths := map[string]string{
24+
"dummy.gguf": ggufPath,
25+
"mmproj-model-f16.gguf": mmprojPath,
26+
}
27+
28+
artifact, err := buildGGUFModelV01(localPaths, weightFiles, nil, mmprojFile, nil)
29+
if err != nil {
30+
t.Fatalf("buildGGUFModelV01 failed: %v", err)
31+
}
32+
33+
// Retrieve the manifest and look for the mmproj layer.
34+
manifest, err := artifact.Manifest()
35+
if err != nil {
36+
t.Fatalf("get manifest: %v", err)
37+
}
38+
39+
found := false
40+
for _, layer := range manifest.Layers {
41+
if layer.MediaType == types.MediaTypeMultimodalProjector {
42+
found = true
43+
break
44+
}
45+
}
46+
if !found {
47+
t.Errorf("expected manifest to contain a %s layer, but none was found",
48+
types.MediaTypeMultimodalProjector)
49+
}
50+
}
51+
52+
// TestBuildGGUFModelV01WithoutMMProj verifies that buildGGUFModelV01 succeeds
53+
// and produces no MediaTypeMultimodalProjector layer when no mmprojFile is
54+
// provided.
55+
func TestBuildGGUFModelV01WithoutMMProj(t *testing.T) {
56+
assetsDir := filepath.Join("..", "assets")
57+
ggufPath := filepath.Join(assetsDir, "dummy.gguf")
58+
59+
weightFiles := []RepoFile{
60+
{Type: "file", Path: "dummy.gguf"},
61+
}
62+
localPaths := map[string]string{
63+
"dummy.gguf": ggufPath,
64+
}
65+
createdTime := time.Now()
66+
67+
artifact, err := buildGGUFModelV01(localPaths, weightFiles, nil, nil, &createdTime)
68+
if err != nil {
69+
t.Fatalf("buildGGUFModelV01 failed: %v", err)
70+
}
71+
72+
manifest, err := artifact.Manifest()
73+
if err != nil {
74+
t.Fatalf("get manifest: %v", err)
75+
}
76+
77+
for _, layer := range manifest.Layers {
78+
if layer.MediaType == types.MediaTypeMultimodalProjector {
79+
t.Errorf("expected no %s layer, but one was found",
80+
types.MediaTypeMultimodalProjector)
81+
}
82+
}
83+
}

0 commit comments

Comments
 (0)