diff --git a/core/src/browser/extensions/engines/LocalOAIEngine.test.ts b/core/src/browser/extensions/engines/LocalOAIEngine.test.ts index 4a36f6b12b..8a7722f3a6 100644 --- a/core/src/browser/extensions/engines/LocalOAIEngine.test.ts +++ b/core/src/browser/extensions/engines/LocalOAIEngine.test.ts @@ -44,48 +44,14 @@ describe('LocalOAIEngine', () => { it('should load model correctly', async () => { const model: Model = { engine: 'testProvider', file_path: 'path/to/model' } as any - const modelFolder = 'path/to' - const systemInfo = { os: 'testOS' } - const res = { error: null } - ;(dirName as jest.Mock).mockResolvedValue(modelFolder) - ;(systemInformation as jest.Mock).mockResolvedValue(systemInfo) - ;(executeOnMain as jest.Mock).mockResolvedValue(res) - - await engine.loadModel(model) - - expect(systemInformation).toHaveBeenCalled() - expect(executeOnMain).toHaveBeenCalledWith( - engine.nodeModule, - engine.loadModelFunctionName, - { modelFolder, model }, - systemInfo - ) - expect(events.emit).toHaveBeenCalledWith(ModelEvent.OnModelReady, model) - }) - - it('should handle load model error', async () => { - const model: any = { engine: 'testProvider', file_path: 'path/to/model' } as any - const modelFolder = 'path/to' - const systemInfo = { os: 'testOS' } - const res = { error: 'load error' } - - ;(dirName as jest.Mock).mockResolvedValue(modelFolder) - ;(systemInformation as jest.Mock).mockResolvedValue(systemInfo) - ;(executeOnMain as jest.Mock).mockResolvedValue(res) - - await expect(engine.loadModel(model)).rejects.toEqual('load error') - - expect(events.emit).toHaveBeenCalledWith(ModelEvent.OnModelFail, { error: res.error }) + expect(engine.loadModel(model)).toBeTruthy() }) it('should unload model correctly', async () => { const model: Model = { engine: 'testProvider' } as any - await engine.unloadModel(model) - - expect(executeOnMain).toHaveBeenCalledWith(engine.nodeModule, engine.unloadModelFunctionName) - expect(events.emit).toHaveBeenCalledWith(ModelEvent.OnModelStopped, {}) + expect(engine.unloadModel(model)).toBeTruthy() }) it('should not unload model if engine does not match', async () => { diff --git a/core/src/browser/extensions/engines/LocalOAIEngine.ts b/core/src/browser/extensions/engines/LocalOAIEngine.ts index cb5b6760e2..e8bd8cdf2d 100644 --- a/core/src/browser/extensions/engines/LocalOAIEngine.ts +++ b/core/src/browser/extensions/engines/LocalOAIEngine.ts @@ -36,11 +36,6 @@ export abstract class LocalOAIEngine extends OAIEngine { * Stops the model. */ override async unloadModel(model?: Model) { - if (model?.engine && model.engine?.toString() !== this.provider) return Promise.resolve() - - this.loadedModel = undefined - await executeOnMain(this.nodeModule, this.unloadModelFunctionName).then(() => { - events.emit(ModelEvent.OnModelStopped, {}) - }) + return Promise.resolve() } } diff --git a/core/src/node/api/processors/download.test.ts b/core/src/node/api/processors/download.test.ts index 370f1746f6..21d94165dc 100644 --- a/core/src/node/api/processors/download.test.ts +++ b/core/src/node/api/processors/download.test.ts @@ -8,7 +8,8 @@ jest.mock('../../helper', () => ({ jest.mock('../../helper/path', () => ({ validatePath: jest.fn().mockReturnValue('path/to/folder'), - normalizeFilePath: () => process.platform === 'win32' ? 'C:\\Users\path\\to\\file.gguf' : '/Users/path/to/file.gguf', + normalizeFilePath: () => + process.platform === 'win32' ? 'C:\\Users\\path\\to\\file.gguf' : '/Users/path/to/file.gguf', })) jest.mock( diff --git a/core/src/types/model/modelEntity.ts b/core/src/types/model/modelEntity.ts index ed1db94bd1..25ed95b8d6 100644 --- a/core/src/types/model/modelEntity.ts +++ b/core/src/types/model/modelEntity.ts @@ -31,7 +31,7 @@ export enum InferenceEngine { cortex = 'cortex', cortex_llamacpp = 'llama-cpp', cortex_onnx = 'onnxruntime', - cortex_tensorrtllm = '.tensorrt-llm', + cortex_tensorrtllm = 'tensorrt-llm', } export type ModelArtifact = { diff --git a/extensions/model-extension/rollup.config.ts b/extensions/model-extension/rollup.config.ts index 781c4df84d..64e62480ff 100644 --- a/extensions/model-extension/rollup.config.ts +++ b/extensions/model-extension/rollup.config.ts @@ -20,8 +20,8 @@ export default [ replace({ preventAssignment: true, SETTINGS: JSON.stringify(settingJson), - API_URL: 'http://127.0.0.1:39291', - SOCKET_URL: 'ws://127.0.0.1:39291', + API_URL: JSON.stringify('http://127.0.0.1:39291'), + SOCKET_URL: JSON.stringify('ws://127.0.0.1:39291'), }), // Allow json resolution json(), diff --git a/web/helpers/atoms/Model.atom.test.ts b/web/helpers/atoms/Model.atom.test.ts index 57827efec1..923f24df47 100644 --- a/web/helpers/atoms/Model.atom.test.ts +++ b/web/helpers/atoms/Model.atom.test.ts @@ -32,13 +32,22 @@ describe('Model.atom.ts', () => { }) describe('showEngineListModelAtom', () => { - it('should initialize as an empty array', () => { - expect(ModelAtoms.showEngineListModelAtom.init).toEqual(['nitro']) + it('should initialize with local engines', () => { + expect(ModelAtoms.showEngineListModelAtom.init).toEqual([ + 'nitro', + 'cortex', + 'llama-cpp', + 'onnxruntime', + 'tensorrt-llm', + ]) }) }) describe('addDownloadingModelAtom', () => { it('should add downloading model', async () => { + const { result: reset } = renderHook(() => + useSetAtom(ModelAtoms.downloadingModelsAtom) + ) const { result: setAtom } = renderHook(() => useSetAtom(ModelAtoms.addDownloadingModelAtom) ) @@ -49,11 +58,16 @@ describe('Model.atom.ts', () => { setAtom.current({ id: '1' } as any) }) expect(getAtom.current).toEqual([{ id: '1' }]) + reset.current([]) }) }) describe('removeDownloadingModelAtom', () => { it('should remove downloading model', async () => { + const { result: reset } = renderHook(() => + useSetAtom(ModelAtoms.downloadingModelsAtom) + ) + const { result: setAtom } = renderHook(() => useSetAtom(ModelAtoms.addDownloadingModelAtom) ) @@ -63,16 +77,21 @@ describe('Model.atom.ts', () => { const { result: getAtom } = renderHook(() => useAtomValue(ModelAtoms.getDownloadingModelAtom) ) + expect(getAtom.current).toEqual([]) act(() => { - setAtom.current({ id: '1' } as any) + setAtom.current('1') removeAtom.current('1') }) expect(getAtom.current).toEqual([]) + reset.current([]) }) }) describe('removeDownloadedModelAtom', () => { it('should remove downloaded model', async () => { + const { result: reset } = renderHook(() => + useSetAtom(ModelAtoms.downloadingModelsAtom) + ) const { result: setAtom } = renderHook(() => useSetAtom(ModelAtoms.downloadedModelsAtom) ) @@ -94,6 +113,7 @@ describe('Model.atom.ts', () => { removeAtom.current('1') }) expect(getAtom.current).toEqual([]) + reset.current([]) }) }) @@ -284,10 +304,4 @@ describe('Model.atom.ts', () => { expect(importAtom.current[0]).toEqual([]) }) }) - - describe('defaultModelAtom', () => { - it('should initialize as undefined', () => { - expect(ModelAtoms.defaultModelAtom.init).toBeUndefined() - }) - }) }) diff --git a/web/helpers/atoms/Model.atom.ts b/web/helpers/atoms/Model.atom.ts index 0f5367f64f..dd4414801a 100644 --- a/web/helpers/atoms/Model.atom.ts +++ b/web/helpers/atoms/Model.atom.ts @@ -64,13 +64,13 @@ export const stateModel = atom({ state: 'start', loading: false, model: '' }) /** * Stores the list of models which are being downloaded. */ -const downloadingModelsAtom = atom([]) +export const downloadingModelsAtom = atom([]) export const getDownloadingModelAtom = atom((get) => get(downloadingModelsAtom)) export const addDownloadingModelAtom = atom(null, (get, set, model: string) => { const downloadingModels = get(downloadingModelsAtom) - if (!downloadingModels.find((e) => e === model)) { + if (!downloadingModels.includes(model)) { set(downloadingModelsAtom, [...downloadingModels, model]) } }) diff --git a/web/hooks/useDeleteModel.test.ts b/web/hooks/useDeleteModel.test.ts index 3a6587d7b8..3ee0926f94 100644 --- a/web/hooks/useDeleteModel.test.ts +++ b/web/hooks/useDeleteModel.test.ts @@ -35,7 +35,7 @@ describe('useDeleteModel', () => { await result.current.deleteModel(mockModel) }) - expect(mockDeleteModel).toHaveBeenCalledWith(mockModel) + expect(mockDeleteModel).toHaveBeenCalledWith('test-model') expect(toaster).toHaveBeenCalledWith({ title: 'Model Deletion Successful', description: `Model ${mockModel.name} has been successfully deleted.`, @@ -67,7 +67,7 @@ describe('useDeleteModel', () => { ) }) - expect(mockDeleteModel).toHaveBeenCalledWith(mockModel) + expect(mockDeleteModel).toHaveBeenCalledWith("test-model") expect(toaster).not.toHaveBeenCalled() }) }) diff --git a/web/hooks/useDownloadModel.test.ts b/web/hooks/useDownloadModel.test.ts index fc0b7c21f4..ff75fbcd8b 100644 --- a/web/hooks/useDownloadModel.test.ts +++ b/web/hooks/useDownloadModel.test.ts @@ -13,12 +13,6 @@ jest.mock('jotai', () => ({ })) jest.mock('@janhq/core') jest.mock('@/extension/ExtensionManager') -jest.mock('./useGpuSetting', () => ({ - __esModule: true, - default: () => ({ - getGpuSettings: jest.fn().mockResolvedValue({ some: 'gpuSettings' }), - }), -})) describe('useDownloadModel', () => { beforeEach(() => { @@ -29,25 +23,24 @@ describe('useDownloadModel', () => { it('should download a model', async () => { const mockModel: core.Model = { id: 'test-model', - sources: [{ filename: 'test.bin' }], + sources: [{ filename: 'test.bin', url: 'https://fake.url' }], } as core.Model const mockExtension = { - downloadModel: jest.fn().mockResolvedValue(undefined), + pullModel: jest.fn().mockResolvedValue(undefined), } ;(useSetAtom as jest.Mock).mockReturnValue(() => undefined) ;(extensionManager.get as jest.Mock).mockReturnValue(mockExtension) const { result } = renderHook(() => useDownloadModel()) - await act(async () => { - await result.current.downloadModel(mockModel) + act(() => { + result.current.downloadModel(mockModel.sources[0].url, mockModel.id) }) - expect(mockExtension.downloadModel).toHaveBeenCalledWith( - mockModel, - { some: 'gpuSettings' }, - { ignoreSSL: undefined, proxy: '' } + expect(mockExtension.pullModel).toHaveBeenCalledWith( + mockModel.sources[0].url, + mockModel.id ) }) @@ -58,15 +51,18 @@ describe('useDownloadModel', () => { } as core.Model ;(core.joinPath as jest.Mock).mockResolvedValue('/path/to/model/test.bin') - ;(core.abortDownload as jest.Mock).mockResolvedValue(undefined) + const mockExtension = { + cancelModelPull: jest.fn().mockResolvedValue(undefined), + } ;(useSetAtom as jest.Mock).mockReturnValue(() => undefined) + ;(extensionManager.get as jest.Mock).mockReturnValue(mockExtension) const { result } = renderHook(() => useDownloadModel()) - await act(async () => { - await result.current.abortModelDownload(mockModel) + act(() => { + result.current.abortModelDownload(mockModel.id) }) - expect(core.abortDownload).toHaveBeenCalledWith('/path/to/model/test.bin') + expect(mockExtension.cancelModelPull).toHaveBeenCalledWith('test-model') }) it('should handle proxy settings', async () => { @@ -76,7 +72,7 @@ describe('useDownloadModel', () => { } as core.Model const mockExtension = { - downloadModel: jest.fn().mockResolvedValue(undefined), + pullModel: jest.fn().mockResolvedValue(undefined), } ;(useSetAtom as jest.Mock).mockReturnValue(() => undefined) ;(extensionManager.get as jest.Mock).mockReturnValue(mockExtension) @@ -85,14 +81,13 @@ describe('useDownloadModel', () => { const { result } = renderHook(() => useDownloadModel()) - await act(async () => { - await result.current.downloadModel(mockModel) + act(() => { + result.current.downloadModel(mockModel.sources[0].url, mockModel.id) }) - expect(mockExtension.downloadModel).toHaveBeenCalledWith( - mockModel, - expect.objectContaining({ some: 'gpuSettings' }), - expect.anything() + expect(mockExtension.pullModel).toHaveBeenCalledWith( + mockModel.sources[0].url, + mockModel.id ) }) }) diff --git a/web/hooks/useGetHFRepoData.test.ts b/web/hooks/useGetHFRepoData.test.ts index eaf86d79a0..01055612d8 100644 --- a/web/hooks/useGetHFRepoData.test.ts +++ b/web/hooks/useGetHFRepoData.test.ts @@ -1,6 +1,10 @@ +/** + * @jest-environment jsdom + */ import { renderHook, act } from '@testing-library/react' import { useGetHFRepoData } from './useGetHFRepoData' import { extensionManager } from '@/extension' +import * as hf from '@/utils/huggingface' jest.mock('@/extension', () => ({ extensionManager: { @@ -8,6 +12,8 @@ jest.mock('@/extension', () => ({ }, })) +jest.mock('@/utils/huggingface') + describe('useGetHFRepoData', () => { beforeEach(() => { jest.clearAllMocks() @@ -15,10 +21,7 @@ describe('useGetHFRepoData', () => { it('should fetch HF repo data successfully', async () => { const mockData = { name: 'Test Repo', stars: 100 } - const mockFetchHuggingFaceRepoData = jest.fn().mockResolvedValue(mockData) - ;(extensionManager.get as jest.Mock).mockReturnValue({ - fetchHuggingFaceRepoData: mockFetchHuggingFaceRepoData, - }) + ;(hf.fetchHuggingFaceRepoData as jest.Mock).mockReturnValue(mockData) const { result } = renderHook(() => useGetHFRepoData()) @@ -34,6 +37,5 @@ describe('useGetHFRepoData', () => { expect(result.current.error).toBeUndefined() expect(await data).toEqual(mockData) - expect(mockFetchHuggingFaceRepoData).toHaveBeenCalledWith('test-repo') }) }) diff --git a/web/hooks/useImportModel.test.ts b/web/hooks/useImportModel.test.ts index 2148f581b8..d37e4a8533 100644 --- a/web/hooks/useImportModel.test.ts +++ b/web/hooks/useImportModel.test.ts @@ -18,7 +18,7 @@ describe('useImportModel', () => { it('should import models successfully', async () => { const mockImportModels = jest.fn().mockResolvedValue(undefined) const mockExtension = { - importModels: mockImportModels, + importModel: mockImportModels, } as any jest.spyOn(extensionManager, 'get').mockReturnValue(mockExtension) @@ -26,15 +26,16 @@ describe('useImportModel', () => { const { result } = renderHook(() => useImportModel()) const models = [ - { importId: '1', name: 'Model 1', path: '/path/to/model1' }, - { importId: '2', name: 'Model 2', path: '/path/to/model2' }, + { modelId: '1', path: '/path/to/model1' }, + { modelId: '2', path: '/path/to/model2' }, ] as any await act(async () => { await result.current.importModels(models, 'local' as any) }) - expect(mockImportModels).toHaveBeenCalledWith(models, 'local') + expect(mockImportModels).toHaveBeenCalledWith('1', '/path/to/model1') + expect(mockImportModels).toHaveBeenCalledWith('2', '/path/to/model2') }) it('should update model info successfully', async () => { @@ -42,7 +43,7 @@ describe('useImportModel', () => { .fn() .mockResolvedValue({ id: 'model-1', name: 'Updated Model' }) const mockExtension = { - updateModelInfo: mockUpdateModelInfo, + updateModel: mockUpdateModelInfo, } as any jest.spyOn(extensionManager, 'get').mockReturnValue(mockExtension) diff --git a/web/hooks/useImportModel.ts b/web/hooks/useImportModel.ts index df6b085ca3..5650c73bda 100644 --- a/web/hooks/useImportModel.ts +++ b/web/hooks/useImportModel.ts @@ -103,6 +103,7 @@ const useImportModel = () => { const localImportModels = async ( models: ImportingModel[], + // TODO: @louis - We will set this option when cortex.cpp supports it optionType: OptionType ): Promise => { await models diff --git a/web/hooks/useModels.test.ts b/web/hooks/useModels.test.ts index 4c53ffaa71..33c1526727 100644 --- a/web/hooks/useModels.test.ts +++ b/web/hooks/useModels.test.ts @@ -1,7 +1,7 @@ // useModels.test.ts import { renderHook, act } from '@testing-library/react' -import { events, ModelEvent } from '@janhq/core' +import { events, ModelEvent, ModelManager } from '@janhq/core' import { extensionManager } from '@/extension' // Mock dependencies @@ -11,18 +11,11 @@ jest.mock('@/extension') import useModels from './useModels' // Mock data -const mockDownloadedModels = [ +const models = [ { id: 'model-1', name: 'Model 1' }, { id: 'model-2', name: 'Model 2' }, ] -const mockConfiguredModels = [ - { id: 'model-3', name: 'Model 3' }, - { id: 'model-4', name: 'Model 4' }, -] - -const mockDefaultModel = { id: 'default-model', name: 'Default Model' } - describe('useModels', () => { beforeEach(() => { jest.clearAllMocks() @@ -30,20 +23,23 @@ describe('useModels', () => { it('should fetch and set models on mount', async () => { const mockModelExtension = { - getDownloadedModels: jest.fn().mockResolvedValue(mockDownloadedModels), - getConfiguredModels: jest.fn().mockResolvedValue(mockConfiguredModels), - getDefaultModel: jest.fn().mockResolvedValue(mockDefaultModel), + getModels: jest.fn().mockResolvedValue(models), } as any + ;(ModelManager.instance as jest.Mock).mockReturnValue({ + models: { + values: () => ({ + toArray: () => {}, + }), + }, + }) jest.spyOn(extensionManager, 'get').mockReturnValue(mockModelExtension) - await act(async () => { + act(() => { renderHook(() => useModels()) }) - expect(mockModelExtension.getDownloadedModels).toHaveBeenCalled() - expect(mockModelExtension.getConfiguredModels).toHaveBeenCalled() - expect(mockModelExtension.getDefaultModel).toHaveBeenCalled() + expect(mockModelExtension.getModels).toHaveBeenCalled() }) it('should remove event listener on unmount', async () => { diff --git a/web/screens/Settings/HuggingFaceRepoDetailModal/ModelDownloadRow/index.tsx b/web/screens/Settings/HuggingFaceRepoDetailModal/ModelDownloadRow/index.tsx index cac0cb350f..488f795b77 100644 --- a/web/screens/Settings/HuggingFaceRepoDetailModal/ModelDownloadRow/index.tsx +++ b/web/screens/Settings/HuggingFaceRepoDetailModal/ModelDownloadRow/index.tsx @@ -15,12 +15,13 @@ import { modelDownloadStateAtom } from '@/hooks/useDownloadState' import { formatDownloadPercentage, toGibibytes } from '@/utils/converter' +import { normalizeModelId } from '@/utils/model' + import { mainViewStateAtom } from '@/helpers/atoms/App.atom' import { assistantsAtom } from '@/helpers/atoms/Assistant.atom' import { importHuggingFaceModelStageAtom } from '@/helpers/atoms/HuggingFace.atom' import { downloadedModelsAtom } from '@/helpers/atoms/Model.atom' -import { normalizeModelId } from '@/utils/model' type Props = { index: number