SparkConf conf = new SparkConf(); conf.setMaster("local"); conf.setAppName("udf"); JavaSparkContext sc = new JavaSparkContext(conf); SQLContext sqlContext = new SQLContext(sc); JavaRDD<String> parallelize = sc.parallelize(Arrays.asList("zhangsan","lisi","wangwu")); JavaRDD<Row> rowRDD = parallelize.map(new Function<String, Row>() { private static final long serialVersionUID = 1L; @Override public Row call(String s) throws Exception { return RowFactory.create(s); } }); /** * 动态创建Schema方式加载DF */ List<StructField> fields = new ArrayList<StructField>(); fields.add(DataTypes.createStructField("name", DataTypes.StringType,true)); StructType schema = DataTypes.createStructType(fields); DataFrame df = sqlContext.createDataFrame(rowRDD,schema); df.registerTempTable("user"); /** * 根据UDF函数参数的个数来决定是实现哪一个UDF UDF1,UDF2。。。。UDF1xxx */ sqlContext.udf().register("StrLen",new UDF2<String, Integer, Integer>() { private static final long serialVersionUID = 1L; @Override public Integer call(String t1, Integer t2) throws Exception { return t1.length() + t2; } } ,DataTypes.IntegerType ); sqlContext.sql("select name ,StrLen(name,100) as length from user").show(); sc.stop();
SparkConf conf = new SparkConf(); conf.setMaster("local").setAppName("udaf"); conf.set("spark.sql.shuffle.partitions", "1"); JavaSparkContext sc = new JavaSparkContext(conf); SQLContext sqlContext = new SQLContext(sc); JavaRDDparallelize = sc.parallelize( Arrays.asList("zhangsan", "lisi", "wangwu", "zhangsan", "zhangsan", "lisi","zhangsan", "lisi", "wangwu", "zhangsan", "zhangsan", "lisi"),2); JavaRDDrowRDD = parallelize.map(new Function() { private static final long serialVersionUID = 1L; @Override public Row call(String s) throws Exception { return RowFactory.create(s); } }); Listfields = new ArrayList(); fields.add(DataTypes.createStructField("name", DataTypes.StringType, true)); StructType schema = DataTypes.createStructType(fields); DataFrame df = sqlContext.createDataFrame(rowRDD, schema); df.registerTempTable("user");
注册一个UDAF函数,实现统计相同值的个数,注意:这里可以自定义一个类继承UserDefinedAggregateFunction类
sqlContext.udf().register("StringCount", new UserDefinedAggregateFunction() { private static final long serialVersionUID = 1L; /** * 初始化一个内部的自己定义的值,在Aggregate之前每组数据的初始化结果 */ @Override public void initialize(MutableAggregationBuffer buffer) { buffer.update(0, 0); System.out.println("init ....." + buffer.get(0)); } /** * 更新 可以认为一个一个地将组内的字段值传递进来 实现拼接的逻辑 * buffer.getInt(0)获取的是上一次聚合后的值 * 相当于map端的combiner,combiner就是对每一个map task的处理结果进行一次小聚合 * 大聚和发生在reduce端. * 这里即是:在进行聚合的时候,每当有新的值进来,对分组后的聚合如何进行计算 */ @Override public void update(MutableAggregationBuffer buffer, Row arg1) { buffer.update(0, buffer.getInt(0) + 1); System.out.println("update.....buffer" + buffer.toString() + " | row" + arg1.toString() ); } /** * 合并 update操作,可能是针对一个分组内的部分数据,在某个节点上发生的 但是可能一个分组内的数据,会分布在多个节点上处理 * 此时就要用merge操作,将各个节点上分布式拼接好的串,合并起来 * buffer1.getInt(0) : 大聚合的时候 上一次聚合后的值 * buffer2.getInt(0) : 这次计算传入进来的update的结果 * 这里即是:最后在分布式节点完成后需要进行全局级别的Merge操作 */ public void merge(MutableAggregationBuffer buffer1, Row arg1) { // 2 3 4 5 6 7 // 0 + 2 = 2 // 2 + 3 = 5 // 5 + 4 = 9 buffer1.update(0, buffer1.getInt(0) + arg1.getInt(0)); System.out.println("merge.....buffer : " + buffer1.toString() + "| row" + arg1.toString() ); } /** * 在进行聚合操作的时候所要处理的数据的结果的类型 */ @Override public StructType bufferSchema() { return DataTypes.createStructType(Arrays.asList(DataTypes.createStructField("bffer", DataTypes.IntegerType, true))); } /** * 最后返回一个和DataType的类型要一致的类型,返回UDAF最后的计算结果 */ @Override public Object evaluate(Row row) { return row.getInt(0); } /** * 指定UDAF函数计算后返回的结果类型 */ @Override public DataType dataType() { return DataTypes.IntegerType; } /** * 指定输入字段的字段及类型 */ @Override public StructType inputSchema() { return DataTypes.createStructType(Arrays.asList(DataTypes.createStructField("name", DataTypes.StringType, true))); } /** * 确保一致性 一般用true,用以标记针对给定的一组输入,UDAF是否总是生成相同的结果。 */ @Override public boolean deterministic() { return true; } }); sqlContext.sql("select name ,StringCount(name) as number from user group by name").show(); sc.stop();