385 lines
		
	
	
		
			9.8 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			385 lines
		
	
	
		
			9.8 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
| """
 | |
| ARFF format loading and saving module.
 | |
| 
 | |
| This module implements the book version [1] of the ARFF format. This means there
 | |
| is no support for instance weights.
 | |
| 
 | |
| Known limitations:
 | |
|   - This implementation does not parse dates
 | |
| 
 | |
| [1]: http://weka.wikispaces.com/ARFF+%28book+version%29
 | |
| """
 | |
| 
 | |
| import re
 | |
| from collections import namedtuple
 | |
| 
 | |
| Field = namedtuple('Field',['name','type'])
 | |
| 
 | |
| __all__ = ['load', 'save', 'Field', 'Numeric', 'Text', 'Nominal']
 | |
| 
 | |
| 
 | |
| #
 | |
| # Line type functions
 | |
| #
 | |
| 
 | |
| def is_empty(line):
 | |
|     return not line.strip()
 | |
| 
 | |
| def is_comment(line):
 | |
|     return line.startswith('%')
 | |
| 
 | |
| def format_comment(line):
 | |
|     return '% '+line
 | |
| 
 | |
| def is_relation(line):
 | |
|     return line.lower().startswith('@relation')
 | |
| 
 | |
| def format_relation(name):
 | |
|     return '@relation ' + format_identifier(name) + '\n'
 | |
| 
 | |
| def is_attribute(line):
 | |
|     return line.lower().startswith('@attribute')
 | |
| 
 | |
| def format_attribute(field):
 | |
|     return '@attribute ' + format_identifier(field.name) + ' ' + str(field.type) + '\n'
 | |
| 
 | |
| def format_attributes(fields):
 | |
|     result = []
 | |
|     for field in fields:
 | |
|         result.append(format_attribute(field))
 | |
|     return ''.join(result)
 | |
| 
 | |
| def is_data(line):
 | |
|     return line.lower().startswith('@data')
 | |
| 
 | |
| def format_data():
 | |
|     return '@data\n'
 | |
| 
 | |
| def format_row(row, fields, sparse=False):
 | |
|     """Formats a data row based on the given fields."""
 | |
|     if sparse:
 | |
|         result = []
 | |
|         for i in range(len(fields)):
 | |
|             field = fields[i]
 | |
|             val = row.get(field.name)
 | |
|             if val != field.type.default():
 | |
|                 result.append(format_numeric(i) + ' ' + field.type.format(val))
 | |
|         return '{' + ','.join(result) + '}\n'
 | |
|     else:
 | |
|         result = []
 | |
|         for field in fields:
 | |
|             result.append(field.type.format(row.get(field.name)))
 | |
|         return ','.join(result)+'\n'
 | |
| 
 | |
| 
 | |
| def safe_next(it):
 | |
|     """Returns the next character from the iterator or ''."""
 | |
|     try:
 | |
|         return next(it)
 | |
|     except StopIteration:
 | |
|         return ''
 | |
| 
 | |
| 
 | |
| def whitespace(rest):
 | |
|     """Parses whitespace at the beginning of the input."""
 | |
|     return rest.lstrip()
 | |
| 
 | |
| 
 | |
| number_pattern = re.compile(r'[-+]?(\d+(\.\d*)?|\.\d+)([eE][-+]?\d+)?')
 | |
| 
 | |
| def numeric(rest):
 | |
|     """Parses a number at the beginning of the input."""
 | |
|     m = number_pattern.match(rest)
 | |
|     if m:
 | |
|         rest = rest[len(m.group(0)):]
 | |
|         try:
 | |
|             number = int(m.group(0))
 | |
|         except ValueError:
 | |
|             number = float(m.group(0))
 | |
|         return number, rest
 | |
|     else:
 | |
|         raise ValueError('Number not parsable')
 | |
| 
 | |
| def format_numeric(number):
 | |
|     """Outputs a number."""
 | |
|     return str(number)
 | |
| 
 | |
| def expect(rest, string):
 | |
|     """Expects to see the string at the start of the input."""
 | |
|     result = rest.startswith(string)
 | |
|     if result:
 | |
|         return result, rest[len(string):]
 | |
|     else:
 | |
|         return False, rest
 | |
| 
 | |
| 
 | |
| identifier_escapes = {
 | |
|     '\\': '\\',
 | |
|     'n' : '\n',
 | |
|     't' : '\t',
 | |
|     'r' : '\r',
 | |
|     '%' : '%',
 | |
|     "'" : "'"
 | |
| }
 | |
| def identifier(rest):
 | |
|     """Parses an optionally quoted identifier at the start of the input."""
 | |
|     name = ''
 | |
| 
 | |
|     it = iter(rest)
 | |
|     c = safe_next(it)
 | |
| 
 | |
|     # non-quoted
 | |
|     if c != "'":
 | |
|         while c and c not in [' ', '\t', ',']:
 | |
|             name += c
 | |
|             c = safe_next(it)
 | |
|         return name, c + ''.join(it)
 | |
| 
 | |
|     # quoted
 | |
| 
 | |
|     # discard the opening quote by fetching next character
 | |
|     c = safe_next(it)
 | |
|     while c:
 | |
|         if c == '\\':
 | |
|             ec = safe_next(it)
 | |
|             if not ec:
 | |
|                 raise ValueError('Input end during escape.')
 | |
|             try:
 | |
|                 name += identifier_escapes[ec]
 | |
|             except KeyError:
 | |
|                 name += '\\' + ec
 | |
|         elif c == "'":
 | |
|             break
 | |
|         else:
 | |
|             name += c
 | |
|         c = safe_next(it)
 | |
|     return name, ''.join(it)
 | |
| 
 | |
| def format_identifier(name):
 | |
|     """Formats an identifier."""
 | |
|     reverse_escapes = { c:ec for (ec,c) in identifier_escapes.items()}
 | |
|     if any(x in name for x in [' ',','] + list(reverse_escapes.keys())):
 | |
|         escaped = ''
 | |
|         for c in name:
 | |
|             if c in reverse_escapes:
 | |
|                 escaped += '\\' + reverse_escapes[c]
 | |
|             else:
 | |
|                 escaped += c
 | |
|         return "'"+escaped+"'"
 | |
| 
 | |
|     return name
 | |
| 
 | |
| class Numeric:
 | |
|     """Numeric field type."""
 | |
|     def parse(self, rest):
 | |
|         if rest.startswith('?'):
 | |
|             return None, rest[1:]
 | |
| 
 | |
|         return numeric(rest)
 | |
| 
 | |
|     def format(self, number):
 | |
|         if number is None:
 | |
|             return '?'
 | |
|         else:
 | |
|             return format_numeric(number)
 | |
| 
 | |
|     def default(self):
 | |
|         return 0
 | |
| 
 | |
|     def __repr__(self):
 | |
|         return 'Numeric'
 | |
| 
 | |
|     def __str__(self):
 | |
|         return 'numeric'
 | |
| 
 | |
| 
 | |
| class Text:
 | |
|     """Text field type."""
 | |
|     def parse(self, rest):
 | |
|         if rest.startswith('?'):
 | |
|             return None, rest[1:]
 | |
| 
 | |
|         return identifier(rest)
 | |
| 
 | |
|     def format(self, name):
 | |
|         if name is None:
 | |
|             return '?'
 | |
|         else:
 | |
|             return format_identifier(name)
 | |
| 
 | |
|     def default(self):
 | |
|         return ''
 | |
| 
 | |
|     def __repr__(self):
 | |
|         return 'Text'
 | |
| 
 | |
|     def __str__(self):
 | |
|         return 'string'
 | |
| 
 | |
| 
 | |
| class Nominal:
 | |
|     """Nominal field type."""
 | |
|     def __init__(self, names):
 | |
|         self.values = names
 | |
| 
 | |
|     def parse(self, rest):
 | |
|         if rest.startswith('?'):
 | |
|             return None, rest[1:]
 | |
| 
 | |
|         name, rest = identifier(rest)
 | |
|         if name in self.values:
 | |
|             return name, rest
 | |
|         else:
 | |
|             raise ValueError('Unknown nominal constant "{}" for {}.'.format(name, self.values))
 | |
| 
 | |
|     def format(self, name):
 | |
|         if name is None:
 | |
|             return '?'
 | |
|         else:
 | |
|             if name not in self.values:
 | |
|                 raise ValueError('Unknown nominal constant "{}" for {}.'.format(name, self.values))
 | |
|             return format_identifier(name)
 | |
| 
 | |
|     def default(self):
 | |
|         return self.values[0]
 | |
| 
 | |
|     def __repr__(self):
 | |
|         return 'Nominal in {}'.format(self.values)
 | |
| 
 | |
|     def __str__(self):
 | |
|         return '{' + ', '.join(format_identifier(name) for name in self.values) + '}'
 | |
| 
 | |
| 
 | |
| def attr_type(rest):
 | |
|     """Parses a field type. Uses the whole rest."""
 | |
|     if rest.lower() in ['numeric', 'integer', 'real']:
 | |
|         return Numeric()
 | |
|     elif rest.lower() in ['string']:
 | |
|         return Text()
 | |
|     elif rest.lower().startswith('date'):
 | |
|         raise NotImplementedError('date parsing is not implemented.')
 | |
|     elif rest.startswith('{') and rest.endswith('}'):
 | |
|         names = []
 | |
|         rest = rest[1:-1]
 | |
|         while rest:
 | |
|             rest = whitespace(rest)
 | |
|             name, rest = identifier(rest)
 | |
|             names.append(name)
 | |
|             rest = whitespace(rest)
 | |
|             seen, rest = expect(rest, ',')
 | |
|             if not seen:
 | |
|                 break
 | |
|         return Nominal(names)
 | |
|     else:
 | |
|         raise ValueError('Unknown attribute type "{}"'.format(rest))
 | |
| 
 | |
| 
 | |
| def parse_attribute(line):
 | |
|     """Parses an attribute line."""
 | |
|     # @attribute WS name WS type
 | |
|     rest = line[len('@attribute'):].strip()
 | |
|     rest = whitespace(rest)
 | |
|     name, rest = identifier(rest)
 | |
|     rest = whitespace(rest)
 | |
|     type = attr_type(rest)
 | |
|     return name, type
 | |
| 
 | |
| 
 | |
| def parse_row(line, fields):
 | |
|     """Parses a row. Row can be normal or sparse."""
 | |
|     line = line.strip()
 | |
|     values = {}
 | |
| 
 | |
|     if not line.startswith('{'):
 | |
|         rest = line
 | |
|         first = True
 | |
|         for field in fields:
 | |
|             if not first:
 | |
|                 rest = whitespace(rest)
 | |
|                 seen, rest = expect(rest, ',')
 | |
|             first = False
 | |
|             rest = whitespace(rest)
 | |
|             value, rest = field.type.parse(rest)
 | |
|             values[field.name] =  value
 | |
|         return values
 | |
|     else:
 | |
|         todo = set(range(len(fields)))
 | |
|         rest = line[1:-1].strip()
 | |
|         first = True
 | |
|         while rest:
 | |
|             if not first:
 | |
|                 rest = whitespace(rest)
 | |
|                 seen, rest = expect(rest, ',')
 | |
|                 if not seen:
 | |
|                     break
 | |
|             first = False
 | |
|             rest = whitespace(rest)
 | |
|             index, rest = numeric(rest)
 | |
|             field = fields[index]
 | |
|             rest = whitespace(rest)
 | |
|             value, rest = field.type.parse(rest)
 | |
|             todo.remove(index)
 | |
|             values[field.name] = value
 | |
|         for field in (fields[i] for i in todo):
 | |
|             values[field.name] = field.type.default()
 | |
|         return values
 | |
| 
 | |
| 
 | |
| def load(fileish):
 | |
|     """
 | |
|     Loads a data set from an arff formatted file-like object.
 | |
| 
 | |
|     This generator function will parse the arff format's header to determine
 | |
|     data shape. Each generated item is a single expanded row.
 | |
| 
 | |
|     fileish -- a file-like object
 | |
|     """
 | |
|     # parse header first
 | |
|     lines = iter(fileish)
 | |
|     fields = []
 | |
|     
 | |
|     for line in lines:
 | |
|         if is_empty(line) or is_comment(line):
 | |
|             continue
 | |
|         
 | |
|         if is_relation(line):
 | |
|             # No care is given for the relation name.
 | |
|             continue 
 | |
| 
 | |
|         if is_attribute(line):
 | |
|             name, type = parse_attribute(line)
 | |
|             fields.append(Field(name, type))
 | |
|             continue
 | |
| 
 | |
|         if is_data(line):
 | |
|             # We are done with the header, next up is 1 row per line
 | |
|             break
 | |
| 
 | |
|     # parse data lines
 | |
|     for line in lines:
 | |
|         if is_empty(line) or is_comment(line):
 | |
|             continue
 | |
|         row = parse_row(line, fields)
 | |
|         yield row
 | |
| 
 | |
| def save(fileish, fields, rows, name='unnamed relation', sparse=False):
 | |
|     """
 | |
|     Saves an arff formatted data set to a file-like object.
 | |
| 
 | |
|     The rows parameter can be any iterable. The fields parameter must be a list
 | |
|     of `Field` instances.
 | |
| 
 | |
|     fileish -- a file-like object to write to
 | |
|     fields -- a list of `Field` instances
 | |
|     rows -- an iterable containing one dictionary per data row
 | |
|     name -- the relation name, defaults to 'unnamed relation'
 | |
|     sparse -- whether the output should be in sparse format, defaults to False
 | |
|     """
 | |
|     fileish.write(format_relation(name))
 | |
|     fileish.write('\n')
 | |
|     fileish.write(format_attributes(fields))
 | |
|     fileish.write('\n')
 | |
|     fileish.write(format_data())
 | |
|     for row in rows:
 | |
|         fileish.write(format_row(row, fields, sparse))
 |