Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Welcome Embeddings 🚀 #4322

Merged
merged 11 commits into from
Mar 12, 2024
5 changes: 4 additions & 1 deletion .github/workflows/pythontest.yml
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,10 @@ jobs:
# Label used to access the service container
postgres:
# Docker Hub image
image: postgres:12
image: ghcr.io/learningequality/postgres:${{ github.base_ref || github.ref_name }}
credentials:
username: ${{ github.actor }}
password: ${{ secrets.GITHUB_TOKEN }}
# Provide the password for postgres
env:
POSTGRES_USER: learningequality
Expand Down
137 changes: 76 additions & 61 deletions contentcuration/contentcuration/management/commands/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,13 @@

from django.core.management import call_command
from django.core.management.base import BaseCommand
from django.db import connection
from django.db import Error as DBError
from le_utils.constants import content_kinds
from le_utils.constants import file_formats
from le_utils.constants import format_presets
from le_utils.constants import licenses
from pgvector.django import VectorExtension

from contentcuration.models import ContentNode
from contentcuration.models import ContentTag
Expand Down Expand Up @@ -40,11 +42,20 @@
SORT_ORDER = 0


def enable_pgvector_extension(connection):
"""
Enables pgvector extension in postgres.
"""
with connection.cursor() as cursor:
cursor.execute("CREATE EXTENSION IF NOT EXISTS %s;" % VectorExtension().name)


class Command(BaseCommand):

def add_arguments(self, parser):
parser.add_argument('--email', dest="email", default="[email protected]")
parser.add_argument('--password', dest="password", default="a")
parser.add_argument('--clean-data-state', action='store_true', default=False, help='Sets database in clean state.')

def handle(self, *args, **options):
# Validate email
Expand All @@ -65,6 +76,9 @@ def handle(self, *args, **options):
except DBError as e:
logging.error('Error creating cache table: {}'.format(str(e)))

# Enable pgvector extension.
enable_pgvector_extension(connection)

# Run migrations
call_command('migrate')

Expand All @@ -79,67 +93,68 @@ def handle(self, *args, **options):
user2 = create_user("[email protected]", "b", "User", "B")
user3 = create_user("[email protected]", "c", "User", "C")

# Create channels

channel1 = create_channel("Published Channel", DESCRIPTION, editors=[admin], bookmarkers=[user1, user2], public=True)
channel2 = create_channel("Ricecooker Channel", DESCRIPTION, editors=[admin, user1], bookmarkers=[user2], viewers=[user3])
channel3 = create_channel("Empty Channel", editors=[user3], viewers=[user2])
channel4 = create_channel("Imported Channel", editors=[admin])

# Invite admin to channel 3
try:
invitation, _new = Invitation.objects.get_or_create(
invited=admin,
sender=user3,
channel=channel3,
email=admin.email,
)
invitation.share_mode = "edit"
invitation.save()
except MultipleObjectsReturned:
# we don't care, just continue
pass

# Create pool of tags
tags = []
for t in TAGS:
tag, _new = ContentTag.objects.get_or_create(tag_name=t, channel=channel1)

# Generate file objects
document_file = create_file("Sample Document", format_presets.DOCUMENT, file_formats.PDF, user=admin)
video_file = create_file("Sample Video", format_presets.VIDEO_HIGH_RES, file_formats.MP4, user=admin)
subtitle_file = create_file("Sample Subtitle", format_presets.VIDEO_SUBTITLE, file_formats.VTT, user=admin)
audio_file = create_file("Sample Audio", format_presets.AUDIO, file_formats.MP3, user=admin)
html5_file = create_file("Sample HTML", format_presets.HTML5_ZIP, file_formats.HTML5, user=admin)

# Populate channel 1 with content
generate_tree(channel1.main_tree, document_file, video_file, subtitle_file, audio_file, html5_file, user=admin, tags=tags)

# Populate channel 2 with staged content
channel2.ricecooker_version = "0.0.0"
channel2.save()
generate_tree(channel2.staging_tree, document_file, video_file, subtitle_file, audio_file, html5_file, user=admin, tags=tags)

# Import content from channel 1 into channel 4
channel1.main_tree.children.first().copy_to(channel4.main_tree)

# Get validation to be reflected in nodes properly
ContentNode.objects.all().update(complete=True)
call_command('mark_incomplete')

# Mark this node as incomplete even though it is complete
# for testing purposes
node = ContentNode.objects.get(tree_id=channel1.main_tree.tree_id, title="Sample Audio")
node.complete = False
node.save()

# Publish
publish_channel(admin.id, channel1.pk)

# Add nodes to clipboard in legacy way
legacy_clipboard_nodes = channel1.main_tree.get_children()
for legacy_node in legacy_clipboard_nodes:
legacy_node.copy_to(target=user1.clipboard_tree)
# Only create additional data when clean-data-state is False (i.e. default behaviour).
if options["clean_data_state"] is False:
# Create channels
channel1 = create_channel("Published Channel", DESCRIPTION, editors=[admin], bookmarkers=[user1, user2], public=True)
channel2 = create_channel("Ricecooker Channel", DESCRIPTION, editors=[admin, user1], bookmarkers=[user2], viewers=[user3])
channel3 = create_channel("Empty Channel", editors=[user3], viewers=[user2])
channel4 = create_channel("Imported Channel", editors=[admin])

# Invite admin to channel 3
try:
invitation, _new = Invitation.objects.get_or_create(
invited=admin,
sender=user3,
channel=channel3,
email=admin.email,
)
invitation.share_mode = "edit"
invitation.save()
except MultipleObjectsReturned:
# we don't care, just continue
pass

# Create pool of tags
tags = []
for t in TAGS:
tag, _new = ContentTag.objects.get_or_create(tag_name=t, channel=channel1)

# Generate file objects
document_file = create_file("Sample Document", format_presets.DOCUMENT, file_formats.PDF, user=admin)
video_file = create_file("Sample Video", format_presets.VIDEO_HIGH_RES, file_formats.MP4, user=admin)
subtitle_file = create_file("Sample Subtitle", format_presets.VIDEO_SUBTITLE, file_formats.VTT, user=admin)
audio_file = create_file("Sample Audio", format_presets.AUDIO, file_formats.MP3, user=admin)
html5_file = create_file("Sample HTML", format_presets.HTML5_ZIP, file_formats.HTML5, user=admin)

# Populate channel 1 with content
generate_tree(channel1.main_tree, document_file, video_file, subtitle_file, audio_file, html5_file, user=admin, tags=tags)

# Populate channel 2 with staged content
channel2.ricecooker_version = "0.0.0"
channel2.save()
generate_tree(channel2.staging_tree, document_file, video_file, subtitle_file, audio_file, html5_file, user=admin, tags=tags)

# Import content from channel 1 into channel 4
channel1.main_tree.children.first().copy_to(channel4.main_tree)

# Get validation to be reflected in nodes properly
ContentNode.objects.all().update(complete=True)
call_command('mark_incomplete')

# Mark this node as incomplete even though it is complete
# for testing purposes
node = ContentNode.objects.get(tree_id=channel1.main_tree.tree_id, title="Sample Audio")
node.complete = False
node.save()

# Publish
publish_channel(admin.id, channel1.pk)

# Add nodes to clipboard in legacy way
legacy_clipboard_nodes = channel1.main_tree.get_children()
for legacy_node in legacy_clipboard_nodes:
legacy_node.copy_to(target=user1.clipboard_tree)

print("\n\n\nSETUP DONE: Log in as admin to view data (email: {}, password: {})\n\n\n".format(email, password))

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
# Generated by Django 3.2.19 on 2023-11-08 09:55
import uuid

import django.db.models.deletion
import pgvector.django
from django.conf import settings
from django.db import migrations
from django.db import models

import contentcuration.models


class Migration(migrations.Migration):

dependencies = [
('contentcuration', '0146_drop_taskresult_fields'),
]

operations = [
migrations.CreateModel(
name='EmbeddingsContentNode',
fields=[
('cid', models.CharField(max_length=64, primary_key=True, serialize=False)),
('contentnode', models.ForeignKey(on_delete=django.db.models.deletion.CASCADE, related_name='node_cid', to='contentcuration.contentnode')),
],
),
migrations.CreateModel(
name='Embeddings',
fields=[
('embedding_id', contentcuration.models.UUIDField(default=uuid.uuid4, max_length=32, primary_key=True, serialize=False)),
('model', models.CharField(db_index=True, max_length=64)),
('embedding', pgvector.django.VectorField()),
('embedded_node', models.ForeignKey(on_delete=django.db.models.deletion.CASCADE, related_name='embeddings', to='contentcuration.embeddingscontentnode')),
],
),
]
19 changes: 19 additions & 0 deletions contentcuration/contentcuration/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@
from mptt.models import MPTTModel
from mptt.models import raise_if_unsaved
from mptt.models import TreeForeignKey
from pgvector.django import VectorField
from postmark.core import PMMailInactiveRecipientException
from postmark.core import PMMailUnauthorizedException
from rest_framework.authtoken.models import Token
Expand Down Expand Up @@ -1983,6 +1984,24 @@ class Meta:
]


class EmbeddingsContentNode(models.Model):
"""
A model that stores the canonical contentnode for embedding purposes.
"""
cid = models.CharField(primary_key=True, max_length=64)
contentnode = models.ForeignKey(ContentNode, related_name="node_cid", blank=False, null=False, on_delete=models.CASCADE)


class Embeddings(models.Model):
"""
A model that stores generated embeddings.
"""
embedding_id = UUIDField(primary_key=True, default=uuid.uuid4)
model = models.CharField(max_length=64, db_index=True)
embedded_node = models.ForeignKey(EmbeddingsContentNode, related_name="embeddings", blank=False, null=False, on_delete=models.CASCADE)
embedding = VectorField()
vkWeb marked this conversation as resolved.
Show resolved Hide resolved


class ContentKind(models.Model):
kind = models.CharField(primary_key=True, max_length=200, choices=content_kinds.choices)

Expand Down
2 changes: 2 additions & 0 deletions contentcuration/contentcuration/test_settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,3 +11,5 @@
INSTALLED_APPS += ("django_concurrent_tests",) # noqa F405

MANAGE_PY_PATH = "./contentcuration/manage.py"

DATABASES["default"]["ENGINE"] = "contentcuration.tests.custom_pytest_db_backend" # noqa
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
from django.db.backends.postgresql.base import DatabaseWrapper as PostgresDatabaseWrapper
from django.db.backends.postgresql.creation import DatabaseCreation as PostgresDatabaseCreation


class CustomDBCreationForPytest(PostgresDatabaseCreation):
"""
Overriding creation module to enable postgres pgvector extension before
pytest runs migration. Because embeddings table rely on pgvector
extension to work.
"""

def _create_test_db(self, verbosity, autoclobber, keepdb=False):
from contentcuration.management.commands.setup import enable_pgvector_extension

# Create test database and get its name.
test_db_name = super()._create_test_db(verbosity, autoclobber, keepdb)
vkWeb marked this conversation as resolved.
Show resolved Hide resolved

# Close current nodb_cursor connection and point connection to the
# newly created test database.
self.connection.close()
self.connection.settings_dict["NAME"] = test_db_name

# Enable pgvector extension.
enable_pgvector_extension(self.connection)

return self.connection.settings_dict["NAME"]


class DatabaseWrapper(PostgresDatabaseWrapper):
"""
A database wrapper to customise creation module. Rest everything is same as
Postgres.
"""
creation_class = CustomDBCreationForPytest
49 changes: 49 additions & 0 deletions contentcuration/contentcuration/tests/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@
from contentcuration.models import ChannelSet
from contentcuration.models import ContentNode
from contentcuration.models import CONTENTNODE_TREE_ID_CACHE_KEY
from contentcuration.models import Embeddings
from contentcuration.models import EmbeddingsContentNode
from contentcuration.models import File
from contentcuration.models import FILE_DURATION_CONSTRAINT
from contentcuration.models import generate_object_storage_name
Expand Down Expand Up @@ -981,3 +983,50 @@ def test_prune(self):
ChannelHistory.prune()
self.assertEqual(2, ChannelHistory.objects.count())
self.assertEqual(2, ChannelHistory.objects.filter(id__in=last_history_ids).count())


class EmbeddingsTestCase(StudioTestCase):
persist_bucket = True

@classmethod
def setUpClass(cls):
super(EmbeddingsTestCase, cls).setUpClass()
cls.create_bucket()
node_1 = testdata.node({
"kind_id": "video",
"title": "first"
})
node_2 = testdata.node({
"kind_id": "video",
"title": "second"
})
node_3 = testdata.node({
"kind_id": "video",
"title": "third"
})

embedded_node_1 = EmbeddingsContentNode.objects.create(cid=node_1.content_id, contentnode=node_1)
embedded_node_2 = EmbeddingsContentNode.objects.create(cid=node_2.content_id, contentnode=node_2)
embedded_node_3 = EmbeddingsContentNode.objects.create(cid=node_3.content_id, contentnode=node_3)

# Two closely placed vectors i.e. they are similar.
Embeddings.objects.create(model="studio-embedder-v1.0", embedded_node=embedded_node_1, embedding=[2, 3])
Embeddings.objects.create(model="studio-embedder-v1.0", embedded_node=embedded_node_2, embedding=[2, 2])
# A vector placed at far distance i.e. not similar to above vectors.
Embeddings.objects.create(model="studio-embedder-v1.0", embedded_node=embedded_node_3, embedding=[4, 1])

@classmethod
def tearDownClass(cls):
super(EmbeddingsTestCase, cls).tearDownClass()
cls.delete_bucket()

def test_can_create_embeddings(self):
embeddings_count = Embeddings.objects.count()
self.assertEqual(embeddings_count, 3)

def test_get_nearest_neighbors(self):
from pgvector.django import L2Distance
import numpy as np
# Get the nearest neighbor of [2, 3] which is [2, 2].
closest_embedding = list(Embeddings.objects.order_by(L2Distance('embedding', [2, 3]))[1:2])[0]
self.assertTrue(np.array_equal(closest_embedding.embedding, np.array([2, 2])))
2 changes: 1 addition & 1 deletion docs/host_services_setup.md
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ yarn run services
```

## Initializing Studio
With the services running, in a separate terminal/terminal-tab, we can now initialize the database for Studio development purposes. The command below will initialize the database, in addition to adding a user account for development:
With the services running, in a separate terminal/terminal-tab, we can now initialize the database for Studio development purposes. The command below will initialize the database then it will enable the required postgres extensions in addition to adding a user account for development:
```bash
yarn run devsetup
```
Expand Down
2 changes: 1 addition & 1 deletion docs/local_dev_docker.md
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ make dcservicesdown
```

## Initializing Studio
With the services running, in a separate terminal/terminal-tab, we can now initialize the database for Studio development purposes. The command below will initialize the database tables, import constants, and a user account for development:
With the services running, in a separate terminal/terminal-tab, we can now initialize the database for Studio development purposes. The command below will initialize the database tables, import constants, enable required postgres extensions and a studio user account for development:
```bash
yarn run devsetup
```
Expand Down
3 changes: 2 additions & 1 deletion package.json
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,8 @@
"build": "webpack --env prod --config webpack.config.js",
"postgres": "pg_ctl -D /usr/local/var/[email protected] start || true",
"redis": "redis-server /usr/local/etc/redis.conf || true",
"devsetup": "cd contentcuration && python manage.py setup --settings=contentcuration.dev_settings",
"devsetup": "python contentcuration/manage.py setup --settings=contentcuration.dev_settings",
"devsetup:clean": "python contentcuration/manage.py setup --clean-data-state --settings=contentcuration.dev_settings",
"services": "npm-run-all -c --parallel --silent celery minio redis postgres",
"test": "jest --config jest_config/jest.conf.js",
"build:dev": "webpack serve --env dev --config webpack.config.js --progress",
Expand Down
1 change: 1 addition & 0 deletions requirements.in
Original file line number Diff line number Diff line change
Expand Up @@ -35,3 +35,4 @@ pillow==10.2.0
python-dateutil>=2.8.1
jsonschema>=3.2.0
django-celery-results
pgvector
Loading
Loading