[Model] Support Tele-FLM Model (#15023)
Signed-off-by: Naitong Yu <ntyu@baai.ac.cn> Signed-off-by: jiangxin <horizon94@outlook.com> Co-authored-by: Jason Fang <jasonfang3900@gmail.com> Co-authored-by: jiangxin <horizon94@outlook.com>
This commit is contained in:
parent
8a8b30eac1
commit
2f4bd358f1
@ -472,6 +472,11 @@ See [this page](#generative-models) for more information on how to use generativ
|
|||||||
* `Tele-AI/TeleChat2-3B`, `Tele-AI/TeleChat2-7B`, `Tele-AI/TeleChat2-35B`, etc.
|
* `Tele-AI/TeleChat2-3B`, `Tele-AI/TeleChat2-7B`, `Tele-AI/TeleChat2-35B`, etc.
|
||||||
* ✅︎
|
* ✅︎
|
||||||
* ✅︎
|
* ✅︎
|
||||||
|
- * `TeleFLMForCausalLM`
|
||||||
|
* TeleFLM
|
||||||
|
* `CofeAI/FLM-2-52B-Instruct-2407`, `CofeAI/Tele-FLM`, etc.
|
||||||
|
* ✅︎
|
||||||
|
* ✅︎
|
||||||
- * `XverseForCausalLM`
|
- * `XverseForCausalLM`
|
||||||
* XVERSE
|
* XVERSE
|
||||||
* `xverse/XVERSE-7B-Chat`, `xverse/XVERSE-13B-Chat`, `xverse/XVERSE-65B-Chat`, etc.
|
* `xverse/XVERSE-7B-Chat`, `xverse/XVERSE-13B-Chat`, `xverse/XVERSE-65B-Chat`, etc.
|
||||||
|
12
examples/template_teleflm.jinja
Normal file
12
examples/template_teleflm.jinja
Normal file
@ -0,0 +1,12 @@
|
|||||||
|
{%- for message in messages %}
|
||||||
|
{%- if message['role'] == 'user' %}
|
||||||
|
{{- '<_user>' + message['content']|trim }}
|
||||||
|
{%- elif message['role'] == 'system' %}
|
||||||
|
{{- '<_system>' + message['content']|trim }}
|
||||||
|
{%- elif message['role'] == 'assistant' %}
|
||||||
|
{{- '<_bot>' + message['content'] }}
|
||||||
|
{%- endif %}
|
||||||
|
{%- endfor %}
|
||||||
|
{%- if add_generation_prompt %}
|
||||||
|
{{- '<_bot>' }}
|
||||||
|
{%- endif %}
|
@ -192,6 +192,8 @@ _TEXT_GENERATION_EXAMPLE_MODELS = {
|
|||||||
"SolarForCausalLM": _HfExamplesInfo("upstage/solar-pro-preview-instruct"),
|
"SolarForCausalLM": _HfExamplesInfo("upstage/solar-pro-preview-instruct"),
|
||||||
"TeleChat2ForCausalLM": _HfExamplesInfo("Tele-AI/TeleChat2-3B",
|
"TeleChat2ForCausalLM": _HfExamplesInfo("Tele-AI/TeleChat2-3B",
|
||||||
trust_remote_code=True),
|
trust_remote_code=True),
|
||||||
|
"TeleFLMForCausalLM": _HfExamplesInfo("CofeAI/FLM-2-52B-Instruct-2407",
|
||||||
|
trust_remote_code=True),
|
||||||
"XverseForCausalLM": _HfExamplesInfo("xverse/XVERSE-7B-Chat",
|
"XverseForCausalLM": _HfExamplesInfo("xverse/XVERSE-7B-Chat",
|
||||||
is_available_online=False,
|
is_available_online=False,
|
||||||
trust_remote_code=True),
|
trust_remote_code=True),
|
||||||
|
@ -104,6 +104,7 @@ _TEXT_GENERATION_MODELS = {
|
|||||||
"Starcoder2ForCausalLM": ("starcoder2", "Starcoder2ForCausalLM"),
|
"Starcoder2ForCausalLM": ("starcoder2", "Starcoder2ForCausalLM"),
|
||||||
"SolarForCausalLM": ("solar", "SolarForCausalLM"),
|
"SolarForCausalLM": ("solar", "SolarForCausalLM"),
|
||||||
"TeleChat2ForCausalLM": ("telechat2", "TeleChat2ForCausalLM"),
|
"TeleChat2ForCausalLM": ("telechat2", "TeleChat2ForCausalLM"),
|
||||||
|
"TeleFLMForCausalLM": ("teleflm", "TeleFLMForCausalLM"),
|
||||||
"XverseForCausalLM": ("llama", "LlamaForCausalLM"),
|
"XverseForCausalLM": ("llama", "LlamaForCausalLM"),
|
||||||
"Zamba2ForCausalLM": ("zamba2", "Zamba2ForCausalLM"),
|
"Zamba2ForCausalLM": ("zamba2", "Zamba2ForCausalLM"),
|
||||||
# [Encoder-decoder]
|
# [Encoder-decoder]
|
||||||
|
79
vllm/model_executor/models/teleflm.py
Normal file
79
vllm/model_executor/models/teleflm.py
Normal file
@ -0,0 +1,79 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
|
||||||
|
# Adapted from
|
||||||
|
# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/llama/modeling_llama.py
|
||||||
|
# Copyright 2023 The vLLM team.
|
||||||
|
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
|
||||||
|
# and OPT implementations in this library. It has been modified from its
|
||||||
|
# original forms to accommodate minor architectural differences compared
|
||||||
|
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
from typing import Type
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from vllm.config import VllmConfig
|
||||||
|
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||||
|
from vllm.model_executor.models.llama import (LlamaDecoderLayer,
|
||||||
|
LlamaForCausalLM, LlamaModel)
|
||||||
|
|
||||||
|
|
||||||
|
class TeleFLMModel(LlamaModel):
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
vllm_config: VllmConfig,
|
||||||
|
prefix: str = "",
|
||||||
|
layer_type: Type[LlamaDecoderLayer] = LlamaDecoderLayer,
|
||||||
|
):
|
||||||
|
super().__init__(vllm_config=vllm_config,
|
||||||
|
prefix=prefix,
|
||||||
|
layer_type=layer_type)
|
||||||
|
"""
|
||||||
|
This implementation is based on the µScaling paper presented at
|
||||||
|
the ICLR 2025 Workshop:
|
||||||
|
NanoLM: An Affordable LLM Study Benchmark \
|
||||||
|
via Accurate Loss Prediction across Scales
|
||||||
|
by Yiqun Yao et al.
|
||||||
|
Available at: https://openreview.net/forum?id=IwaPYg1SCA
|
||||||
|
arXiv preprint: https://arxiv.org/abs/2304.06875
|
||||||
|
"""
|
||||||
|
self.use_mup = self.config.use_mup
|
||||||
|
if self.use_mup:
|
||||||
|
self.input_mult = self.config.input_mult
|
||||||
|
|
||||||
|
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||||
|
embedding = self.embed_tokens(input_ids)
|
||||||
|
if self.use_mup:
|
||||||
|
embedding = embedding * self.input_mult
|
||||||
|
return embedding
|
||||||
|
|
||||||
|
|
||||||
|
class TeleFLMForCausalLM(LlamaForCausalLM):
|
||||||
|
|
||||||
|
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||||
|
super().__init__(vllm_config=vllm_config, prefix=prefix)
|
||||||
|
# mup
|
||||||
|
self.use_mup = self.config.use_mup
|
||||||
|
if self.use_mup:
|
||||||
|
self.mup_scale_factor = self.config.mup_scale_factor
|
||||||
|
self.output_mult = self.config.output_mult / self.mup_scale_factor
|
||||||
|
logit_scale = self.output_mult
|
||||||
|
self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
|
||||||
|
self.config.vocab_size,
|
||||||
|
logit_scale)
|
Loading…
x
Reference in New Issue
Block a user