## Licensed to the Apache Software Foundation (ASF) under one or more# contributor license agreements. See the NOTICE file distributed with# this work for additional information regarding copyright ownership.# The ASF licenses this file to You under the Apache License, Version 2.0# (the "License"); you may not use this file except in compliance with# the License. You may obtain a copy of the License at## http://www.apache.org/licenses/LICENSE-2.0## Unless required by applicable law or agreed to in writing, software# distributed under the License is distributed on an "AS IS" BASIS,# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.# See the License for the specific language governing permissions and# limitations under the License.#"""User-defined table function related classes and functions"""importpickleimportsysimportwarningsfromtypingimportAny,Type,TYPE_CHECKING,Optional,Unionfrompy4j.java_gatewayimportJavaObjectfrompyspark.errorsimportPySparkAttributeError,PySparkRuntimeError,PySparkTypeErrorfrompyspark.rddimportPythonEvalTypefrompyspark.sql.columnimport_to_java_column,_to_seqfrompyspark.sql.pandas.utilsimportrequire_minimum_pandas_version,require_minimum_pyarrow_versionfrompyspark.sql.typesimportStructType,_parse_datatype_stringfrompyspark.sql.udfimport_wrap_functionifTYPE_CHECKING:frompyspark.sql._typingimportColumnOrNamefrompyspark.sql.dataframeimportDataFramefrompyspark.sql.sessionimportSparkSession__all__=["UDTFRegistration"]def_create_udtf(cls:Type,returnType:Union[StructType,str],name:Optional[str]=None,evalType:int=PythonEvalType.SQL_TABLE_UDF,deterministic:bool=False,)->"UserDefinedTableFunction":"""Create a Python UDTF with the given eval type."""udtf_obj=UserDefinedTableFunction(cls,returnType=returnType,name=name,evalType=evalType,deterministic=deterministic)returnudtf_objdef_create_py_udtf(cls:Type,returnType:Union[StructType,str],name:Optional[str]=None,deterministic:bool=False,useArrow:Optional[bool]=None,)->"UserDefinedTableFunction":"""Create a regular or an Arrow-optimized Python UDTF."""# Determine whether to create Arrow-optimized UDTFs.ifuseArrowisnotNone:arrow_enabled=useArrowelse:frompyspark.sqlimportSparkSessionsession=SparkSession._instantiatedSessionarrow_enabled=FalseifsessionisnotNone:value=session.conf.get("spark.sql.execution.pythonUDTF.arrow.enabled")ifisinstance(value,str)andvalue.lower()=="true":arrow_enabled=Trueeval_type:int=PythonEvalType.SQL_TABLE_UDFifarrow_enabled:# Return the regular UDTF if the required dependencies are not satisfied.try:require_minimum_pandas_version()require_minimum_pyarrow_version()eval_type=PythonEvalType.SQL_ARROW_TABLE_UDFexceptImportErrorase:warnings.warn(f"Arrow optimization for Python UDTFs cannot be enabled: {str(e)}. "f"Falling back to using regular Python UDTFs.",UserWarning,)return_create_udtf(cls=cls,returnType=returnType,name=name,evalType=eval_type,deterministic=deterministic,)def_validate_udtf_handler(cls:Any)->None:"""Validate the handler class of a UDTF."""ifnotisinstance(cls,type):raisePySparkTypeError(error_class="INVALID_UDTF_HANDLER_TYPE",message_parameters={"type":type(cls).__name__})ifnothasattr(cls,"eval"):raisePySparkAttributeError(error_class="INVALID_UDTF_NO_EVAL",message_parameters={"name":cls.__name__})
[docs]classUserDefinedTableFunction:""" User-defined table function in Python .. versionadded:: 3.5.0 Notes ----- The constructor of this class is not supposed to be directly called. Use :meth:`pyspark.sql.functions.udtf` to create this instance. This API is evolving. """def__init__(self,func:Type,returnType:Union[StructType,str],name:Optional[str]=None,evalType:int=PythonEvalType.SQL_TABLE_UDF,deterministic:bool=False,):_validate_udtf_handler(func)self.func=funcself._returnType=returnTypeself._returnType_placeholder:Optional[StructType]=Noneself._inputTypes_placeholder=Noneself._judtf_placeholder=Noneself._name=nameorfunc.__name__self.evalType=evalTypeself.deterministic=deterministic@propertydefreturnType(self)->StructType:# `_parse_datatype_string` accesses to JVM for parsing a DDL formatted string.# This makes sure this is called after SparkContext is initialized.ifself._returnType_placeholderisNone:ifisinstance(self._returnType,str):parsed=_parse_datatype_string(self._returnType)else:parsed=self._returnTypeifnotisinstance(parsed,StructType):raisePySparkTypeError(error_class="UDTF_RETURN_TYPE_MISMATCH",message_parameters={"name":self._name,"return_type":f"{parsed}",},)self._returnType_placeholder=parsedreturnself._returnType_placeholder@propertydef_judtf(self)->JavaObject:ifself._judtf_placeholderisNone:self._judtf_placeholder=self._create_judtf(self.func)returnself._judtf_placeholderdef_create_judtf(self,func:Type)->JavaObject:frompyspark.sqlimportSparkSessionspark=SparkSession._getActiveSessionOrCreate()sc=spark.sparkContexttry:wrapped_func=_wrap_function(sc,func)exceptpickle.PicklingErrorase:if"CONTEXT_ONLY_VALID_ON_DRIVER"instr(e):raisePySparkRuntimeError(error_class="UDTF_SERIALIZATION_ERROR",message_parameters={"name":self._name,"message":"it appears that you are attempting to reference SparkSession ""inside a UDTF. SparkSession can only be used on the driver, ""not in code that runs on workers. Please remove the reference ""and try again.",},)fromNoneraisePySparkRuntimeError(error_class="UDTF_SERIALIZATION_ERROR",message_parameters={"name":self._name,"message":"Please check the stack trace and make sure the ""function is serializable.",},)jdt=spark._jsparkSession.parseDataType(self.returnType.json())assertsc._jvmisnotNonejudtf=sc._jvm.org.apache.spark.sql.execution.python.UserDefinedPythonTableFunction(self._name,wrapped_func,jdt,self.evalType,self.deterministic)returnjudtfdef__call__(self,*cols:"ColumnOrName")->"DataFrame":frompyspark.sqlimportDataFrame,SparkSessionspark=SparkSession._getActiveSessionOrCreate()sc=spark.sparkContextjudtf=self._judtfjPythonUDTF=judtf.apply(spark._jsparkSession,_to_seq(sc,cols,_to_java_column))returnDataFrame(jPythonUDTF,spark)
[docs]defasDeterministic(self)->"UserDefinedTableFunction":""" Updates UserDefinedTableFunction to deterministic. """# Explicitly clean the cache to create a JVM UDTF instance.self._judtf_placeholder=Noneself.deterministic=Truereturnself
[docs]classUDTFRegistration:""" Wrapper for user-defined table function registration. This instance can be accessed by :attr:`spark.udtf` or :attr:`sqlContext.udtf`. .. versionadded:: 3.5.0 """def__init__(self,sparkSession:"SparkSession"):self.sparkSession=sparkSession
[docs]defregister(self,name:str,f:"UserDefinedTableFunction",)->"UserDefinedTableFunction":"""Register a Python user-defined table function as a SQL table function. .. versionadded:: 3.5.0 Parameters ---------- name : str The name of the user-defined table function in SQL statements. f : function or :meth:`pyspark.sql.functions.udtf` The user-defined table function. Returns ------- function The registered user-defined table function. Notes ----- Spark uses the return type of the given user-defined table function as the return type of the registered user-defined function. To register a nondeterministic Python table function, users need to first build a nondeterministic user-defined table function and then register it as a SQL function. Examples -------- >>> from pyspark.sql.functions import udtf >>> @udtf(returnType="c1: int, c2: int") ... class PlusOne: ... def eval(self, x: int): ... yield x, x + 1 ... >>> _ = spark.udtf.register(name="plus_one", f=PlusOne) >>> spark.sql("SELECT * FROM plus_one(1)").collect() [Row(c1=1, c2=2)] Use it with lateral join >>> spark.sql("SELECT * FROM VALUES (0, 1), (1, 2) t(x, y), LATERAL plus_one(x)").collect() [Row(x=0, y=1, c1=0, c2=1), Row(x=1, y=2, c1=1, c2=2)] """iff.evalTypenotin[PythonEvalType.SQL_TABLE_UDF,PythonEvalType.SQL_ARROW_TABLE_UDF]:raisePySparkTypeError(error_class="INVALID_UDTF_EVAL_TYPE",message_parameters={"name":name,"eval_type":"SQL_TABLE_UDF, SQL_ARROW_TABLE_UDF",},)register_udtf=_create_udtf(cls=f.func,returnType=f.returnType,name=name,evalType=f.evalType,deterministic=f.deterministic,)self.sparkSession._jsparkSession.udtf().registerPython(name,register_udtf._judtf)returnregister_udtf