1 #!/usr/bin/env python 2 3 """ 4 Java class file decoder. Specification found at the following URL: 5 http://java.sun.com/docs/books/vmspec/2nd-edition/html/ClassFile.doc.html 6 """ 7 8 import struct 9 10 # Utility functions. 11 12 def u1(data): 13 return struct.unpack(">B", data[0:1])[0] 14 15 def u2(data): 16 return struct.unpack(">H", data[0:2])[0] 17 18 def u4(data): 19 return struct.unpack(">L", data[0:4])[0] 20 21 # Constant information. 22 # Objects of these classes are not directly aware of the class they reside in. 23 24 class ClassInfo: 25 def init(self, data): 26 self.name_index = u2(data[0:2]) 27 return data[2:] 28 29 class RefInfo: 30 def init(self, data): 31 self.class_index = u2(data[0:2]) 32 self.name_and_type_index = u2(data[2:4]) 33 return data[4:] 34 35 class FieldRefInfo(RefInfo): 36 pass 37 38 class MethodRefInfo(RefInfo): 39 pass 40 41 class InterfaceMethodRefInfo(RefInfo): 42 pass 43 44 class NameAndTypeInfo: 45 def init(self, data): 46 self.name_index = u2(data[0:2]) 47 self.descriptor_index = u2(data[2:4]) 48 return data[4:] 49 50 class Utf8Info: 51 def init(self, data): 52 self.length = u2(data[0:2]) 53 self.bytes = data[2:2+self.length] 54 return data[2+self.length:] 55 56 def __str__(self): 57 return self.bytes 58 59 def __unicode__(self): 60 return unicode(self.bytes, "utf-8") 61 62 class StringInfo: 63 def init(self, data): 64 self.string_index = u2(data[0:2]) 65 return data[2:] 66 67 class SmallNumInfo: 68 def init(self, data): 69 self.bytes = u4(data[0:4]) 70 return data[4:] 71 72 class IntegerInfo(SmallNumInfo): 73 pass 74 75 class FloatInfo(SmallNumInfo): 76 pass 77 78 class LargeNumInfo: 79 def init(self, data): 80 self.high_bytes = u4(data[0:4]) 81 self.low_bytes = u4(data[4:8]) 82 return data[8:] 83 84 class LongInfo(LargeNumInfo): 85 pass 86 87 class DoubleInfo(LargeNumInfo): 88 pass 89 90 # Other information. 91 # Objects of these classes are generally aware of the class they reside in. 92 93 class ItemInfo: 94 def init(self, data, class_file): 95 self.class_file = class_file 96 self.access_flags = u2(data[0:2]) 97 self.name_index = u2(data[2:4]) 98 self.descriptor_index = u2(data[4:6]) 99 self.attributes, data = self.class_file._get_attributes(data[6:]) 100 return data 101 102 # Symbol parsing. 103 104 def _get_method_descriptor(self, s): 105 assert s[0] == "(" 106 params = [] 107 s = s[1:] 108 while s[0] != ")": 109 parameter_descriptor, s = self._get_parameter_descriptor(s) 110 params.append(parameter_descriptor) 111 if s[1] != "V": 112 return_type, s = self._get_field_type(s[1:]) 113 else: 114 return_type, s = None, s[1:] 115 return params, return_type 116 117 def _get_parameter_descriptor(self, s): 118 return self._get_field_type(s) 119 120 def _get_field_descriptor(self, s): 121 return self._get_field_type(s) 122 123 def _get_component_type(self, s): 124 return self._get_field_type(s) 125 126 def _get_field_type(self, s): 127 base_type, s = self._get_base_type(s) 128 object_type = None 129 array_type = None 130 if base_type == "L": 131 object_type, s = self._get_object_type(s) 132 elif base_type == "[": 133 array_type, s = self._get_array_type(s) 134 return (base_type, object_type, array_type), s 135 136 def _get_base_type(self, s): 137 if len(s) > 0: 138 return s[0], s[1:] 139 else: 140 return None, s 141 142 def _get_object_type(self, s): 143 if len(s) > 0: 144 s_end = s.find(";") 145 assert s_end != -1 146 return s[:s_end], s[s_end+1:] 147 else: 148 return None, s 149 150 def _get_array_type(self, s): 151 if len(s) > 0: 152 return self._get_component_type(s[1:]) 153 else: 154 return None, s 155 156 # Processed details. 157 158 def get_name(self): 159 return unicode(self.class_file.constants[self.name_index - 1]) 160 161 class FieldInfo(ItemInfo): 162 def get_descriptor(self): 163 return self._get_field_descriptor(unicode(self.class_file.constants[self.descriptor_index - 1])) 164 165 class MethodInfo(ItemInfo): 166 def get_descriptor(self): 167 return self._get_method_descriptor(unicode(self.class_file.constants[self.descriptor_index - 1])) 168 169 class AttributeInfo: 170 def init(self, data, class_file): 171 self.attribute_length = u4(data[0:4]) 172 self.info = data[4:4+self.attribute_length] 173 return data[4+self.attribute_length:] 174 175 # NOTE: Decode the different attribute formats. 176 177 class SourceFileAttributeInfo(AttributeInfo): 178 pass 179 180 class ConstantValueAttributeInfo(AttributeInfo): 181 def init(self, data, class_file): 182 self.attribute_length = u4(data[0:4]) 183 self.constant_value_index = u2(data[4:6]) 184 assert 4+self.attribute_length == 6 185 return data[4+self.attribute_length:] 186 187 class CodeAttributeInfo(AttributeInfo): 188 def init(self, data, class_file): 189 self.class_file = class_file 190 self.attribute_length = u4(data[0:4]) 191 self.max_stack = u2(data[4:6]) 192 self.max_locals = u2(data[6:8]) 193 self.code_length = u4(data[8:12]) 194 end_of_code = 12+self.code_length 195 self.code = data[12:end_of_code] 196 self.exception_table_length = u2(data[end_of_code:end_of_code+2]) 197 self.exception_table = [] 198 data = data[end_of_code + 2:] 199 for i in range(0, self.exception_table_length): 200 exception = ExceptionInfo() 201 data = exception.init(data) 202 self.attributes, data = self.class_file._get_attributes(data) 203 return data 204 205 class ExceptionsAttributeInfo(AttributeInfo): 206 def init(self, data, class_file): 207 self.class_file = class_file 208 self.attribute_length = u4(data[0:4]) 209 self.number_of_exceptions = u2(data[4:6]) 210 self.exception_index_table = [] 211 index = 6 212 for i in range(0, self.number_of_exceptions): 213 self.exception_index_table.append(u2(data[index:index+2])) 214 index += 2 215 return data[index:] 216 217 def get_exception(self, i): 218 exception_index = self.exception_index_table[i] 219 return self.class_file.constants[exception_index - 1] 220 221 class InnerClassesAttributeInfo(AttributeInfo): 222 pass 223 224 class SyntheticAttributeInfo(AttributeInfo): 225 pass 226 227 class LineNumberAttributeInfo(AttributeInfo): 228 pass 229 230 class LocalVariableAttributeInfo(AttributeInfo): 231 pass 232 233 class DeprecatedAttributeInfo(AttributeInfo): 234 pass 235 236 class ExceptionInfo: 237 def __init__(self): 238 self.start_pc, self.end_pc, self.handler_pc, self.catch_type = None, None, None, None 239 240 def init(self, data): 241 self.start_pc = u2(data[0:2]) 242 self.end_pc = u2(data[2:4]) 243 self.handler_pc = u2(data[4:6]) 244 self.catch_type = u2(data[6:8]) 245 return data[8:] 246 247 class UnknownTag(Exception): 248 pass 249 250 class UnknownAttribute(Exception): 251 pass 252 253 # Abstractions for the main structures. 254 255 class ClassFile: 256 257 "A class representing a Java class file." 258 259 def __init__(self, s): 260 261 """ 262 Process the given string 's', populating the object with the class 263 file's details. 264 """ 265 266 self.constants, s = self._get_constants(s[8:]) 267 self.access_flags, s = self._get_access_flags(s) 268 self.this_class, s = self._get_this_class(s) 269 self.super_class, s = self._get_super_class(s) 270 self.interfaces, s = self._get_interfaces(s) 271 self.fields, s = self._get_fields(s) 272 self.methods, s = self._get_methods(s) 273 self.attributes, s = self._get_attributes(s) 274 275 def _decode_const(self, s): 276 tag = u1(s[0:1]) 277 if tag == 1: 278 const = Utf8Info() 279 elif tag == 3: 280 const = IntegerInfo() 281 elif tag == 4: 282 const = FloatInfo() 283 elif tag == 5: 284 const = LongInfo() 285 elif tag == 6: 286 const = DoubleInfo() 287 elif tag == 7: 288 const = ClassInfo() 289 elif tag == 8: 290 const = StringInfo() 291 elif tag == 9: 292 const = FieldRefInfo() 293 elif tag == 10: 294 const = MethodRefInfo() 295 elif tag == 11: 296 const = InterfaceMethodRefInfo() 297 elif tag == 12: 298 const = NameAndTypeInfo() 299 else: 300 raise UnknownTag, tag 301 s = const.init(s[1:]) 302 return const, s 303 304 def _get_constants_from_table(self, count, s): 305 l = [] 306 # Have to skip certain entries specially. 307 i = 1 308 while i < count: 309 c, s = self._decode_const(s) 310 l.append(c) 311 # Add a blank entry after "large" entries. 312 if isinstance(c, LargeNumInfo): 313 l.append(None) 314 i += 1 315 i += 1 316 return l, s 317 318 def _get_items_from_table(self, cls, number, s): 319 l = [] 320 for i in range(0, number): 321 f = cls() 322 s = f.init(s, self) 323 l.append(f) 324 return l, s 325 326 def _get_methods_from_table(self, number, s): 327 return self._get_items_from_table(MethodInfo, number, s) 328 329 def _get_fields_from_table(self, number, s): 330 return self._get_items_from_table(FieldInfo, number, s) 331 332 def _get_attribute_from_table(self, s): 333 attribute_name_index = u2(s[0:2]) 334 constant_name = self.constants[attribute_name_index - 1].bytes 335 if constant_name == "SourceFile": 336 attribute = SourceFileAttributeInfo() 337 elif constant_name == "ConstantValue": 338 attribute = ConstantValueAttributeInfo() 339 elif constant_name == "Code": 340 attribute = CodeAttributeInfo() 341 elif constant_name == "Exceptions": 342 attribute = ExceptionsAttributeInfo() 343 elif constant_name == "InnerClasses": 344 attribute = InnerClassesAttributeInfo() 345 elif constant_name == "Synthetic": 346 attribute = SyntheticAttributeInfo() 347 elif constant_name == "LineNumberTable": 348 attribute = LineNumberAttributeInfo() 349 elif constant_name == "LocalVariableTable": 350 attribute = LocalVariableAttributeInfo() 351 elif constant_name == "Deprecated": 352 attribute = DeprecatedAttributeInfo() 353 else: 354 raise UnknownAttribute, constant_name 355 s = attribute.init(s[2:], self) 356 return attribute, s 357 358 def _get_attributes_from_table(self, number, s): 359 attributes = [] 360 for i in range(0, number): 361 attribute, s = self._get_attribute_from_table(s) 362 attributes.append(attribute) 363 return attributes, s 364 365 def _get_constants(self, s): 366 count = u2(s[0:2]) 367 return self._get_constants_from_table(count, s[2:]) 368 369 def _get_access_flags(self, s): 370 return u2(s[0:2]), s[2:] 371 372 def _get_this_class(self, s): 373 index = u2(s[0:2]) 374 return self.constants[index - 1], s[2:] 375 376 _get_super_class = _get_this_class 377 378 def _get_interfaces(self, s): 379 interfaces = [] 380 number = u2(s[0:2]) 381 s = s[2:] 382 for i in range(0, number): 383 index = u2(s[0:2]) 384 interfaces.append(self.constants[index - 1]) 385 s = s[2:] 386 return interfaces, s 387 388 def _get_fields(self, s): 389 number = u2(s[0:2]) 390 return self._get_fields_from_table(number, s[2:]) 391 392 def _get_attributes(self, s): 393 number = u2(s[0:2]) 394 return self._get_attributes_from_table(number, s[2:]) 395 396 def _get_methods(self, s): 397 number = u2(s[0:2]) 398 return self._get_methods_from_table(number, s[2:]) 399 400 if __name__ == "__main__": 401 import sys 402 f = open(sys.argv[1]) 403 c = ClassFile(f.read()) 404 405 # vim: tabstop=4 expandtab shiftwidth=4