Skip to content

Commit 913b56c

Browse files
committed
feat: implement Azure OpenAI client support and update provider configuration
1 parent 27ac47f commit 913b56c

4 files changed

Lines changed: 42 additions & 10 deletions

File tree

internal/controller/ai_controller.go

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -292,6 +292,10 @@ func (c *AIController) createOpenAIClient() *openai.Client {
292292

293293
aiProvider := aiConfig.GetProvider()
294294

295+
if aiProvider.Provider == "azure_ai" {
296+
return c.createAzureAIClient(aiProvider)
297+
}
298+
295299
config = openai.DefaultConfig(aiProvider.APIKey)
296300
config.BaseURL = aiProvider.APIHost
297301
if !strings.HasSuffix(config.BaseURL, "/v1") {
@@ -300,6 +304,36 @@ func (c *AIController) createOpenAIClient() *openai.Client {
300304
return openai.NewClientWithConfig(config)
301305
}
302306

307+
// createAzureAIClient creates an OpenAI client configured for Azure AI.
308+
// Uses the Azure OpenAI compatibility endpoint: https://{resource}.openai.azure.com/openai/v1
309+
func (c *AIController) createAzureAIClient(aiProvider *schema.SiteAIProvider) *openai.Client {
310+
azureBaseURL := strings.TrimRight(aiProvider.APIHost, "/") + "/openai/v1"
311+
312+
config := openai.DefaultConfig(aiProvider.APIKey)
313+
config.BaseURL = azureBaseURL
314+
config.HTTPClient = &http.Client{
315+
Transport: &azureAPIKeyTransport{
316+
apiKey: aiProvider.APIKey,
317+
transport: http.DefaultTransport,
318+
},
319+
}
320+
return openai.NewClientWithConfig(config)
321+
}
322+
323+
// azureAPIKeyTransport is an http.RoundTripper that replaces the Authorization
324+
// header with the Azure-style api-key header for Azure OpenAI requests.
325+
type azureAPIKeyTransport struct {
326+
apiKey string
327+
transport http.RoundTripper
328+
}
329+
330+
func (t *azureAPIKeyTransport) RoundTrip(req *http.Request) (*http.Response, error) {
331+
req = req.Clone(req.Context())
332+
req.Header.Del("Authorization")
333+
req.Header.Set("api-key", t.apiKey)
334+
return t.transport.RoundTrip(req)
335+
}
336+
303337
// getPromptByLanguage
304338
func (c *AIController) getPromptByLanguage(language i18n.Language, question string) string {
305339
aiConfig, err := c.siteInfoService.GetSiteAI(context.Background())
@@ -497,6 +531,10 @@ func (c *AIController) processAIStream(
497531
break
498532
}
499533

534+
if len(response.Choices) == 0 {
535+
continue
536+
}
537+
500538
choice := response.Choices[0]
501539

502540
if len(choice.Delta.ToolCalls) > 0 {

internal/migrations/init_data.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -353,7 +353,7 @@ var (
353353
{ID: 128, Key: "rank.answer.undeleted", Value: `-1`},
354354
{ID: 129, Key: "rank.question.undeleted", Value: `-1`},
355355
{ID: 130, Key: "rank.tag.undeleted", Value: `-1`},
356-
{ID: 131, Key: "ai_config.provider", Value: `[{"default_api_host":"https://api.openai.com","display_name":"OpenAI","name":"openai"},{"default_api_host":"https://generativelanguage.googleapis.com","display_name":"Gemini","name":"gemini"},{"default_api_host":"https://api.anthropic.com","display_name":"Anthropic","name":"anthropic"},{"default_api_host":"https://{your-project}.services.ai.azure.com","display_name":"Azure AI Foundry","name":"azure_openai"}]`},
356+
{ID: 131, Key: "ai_config.provider", Value: `[{"default_api_host":"https://api.openai.com","display_name":"OpenAI","name":"openai"},{"default_api_host":"https://generativelanguage.googleapis.com","display_name":"Gemini","name":"gemini"},{"default_api_host":"https://api.anthropic.com","display_name":"Anthropic","name":"anthropic"},{"default_api_host":"https://{your-resource}.openai.azure.com","display_name":"Azure AI","name":"azure_ai"}]`},
357357
}
358358

359359
defaultBadgeGroupTable = []*entity.BadgeGroup{

internal/migrations/v31.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@ func addAPIKey(ctx context.Context, x *xorm.Engine) error {
6161
}
6262

6363
defaultConfigTable := []*entity.Config{
64-
{ID: 131, Key: "ai_config.provider", Value: `[{"default_api_host":"https://api.openai.com","display_name":"OpenAI","name":"openai"},{"default_api_host":"https://generativelanguage.googleapis.com","display_name":"Gemini","name":"gemini"},{"default_api_host":"https://api.anthropic.com","display_name":"Anthropic","name":"anthropic"},{"default_api_host":"https://{your-project}.services.ai.azure.com","display_name":"Azure AI Foundry","name":"azure_openai"}]`},
64+
{ID: 131, Key: "ai_config.provider", Value: `[{"default_api_host":"https://api.openai.com","display_name":"OpenAI","name":"openai"},{"default_api_host":"https://generativelanguage.googleapis.com","display_name":"Gemini","name":"gemini"},{"default_api_host":"https://api.anthropic.com","display_name":"Anthropic","name":"anthropic"},{"default_api_host":"https://{your-resource}.openai.azure.com","display_name":"Azure AI","name":"azure_ai"}]`},
6565
}
6666
for _, c := range defaultConfigTable {
6767
exist, err := x.Context(ctx).Get(&entity.Config{Key: c.Key})

internal/service/siteinfo/siteinfo_service.go

Lines changed: 2 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,6 @@ import (
2424
"encoding/json"
2525
errpkg "errors"
2626
"fmt"
27-
"net/url"
2827
"strings"
2928

3029
"github.com/apache/answer/internal/base/constant"
@@ -730,14 +729,9 @@ func (s *SiteInfoService) GetAIModels(ctx context.Context, req *schema.GetAIMode
730729
var respBody *resty.Response
731730
apiHost := strings.TrimRight(req.APIHost, "/")
732731
if req.Provider == "azure_ai" {
733-
// Azure AI: parse resource name from apiHost and list deployments via *.openai.azure.com
732+
// Azure AI: list deployments via the Azure OpenAI endpoint
734733
r.SetHeader("api-key", req.APIKey)
735-
parsedURL, parseErr := url.Parse(apiHost)
736-
if parseErr != nil || parsedURL.Host == "" {
737-
return resp, errors.BadRequest("invalid api_host URL")
738-
}
739-
resourceName := strings.Split(parsedURL.Hostname(), ".")[0]
740-
deploymentsURL := fmt.Sprintf("https://%s.openai.azure.com/openai/deployments?api-version=2022-12-01", resourceName)
734+
deploymentsURL := apiHost + "/openai/deployments?api-version=2022-12-01"
741735
respBody, err = r.R().Get(deploymentsURL)
742736
} else {
743737
// Standard OpenAI-compatible providers

0 commit comments

Comments
 (0)