Skip to content

Commit 916a779

Browse files
committed
[SPARK-53449][SQL] Simply options for builtin Datasource Scan related classes
### What changes were proposed in this pull request? Simplify interoperations between SQLConf and builtin Datasource Scan related classes, following [SPARK-52704](https://issues.apache.org/jira/browse/SPARK-52704) [SPARK-53415](https://issues.apache.org/jira/browse/SPARK-53415) ### Why are the changes needed? - Reduce code duplication - Restore type annotation for IDE ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? existing tests ### Was this patch authored or co-authored using generative AI tooling? no Closes #52192 from yaooqinn/SPARK-53449. Authored-by: Kent Yao <yao@apache.org> Signed-off-by: Kent Yao <yao@apache.org>
1 parent 5a30d68 commit 916a779

File tree

5 files changed

+42
-39
lines changed

5 files changed

+42
-39
lines changed

sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ import org.apache.spark.sql.execution.datasources.parquet.{ParquetFileFormat =>
3737
import org.apache.spark.sql.execution.datasources.v2.{PushedDownOperators, TableSampleInfo}
3838
import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics}
3939
import org.apache.spark.sql.execution.vectorized.ConstantColumnVector
40-
import org.apache.spark.sql.internal.SQLConf
40+
import org.apache.spark.sql.internal.{SessionStateHelper, SQLConf}
4141
import org.apache.spark.sql.sources.{BaseRelation, Filter}
4242
import org.apache.spark.sql.types.StructType
4343
import org.apache.spark.sql.vectorized.ColumnarBatch
@@ -295,7 +295,7 @@ case class RowDataSourceScanExec(
295295
/**
296296
* A base trait for file scans containing file listing and metrics code.
297297
*/
298-
trait FileSourceScanLike extends DataSourceScanExec {
298+
trait FileSourceScanLike extends DataSourceScanExec with SessionStateHelper {
299299

300300
// Filters on non-partition columns.
301301
def dataFilters: Seq[Expression]
@@ -327,7 +327,7 @@ trait FileSourceScanLike extends DataSourceScanExec {
327327
relation.fileFormat.vectorTypes(
328328
requiredSchema = requiredSchema,
329329
partitionSchema = relation.partitionSchema,
330-
relation.sparkSession.sessionState.conf).map { vectorTypes =>
330+
getSqlConf(relation.sparkSession)).map { vectorTypes =>
331331
vectorTypes ++
332332
// for column-based file format, append metadata column's vector type classes if any
333333
fileConstantMetadataColumns.map { _ => classOf[ConstantColumnVector].getName }
@@ -414,7 +414,7 @@ trait FileSourceScanLike extends DataSourceScanExec {
414414

415415
// exposed for testing
416416
lazy val bucketedScan: Boolean = {
417-
if (relation.sparkSession.sessionState.conf.bucketingEnabled && relation.bucketSpec.isDefined
417+
if (getSqlConf(relation.sparkSession).bucketingEnabled && relation.bucketSpec.isDefined
418418
&& !disableBucketedScan) {
419419
val spec = relation.bucketSpec.get
420420
val bucketColumns = spec.bucketColumnNames.flatMap(n => toAttribute(n))
@@ -535,7 +535,7 @@ trait FileSourceScanLike extends DataSourceScanExec {
535535
bucketedKey -> "true",
536536
"SelectedBucketsCount" -> (s"$numSelectedBuckets out of ${spec.numBuckets}" +
537537
optionalNumCoalescedBuckets.map { b => s" (Coalesced to $b)"}.getOrElse("")))
538-
} else if (!relation.sparkSession.sessionState.conf.bucketingEnabled) {
538+
} else if (!getSqlConf(relation.sparkSession).bucketingEnabled) {
539539
metadata + (bucketedKey -> "false (disabled by configuration)")
540540
} else if (disableBucketedScan) {
541541
metadata + (bucketedKey -> "false (disabled by query planner)")
@@ -646,7 +646,7 @@ trait FileSourceScanLike extends DataSourceScanExec {
646646
}
647647

648648
override def calculateTotalPartitionBytes: Long = {
649-
val openCostInBytes = relation.sparkSession.sessionState.conf.filesOpenCostInBytes
649+
val openCostInBytes = getSqlConf(relation.sparkSession).filesOpenCostInBytes
650650
partitionDirectories.flatMap(_.files.map(_.getLen + openCostInBytes)).sum
651651
}
652652

@@ -698,7 +698,7 @@ case class FileSourceScanExec(
698698
// Note that some vals referring the file-based relation are lazy intentionally
699699
// so that this plan can be canonicalized on executor side too. See SPARK-23731.
700700
override lazy val supportsColumnar: Boolean = {
701-
val conf = relation.sparkSession.sessionState.conf
701+
val conf = getSqlConf(relation.sparkSession)
702702
// Only output columnar if there is WSCG to read it.
703703
val requiredWholeStageCodegenSettings =
704704
conf.wholeStageEnabled && !WholeStageCodegenExec.isTooManyFields(conf, schema)
@@ -725,7 +725,7 @@ case class FileSourceScanExec(
725725
requiredSchema = requiredSchema,
726726
filters = pushedDownFilters,
727727
options = options,
728-
hadoopConf = relation.sparkSession.sessionState.newHadoopConfWithOptions(relation.options))
728+
hadoopConf = getHadoopConf(relation.sparkSession, relation.options))
729729

730730
val readRDD = if (bucketedScan) {
731731
createBucketedReadRDD(relation.bucketSpec.get, readFile, dynamicallySelectedPartitions)
@@ -849,15 +849,15 @@ case class FileSourceScanExec(
849849
private def createReadRDD(
850850
readFile: PartitionedFile => Iterator[InternalRow],
851851
selectedPartitions: ScanFileListing): RDD[InternalRow] = {
852-
val openCostInBytes = relation.sparkSession.sessionState.conf.filesOpenCostInBytes
852+
val openCostInBytes = getSqlConf(relation.sparkSession).filesOpenCostInBytes
853853
val maxSplitBytes =
854854
FilePartition.maxSplitBytes(relation.sparkSession, selectedPartitions)
855855
logInfo(log"Planning scan with bin packing, max size: ${MDC(MAX_SPLIT_BYTES, maxSplitBytes)} " +
856856
log"bytes, open cost is considered as scanning ${MDC(OPEN_COST_IN_BYTES, openCostInBytes)} " +
857857
log"bytes.")
858858

859859
// Filter files with bucket pruning if possible
860-
val bucketingEnabled = relation.sparkSession.sessionState.conf.bucketingEnabled
860+
val bucketingEnabled = getSqlConf(relation.sparkSession).bucketingEnabled
861861
val shouldProcess: Path => Boolean = optionalBucketSet match {
862862
case Some(bucketSet) if bucketingEnabled =>
863863
// Do not prune the file if bucket file name is invalid

sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala

Lines changed: 15 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ import org.apache.spark.sql.execution.streaming.{Sink, Source}
5353
import org.apache.spark.sql.execution.streaming.runtime._
5454
import org.apache.spark.sql.execution.streaming.sinks.FileStreamSink
5555
import org.apache.spark.sql.execution.streaming.sources.{RateStreamProvider, TextSocketSourceProvider}
56-
import org.apache.spark.sql.internal.SQLConf
56+
import org.apache.spark.sql.internal.{SessionStateHelper, SQLConf}
5757
import org.apache.spark.sql.sources._
5858
import org.apache.spark.sql.streaming.OutputMode
5959
import org.apache.spark.sql.types.{DataType, StructField, StructType}
@@ -100,12 +100,14 @@ case class DataSource(
100100
partitionColumns: Seq[String] = Seq.empty,
101101
bucketSpec: Option[BucketSpec] = None,
102102
options: Map[String, String] = Map.empty,
103-
catalogTable: Option[CatalogTable] = None) extends Logging {
103+
catalogTable: Option[CatalogTable] = None) extends SessionStateHelper with Logging {
104104

105105
case class SourceInfo(name: String, schema: StructType, partitionColumns: Seq[String])
106106

107+
private val conf: SQLConf = getSqlConf(sparkSession)
108+
107109
lazy val providingClass: Class[_] = {
108-
val cls = DataSource.lookupDataSource(className, sparkSession.sessionState.conf)
110+
val cls = DataSource.lookupDataSource(className, conf)
109111
// `providingClass` is used for resolving data source relation for catalog tables.
110112
// As now catalog for data source V2 is under development, here we fall back all the
111113
// [[FileDataSourceV2]] to [[FileFormat]] to guarantee the current catalog works.
@@ -120,8 +122,7 @@ case class DataSource(
120122

121123
private[sql] def providingInstance(): Any = providingClass.getConstructor().newInstance()
122124

123-
private def newHadoopConfiguration(): Configuration =
124-
sparkSession.sessionState.newHadoopConfWithOptions(options)
125+
private def newHadoopConfiguration(): Configuration = getHadoopConf(sparkSession, options)
125126

126127
private def makeQualified(path: Path): Path = {
127128
val fs = path.getFileSystem(newHadoopConfiguration())
@@ -130,7 +131,7 @@ case class DataSource(
130131

131132
lazy val sourceInfo: SourceInfo = sourceSchema()
132133
private val caseInsensitiveOptions = CaseInsensitiveMap(options)
133-
private val equality = sparkSession.sessionState.conf.resolver
134+
private val equality = conf.resolver
134135

135136
/**
136137
* Whether or not paths should be globbed before being used to access files.
@@ -262,7 +263,7 @@ case class DataSource(
262263
}
263264
}
264265

265-
val isSchemaInferenceEnabled = sparkSession.sessionState.conf.streamingSchemaInference
266+
val isSchemaInferenceEnabled = conf.streamingSchemaInference
266267
val isTextSource = providingClass == classOf[text.TextFileFormat]
267268
val isSingleVariantColumn = (providingClass == classOf[json.JsonFileFormat] ||
268269
providingClass == classOf[csv.CSVFileFormat]) &&
@@ -281,8 +282,7 @@ case class DataSource(
281282
checkAndGlobPathIfNecessary(checkEmptyGlobPath = false, checkFilesExist = false)
282283
createInMemoryFileIndex(globbedPaths)
283284
})
284-
val forceNullable = sparkSession.sessionState.conf
285-
.getConf(SQLConf.FILE_SOURCE_SCHEMA_FORCE_NULLABLE)
285+
val forceNullable = conf.getConf(SQLConf.FILE_SOURCE_SCHEMA_FORCE_NULLABLE)
286286
val sourceDataSchema = if (forceNullable) dataSchema.asNullable else dataSchema
287287
SourceInfo(
288288
s"FileSource[$path]",
@@ -381,7 +381,7 @@ case class DataSource(
381381
if FileStreamSink.hasMetadata(
382382
caseInsensitiveOptions.get("path").toSeq ++ paths,
383383
newHadoopConfiguration(),
384-
sparkSession.sessionState.conf) =>
384+
conf) =>
385385
val basePath = new Path((caseInsensitiveOptions.get("path").toSeq ++ paths).head)
386386
val fileCatalog = new MetadataLogFileIndex(sparkSession, basePath,
387387
caseInsensitiveOptions, userSpecifiedSchema)
@@ -407,11 +407,11 @@ case class DataSource(
407407

408408
// This is a non-streaming file based datasource.
409409
case (format: FileFormat, _) =>
410-
val useCatalogFileIndex = sparkSession.sessionState.conf.manageFilesourcePartitions &&
410+
val useCatalogFileIndex = conf.manageFilesourcePartitions &&
411411
catalogTable.isDefined && catalogTable.get.tracksPartitionsInCatalog &&
412412
catalogTable.get.partitionColumnNames.nonEmpty
413413
val (fileCatalog, dataSchema, partitionSchema) = if (useCatalogFileIndex) {
414-
val defaultTableSize = sparkSession.sessionState.conf.defaultSizeInBytes
414+
val defaultTableSize = conf.defaultSizeInBytes
415415
val index = new CatalogFileIndex(
416416
sparkSession,
417417
catalogTable.get,
@@ -475,7 +475,7 @@ case class DataSource(
475475
throw QueryExecutionErrors.dataPathNotSpecifiedError()
476476
}
477477

478-
val caseSensitive = sparkSession.sessionState.conf.caseSensitiveAnalysis
478+
val caseSensitive = conf.caseSensitiveAnalysis
479479
PartitioningUtils.validatePartitionColumn(data.schema, partitionColumns, caseSensitive)
480480

481481
val fileIndex = catalogTable.map(_.identifier).map { tableIdent =>
@@ -531,7 +531,7 @@ case class DataSource(
531531
disallowWritingIntervals(
532532
outputColumns.toStructType.asNullable, format.toString, forbidAnsiIntervals = false)
533533
val cmd = planForWritingFileFormat(format, mode, data)
534-
val qe = sparkSession.sessionState.executePlan(cmd)
534+
val qe = sessionState(sparkSession).executePlan(cmd)
535535
qe.assertCommandExecuted()
536536
// Replace the schema with that of the DataFrame we just wrote out to avoid re-inferring
537537
copy(userSpecifiedSchema = Some(outputColumns.toStructType.asNullable)).resolveRelation()
@@ -555,7 +555,7 @@ case class DataSource(
555555
SaveIntoDataSourceCommand(data, dataSource, caseInsensitiveOptions, mode)
556556
case format: FileFormat =>
557557
disallowWritingIntervals(data.schema, format.toString, forbidAnsiIntervals = false)
558-
DataSource.validateSchema(format.toString, data.schema, sparkSession.sessionState.conf)
558+
DataSource.validateSchema(format.toString, data.schema, conf)
559559
planForWritingFileFormat(format, mode, data)
560560
case _ => throw SparkException.internalError(
561561
s"${providingClass.getCanonicalName} does not allow create table as select.")

sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/HadoopFsRelation.scala

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ package org.apache.spark.sql.execution.datasources
2020
import org.apache.spark.sql.{SparkSession, SQLContext}
2121
import org.apache.spark.sql.catalyst.catalog.BucketSpec
2222
import org.apache.spark.sql.execution.FileRelation
23+
import org.apache.spark.sql.internal.SessionStateHelper
2324
import org.apache.spark.sql.sources.{BaseRelation, DataSourceRegister}
2425
import org.apache.spark.sql.types.{StructField, StructType}
2526

@@ -46,7 +47,7 @@ case class HadoopFsRelation(
4647
bucketSpec: Option[BucketSpec],
4748
fileFormat: FileFormat,
4849
options: Map[String, String])(val sparkSession: SparkSession)
49-
extends BaseRelation with FileRelation {
50+
extends BaseRelation with FileRelation with SessionStateHelper{
5051

5152
override def sqlContext: SQLContext = sparkSession.sqlContext
5253

@@ -55,7 +56,7 @@ case class HadoopFsRelation(
5556
// respects the data types of the partition schema.
5657
val (schema: StructType, overlappedPartCols: Map[String, StructField]) =
5758
PartitioningUtils.mergeDataAndPartitionSchema(dataSchema,
58-
partitionSchema, sparkSession.sessionState.conf.caseSensitiveAnalysis)
59+
partitionSchema, getSqlConf(sparkSession).caseSensitiveAnalysis)
5960

6061
override def toString: String = {
6162
fileFormat match {
@@ -65,7 +66,7 @@ case class HadoopFsRelation(
6566
}
6667

6768
override def sizeInBytes: Long = {
68-
val compressionFactor = sparkSession.sessionState.conf.fileCompressionFactor
69+
val compressionFactor = getSqlConf(sparkSession).fileCompressionFactor
6970
(location.sizeInBytes * compressionFactor).toLong
7071
}
7172

sql/core/src/main/scala/org/apache/spark/sql/internal/SessionStateHelper.scala

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ import org.apache.spark.sql.SparkSession
2727
* It also provides type annotations for IDEs to build indexes.
2828
*/
2929
trait SessionStateHelper {
30-
private def sessionState(sparkSession: SparkSession): SessionState = {
30+
protected def sessionState(sparkSession: SparkSession): SessionState = {
3131
sparkSession.sessionState
3232
}
3333

@@ -48,6 +48,10 @@ trait SessionStateHelper {
4848
options: Map[String, String]): Configuration = {
4949
sessionState(sparkSession).newHadoopConfWithOptions(options)
5050
}
51+
52+
def getHadoopConf(sparkSession: SparkSession): Configuration = {
53+
sessionState(sparkSession).newHadoopConf()
54+
}
5155
}
5256

5357
object SessionStateHelper extends SessionStateHelper

sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveTableScanExec.scala

Lines changed: 8 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ import org.apache.spark.sql.execution._
4040
import org.apache.spark.sql.execution.metric.SQLMetrics
4141
import org.apache.spark.sql.hive._
4242
import org.apache.spark.sql.hive.client.HiveClientImpl
43-
import org.apache.spark.sql.internal.SQLConf
43+
import org.apache.spark.sql.internal.{SessionStateHelper, SQLConf}
4444
import org.apache.spark.sql.types.{BooleanType, DataType}
4545
import org.apache.spark.util.Utils
4646

@@ -57,12 +57,12 @@ case class HiveTableScanExec(
5757
relation: HiveTableRelation,
5858
partitionPruningPred: Seq[Expression])(
5959
@transient private val sparkSession: SparkSession)
60-
extends LeafExecNode with CastSupport {
60+
extends LeafExecNode with CastSupport with SessionStateHelper {
6161

6262
require(partitionPruningPred.isEmpty || relation.isPartitioned,
6363
"Partition pruning predicates only supported for partitioned tables.")
6464

65-
override def conf: SQLConf = sparkSession.sessionState.conf
65+
override def conf: SQLConf = getSqlConf(sparkSession)
6666

6767
override def nodeName: String = s"Scan hive ${relation.tableMeta.qualifiedName}"
6868

@@ -98,7 +98,7 @@ case class HiveTableScanExec(
9898
// Create a local copy of hadoopConf,so that scan specific modifications should not impact
9999
// other queries
100100
@transient private lazy val hadoopConf = {
101-
val c = sparkSession.sessionState.newHadoopConf()
101+
val c = getHadoopConf(sparkSession)
102102
// append columns ids and names before broadcast
103103
addColumnMetadataToConf(c)
104104
c
@@ -175,8 +175,7 @@ case class HiveTableScanExec(
175175
prunePartitions(hivePartitions)
176176
}
177177
} else {
178-
if (sparkSession.sessionState.conf.metastorePartitionPruning &&
179-
partitionPruningPred.nonEmpty) {
178+
if (conf.metastorePartitionPruning && partitionPruningPred.nonEmpty) {
180179
rawPartitions
181180
} else {
182181
prunePartitions(rawPartitions)
@@ -187,16 +186,15 @@ case class HiveTableScanExec(
187186
// exposed for tests
188187
@transient lazy val rawPartitions: Seq[HivePartition] = {
189188
val prunedPartitions =
190-
if (sparkSession.sessionState.conf.metastorePartitionPruning &&
191-
partitionPruningPred.nonEmpty) {
189+
if (conf.metastorePartitionPruning && partitionPruningPred.nonEmpty) {
192190
// Retrieve the original attributes based on expression ID so that capitalization matches.
193191
val normalizedFilters = partitionPruningPred.map(_.transform {
194192
case a: AttributeReference => originalAttributes(a)
195193
})
196-
sparkSession.sessionState.catalog
194+
sessionState(sparkSession).catalog
197195
.listPartitionsByFilter(relation.tableMeta.identifier, normalizedFilters)
198196
} else {
199-
sparkSession.sessionState.catalog.listPartitions(relation.tableMeta.identifier)
197+
sessionState(sparkSession).catalog.listPartitions(relation.tableMeta.identifier)
200198
}
201199
prunedPartitions.map(HiveClientImpl.toHivePartition(_, hiveQlTable))
202200
}

0 commit comments

Comments
 (0)