Skip to content

Commit

Permalink
Revert "fix: performance improvements (#374)"
Browse files Browse the repository at this point in the history
This reverts commit 11cc5d5.
  • Loading branch information
michellyrds committed Dec 4, 2024
1 parent b802f69 commit bccd33c
Show file tree
Hide file tree
Showing 6 changed files with 29 additions and 65 deletions.
14 changes: 2 additions & 12 deletions butterfree/_cli/migrate.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import os
import pkgutil
import sys
from typing import Set, Type
from typing import Set

import boto3
import setuptools
Expand Down Expand Up @@ -90,18 +90,8 @@ def __fs_objects(path: str) -> Set[FeatureSetPipeline]:

instances.add(value)

def create_instance(cls: Type[FeatureSetPipeline]) -> FeatureSetPipeline:
sig = inspect.signature(cls.__init__)
parameters = sig.parameters

if "run_date" in parameters:
run_date = datetime.datetime.today().strftime("%Y-%m-%d")
return cls(run_date)

return cls()

logger.info("Creating instances...")
return set(create_instance(value) for value in instances) # type: ignore
return set(value() for value in instances) # type: ignore


PATH = typer.Argument(
Expand Down
14 changes: 4 additions & 10 deletions butterfree/extract/source.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
from typing import List, Optional

from pyspark.sql import DataFrame
from pyspark.storagelevel import StorageLevel

from butterfree.clients import SparkClient
from butterfree.extract.readers.reader import Reader
Expand Down Expand Up @@ -96,21 +95,16 @@ def construct(
DataFrame with the query result against all readers.
"""
# Step 1: Build temporary views for each reader
for reader in self.readers:
reader.build(client=client, start_date=start_date, end_date=end_date)
reader.build(
client=client, start_date=start_date, end_date=end_date
) # create temporary views for each reader

# Step 2: Execute SQL query on the combined readers
dataframe = client.sql(self.query)

# Step 3: Cache the dataframe if necessary, using memory and disk storage
if not dataframe.isStreaming and self.eager_evaluation:
# Persist to ensure the DataFrame is stored in mem and disk (if necessary)
dataframe.persist(StorageLevel.MEMORY_AND_DISK)
# Trigger the cache/persist operation by performing an action
dataframe.count()
dataframe.cache().count()

# Step 4: Run post-processing hooks on the dataframe
post_hook_df = self.run_post_hooks(dataframe)

return post_hook_df
20 changes: 4 additions & 16 deletions butterfree/pipelines/feature_set_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,6 @@

from typing import List, Optional

from pyspark.storagelevel import StorageLevel

from butterfree.clients import SparkClient
from butterfree.dataframe_service import repartition_sort_df
from butterfree.extract import Source
Expand Down Expand Up @@ -211,45 +209,35 @@ def run(
soon. Use only if strictly necessary.
"""
# Step 1: Construct input dataframe from the source.
dataframe = self.source.construct(
client=self.spark_client,
start_date=self.feature_set.define_start_date(start_date),
end_date=end_date,
)

# Step 2: Repartition and sort if required, avoid if not necessary.
if partition_by:
order_by = order_by or partition_by
dataframe = repartition_sort_df(
dataframe, partition_by, order_by, num_processors
)

# Step 3: Construct the feature set dataframe using defined transformations.
transformed_dataframe = self.feature_set.construct(
dataframe = self.feature_set.construct(
dataframe=dataframe,
client=self.spark_client,
start_date=start_date,
end_date=end_date,
num_processors=num_processors,
)

if transformed_dataframe.storageLevel != StorageLevel(
False, False, False, False, 1
):
dataframe.unpersist() # Clear the data from the cache (disk and memory)

# Step 4: Load the data into the configured sink.
self.sink.flush(
dataframe=transformed_dataframe,
dataframe=dataframe,
feature_set=self.feature_set,
spark_client=self.spark_client,
)

# Step 5: Validate the output if not streaming and data volume is reasonable.
if not transformed_dataframe.isStreaming:
if not dataframe.isStreaming:
self.sink.validate(
dataframe=transformed_dataframe,
dataframe=dataframe,
feature_set=self.feature_set,
spark_client=self.spark_client,
)
Expand Down
37 changes: 13 additions & 24 deletions butterfree/transform/aggregated_feature_set.py
Original file line number Diff line number Diff line change
Expand Up @@ -387,7 +387,6 @@ def _aggregate(
]

groupby = self.keys_columns.copy()

if window is not None:
dataframe = dataframe.withColumn("window", window.get())
groupby.append("window")
Expand All @@ -411,23 +410,19 @@ def _aggregate(
"keep_rn", functions.row_number().over(partition_window)
).filter("keep_rn = 1")

current_partitions = dataframe.rdd.getNumPartitions()
optimal_partitions = num_processors or current_partitions

if current_partitions != optimal_partitions:
dataframe = repartition_df(
dataframe,
partition_by=groupby,
num_processors=optimal_partitions,
)

# repartition to have all rows for each group at the same partition
# by doing that, we won't have to shuffle data on grouping by id
dataframe = repartition_df(
dataframe,
partition_by=groupby,
num_processors=num_processors,
)
grouped_data = dataframe.groupby(*groupby)

if self._pivot_column and self._pivot_values:
if self._pivot_column:
grouped_data = grouped_data.pivot(self._pivot_column, self._pivot_values)

aggregated = grouped_data.agg(*aggregations)

return self._with_renamed_columns(aggregated, features, window)

def _with_renamed_columns(
Expand Down Expand Up @@ -639,18 +634,12 @@ def construct(
output_df = output_df.select(*self.columns).replace( # type: ignore
float("nan"), None
)

if not output_df.isStreaming and self.deduplicate_rows:
output_df = self._filter_duplicated_rows(output_df)
if not output_df.isStreaming:
if self.deduplicate_rows:
output_df = self._filter_duplicated_rows(output_df)
if self.eager_evaluation:
output_df.cache().count()

post_hook_df = self.run_post_hooks(output_df)

# Eager evaluation, only if needed and managable
if not output_df.isStreaming and self.eager_evaluation:
# Small dataframes only
if output_df.count() < 1_000_000:
post_hook_df.cache().count()
else:
post_hook_df.cache() # Cache without materialization for large volumes

return post_hook_df
7 changes: 5 additions & 2 deletions butterfree/transform/feature_set.py
Original file line number Diff line number Diff line change
Expand Up @@ -436,8 +436,11 @@ def construct(
pre_hook_df,
).select(*self.columns)

if not output_df.isStreaming and self.deduplicate_rows:
output_df = self._filter_duplicated_rows(output_df)
if not output_df.isStreaming:
if self.deduplicate_rows:
output_df = self._filter_duplicated_rows(output_df)
if self.eager_evaluation:
output_df.cache().count()

output_df = self.incremental_strategy.filter_with_incremental_strategy(
dataframe=output_df, start_date=start_date, end_date=end_date
Expand Down
2 changes: 1 addition & 1 deletion tests/unit/butterfree/transform/test_feature_set.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,7 +220,7 @@ def test_construct(
+ feature_divide.get_output_columns()
)
assert_dataframe_equality(result_df, feature_set_dataframe)
assert not result_df.is_cached
assert result_df.is_cached

def test_construct_invalid_df(
self, key_id, timestamp_c, feature_add, feature_divide
Expand Down

0 comments on commit bccd33c

Please sign in to comment.