Java反序列化原理

描述

这是Java反序列化系列的第一篇,主要记录了一般性的Java反序列化过程。这篇文章是在千岛湖上完成的,四面环湖确实更加容易让人静心。

测试代码

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
import java.io.*;

public class serializeDemo {
public static void main(String[] args) throws IOException, ClassNotFoundException {

testClass demo = new testClass();

FileOutputStream fos = new FileOutputStream("test.ser");
ObjectOutputStream os = new ObjectOutputStream(fos);
os.writeObject(demo);
os.close();

FileInputStream fileInputStream = new FileInputStream("test.ser");
ObjectInputStream ois = new ObjectInputStream(fileInputStream);
testClass t = (testClass) ois.readObject();

}
}

class testClass implements Serializable {
public String flag;
public testClass(){
flag = "flag{test}";
}
}

test.ser

原理分析

ObjectInputStream构造函数

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
public ObjectInputStream(InputStream in) throws IOException {
verifySubclass();
bin = new BlockDataInputStream(in);
handles = new HandleTable(10);
vlist = new ValidationList();
serialFilter = ObjectInputFilter.Config.getSerialFilter();
enableOverride = false;
readStreamHeader();
bin.setBlockDataMode(true);
}

// verifySubclass() 查看是否可以在不违反安全约束的情况下构造此实例
// bin 读取块数据
// handles 是一个轻量的hash表,它的作用是缓存写过的共享class便于下次查找,内部含有3个数组,spine、next和objs。objs存储的是对象也就是class,spine是hash桶,next是冲突链表,使用assign来添加
// vist 是一个用来提供CallBack操作的验证集合

readObject函数

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
public final Object readObject()
throws IOException, ClassNotFoundException
{
if (enableOverride) {
return readObjectOverride();
}

// if nested read, passHandle contains handle of enclosing object
int outerHandle = passHandle;
try {
Object obj = readObject0(false);
handles.markDependency(outerHandle, passHandle);
ClassNotFoundException ex = handles.lookupException(passHandle);
if (ex != null) {
throw ex;
}
if (depth == 0) {
vlist.doCallbacks();
}
return obj;
} finally {
passHandle = outerHandle;
if (closed && depth == 0) {
clear();
}
}
}

readObject0函数

先判断oldMode的值

选择*TC_OBJECT*进入checkResolvereadOrdinaryObject

readOrdinaryObject

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
private Object readOrdinaryObject(boolean unshared)
throws IOException
{
if (bin.readByte() != TC_OBJECT) {
throw new InternalError();
}
//使用 Class.forName 生成 testClas 类的描述符
ObjectStreamClass desc = readClassDesc(false);
// 检查是否可以反序列化
desc.checkDeserialize();

// 使用 forClass 生成描述符表示的类实例,当不存在本地类时会输出null
Class<?> cl = desc.forClass();
if (cl == String.class || cl == Class.class
|| cl == ObjectStreamClass.class) {
throw new InvalidClassException("invalid class descriptor");
}

Object obj;
try {
// 创建 testClass 的实例
obj = desc.isInstantiable() ? desc.newInstance() : null;
} catch (Exception ex) {
throw (IOException) new InvalidClassException(
desc.forClass().getName(),
"unable to create instance").initCause(ex);
}

passHandle = handles.assign(unshared ? unsharedMarker : obj);
ClassNotFoundException resolveEx = desc.getResolveException();
if (resolveEx != null) {
handles.markException(passHandle, resolveEx);
}

if (desc.isExternalizable()) {
readExternalData((Externalizable) obj, desc);
} else {
readSerialData(obj, desc);
}

handles.finish(passHandle);

if (obj != null &&
handles.lookupException(passHandle) == null &&
desc.hasReadResolveMethod())
{
Object rep = desc.invokeReadResolve(obj);
if (unshared && rep.getClass().isArray()) {
rep = cloneArray(rep);
}
if (rep != obj) {
// Filter the replacement object
if (rep != null) {
if (rep.getClass().isArray()) {
filterCheck(rep.getClass(), Array.getLength(rep));
} else {
filterCheck(rep.getClass(), -1);
}
}
handles.setObject(passHandle, obj = rep);
}
}

return obj;
}

readClassDesc函数

通过读取test.ser中的标识符,匹配到*TC_CLASSDESC*再进入resolveClass函数

使用Class.forName生成testClass类的描述符

具体的调用链

readSerialData

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
private void readSerialData(Object obj, ObjectStreamClass desc)
throws IOException
{
ObjectStreamClass.ClassDataSlot[] slots = desc.getClassDataLayout();
for (int i = 0; i < slots.length; i++) {
ObjectStreamClass slotDesc = slots[i].desc;

if (slots[i].hasData) {
if (obj == null || handles.lookupException(passHandle) != null) {
defaultReadFields(null, slotDesc); // skip field values
} else if (slotDesc.hasReadObjectMethod()) {
ThreadDeath t = null;
boolean reset = false;
SerialCallbackContext oldContext = curContext;
if (oldContext != null)
oldContext.check();
try {
curContext = new SerialCallbackContext(obj, slotDesc);

bin.setBlockDataMode(true);
slotDesc.invokeReadObject(obj, this);
} catch (ClassNotFoundException ex) {
/*
* In most cases, the handle table has already
* propagated a CNFException to passHandle at this
* point; this mark call is included to address cases
* where the custom readObject method has cons'ed and
* thrown a new CNFException of its own.
*/
handles.markException(passHandle, ex);
} finally {
do {
try {
curContext.setUsed();
if (oldContext!= null)
oldContext.check();
curContext = oldContext;
reset = true;
} catch (ThreadDeath x) {
t = x; // defer until reset is true
}
} while (!reset);
if (t != null)
throw t;
}

/*
* defaultDataEnd may have been set indirectly by custom
* readObject() method when calling defaultReadObject() or
* readFields(); clear it to restore normal read behavior.
*/
defaultDataEnd = false;
} else {
defaultReadFields(obj, slotDesc);
}

if (slotDesc.hasWriteObjectData()) {
skipCustomData();
} else {
bin.setBlockDataMode(false);
}
} else {
if (obj != null &&
slotDesc.hasReadObjectNoDataMethod() &&
handles.lookupException(passHandle) == null)
{
slotDesc.invokeReadObjectNoData(obj);
}
}
}
}

当反序列化的类没有重写resolveClass函数和readObject函数时,就会使用defaultReadFields函数对实例化的obj进行赋值

defaultReadFields

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
private void defaultReadFields(Object obj, ObjectStreamClass desc)
throws IOException
{
Class<?> cl = desc.forClass();
if (cl != null && obj != null && !cl.isInstance(obj)) {
throw new ClassCastException();
}

int primDataSize = desc.getPrimDataSize();
if (primVals == null || primVals.length < primDataSize) {
primVals = new byte[primDataSize];
}
bin.readFully(primVals, 0, primDataSize, false);
if (obj != null) {
desc.setPrimFieldValues(obj, primVals);
}

int objHandle = passHandle;
ObjectStreamField[] fields = desc.getFields(false);
Object[] objVals = new Object[desc.getNumObjFields()];
int numPrimFields = fields.length - objVals.length;
for (int i = 0; i < objVals.length; i++) {
ObjectStreamField f = fields[numPrimFields + i];
objVals[i] = readObject0(f.isUnshared());
if (f.getField() != null) {
handles.markDependency(objHandle, passHandle);
}
}
if (obj != null) {
desc.setObjFieldValues(obj, objVals);
}
passHandle = objHandle;
}

该函数主要要的逻辑是这部分,使用readObject0对类中的每个参数的值进行读取赋值。

当所有的参数被读取完成时,就完成了反序列化。

总结

在Java反序列化的过程中,最主要的函数是readOrdinaryObject函数和readSerialData函数,readOrdinaryObject函数主要负责实例化未赋值的类和选择对应的参数值读取方法,readSerialData函数是Java默认的参数值读取函数,同时由于参数的值也是序列化后存储的,所以在readSerialData函数中通过调用readObject0函数的方法,实现整个反序列化的参数读取流程。

参考

https://fynch3r.github.io/Java反序列化流程梳理/
https://developer.aliyun.com/article/643797
https://www.cnblogs.com/nice0e3/p/14127885.html#0x02-shiro-resolveclass方法分析

后来看到的文章,分析的更加透彻,补一下
https://blog.kaibro.tw/2020/02/23/Java%E5%8F%8D%E5%BA%8F%E5%88%97%E5%8C%96%E4%B9%8BreadObject%E5%88%86%E6%9E%90/


Java反序列化原理
http://blog.lousix.top/2023/01/31/反序列化原理/
作者
Lousix
发布于
2023年1月31日
许可协议