import numpy as np from tempfile import NamedTemporaryFile as Tmp def load_csv_file(fname, *args, **kargs): f = open(fname) data = np.genfromtxt(f, delimiter=",", comments="#", *args, **kargs) f.close() # don't leak file handles return data def load_csv_file_fast(fname): data = np.loadtxt(fname, delimiter=",") return data def load_binary_file(fname, dtype='float32', modify=False): data = np.memmap(fname, dtype=dtype, mode='r+' if modify else 'c') return data def write_csv_file(fname, rows, header=None, width=None, break_col=None): if fname is None: f = Tmp() else: f = open(fname, 'w') if width: fmt = "%%%ds" % width else: fmt = "%s" if header: f.write('#') f.write(", ".join([fmt % str(x) for x in header])) f.write('\n') last_row = None for row in rows: # Insert extra line breaks for Gnuplot plot3d plotting whenever # the value of break_col changes. if not last_row is None and not break_col is None: if last_row[break_col] != row[break_col]: f.write('\n') f.write(' ') f.write(", ".join([fmt % str(x) for x in row])) f.write('\n') last_row = row if fname is None: # keep it open f.flush() return f else: f.close() def select(keep, rows): ok_rows = np.zeros(len(rows), dtype=bool) for i, row in enumerate(rows): ok_rows[i] = keep(row) return rows[ok_rows]