Files
twitter-project/eca/arff.py

385 lines
9.8 KiB
Python

"""
ARFF format loading and saving module.
This module implements the book version [1] of the ARFF format. This means there
is no support for instance weights.
Known limitations:
- This implementation does not parse dates
[1]: http://weka.wikispaces.com/ARFF+%28book+version%29
"""
import re
from collections import namedtuple
Field = namedtuple('Field',['name','type'])
__all__ = ['load', 'save', 'Field', 'Numeric', 'Text', 'Nominal']
#
# Line type functions
#
def is_empty(line):
return not line.strip()
def is_comment(line):
return line.startswith('%')
def format_comment(line):
return '% '+line
def is_relation(line):
return line.lower().startswith('@relation')
def format_relation(name):
return '@relation ' + format_identifier(name) + '\n'
def is_attribute(line):
return line.lower().startswith('@attribute')
def format_attribute(field):
return '@attribute ' + format_identifier(field.name) + ' ' + str(field.type) + '\n'
def format_attributes(fields):
result = []
for field in fields:
result.append(format_attribute(field))
return ''.join(result)
def is_data(line):
return line.lower().startswith('@data')
def format_data():
return '@data\n'
def format_row(row, fields, sparse=False):
"""Formats a data row based on the given fields."""
if sparse:
result = []
for i in range(len(fields)):
field = fields[i]
val = row.get(field.name)
if val != field.type.default():
result.append(format_numeric(i) + ' ' + field.type.format(val))
return '{' + ','.join(result) + '}\n'
else:
result = []
for field in fields:
result.append(field.type.format(row.get(field.name)))
return ','.join(result)+'\n'
def safe_next(it):
"""Returns the next character from the iterator or ''."""
try:
return next(it)
except StopIteration:
return ''
def whitespace(rest):
"""Parses whitespace at the beginning of the input."""
return rest.lstrip()
number_pattern = re.compile(r'[-+]?(\d+(\.\d*)?|\.\d+)([eE][-+]?\d+)?')
def numeric(rest):
"""Parses a number at the beginning of the input."""
m = number_pattern.match(rest)
if m:
rest = rest[len(m.group(0)):]
try:
number = int(m.group(0))
except ValueError:
number = float(m.group(0))
return number, rest
else:
raise ValueError('Number not parsable')
def format_numeric(number):
"""Outputs a number."""
return str(number)
def expect(rest, string):
"""Expects to see the string at the start of the input."""
result = rest.startswith(string)
if result:
return result, rest[len(string):]
else:
return False, rest
identifier_escapes = {
'\\': '\\',
'n' : '\n',
't' : '\t',
'r' : '\r',
'%' : '%',
"'" : "'"
}
def identifier(rest):
"""Parses an optionally quoted identifier at the start of the input."""
name = ''
it = iter(rest)
c = safe_next(it)
# non-quoted
if c != "'":
while c and c not in [' ', '\t', ',']:
name += c
c = safe_next(it)
return name, c + ''.join(it)
# quoted
# discard the opening quote by fetching next character
c = safe_next(it)
while c:
if c == '\\':
ec = safe_next(it)
if not ec:
raise ValueError('Input end during escape.')
try:
name += identifier_escapes[ec]
except KeyError:
name += '\\' + ec
elif c == "'":
break
else:
name += c
c = safe_next(it)
return name, ''.join(it)
def format_identifier(name):
"""Formats an identifier."""
reverse_escapes = { c:ec for (ec,c) in identifier_escapes.items()}
if any(x in name for x in [' ',','] + list(reverse_escapes.keys())):
escaped = ''
for c in name:
if c in reverse_escapes:
escaped += '\\' + reverse_escapes[c]
else:
escaped += c
return "'"+escaped+"'"
return name
class Numeric:
"""Numeric field type."""
def parse(self, rest):
if rest.startswith('?'):
return None, rest[1:]
return numeric(rest)
def format(self, number):
if number is None:
return '?'
else:
return format_numeric(number)
def default(self):
return 0
def __repr__(self):
return 'Numeric'
def __str__(self):
return 'numeric'
class Text:
"""Text field type."""
def parse(self, rest):
if rest.startswith('?'):
return None, rest[1:]
return identifier(rest)
def format(self, name):
if name is None:
return '?'
else:
return format_identifier(name)
def default(self):
return ''
def __repr__(self):
return 'Text'
def __str__(self):
return 'string'
class Nominal:
"""Nominal field type."""
def __init__(self, names):
self.values = names
def parse(self, rest):
if rest.startswith('?'):
return None, rest[1:]
name, rest = identifier(rest)
if name in self.values:
return name, rest
else:
raise ValueError('Unknown nominal constant "{}" for {}.'.format(name, self.values))
def format(self, name):
if name is None:
return '?'
else:
if name not in self.values:
raise ValueError('Unknown nominal constant "{}" for {}.'.format(name, self.values))
return format_identifier(name)
def default(self):
return self.values[0]
def __repr__(self):
return 'Nominal in {}'.format(self.values)
def __str__(self):
return '{' + ', '.join(format_identifier(name) for name in self.values) + '}'
def attr_type(rest):
"""Parses a field type. Uses the whole rest."""
if rest.lower() in ['numeric', 'integer', 'real']:
return Numeric()
elif rest.lower() in ['string']:
return Text()
elif rest.lower().startswith('date'):
raise NotImplementedError('date parsing is not implemented.')
elif rest.startswith('{') and rest.endswith('}'):
names = []
rest = rest[1:-1]
while rest:
rest = whitespace(rest)
name, rest = identifier(rest)
names.append(name)
rest = whitespace(rest)
seen, rest = expect(rest, ',')
if not seen:
break
return Nominal(names)
else:
raise ValueError('Unknown attribute type "{}"'.format(rest))
def parse_attribute(line):
"""Parses an attribute line."""
# @attribute WS name WS type
rest = line[len('@attribute'):].strip()
rest = whitespace(rest)
name, rest = identifier(rest)
rest = whitespace(rest)
type = attr_type(rest)
return name, type
def parse_row(line, fields):
"""Parses a row. Row can be normal or sparse."""
line = line.strip()
values = {}
if not line.startswith('{'):
rest = line
first = True
for field in fields:
if not first:
rest = whitespace(rest)
seen, rest = expect(rest, ',')
first = False
rest = whitespace(rest)
value, rest = field.type.parse(rest)
values[field.name] = value
return values
else:
todo = set(range(len(fields)))
rest = line[1:-1].strip()
first = True
while rest:
if not first:
rest = whitespace(rest)
seen, rest = expect(rest, ',')
if not seen:
break
first = False
rest = whitespace(rest)
index, rest = numeric(rest)
field = fields[index]
rest = whitespace(rest)
value, rest = field.type.parse(rest)
todo.remove(index)
values[field.name] = value
for field in (fields[i] for i in todo):
values[field.name] = field.type.default()
return values
def load(fileish):
"""
Loads a data set from an arff formatted file-like object.
This generator function will parse the arff format's header to determine
data shape. Each generated item is a single expanded row.
fileish -- a file-like object
"""
# parse header first
lines = iter(fileish)
fields = []
for line in lines:
if is_empty(line) or is_comment(line):
continue
if is_relation(line):
# No care is given for the relation name.
continue
if is_attribute(line):
name, type = parse_attribute(line)
fields.append(Field(name, type))
continue
if is_data(line):
# We are done with the header, next up is 1 row per line
break
# parse data lines
for line in lines:
if is_empty(line) or is_comment(line):
continue
row = parse_row(line, fields)
yield row
def save(fileish, fields, rows, name='unnamed relation', sparse=False):
"""
Saves an arff formatted data set to a file-like object.
The rows parameter can be any iterable. The fields parameter must be a list
of `Field` instances.
fileish -- a file-like object to write to
fields -- a list of `Field` instances
rows -- an iterable containing one dictionary per data row
name -- the relation name, defaults to 'unnamed relation'
sparse -- whether the output should be in sparse format, defaults to False
"""
fileish.write(format_relation(name))
fileish.write('\n')
fileish.write(format_attributes(fields))
fileish.write('\n')
fileish.write(format_data())
for row in rows:
fileish.write(format_row(row, fields, sparse))