Skip to content

Latest commit

 

History

History
212 lines (179 loc) · 6.12 KB

custom_op.md

File metadata and controls

212 lines (179 loc) · 6.12 KB

自定义模型算子文档

AddTo

def AddTo(self, input0, input1, alpha = 1.0):
    """
    将两个输入节点相加,并乘以一个可选的缩放因子 alpha。

    参数:
    input0 (GraphNode): 第一个输入节点。
    input1 (GraphNode): 第二个输入节点。
    alpha (float, optional): 缩放因子,默认为 1.0。

    返回:
    无返回值,结果存储在内部图结构中。
    """
    self.graph.append({"type": "AddTo", 
                       "nodes": {"input0": input0, "input1": input1, "alpha": FloatGraphNode(alpha)}})

DataTypeAs

def DataTypeAs(self, input, input1):
    """
    将输入节点的数据类型转换为另一个输入节点的数据类型。

    参数:
    input (GraphNode): 需要转换数据类型的输入节点。
    input1 (GraphNode): 目标数据类型的输入节点。

    返回:
    无返回值,结果存储在内部图结构中。
    """
    self.graph.append({"type": "DataTypeAs", 
                       "nodes": {"input": input, "input1": input1}})

Embedding

def Embedding(self, input, weight, output):
    """
    执行嵌入操作,将输入索引映射到嵌入权重。

    参数:
    input (GraphNode): 输入索引节点。
    weight (GraphNode): 嵌入权重节点。
    output (GraphNode): 输出节点。

    返回:
    无返回值,结果存储在内部图结构中。
    """
    self.graph.append({"type": "Embedding", 
                       "nodes": {"input": input, "weight": weight, "output": output}})

ExpandHead

def ExpandHead(self, input, headDim):
    """
    把input最后一维展开成[-1, headDim]。

    参数:
    input (GraphNode): 输入节点。
    headDim (int): 头部维度大小。

    返回:
    无返回值,结果存储在内部图结构中。
    """
    self.graph.append({"type": "ExpandHeads", 
                       "nodes": {"input": input, "headDim": IntGraphNode(headDim)}})

FusedAttention

def FusedAttention(self, q, k, v, curk, curv, original, mask, output, seqLens, 
                   scale, maskType=0, unitLen=128):
    """
    执行Attention操作。

    参数:
    q (GraphNode): 查询节点。
    k (GraphNode): key cache 
    v (GraphNode): value cache 
    curk (GraphNode): 当前key
    curv (GraphNode): 当前value
    original (GraphNode): 原始节点,用于恢复计算后的shape
    mask (GraphNode): 掩码
    output (GraphNode): 输出
    seqLens (GraphNode): 序列长度
    scale (float): 缩放因子
    maskType (int, optional): 掩码类型,默认为 0。
    unitLen (int, optional): 单元长度,默认为 128。

    返回:
    无返回值,结果存储在内部图结构中。
    """
    self.graph.append({"type": "FusedAttention", 
                       "nodes": {"q": q, "k": k, "v": v, "curk": curk, "curv": curv, 
                                "original": original, "mask": mask, "output": output, "seqLens": seqLens, 
                                 "scale": FloatGraphNode(scale), 
                                 "maskType": IntGraphNode(maskType), "unitLen": IntGraphNode(unitLen)}})

Linear

def Linear(self, input, weight, bias, output):
    """
    执行线性变换操作。

    参数:
    input (GraphNode): 输入节点。
    weight (GraphNode): 权重节点。
    bias (GraphNode): 偏置节点。
    output (GraphNode): 输出节点。

    返回:
    无返回值,结果存储在内部图结构中。
    """
    self.graph.append({"type": "Linear", 
                       "nodes": {"input": input, "weight": weight, "bias": bias, "output": output}})

LlamaRotatePosition2D

def LlamaRotatePosition2D(self, input, positionIds, sin, cos, rotaryDim):
    """
    执行 Llama 模型的二维位置旋转操作。

    参数:
    input (GraphNode): 输入节点。
    positionIds (GraphNode): 位置 ID 节点。
    sin (GraphNode): 正弦节点。
    cos (GraphNode): 余弦节点。
    rotaryDim (int): 旋转维度大小。

    返回:
    无返回值,结果存储在内部图结构中。
    """
    self.graph.append({"type": "LlamaRotatePosition2D", 
                       "nodes": {"input": input, "positionIds": positionIds, "sin": sin, "cos": cos, "rotaryDim": IntGraphNode(rotaryDim)}})

MulTo

def MulTo(self, input0, input1):
    """
    将两个输入节点相乘。

    参数:
    input0 (GraphNode): 第一个输入节点。
    input1 (GraphNode): 第二个输入节点。

    返回:
    无返回值,结果存储在内部图结构中。
    """
    self.graph.append({"type": "MulTo", 
                       "nodes": {"input0": input0, "input1": input1}})

RMSNorm

def RMSNorm(self, input, weight, eps, output):
    """
    执行 RMS 归一化操作。

    参数:
    input (GraphNode): 输入节点。
    weight (GraphNode): 权重节点。
    eps (float): 小常数,用于防止除零错误。
    output (GraphNode): 输出节点。

    返回:
    无返回值,结果存储在内部图结构中。
    """
    self.graph.append({"type": "RMSNorm", 
                       "nodes": {"input": input, "weight": weight, "eps": FloatGraphNode(eps), "output": output}})

Silu

def Silu(self, input, output):
    """
    执行 SiLU(Sigmoid Linear Unit)激活函数操作。

    参数:
    input (GraphNode): 输入节点。
    output (GraphNode): 输出节点。

    返回:
    无返回值,结果存储在内部图结构中。
    """
    self.graph.append({"type": "Silu", 
                       "nodes": {"input": input, "output": output}})

SplitLastTokenStates

def SplitLastTokenStates(self, input, seqLens, output):
    """
    分割batch输入中每个batch的最后一个 token 状态。

    参数:
    input (GraphNode): 输入节点。
    seqLens (GraphNode): 序列长度节点。
    output (GraphNode): 输出节点。

    返回:
    无返回值,结果存储在内部图结构中。
    """
    self.graph.append({"type": "SplitLastTokenStates", 
                       "nodes": {"input": input, "output": output, "seqLens": seqLens}})