MapReduce处理数据的基本原则之一就是将输入数据分割成片(split),按片读取数据,每个分片交由一个Mapper去做处理。注意,输入分片只是一种逻辑划分,有别于hdfs的数据分块(block),数据分块是数据的物理划分。InputFormat类抽象出了数据分片和读取这两个操作,具体实现交由子类去完成,除了hadoop默认提供的一些子类之外,我们可以自己根据实际需要进行扩展。
下图列出了涉及MapReduce读取数据的几个核心类以及常见的几种扩展。
如图所示,InputFormat类抽象了两个方法,创建分片的getSplit( )和创建数据读取工具的createRecordReader( ),可以扩展InputFormat重写这两个方法来实现不同数据源读取数据,或者采用不同的方式读取数据。
InputSplit表示数据的逻辑分片,最常见的是用于表示文本文件分片的FileSplit类,该类扩展了InputSpit,包含了文件的路径、分片起始位置在源文件中的字节偏移量、分片的字节长度以及分片所属文件块存在数据节点信息。用于分片信息在客户端提交作业时会被序列化到文件然后提交,并且在作业执行中会被反序列化,所以FileSplit还实现了Writable接口,实现了write(DataOutput out)和readFields(DataInput in)两个方法。
正真为MapReduce提供数据输入的是RecordReader类,给它分配一个分片,它就从数据源中读取分片指定的数据段、并将数据组织成指定的数据结构。
hadoop默认提供的一些数据读取类基本可以满足多数需求,特殊情况下我们也可以自己扩展。以下用一个简单了例子介绍扩展方式。
扩展的目的是从多个不同类型的数据库(mysql、oracle、db2等)、或者多张表读取数据。代码如下:
public class MultiTableInputSplit extends DBInputSplit implements Writable{
//查询sql
private String inputSql;
//数据库链接url
private String dbConUrl;
//数据库用户名
private String userName;
private String passWord;
//数据库类型
private String dbType;
//作业执行过程中通过反射创建实例,所以必须有无参构造函数
public MultiTableInputSplit(){
}
public MultiTableInputSplit(long start, long end, String intputSql,
String dbConUrl, String userName, String passWord, String dbType) {
super(start, end);
this.inputSql = intputSql;
this.dbConUrl = dbConUrl;
this.userName = userName;
this.passWord = passWord;
this.dbType = dbType;
}
@Override
public void write(DataOutput output) throws IOException {
super.write(output);
output.writeUTF(inputSql);
output.writeUTF(dbConUrl);
output.writeUTF(userName);
output.writeUTF(passWord);
output.writeUTF(dbType);
}
@Override
public void readFields(DataInput input) throws IOException {
super.readFields(input);
this.inputSql = input.readUTF();
this.dbConUrl = input.readUTF();
this.userName = input.readUTF();
this.passWord = input.readUTF();
this.dbType = input.readUTF();
}
//get、set等方法省略
}
MultiInputSplit 类间接扩展了InputSplit类,添加了数据库连接信息和查询数据所使用的sql语句。
public class MultiTableInputFormat extends InputFormat<LongWritable, MapDBWritable>{ /**
* 单表分片
* sqlInfo字符串的格式为:单个sqlInfo里面各个属性之间用"##"分隔,如:
* dbType##driver##url##user##password##sql##counts
* 多个sqlInfo之间用"#_#"分隔,如:
* sqlInfo1#_#sqlInfo2
*/
@Override
public List<InputSplit> getSplits(JobContext job) throws IOException {
List<InputSplit> splits = new ArrayList<InputSplit>();
String inputQuery = job.getConfiguration().get("sqlInfo");
String[] inputQueryArr = inputQuery.split("#_#");
for(String sql : inputQueryArr){
getSplit(sql, job, splits);
}
return splits;
}
@Override
public RecordReader<LongWritable, MapDBWritable> createRecordReader(
InputSplit split, TaskAttemptContext context) throws IOException,
InterruptedException {
try {
return new MultiTableRecordReader((MultiTableInputSplit)split,
context.getConfiguration());
} catch (SQLException ex) {
throw new IOException(ex.getMessage());
}
}
/**
* 可以根据表的数量、表的大小控制分片的数量
*/
private int getSplitCount(String[] sqlInfo) {
return 1; //简单实现
}
/**
* 计算分片的大小
*/
private int getSplitSize(String[] sqlInfo){
return 100000; //简单实现
}
public void getSplit(String inputQuery, JobContext job, List<InputSplit> splits){
String[] sql = inputQuery.split("##");
int recordCount = Integer.parseInt(sql[6]);
long splitCount = getSplitCount(sql);
long splitSize = getSplitSize(sql);
for (int i = 0; i < splitCount; i++) {
InputSplit split;
if (i + 1 == splitCount) {
split = new MultiTableInputSplit(i * splitSize, recordCount + 1,
sql[5], sql[2], sql[3], sql[4], sql[0]);
} else {
split = new MultiTableInputSplit(i * splitSize, i * splitSize + splitSize,
sql[5], sql[2], sql[3], sql[4], sql[0]);
}
splits.add(split);
}
}
}
class MultiTableRecordReader extends RecordReader<LongWritable, MapDBWritable> { protected ResultSet results = null;
protected DBInputFormat.DBInputSplit split;
protected long pos = 0;
protected LongWritable key = null;
protected MapDBWritable value = null;
protected String dbType;
protected Connection connection;
protected Statement statement;
public MultiTableRecordReader(DBInputFormat.DBInputSplit split, Configuration conf)
throws SQLException {
this.split = split;
initConnection(); //初始化数据库链接
}
@Override
public boolean nextKeyValue() throws IOException {
try {
if (key == null) {
key = new LongWritable();
}
if (value == null) {
value = new MapDBWritable();
}
if (null == this.results) {
this.results = executeQuery(getSelectQuery());
}
if (!results.next()) {
return false;
}
key.set(pos + split.getStart());
value.readFields(results);
pos++;
} catch (SQLException e) {
throw new IOException("SQLException in nextKeyValue", e);
}
return true;
}
@Override
public LongWritable getCurrentKey() {
return key;
}
@Override
public MapDBWritable getCurrentValue() {
return value;
}
@Override
public float getProgress() throws IOException {
return pos / (float) split.getLength();
}
@Override
public void initialize(InputSplit split, TaskAttemptContext context)
throws IOException, InterruptedException {
}
/**
* 根据不同的数据库类型实现不同的分页查询
*/
protected String getSelectQuery() {
StringBuilder query = new StringBuilder();
try {
DBInputFormat.DBInputSplit split = getSplit();
MultiTableInputSplit extSplit = (MultiTableInputSplit) split;
if (extSplit.getLength() > 0 && extSplit.getStart() >= 0) {
query.append(extSplit.getInputSql());
String dbType = ((MultiTableInputSplit)split).getDbType();
if(dbType.equalsIgnoreCase("TERADATA")){
query.append(" QUALIFY ROW_NUM>=").append(split.getStart())
.append(" AND ROW_NUM<").append(split.getEnd());
} if(dbType.equalsIgnoreCase("ORACLE")){
query.append(" WHERE ROW_NUM>=").append(extSplit.getStart())
.append(" AND ROW_NUM<").append(extSplit.getEnd());
} else{
query.append(" LIMIT ").append(extSplit.getStart())
.append(" ,").append(extSplit.getEnd());
}
}
} catch (IOException ex) {
ex.printStackTrace();
}
return query.toString();
}
public void initConnection() throws SQLException{
MultiTableInputSplit sp = (MultiTableInputSplit)split;
String conUrl = sp.getDbConUrl();
String userName = sp.getUserName();
String passWord = sp.getPassWord();
connection = DriverManager.getConnection(conUrl, userName, passWord);
statement= connection.createStatement();
}
@Override
public void close() throws IOException {
try {
if (null != results) {
results.close();
}
if (null != statement) {
statement.close();
}
if (null != connection) {
connection.close();
}
} catch (SQLException e) {
throw new IOException(e.getMessage());
}
}
protected ResultSet executeQuery(String query) throws SQLException {
return statement.executeQuery(query);
}
public Connection getConnection() {
return connection;
}
public DBInputFormat.DBInputSplit getSplit() {
return split;
}
protected void setStatement(PreparedStatement stmt) {
this.statement = stmt;
}
}
public class MapDBWritable implements DBWritable{ /**
* 表中的列名称--->列值
*/
private Map<String, Object> values = null;
/**
* 表中的列名称
*/
private String[] colNames;
/**
* 列名称--->字段数据类型
*/
private Map<String, String> colType;
public MapDBWritable(){
}
public void readFields(ResultSet resultSet) throws SQLException {
ResultSetMetaData meta = resultSet.getMetaData();
int count = meta.getColumnCount();
colNames = new String[count];
values = new HashMap<String, Object>(count);
colType = new HashMap<String, String>(count);
for (int i = 0; i < count; i++) {
String colName = meta.getColumnName(i + 1);
colNames[i] = colName;
values.put(colName, resultSet.getString(colName));
colType.put(colName, meta.getColumnTypeName(i + 1));
}
}
@Override
public String toString() {
StringBuilder builder = new StringBuilder();
for (Map.Entry<String, Object> entry : this.values.entrySet()) {
builder.append(entry.getKey() + "->" + entry.getValue());
builder.append(";");
}
return builder.substring(0, builder.length() - 1);
}
public void write(PreparedStatement preparedstatement) throws SQLException {
}
public Map<String, Object> getValues() {
return values;
}
public void setValues(Map<String, Object> values) {
this.values = values;
}
public String[] getColNames() {
return colNames;
}
public void setColNames(String[] colNames) {
this.colNames = colNames;
}
public Map<String, String> getColType() {
return colType;
}
public void setColType(Map<String, String> colType) {
this.colType = colType;
}
}
MapDBWritable从ResultSet中读取列的类型信息保存在以列名为key类型为value的map中,读取列值保存在以列名为key列值为value的map中,Mapper的输入就为MapDBWritable。
通过扩展以上几个类就可以从多个数据库、多个表读取数据了。