Skip to content

Commit

Permalink
add gelu for cuda
Browse files Browse the repository at this point in the history
  • Loading branch information
黄宇扬 committed Sep 23, 2024
1 parent 016eb32 commit 02de743
Show file tree
Hide file tree
Showing 4 changed files with 36 additions and 2 deletions.
4 changes: 4 additions & 0 deletions include/devices/cuda/cudadevice.h
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,10 @@ namespace fastllm {
void Run(const std::string &opType, const DataDict &datas, const FloatDict &floatParams, const IntDict &intParams);
};

class CudaGeluOp : BaseOperator {
void Run(const std::string &opType, const DataDict &datas, const FloatDict &floatParams, const IntDict &intParams);
};

class CudaGeluNewOp : BaseOperator {
void Run(const std::string &opType, const DataDict &datas, const FloatDict &floatParams, const IntDict &intParams);
};
Expand Down
3 changes: 2 additions & 1 deletion include/devices/cuda/fastllm-cuda.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,8 @@ bool FastllmBF16ToFloat(void *a, void *b, int len);
bool FastllmCudaEmbedding(const fastllm::Data &input, const fastllm::Data &weight, fastllm::Data &output);
bool FastllmCudaAttention(const fastllm::Data &q, const fastllm::Data &k, const fastllm::Data &v,
const fastllm::Data &mask, const fastllm::Data &output, int group, float scale);
bool FastllmCudaGeluNew(const fastllm::Data &input, fastllm::Data &output);
bool FastllmCudaGeluNew(const fastllm::Data &input, fastllm::Data &output);\
bool FastllmCudaGelu(const fastllm::Data &input, fastllm::Data &output);
bool FastllmCudaSilu(const fastllm::Data &input, fastllm::Data &output);
bool FastllmCudaSwiglu(const fastllm::Data &input, fastllm::Data &output);
bool FastllmCudaMul(const fastllm::Data &input, float v, fastllm::Data &output);
Expand Down
10 changes: 10 additions & 0 deletions src/devices/cuda/cudadevice.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ namespace fastllm {
this->ops["MatMul"] = (BaseOperator*)(new CudaMatMulOp());
this->ops["MatMulTransB"] = (BaseOperator*)(new CudaMatMulTransBOp());
this->ops["SoftMax"] = (BaseOperator*)(new CudaSoftMaxOp());
this->ops["Gelu"] = (BaseOperator*)(new CudaGeluOp());
this->ops["GeluNew"] = (BaseOperator*)(new CudaGeluNewOp());
this->ops["Silu"] = (BaseOperator*)(new CudaSiluOp());
this->ops["Swiglu"] = (BaseOperator*)(new CudaSwigluOp());
Expand Down Expand Up @@ -571,6 +572,15 @@ namespace fastllm {
FastllmCudaGeluNew(input, output);
}

void CudaGeluOp::Run(const std::string &opType, const fastllm::DataDict &datas,
const fastllm::FloatDict &floatParams, const fastllm::IntDict &intParams) {
Data &input = *(datas.find("input")->second);
Data &output = *(datas.find("output")->second);
output.Allocate();
AssertInFastLLM(input.dataType == DataType::FLOAT32, "GeluNew error: Data's type should be float32.\n");
FastllmCudaGelu(input, output);
}

void CudaSwigluOp::Reshape(const std::string &opType, const fastllm::DataDict &datas,
const fastllm::FloatDict &floatParams, const fastllm::IntDict &intParams) {
Data &input = *(datas.find("input")->second);
Expand Down
21 changes: 20 additions & 1 deletion src/devices/cuda/fastllm-cuda.cu
Original file line number Diff line number Diff line change
Expand Up @@ -433,6 +433,14 @@ __global__ void FastllmCudaBiasKernel(half *a, half *bias, int k) {
}

__global__ void FastllmGeluKernel(float* a, float *b, int len) {
int idx = threadIdx.x + blockIdx.x * blockDim.x;
if (idx < len) {
float x = a[idx];
b[idx] = x * 0.5f * (1.0f + erf(x / sqrt(2.0)));
}
}

__global__ void FastllmGeluNewKernel(float* a, float *b, int len) {
int idx = threadIdx.x + blockIdx.x * blockDim.x;
if (idx < len) {
float x = a[idx];
Expand Down Expand Up @@ -3127,7 +3135,7 @@ void FastllmCudaMemcpy2DDeviceToDeviceBatch(void ** dsts, size_t * dpitchs, voi
DeviceSync();
}

bool FastllmCudaGeluNew(const fastllm::Data &input, fastllm::Data &output) {
bool FastllmCudaGelu(const fastllm::Data &input, fastllm::Data &output) {
int len = input.Count(0);
float *cudaInput = (float *) FastllmCudaPrepareInput(input);
float *cudaOutput = (float *) FastllmCudaPrepareOutput(output);
Expand All @@ -3138,6 +3146,17 @@ bool FastllmCudaGeluNew(const fastllm::Data &input, fastllm::Data &output) {
return true;
}

bool FastllmCudaGeluNew(const fastllm::Data &input, fastllm::Data &output) {
int len = input.Count(0);
float *cudaInput = (float *) FastllmCudaPrepareInput(input);
float *cudaOutput = (float *) FastllmCudaPrepareOutput(output);
int threadPerBlock = std::min(256, len);
FastllmGeluNewKernel <<< (len - 1) / threadPerBlock + 1, threadPerBlock>>>(cudaInput, cudaOutput, len);
FastllmCudaFinishInput(input, cudaInput);
FastllmCudaFinishOutput(output, cudaOutput);
return true;
}

bool FastllmCudaSilu(const fastllm::Data &input, fastllm::Data &output) {
int len = input.Count(0);
float *cudaInput = (float *) FastllmCudaPrepareInput(input);
Expand Down

0 comments on commit 02de743

Please sign in to comment.