385 lines
9.8 KiB
Python
385 lines
9.8 KiB
Python
"""
|
|
ARFF format loading and saving module.
|
|
|
|
This module implements the book version [1] of the ARFF format. This means there
|
|
is no support for instance weights.
|
|
|
|
Known limitations:
|
|
- This implementation does not parse dates
|
|
|
|
[1]: http://weka.wikispaces.com/ARFF+%28book+version%29
|
|
"""
|
|
|
|
import re
|
|
from collections import namedtuple
|
|
|
|
Field = namedtuple('Field',['name','type'])
|
|
|
|
__all__ = ['load', 'save', 'Field', 'Numeric', 'Text', 'Nominal']
|
|
|
|
|
|
#
|
|
# Line type functions
|
|
#
|
|
|
|
def is_empty(line):
|
|
return not line.strip()
|
|
|
|
def is_comment(line):
|
|
return line.startswith('%')
|
|
|
|
def format_comment(line):
|
|
return '% '+line
|
|
|
|
def is_relation(line):
|
|
return line.lower().startswith('@relation')
|
|
|
|
def format_relation(name):
|
|
return '@relation ' + format_identifier(name) + '\n'
|
|
|
|
def is_attribute(line):
|
|
return line.lower().startswith('@attribute')
|
|
|
|
def format_attribute(field):
|
|
return '@attribute ' + format_identifier(field.name) + ' ' + str(field.type) + '\n'
|
|
|
|
def format_attributes(fields):
|
|
result = []
|
|
for field in fields:
|
|
result.append(format_attribute(field))
|
|
return ''.join(result)
|
|
|
|
def is_data(line):
|
|
return line.lower().startswith('@data')
|
|
|
|
def format_data():
|
|
return '@data\n'
|
|
|
|
def format_row(row, fields, sparse=False):
|
|
"""Formats a data row based on the given fields."""
|
|
if sparse:
|
|
result = []
|
|
for i in range(len(fields)):
|
|
field = fields[i]
|
|
val = row.get(field.name)
|
|
if val != field.type.default():
|
|
result.append(format_numeric(i) + ' ' + field.type.format(val))
|
|
return '{' + ','.join(result) + '}\n'
|
|
else:
|
|
result = []
|
|
for field in fields:
|
|
result.append(field.type.format(row.get(field.name)))
|
|
return ','.join(result)+'\n'
|
|
|
|
|
|
def safe_next(it):
|
|
"""Returns the next character from the iterator or ''."""
|
|
try:
|
|
return next(it)
|
|
except StopIteration:
|
|
return ''
|
|
|
|
|
|
def whitespace(rest):
|
|
"""Parses whitespace at the beginning of the input."""
|
|
return rest.lstrip()
|
|
|
|
|
|
number_pattern = re.compile(r'[-+]?(\d+(\.\d*)?|\.\d+)([eE][-+]?\d+)?')
|
|
|
|
def numeric(rest):
|
|
"""Parses a number at the beginning of the input."""
|
|
m = number_pattern.match(rest)
|
|
if m:
|
|
rest = rest[len(m.group(0)):]
|
|
try:
|
|
number = int(m.group(0))
|
|
except ValueError:
|
|
number = float(m.group(0))
|
|
return number, rest
|
|
else:
|
|
raise ValueError('Number not parsable')
|
|
|
|
def format_numeric(number):
|
|
"""Outputs a number."""
|
|
return str(number)
|
|
|
|
def expect(rest, string):
|
|
"""Expects to see the string at the start of the input."""
|
|
result = rest.startswith(string)
|
|
if result:
|
|
return result, rest[len(string):]
|
|
else:
|
|
return False, rest
|
|
|
|
|
|
identifier_escapes = {
|
|
'\\': '\\',
|
|
'n' : '\n',
|
|
't' : '\t',
|
|
'r' : '\r',
|
|
'%' : '%',
|
|
"'" : "'"
|
|
}
|
|
def identifier(rest):
|
|
"""Parses an optionally quoted identifier at the start of the input."""
|
|
name = ''
|
|
|
|
it = iter(rest)
|
|
c = safe_next(it)
|
|
|
|
# non-quoted
|
|
if c != "'":
|
|
while c and c not in [' ', '\t', ',']:
|
|
name += c
|
|
c = safe_next(it)
|
|
return name, c + ''.join(it)
|
|
|
|
# quoted
|
|
|
|
# discard the opening quote by fetching next character
|
|
c = safe_next(it)
|
|
while c:
|
|
if c == '\\':
|
|
ec = safe_next(it)
|
|
if not ec:
|
|
raise ValueError('Input end during escape.')
|
|
try:
|
|
name += identifier_escapes[ec]
|
|
except KeyError:
|
|
name += '\\' + ec
|
|
elif c == "'":
|
|
break
|
|
else:
|
|
name += c
|
|
c = safe_next(it)
|
|
return name, ''.join(it)
|
|
|
|
def format_identifier(name):
|
|
"""Formats an identifier."""
|
|
reverse_escapes = { c:ec for (ec,c) in identifier_escapes.items()}
|
|
if any(x in name for x in [' ',','] + list(reverse_escapes.keys())):
|
|
escaped = ''
|
|
for c in name:
|
|
if c in reverse_escapes:
|
|
escaped += '\\' + reverse_escapes[c]
|
|
else:
|
|
escaped += c
|
|
return "'"+escaped+"'"
|
|
|
|
return name
|
|
|
|
class Numeric:
|
|
"""Numeric field type."""
|
|
def parse(self, rest):
|
|
if rest.startswith('?'):
|
|
return None, rest[1:]
|
|
|
|
return numeric(rest)
|
|
|
|
def format(self, number):
|
|
if number is None:
|
|
return '?'
|
|
else:
|
|
return format_numeric(number)
|
|
|
|
def default(self):
|
|
return 0
|
|
|
|
def __repr__(self):
|
|
return 'Numeric'
|
|
|
|
def __str__(self):
|
|
return 'numeric'
|
|
|
|
|
|
class Text:
|
|
"""Text field type."""
|
|
def parse(self, rest):
|
|
if rest.startswith('?'):
|
|
return None, rest[1:]
|
|
|
|
return identifier(rest)
|
|
|
|
def format(self, name):
|
|
if name is None:
|
|
return '?'
|
|
else:
|
|
return format_identifier(name)
|
|
|
|
def default(self):
|
|
return ''
|
|
|
|
def __repr__(self):
|
|
return 'Text'
|
|
|
|
def __str__(self):
|
|
return 'string'
|
|
|
|
|
|
class Nominal:
|
|
"""Nominal field type."""
|
|
def __init__(self, names):
|
|
self.values = names
|
|
|
|
def parse(self, rest):
|
|
if rest.startswith('?'):
|
|
return None, rest[1:]
|
|
|
|
name, rest = identifier(rest)
|
|
if name in self.values:
|
|
return name, rest
|
|
else:
|
|
raise ValueError('Unknown nominal constant "{}" for {}.'.format(name, self.values))
|
|
|
|
def format(self, name):
|
|
if name is None:
|
|
return '?'
|
|
else:
|
|
if name not in self.values:
|
|
raise ValueError('Unknown nominal constant "{}" for {}.'.format(name, self.values))
|
|
return format_identifier(name)
|
|
|
|
def default(self):
|
|
return self.values[0]
|
|
|
|
def __repr__(self):
|
|
return 'Nominal in {}'.format(self.values)
|
|
|
|
def __str__(self):
|
|
return '{' + ', '.join(format_identifier(name) for name in self.values) + '}'
|
|
|
|
|
|
def attr_type(rest):
|
|
"""Parses a field type. Uses the whole rest."""
|
|
if rest.lower() in ['numeric', 'integer', 'real']:
|
|
return Numeric()
|
|
elif rest.lower() in ['string']:
|
|
return Text()
|
|
elif rest.lower().startswith('date'):
|
|
raise NotImplementedError('date parsing is not implemented.')
|
|
elif rest.startswith('{') and rest.endswith('}'):
|
|
names = []
|
|
rest = rest[1:-1]
|
|
while rest:
|
|
rest = whitespace(rest)
|
|
name, rest = identifier(rest)
|
|
names.append(name)
|
|
rest = whitespace(rest)
|
|
seen, rest = expect(rest, ',')
|
|
if not seen:
|
|
break
|
|
return Nominal(names)
|
|
else:
|
|
raise ValueError('Unknown attribute type "{}"'.format(rest))
|
|
|
|
|
|
def parse_attribute(line):
|
|
"""Parses an attribute line."""
|
|
# @attribute WS name WS type
|
|
rest = line[len('@attribute'):].strip()
|
|
rest = whitespace(rest)
|
|
name, rest = identifier(rest)
|
|
rest = whitespace(rest)
|
|
type = attr_type(rest)
|
|
return name, type
|
|
|
|
|
|
def parse_row(line, fields):
|
|
"""Parses a row. Row can be normal or sparse."""
|
|
line = line.strip()
|
|
values = {}
|
|
|
|
if not line.startswith('{'):
|
|
rest = line
|
|
first = True
|
|
for field in fields:
|
|
if not first:
|
|
rest = whitespace(rest)
|
|
seen, rest = expect(rest, ',')
|
|
first = False
|
|
rest = whitespace(rest)
|
|
value, rest = field.type.parse(rest)
|
|
values[field.name] = value
|
|
return values
|
|
else:
|
|
todo = set(range(len(fields)))
|
|
rest = line[1:-1].strip()
|
|
first = True
|
|
while rest:
|
|
if not first:
|
|
rest = whitespace(rest)
|
|
seen, rest = expect(rest, ',')
|
|
if not seen:
|
|
break
|
|
first = False
|
|
rest = whitespace(rest)
|
|
index, rest = numeric(rest)
|
|
field = fields[index]
|
|
rest = whitespace(rest)
|
|
value, rest = field.type.parse(rest)
|
|
todo.remove(index)
|
|
values[field.name] = value
|
|
for field in (fields[i] for i in todo):
|
|
values[field.name] = field.type.default()
|
|
return values
|
|
|
|
|
|
def load(fileish):
|
|
"""
|
|
Loads a data set from an arff formatted file-like object.
|
|
|
|
This generator function will parse the arff format's header to determine
|
|
data shape. Each generated item is a single expanded row.
|
|
|
|
fileish -- a file-like object
|
|
"""
|
|
# parse header first
|
|
lines = iter(fileish)
|
|
fields = []
|
|
|
|
for line in lines:
|
|
if is_empty(line) or is_comment(line):
|
|
continue
|
|
|
|
if is_relation(line):
|
|
# No care is given for the relation name.
|
|
continue
|
|
|
|
if is_attribute(line):
|
|
name, type = parse_attribute(line)
|
|
fields.append(Field(name, type))
|
|
continue
|
|
|
|
if is_data(line):
|
|
# We are done with the header, next up is 1 row per line
|
|
break
|
|
|
|
# parse data lines
|
|
for line in lines:
|
|
if is_empty(line) or is_comment(line):
|
|
continue
|
|
row = parse_row(line, fields)
|
|
yield row
|
|
|
|
def save(fileish, fields, rows, name='unnamed relation', sparse=False):
|
|
"""
|
|
Saves an arff formatted data set to a file-like object.
|
|
|
|
The rows parameter can be any iterable. The fields parameter must be a list
|
|
of `Field` instances.
|
|
|
|
fileish -- a file-like object to write to
|
|
fields -- a list of `Field` instances
|
|
rows -- an iterable containing one dictionary per data row
|
|
name -- the relation name, defaults to 'unnamed relation'
|
|
sparse -- whether the output should be in sparse format, defaults to False
|
|
"""
|
|
fileish.write(format_relation(name))
|
|
fileish.write('\n')
|
|
fileish.write(format_attributes(fields))
|
|
fileish.write('\n')
|
|
fileish.write(format_data())
|
|
for row in rows:
|
|
fileish.write(format_row(row, fields, sparse))
|