Skip to content

Commit

Permalink
[OPENVINO-CODE] add-device-options (#895)
Browse files Browse the repository at this point in the history
  • Loading branch information
kumarijy authored Apr 3, 2024
1 parent 8a7bf32 commit 8f23817
Show file tree
Hide file tree
Showing 11 changed files with 210 additions and 27 deletions.
12 changes: 12 additions & 0 deletions modules/openvino_code/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@
"vsce:publish": "vsce publish",
"ovsx:publish": "ovsx publish",
"clear-out": "rimraf ./out"

},
"devDependencies": {
"@types/glob": "8.1.0",
Expand Down Expand Up @@ -200,6 +201,17 @@
],
"description": "Which model to use for code generation."
},
"openvinoCode.device": {
"order": 1,
"type": "string",
"default": "CPU",
"enum":[
"CPU",
"GPU",
"NPU"
],
"description": "Which device to use for code generation"
},
"openvinoCode.serverUrl": {
"order": 1,
"type": "string",
Expand Down
17 changes: 11 additions & 6 deletions modules/openvino_code/server/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ dependencies = [
'torch @ https://download.pytorch.org/whl/cpu-cxx11-abi/torch-2.0.1%2Bcpu.cxx11.abi-cp310-cp310-linux_x86_64.whl ; sys_platform=="linux" and python_version == "3.10"',
'torch @ https://download.pytorch.org/whl/cpu-cxx11-abi/torch-2.0.1%2Bcpu.cxx11.abi-cp311-cp311-linux_x86_64.whl ; sys_platform=="linux" and python_version == "3.11"',
'torch ; sys_platform != "linux"',
'openvino==2023.3.0',
'openvino==2024.0.0',
'transformers==4.36.0',
'optimum==1.17.1',
'optimum-intel[openvino]==1.15.0',
Expand All @@ -27,13 +27,18 @@ build-backend = "setuptools.build_meta"

[tool.black]
line-length = 119
target-versions = ["py38", "py39", "py310", "py311"]

target-version = ['py38', 'py39', 'py310', 'py311']
unstable = true
preview = true

[tool.ruff]
ignore = ["C901", "E501", "E741", "W605"]
select = ["C", "E", "F", "I", "W"]
lint.ignore = ["C901", "E501", "E741", "W605", "F401", "W292"]
lint.select = ["C", "E", "F", "I", "W"]
lint.extend-safe-fixes = ["F601"]
lint.extend-unsafe-fixes = ["UP034"]
lint.fixable = ["F401"]
line-length = 119

[tool.ruff.isort]

[tool.ruff.lint.isort]
lines-after-imports = 2
18 changes: 15 additions & 3 deletions modules/openvino_code/server/src/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,11 @@ async def generate_stream(
generation_request = TypeAdapter(GenerationRequest).validate_python(await request.json())
logger.info(generation_request)
return StreamingResponse(
generator.generate_stream(generation_request.inputs, generation_request.parameters.model_dump(), request)
generator.generate_stream(
generation_request.inputs,
generation_request.parameters.model_dump(),
request,
)
)


Expand All @@ -127,7 +131,11 @@ async def summarize(

start = perf_counter()
generated_text: str = generator.summarize(
request.inputs, request.template, request.definition, request.format, request.parameters.model_dump()
request.inputs,
request.template,
request.definition,
request.format,
request.parameters.model_dump(),
)
stop = perf_counter()

Expand All @@ -148,6 +156,10 @@ async def summarize_stream(
logger.info(request)
return StreamingResponse(
generator.summarize_stream(
request.inputs, request.template, request.definition, request.format, request.parameters.model_dump()
request.inputs,
request.template,
request.definition,
request.format,
request.parameters.model_dump(),
)
)
91 changes: 76 additions & 15 deletions modules/openvino_code/server/src/generators.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,17 @@
from pathlib import Path
from threading import Thread
from time import time
from typing import Any, Callable, Container, Dict, Generator, List, Optional, Type, Union
from typing import (
Any,
Callable,
Container,
Dict,
Generator,
List,
Optional,
Type,
Union,
)

import torch
from fastapi import Request
Expand Down Expand Up @@ -53,11 +63,20 @@ def get_model(checkpoint: str, device: str = "CPU") -> OVModel:
model_class = get_model_class(checkpoint)
try:
model = model_class.from_pretrained(
checkpoint, ov_config=ov_config, compile=False, device=device, trust_remote_code=True
checkpoint,
ov_config=ov_config,
compile=False,
device=device,
trust_remote_code=True,
)
except EntryNotFoundError:
model = model_class.from_pretrained(
checkpoint, ov_config=ov_config, export=True, compile=False, device=device, trust_remote_code=True
checkpoint,
ov_config=ov_config,
export=True,
compile=False,
device=device,
trust_remote_code=True,
)
model.save_pretrained(model_path)
model.compile()
Expand All @@ -75,10 +94,24 @@ def __call__(self, input_text: str, parameters: Dict[str, Any]) -> str:
async def generate_stream(self, input_text: str, parameters: Dict[str, Any], request: Request):
raise NotImplementedError

def summarize(self, input_text: str, template: str, signature: str, style: str, parameters: Dict[str, Any]):
def summarize(
self,
input_text: str,
template: str,
signature: str,
style: str,
parameters: Dict[str, Any],
):
raise NotImplementedError

def summarize_stream(self, input_text: str, template: str, signature: str, style: str, parameters: Dict[str, Any]):
def summarize_stream(
self,
input_text: str,
template: str,
signature: str,
style: str,
parameters: Dict[str, Any],
):
raise NotImplementedError


Expand Down Expand Up @@ -128,13 +161,19 @@ def __call__(self, input_text: str, parameters: Dict[str, Any]) -> str:
prompt_len = input_ids.shape[-1]
config = GenerationConfig.from_dict({**self.generation_config.to_dict(), **parameters})
output_ids = self.model.generate(
input_ids, generation_config=config, stopping_criteria=stopping_criteria, **self.assistant_model_config
input_ids,
generation_config=config,
stopping_criteria=stopping_criteria,
**self.assistant_model_config,
)[0][prompt_len:]
logger.info(f"Number of input tokens: {prompt_len}; generated {len(output_ids)} tokens")
return self.tokenizer.decode(output_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)

async def generate_stream(
self, input_text: str, parameters: Dict[str, Any], request: Optional[Request] = None
self,
input_text: str,
parameters: Dict[str, Any],
request: Optional[Request] = None,
) -> Generator[str, None, None]:
input_ids = self.tokenizer.encode(input_text, return_tensors="pt")
streamer = TextIteratorStreamer(self.tokenizer, skip_prompt=True, skip_special_tokens=True)
Expand Down Expand Up @@ -192,7 +231,10 @@ def generate_between(
prev_len = prompt.shape[-1]

prompt = self.model.generate(
prompt, generation_config=config, stopping_criteria=stopping_criteria, **self.assistant_model_config
prompt,
generation_config=config,
stopping_criteria=stopping_criteria,
**self.assistant_model_config,
)[
:, :-1
] # skip the last token - stop token
Expand All @@ -219,7 +261,10 @@ async def generate_between_stream(
prev_len = prompt.shape[-1]

prompt = self.model.generate(
prompt, generation_config=config, stopping_criteria=stopping_criteria, **self.assistant_model_config
prompt,
generation_config=config,
stopping_criteria=stopping_criteria,
**self.assistant_model_config,
)[
:, :-1
] # skip the last token - stop token
Expand All @@ -237,24 +282,40 @@ def summarization_input(function: str, signature: str, style: str) -> str:
signature=signature,
)

def summarize(self, input_text: str, template: str, signature: str, style: str, parameters: Dict[str, Any]) -> str:
def summarize(
self,
input_text: str,
template: str,
signature: str,
style: str,
parameters: Dict[str, Any],
) -> str:
prompt = self.summarization_input(input_text, signature, style)
splited_template = re.split(r"\$\{.*\}", template)
splited_template[0] = prompt + splited_template[0]

return self.generate_between(splited_template, parameters, stopping_criteria=self.summarize_stopping_criteria)[
len(prompt) :
]
return self.generate_between(
splited_template,
parameters,
stopping_criteria=self.summarize_stopping_criteria,
)[len(prompt) :]

async def summarize_stream(
self, input_text: str, template: str, signature: str, style: str, parameters: Dict[str, Any]
self,
input_text: str,
template: str,
signature: str,
style: str,
parameters: Dict[str, Any],
):
prompt = self.summarization_input(input_text, signature, style)
splited_template = re.split(r"\$\{.*\}", template)
splited_template = [prompt] + splited_template

async for token in self.generate_between_stream(
splited_template, parameters, stopping_criteria=self.summarize_stopping_criteria
splited_template,
parameters,
stopping_criteria=self.summarize_stopping_criteria,
):
yield token

Expand Down
25 changes: 25 additions & 0 deletions modules/openvino_code/shared/device.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
import { Features } from './features';

enum DeviceId {
CPU = 'CPU',
GPU = 'GPU',
NPU = 'NPU',
}

export enum DeviceName {
CPU = 'CPU',
GPU = 'GPU',
NPU = 'NPU',
}

export const DEVICE_NAME_TO_ID_MAP: Record<DeviceName, DeviceId> = {
[DeviceName.CPU]: DeviceId.CPU,
[DeviceName.GPU]: DeviceId.GPU,
[DeviceName.NPU]: DeviceId.NPU,
};

export const DEVICE_SUPPORTED_FEATURES: Record<DeviceName, Features[]> = {
[DeviceName.CPU]: [Features.CODE_COMPLETION, Features.SUMMARIZATION, Features.FIM],
[DeviceName.GPU]: [Features.CODE_COMPLETION, Features.SUMMARIZATION, Features.FIM],
[DeviceName.NPU]: [Features.CODE_COMPLETION, Features.SUMMARIZATION, Features.FIM],
};
1 change: 1 addition & 0 deletions modules/openvino_code/shared/side-panel-message.ts
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ export enum SidePanelMessageTypes {
GENERATE_COMPLETION_CLICK = `${sidePanelMessagePrefix}.generateCompletionClick`,
SETTINGS_CLICK = `${sidePanelMessagePrefix}.settingsClick`,
MODEL_CHANGE = `${sidePanelMessagePrefix}.modelChange`,
DEVICE_CHANGE = `${sidePanelMessagePrefix}.deviceChange`,
}

export interface ISidePanelMessage<P = unknown> {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
//import { ModelName } from '@shared/model';
import { DeviceName } from '@shared/device';
import { Select, SelectOptionProps } from '../../../shared/Select/Select';
import { ServerStatus } from '@shared/server-state';
import { Features } from '@shared/features';

const options: SelectOptionProps<DeviceName>[] = [
{ value: DeviceName.CPU },
{ value: DeviceName.GPU },
{ value: DeviceName.NPU },
];

interface DeviceSelectProps {
disabled: boolean;
selectedDeviceName: DeviceName;
onChange: (deviceName: DeviceName) => void;
supportedFeatures: Features[];
serverStatus: ServerStatus;
}

export const DeviceSelect = ({
disabled,
selectedDeviceName,
onChange,
supportedFeatures,
serverStatus,
}: DeviceSelectProps): JSX.Element => {
const isServerStopped = serverStatus === ServerStatus.STOPPED;
return (
<>
<Select
label="Device"
options={options}
selectedValue={selectedDeviceName}
disabled={disabled}
onChange={(value) => onChange(value)}
></Select>
{isServerStopped && <span>Supported Features: {supportedFeatures.join(', ')}</span>}
</>
);
};
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@ import { ServerStatus } from './ServerStatus/ServerStatus';
import './ServerSection.css';
import { ModelSelect } from './ModelSelect/ModelSelect';
import { ModelName } from '@shared/model';
import { DeviceSelect } from './DeviceSelect/DeviceSelect';
import { DeviceName } from '@shared/device';

interface ServerSectionProps {
state: IExtensionState | null;
Expand Down Expand Up @@ -46,6 +48,15 @@ export function ServerSection({ state }: ServerSectionProps): JSX.Element {
});
};

const handleDeviceChange = (deviceName: DeviceName) => {
vscode.postMessage({
type: SidePanelMessageTypes.DEVICE_CHANGE,
payload: {
deviceName,
},
});
};

if (!state) {
return <>Extension state is not available</>;
}
Expand All @@ -64,6 +75,13 @@ export function ServerSection({ state }: ServerSectionProps): JSX.Element {
supportedFeatures={state.features.supportedList}
serverStatus={state.server.status}
></ModelSelect>
<DeviceSelect
disabled={!isServerStopped}
onChange={handleDeviceChange}
selectedDeviceName={state.config.device}
supportedFeatures={state.features.supportedList}
serverStatus={state.server.status}
></DeviceSelect>
{isServerStarting && <StartingStages currentStage={state.server.stage}></StartingStages>}
<div className="button-group">
{isServerStopped && <button onClick={handleStartServerClick}>Start Server</button>}
Expand Down
2 changes: 2 additions & 0 deletions modules/openvino_code/src/configuration.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import { ModelName } from '@shared/model';
import { DeviceName } from '@shared/device';
import { WorkspaceConfiguration, workspace } from 'vscode';
import { CONFIG_KEY } from './constants';

Expand All @@ -7,6 +8,7 @@ import { CONFIG_KEY } from './constants';
*/
export type CustomConfiguration = {
model: ModelName;
device: DeviceName;
serverUrl: string;
serverRequestTimeout: number;
streamInlineCompletion: boolean;
Expand Down
Loading

0 comments on commit 8f23817

Please sign in to comment.