Note: Pyspark is deprecating df.sql_ctx in an upcoming version, so this answer is not future-proof.
I like many of the other answers, but there are a few lingering questions in comments. I think they can be addressed as such:
- we need to think of everything as immutable, so we return the subclass
- we do not need to call
self._jdf anywhere -- instead, just use self as if it were a DataFrame (since it is one -- this is why we used inheritance!)
- we need to explicitly construct a new one of our class since returns from
self.foo will be of the base DataFrame type
- I have added a DataFrameExtender subclass that mediates creation of new classes. Subclasses will inherit parent constructors if not overridden, so we can neaten up the DataFrame constructor to take a DataFrame, and add in the capability to store metadata.
We can make a new class for conceptual stages that the data arrives in, and we can sidecar flags that help us identify the state of the data in the dataframe. Here I add a flag when either add column method is called, and I push forward all existing flags. You can do whatever you like.
This setup means that you can create a sequence of DataFrameExtender objects, such as:
RawData, which implements .clean() method, returning CleanedData
CleanedData, which implements .normalize() method, returning ModelReadyData
ModelReadyData, which implements .train(model) and .predict(model), or .summarize() and which is used in a model as a base DataFrame object would be used.
By splitting these methods into different classes, it means that we cannot call .train() on RawData, but we can take a RawData object and chain together .clean().normalize().train(). This is a functional-like approach, but using immutable objects to assist in interpretation.
Note that DataFrames in Spark are lazily evaluated, which is great for this approach. All of this code just produces a final unevaluated DataFrame object that contains all of the operations that will be performed. We don't have to worry about memory or copies or anything.
from pyspark.sql.dataframe import DataFrame
class DataFrameExtender(DataFrame):
def __init__(self,df,**kwargs):
self.flags = kwargs
super().__init__(df._jdf, df.sql_ctx)
class ColumnAddedData(DataFrameExtender):
def add_column3(self):
df_added_column = self.withColumn("col3", lit(3))
return ColumnAddedData(df_added_column, with_col3=True, **self.flags)
def add_column4(self):
## Add a bit of complexity: do not call again if we have already called this method
if not self.flags['with_col4']:
df_added_column = self.withColumn("col4", lit(4))
return ColumnAddedData(df_added_column, with_col4=True, **self.flags)
return self