From 93042cedb38f27272046e3f1cc73bbf0baffebb3 Mon Sep 17 00:00:00 2001 From: Vivek Agrawal Date: Wed, 1 Nov 2023 19:13:39 +0530 Subject: [PATCH 1/7] Welcome Embeddings :rocket: --- .../0147_create_embeddings_table.py | 30 +++++++++++++++++++ contentcuration/contentcuration/models.py | 14 +++++++++ .../contentcuration/tests/test_models.py | 23 ++++++++++++++ docker-compose.yml | 6 +++- docker/Dockerfile.postgres.dev | 4 +++ requirements.in | 1 + requirements.txt | 8 +++-- 7 files changed, 83 insertions(+), 3 deletions(-) create mode 100644 contentcuration/contentcuration/migrations/0147_create_embeddings_table.py create mode 100644 docker/Dockerfile.postgres.dev diff --git a/contentcuration/contentcuration/migrations/0147_create_embeddings_table.py b/contentcuration/contentcuration/migrations/0147_create_embeddings_table.py new file mode 100644 index 0000000000..6371267169 --- /dev/null +++ b/contentcuration/contentcuration/migrations/0147_create_embeddings_table.py @@ -0,0 +1,30 @@ +# Generated by Django 3.2.19 on 2023-11-01 12:00 +from django.db import migrations +from django.db import models +from pgvector.django import VectorExtension +from pgvector.django import VectorField + +import contentcuration.models + + +class Migration(migrations.Migration): + + dependencies = [ + ('contentcuration', '0146_drop_taskresult_fields'), + ] + + operations = [ + VectorExtension(), + migrations.CreateModel( + name='Embeddings', + fields=[ + ('id', models.AutoField(auto_created=True, primary_key=True, serialize=False, verbose_name='ID')), + ('content_id', contentcuration.models.UUIDField(max_length=32)), + ('embedding', VectorField()), + ], + ), + migrations.AddIndex( + model_name='embeddings', + index=models.Index(fields=['content_id'], name='contentcura_content_e09ec2_idx'), + ), + ] diff --git a/contentcuration/contentcuration/models.py b/contentcuration/contentcuration/models.py index 0d73096bfa..fbcf7a02c7 100644 --- a/contentcuration/contentcuration/models.py +++ b/contentcuration/contentcuration/models.py @@ -57,6 +57,7 @@ from mptt.models import MPTTModel from mptt.models import raise_if_unsaved from mptt.models import TreeForeignKey +from pgvector.django import VectorField from postmark.core import PMMailInactiveRecipientException from postmark.core import PMMailUnauthorizedException from rest_framework.authtoken.models import Token @@ -1983,6 +1984,19 @@ class Meta: ] +class Embeddings(models.Model): + """ + A model that caches embeddings. + """ + content_id = UUIDField(primary_key=False) + embedding = VectorField() + + class Meta: + indexes = [ + models.Index(fields=["content_id"]), + ] + + class ContentKind(models.Model): kind = models.CharField(primary_key=True, max_length=200, choices=content_kinds.choices) diff --git a/contentcuration/contentcuration/tests/test_models.py b/contentcuration/contentcuration/tests/test_models.py index d53be0176b..187fbe4a76 100644 --- a/contentcuration/contentcuration/tests/test_models.py +++ b/contentcuration/contentcuration/tests/test_models.py @@ -19,6 +19,7 @@ from contentcuration.models import ChannelSet from contentcuration.models import ContentNode from contentcuration.models import CONTENTNODE_TREE_ID_CACHE_KEY +from contentcuration.models import Embeddings from contentcuration.models import File from contentcuration.models import FILE_DURATION_CONSTRAINT from contentcuration.models import generate_object_storage_name @@ -981,3 +982,25 @@ def test_prune(self): ChannelHistory.prune() self.assertEqual(2, ChannelHistory.objects.count()) self.assertEqual(2, ChannelHistory.objects.filter(id__in=last_history_ids).count()) + + +class EmbeddingsTestCase(StudioTestCase): + @classmethod + def setUpClass(cls): + super(EmbeddingsTestCase, cls).setUpClass() + # Two closely placed vectors i.e. they are similar. + Embeddings.objects.create(content_id=uuid.uuid4().hex, embedding=[2, 3]) + Embeddings.objects.create(content_id=uuid.uuid4().hex, embedding=[2, 2]) + # A vector placed at far distance i.e. not similar to above vectors. + Embeddings.objects.create(content_id=uuid.uuid4().hex, embedding=[4, 1]) + + def test_can_create_embeddings(self): + embeddings_count = Embeddings.objects.count() + self.assertEqual(embeddings_count, 3) + + def test_get_nearest_neighbors(self): + from pgvector.django import L2Distance + import numpy as np + # Get the nearest neighbor of [2, 3] which is [2, 2]. + closest_embedding = list(Embeddings.objects.order_by(L2Distance('embedding', [2, 3]))[1:2])[0] + self.assertTrue(np.array_equal(closest_embedding.embedding, np.array([2, 2]))) diff --git a/docker-compose.yml b/docker-compose.yml index f0b2cb7b86..5796e34bb5 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -53,7 +53,11 @@ services: - .docker/minio:/data postgres: - image: postgres:12 + build: + context: ./docker + dockerfile: Dockerfile.postgres.dev + args: + PG_MAJOR: 12 environment: PGDATA: /var/lib/postgresql/data/pgdata POSTGRES_USER: learningequality diff --git a/docker/Dockerfile.postgres.dev b/docker/Dockerfile.postgres.dev new file mode 100644 index 0000000000..c981c7726d --- /dev/null +++ b/docker/Dockerfile.postgres.dev @@ -0,0 +1,4 @@ +# Installs pgvector to postgres base image. +ARG PG_MAJOR +FROM postgres:$PG_MAJOR +RUN apt-get update && apt-get install -y postgresql-$PG_MAJOR-pgvector diff --git a/requirements.in b/requirements.in index 79d56d527d..30b9879f5a 100644 --- a/requirements.in +++ b/requirements.in @@ -40,3 +40,4 @@ python-dateutil>=2.8.1 jsonschema>=3.2.0 importlib-metadata==1.7.0 django-celery-results +pgvector diff --git a/requirements.txt b/requirements.txt index a10cf719cd..eea4364652 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,6 +1,6 @@ # -# This file is autogenerated by pip-compile with python 3.9 -# To update, run: +# This file is autogenerated by pip-compile with Python 3.9 +# by the following command: # # pip-compile requirements.in # @@ -168,6 +168,8 @@ le-utils==0.2.1 # via -r requirements.in newrelic==6.2.0.156 # via -r requirements.in +numpy==1.26.0 + # via pgvector oauth2client==4.1.3 # via -r requirements.in packaging==20.9 @@ -177,6 +179,8 @@ packaging==20.9 # google-cloud-kms pathlib==1.0.1 # via -r requirements.in +pgvector==0.2.3 + # via -r requirements.in pillow==9.4.0 # via -r requirements.in prometheus-client==0.10.1 From 2acd50cd8663ec7cf6567b6f3ca71b90852cc2ec Mon Sep 17 00:00:00 2001 From: Vivek Agrawal Date: Wed, 8 Nov 2023 16:32:15 +0530 Subject: [PATCH 2/7] New embedding model structure --- .../management/commands/setup.py | 5 +++ .../0147_create_embeddings_table.py | 30 ---------------- .../0147_embeddings_embeddingscontentnode.py | 35 +++++++++++++++++++ contentcuration/contentcuration/models.py | 19 ++++++---- 4 files changed, 52 insertions(+), 37 deletions(-) delete mode 100644 contentcuration/contentcuration/migrations/0147_create_embeddings_table.py create mode 100644 contentcuration/contentcuration/migrations/0147_embeddings_embeddingscontentnode.py diff --git a/contentcuration/contentcuration/management/commands/setup.py b/contentcuration/contentcuration/management/commands/setup.py index 3284349ebe..d032b90942 100644 --- a/contentcuration/contentcuration/management/commands/setup.py +++ b/contentcuration/contentcuration/management/commands/setup.py @@ -4,11 +4,13 @@ from django.core.management import call_command from django.core.management.base import BaseCommand +from django.db import connection from django.db import Error as DBError from le_utils.constants import content_kinds from le_utils.constants import file_formats from le_utils.constants import format_presets from le_utils.constants import licenses +from pgvector.django import VectorExtension from contentcuration.models import ContentNode from contentcuration.models import ContentTag @@ -65,6 +67,9 @@ def handle(self, *args, **options): except DBError as e: logging.error('Error creating cache table: {}'.format(str(e))) + with connection.cursor() as cursor: + cursor.execute("CREATE EXTENSION IF NOT EXISTS %s;" % VectorExtension().name) + # Run migrations call_command('migrate') diff --git a/contentcuration/contentcuration/migrations/0147_create_embeddings_table.py b/contentcuration/contentcuration/migrations/0147_create_embeddings_table.py deleted file mode 100644 index 6371267169..0000000000 --- a/contentcuration/contentcuration/migrations/0147_create_embeddings_table.py +++ /dev/null @@ -1,30 +0,0 @@ -# Generated by Django 3.2.19 on 2023-11-01 12:00 -from django.db import migrations -from django.db import models -from pgvector.django import VectorExtension -from pgvector.django import VectorField - -import contentcuration.models - - -class Migration(migrations.Migration): - - dependencies = [ - ('contentcuration', '0146_drop_taskresult_fields'), - ] - - operations = [ - VectorExtension(), - migrations.CreateModel( - name='Embeddings', - fields=[ - ('id', models.AutoField(auto_created=True, primary_key=True, serialize=False, verbose_name='ID')), - ('content_id', contentcuration.models.UUIDField(max_length=32)), - ('embedding', VectorField()), - ], - ), - migrations.AddIndex( - model_name='embeddings', - index=models.Index(fields=['content_id'], name='contentcura_content_e09ec2_idx'), - ), - ] diff --git a/contentcuration/contentcuration/migrations/0147_embeddings_embeddingscontentnode.py b/contentcuration/contentcuration/migrations/0147_embeddings_embeddingscontentnode.py new file mode 100644 index 0000000000..f6fa1de4f1 --- /dev/null +++ b/contentcuration/contentcuration/migrations/0147_embeddings_embeddingscontentnode.py @@ -0,0 +1,35 @@ +# Generated by Django 3.2.19 on 2023-11-08 09:55 +import uuid + +import django.db.models.deletion +import pgvector.django +from django.db import migrations +from django.db import models + +import contentcuration.models + + +class Migration(migrations.Migration): + + dependencies = [ + ('contentcuration', '0146_drop_taskresult_fields'), + ] + + operations = [ + migrations.CreateModel( + name='EmbeddingsContentNode', + fields=[ + ('cid', models.CharField(max_length=64, primary_key=True, serialize=False)), + ('contentnode', models.ForeignKey(on_delete=django.db.models.deletion.CASCADE, related_name='node_cid', to='contentcuration.contentnode')), + ], + ), + migrations.CreateModel( + name='Embeddings', + fields=[ + ('embedding_id', contentcuration.models.UUIDField(default=uuid.uuid4, max_length=32, primary_key=True, serialize=False)), + ('model', models.CharField(db_index=True, max_length=64)), + ('embedding', pgvector.django.VectorField()), + ('embedded_node', models.ForeignKey(on_delete=django.db.models.deletion.CASCADE, related_name='embeddings', to='contentcuration.embeddingscontentnode')), + ], + ), + ] diff --git a/contentcuration/contentcuration/models.py b/contentcuration/contentcuration/models.py index fbcf7a02c7..4b3ebd9147 100644 --- a/contentcuration/contentcuration/models.py +++ b/contentcuration/contentcuration/models.py @@ -1984,18 +1984,23 @@ class Meta: ] +class EmbeddingsContentNode(models.Model): + """ + A model that stores the canonical contentnode for embedding purposes. + """ + cid = models.CharField(primary_key=True, max_length=64) + contentnode = models.ForeignKey(ContentNode, related_name="node_cid", blank=False, null=False, on_delete=models.CASCADE) + + class Embeddings(models.Model): """ - A model that caches embeddings. + A model that stores generated embeddings. """ - content_id = UUIDField(primary_key=False) + embedding_id = UUIDField(primary_key=True, default=uuid.uuid4) + model = models.CharField(max_length=64, db_index=True) + embedded_node = models.ForeignKey(EmbeddingsContentNode, related_name="embeddings", blank=False, null=False, on_delete=models.CASCADE) embedding = VectorField() - class Meta: - indexes = [ - models.Index(fields=["content_id"]), - ] - class ContentKind(models.Model): kind = models.CharField(primary_key=True, max_length=200, choices=content_kinds.choices) From e003bbc4103ff3f06a82517e2a6d0583d5e68ea8 Mon Sep 17 00:00:00 2001 From: Vivek Agrawal Date: Wed, 8 Nov 2023 17:54:05 +0530 Subject: [PATCH 3/7] Enable extension only during pytest & docs update --- README.md | 2 +- .../management/commands/setup.py | 124 +++++++++--------- .../0147_embeddings_embeddingscontentnode.py | 8 ++ .../contentcuration/tests/test_models.py | 24 +++- docs/host_services_setup.md | 2 +- package.json | 3 +- 6 files changed, 96 insertions(+), 67 deletions(-) diff --git a/README.md b/README.md index aa0a9ad8d5..cb79d7a5e6 100644 --- a/README.md +++ b/README.md @@ -106,7 +106,7 @@ make dcservicesdown ``` ### Initializing Studio -With the services running, in a separate terminal/terminal-tab, we can now initialize the database for Studio development purposes. The command below will initialize the database tables, import constants, and a user account for development: +With the services running, in a separate terminal/terminal-tab, we can now initialize the database for Studio development purposes. The command below will initialize the database tables, import constants, enable required postgres extensions and a studio user account for development: ```bash yarn run devsetup ``` diff --git a/contentcuration/contentcuration/management/commands/setup.py b/contentcuration/contentcuration/management/commands/setup.py index d032b90942..44143534af 100644 --- a/contentcuration/contentcuration/management/commands/setup.py +++ b/contentcuration/contentcuration/management/commands/setup.py @@ -47,6 +47,7 @@ class Command(BaseCommand): def add_arguments(self, parser): parser.add_argument('--email', dest="email", default="a@a.com") parser.add_argument('--password', dest="password", default="a") + parser.add_argument('--clean-data-state', action='store_true', default=False, help='Sets database in clean state.') def handle(self, *args, **options): # Validate email @@ -84,67 +85,68 @@ def handle(self, *args, **options): user2 = create_user("user@b.com", "b", "User", "B") user3 = create_user("user@c.com", "c", "User", "C") - # Create channels - - channel1 = create_channel("Published Channel", DESCRIPTION, editors=[admin], bookmarkers=[user1, user2], public=True) - channel2 = create_channel("Ricecooker Channel", DESCRIPTION, editors=[admin, user1], bookmarkers=[user2], viewers=[user3]) - channel3 = create_channel("Empty Channel", editors=[user3], viewers=[user2]) - channel4 = create_channel("Imported Channel", editors=[admin]) - - # Invite admin to channel 3 - try: - invitation, _new = Invitation.objects.get_or_create( - invited=admin, - sender=user3, - channel=channel3, - email=admin.email, - ) - invitation.share_mode = "edit" - invitation.save() - except MultipleObjectsReturned: - # we don't care, just continue - pass - - # Create pool of tags - tags = [] - for t in TAGS: - tag, _new = ContentTag.objects.get_or_create(tag_name=t, channel=channel1) - - # Generate file objects - document_file = create_file("Sample Document", format_presets.DOCUMENT, file_formats.PDF, user=admin) - video_file = create_file("Sample Video", format_presets.VIDEO_HIGH_RES, file_formats.MP4, user=admin) - subtitle_file = create_file("Sample Subtitle", format_presets.VIDEO_SUBTITLE, file_formats.VTT, user=admin) - audio_file = create_file("Sample Audio", format_presets.AUDIO, file_formats.MP3, user=admin) - html5_file = create_file("Sample HTML", format_presets.HTML5_ZIP, file_formats.HTML5, user=admin) - - # Populate channel 1 with content - generate_tree(channel1.main_tree, document_file, video_file, subtitle_file, audio_file, html5_file, user=admin, tags=tags) - - # Populate channel 2 with staged content - channel2.ricecooker_version = "0.0.0" - channel2.save() - generate_tree(channel2.staging_tree, document_file, video_file, subtitle_file, audio_file, html5_file, user=admin, tags=tags) - - # Import content from channel 1 into channel 4 - channel1.main_tree.children.first().copy_to(channel4.main_tree) - - # Get validation to be reflected in nodes properly - ContentNode.objects.all().update(complete=True) - call_command('mark_incomplete') - - # Mark this node as incomplete even though it is complete - # for testing purposes - node = ContentNode.objects.get(tree_id=channel1.main_tree.tree_id, title="Sample Audio") - node.complete = False - node.save() - - # Publish - publish_channel(admin.id, channel1.pk) - - # Add nodes to clipboard in legacy way - legacy_clipboard_nodes = channel1.main_tree.get_children() - for legacy_node in legacy_clipboard_nodes: - legacy_node.copy_to(target=user1.clipboard_tree) + # Only create additional data when clean-data-state is False (i.e. default behaviour). + if options["clean_data_state"] is False: + # Create channels + channel1 = create_channel("Published Channel", DESCRIPTION, editors=[admin], bookmarkers=[user1, user2], public=True) + channel2 = create_channel("Ricecooker Channel", DESCRIPTION, editors=[admin, user1], bookmarkers=[user2], viewers=[user3]) + channel3 = create_channel("Empty Channel", editors=[user3], viewers=[user2]) + channel4 = create_channel("Imported Channel", editors=[admin]) + + # Invite admin to channel 3 + try: + invitation, _new = Invitation.objects.get_or_create( + invited=admin, + sender=user3, + channel=channel3, + email=admin.email, + ) + invitation.share_mode = "edit" + invitation.save() + except MultipleObjectsReturned: + # we don't care, just continue + pass + + # Create pool of tags + tags = [] + for t in TAGS: + tag, _new = ContentTag.objects.get_or_create(tag_name=t, channel=channel1) + + # Generate file objects + document_file = create_file("Sample Document", format_presets.DOCUMENT, file_formats.PDF, user=admin) + video_file = create_file("Sample Video", format_presets.VIDEO_HIGH_RES, file_formats.MP4, user=admin) + subtitle_file = create_file("Sample Subtitle", format_presets.VIDEO_SUBTITLE, file_formats.VTT, user=admin) + audio_file = create_file("Sample Audio", format_presets.AUDIO, file_formats.MP3, user=admin) + html5_file = create_file("Sample HTML", format_presets.HTML5_ZIP, file_formats.HTML5, user=admin) + + # Populate channel 1 with content + generate_tree(channel1.main_tree, document_file, video_file, subtitle_file, audio_file, html5_file, user=admin, tags=tags) + + # Populate channel 2 with staged content + channel2.ricecooker_version = "0.0.0" + channel2.save() + generate_tree(channel2.staging_tree, document_file, video_file, subtitle_file, audio_file, html5_file, user=admin, tags=tags) + + # Import content from channel 1 into channel 4 + channel1.main_tree.children.first().copy_to(channel4.main_tree) + + # Get validation to be reflected in nodes properly + ContentNode.objects.all().update(complete=True) + call_command('mark_incomplete') + + # Mark this node as incomplete even though it is complete + # for testing purposes + node = ContentNode.objects.get(tree_id=channel1.main_tree.tree_id, title="Sample Audio") + node.complete = False + node.save() + + # Publish + publish_channel(admin.id, channel1.pk) + + # Add nodes to clipboard in legacy way + legacy_clipboard_nodes = channel1.main_tree.get_children() + for legacy_node in legacy_clipboard_nodes: + legacy_node.copy_to(target=user1.clipboard_tree) print("\n\n\nSETUP DONE: Log in as admin to view data (email: {}, password: {})\n\n\n".format(email, password)) diff --git a/contentcuration/contentcuration/migrations/0147_embeddings_embeddingscontentnode.py b/contentcuration/contentcuration/migrations/0147_embeddings_embeddingscontentnode.py index f6fa1de4f1..9fde6731d1 100644 --- a/contentcuration/contentcuration/migrations/0147_embeddings_embeddingscontentnode.py +++ b/contentcuration/contentcuration/migrations/0147_embeddings_embeddingscontentnode.py @@ -3,6 +3,7 @@ import django.db.models.deletion import pgvector.django +from django.conf import settings from django.db import migrations from django.db import models @@ -33,3 +34,10 @@ class Migration(migrations.Migration): ], ), ] + + # Enable Pgvector postgres extension only when pytest is running. + # For development, its enabled via devsetup management command. + # For production, it'll be enabled manually. + if getattr(settings, "TEST_ENV", False) is True: + from pgvector.django import VectorExtension + operations.insert(0, VectorExtension()) diff --git a/contentcuration/contentcuration/tests/test_models.py b/contentcuration/contentcuration/tests/test_models.py index 187fbe4a76..77c3826483 100644 --- a/contentcuration/contentcuration/tests/test_models.py +++ b/contentcuration/contentcuration/tests/test_models.py @@ -20,6 +20,7 @@ from contentcuration.models import ContentNode from contentcuration.models import CONTENTNODE_TREE_ID_CACHE_KEY from contentcuration.models import Embeddings +from contentcuration.models import EmbeddingsContentNode from contentcuration.models import File from contentcuration.models import FILE_DURATION_CONSTRAINT from contentcuration.models import generate_object_storage_name @@ -988,11 +989,28 @@ class EmbeddingsTestCase(StudioTestCase): @classmethod def setUpClass(cls): super(EmbeddingsTestCase, cls).setUpClass() + node_1 = testdata.node({ + "kind_id": "video", + "title": "first" + }) + node_2 = testdata.node({ + "kind_id": "video", + "title": "second" + }) + node_3 = testdata.node({ + "kind_id": "video", + "title": "third" + }) + + embedded_node_1 = EmbeddingsContentNode.objects.create(cid=node_1.content_id, contentnode=node_1) + embedded_node_2 = EmbeddingsContentNode.objects.create(cid=node_2.content_id, contentnode=node_2) + embedded_node_3 = EmbeddingsContentNode.objects.create(cid=node_3.content_id, contentnode=node_3) + # Two closely placed vectors i.e. they are similar. - Embeddings.objects.create(content_id=uuid.uuid4().hex, embedding=[2, 3]) - Embeddings.objects.create(content_id=uuid.uuid4().hex, embedding=[2, 2]) + Embeddings.objects.create(model="studio-embedder-v1.0", embedded_node=embedded_node_1, embedding=[2, 3]) + Embeddings.objects.create(model="studio-embedder-v1.0", embedded_node=embedded_node_2, embedding=[2, 2]) # A vector placed at far distance i.e. not similar to above vectors. - Embeddings.objects.create(content_id=uuid.uuid4().hex, embedding=[4, 1]) + Embeddings.objects.create(model="studio-embedder-v1.0", embedded_node=embedded_node_3, embedding=[4, 1]) def test_can_create_embeddings(self): embeddings_count = Embeddings.objects.count() diff --git a/docs/host_services_setup.md b/docs/host_services_setup.md index 7ded077aa7..15bd2a2cf2 100644 --- a/docs/host_services_setup.md +++ b/docs/host_services_setup.md @@ -130,7 +130,7 @@ yarn run services ``` ## Initializing Studio -With the services running, in a separate terminal/terminal-tab, we can now initialize the database for Studio development purposes. The command below will initialize the database, in addition to adding a user account for development: +With the services running, in a separate terminal/terminal-tab, we can now initialize the database for Studio development purposes. The command below will initialize the database then it will enable the required postgres extensions in addition to adding a user account for development: ```bash yarn run devsetup ``` diff --git a/package.json b/package.json index 2c7ebd30dc..816e0b16e9 100644 --- a/package.json +++ b/package.json @@ -15,7 +15,8 @@ "build": "webpack --env prod --config webpack.config.js", "postgres": "pg_ctl -D /usr/local/var/postgresql@9.6 start || true", "redis": "redis-server /usr/local/etc/redis.conf || true", - "devsetup": "cd contentcuration && python manage.py setup --settings=contentcuration.dev_settings", + "devsetup": "python contentcuration/manage.py setup --settings=contentcuration.dev_settings", + "devsetup:clean": "python contentcuration/manage.py setup --clean-data-state --settings=contentcuration.dev_settings", "services": "npm-run-all -c --parallel --silent celery minio redis postgres", "test": "jest --config jest_config/jest.conf.js", "build:dev": "webpack serve --env dev --config webpack.config.js --progress", From 3735b971451069e2971f90128c507164fb176f16 Mon Sep 17 00:00:00 2001 From: Vivek Agrawal Date: Fri, 10 Nov 2023 20:48:33 +0530 Subject: [PATCH 4/7] Use custom db backend for pytest instead of hacking migrations --- .../management/commands/setup.py | 12 +++++-- .../0147_embeddings_embeddingscontentnode.py | 7 ---- .../contentcuration/test_settings.py | 2 ++ .../tests/custom_pytest_db_backend/base.py | 32 +++++++++++++++++++ 4 files changed, 44 insertions(+), 9 deletions(-) create mode 100644 contentcuration/contentcuration/tests/custom_pytest_db_backend/base.py diff --git a/contentcuration/contentcuration/management/commands/setup.py b/contentcuration/contentcuration/management/commands/setup.py index 44143534af..78c94fa2d7 100644 --- a/contentcuration/contentcuration/management/commands/setup.py +++ b/contentcuration/contentcuration/management/commands/setup.py @@ -42,6 +42,14 @@ SORT_ORDER = 0 +def enable_pgvector_extension(connection): + """ + Enables pgvector extension in postgres. + """ + with connection.cursor() as cursor: + cursor.execute("CREATE EXTENSION IF NOT EXISTS %s;" % VectorExtension().name) + + class Command(BaseCommand): def add_arguments(self, parser): @@ -68,8 +76,8 @@ def handle(self, *args, **options): except DBError as e: logging.error('Error creating cache table: {}'.format(str(e))) - with connection.cursor() as cursor: - cursor.execute("CREATE EXTENSION IF NOT EXISTS %s;" % VectorExtension().name) + # Enable pgvector extension. + enable_pgvector_extension(connection) # Run migrations call_command('migrate') diff --git a/contentcuration/contentcuration/migrations/0147_embeddings_embeddingscontentnode.py b/contentcuration/contentcuration/migrations/0147_embeddings_embeddingscontentnode.py index 9fde6731d1..d00f289b85 100644 --- a/contentcuration/contentcuration/migrations/0147_embeddings_embeddingscontentnode.py +++ b/contentcuration/contentcuration/migrations/0147_embeddings_embeddingscontentnode.py @@ -34,10 +34,3 @@ class Migration(migrations.Migration): ], ), ] - - # Enable Pgvector postgres extension only when pytest is running. - # For development, its enabled via devsetup management command. - # For production, it'll be enabled manually. - if getattr(settings, "TEST_ENV", False) is True: - from pgvector.django import VectorExtension - operations.insert(0, VectorExtension()) diff --git a/contentcuration/contentcuration/test_settings.py b/contentcuration/contentcuration/test_settings.py index a1fcf20e6b..556a4e13ae 100644 --- a/contentcuration/contentcuration/test_settings.py +++ b/contentcuration/contentcuration/test_settings.py @@ -11,3 +11,5 @@ INSTALLED_APPS += ("django_concurrent_tests",) # noqa F405 MANAGE_PY_PATH = "./contentcuration/manage.py" + +DATABASES["default"]["ENGINE"] = "contentcuration.tests.custom_pytest_db_backend" # noqa diff --git a/contentcuration/contentcuration/tests/custom_pytest_db_backend/base.py b/contentcuration/contentcuration/tests/custom_pytest_db_backend/base.py new file mode 100644 index 0000000000..d15926c916 --- /dev/null +++ b/contentcuration/contentcuration/tests/custom_pytest_db_backend/base.py @@ -0,0 +1,32 @@ +from django.db.backends.postgresql.base import DatabaseWrapper as PostgresDatabaseWrapper +from django.db.backends.postgresql.creation import DatabaseCreation as PostgresDatabaseCreation + + +class CustomDBCreationForPytest(PostgresDatabaseCreation): + """ + Overriding creation module to enable postgres pgvector extension before + pytest runs migration. Because embeddings table rely on pgvector + extension to work. + """ + + def _create_test_db(self, verbosity, autoclobber, keepdb=False): + from contentcuration.management.commands.setup import enable_pgvector_extension + + # Create test database and get its name. + test_db_name = super()._create_test_db(verbosity, autoclobber, keepdb) + + # Close current nodb_cursor connection and point connection to the + # newly created test database. + self.connection.close() + self.connection.settings_dict["NAME"] = test_db_name + + # Enable pgvector extension. + enable_pgvector_extension(self.connection) + + +class DatabaseWrapper(PostgresDatabaseWrapper): + """ + A database wrapper to customise creation module. Rest everything is same as + Postgres. + """ + creation_class = CustomDBCreationForPytest From d1e38c56d94ebb33f791a678b6db3e31828e06c1 Mon Sep 17 00:00:00 2001 From: Blaine Jester Date: Mon, 18 Dec 2023 14:32:00 -0800 Subject: [PATCH 5/7] Add creds to PG service in workflow and update postgres image --- .github/workflows/pythontest.yml | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/.github/workflows/pythontest.yml b/.github/workflows/pythontest.yml index 84154ab71c..25ea5f9f8b 100644 --- a/.github/workflows/pythontest.yml +++ b/.github/workflows/pythontest.yml @@ -32,7 +32,10 @@ jobs: # Label used to access the service container postgres: # Docker Hub image - image: postgres:12 + image: ghcr.io/learningequality/postgres:${{ github.base_ref || github.ref_name }} + credentials: + username: ${{ github.actor }} + password: ${{ secrets.GITHUB_TOKEN }} # Provide the password for postgres env: POSTGRES_USER: learningequality From e3cb4c02976963c46f44b0276a4e47dfb290544a Mon Sep 17 00:00:00 2001 From: Vivek Agrawal Date: Thu, 21 Dec 2023 18:11:39 +0530 Subject: [PATCH 6/7] Return test db name, now its perfect. --- .../contentcuration/tests/custom_pytest_db_backend/base.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/contentcuration/contentcuration/tests/custom_pytest_db_backend/base.py b/contentcuration/contentcuration/tests/custom_pytest_db_backend/base.py index d15926c916..b4318cec38 100644 --- a/contentcuration/contentcuration/tests/custom_pytest_db_backend/base.py +++ b/contentcuration/contentcuration/tests/custom_pytest_db_backend/base.py @@ -23,6 +23,8 @@ def _create_test_db(self, verbosity, autoclobber, keepdb=False): # Enable pgvector extension. enable_pgvector_extension(self.connection) + return self.connection.settings_dict["NAME"] + class DatabaseWrapper(PostgresDatabaseWrapper): """ From 4376391cf182a89cca9dfb63fd5101670ff84a09 Mon Sep 17 00:00:00 2001 From: Vivek Agrawal Date: Sat, 23 Dec 2023 17:41:10 +0530 Subject: [PATCH 7/7] Create and persist bucket --- contentcuration/contentcuration/tests/test_models.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/contentcuration/contentcuration/tests/test_models.py b/contentcuration/contentcuration/tests/test_models.py index 77c3826483..e8d1747813 100644 --- a/contentcuration/contentcuration/tests/test_models.py +++ b/contentcuration/contentcuration/tests/test_models.py @@ -986,9 +986,12 @@ def test_prune(self): class EmbeddingsTestCase(StudioTestCase): + persist_bucket = True + @classmethod def setUpClass(cls): super(EmbeddingsTestCase, cls).setUpClass() + cls.create_bucket() node_1 = testdata.node({ "kind_id": "video", "title": "first" @@ -1012,6 +1015,11 @@ def setUpClass(cls): # A vector placed at far distance i.e. not similar to above vectors. Embeddings.objects.create(model="studio-embedder-v1.0", embedded_node=embedded_node_3, embedding=[4, 1]) + @classmethod + def tearDownClass(cls): + super(EmbeddingsTestCase, cls).tearDownClass() + cls.delete_bucket() + def test_can_create_embeddings(self): embeddings_count = Embeddings.objects.count() self.assertEqual(embeddings_count, 3)