# ------------------------------------------------------------------------- # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT License. # -------------------------------------------------------------------------- from argparse import ArgumentParser from enum import Enum class AttentionMaskFormat: # Build 1D mask indice (sequence length). It requires right side padding! Recommended for BERT model to get best performance. MaskIndexEnd = 0 # For experiment only. Do not use it in production. MaskIndexEndAndStart = 1 # Raw attention mask with 0 means padding (or no attention) and 1 otherwise. AttentionMask = 2 # No attention mask NoMask = 3 class AttentionOpType(Enum): Attention = "Attention" MultiHeadAttention = "MultiHeadAttention" GroupQueryAttention = "GroupQueryAttention" PagedAttention = "PagedAttention" def __str__(self): return self.value # Override __eq__ to return string comparison def __hash__(self): return hash(self.value) def __eq__(self, other): return other.value == self.value class FusionOptions: """Options of fusion in graph optimization""" def __init__(self, model_type): self.enable_gelu = True self.enable_layer_norm = True self.enable_attention = True self.enable_rotary_embeddings = True # Use MultiHeadAttention instead of Attention operator. The difference: # (1) Attention has merged weights for Q/K/V projection, which might be faster in some cases since 3 MatMul is # merged into one. # (2) Attention could only handle self attention; MultiHeadAttention could handle both self and cross attention. self.use_multi_head_attention = False self.disable_multi_head_attention_bias = False self.enable_skip_layer_norm = True self.enable_embed_layer_norm = True self.enable_bias_skip_layer_norm = True self.enable_bias_gelu = True self.enable_gelu_approximation = False self.enable_qordered_matmul = True self.enable_shape_inference = True self.enable_gemm_fast_gelu = False self.group_norm_channels_last = True if model_type == "clip": self.enable_embed_layer_norm = False # Set default to sequence length for BERT model to use fused attention to speed up. # Note that embed layer normalization will convert 2D mask to 1D when mask type is MaskIndexEnd. self.attention_mask_format = AttentionMaskFormat.AttentionMask if model_type == "bert": self.attention_mask_format = AttentionMaskFormat.MaskIndexEnd elif model_type == "vit": self.attention_mask_format = AttentionMaskFormat.NoMask self.attention_op_type = None # options for stable diffusion if model_type in ["unet", "vae", "clip"]: self.enable_nhwc_conv = True self.enable_group_norm = True self.enable_skip_group_norm = True self.enable_bias_splitgelu = True self.enable_packed_qkv = True self.enable_packed_kv = True self.enable_bias_add = True def use_raw_attention_mask(self, use_raw_mask=True): if use_raw_mask: self.attention_mask_format = AttentionMaskFormat.AttentionMask else: self.attention_mask_format = AttentionMaskFormat.MaskIndexEnd def disable_attention_mask(self): self.attention_mask_format = AttentionMaskFormat.NoMask def set_attention_op_type(self, attn_op_type: AttentionOpType): self.attention_op_type = attn_op_type @staticmethod def parse(args): options = FusionOptions(args.model_type) if args.disable_gelu: options.enable_gelu = False if args.disable_layer_norm: options.enable_layer_norm = False if args.disable_rotary_embeddings: options.enable_rotary_embeddings = False if args.disable_attention: options.enable_attention = False if args.use_multi_head_attention: options.use_multi_head_attention = True if args.disable_skip_layer_norm: options.enable_skip_layer_norm = False if args.disable_embed_layer_norm: options.enable_embed_layer_norm = False if args.disable_bias_skip_layer_norm: options.enable_bias_skip_layer_norm = False if args.disable_bias_gelu: options.enable_bias_gelu = False if args.enable_gelu_approximation: options.enable_gelu_approximation = True if args.disable_shape_inference: options.enable_shape_inference = False if args.enable_gemm_fast_gelu: options.enable_gemm_fast_gelu = True if args.use_mask_index: options.use_raw_attention_mask(False) if args.use_raw_attention_mask: options.use_raw_attention_mask(True) if args.no_attention_mask: options.disable_attention_mask() if args.model_type in ["unet", "vae", "clip"]: if args.use_group_norm_channels_first: options.group_norm_channels_last = False if args.disable_nhwc_conv: options.enable_nhwc_conv = False if args.disable_group_norm: options.enable_group_norm = False if args.disable_skip_group_norm: options.enable_skip_group_norm = False if args.disable_bias_splitgelu: options.enable_bias_splitgelu = False if args.disable_packed_qkv: options.enable_packed_qkv = False if args.disable_packed_kv: options.enable_packed_kv = False if args.disable_bias_add: options.enable_bias_add = False return options @staticmethod def add_arguments(parser: ArgumentParser): parser.add_argument( "--disable_attention", required=False, action="store_true", help="disable Attention fusion", ) parser.set_defaults(disable_attention=False) parser.add_argument( "--disable_skip_layer_norm", required=False, action="store_true", help="disable SkipLayerNormalization fusion", ) parser.set_defaults(disable_skip_layer_norm=False) parser.add_argument( "--disable_embed_layer_norm", required=False, action="store_true", help="disable EmbedLayerNormalization fusion", ) parser.set_defaults(disable_embed_layer_norm=False) parser.add_argument( "--disable_bias_skip_layer_norm", required=False, action="store_true", help="disable Add Bias and SkipLayerNormalization fusion", ) parser.set_defaults(disable_bias_skip_layer_norm=False) parser.add_argument( "--disable_bias_gelu", required=False, action="store_true", help="disable Add Bias and Gelu/FastGelu fusion", ) parser.set_defaults(disable_bias_gelu=False) parser.add_argument( "--disable_layer_norm", required=False, action="store_true", help="disable LayerNormalization fusion", ) parser.set_defaults(disable_layer_norm=False) parser.add_argument( "--disable_gelu", required=False, action="store_true", help="disable Gelu fusion", ) parser.set_defaults(disable_gelu=False) parser.add_argument( "--enable_gelu_approximation", required=False, action="store_true", help="enable Gelu/BiasGelu to FastGelu conversion", ) parser.set_defaults(enable_gelu_approximation=False) parser.add_argument( "--disable_shape_inference", required=False, action="store_true", help="disable symbolic shape inference", ) parser.set_defaults(disable_shape_inference=False) parser.add_argument( "--enable_gemm_fast_gelu", required=False, action="store_true", help="enable GemmfastGelu fusion", ) parser.set_defaults(enable_gemm_fast_gelu=False) parser.add_argument( "--use_mask_index", required=False, action="store_true", help="use mask index to activate fused attention to speed up. It requires right-side padding!", ) parser.set_defaults(use_mask_index=False) parser.add_argument( "--use_raw_attention_mask", required=False, action="store_true", help="use raw attention mask. Use this option if your input is not right-side padding. This might deactivate fused attention and get worse performance.", ) parser.set_defaults(use_raw_attention_mask=False) parser.add_argument( "--no_attention_mask", required=False, action="store_true", help="no attention mask. Only works for model_type=bert", ) parser.set_defaults(no_attention_mask=False) parser.add_argument( "--use_multi_head_attention", required=False, action="store_true", help="Use MultiHeadAttention instead of Attention operator for testing purpose. " "Note that MultiHeadAttention might be slower than Attention when qkv are not packed. ", ) parser.set_defaults(use_multi_head_attention=False) parser.add_argument( "--disable_group_norm", required=False, action="store_true", help="not fuse GroupNorm. Only works for model_type=unet or vae", ) parser.set_defaults(disable_group_norm=False) parser.add_argument( "--disable_skip_group_norm", required=False, action="store_true", help="not fuse Add + GroupNorm to SkipGroupNorm. Only works for model_type=unet or vae", ) parser.set_defaults(disable_skip_group_norm=False) parser.add_argument( "--disable_packed_kv", required=False, action="store_true", help="not use packed kv for cross attention in MultiHeadAttention. Only works for model_type=unet", ) parser.set_defaults(disable_packed_kv=False) parser.add_argument( "--disable_packed_qkv", required=False, action="store_true", help="not use packed qkv for self attention in MultiHeadAttention. Only works for model_type=unet", ) parser.set_defaults(disable_packed_qkv=False) parser.add_argument( "--disable_bias_add", required=False, action="store_true", help="not fuse BiasAdd. Only works for model_type=unet", ) parser.set_defaults(disable_bias_add=False) parser.add_argument( "--disable_bias_splitgelu", required=False, action="store_true", help="not fuse BiasSplitGelu. Only works for model_type=unet", ) parser.set_defaults(disable_bias_splitgelu=False) parser.add_argument( "--disable_nhwc_conv", required=False, action="store_true", help="Do not use NhwcConv. Only works for model_type=unet or vae", ) parser.set_defaults(disable_nhwc_conv=False) parser.add_argument( "--use_group_norm_channels_first", required=False, action="store_true", help="Use channels_first (NCHW) instead of channels_last (NHWC) for GroupNorm. Only works for model_type=unet or vae", ) parser.set_defaults(use_group_norm_channels_first=False) parser.add_argument( "--disable_rotary_embeddings", required=False, action="store_true", help="Do not fuse rotary embeddings into RotaryEmbedding op", )