diff --git a/.github/workflows/pythontest.yml b/.github/workflows/pythontest.yml index 443e445b4e..e156ed0bbd 100644 --- a/.github/workflows/pythontest.yml +++ b/.github/workflows/pythontest.yml @@ -32,7 +32,10 @@ jobs: # Label used to access the service container postgres: # Docker Hub image - image: postgres:12 + image: ghcr.io/learningequality/postgres:${{ github.base_ref || github.ref_name }} + credentials: + username: ${{ github.actor }} + password: ${{ secrets.GITHUB_TOKEN }} # Provide the password for postgres env: POSTGRES_USER: learningequality diff --git a/contentcuration/contentcuration/management/commands/setup.py b/contentcuration/contentcuration/management/commands/setup.py index 3284349ebe..78c94fa2d7 100644 --- a/contentcuration/contentcuration/management/commands/setup.py +++ b/contentcuration/contentcuration/management/commands/setup.py @@ -4,11 +4,13 @@ from django.core.management import call_command from django.core.management.base import BaseCommand +from django.db import connection from django.db import Error as DBError from le_utils.constants import content_kinds from le_utils.constants import file_formats from le_utils.constants import format_presets from le_utils.constants import licenses +from pgvector.django import VectorExtension from contentcuration.models import ContentNode from contentcuration.models import ContentTag @@ -40,11 +42,20 @@ SORT_ORDER = 0 +def enable_pgvector_extension(connection): + """ + Enables pgvector extension in postgres. + """ + with connection.cursor() as cursor: + cursor.execute("CREATE EXTENSION IF NOT EXISTS %s;" % VectorExtension().name) + + class Command(BaseCommand): def add_arguments(self, parser): parser.add_argument('--email', dest="email", default="a@a.com") parser.add_argument('--password', dest="password", default="a") + parser.add_argument('--clean-data-state', action='store_true', default=False, help='Sets database in clean state.') def handle(self, *args, **options): # Validate email @@ -65,6 +76,9 @@ def handle(self, *args, **options): except DBError as e: logging.error('Error creating cache table: {}'.format(str(e))) + # Enable pgvector extension. + enable_pgvector_extension(connection) + # Run migrations call_command('migrate') @@ -79,67 +93,68 @@ def handle(self, *args, **options): user2 = create_user("user@b.com", "b", "User", "B") user3 = create_user("user@c.com", "c", "User", "C") - # Create channels - - channel1 = create_channel("Published Channel", DESCRIPTION, editors=[admin], bookmarkers=[user1, user2], public=True) - channel2 = create_channel("Ricecooker Channel", DESCRIPTION, editors=[admin, user1], bookmarkers=[user2], viewers=[user3]) - channel3 = create_channel("Empty Channel", editors=[user3], viewers=[user2]) - channel4 = create_channel("Imported Channel", editors=[admin]) - - # Invite admin to channel 3 - try: - invitation, _new = Invitation.objects.get_or_create( - invited=admin, - sender=user3, - channel=channel3, - email=admin.email, - ) - invitation.share_mode = "edit" - invitation.save() - except MultipleObjectsReturned: - # we don't care, just continue - pass - - # Create pool of tags - tags = [] - for t in TAGS: - tag, _new = ContentTag.objects.get_or_create(tag_name=t, channel=channel1) - - # Generate file objects - document_file = create_file("Sample Document", format_presets.DOCUMENT, file_formats.PDF, user=admin) - video_file = create_file("Sample Video", format_presets.VIDEO_HIGH_RES, file_formats.MP4, user=admin) - subtitle_file = create_file("Sample Subtitle", format_presets.VIDEO_SUBTITLE, file_formats.VTT, user=admin) - audio_file = create_file("Sample Audio", format_presets.AUDIO, file_formats.MP3, user=admin) - html5_file = create_file("Sample HTML", format_presets.HTML5_ZIP, file_formats.HTML5, user=admin) - - # Populate channel 1 with content - generate_tree(channel1.main_tree, document_file, video_file, subtitle_file, audio_file, html5_file, user=admin, tags=tags) - - # Populate channel 2 with staged content - channel2.ricecooker_version = "0.0.0" - channel2.save() - generate_tree(channel2.staging_tree, document_file, video_file, subtitle_file, audio_file, html5_file, user=admin, tags=tags) - - # Import content from channel 1 into channel 4 - channel1.main_tree.children.first().copy_to(channel4.main_tree) - - # Get validation to be reflected in nodes properly - ContentNode.objects.all().update(complete=True) - call_command('mark_incomplete') - - # Mark this node as incomplete even though it is complete - # for testing purposes - node = ContentNode.objects.get(tree_id=channel1.main_tree.tree_id, title="Sample Audio") - node.complete = False - node.save() - - # Publish - publish_channel(admin.id, channel1.pk) - - # Add nodes to clipboard in legacy way - legacy_clipboard_nodes = channel1.main_tree.get_children() - for legacy_node in legacy_clipboard_nodes: - legacy_node.copy_to(target=user1.clipboard_tree) + # Only create additional data when clean-data-state is False (i.e. default behaviour). + if options["clean_data_state"] is False: + # Create channels + channel1 = create_channel("Published Channel", DESCRIPTION, editors=[admin], bookmarkers=[user1, user2], public=True) + channel2 = create_channel("Ricecooker Channel", DESCRIPTION, editors=[admin, user1], bookmarkers=[user2], viewers=[user3]) + channel3 = create_channel("Empty Channel", editors=[user3], viewers=[user2]) + channel4 = create_channel("Imported Channel", editors=[admin]) + + # Invite admin to channel 3 + try: + invitation, _new = Invitation.objects.get_or_create( + invited=admin, + sender=user3, + channel=channel3, + email=admin.email, + ) + invitation.share_mode = "edit" + invitation.save() + except MultipleObjectsReturned: + # we don't care, just continue + pass + + # Create pool of tags + tags = [] + for t in TAGS: + tag, _new = ContentTag.objects.get_or_create(tag_name=t, channel=channel1) + + # Generate file objects + document_file = create_file("Sample Document", format_presets.DOCUMENT, file_formats.PDF, user=admin) + video_file = create_file("Sample Video", format_presets.VIDEO_HIGH_RES, file_formats.MP4, user=admin) + subtitle_file = create_file("Sample Subtitle", format_presets.VIDEO_SUBTITLE, file_formats.VTT, user=admin) + audio_file = create_file("Sample Audio", format_presets.AUDIO, file_formats.MP3, user=admin) + html5_file = create_file("Sample HTML", format_presets.HTML5_ZIP, file_formats.HTML5, user=admin) + + # Populate channel 1 with content + generate_tree(channel1.main_tree, document_file, video_file, subtitle_file, audio_file, html5_file, user=admin, tags=tags) + + # Populate channel 2 with staged content + channel2.ricecooker_version = "0.0.0" + channel2.save() + generate_tree(channel2.staging_tree, document_file, video_file, subtitle_file, audio_file, html5_file, user=admin, tags=tags) + + # Import content from channel 1 into channel 4 + channel1.main_tree.children.first().copy_to(channel4.main_tree) + + # Get validation to be reflected in nodes properly + ContentNode.objects.all().update(complete=True) + call_command('mark_incomplete') + + # Mark this node as incomplete even though it is complete + # for testing purposes + node = ContentNode.objects.get(tree_id=channel1.main_tree.tree_id, title="Sample Audio") + node.complete = False + node.save() + + # Publish + publish_channel(admin.id, channel1.pk) + + # Add nodes to clipboard in legacy way + legacy_clipboard_nodes = channel1.main_tree.get_children() + for legacy_node in legacy_clipboard_nodes: + legacy_node.copy_to(target=user1.clipboard_tree) print("\n\n\nSETUP DONE: Log in as admin to view data (email: {}, password: {})\n\n\n".format(email, password)) diff --git a/contentcuration/contentcuration/migrations/0147_embeddings_embeddingscontentnode.py b/contentcuration/contentcuration/migrations/0147_embeddings_embeddingscontentnode.py new file mode 100644 index 0000000000..d00f289b85 --- /dev/null +++ b/contentcuration/contentcuration/migrations/0147_embeddings_embeddingscontentnode.py @@ -0,0 +1,36 @@ +# Generated by Django 3.2.19 on 2023-11-08 09:55 +import uuid + +import django.db.models.deletion +import pgvector.django +from django.conf import settings +from django.db import migrations +from django.db import models + +import contentcuration.models + + +class Migration(migrations.Migration): + + dependencies = [ + ('contentcuration', '0146_drop_taskresult_fields'), + ] + + operations = [ + migrations.CreateModel( + name='EmbeddingsContentNode', + fields=[ + ('cid', models.CharField(max_length=64, primary_key=True, serialize=False)), + ('contentnode', models.ForeignKey(on_delete=django.db.models.deletion.CASCADE, related_name='node_cid', to='contentcuration.contentnode')), + ], + ), + migrations.CreateModel( + name='Embeddings', + fields=[ + ('embedding_id', contentcuration.models.UUIDField(default=uuid.uuid4, max_length=32, primary_key=True, serialize=False)), + ('model', models.CharField(db_index=True, max_length=64)), + ('embedding', pgvector.django.VectorField()), + ('embedded_node', models.ForeignKey(on_delete=django.db.models.deletion.CASCADE, related_name='embeddings', to='contentcuration.embeddingscontentnode')), + ], + ), + ] diff --git a/contentcuration/contentcuration/models.py b/contentcuration/contentcuration/models.py index 0d73096bfa..4b3ebd9147 100644 --- a/contentcuration/contentcuration/models.py +++ b/contentcuration/contentcuration/models.py @@ -57,6 +57,7 @@ from mptt.models import MPTTModel from mptt.models import raise_if_unsaved from mptt.models import TreeForeignKey +from pgvector.django import VectorField from postmark.core import PMMailInactiveRecipientException from postmark.core import PMMailUnauthorizedException from rest_framework.authtoken.models import Token @@ -1983,6 +1984,24 @@ class Meta: ] +class EmbeddingsContentNode(models.Model): + """ + A model that stores the canonical contentnode for embedding purposes. + """ + cid = models.CharField(primary_key=True, max_length=64) + contentnode = models.ForeignKey(ContentNode, related_name="node_cid", blank=False, null=False, on_delete=models.CASCADE) + + +class Embeddings(models.Model): + """ + A model that stores generated embeddings. + """ + embedding_id = UUIDField(primary_key=True, default=uuid.uuid4) + model = models.CharField(max_length=64, db_index=True) + embedded_node = models.ForeignKey(EmbeddingsContentNode, related_name="embeddings", blank=False, null=False, on_delete=models.CASCADE) + embedding = VectorField() + + class ContentKind(models.Model): kind = models.CharField(primary_key=True, max_length=200, choices=content_kinds.choices) diff --git a/contentcuration/contentcuration/test_settings.py b/contentcuration/contentcuration/test_settings.py index a1fcf20e6b..556a4e13ae 100644 --- a/contentcuration/contentcuration/test_settings.py +++ b/contentcuration/contentcuration/test_settings.py @@ -11,3 +11,5 @@ INSTALLED_APPS += ("django_concurrent_tests",) # noqa F405 MANAGE_PY_PATH = "./contentcuration/manage.py" + +DATABASES["default"]["ENGINE"] = "contentcuration.tests.custom_pytest_db_backend" # noqa diff --git a/contentcuration/contentcuration/tests/custom_pytest_db_backend/base.py b/contentcuration/contentcuration/tests/custom_pytest_db_backend/base.py new file mode 100644 index 0000000000..b4318cec38 --- /dev/null +++ b/contentcuration/contentcuration/tests/custom_pytest_db_backend/base.py @@ -0,0 +1,34 @@ +from django.db.backends.postgresql.base import DatabaseWrapper as PostgresDatabaseWrapper +from django.db.backends.postgresql.creation import DatabaseCreation as PostgresDatabaseCreation + + +class CustomDBCreationForPytest(PostgresDatabaseCreation): + """ + Overriding creation module to enable postgres pgvector extension before + pytest runs migration. Because embeddings table rely on pgvector + extension to work. + """ + + def _create_test_db(self, verbosity, autoclobber, keepdb=False): + from contentcuration.management.commands.setup import enable_pgvector_extension + + # Create test database and get its name. + test_db_name = super()._create_test_db(verbosity, autoclobber, keepdb) + + # Close current nodb_cursor connection and point connection to the + # newly created test database. + self.connection.close() + self.connection.settings_dict["NAME"] = test_db_name + + # Enable pgvector extension. + enable_pgvector_extension(self.connection) + + return self.connection.settings_dict["NAME"] + + +class DatabaseWrapper(PostgresDatabaseWrapper): + """ + A database wrapper to customise creation module. Rest everything is same as + Postgres. + """ + creation_class = CustomDBCreationForPytest diff --git a/contentcuration/contentcuration/tests/test_models.py b/contentcuration/contentcuration/tests/test_models.py index d53be0176b..e8d1747813 100644 --- a/contentcuration/contentcuration/tests/test_models.py +++ b/contentcuration/contentcuration/tests/test_models.py @@ -19,6 +19,8 @@ from contentcuration.models import ChannelSet from contentcuration.models import ContentNode from contentcuration.models import CONTENTNODE_TREE_ID_CACHE_KEY +from contentcuration.models import Embeddings +from contentcuration.models import EmbeddingsContentNode from contentcuration.models import File from contentcuration.models import FILE_DURATION_CONSTRAINT from contentcuration.models import generate_object_storage_name @@ -981,3 +983,50 @@ def test_prune(self): ChannelHistory.prune() self.assertEqual(2, ChannelHistory.objects.count()) self.assertEqual(2, ChannelHistory.objects.filter(id__in=last_history_ids).count()) + + +class EmbeddingsTestCase(StudioTestCase): + persist_bucket = True + + @classmethod + def setUpClass(cls): + super(EmbeddingsTestCase, cls).setUpClass() + cls.create_bucket() + node_1 = testdata.node({ + "kind_id": "video", + "title": "first" + }) + node_2 = testdata.node({ + "kind_id": "video", + "title": "second" + }) + node_3 = testdata.node({ + "kind_id": "video", + "title": "third" + }) + + embedded_node_1 = EmbeddingsContentNode.objects.create(cid=node_1.content_id, contentnode=node_1) + embedded_node_2 = EmbeddingsContentNode.objects.create(cid=node_2.content_id, contentnode=node_2) + embedded_node_3 = EmbeddingsContentNode.objects.create(cid=node_3.content_id, contentnode=node_3) + + # Two closely placed vectors i.e. they are similar. + Embeddings.objects.create(model="studio-embedder-v1.0", embedded_node=embedded_node_1, embedding=[2, 3]) + Embeddings.objects.create(model="studio-embedder-v1.0", embedded_node=embedded_node_2, embedding=[2, 2]) + # A vector placed at far distance i.e. not similar to above vectors. + Embeddings.objects.create(model="studio-embedder-v1.0", embedded_node=embedded_node_3, embedding=[4, 1]) + + @classmethod + def tearDownClass(cls): + super(EmbeddingsTestCase, cls).tearDownClass() + cls.delete_bucket() + + def test_can_create_embeddings(self): + embeddings_count = Embeddings.objects.count() + self.assertEqual(embeddings_count, 3) + + def test_get_nearest_neighbors(self): + from pgvector.django import L2Distance + import numpy as np + # Get the nearest neighbor of [2, 3] which is [2, 2]. + closest_embedding = list(Embeddings.objects.order_by(L2Distance('embedding', [2, 3]))[1:2])[0] + self.assertTrue(np.array_equal(closest_embedding.embedding, np.array([2, 2]))) diff --git a/docs/host_services_setup.md b/docs/host_services_setup.md index bd21c52b35..49c6aabc6e 100644 --- a/docs/host_services_setup.md +++ b/docs/host_services_setup.md @@ -130,7 +130,7 @@ yarn run services ``` ## Initializing Studio -With the services running, in a separate terminal/terminal-tab, we can now initialize the database for Studio development purposes. The command below will initialize the database, in addition to adding a user account for development: +With the services running, in a separate terminal/terminal-tab, we can now initialize the database for Studio development purposes. The command below will initialize the database then it will enable the required postgres extensions in addition to adding a user account for development: ```bash yarn run devsetup ``` diff --git a/docs/local_dev_docker.md b/docs/local_dev_docker.md index 9234330ff4..8bff7e2c93 100644 --- a/docs/local_dev_docker.md +++ b/docs/local_dev_docker.md @@ -92,7 +92,7 @@ make dcservicesdown ``` ## Initializing Studio -With the services running, in a separate terminal/terminal-tab, we can now initialize the database for Studio development purposes. The command below will initialize the database tables, import constants, and a user account for development: +With the services running, in a separate terminal/terminal-tab, we can now initialize the database for Studio development purposes. The command below will initialize the database tables, import constants, enable required postgres extensions and a studio user account for development: ```bash yarn run devsetup ``` diff --git a/package.json b/package.json index a9202aebec..22fa77bde1 100644 --- a/package.json +++ b/package.json @@ -15,7 +15,8 @@ "build": "webpack --env prod --config webpack.config.js", "postgres": "pg_ctl -D /usr/local/var/postgresql@9.6 start || true", "redis": "redis-server /usr/local/etc/redis.conf || true", - "devsetup": "cd contentcuration && python manage.py setup --settings=contentcuration.dev_settings", + "devsetup": "python contentcuration/manage.py setup --settings=contentcuration.dev_settings", + "devsetup:clean": "python contentcuration/manage.py setup --clean-data-state --settings=contentcuration.dev_settings", "services": "npm-run-all -c --parallel --silent celery minio redis postgres", "test": "jest --config jest_config/jest.conf.js", "build:dev": "webpack serve --env dev --config webpack.config.js --progress", diff --git a/requirements.in b/requirements.in index bfce094b72..400d07d392 100644 --- a/requirements.in +++ b/requirements.in @@ -35,3 +35,4 @@ pillow==10.2.0 python-dateutil>=2.8.1 jsonschema>=3.2.0 django-celery-results +pgvector diff --git a/requirements.txt b/requirements.txt index e37ae00636..f4b46f9542 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,6 +1,6 @@ # -# This file is autogenerated by pip-compile with Python 3.10 -# by the following command: +# This file is autogenerated by pip-compile with python 3.10 +# To update, run: # # pip-compile requirements.in # @@ -98,7 +98,6 @@ future==0.18.3 # via -r requirements.in google-api-core[grpc]==1.27.0 # via - # google-api-core # google-cloud-core # google-cloud-error-reporting # google-cloud-kms @@ -160,11 +159,15 @@ le-utils==0.2.1 # via -r requirements.in newrelic==6.2.0.156 # via -r requirements.in +numpy==1.26.0 + # via pgvector packaging==20.9 # via # google-api-core # google-cloud-error-reporting # google-cloud-kms +pgvector==0.2.3 + # via -r requirements.in pillow==10.2.0 # via -r requirements.in prometheus-client==0.10.1