diff --git a/docs/source/reference/databases/reddit.rst b/docs/source/reference/databases/reddit.rst
new file mode 100644
index 000000000..c25115e91
--- /dev/null
+++ b/docs/source/reference/databases/reddit.rst
@@ -0,0 +1,48 @@
+Reddit
+==========
+
+The connection to Reddit is based on the `praw `_ library.
+
+Dependency
+----------
+
+* praw
+
+
+Parameters
+----------
+
+Required:
+
+* ``subreddit`` is the name of the subreddit from which the data is fetched.
+* ``clientId`` is the unique identifier issued to the client when creating credentials on Reddit. Refer to the [First Steps](https://github.com/reddit-archive/reddit/wiki/OAuth2-Quick-Start-Example#first-steps) guide for more details on how to get this and the next two parameters.
+* ``clientSecret`` is the secret key obtained when credentials are created that is used for authentication and authorization.
+* ``userAgent`` is a string of your choosing that explains your use of the the Reddit API. More details are available in the guide linked above.
+
+Optional:
+
+
+Create Connection
+-----------------
+
+.. code-block:: text
+
+ CREATE DATABASE reddit_data WITH ENGINE = 'reddit', PARAMETERS = {
+ "subreddit": "AskReddit",
+ "client_id": "abcd",
+ "clientSecret": "abcd1234",
+ "userAgent": "Eva DB Staging Build"
+ };
+
+Supported Tables
+----------------
+
+* ``submissions``: Lists top submissions in the given subreddit. Check `databases/reddit/table_column_info.py` for all the available columns in the table.
+
+.. code-block:: sql
+
+ SELECT * FROM hackernews_data.search_results LIMIT 3;
+
+.. note::
+
+ Looking for another table from Hackernews? Please raise a `Feature Request `_.
diff --git a/evadb/third_party/databases/interface.py b/evadb/third_party/databases/interface.py
index cacb4110f..3743ebee7 100644
--- a/evadb/third_party/databases/interface.py
+++ b/evadb/third_party/databases/interface.py
@@ -52,6 +52,8 @@ def _get_database_handler(engine: str, **kwargs):
return mod.HackernewsSearchHandler(engine, **kwargs)
elif engine == "slack":
return mod.SlackHandler(engine, **kwargs)
+ elif engine == "reddit":
+ return mod.RedditHandler(engine, **kwargs)
else:
raise NotImplementedError(f"Engine {engine} is not supported")
diff --git a/evadb/third_party/databases/reddit/__init__.py b/evadb/third_party/databases/reddit/__init__.py
new file mode 100644
index 000000000..01c7e9e97
--- /dev/null
+++ b/evadb/third_party/databases/reddit/__init__.py
@@ -0,0 +1,15 @@
+# coding=utf-8
+# Copyright 2018-2023 EvaDB
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""third party/applications/reddit"""
diff --git a/evadb/third_party/databases/reddit/reddit_handler.py b/evadb/third_party/databases/reddit/reddit_handler.py
new file mode 100644
index 000000000..3450e3c5b
--- /dev/null
+++ b/evadb/third_party/databases/reddit/reddit_handler.py
@@ -0,0 +1,170 @@
+# coding=utf-8
+# Copyright 2018-2023 EvaDB
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import pandas as pd
+from praw import Reddit
+from prawcore import ResponseException
+
+from ..types import DBHandler, DBHandlerResponse, DBHandlerStatus
+from .table_column_info import SUBMISSION_COLUMNS
+
+
+class RedditHandler(DBHandler):
+ def __init__(self, name: str, **kwargs):
+ super().__init__(name)
+ self.clientId = kwargs.get("client_id")
+ self.clientSecret = kwargs.get("clientSecret")
+ self.userAgent = kwargs.get("userAgent")
+ self.subreddit = kwargs.get("subreddit")
+
+ def connect(self):
+ try:
+ self.client = Reddit(
+ client_id=self.clientId,
+ client_secret=self.clientSecret,
+ user_agent=self.userAgent,
+ )
+ return DBHandlerStatus(status=True)
+ except Exception as e:
+ return DBHandlerStatus(status=False, error=str(e))
+
+ @property
+ def supported_table(self):
+ def _submission_generator():
+ for submission in self.client.subreddit(self.subreddit).hot():
+ yield {
+ property_name: getattr(submission, property_name)
+ for property_name, _ in SUBMISSION_COLUMNS
+ }
+
+ mapping = {
+ "submissions": {
+ "columns": SUBMISSION_COLUMNS,
+ "generator": _submission_generator(),
+ },
+ }
+ return mapping
+
+ def disconnect(self):
+ """
+ No action required to disconnect from Reddit datasource
+ TODO: Add support for destroying session token if used in other flows
+ """
+ return
+ # raise NotImplementedError()
+
+ def check_connection(self) -> DBHandlerStatus:
+ try:
+ self.client.user.me()
+ except ResponseException as e:
+ return DBHandlerStatus(
+ status=False, error=f"Received ResponseException: {e.response}"
+ )
+ return DBHandlerStatus(status=True)
+
+ def get_tables(self) -> DBHandlerResponse:
+ connection_status = self.check_connection()
+ if not connection_status.status:
+ return DBHandlerResponse(data=None, error=str(connection_status))
+
+ try:
+ tables_df = pd.DataFrame(
+ list(self.supported_table.keys()), columns=["table_name"]
+ )
+ return DBHandlerResponse(data=tables_df)
+ except Exception as e:
+ return DBHandlerResponse(data=None, error=str(e))
+
+ def get_columns(self, table_name: str) -> DBHandlerResponse:
+ columns = self.supported_table[table_name]["columns"]
+ columns_df = pd.DataFrame(columns, columns=["name", "dtype"])
+ return DBHandlerResponse(data=columns_df)
+
+ def select(self, table_name: str) -> DBHandlerResponse:
+ """
+ Returns a generator that yields the data from the given table.
+ Args:
+ table_name (str): name of the table whose data is to be retrieved.
+ Returns:
+ DBHandlerResponse
+ """
+ if not self.client:
+ return DBHandlerResponse(data=None, error="Not connected to the database.")
+ try:
+ if table_name not in self.supported_table:
+ return DBHandlerResponse(
+ data=None,
+ error="{} is not supported or does not exist.".format(table_name),
+ )
+ # TODO: Projection column trimming optimization opportunity
+ return DBHandlerResponse(
+ data=None,
+ data_generator=self.supported_table[table_name]["generator"],
+ )
+ except Exception as e:
+ return DBHandlerResponse(data=None, error=str(e))
+
+ # def post_message(self, message) -> DBHandlerResponse:
+ # try:
+ # response = self.client.chat_postMessage(channel=self.channel, text=message)
+ # return DBHandlerResponse(data=response["message"]["text"])
+ # except SlackApiError as e:
+ # assert e.response["ok"] is False
+ # assert e.response["error"]
+ # return DBHandlerResponse(data=None, error=e.response["error"])
+ #
+ # def _convert_json_response_to_DataFrame(self, json_response):
+ # messages = json_response["messages"]
+ # columns = ["text", "ts", "user"]
+ # data_df = pd.DataFrame(columns=columns)
+ # for message in messages:
+ # if message["text"] and message["ts"] and message["user"]:
+ # data_df.loc[len(data_df.index)] = [
+ # message["text"],
+ # message["ts"],
+ # message["user"],
+ # ]
+ # return data_df
+ #
+ # def get_messages(self) -> DBHandlerResponse:
+ # try:
+ # channels = self.client.conversations_list(
+ # types="public_channel,private_channel"
+ # )["channels"]
+ # channel_ids = {c["name"]: c["id"] for c in channels}
+ # response = self.client.conversations_history(
+ # channel=channel_ids[self.channel_name]
+ # )
+ # data_df = self._convert_json_response_to_DataFrame(response)
+ # return data_df
+ #
+ # except SlackApiError as e:
+ # assert e.response["ok"] is False
+ # assert e.response["error"]
+ # return DBHandlerResponse(data=None, error=e.response["error"])
+ #
+ # def del_message(self, timestamp) -> DBHandlerResponse:
+ # try:
+ # self.client.chat_delete(channel=self.channel, ts=timestamp)
+ # except SlackApiError as e:
+ # assert e.response["ok"] is False
+ # assert e.response["error"]
+ # return DBHandlerResponse(data=None, error=e.response["error"])
+
+ # def execute_native_query(self, query_string: str) -> DBHandlerResponse:
+ # """
+ # TODO: integrate code for executing query on Reddit
+ # """
+ # raise NotImplementedError()
diff --git a/evadb/third_party/databases/reddit/table_column_info.py b/evadb/third_party/databases/reddit/table_column_info.py
new file mode 100644
index 000000000..6dd819b6e
--- /dev/null
+++ b/evadb/third_party/databases/reddit/table_column_info.py
@@ -0,0 +1,41 @@
+# coding=utf-8
+# Copyright 2018-2023 EvaDB
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from typing import Union
+
+SUBMISSION_COLUMNS = [
+ ["author", str],
+ ["author_flair_text", Union[str, None]],
+ ["clicked", bool],
+ ["created_utc", str],
+ ["distinguished", bool],
+ ["edited", bool],
+ ["id", str],
+ ["is_original_content", bool],
+ ["is_self", bool],
+ ["link_flair_text", Union[str, None]],
+ ["locked", bool],
+ ["name", str],
+ ["num_comments", int],
+ ["over_18", bool],
+ ["permalink", str],
+ ["saved", bool],
+ ["score", float],
+ ["selftext", str],
+ ["spoiler", bool],
+ ["stickied", bool],
+ ["title", str],
+ ["upvote_ratio", float],
+ ["url", str],
+]
diff --git a/run_reddit_command.py b/run_reddit_command.py
new file mode 100644
index 000000000..e69de29bb
diff --git a/script/formatting/spelling.txt b/script/formatting/spelling.txt
index 1dd5566ca..0444f935b 100644
--- a/script/formatting/spelling.txt
+++ b/script/formatting/spelling.txt
@@ -695,6 +695,7 @@ PlanOprType
Popen
PostgresHandler
PostgresNativeStorageEngineTest
+praw
PredicateExecutor
PredicatePlan
PredictEmployee
diff --git a/setup.py b/setup.py
index e3d211ece..be7bdbdf9 100644
--- a/setup.py
+++ b/setup.py
@@ -138,6 +138,10 @@ def read(path, encoding="utf-8"):
"replicate"
]
+reddit_libs = [
+ "praw"
+]
+
### NEEDED FOR DEVELOPER TESTING ONLY
dev_libs = [
@@ -183,8 +187,9 @@ def read(path, encoding="utf-8"):
"xgboost": xgboost_libs,
"forecasting": forecasting_libs,
"hackernews": hackernews_libs,
+ "reddit": reddit_libs,
# everything except ray, qdrant, ludwig and postgres. The first three fail on pyhton 3.11.
- "dev": dev_libs + vision_libs + document_libs + function_libs + notebook_libs + forecasting_libs + sklearn_libs + imagegen_libs + xgboost_libs
+ "dev": dev_libs + vision_libs + document_libs + function_libs + notebook_libs + forecasting_libs + sklearn_libs + imagegen_libs + xgboost_libs + reddit_libs
}
setup(
diff --git a/test/integration_tests/long/test_reddit_datasource.py b/test/integration_tests/long/test_reddit_datasource.py
new file mode 100644
index 000000000..4ba41086a
--- /dev/null
+++ b/test/integration_tests/long/test_reddit_datasource.py
@@ -0,0 +1,59 @@
+# coding=utf-8
+# Copyright 2018-2023 EvaDB
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import unittest
+from test.markers import reddit_skip_marker
+from test.util import get_evadb_for_testing
+
+import pytest
+
+from evadb.server.command_handler import execute_query_fetch_all
+from evadb.third_party.databases.reddit.table_column_info import SUBMISSION_COLUMNS
+
+
+@pytest.mark.notparallel
+class RedditDataSourceTest(unittest.TestCase):
+ def setUp(self):
+ self.evadb = get_evadb_for_testing()
+ # reset the catalog manager before running each test
+ self.evadb.catalog().reset()
+
+ def tearDown(self):
+ execute_query_fetch_all(self.evadb, "DROP DATABASE IF EXISTS reddit_data;")
+
+ @reddit_skip_marker
+ def test_should_run_select_query_on_reddit(self):
+ # Create database.
+ params = {
+ "subreddit": "cricket",
+ "client_id": "clientid..",
+ "client_secret": "clientsecret..",
+ "user_agent": "test script for dev eva",
+ }
+ query = f"""CREATE DATABASE reddit_data
+ WITH ENGINE = "reddit",
+ PARAMETERS = {params};"""
+ execute_query_fetch_all(self.evadb, query)
+
+ query = "SELECT * FROM reddit_data.submissions LIMIT 10;"
+ batch = execute_query_fetch_all(self.evadb, query)
+ self.assertEqual(len(batch), 10)
+ expected_column = list(
+ ["submissions.{}".format(col) for col, _ in SUBMISSION_COLUMNS]
+ )
+ self.assertEqual(batch.columns, expected_column)
+
+
+if __name__ == "__main__":
+ unittest.main()
diff --git a/test/markers.py b/test/markers.py
index deefadb29..070b885fe 100644
--- a/test/markers.py
+++ b/test/markers.py
@@ -117,3 +117,5 @@
stable_diffusion_skip_marker = pytest.mark.skipif(
is_replicate_available() is False, reason="requires replicate"
)
+
+reddit_skip_marker = pytest.mark.skip(reason="requires Reddit secret key")