Back to main

General-purpose data compression with PPM

This is an implementation of PPM, prediction by partial matching, a modern general-purpose data compression algorithm. Algorithms based on PPM use context information to predict unread symbols. This implementation uses a fifth-order model with an implicit probability table. Probability tables in this variant initially contain an escape symbol with a pseudo-count of one, which is incremented whenever an escape sequence is encountered. Escape sequences are used to introduce new symbols into the probability table. This implementation uses range coding for writing symbols, and a zero-order symbol dictionary when introducing previously unseen symbols into context models. The zero-order symbol probability table initially gives every byte a probability of 1/256. All context-based probability tables start with a proability of zero for all symbols, with the exception of the escape symbol. The set of all possible bytes is a strict subset of the symbols defined. The probability table for each context is adjusted as symbols are encountered in each respective context, and the global probability table is adjusted whenever a symbol is preceded by the escape symbol. This implementation uses a LRU cache to only store the most important contexts, discarding the rest, as storing every context is infeasable due to the memory requirements of a fifth-order model. Emperical evidence on natural language text compression shows that a LRU cache ten times larger than the one used here increases compression ratios by less than 5%. The model's order can also be changed; I have found that a fifth-order model is capable of recognizing many deterministic contexts without running the risk of becoming too specific and forcing excessively long escape sequences. Whenever the highest order context matched has no symbols, the next highest order context is automatically selected. When dropping down to a lower-order context, all symbols present in the higher-order context are excluded from consideration, reducing the range of possible symbols and increasing the probabilities of the remaining symbols. Consequently, the probability of an escape symbol is also increased. Whenever a probability is updated in a context with a lower order, the higher-order contexts are also updated. The counts handled by the range coder are scaled down whenever the counts of any model would push any individual probability value below the smallest positive integer representable by of fixed-point arithmetic. The only difference between range coding and arithmetic coding is that the sequence in range coding doesn't begin with an implicit decimal point. Consequently, a fixed range of [0, PPM_HIGH) is used, rather than the [0, 1) range implemented with arithmetic coding. Although shorter in length than most non-trivial compressors, correctly implementing any variant of PPM with a range coder is a fine art and a tricky exercise.

Source

#include <stdlib.h>
#include <stdio.h>
#include <assert.h>

typedef unsigned char byte;
typedef unsigned short i16;
typedef unsigned int i32;

const int PPM_MAX_CONTEXTS = 50000;
const int PPM_MAX_ORDER = 5;

const i16 PPM_ESCAPE = 256;
const i16 PPM_EOF = 257;
const i16 PPM_TOTALS = 258;
const i16 PPM_NUM_SYMBOLS = PPM_TOTALS + 1;

const int PPM_HIGH = 0xFFFF;
const int PPM_LOW = 0x0000;
const int PPM_MAX_COUNT = 1 << 15;
const int PPM_MAX_TOTAL = PPM_HIGH >> 2;
const int PPM_MSB_SHIFT = 15;
const int PPM_MSB2_SHIFT	= 14;
const int PPM_MSB_MASK = 1 << PPM_MSB_SHIFT;
const int PPM_MSB2_MASK = 1 << PPM_MSB2_SHIFT;

const int PPM_CHILD_COUNT = 1 << 5;
const int PPM_CHILD_MASK = PPM_CHILD_COUNT - 1;

#define ASSERT(x) assert(x)

#ifndef min
#define min(x, y) (((x) < (y)) ? (x) : (y))
#endif

#ifndef max
#define max(x, y) (((x) > (y)) ? (x) : (y))
#endif

struct PPM_Context
{
	i16 symbol;
	byte order;

	i16 count[PPM_TOTALS];
	i16 num_byte_types;
	i32 total;

	PPM_Context* parent;
	PPM_Context* first_child[PPM_CHILD_COUNT];

	PPM_Context* lru_next;
	PPM_Context* lru_prev;

	PPM_Context* sibling_next;
	PPM_Context* sibling_prev;
};

struct PPM_State
{
	byte cur_byte;
	int cur_bits;

	i16 code;
	i32 underflow_bits;
	i16 low;
	i16 high;

	PPM_Context* lru_head;
	PPM_Context* lru_tail;
	int num_contexts;
};

byte* file_read(const char* path, int* len)
{
	ASSERT(path);

	FILE* f = fopen(path, "rb");
	ASSERT(f);

	fseek(f, 0, SEEK_END);
	size_t size = ftell(f);
	fseek(f, 0, SEEK_SET);

	byte* buf = (byte*)malloc(size + 1);
	ASSERT(buf);
	fread(buf, 1, size, f);
	buf[size] = '\0';

	if (len) { *len = size; }

	fclose(f);
	return buf;
}

void bit_write(FILE* file, byte& cur_byte, int& cur_bits, int val, int num_bits)
{
	while (num_bits)
	{
		int put = 8 - cur_bits;
		if (put > num_bits) { put = num_bits; }
		int fraction = val & ((1 << put) - 1);
		cur_byte |= fraction << cur_bits;
		num_bits -= put;
		val >>= put;

		ASSERT(cur_bits + put <= 8);
		if (cur_bits + put == 8)
		{
			fwrite(&cur_byte, 1, 1, file);
			cur_byte = 0;
		}

		cur_bits = (cur_bits + put) & 7;
	}
}

void bit_write_flush(FILE* file, byte& cur_byte, int& cur_bits)
{
	if (cur_bits > 0)
	{
		fwrite(&cur_byte, 1, 1, file);
		cur_byte = 0;
		cur_bits = 0;
	}
}

int bit_read(FILE* file, byte& cur_byte, int& cur_bits, int num_bits)
{
	int value = 0;
	int value_bits = 0;

	while (value_bits < num_bits)
	{
		if (cur_bits == 0)
		{
			if (feof(file)) { cur_byte = 0; }
			else { fread(&cur_byte, 1, 1, file); }
		}

		int get = 8 - cur_bits;
		if (get > (num_bits - value_bits)) { get = num_bits - value_bits; }
		int fraction = cur_byte;
		fraction >>= cur_bits;
		fraction &= (1 << get) - 1;
		value |= fraction << value_bits;
		value_bits += get;
		cur_bits = (cur_bits + get) & 7;
	}

	return value;
}

void ppm_context_make_lru_head(PPM_Context* c, PPM_State& state)
{
	if (c == state.lru_head) { state.lru_head = c->lru_next; }
	if (c == state.lru_tail) { state.lru_tail = c->lru_prev; }

	if (c->lru_next) { c->lru_next->lru_prev = c->lru_prev; }
	if (c->lru_prev) { c->lru_prev->lru_next = c->lru_next; }

	c->lru_next = state.lru_head;
	c->lru_prev = 0;
	if (state.lru_head)
	{
		ASSERT(!state.lru_head->lru_prev);
		state.lru_head->lru_prev = c;
	}
	state.lru_head = c;
	if (!state.lru_tail) { state.lru_tail = c; }
}

void ppm_context_free_single(PPM_Context* c, PPM_State& state, bool reuse_memory)
{
	if (c->parent && c == c->parent->first_child[c->symbol & PPM_CHILD_MASK])
	{
		c->parent->first_child[c->symbol & PPM_CHILD_MASK] = c->sibling_next;
	}

	if (c == state.lru_head) { state.lru_head = c->lru_next; }
	if (c == state.lru_tail) { state.lru_tail = c->lru_prev; }

	if (c->sibling_next) { c->sibling_next->sibling_prev = c->sibling_prev; }
	if (c->sibling_prev) { c->sibling_prev->sibling_next = c->sibling_next; }

	if (c->lru_next) { c->lru_next->lru_prev = c->lru_prev; }
	if (c->lru_prev) { c->lru_prev->lru_next = c->lru_next; }

	for (int i = 0; i < PPM_CHILD_COUNT; i++)
	{
		while (c->first_child[i])
		{
			ppm_context_free_single(c->first_child[i], state, false);
		}
	}

	state.num_contexts--;
	if (!reuse_memory) { delete c; }
}

void ppm_context_free_root(PPM_Context* c, PPM_State& state)
{
	ASSERT(!c->sibling_next);
	ASSERT(!c->sibling_prev);

	for (int i = 0; i < PPM_CHILD_COUNT; i++)
	{
		while (c->first_child[i]) { ppm_context_free_single(c->first_child[i], state, false); }
	}
	delete c;
}

PPM_Context* ppm_context_create(PPM_Context* parent, PPM_State& state, i16 symbol)
{
	PPM_Context* c = 0;

	if (state.num_contexts == PPM_MAX_CONTEXTS)
	{
		c = state.lru_tail;
		ASSERT(c);
		ppm_context_free_single(c, state, true);
		ASSERT(state.num_contexts < PPM_MAX_CONTEXTS);
	}
	else
	{
		c = new PPM_Context();
	}

	state.num_contexts++;

	ASSERT(parent);
	c->parent = parent;
	c->order = parent->order;
	c->symbol = symbol;

	for (int i = 0; i < PPM_CHILD_COUNT; i++) { c->first_child[i] = 0; }
	c->lru_next = c->lru_prev = 0;
	c->sibling_next = c->sibling_prev = 0;

	for (int i = 0; i < PPM_TOTALS; i++) { c->count[i] = 0; }
	c->num_byte_types = 0;
	c->count[PPM_ESCAPE] = 1;
	c->count[PPM_EOF] = 0;
	c->total = c->count[PPM_ESCAPE];

	c->lru_next = state.lru_head;
	if (state.lru_head)
	{
		ASSERT(!state.lru_head->lru_prev);
		state.lru_head->lru_prev = c;
	}
	state.lru_head = c;
	if (!state.lru_tail) { state.lru_tail = c; }

	if (parent)
	{
		int parent_idx = c->symbol & PPM_CHILD_MASK;
		if (parent->first_child[parent_idx])
		{
			ASSERT(!parent->first_child[parent_idx]->sibling_prev);

			parent->first_child[parent_idx]->sibling_prev = c;
			c->sibling_next = parent->first_child[parent_idx];
		}

		parent->first_child[parent_idx] = c;
	}

	return c;
}

PPM_Context* ppm_context_one_create()
{
	PPM_Context* c = new PPM_Context();

	for (int i = 0; i < PPM_TOTALS; i++) { c->count[i] = 1; }
	c->total = PPM_TOTALS;
	c->num_byte_types = 256;

	c->order = 0;
	c->parent = 0;
	for (int i = 0; i < PPM_CHILD_COUNT; i++) { c->first_child[i] = 0; }
	c->lru_next = c->lru_prev = 0;
	c->sibling_next = c->sibling_prev = 0;
	c->symbol = 0;

	return c;
}

PPM_Context* ppm_context_get_sub(PPM_Context* c, i16 sym)
{
	for (PPM_Context* s = c->first_child[sym & PPM_CHILD_MASK]; s; s = s->sibling_next)
	{
		if (s->symbol == sym)
		{
			return s;
		}
	}
	return 0;
}

void ppm_context_update_total(PPM_Context* c)
{
	c->total = 0;
	for (int i = 0; i < PPM_TOTALS; i++)
	{
		c->total += c->count[i];
	}
}

void ppm_context_rescale(PPM_Context* c)
{
	int scale = 1;

	if (c->total >= (PPM_MAX_TOTAL * 2)) { scale = 4; }
	else if (c->total >= PPM_MAX_TOTAL) { scale = 2; }
	else
	{
		for (int i = 0; i < PPM_TOTALS; i++)
		{
			if (c->count[i] >= PPM_MAX_COUNT)
			{
				scale = 2;
				break;
			}
		}
	}

	if (scale > 1)
	{
		for (int i = 0; i < PPM_TOTALS; i++)
		{
			if (c->count[i] > scale) { c->count[i] /= scale; }
			else if (c->count[i] > 0) { c->count[i] = 1; }
			else { c->count[i] = 0; }
		}
		ppm_context_update_total(c);
	}

	ASSERT(c->total < PPM_MAX_TOTAL);
}

bool ppm_has_symbol(i16 sym, PPM_Context* c)
{
	return c->count[sym] > 0;
}

bool ppm_has_symbol(i16 sym, i32* totals)
{
	return totals[sym + 1] > totals[sym];
}

void ppm_sym_to_range(i16 sym, PPM_Context* c, i32& low, i32& high)
{
	low = 0;
	high = 0;

	for (int i = 0; i < sym; i++) { low += c->count[i]; }
	high = low + c->count[sym];
}

int ppm_count_to_prob(i32* low, i32* high, int count, PPM_Context* c)
{
	i16 sym = 0;
	int lower = 0;
	int upper = 0;
	for (sym = 0; sym < PPM_TOTALS; sym++)
	{
		lower = upper;
		upper = upper + c->count[sym];

		if (count >= lower && count < upper)
		{
			if (low) { *low = lower; }
			if (high) { *high = upper; }
			return sym;
		}
	}

	ASSERT(false);
	return -1;
}

i16 ppm_unscaled_count(i16 code, i16 low, i16 high, i16 scale)
{
	i32 range = (high - low) + 1;
	return (((i32)(code - low) + 1) * scale - 1) / range;
}

PPM_State ppm_init_state()
{
	PPM_State p;

	p.cur_byte = 0;
	p.cur_bits = 0;

	p.code = 0;
	p.underflow_bits = 0;
	p.low = PPM_LOW;
	p.high = PPM_HIGH;

	p.lru_head = p.lru_tail = 0;
	p.num_contexts = 0;

	return p;
}

void ppm_write_range(FILE* outf, PPM_State& state, i32 low, i32 high, i16 scale)
{
	i32 range = (i32)(state.high - state.low) + 1;
	state.high = state.low + (i16)((range * high) / scale - 1);
	state.low = state.low + (i16)((range * low) / scale);

	while (true)
	{
		if ((state.high & PPM_MSB_MASK) == (state.low & PPM_MSB_MASK))
		{
			bit_write(outf, state.cur_byte, state.cur_bits, state.high >> PPM_MSB_SHIFT, 1);
			for (int i = 0; i < state.underflow_bits; i++)
			{
				bit_write(outf, state.cur_byte, state.cur_bits, ~state.high >> PPM_MSB_SHIFT, 1);
			}
			state.underflow_bits = 0;
		}
		else if ((state.low & PPM_MSB2_MASK) && !(state.high & PPM_MSB2_MASK))
		{
			state.underflow_bits++;
			state.low &= PPM_MSB2_MASK - 1;
			state.high |= PPM_MSB2_MASK;
		}
		else
		{
			break;
		}

		state.low <<= 1;
		state.high = (state.high << 1) | 1;
	}
}

void ppm_write_symbol(FILE* outf, i16 cur_sym, PPM_State& state, PPM_Context* context)
{
	ASSERT(ppm_has_symbol(cur_sym, context));

	i32 low, high;
	ppm_sym_to_range(cur_sym, context, low, high);
	ppm_write_range(outf, state, low, high, context->total);
}

void ppm_flush_encoder(FILE* outf, PPM_State& state)
{
	bit_write(outf, state.cur_byte, state.cur_bits, state.low >> PPM_MSB2_SHIFT, 1);
	for (int i = 0; i < state.underflow_bits + 1; i++)
	{
		bit_write(outf, state.cur_byte, state.cur_bits, ~state.low >> PPM_MSB2_SHIFT, 1);
	}
	bit_write(outf, state.cur_byte, state.cur_bits, 0, 16);

	bit_write_flush(outf, state.cur_byte, state.cur_bits);
}

void ppm_init_decoder(FILE* inf, PPM_State& state)
{
	for (int i = 0; i < 16; i++)
	{
		state.code = (state.code << 1) | bit_read(inf, state.cur_byte, state.cur_bits, 1);
	}
}

void ppm_read_range(FILE* inf, PPM_State& state, i32 low, i32 high, i16 scale)
{
	i32 range = (i32)(state.high - state.low) + 1;
	state.high = state.low + (i16)((range * high) / scale - 1);
	state.low = state.low + (i16)((range * low) / scale);

	while (true)
	{
		if ((state.high & PPM_MSB_MASK) == (state.low & PPM_MSB_MASK))
		{

		}
		else if ((state.low & PPM_MSB2_MASK) && !(state.high & PPM_MSB2_MASK))
		{
			state.low &= PPM_MSB2_MASK - 1;
			state.high |= PPM_MSB2_MASK;
			state.code ^= PPM_MSB2_MASK;
		}
		else
		{
			break;
		}

		state.low <<= 1;
		state.high = (state.high << 1) | 1;
		state.code = (state.code << 1) | bit_read(inf, state.cur_byte, state.cur_bits, 1);
	}
}

i16 ppm_remove_symbol(FILE* inf, PPM_State& state, PPM_Context* context)
{
	i32 low, high;
	int count = ppm_unscaled_count(state.code, state.low, state.high, context->total);
	i16 cur_sym = ppm_count_to_prob(&low, &high, count, context);

	ppm_read_range(inf, state, low, high, context->total);

	return cur_sym;
}

void ppm_context_update_sym(PPM_Context* c, PPM_State& state, i16 sym)
{
	if (c->count[sym] == 0)
	{
		c->num_byte_types++;
	}

	c->count[sym]++;
	c->total++;
	ppm_context_rescale(c);

	if (sym < 256 && c->order < PPM_MAX_ORDER && !ppm_context_get_sub(c, sym))
	{
		ppm_context_create(c, state, sym);
	}
}

void ppm_context_local(PPM_Context* local, PPM_Context* c)
{
	for (int i = 0; i < PPM_TOTALS; i++) { local->count[i] = c->count[i]; }

	local->total = c->total;
	local->num_byte_types = c->num_byte_types;
	local->order = c->order;
}

void ppm_context_local_exclude(PPM_Context* local, PPM_Context* c)
{
	for (int i = 0; i < 256; i++)
	{
		if (c->count[i] > 0 && local->count[i] > 0)
		{
			local->total -= local->count[i];
			local->num_byte_types--;
			local->count[i] = 0;
		}
	}
}

int ppm_context_select(PPM_Context* ctx[1 + PPM_MAX_ORDER], int num_ctx)
{
	for (int i = num_ctx - 1; i >= 0; i--)
	{
		if (ctx[i]->total > ctx[i]->count[PPM_ESCAPE] + ctx[i]->count[PPM_EOF])
		{
			return i;
		}
	}

	return num_ctx - 1;
}

int ppm_context_get(PPM_State& state, PPM_Context* ctx[1 + PPM_MAX_ORDER], PPM_Context* global, const byte* src, int src_len, int pos)
{
	for (int i = 0; i <= PPM_MAX_ORDER; i++) { ctx[i] = 0; }
	ctx[0] = global;

	for (int i = 1; i <= PPM_MAX_ORDER; i++)
	{
		if (pos < i) { break; }

		ctx[i] = global;
		for (int j = 1; ctx[i] && j <= i; j++)
		{
			ctx[i] = ppm_context_get_sub(ctx[i], src[pos - i + j]);
		}

		if (!ctx[i]) { break; }
	}

	int num_ctx = 0;
	for (int i = 0; i <= PPM_MAX_ORDER; i++)
	{
		if (!ctx[i]) { break; }
		ppm_context_make_lru_head(ctx[i], state);
		++num_ctx;
	}

	return num_ctx;
}

void ppm_encode(const char* src_file, const char* dest_file)
{
	int src_len = 0;
	byte* src = file_read(src_file, &src_len);
	ASSERT(src);

	FILE* outf = fopen(dest_file, "wb");
	ASSERT(outf);

	PPM_State state = ppm_init_state();
	PPM_Context* context_global = ppm_context_one_create();
	PPM_Context* context_local = ppm_context_one_create();

	bit_write(outf, state.cur_byte, state.cur_bits, src_len, 32);

	int pos = 0;
	while (pos < src_len)
	{
		i16 sym = src[pos];

		PPM_Context* ctx[1 + PPM_MAX_ORDER];
		int num_ctx = ppm_context_get(state, ctx, context_global, src, src_len, pos - 1);
		int selected_ctx_idx = ppm_context_select(ctx, num_ctx);
		int first_selected_ctx_idx = selected_ctx_idx;

		PPM_Context* sel_ctx = ctx[selected_ctx_idx];
		ppm_context_local(context_local, sel_ctx);
		while (!ppm_has_symbol(sym, sel_ctx))
		{
			ppm_write_symbol(outf, PPM_ESCAPE, state, context_local);
			ppm_context_update_sym(sel_ctx, state, PPM_ESCAPE);

			ASSERT(selected_ctx_idx > 0);

			sel_ctx = ctx[selected_ctx_idx - 1];
			ppm_context_local(context_local, sel_ctx);
			for (int i = selected_ctx_idx; i <= first_selected_ctx_idx; i++)
			{
				ASSERT(!ppm_has_symbol(sym, ctx[i]));
				ppm_context_local_exclude(context_local, ctx[i]);
			}
			selected_ctx_idx--;
		}
		ASSERT(ppm_has_symbol(sym, context_local));
		ppm_write_symbol(outf, sym, state, context_local);

		for (int i = selected_ctx_idx; i < num_ctx; i++)
		{
			ppm_context_update_sym(ctx[i], state, sym);
		}

		pos++;
	}
	ppm_write_symbol(outf, PPM_EOF, state, context_global);
	ppm_flush_encoder(outf, state);

	fclose(outf);

	ppm_context_free_root(context_global, state);
	ppm_context_free_root(context_local, state);

	free(src);
}

void ppm_decode(const char* src_file, const char* dest_file)
{
	FILE* inf = fopen(src_file, "rb");
	ASSERT(inf);

	ASSERT(!feof(inf));

	PPM_State state = ppm_init_state();
	PPM_Context* context_global = ppm_context_one_create();
	PPM_Context* context_local = ppm_context_one_create();

	int src_len = bit_read(inf, state.cur_byte, state.cur_bits, 32);
	byte* res = (byte*)malloc(src_len);

	ppm_init_decoder(inf, state);

	int pos = 0;
	while (pos < src_len)
	{
		PPM_Context* ctx[1 + PPM_MAX_ORDER];
		int num_ctx = ppm_context_get(state, ctx, context_global, res, src_len, pos - 1);
		int selected_ctx_idx = ppm_context_select(ctx, num_ctx);
		int first_selected_ctx_idx = selected_ctx_idx;

		PPM_Context* sel_ctx = ctx[selected_ctx_idx];
		ppm_context_local(context_local, sel_ctx);

		i16 sym = ppm_remove_symbol(inf, state, sel_ctx);
		while (sym == PPM_ESCAPE)
		{
			ppm_context_update_sym(sel_ctx, state, PPM_ESCAPE);

			ASSERT(selected_ctx_idx > 0);
			sel_ctx = ctx[selected_ctx_idx - 1];
			ppm_context_local(context_local, sel_ctx);
			for (int i = selected_ctx_idx; i <= first_selected_ctx_idx; i++)
			{
				ppm_context_local_exclude(context_local, ctx[i]);
			}
			selected_ctx_idx--;

			sym = ppm_remove_symbol(inf, state, context_local);
		}

		for (int i = selected_ctx_idx; i < num_ctx; i++)
		{
			ppm_context_update_sym(ctx[i], state, sym);
		}

		res[pos] = sym;
		pos++;

	}
	ASSERT(ppm_remove_symbol(inf, state, context_global) == PPM_EOF);

	ppm_context_free_root(context_global, state);
	ppm_context_free_root(context_local, state);
	fclose(inf);

	FILE* outf = fopen(dest_file, "wb");
	ASSERT(outf);
	fwrite(res, 1, src_len, outf);
	free(res);
	fclose(outf);
}