Skip to content

Commit

Permalink
Merge pull request #4322 from vkWeb/embedds-model
Browse files Browse the repository at this point in the history
Welcome Embeddings 🚀
  • Loading branch information
bjester authored Mar 12, 2024
2 parents 91dfec0 + c999814 commit 3861ca1
Show file tree
Hide file tree
Showing 12 changed files with 231 additions and 68 deletions.
5 changes: 4 additions & 1 deletion .github/workflows/pythontest.yml
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,10 @@ jobs:
# Label used to access the service container
postgres:
# Docker Hub image
image: postgres:12
image: ghcr.io/learningequality/postgres:${{ github.base_ref || github.ref_name }}
credentials:
username: ${{ github.actor }}
password: ${{ secrets.GITHUB_TOKEN }}
# Provide the password for postgres
env:
POSTGRES_USER: learningequality
Expand Down
137 changes: 76 additions & 61 deletions contentcuration/contentcuration/management/commands/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,13 @@

from django.core.management import call_command
from django.core.management.base import BaseCommand
from django.db import connection
from django.db import Error as DBError
from le_utils.constants import content_kinds
from le_utils.constants import file_formats
from le_utils.constants import format_presets
from le_utils.constants import licenses
from pgvector.django import VectorExtension

from contentcuration.models import ContentNode
from contentcuration.models import ContentTag
Expand Down Expand Up @@ -40,11 +42,20 @@
SORT_ORDER = 0


def enable_pgvector_extension(connection):
"""
Enables pgvector extension in postgres.
"""
with connection.cursor() as cursor:
cursor.execute("CREATE EXTENSION IF NOT EXISTS %s;" % VectorExtension().name)


class Command(BaseCommand):

def add_arguments(self, parser):
parser.add_argument('--email', dest="email", default="[email protected]")
parser.add_argument('--password', dest="password", default="a")
parser.add_argument('--clean-data-state', action='store_true', default=False, help='Sets database in clean state.')

def handle(self, *args, **options):
# Validate email
Expand All @@ -65,6 +76,9 @@ def handle(self, *args, **options):
except DBError as e:
logging.error('Error creating cache table: {}'.format(str(e)))

# Enable pgvector extension.
enable_pgvector_extension(connection)

# Run migrations
call_command('migrate')

Expand All @@ -79,67 +93,68 @@ def handle(self, *args, **options):
user2 = create_user("[email protected]", "b", "User", "B")
user3 = create_user("[email protected]", "c", "User", "C")

# Create channels

channel1 = create_channel("Published Channel", DESCRIPTION, editors=[admin], bookmarkers=[user1, user2], public=True)
channel2 = create_channel("Ricecooker Channel", DESCRIPTION, editors=[admin, user1], bookmarkers=[user2], viewers=[user3])
channel3 = create_channel("Empty Channel", editors=[user3], viewers=[user2])
channel4 = create_channel("Imported Channel", editors=[admin])

# Invite admin to channel 3
try:
invitation, _new = Invitation.objects.get_or_create(
invited=admin,
sender=user3,
channel=channel3,
email=admin.email,
)
invitation.share_mode = "edit"
invitation.save()
except MultipleObjectsReturned:
# we don't care, just continue
pass

# Create pool of tags
tags = []
for t in TAGS:
tag, _new = ContentTag.objects.get_or_create(tag_name=t, channel=channel1)

# Generate file objects
document_file = create_file("Sample Document", format_presets.DOCUMENT, file_formats.PDF, user=admin)
video_file = create_file("Sample Video", format_presets.VIDEO_HIGH_RES, file_formats.MP4, user=admin)
subtitle_file = create_file("Sample Subtitle", format_presets.VIDEO_SUBTITLE, file_formats.VTT, user=admin)
audio_file = create_file("Sample Audio", format_presets.AUDIO, file_formats.MP3, user=admin)
html5_file = create_file("Sample HTML", format_presets.HTML5_ZIP, file_formats.HTML5, user=admin)

# Populate channel 1 with content
generate_tree(channel1.main_tree, document_file, video_file, subtitle_file, audio_file, html5_file, user=admin, tags=tags)

# Populate channel 2 with staged content
channel2.ricecooker_version = "0.0.0"
channel2.save()
generate_tree(channel2.staging_tree, document_file, video_file, subtitle_file, audio_file, html5_file, user=admin, tags=tags)

# Import content from channel 1 into channel 4
channel1.main_tree.children.first().copy_to(channel4.main_tree)

# Get validation to be reflected in nodes properly
ContentNode.objects.all().update(complete=True)
call_command('mark_incomplete')

# Mark this node as incomplete even though it is complete
# for testing purposes
node = ContentNode.objects.get(tree_id=channel1.main_tree.tree_id, title="Sample Audio")
node.complete = False
node.save()

# Publish
publish_channel(admin.id, channel1.pk)

# Add nodes to clipboard in legacy way
legacy_clipboard_nodes = channel1.main_tree.get_children()
for legacy_node in legacy_clipboard_nodes:
legacy_node.copy_to(target=user1.clipboard_tree)
# Only create additional data when clean-data-state is False (i.e. default behaviour).
if options["clean_data_state"] is False:
# Create channels
channel1 = create_channel("Published Channel", DESCRIPTION, editors=[admin], bookmarkers=[user1, user2], public=True)
channel2 = create_channel("Ricecooker Channel", DESCRIPTION, editors=[admin, user1], bookmarkers=[user2], viewers=[user3])
channel3 = create_channel("Empty Channel", editors=[user3], viewers=[user2])
channel4 = create_channel("Imported Channel", editors=[admin])

# Invite admin to channel 3
try:
invitation, _new = Invitation.objects.get_or_create(
invited=admin,
sender=user3,
channel=channel3,
email=admin.email,
)
invitation.share_mode = "edit"
invitation.save()
except MultipleObjectsReturned:
# we don't care, just continue
pass

# Create pool of tags
tags = []
for t in TAGS:
tag, _new = ContentTag.objects.get_or_create(tag_name=t, channel=channel1)

# Generate file objects
document_file = create_file("Sample Document", format_presets.DOCUMENT, file_formats.PDF, user=admin)
video_file = create_file("Sample Video", format_presets.VIDEO_HIGH_RES, file_formats.MP4, user=admin)
subtitle_file = create_file("Sample Subtitle", format_presets.VIDEO_SUBTITLE, file_formats.VTT, user=admin)
audio_file = create_file("Sample Audio", format_presets.AUDIO, file_formats.MP3, user=admin)
html5_file = create_file("Sample HTML", format_presets.HTML5_ZIP, file_formats.HTML5, user=admin)

# Populate channel 1 with content
generate_tree(channel1.main_tree, document_file, video_file, subtitle_file, audio_file, html5_file, user=admin, tags=tags)

# Populate channel 2 with staged content
channel2.ricecooker_version = "0.0.0"
channel2.save()
generate_tree(channel2.staging_tree, document_file, video_file, subtitle_file, audio_file, html5_file, user=admin, tags=tags)

# Import content from channel 1 into channel 4
channel1.main_tree.children.first().copy_to(channel4.main_tree)

# Get validation to be reflected in nodes properly
ContentNode.objects.all().update(complete=True)
call_command('mark_incomplete')

# Mark this node as incomplete even though it is complete
# for testing purposes
node = ContentNode.objects.get(tree_id=channel1.main_tree.tree_id, title="Sample Audio")
node.complete = False
node.save()

# Publish
publish_channel(admin.id, channel1.pk)

# Add nodes to clipboard in legacy way
legacy_clipboard_nodes = channel1.main_tree.get_children()
for legacy_node in legacy_clipboard_nodes:
legacy_node.copy_to(target=user1.clipboard_tree)

print("\n\n\nSETUP DONE: Log in as admin to view data (email: {}, password: {})\n\n\n".format(email, password))

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
# Generated by Django 3.2.19 on 2023-11-08 09:55
import uuid

import django.db.models.deletion
import pgvector.django
from django.conf import settings
from django.db import migrations
from django.db import models

import contentcuration.models


class Migration(migrations.Migration):

dependencies = [
('contentcuration', '0146_drop_taskresult_fields'),
]

operations = [
migrations.CreateModel(
name='EmbeddingsContentNode',
fields=[
('cid', models.CharField(max_length=64, primary_key=True, serialize=False)),
('contentnode', models.ForeignKey(on_delete=django.db.models.deletion.CASCADE, related_name='node_cid', to='contentcuration.contentnode')),
],
),
migrations.CreateModel(
name='Embeddings',
fields=[
('embedding_id', contentcuration.models.UUIDField(default=uuid.uuid4, max_length=32, primary_key=True, serialize=False)),
('model', models.CharField(db_index=True, max_length=64)),
('embedding', pgvector.django.VectorField()),
('embedded_node', models.ForeignKey(on_delete=django.db.models.deletion.CASCADE, related_name='embeddings', to='contentcuration.embeddingscontentnode')),
],
),
]
19 changes: 19 additions & 0 deletions contentcuration/contentcuration/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@
from mptt.models import MPTTModel
from mptt.models import raise_if_unsaved
from mptt.models import TreeForeignKey
from pgvector.django import VectorField
from postmark.core import PMMailInactiveRecipientException
from postmark.core import PMMailUnauthorizedException
from rest_framework.authtoken.models import Token
Expand Down Expand Up @@ -1983,6 +1984,24 @@ class Meta:
]


class EmbeddingsContentNode(models.Model):
"""
A model that stores the canonical contentnode for embedding purposes.
"""
cid = models.CharField(primary_key=True, max_length=64)
contentnode = models.ForeignKey(ContentNode, related_name="node_cid", blank=False, null=False, on_delete=models.CASCADE)


class Embeddings(models.Model):
"""
A model that stores generated embeddings.
"""
embedding_id = UUIDField(primary_key=True, default=uuid.uuid4)
model = models.CharField(max_length=64, db_index=True)
embedded_node = models.ForeignKey(EmbeddingsContentNode, related_name="embeddings", blank=False, null=False, on_delete=models.CASCADE)
embedding = VectorField()


class ContentKind(models.Model):
kind = models.CharField(primary_key=True, max_length=200, choices=content_kinds.choices)

Expand Down
2 changes: 2 additions & 0 deletions contentcuration/contentcuration/test_settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,3 +11,5 @@
INSTALLED_APPS += ("django_concurrent_tests",) # noqa F405

MANAGE_PY_PATH = "./contentcuration/manage.py"

DATABASES["default"]["ENGINE"] = "contentcuration.tests.custom_pytest_db_backend" # noqa
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
from django.db.backends.postgresql.base import DatabaseWrapper as PostgresDatabaseWrapper
from django.db.backends.postgresql.creation import DatabaseCreation as PostgresDatabaseCreation


class CustomDBCreationForPytest(PostgresDatabaseCreation):
"""
Overriding creation module to enable postgres pgvector extension before
pytest runs migration. Because embeddings table rely on pgvector
extension to work.
"""

def _create_test_db(self, verbosity, autoclobber, keepdb=False):
from contentcuration.management.commands.setup import enable_pgvector_extension

# Create test database and get its name.
test_db_name = super()._create_test_db(verbosity, autoclobber, keepdb)

# Close current nodb_cursor connection and point connection to the
# newly created test database.
self.connection.close()
self.connection.settings_dict["NAME"] = test_db_name

# Enable pgvector extension.
enable_pgvector_extension(self.connection)

return self.connection.settings_dict["NAME"]


class DatabaseWrapper(PostgresDatabaseWrapper):
"""
A database wrapper to customise creation module. Rest everything is same as
Postgres.
"""
creation_class = CustomDBCreationForPytest
49 changes: 49 additions & 0 deletions contentcuration/contentcuration/tests/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@
from contentcuration.models import ChannelSet
from contentcuration.models import ContentNode
from contentcuration.models import CONTENTNODE_TREE_ID_CACHE_KEY
from contentcuration.models import Embeddings
from contentcuration.models import EmbeddingsContentNode
from contentcuration.models import File
from contentcuration.models import FILE_DURATION_CONSTRAINT
from contentcuration.models import generate_object_storage_name
Expand Down Expand Up @@ -981,3 +983,50 @@ def test_prune(self):
ChannelHistory.prune()
self.assertEqual(2, ChannelHistory.objects.count())
self.assertEqual(2, ChannelHistory.objects.filter(id__in=last_history_ids).count())


class EmbeddingsTestCase(StudioTestCase):
persist_bucket = True

@classmethod
def setUpClass(cls):
super(EmbeddingsTestCase, cls).setUpClass()
cls.create_bucket()
node_1 = testdata.node({
"kind_id": "video",
"title": "first"
})
node_2 = testdata.node({
"kind_id": "video",
"title": "second"
})
node_3 = testdata.node({
"kind_id": "video",
"title": "third"
})

embedded_node_1 = EmbeddingsContentNode.objects.create(cid=node_1.content_id, contentnode=node_1)
embedded_node_2 = EmbeddingsContentNode.objects.create(cid=node_2.content_id, contentnode=node_2)
embedded_node_3 = EmbeddingsContentNode.objects.create(cid=node_3.content_id, contentnode=node_3)

# Two closely placed vectors i.e. they are similar.
Embeddings.objects.create(model="studio-embedder-v1.0", embedded_node=embedded_node_1, embedding=[2, 3])
Embeddings.objects.create(model="studio-embedder-v1.0", embedded_node=embedded_node_2, embedding=[2, 2])
# A vector placed at far distance i.e. not similar to above vectors.
Embeddings.objects.create(model="studio-embedder-v1.0", embedded_node=embedded_node_3, embedding=[4, 1])

@classmethod
def tearDownClass(cls):
super(EmbeddingsTestCase, cls).tearDownClass()
cls.delete_bucket()

def test_can_create_embeddings(self):
embeddings_count = Embeddings.objects.count()
self.assertEqual(embeddings_count, 3)

def test_get_nearest_neighbors(self):
from pgvector.django import L2Distance
import numpy as np
# Get the nearest neighbor of [2, 3] which is [2, 2].
closest_embedding = list(Embeddings.objects.order_by(L2Distance('embedding', [2, 3]))[1:2])[0]
self.assertTrue(np.array_equal(closest_embedding.embedding, np.array([2, 2])))
2 changes: 1 addition & 1 deletion docs/host_services_setup.md
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ yarn run services
```

## Initializing Studio
With the services running, in a separate terminal/terminal-tab, we can now initialize the database for Studio development purposes. The command below will initialize the database, in addition to adding a user account for development:
With the services running, in a separate terminal/terminal-tab, we can now initialize the database for Studio development purposes. The command below will initialize the database then it will enable the required postgres extensions in addition to adding a user account for development:
```bash
yarn run devsetup
```
Expand Down
2 changes: 1 addition & 1 deletion docs/local_dev_docker.md
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ make dcservicesdown
```

## Initializing Studio
With the services running, in a separate terminal/terminal-tab, we can now initialize the database for Studio development purposes. The command below will initialize the database tables, import constants, and a user account for development:
With the services running, in a separate terminal/terminal-tab, we can now initialize the database for Studio development purposes. The command below will initialize the database tables, import constants, enable required postgres extensions and a studio user account for development:
```bash
yarn run devsetup
```
Expand Down
3 changes: 2 additions & 1 deletion package.json
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,8 @@
"build": "webpack --env prod --config webpack.config.js",
"postgres": "pg_ctl -D /usr/local/var/[email protected] start || true",
"redis": "redis-server /usr/local/etc/redis.conf || true",
"devsetup": "cd contentcuration && python manage.py setup --settings=contentcuration.dev_settings",
"devsetup": "python contentcuration/manage.py setup --settings=contentcuration.dev_settings",
"devsetup:clean": "python contentcuration/manage.py setup --clean-data-state --settings=contentcuration.dev_settings",
"services": "npm-run-all -c --parallel --silent celery minio redis postgres",
"test": "jest --config jest_config/jest.conf.js",
"build:dev": "webpack serve --env dev --config webpack.config.js --progress",
Expand Down
1 change: 1 addition & 0 deletions requirements.in
Original file line number Diff line number Diff line change
Expand Up @@ -35,3 +35,4 @@ pillow==10.2.0
python-dateutil>=2.8.1
jsonschema>=3.2.0
django-celery-results
pgvector
Loading

0 comments on commit 3861ca1

Please sign in to comment.