[Model] fix model testing for TeleChat2ForCausalLM and V0 llama4 (#16112)
Signed-off-by: Lu Fang <fanglu@fb.com>
This commit is contained in:
parent
29283eaa7e
commit
620fc2d09e
@ -617,10 +617,15 @@ class FlashAttentionImpl(AttentionImpl):
|
|||||||
blocksparse_params: Optional[Dict[str, Any]] = None,
|
blocksparse_params: Optional[Dict[str, Any]] = None,
|
||||||
logits_soft_cap: Optional[float] = None,
|
logits_soft_cap: Optional[float] = None,
|
||||||
attn_type: str = AttentionType.DECODER,
|
attn_type: str = AttentionType.DECODER,
|
||||||
|
use_irope: bool = False,
|
||||||
) -> None:
|
) -> None:
|
||||||
if blocksparse_params is not None:
|
if blocksparse_params is not None:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"FlashAttention does not support block-sparse attention.")
|
"FlashAttention does not support block-sparse attention.")
|
||||||
|
if use_irope:
|
||||||
|
logger.warning(
|
||||||
|
"Using irope in V0 is not supported yet, it will fall back "
|
||||||
|
"to global attention for long context.")
|
||||||
self.num_heads = num_heads
|
self.num_heads = num_heads
|
||||||
self.head_size = head_size
|
self.head_size = head_size
|
||||||
self.scale = float(scale)
|
self.scale = float(scale)
|
||||||
|
@ -19,7 +19,7 @@
|
|||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
from typing import Iterable, Set, Tuple
|
from typing import Iterable, Set, Tuple, Type
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
@ -27,6 +27,7 @@ from vllm.config import VllmConfig
|
|||||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||||
from vllm.model_executor.models.llama import LlamaForCausalLM, LlamaModel
|
from vllm.model_executor.models.llama import LlamaForCausalLM, LlamaModel
|
||||||
|
|
||||||
|
from .llama import LlamaDecoderLayer
|
||||||
from .utils import (AutoWeightsLoader, PPMissingLayer, WeightsMapper,
|
from .utils import (AutoWeightsLoader, PPMissingLayer, WeightsMapper,
|
||||||
is_pp_missing_parameter)
|
is_pp_missing_parameter)
|
||||||
|
|
||||||
@ -120,7 +121,10 @@ class TeleChat2ForCausalLM(LlamaForCausalLM):
|
|||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
def _init_model(self, vllm_config: VllmConfig, prefix: str = ""):
|
def _init_model(self,
|
||||||
|
vllm_config: VllmConfig,
|
||||||
|
prefix: str = "",
|
||||||
|
layer_type: Type[LlamaDecoderLayer] = LlamaDecoderLayer):
|
||||||
return TeleChat2Model(vllm_config=vllm_config, prefix=prefix)
|
return TeleChat2Model(vllm_config=vllm_config, prefix=prefix)
|
||||||
|
|
||||||
def load_weights(self, weights: Iterable[Tuple[str,
|
def load_weights(self, weights: Iterable[Tuple[str,
|
||||||
|
Loading…
x
Reference in New Issue
Block a user