mirror of
https://github.com/lobehub/lobe-chat.git
synced 2025-12-20 01:12:52 +08:00
🐛 fix: vertext ai create image (#9710)
This commit is contained in:
@@ -1,38 +0,0 @@
|
||||
import { ModelRuntime } from '@lobechat/model-runtime';
|
||||
import { LobeVertexAI } from '@lobechat/model-runtime/vertexai';
|
||||
import { ModelProvider } from 'model-bank';
|
||||
|
||||
import { checkAuth } from '@/app/(backend)/middleware/auth';
|
||||
import { safeParseJSON } from '@/utils/safeParseJSON';
|
||||
|
||||
import { POST as UniverseRoute } from '../[provider]/route';
|
||||
|
||||
export const maxDuration = 300;
|
||||
// due to the Chinese region does not support accessing Google
|
||||
// we need to use proxy to access it
|
||||
// refs: https://github.com/google/generative-ai-js/issues/29#issuecomment-1866246513
|
||||
// if (process.env.HTTP_PROXY_URL) {
|
||||
// const { setGlobalDispatcher, ProxyAgent } = require('undici');
|
||||
//
|
||||
// setGlobalDispatcher(new ProxyAgent({ uri: process.env.HTTP_PROXY_URL }));
|
||||
// }
|
||||
|
||||
export const POST: any = checkAuth(async (req: Request, { jwtPayload }) =>
|
||||
UniverseRoute(req, {
|
||||
createRuntime: () => {
|
||||
const googleAuthStr = jwtPayload.apiKey ?? process.env.VERTEXAI_CREDENTIALS ?? undefined;
|
||||
|
||||
const credentials = safeParseJSON(googleAuthStr);
|
||||
const googleAuthOptions = credentials ? { credentials } : undefined;
|
||||
|
||||
const instance = LobeVertexAI.initFromVertexAI({
|
||||
googleAuthOptions,
|
||||
location: process.env.VERTEXAI_LOCATION,
|
||||
project: !!credentials?.project_id ? credentials?.project_id : process.env.VERTEXAI_PROJECT,
|
||||
});
|
||||
|
||||
return new ModelRuntime(instance);
|
||||
},
|
||||
params: Promise.resolve({ provider: ModelProvider.VertexAI }),
|
||||
}),
|
||||
);
|
||||
@@ -20,6 +20,7 @@ import {
|
||||
LobeZhipuAI,
|
||||
ModelRuntime,
|
||||
} from '@lobechat/model-runtime';
|
||||
import { LobeVertexAI } from '@lobechat/model-runtime/vertexai';
|
||||
import { ClientSecretPayload } from '@lobechat/types';
|
||||
import { ModelProvider } from 'model-bank';
|
||||
import { describe, expect, it, vi } from 'vitest';
|
||||
@@ -128,6 +129,36 @@ describe('initModelRuntimeWithUserPayload method', () => {
|
||||
expect(runtime['_runtime']).toBeInstanceOf(LobeQwenAI);
|
||||
});
|
||||
|
||||
it('Vertex AI provider: with service account json', async () => {
|
||||
const credentials = {
|
||||
client_email: 'vertex@test-project.iam.gserviceaccount.com',
|
||||
private_key: '-----BEGIN PRIVATE KEY-----\nTEST\n-----END PRIVATE KEY-----\n',
|
||||
project_id: 'test-project',
|
||||
type: 'service_account',
|
||||
};
|
||||
const payload: ClientSecretPayload = { apiKey: JSON.stringify(credentials) };
|
||||
const initSpy = vi
|
||||
.spyOn(LobeVertexAI, 'initFromVertexAI')
|
||||
.mockImplementation((options: any) => {
|
||||
expect(options.project).toBe('test-project');
|
||||
expect(options.googleAuthOptions?.credentials?.private_key).toContain('TEST');
|
||||
|
||||
return new LobeGoogleAI({
|
||||
apiKey: 'avoid-error',
|
||||
client: {} as any,
|
||||
isVertexAi: true,
|
||||
});
|
||||
});
|
||||
|
||||
const runtime = await initModelRuntimeWithUserPayload(ModelProvider.VertexAI, payload);
|
||||
|
||||
expect(initSpy).toHaveBeenCalledTimes(1);
|
||||
expect(runtime).toBeInstanceOf(ModelRuntime);
|
||||
expect(runtime['_runtime']).toBeInstanceOf(LobeGoogleAI);
|
||||
|
||||
initSpy.mockRestore();
|
||||
});
|
||||
|
||||
it('Bedrock AI provider: with apikey, awsAccessKeyId, awsSecretAccessKey, awsRegion', async () => {
|
||||
const jwtPayload: ClientSecretPayload = {
|
||||
apiKey: 'user-bedrock-key',
|
||||
|
||||
@@ -1,5 +1,8 @@
|
||||
import { GoogleGenAIOptions } from '@google/genai';
|
||||
import { ModelRuntime } from '@lobechat/model-runtime';
|
||||
import { LobeVertexAI } from '@lobechat/model-runtime/vertexai';
|
||||
import { ClientSecretPayload } from '@lobechat/types';
|
||||
import { safeParseJSON } from '@lobechat/utils';
|
||||
import { ModelProvider } from 'model-bank';
|
||||
|
||||
import { getLLMConfig } from '@/envs/llm';
|
||||
@@ -20,6 +23,10 @@ const getParamsFromPayload = (provider: string, payload: ClientSecretPayload) =>
|
||||
const llmConfig = getLLMConfig() as Record<string, any>;
|
||||
|
||||
switch (provider) {
|
||||
case ModelProvider.VertexAI: {
|
||||
return {};
|
||||
}
|
||||
|
||||
default: {
|
||||
let upperProvider = provider.toUpperCase();
|
||||
|
||||
@@ -116,6 +123,35 @@ const getParamsFromPayload = (provider: string, payload: ClientSecretPayload) =>
|
||||
}
|
||||
};
|
||||
|
||||
const buildVertexOptions = (
|
||||
payload: ClientSecretPayload,
|
||||
params: Partial<GoogleGenAIOptions> = {},
|
||||
): GoogleGenAIOptions => {
|
||||
const rawCredentials = payload.apiKey ?? process.env.VERTEXAI_CREDENTIALS ?? '';
|
||||
const credentials = safeParseJSON<Record<string, string>>(rawCredentials);
|
||||
|
||||
const projectFromParams = params.project as string | undefined;
|
||||
const projectFromCredentials = credentials?.project_id;
|
||||
const projectFromEnv = process.env.VERTEXAI_PROJECT;
|
||||
|
||||
const project = projectFromParams ?? projectFromCredentials ?? projectFromEnv;
|
||||
const location =
|
||||
(params.location as string | undefined) ?? process.env.VERTEXAI_LOCATION ?? undefined;
|
||||
|
||||
const googleAuthOptions = params.googleAuthOptions ?? (credentials ? { credentials } : undefined);
|
||||
|
||||
const options: GoogleGenAIOptions = {
|
||||
...params,
|
||||
vertexai: true,
|
||||
};
|
||||
|
||||
if (googleAuthOptions) options.googleAuthOptions = googleAuthOptions;
|
||||
if (project) options.project = project;
|
||||
if (location) options.location = location as GoogleGenAIOptions['location'];
|
||||
|
||||
return options;
|
||||
};
|
||||
|
||||
/**
|
||||
* Initializes the agent runtime with the user payload in backend
|
||||
* @param provider - The provider name.
|
||||
@@ -130,6 +166,13 @@ export const initModelRuntimeWithUserPayload = (
|
||||
) => {
|
||||
const runtimeProvider = payload.runtimeProvider ?? provider;
|
||||
|
||||
if (runtimeProvider === ModelProvider.VertexAI) {
|
||||
const vertexOptions = buildVertexOptions(payload, params);
|
||||
const runtime = LobeVertexAI.initFromVertexAI(vertexOptions);
|
||||
|
||||
return new ModelRuntime(runtime);
|
||||
}
|
||||
|
||||
return ModelRuntime.initializeWithProvider(runtimeProvider, {
|
||||
...getParamsFromPayload(runtimeProvider, payload),
|
||||
...params,
|
||||
|
||||
Reference in New Issue
Block a user