Allow model to be served under multiple names (#2894)
Co-authored-by: Alexandre Payot <alexandrep@graphcore.ai>
This commit is contained in:
parent
6dc1fc9cfe
commit
66ded03067
@ -150,18 +150,18 @@ if __name__ == "__main__":
|
|||||||
logger.info(f"args: {args}")
|
logger.info(f"args: {args}")
|
||||||
|
|
||||||
if args.served_model_name is not None:
|
if args.served_model_name is not None:
|
||||||
served_model = args.served_model_name
|
served_model_names = args.served_model_name
|
||||||
else:
|
else:
|
||||||
served_model = args.model
|
served_model_names = [args.model]
|
||||||
engine_args = AsyncEngineArgs.from_cli_args(args)
|
engine_args = AsyncEngineArgs.from_cli_args(args)
|
||||||
engine = AsyncLLMEngine.from_engine_args(
|
engine = AsyncLLMEngine.from_engine_args(
|
||||||
engine_args, usage_context=UsageContext.OPENAI_API_SERVER)
|
engine_args, usage_context=UsageContext.OPENAI_API_SERVER)
|
||||||
openai_serving_chat = OpenAIServingChat(engine, served_model,
|
openai_serving_chat = OpenAIServingChat(engine, served_model_names,
|
||||||
args.response_role,
|
args.response_role,
|
||||||
args.lora_modules,
|
args.lora_modules,
|
||||||
args.chat_template)
|
args.chat_template)
|
||||||
openai_serving_completion = OpenAIServingCompletion(
|
openai_serving_completion = OpenAIServingCompletion(
|
||||||
engine, served_model, args.lora_modules)
|
engine, served_model_names, args.lora_modules)
|
||||||
|
|
||||||
app.root_path = args.root_path
|
app.root_path = args.root_path
|
||||||
uvicorn.run(app,
|
uvicorn.run(app,
|
||||||
|
@ -54,11 +54,15 @@ def make_arg_parser():
|
|||||||
help="If provided, the server will require this key "
|
help="If provided, the server will require this key "
|
||||||
"to be presented in the header.")
|
"to be presented in the header.")
|
||||||
parser.add_argument("--served-model-name",
|
parser.add_argument("--served-model-name",
|
||||||
|
nargs="+",
|
||||||
type=str,
|
type=str,
|
||||||
default=None,
|
default=None,
|
||||||
help="The model name used in the API. If not "
|
help="The model name(s) used in the API. If multiple "
|
||||||
"specified, the model name will be the same as "
|
"names are provided, the server will respond to any "
|
||||||
"the huggingface name.")
|
"of the provided names. The model name in the model "
|
||||||
|
"field of a response will be the first name in this "
|
||||||
|
"list. If not specified, the model name will be the "
|
||||||
|
"same as the `--model` argument.")
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--lora-modules",
|
"--lora-modules",
|
||||||
type=str,
|
type=str,
|
||||||
|
@ -24,12 +24,12 @@ class OpenAIServingChat(OpenAIServing):
|
|||||||
|
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
engine: AsyncLLMEngine,
|
engine: AsyncLLMEngine,
|
||||||
served_model: str,
|
served_model_names: List[str],
|
||||||
response_role: str,
|
response_role: str,
|
||||||
lora_modules: Optional[List[LoRA]] = None,
|
lora_modules: Optional[List[LoRA]] = None,
|
||||||
chat_template=None):
|
chat_template=None):
|
||||||
super().__init__(engine=engine,
|
super().__init__(engine=engine,
|
||||||
served_model=served_model,
|
served_model_names=served_model_names,
|
||||||
lora_modules=lora_modules)
|
lora_modules=lora_modules)
|
||||||
self.response_role = response_role
|
self.response_role = response_role
|
||||||
self._load_chat_template(chat_template)
|
self._load_chat_template(chat_template)
|
||||||
@ -109,7 +109,7 @@ class OpenAIServingChat(OpenAIServing):
|
|||||||
result_generator: AsyncIterator[RequestOutput], request_id: str
|
result_generator: AsyncIterator[RequestOutput], request_id: str
|
||||||
) -> Union[ErrorResponse, AsyncGenerator[str, None]]:
|
) -> Union[ErrorResponse, AsyncGenerator[str, None]]:
|
||||||
|
|
||||||
model_name = request.model
|
model_name = self.served_model_names[0]
|
||||||
created_time = int(time.time())
|
created_time = int(time.time())
|
||||||
chunk_object_type = "chat.completion.chunk"
|
chunk_object_type = "chat.completion.chunk"
|
||||||
first_iteration = True
|
first_iteration = True
|
||||||
@ -251,7 +251,7 @@ class OpenAIServingChat(OpenAIServing):
|
|||||||
result_generator: AsyncIterator[RequestOutput],
|
result_generator: AsyncIterator[RequestOutput],
|
||||||
request_id: str) -> Union[ErrorResponse, ChatCompletionResponse]:
|
request_id: str) -> Union[ErrorResponse, ChatCompletionResponse]:
|
||||||
|
|
||||||
model_name = request.model
|
model_name = self.served_model_names[0]
|
||||||
created_time = int(time.time())
|
created_time = int(time.time())
|
||||||
final_res: RequestOutput = None
|
final_res: RequestOutput = None
|
||||||
|
|
||||||
|
@ -53,10 +53,10 @@ class OpenAIServingCompletion(OpenAIServing):
|
|||||||
|
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
engine: AsyncLLMEngine,
|
engine: AsyncLLMEngine,
|
||||||
served_model: str,
|
served_model_names: List[str],
|
||||||
lora_modules: Optional[List[LoRA]] = None):
|
lora_modules: Optional[List[LoRA]] = None):
|
||||||
super().__init__(engine=engine,
|
super().__init__(engine=engine,
|
||||||
served_model=served_model,
|
served_model_names=served_model_names,
|
||||||
lora_modules=lora_modules)
|
lora_modules=lora_modules)
|
||||||
|
|
||||||
async def create_completion(self, request: CompletionRequest,
|
async def create_completion(self, request: CompletionRequest,
|
||||||
@ -79,7 +79,7 @@ class OpenAIServingCompletion(OpenAIServing):
|
|||||||
return self.create_error_response(
|
return self.create_error_response(
|
||||||
"suffix is not currently supported")
|
"suffix is not currently supported")
|
||||||
|
|
||||||
model_name = request.model
|
model_name = self.served_model_names[0]
|
||||||
request_id = f"cmpl-{random_uuid()}"
|
request_id = f"cmpl-{random_uuid()}"
|
||||||
created_time = int(time.time())
|
created_time = int(time.time())
|
||||||
|
|
||||||
|
@ -29,10 +29,10 @@ class OpenAIServing:
|
|||||||
|
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
engine: AsyncLLMEngine,
|
engine: AsyncLLMEngine,
|
||||||
served_model: str,
|
served_model_names: List[str],
|
||||||
lora_modules=Optional[List[LoRA]]):
|
lora_modules=Optional[List[LoRA]]):
|
||||||
self.engine = engine
|
self.engine = engine
|
||||||
self.served_model = served_model
|
self.served_model_names = served_model_names
|
||||||
if lora_modules is None:
|
if lora_modules is None:
|
||||||
self.lora_requests = []
|
self.lora_requests = []
|
||||||
else:
|
else:
|
||||||
@ -74,13 +74,14 @@ class OpenAIServing:
|
|||||||
async def show_available_models(self) -> ModelList:
|
async def show_available_models(self) -> ModelList:
|
||||||
"""Show available models. Right now we only have one model."""
|
"""Show available models. Right now we only have one model."""
|
||||||
model_cards = [
|
model_cards = [
|
||||||
ModelCard(id=self.served_model,
|
ModelCard(id=served_model_name,
|
||||||
root=self.served_model,
|
root=self.served_model_names[0],
|
||||||
permission=[ModelPermission()])
|
permission=[ModelPermission()])
|
||||||
|
for served_model_name in self.served_model_names
|
||||||
]
|
]
|
||||||
lora_cards = [
|
lora_cards = [
|
||||||
ModelCard(id=lora.lora_name,
|
ModelCard(id=lora.lora_name,
|
||||||
root=self.served_model,
|
root=self.served_model_names[0],
|
||||||
permission=[ModelPermission()])
|
permission=[ModelPermission()])
|
||||||
for lora in self.lora_requests
|
for lora in self.lora_requests
|
||||||
]
|
]
|
||||||
@ -150,7 +151,7 @@ class OpenAIServing:
|
|||||||
return json_str
|
return json_str
|
||||||
|
|
||||||
async def _check_model(self, request) -> Optional[ErrorResponse]:
|
async def _check_model(self, request) -> Optional[ErrorResponse]:
|
||||||
if request.model == self.served_model:
|
if request.model in self.served_model_names:
|
||||||
return
|
return
|
||||||
if request.model in [lora.lora_name for lora in self.lora_requests]:
|
if request.model in [lora.lora_name for lora in self.lora_requests]:
|
||||||
return
|
return
|
||||||
@ -160,7 +161,7 @@ class OpenAIServing:
|
|||||||
status_code=HTTPStatus.NOT_FOUND)
|
status_code=HTTPStatus.NOT_FOUND)
|
||||||
|
|
||||||
def _maybe_get_lora(self, request) -> Optional[LoRARequest]:
|
def _maybe_get_lora(self, request) -> Optional[LoRARequest]:
|
||||||
if request.model == self.served_model:
|
if request.model in self.served_model_names:
|
||||||
return
|
return
|
||||||
for lora in self.lora_requests:
|
for lora in self.lora_requests:
|
||||||
if request.model == lora.lora_name:
|
if request.model == lora.lora_name:
|
||||||
|
Loading…
x
Reference in New Issue
Block a user