diff --git a/docs/source/reference/databases/reddit.rst b/docs/source/reference/databases/reddit.rst new file mode 100644 index 000000000..c25115e91 --- /dev/null +++ b/docs/source/reference/databases/reddit.rst @@ -0,0 +1,48 @@ +Reddit +========== + +The connection to Reddit is based on the `praw `_ library. + +Dependency +---------- + +* praw + + +Parameters +---------- + +Required: + +* ``subreddit`` is the name of the subreddit from which the data is fetched. +* ``clientId`` is the unique identifier issued to the client when creating credentials on Reddit. Refer to the [First Steps](https://github.com/reddit-archive/reddit/wiki/OAuth2-Quick-Start-Example#first-steps) guide for more details on how to get this and the next two parameters. +* ``clientSecret`` is the secret key obtained when credentials are created that is used for authentication and authorization. +* ``userAgent`` is a string of your choosing that explains your use of the the Reddit API. More details are available in the guide linked above. + +Optional: + + +Create Connection +----------------- + +.. code-block:: text + + CREATE DATABASE reddit_data WITH ENGINE = 'reddit', PARAMETERS = { + "subreddit": "AskReddit", + "client_id": "abcd", + "clientSecret": "abcd1234", + "userAgent": "Eva DB Staging Build" + }; + +Supported Tables +---------------- + +* ``submissions``: Lists top submissions in the given subreddit. Check `databases/reddit/table_column_info.py` for all the available columns in the table. + +.. code-block:: sql + + SELECT * FROM hackernews_data.search_results LIMIT 3; + +.. note:: + + Looking for another table from Hackernews? Please raise a `Feature Request `_. diff --git a/evadb/third_party/databases/interface.py b/evadb/third_party/databases/interface.py index cacb4110f..3743ebee7 100644 --- a/evadb/third_party/databases/interface.py +++ b/evadb/third_party/databases/interface.py @@ -52,6 +52,8 @@ def _get_database_handler(engine: str, **kwargs): return mod.HackernewsSearchHandler(engine, **kwargs) elif engine == "slack": return mod.SlackHandler(engine, **kwargs) + elif engine == "reddit": + return mod.RedditHandler(engine, **kwargs) else: raise NotImplementedError(f"Engine {engine} is not supported") diff --git a/evadb/third_party/databases/reddit/__init__.py b/evadb/third_party/databases/reddit/__init__.py new file mode 100644 index 000000000..01c7e9e97 --- /dev/null +++ b/evadb/third_party/databases/reddit/__init__.py @@ -0,0 +1,15 @@ +# coding=utf-8 +# Copyright 2018-2023 EvaDB +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""third party/applications/reddit""" diff --git a/evadb/third_party/databases/reddit/reddit_handler.py b/evadb/third_party/databases/reddit/reddit_handler.py new file mode 100644 index 000000000..3450e3c5b --- /dev/null +++ b/evadb/third_party/databases/reddit/reddit_handler.py @@ -0,0 +1,170 @@ +# coding=utf-8 +# Copyright 2018-2023 EvaDB +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pandas as pd +from praw import Reddit +from prawcore import ResponseException + +from ..types import DBHandler, DBHandlerResponse, DBHandlerStatus +from .table_column_info import SUBMISSION_COLUMNS + + +class RedditHandler(DBHandler): + def __init__(self, name: str, **kwargs): + super().__init__(name) + self.clientId = kwargs.get("client_id") + self.clientSecret = kwargs.get("clientSecret") + self.userAgent = kwargs.get("userAgent") + self.subreddit = kwargs.get("subreddit") + + def connect(self): + try: + self.client = Reddit( + client_id=self.clientId, + client_secret=self.clientSecret, + user_agent=self.userAgent, + ) + return DBHandlerStatus(status=True) + except Exception as e: + return DBHandlerStatus(status=False, error=str(e)) + + @property + def supported_table(self): + def _submission_generator(): + for submission in self.client.subreddit(self.subreddit).hot(): + yield { + property_name: getattr(submission, property_name) + for property_name, _ in SUBMISSION_COLUMNS + } + + mapping = { + "submissions": { + "columns": SUBMISSION_COLUMNS, + "generator": _submission_generator(), + }, + } + return mapping + + def disconnect(self): + """ + No action required to disconnect from Reddit datasource + TODO: Add support for destroying session token if used in other flows + """ + return + # raise NotImplementedError() + + def check_connection(self) -> DBHandlerStatus: + try: + self.client.user.me() + except ResponseException as e: + return DBHandlerStatus( + status=False, error=f"Received ResponseException: {e.response}" + ) + return DBHandlerStatus(status=True) + + def get_tables(self) -> DBHandlerResponse: + connection_status = self.check_connection() + if not connection_status.status: + return DBHandlerResponse(data=None, error=str(connection_status)) + + try: + tables_df = pd.DataFrame( + list(self.supported_table.keys()), columns=["table_name"] + ) + return DBHandlerResponse(data=tables_df) + except Exception as e: + return DBHandlerResponse(data=None, error=str(e)) + + def get_columns(self, table_name: str) -> DBHandlerResponse: + columns = self.supported_table[table_name]["columns"] + columns_df = pd.DataFrame(columns, columns=["name", "dtype"]) + return DBHandlerResponse(data=columns_df) + + def select(self, table_name: str) -> DBHandlerResponse: + """ + Returns a generator that yields the data from the given table. + Args: + table_name (str): name of the table whose data is to be retrieved. + Returns: + DBHandlerResponse + """ + if not self.client: + return DBHandlerResponse(data=None, error="Not connected to the database.") + try: + if table_name not in self.supported_table: + return DBHandlerResponse( + data=None, + error="{} is not supported or does not exist.".format(table_name), + ) + # TODO: Projection column trimming optimization opportunity + return DBHandlerResponse( + data=None, + data_generator=self.supported_table[table_name]["generator"], + ) + except Exception as e: + return DBHandlerResponse(data=None, error=str(e)) + + # def post_message(self, message) -> DBHandlerResponse: + # try: + # response = self.client.chat_postMessage(channel=self.channel, text=message) + # return DBHandlerResponse(data=response["message"]["text"]) + # except SlackApiError as e: + # assert e.response["ok"] is False + # assert e.response["error"] + # return DBHandlerResponse(data=None, error=e.response["error"]) + # + # def _convert_json_response_to_DataFrame(self, json_response): + # messages = json_response["messages"] + # columns = ["text", "ts", "user"] + # data_df = pd.DataFrame(columns=columns) + # for message in messages: + # if message["text"] and message["ts"] and message["user"]: + # data_df.loc[len(data_df.index)] = [ + # message["text"], + # message["ts"], + # message["user"], + # ] + # return data_df + # + # def get_messages(self) -> DBHandlerResponse: + # try: + # channels = self.client.conversations_list( + # types="public_channel,private_channel" + # )["channels"] + # channel_ids = {c["name"]: c["id"] for c in channels} + # response = self.client.conversations_history( + # channel=channel_ids[self.channel_name] + # ) + # data_df = self._convert_json_response_to_DataFrame(response) + # return data_df + # + # except SlackApiError as e: + # assert e.response["ok"] is False + # assert e.response["error"] + # return DBHandlerResponse(data=None, error=e.response["error"]) + # + # def del_message(self, timestamp) -> DBHandlerResponse: + # try: + # self.client.chat_delete(channel=self.channel, ts=timestamp) + # except SlackApiError as e: + # assert e.response["ok"] is False + # assert e.response["error"] + # return DBHandlerResponse(data=None, error=e.response["error"]) + + # def execute_native_query(self, query_string: str) -> DBHandlerResponse: + # """ + # TODO: integrate code for executing query on Reddit + # """ + # raise NotImplementedError() diff --git a/evadb/third_party/databases/reddit/table_column_info.py b/evadb/third_party/databases/reddit/table_column_info.py new file mode 100644 index 000000000..6dd819b6e --- /dev/null +++ b/evadb/third_party/databases/reddit/table_column_info.py @@ -0,0 +1,41 @@ +# coding=utf-8 +# Copyright 2018-2023 EvaDB +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Union + +SUBMISSION_COLUMNS = [ + ["author", str], + ["author_flair_text", Union[str, None]], + ["clicked", bool], + ["created_utc", str], + ["distinguished", bool], + ["edited", bool], + ["id", str], + ["is_original_content", bool], + ["is_self", bool], + ["link_flair_text", Union[str, None]], + ["locked", bool], + ["name", str], + ["num_comments", int], + ["over_18", bool], + ["permalink", str], + ["saved", bool], + ["score", float], + ["selftext", str], + ["spoiler", bool], + ["stickied", bool], + ["title", str], + ["upvote_ratio", float], + ["url", str], +] diff --git a/run_reddit_command.py b/run_reddit_command.py new file mode 100644 index 000000000..e69de29bb diff --git a/script/formatting/spelling.txt b/script/formatting/spelling.txt index 1dd5566ca..0444f935b 100644 --- a/script/formatting/spelling.txt +++ b/script/formatting/spelling.txt @@ -695,6 +695,7 @@ PlanOprType Popen PostgresHandler PostgresNativeStorageEngineTest +praw PredicateExecutor PredicatePlan PredictEmployee diff --git a/setup.py b/setup.py index e3d211ece..be7bdbdf9 100644 --- a/setup.py +++ b/setup.py @@ -138,6 +138,10 @@ def read(path, encoding="utf-8"): "replicate" ] +reddit_libs = [ + "praw" +] + ### NEEDED FOR DEVELOPER TESTING ONLY dev_libs = [ @@ -183,8 +187,9 @@ def read(path, encoding="utf-8"): "xgboost": xgboost_libs, "forecasting": forecasting_libs, "hackernews": hackernews_libs, + "reddit": reddit_libs, # everything except ray, qdrant, ludwig and postgres. The first three fail on pyhton 3.11. - "dev": dev_libs + vision_libs + document_libs + function_libs + notebook_libs + forecasting_libs + sklearn_libs + imagegen_libs + xgboost_libs + "dev": dev_libs + vision_libs + document_libs + function_libs + notebook_libs + forecasting_libs + sklearn_libs + imagegen_libs + xgboost_libs + reddit_libs } setup( diff --git a/test/integration_tests/long/test_reddit_datasource.py b/test/integration_tests/long/test_reddit_datasource.py new file mode 100644 index 000000000..4ba41086a --- /dev/null +++ b/test/integration_tests/long/test_reddit_datasource.py @@ -0,0 +1,59 @@ +# coding=utf-8 +# Copyright 2018-2023 EvaDB +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import unittest +from test.markers import reddit_skip_marker +from test.util import get_evadb_for_testing + +import pytest + +from evadb.server.command_handler import execute_query_fetch_all +from evadb.third_party.databases.reddit.table_column_info import SUBMISSION_COLUMNS + + +@pytest.mark.notparallel +class RedditDataSourceTest(unittest.TestCase): + def setUp(self): + self.evadb = get_evadb_for_testing() + # reset the catalog manager before running each test + self.evadb.catalog().reset() + + def tearDown(self): + execute_query_fetch_all(self.evadb, "DROP DATABASE IF EXISTS reddit_data;") + + @reddit_skip_marker + def test_should_run_select_query_on_reddit(self): + # Create database. + params = { + "subreddit": "cricket", + "client_id": "clientid..", + "client_secret": "clientsecret..", + "user_agent": "test script for dev eva", + } + query = f"""CREATE DATABASE reddit_data + WITH ENGINE = "reddit", + PARAMETERS = {params};""" + execute_query_fetch_all(self.evadb, query) + + query = "SELECT * FROM reddit_data.submissions LIMIT 10;" + batch = execute_query_fetch_all(self.evadb, query) + self.assertEqual(len(batch), 10) + expected_column = list( + ["submissions.{}".format(col) for col, _ in SUBMISSION_COLUMNS] + ) + self.assertEqual(batch.columns, expected_column) + + +if __name__ == "__main__": + unittest.main() diff --git a/test/markers.py b/test/markers.py index deefadb29..070b885fe 100644 --- a/test/markers.py +++ b/test/markers.py @@ -117,3 +117,5 @@ stable_diffusion_skip_marker = pytest.mark.skipif( is_replicate_available() is False, reason="requires replicate" ) + +reddit_skip_marker = pytest.mark.skip(reason="requires Reddit secret key")