Skip to content

Commit

Permalink
Fix regex parser issue for parsing functions having SQL body with lan…
Browse files Browse the repository at this point in the history
…guage sql (PG15 feature)
  • Loading branch information
priyanshi-yb committed Jan 16, 2025
1 parent 8b91860 commit 860e99a
Show file tree
Hide file tree
Showing 2 changed files with 93 additions and 3 deletions.
3 changes: 3 additions & 0 deletions yb-voyager/cmd/analyzeSchema.go
Original file line number Diff line number Diff line change
Expand Up @@ -921,6 +921,9 @@ sqlParsingLoop:
} else if matches := dollarQuoteRegex.FindStringSubmatch(currLine); matches != nil {
dollarQuoteFlag = 1 //denotes start of the code/body part
codeBlockDelimiter = matches[0]
} else if strings.Contains(currLine, "BEGIN ATOMIC") {
dollarQuoteFlag = 1 //denotes start of the sql body part https://www.postgresql.org/docs/15/sql-createfunction.html#:~:text=a%20new%20session.-,sql_body,-The%20body%20of
codeBlockDelimiter = "END"
}
case CODE_BLOCK_STARTED:
if strings.Contains(currLine, codeBlockDelimiter) {
Expand Down
93 changes: 90 additions & 3 deletions yb-voyager/cmd/analyzeSchema_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -76,13 +76,12 @@ CREATE TABLE another_table (
t.Errorf("Error creating file for the objType %s: %v", objType, err)
}


defer os.Remove(sqlFile.Name())

sqlInfoArr := parseSqlFileForObjectType(sqlFile.Name(), objType)
// Validate the number of SQL statements found
if len(sqlInfoArr) != len(expectedSqlInfoArr) {
t.Errorf("Expected %d SQL statements for %s, got %d",len(expectedSqlInfoArr), objType, len(sqlInfoArr))
t.Errorf("Expected %d SQL statements for %s, got %d", len(expectedSqlInfoArr), objType, len(sqlInfoArr))
}

for i, expectedSqlInfo := range expectedSqlInfoArr {
Expand Down Expand Up @@ -140,9 +139,97 @@ $$ LANGUAGE plpgsql;`,

// Validate the number of SQL statements found
if len(sqlInfoArr) != len(expectedSqlInfoArr) {
t.Errorf("Expected %d SQL statements for %s, got %d",len(expectedSqlInfoArr), objType, len(sqlInfoArr))
t.Errorf("Expected %d SQL statements for %s, got %d", len(expectedSqlInfoArr), objType, len(sqlInfoArr))
}

for i, expectedSqlInfo := range expectedSqlInfoArr {
assert.Equal(t, expectedSqlInfo.objName, sqlInfoArr[i].objName)
assert.Equal(t, expectedSqlInfo.stmt, sqlInfoArr[i].stmt)
assert.Equal(t, expectedSqlInfo.formattedStmt, sqlInfoArr[i].formattedStmt)
}

}

func TestFunctionSQLFile(t *testing.T) {
functionFileContent := `CREATE FUNCTION public.asterisks(n integer) RETURNS SETOF text
LANGUAGE sql IMMUTABLE STRICT PARALLEL SAFE
BEGIN ATOMIC
SELECT repeat('*'::text, g.g) AS repeat
FROM generate_series(1, asterisks.n) g(g);
END;
CREATE OR REPLACE FUNCTION copy_high_earners(threshold NUMERIC) RETURNS VOID AS $$
DECLARE
temp_salary employees.salary%TYPE;
BEGIN
CREATE TEMP TABLE temp_high_earners AS
SELECT * FROM employees WHERE salary > threshold;
FOR temp_salary IN SELECT salary FROM temp_high_earners LOOP
RAISE NOTICE 'High earner salary: %', temp_salary;
END LOOP;
END;
$$ LANGUAGE plpgsql;
CREATE FUNCTION add(int, int) RETURNS int IMMUTABLE PARALLEL SAFE BEGIN ATOMIC; SELECT $1 + $2; END;
CREATE FUNCTION public.asterisks1(n integer) RETURNS text
LANGUAGE sql IMMUTABLE STRICT PARALLEL SAFE
RETURN repeat('*'::text, n);`

expectedSqlInfoArr := []sqlInfo{
sqlInfo{
objName: "public.asterisks",
stmt: "CREATE FUNCTION public.asterisks(n integer) RETURNS SETOF text LANGUAGE sql IMMUTABLE STRICT PARALLEL SAFE BEGIN ATOMIC SELECT repeat('*'::text, g.g) AS repeat FROM generate_series(1, asterisks.n) g(g); END; ",
formattedStmt: `CREATE FUNCTION public.asterisks(n integer) RETURNS SETOF text
LANGUAGE sql IMMUTABLE STRICT PARALLEL SAFE
BEGIN ATOMIC
SELECT repeat('*'::text, g.g) AS repeat
FROM generate_series(1, asterisks.n) g(g);
END;`,
},
sqlInfo{
objName: "copy_high_earners",
stmt: "CREATE OR REPLACE FUNCTION copy_high_earners(threshold NUMERIC) RETURNS VOID AS $$ DECLARE temp_salary employees.salary%TYPE; BEGIN CREATE TEMP TABLE temp_high_earners AS SELECT * FROM employees WHERE salary > threshold; FOR temp_salary IN SELECT salary FROM temp_high_earners LOOP RAISE NOTICE 'High earner salary: %', temp_salary; END LOOP; END; $$ LANGUAGE plpgsql; ",
formattedStmt: `CREATE OR REPLACE FUNCTION copy_high_earners(threshold NUMERIC) RETURNS VOID AS $$
DECLARE
temp_salary employees.salary%TYPE;
BEGIN
CREATE TEMP TABLE temp_high_earners AS
SELECT * FROM employees WHERE salary > threshold;
FOR temp_salary IN SELECT salary FROM temp_high_earners LOOP
RAISE NOTICE 'High earner salary: %', temp_salary;
END LOOP;
END;
$$ LANGUAGE plpgsql;`,
},
sqlInfo{
objName: "add",
stmt: "CREATE FUNCTION add(int, int) RETURNS int IMMUTABLE PARALLEL SAFE BEGIN ATOMIC; SELECT $1 + $2; END; ",
formattedStmt: `CREATE FUNCTION add(int, int) RETURNS int IMMUTABLE PARALLEL SAFE BEGIN ATOMIC; SELECT $1 + $2; END;`,
},
sqlInfo{
objName: "public.asterisks1",
stmt: "CREATE FUNCTION public.asterisks1(n integer) RETURNS text LANGUAGE sql IMMUTABLE STRICT PARALLEL SAFE RETURN repeat('*'::text, n); ",
formattedStmt: `CREATE FUNCTION public.asterisks1(n integer) RETURNS text
LANGUAGE sql IMMUTABLE STRICT PARALLEL SAFE
RETURN repeat('*'::text, n);`,
},
}
objType := "FUNCTION"
sqlFile, err := setupFile(objType, functionFileContent)
if err != nil {
t.Errorf("Error creating file for the objType %s: %v", objType, err)
}

defer os.Remove(sqlFile.Name())

sqlInfoArr := parseSqlFileForObjectType(sqlFile.Name(), objType)

// Validate the number of SQL statements found
if len(sqlInfoArr) != len(expectedSqlInfoArr) {
t.Errorf("Expected %d SQL statements for %s, got %d", len(expectedSqlInfoArr), objType, len(sqlInfoArr))
}

for i, expectedSqlInfo := range expectedSqlInfoArr {
assert.Equal(t, expectedSqlInfo.objName, sqlInfoArr[i].objName)
assert.Equal(t, expectedSqlInfo.stmt, sqlInfoArr[i].stmt)
Expand Down

0 comments on commit 860e99a

Please sign in to comment.