Adding a nested column to a Spark DataFrame

How to add or replace fields in a structure at any nested level?

This input:

val rdd = sc.parallelize(Seq( """{"a": {"xX": 1,"XX": 2},"b": {"z": 0}}""", """{"a": {"xX": 3},"b": {"z": 0}}""", """{"a": {"XX": 3},"b": {"z": 0}}""", """{"a": {"xx": 4},"b": {"z": 0}}""")) var df = sqlContext.read.json(rdd) 

The following scheme yields:

 root |-- a: struct (nullable = true) | |-- XX: long (nullable = true) | |-- xX: long (nullable = true) | |-- xx: long (nullable = true) |-- b: struct (nullable = true) | |-- z: long (nullable = true) 

Then I can do it:

 import org.apache.spark.sql.functions._ val overlappingNames = Seq(col("a.xx"), col("a.xX"), col("a.XX")) df = df .withColumn("a_xx", coalesce(overlappingNames:_*)) .dropNestedColumn("a.xX") .dropNestedColumn("a.XX") .dropNestedColumn("a.xx") 

( dropNestedColumn borrowed from this answer: https://stackoverflow.com/a/4646262/) . I basically look for the inverse operation.)

And the circuit becomes:

 root |-- a: struct (nullable = false) |-- b: struct (nullable = true) | |-- z: long (nullable = true) |-- a_xx: long (nullable = true) 

Obviously, it does not replace (or add) a.xx , but instead adds a new field a_xx at the root level.

I would like to do this instead:

 val overlappingNames = Seq(col("a.xx"), col("a.xX"), col("a.XX")) df = df .withNestedColumn("a.xx", coalesce(overlappingNames:_*)) .dropNestedColumn("a.xX") .dropNestedColumn("a.XX") 

In order for this to lead to this scheme:

 root |-- a: struct (nullable = false) | |-- xx: long (nullable = true) |-- b: struct (nullable = true) | |-- z: long (nullable = true) 

How can i achieve this?

The practical goal here is case insensitive with column names in input JSON. The last step would be simple: collect all the overlapping column names and apply coalescence on each.

+6
source share
1 answer

It may not be as elegant or effective as it could be, but here is what I came up with:

 object DataFrameUtils { private def nullableCol(parentCol: Column, c: Column): Column = { when(parentCol.isNotNull, c) } private def nullableCol(c: Column): Column = { nullableCol(c, c) } private def createNestedStructs(splitted: Seq[String], newCol: Column): Column = { splitted .foldRight(newCol) { case (colName, nestedStruct) => nullableCol(struct(nestedStruct as colName)) } } private def recursiveAddNestedColumn(splitted: Seq[String], col: Column, colType: DataType, nullable: Boolean, newCol: Column): Column = { colType match { case colType: StructType if splitted.nonEmpty => { var modifiedFields: Seq[(String, Column)] = colType.fields .map(f => { var curCol = col.getField(f.name) if (f.name == splitted.head) { curCol = recursiveAddNestedColumn(splitted.tail, curCol, f.dataType, f.nullable, newCol) } (f.name, curCol as f.name) }) if (!modifiedFields.exists(_._1 == splitted.head)) { modifiedFields :+= (splitted.head, nullableCol(col, createNestedStructs(splitted.tail, newCol)) as splitted.head) } var modifiedStruct: Column = struct(modifiedFields.map(_._2): _*) if (nullable) { modifiedStruct = nullableCol(col, modifiedStruct) } modifiedStruct } case _ => createNestedStructs(splitted, newCol) } } private def addNestedColumn(df: DataFrame, newColName: String, newCol: Column): DataFrame = { if (newColName.contains('.')) { var splitted = newColName.split('.') val modifiedOrAdded: (String, Column) = df.schema.fields .find(_.name == splitted.head) .map(f => (f.name, recursiveAddNestedColumn(splitted.tail, col(f.name), f.dataType, f.nullable, newCol))) .getOrElse { (splitted.head, createNestedStructs(splitted.tail, newCol) as splitted.head) } df.withColumn(modifiedOrAdded._1, modifiedOrAdded._2) } else { // Top level addition, use spark method as-is df.withColumn(newColName, newCol) } } implicit class ExtendedDataFrame(df: DataFrame) extends Serializable { /** * Add nested field to DataFrame * * @param newColName Dot-separated nested field name * @param newCol New column value */ def withNestedColumn(newColName: String, newCol: Column): DataFrame = { DataFrameUtils.addNestedColumn(df, newColName, newCol) } } } 

Feel free to improve it.

 val data = spark.sparkContext.parallelize(List("""{ "a1": 1, "a3": { "b1": 3, "b2": { "c1": 5, "c2": 6 } } }""")) val df: DataFrame = spark.read.json(data) val df2 = df.withNestedColumn("a3.b2.c3.d1", $"a3.b2") 

must produce:

 assertResult("struct<a1:bigint,a3:struct<b1:bigint,b2:struct<c1:bigint,c2:bigint,c3:struct<d1:struct<c1:bigint,c2:bigint>>>>>")(df2.shema.simpleString) 
+1
source

Source: https://habr.com/ru/post/1014203/


All Articles