Skip to content

Commit

Permalink
Add the complex method.
Browse files Browse the repository at this point in the history
  • Loading branch information
yjcyxky committed Jun 30, 2024
1 parent 1c11927 commit e42e2a7
Show file tree
Hide file tree
Showing 3 changed files with 139 additions and 10 deletions.
4 changes: 2 additions & 2 deletions docker/Dockerfile.tgmc
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,12 @@ RUN curl https://www.postgresql.org/media/keys/ACCC4CF8.asc | gpg --dearmor | te
ENV TZ=UTC
ENV DEBIAN_FRONTEND=noninteractive
RUN apt update -y && apt install git wget -y
RUN wget https://github.com/yjcyxky/postgresml/releases/download/v2.8.3-a1fcad1/postgresql-pgml-15_2.8.3-a1fcad1-ubuntu22.04-amd64.deb && (dpkg -i ./postgresql-pgml-*.deb || apt-get install -f -y) && git clone --branch v0.5.0 https://github.com/pgvector/pgvector && \
RUN wget https://github.com/yjcyxky/postgresml/releases/download/v2.8.3-1c11927/postgresql-pgml-15_2.8.3-1c11927-ubuntu22.04-amd64.deb && (dpkg -i ./postgresql-pgml-*.deb || apt-get install -f -y) && git clone --branch v0.5.0 https://github.com/pgvector/pgvector && \
cd pgvector && \
echo "trusted = true" >> vector.control && \
make && \
make install
RUN wget https://github.com/yjcyxky/postgresml/releases/download/v2.8.3-a1fcad1/postgresml-15-2.8.3-a1fcad1-ubuntu22.04-all.deb && dpkg -i ./postgresml-*-all.deb || apt-get install -f -y
RUN wget https://github.com/yjcyxky/postgresml/releases/download/v2.8.3-1c11927/postgresml-15-2.8.3-1c11927-ubuntu22.04-all.deb && dpkg -i ./postgresml-*-all.deb || apt-get install -f -y
# RUN wget https://github.com/yjcyxky/postgresml/releases/download/2.8.3/postgresml-dashboard-2.8.3-ubuntu22.04-amd64.deb && dpkg -i ./postgresml-dashboard-*.deb || apt-get install -f -y

COPY entrypoint.sh /app/entrypoint.sh
Expand Down
44 changes: 36 additions & 8 deletions pgml-extension/sql/pgml--2.8.2--2.8.3.sql
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ AS 'MODULE_PATHNAME', 'median_wrapper';

-- src/kge.rs
-- pgml::kge::transe_l2_parallel
DROP FUNCTION IF EXISTS pgml."transe_l2_parallel"(real[], real[], real[], real, bool);
DROP FUNCTION IF EXISTS pgml."transe_l2_parallel"(real[], real[], real[], real, bool, bool);
CREATE FUNCTION pgml."transe_l2_parallel"(
"head_vector" REAL[] /* f32[] */,
"relation_vector" REAL[] /* f32[] */,
Expand All @@ -45,7 +45,7 @@ AS 'MODULE_PATHNAME', 'transe_l2_parallel_wrapper';

-- src/kge.rs
-- pgml::kge::transe_l2_ndarray
DROP FUNCTION IF EXISTS pgml."transe_l2_ndarray"(real[], real[], real[], real, bool);
DROP FUNCTION IF EXISTS pgml."transe_l2_ndarray"(real[], real[], real[], real, bool, bool);
CREATE FUNCTION pgml."transe_l2_ndarray"(
"head_vector" REAL[] /* f32[] */,
"relation_vector" REAL[] /* f32[] */,
Expand All @@ -60,7 +60,7 @@ AS 'MODULE_PATHNAME', 'transe_l2_ndarray_wrapper';

-- src/kge.rs
-- pgml::kge::transe_l2
DROP FUNCTION IF EXISTS pgml."transe_l2"(real[], real[], real[], real, bool);
DROP FUNCTION IF EXISTS pgml."transe_l2"(real[], real[], real[], real, bool, bool);
CREATE FUNCTION pgml."transe_l2"(
"head_vector" REAL[] /* f32[] */,
"relation_vector" REAL[] /* f32[] */,
Expand All @@ -75,7 +75,7 @@ AS 'MODULE_PATHNAME', 'transe_l2_wrapper';

-- src/kge.rs
-- pgml::kge::transe_l1
DROP FUNCTION IF EXISTS pgml."transe_l1"(real[], real[], real[], real, bool);
DROP FUNCTION IF EXISTS pgml."transe_l1"(real[], real[], real[], real, bool, bool);
CREATE FUNCTION pgml."transe_l1"(
"head_vector" REAL[] /* f32[] */,
"relation_vector" REAL[] /* f32[] */,
Expand All @@ -90,7 +90,7 @@ AS 'MODULE_PATHNAME', 'transe_l1_wrapper';

-- src/kge.rs
-- pgml::kge::transe_l1_ndarray
DROP FUNCTION IF EXISTS pgml."transe_l1_ndarray"(real[], real[], real[], real, bool);
DROP FUNCTION IF EXISTS pgml."transe_l1_ndarray"(real[], real[], real[], real, bool, bool);
CREATE FUNCTION pgml."transe_l1_ndarray"(
"head_vector" REAL[] /* f32[] */,
"relation_vector" REAL[] /* f32[] */,
Expand All @@ -105,7 +105,7 @@ AS 'MODULE_PATHNAME', 'transe_l1_ndarray_wrapper';

-- src/kge.rs
-- pgml::kge::transe_l1_parallel
DROP FUNCTION IF EXISTS pgml."transe_l1_parallel"(real[], real[], real[], real, bool);
DROP FUNCTION IF EXISTS pgml."transe_l1_parallel"(real[], real[], real[], real, bool, bool);
CREATE FUNCTION pgml."transe_l1_parallel"(
"head_vector" REAL[] /* f32[] */,
"relation_vector" REAL[] /* f32[] */,
Expand All @@ -120,7 +120,7 @@ AS 'MODULE_PATHNAME', 'transe_l1_parallel_wrapper';

-- src/kge.rs
-- pgml::kge::distmult
DROP FUNCTION IF EXISTS pgml."distmult"(real[], real[], real[], bool);
DROP FUNCTION IF EXISTS pgml."distmult"(real[], real[], real[], bool, bool);
CREATE FUNCTION pgml."distmult"(
"head_vector" REAL[] /* f32[] */,
"relation_vector" REAL[] /* f32[] */,
Expand All @@ -134,7 +134,7 @@ AS 'MODULE_PATHNAME', 'distmult_wrapper';

-- src/kge.rs
-- pgml::kge::distmult_ndarray
DROP FUNCTION IF EXISTS pgml."distmult_ndarray"(real[], real[], real[], bool);
DROP FUNCTION IF EXISTS pgml."distmult_ndarray"(real[], real[], real[], bool, bool);
CREATE FUNCTION pgml."distmult_ndarray"(
"head_vector" REAL[] /* f32[] */,
"relation_vector" REAL[] /* f32[] */,
Expand All @@ -145,3 +145,31 @@ CREATE FUNCTION pgml."distmult_ndarray"(
IMMUTABLE STRICT PARALLEL SAFE
LANGUAGE c /* Rust */
AS 'MODULE_PATHNAME', 'distmult_ndarray_wrapper';

-- src/kge.rs
-- pgml::kge::complex
DROP FUNCTION IF EXISTS pgml."complex"(real[], real[], real[], bool, bool);
CREATE FUNCTION pgml."complex"(
"head_vector" REAL[] /* f32[] */,
"relation_vector" REAL[] /* f32[] */,
"tail_vector" REAL[] /* f32[] */,
"exp_enabled" BOOLEAN,
"reverse" BOOLEAN
) RETURNS REAL /* f32 */
IMMUTABLE STRICT PARALLEL SAFE
LANGUAGE c /* Rust */
AS 'MODULE_PATHNAME', 'complex_wrapper';

-- src/kge.rs
-- pgml::kge::complex_ndarray
DROP FUNCTION IF EXISTS pgml."complex_ndarray"(real[], real[], real[], bool, bool);
CREATE FUNCTION pgml."complex_ndarray"(
"head_vector" REAL[] /* f32[] */,
"relation_vector" REAL[] /* f32[] */,
"tail_vector" REAL[] /* f32[] */,
"exp_enabled" BOOLEAN,
"reverse" BOOLEAN
) RETURNS REAL /* f32 */
IMMUTABLE STRICT PARALLEL SAFE
LANGUAGE c /* Rust */
AS 'MODULE_PATHNAME', 'complex_ndarray_wrapper';
101 changes: 101 additions & 0 deletions pgml-extension/src/kge.rs
Original file line number Diff line number Diff line change
Expand Up @@ -378,6 +378,107 @@ fn distmult_ndarray(
}
}

#[pg_extern(immutable, parallel_safe, strict, name = "complex")]
fn complex(
head_array: Array<f32>,
relation_array: Array<f32>,
tail_array: Array<f32>,
exp_enabled: bool,
reverse: bool,
) -> f32 {
let len = head_array.len();
if len % 2 != 0 || len != relation_array.len() || len != tail_array.len() {
error!("The length of the head, relation, and tail arrays must be the same and even.");
}

let half_len = len / 2;

let head_real = Array1::from_vec(head_array.iter_deny_null().take(half_len).cloned().collect());
let head_img = Array1::from_vec(head_array.iter_deny_null().skip(half_len).cloned().collect());

let tail_real = Array1::from_vec(tail_array.iter_deny_null().take(half_len).cloned().collect());
let tail_img = Array1::from_vec(tail_array.iter_deny_null().skip(half_len).cloned().collect());

let rel_real = Array1::from_vec(relation_array.iter_deny_null().take(half_len).cloned().collect());
let rel_img = Array1::from_vec(relation_array.iter_deny_null().skip(half_len).cloned().collect());

let score = if reverse {
(&tail_real * &rel_real - &tail_img * &rel_img) * &head_real
+ (&tail_real * &rel_img + &tail_img * &rel_real) * &head_img
} else {
(&head_real * &rel_real - &head_img * &rel_img) * &tail_real
+ (&head_real * &rel_img + &head_img * &rel_real) * &tail_img
};

let score_sum: f32 = score.sum();

if exp_enabled {
exp(logsigmoid(score_sum))
} else {
logsigmoid(score_sum)
}
}

#[pg_extern(immutable, parallel_safe, strict, name = "complex_ndarray")]
fn complex_ndarray(
head: Array<f32>,
rel: Array<f32>,
tails: Array<f32>,
exp_enabled: bool,
reverse: bool,
) -> Vec<Option<f32>> {
let len = head.len();
if len % 2 != 0 || len != rel.len() {
error!("The length of the head and relation arrays must be the same and even.");
}

let half_len = len / 2;
let tails_len = tails.len();
if tails_len % len != 0 {
error!("The length of the tail array must be a multiple of the head array.");
}

let head_real = Array1::from_vec(head.iter_deny_null().take(half_len).cloned().collect());
let head_img = Array1::from_vec(head.iter_deny_null().skip(half_len).cloned().collect());

let rel_real = Array1::from_vec(rel.iter_deny_null().take(half_len).cloned().collect());
let rel_img = Array1::from_vec(rel.iter_deny_null().skip(half_len).cloned().collect());

let tails_real = tails.iter_deny_null().cloned().collect::<Vec<f32>>();
let tails_img = tails.iter_deny_null().skip(half_len).cloned().collect::<Vec<f32>>();

let tails_real = Array2::from_shape_vec((tails_len / len, half_len), tails_real).unwrap();
let tails_img = Array2::from_shape_vec((tails_len / len, half_len), tails_img).unwrap();

let score = if reverse {
(head_real.broadcast(tails_real.dim()).unwrap() * rel_real.broadcast(tails_real.dim()).unwrap()
- head_img.broadcast(tails_img.dim()).unwrap() * rel_img.broadcast(tails_img.dim()).unwrap())
* tails_real
+ (head_real.broadcast(tails_real.dim()).unwrap() * rel_img.broadcast(tails_img.dim()).unwrap()
+ head_img.broadcast(tails_img.dim()).unwrap() * rel_real.broadcast(tails_real.dim()).unwrap())
* tails_img
} else {
(head_real.broadcast(tails_real.dim()).unwrap() * rel_real.broadcast(tails_real.dim()).unwrap()
- head_img.broadcast(tails_img.dim()).unwrap() * rel_img.broadcast(tails_img.dim()).unwrap())
* tails_real
+ (head_real.broadcast(tails_real.dim()).unwrap() * rel_img.broadcast(tails_img.dim()).unwrap()
+ head_img.broadcast(tails_img.dim()).unwrap() * rel_real.broadcast(tails_real.dim()).unwrap())
* tails_img
};

let scores_sum = score.sum_axis(Axis(1));

let adjusted_scores = scores_sum.mapv(|x| gamma - x);

let result = if exp_enabled {
logsigmoid_vectorized(&adjusted_scores).mapv(|x| E.powf(x))
} else {
logsigmoid_vectorized(&adjusted_scores)
};

result.iter().map(|&x| Some(x)).collect()
}

#[cfg(any(test, feature = "pg_test"))]
#[pg_schema]
mod tests {
Expand Down

0 comments on commit e42e2a7

Please sign in to comment.