Allow model to be served under multiple names (#2894)

Co-authored-by: Alexandre Payot <alexandrep@graphcore.ai>
This commit is contained in:
Harry Mellor 2024-04-18 08:16:26 +01:00 committed by GitHub
parent 6dc1fc9cfe
commit 66ded03067
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 26 additions and 21 deletions

View File

@ -150,18 +150,18 @@ if __name__ == "__main__":
logger.info(f"args: {args}") logger.info(f"args: {args}")
if args.served_model_name is not None: if args.served_model_name is not None:
served_model = args.served_model_name served_model_names = args.served_model_name
else: else:
served_model = args.model served_model_names = [args.model]
engine_args = AsyncEngineArgs.from_cli_args(args) engine_args = AsyncEngineArgs.from_cli_args(args)
engine = AsyncLLMEngine.from_engine_args( engine = AsyncLLMEngine.from_engine_args(
engine_args, usage_context=UsageContext.OPENAI_API_SERVER) engine_args, usage_context=UsageContext.OPENAI_API_SERVER)
openai_serving_chat = OpenAIServingChat(engine, served_model, openai_serving_chat = OpenAIServingChat(engine, served_model_names,
args.response_role, args.response_role,
args.lora_modules, args.lora_modules,
args.chat_template) args.chat_template)
openai_serving_completion = OpenAIServingCompletion( openai_serving_completion = OpenAIServingCompletion(
engine, served_model, args.lora_modules) engine, served_model_names, args.lora_modules)
app.root_path = args.root_path app.root_path = args.root_path
uvicorn.run(app, uvicorn.run(app,

View File

@ -54,11 +54,15 @@ def make_arg_parser():
help="If provided, the server will require this key " help="If provided, the server will require this key "
"to be presented in the header.") "to be presented in the header.")
parser.add_argument("--served-model-name", parser.add_argument("--served-model-name",
nargs="+",
type=str, type=str,
default=None, default=None,
help="The model name used in the API. If not " help="The model name(s) used in the API. If multiple "
"specified, the model name will be the same as " "names are provided, the server will respond to any "
"the huggingface name.") "of the provided names. The model name in the model "
"field of a response will be the first name in this "
"list. If not specified, the model name will be the "
"same as the `--model` argument.")
parser.add_argument( parser.add_argument(
"--lora-modules", "--lora-modules",
type=str, type=str,

View File

@ -24,12 +24,12 @@ class OpenAIServingChat(OpenAIServing):
def __init__(self, def __init__(self,
engine: AsyncLLMEngine, engine: AsyncLLMEngine,
served_model: str, served_model_names: List[str],
response_role: str, response_role: str,
lora_modules: Optional[List[LoRA]] = None, lora_modules: Optional[List[LoRA]] = None,
chat_template=None): chat_template=None):
super().__init__(engine=engine, super().__init__(engine=engine,
served_model=served_model, served_model_names=served_model_names,
lora_modules=lora_modules) lora_modules=lora_modules)
self.response_role = response_role self.response_role = response_role
self._load_chat_template(chat_template) self._load_chat_template(chat_template)
@ -109,7 +109,7 @@ class OpenAIServingChat(OpenAIServing):
result_generator: AsyncIterator[RequestOutput], request_id: str result_generator: AsyncIterator[RequestOutput], request_id: str
) -> Union[ErrorResponse, AsyncGenerator[str, None]]: ) -> Union[ErrorResponse, AsyncGenerator[str, None]]:
model_name = request.model model_name = self.served_model_names[0]
created_time = int(time.time()) created_time = int(time.time())
chunk_object_type = "chat.completion.chunk" chunk_object_type = "chat.completion.chunk"
first_iteration = True first_iteration = True
@ -251,7 +251,7 @@ class OpenAIServingChat(OpenAIServing):
result_generator: AsyncIterator[RequestOutput], result_generator: AsyncIterator[RequestOutput],
request_id: str) -> Union[ErrorResponse, ChatCompletionResponse]: request_id: str) -> Union[ErrorResponse, ChatCompletionResponse]:
model_name = request.model model_name = self.served_model_names[0]
created_time = int(time.time()) created_time = int(time.time())
final_res: RequestOutput = None final_res: RequestOutput = None

View File

@ -53,10 +53,10 @@ class OpenAIServingCompletion(OpenAIServing):
def __init__(self, def __init__(self,
engine: AsyncLLMEngine, engine: AsyncLLMEngine,
served_model: str, served_model_names: List[str],
lora_modules: Optional[List[LoRA]] = None): lora_modules: Optional[List[LoRA]] = None):
super().__init__(engine=engine, super().__init__(engine=engine,
served_model=served_model, served_model_names=served_model_names,
lora_modules=lora_modules) lora_modules=lora_modules)
async def create_completion(self, request: CompletionRequest, async def create_completion(self, request: CompletionRequest,
@ -79,7 +79,7 @@ class OpenAIServingCompletion(OpenAIServing):
return self.create_error_response( return self.create_error_response(
"suffix is not currently supported") "suffix is not currently supported")
model_name = request.model model_name = self.served_model_names[0]
request_id = f"cmpl-{random_uuid()}" request_id = f"cmpl-{random_uuid()}"
created_time = int(time.time()) created_time = int(time.time())

View File

@ -29,10 +29,10 @@ class OpenAIServing:
def __init__(self, def __init__(self,
engine: AsyncLLMEngine, engine: AsyncLLMEngine,
served_model: str, served_model_names: List[str],
lora_modules=Optional[List[LoRA]]): lora_modules=Optional[List[LoRA]]):
self.engine = engine self.engine = engine
self.served_model = served_model self.served_model_names = served_model_names
if lora_modules is None: if lora_modules is None:
self.lora_requests = [] self.lora_requests = []
else: else:
@ -74,13 +74,14 @@ class OpenAIServing:
async def show_available_models(self) -> ModelList: async def show_available_models(self) -> ModelList:
"""Show available models. Right now we only have one model.""" """Show available models. Right now we only have one model."""
model_cards = [ model_cards = [
ModelCard(id=self.served_model, ModelCard(id=served_model_name,
root=self.served_model, root=self.served_model_names[0],
permission=[ModelPermission()]) permission=[ModelPermission()])
for served_model_name in self.served_model_names
] ]
lora_cards = [ lora_cards = [
ModelCard(id=lora.lora_name, ModelCard(id=lora.lora_name,
root=self.served_model, root=self.served_model_names[0],
permission=[ModelPermission()]) permission=[ModelPermission()])
for lora in self.lora_requests for lora in self.lora_requests
] ]
@ -150,7 +151,7 @@ class OpenAIServing:
return json_str return json_str
async def _check_model(self, request) -> Optional[ErrorResponse]: async def _check_model(self, request) -> Optional[ErrorResponse]:
if request.model == self.served_model: if request.model in self.served_model_names:
return return
if request.model in [lora.lora_name for lora in self.lora_requests]: if request.model in [lora.lora_name for lora in self.lora_requests]:
return return
@ -160,7 +161,7 @@ class OpenAIServing:
status_code=HTTPStatus.NOT_FOUND) status_code=HTTPStatus.NOT_FOUND)
def _maybe_get_lora(self, request) -> Optional[LoRARequest]: def _maybe_get_lora(self, request) -> Optional[LoRARequest]:
if request.model == self.served_model: if request.model in self.served_model_names:
return return
for lora in self.lora_requests: for lora in self.lora_requests:
if request.model == lora.lora_name: if request.model == lora.lora_name: