
import javax.xml.parsers.DocumentBuilder;
import javax.xml.parsers.DocumentBuilderFactory;
import org.w3c.dom.Document;
import org.w3c.dom.Element;
import org.w3c.dom.NodeList;

import java.io.File;
import java.util.List;
import java.util.Map;
import java.util.Hashtable;
import java.util.Vector;


/**
 *
 */
public class AttackOdds {
	
	class InvalidDataException extends RuntimeException {
		public InvalidDataException(String s) {
			super(s);
		}
		public InvalidDataException(String s, Throwable t) {
			super(s,t);
		}
	}

	abstract class BaseObject {

		List m_children = null;

		public void initialize(Element root, String name, String childList) {
			if (m_children == null) {
				m_children = new Vector();
			} else {
				m_children.clear();
			}
			if (!name.equals(root.getTagName())) {
				throw new InvalidDataException("tag '"+name+"' not found where expected");
			}
			NodeList nl = root.getElementsByTagName(childList);
			for (int i=0; i<nl.getLength(); i++) {
				m_children.add(createChild((Element) nl.item(i)));
			}
		}
		public abstract BaseObject createChild(Element item);
	}

	class Combat extends BaseObject {

		public Combat(Element root) {
			initialize(root);
		}

		public void initialize(Element root) {
			initialize(root, "combat", "sides");
		}

		public BaseObject createChild(Element item) {
			return new Side(item);
		}
	}
	class Side extends BaseObject {
		public Side(Element root) {
			initialize(root);
		}

		public void initialize(Element root) {
			initialize(root, "side", "unit");
		}
		public BaseObject createChild(Element item) {
			return new Unit(item);
		}
	}

	class Unit extends BaseObject {
		public Unit (Element root) {
			initialize(root);
		}
		public void initialize(Element root) {
			initialize(root, "unit", "model");

			// TODO: process 'special' tags
		}
		public BaseObject createChild(Element item) {
			return new Model(item);
		}
	}

	class Model extends BaseObject {
		public Model (Element root) {
		}

		public BaseObject createChild(Element item) {
			return null;
		}
	}

	public static final String OPT_POISON = "p";
	public static final String OPT_HATRED = "h";
	public static final String OPT_WARDSAVE = "w";

	// NOTE: maybe want to change to hashtable?
	protected Map m_options = null;
	
	public AttackOdds() {
	}

	public void setOptions(String[] options) {
		m_options = new Hashtable();
		for (int i=0; i<options.length; i++) {
			m_options.put(getKey(options[i]), getValue(options[i]));
		}
	}

	public String getKey(String str) {
		return str.substring(0, 1);
	}

	public String getValue(String str) {
		if (str.length() > 1) {
			return str.substring(1);
		} else {
			return "";
		}
	}

	public void initialize(String fileName) throws Exception {
		DocumentBuilder docBuilder = DocumentBuilderFactory.newInstance().newDocumentBuilder();
		Document data = docBuilder.parse (new File(fileName));
		Element root = data.getDocumentElement();
	}

	public boolean isPoison() {
		return m_options.containsKey(OPT_POISON);
	}

	public boolean isHatred() {
		return m_options.containsKey(OPT_HATRED);
	}

	public int getWardSave() {
		if (m_options.containsKey(OPT_WARDSAVE)) {
			return Integer.parseInt((String) m_options.get(OPT_WARDSAVE));
		} else {
			return 7;
		}
	}

	public double getProbOfPoison(int wsAtt, int wsDef, int sAtt, int armorDef) {
		double ret;
		if (isPoison()) {
			ret = getBasePoisonProb(wsAtt, wsDef) * getArmorNumber(sAtt, armorDef);
		} else {
			ret = 0.0;
		}
		debug ("prob of poison: "+ret);
		return ret;
	}
	/**
	 * This is the chance of a single hit causing a poisoned wound.
	 * It is more complicated if the hit is a 'hatred' hit as well.
	 */
	public double getBasePoisonProb(int wsAtt, int wsDef) {
		double poisonProb = (1.0/6.0);
		if (isHatred()) {
			poisonProb += getBaseToMissNumber(wsAtt, wsDef) * poisonProb;
		}
		return poisonProb;
	}

	/**
	 * get the probability of 'numAttacks' attacks wounding out of a total of 'attacks'
	 * attacks.
	 */
	public double getProbOfWound (double toWoundNumber, int attacks, int numAttacks) {
		return Math.pow( toWoundNumber , (double) (numAttacks))
			* Math.pow(1-toWoundNumber, (double) (attacks-numAttacks))
			* comb(attacks, numAttacks);
	}

	/**
	 * computes the combination n choose k: n! / (k! * (n-k)!)
	 */
	public double comb( int n, int k) {
		return (factorial(n) / (factorial(k) * factorial(n-k)));
	}

	public double factorial (int n) {
		if (n<1) return 1.0;
		double total = 1.0;
		while (n > 1) {
			total *= (double) n;
			n--;
		}
		return total;
	}


	public double getTotalToWoundNumber(
		int wsAtt,
		int sAtt,
		int wsDef,
		int tDef,
		int armorDef
	) {
		return getToHitNumber(wsAtt, wsDef) * getToWoundNumber(sAtt, tDef) * 
			getArmorNumber(sAtt, armorDef) * getWardSaveNumber()
			+ getProbOfPoison(wsAtt, wsDef, sAtt, armorDef);
	}

	/**
	 * If the attacks are poisoned, then we don't count them in here since
	 * they will be counted separately.
	 */
	public double getBaseToMissNumber(int wsAtt, int wsDef) {
		return 1.0 - getBaseToHitNumber(wsAtt, wsDef);
	}
	public double getBaseToHitNumber(int wsAtt, int wsDef) {
		double ret;
		if (wsAtt > wsDef) {
			ret = (2.0/3.0);
		} else if (wsAtt*2 < wsDef) {
			ret = (1.0/3.0);
		} else {
			ret = 0.5;
		}
		return ret;
	}

	public double getToHitNumber(int wsAtt, int wsDef) {
		double ret = getBaseToHitNumber(wsAtt, wsDef);

		double poisonProb = getBasePoisonProb(wsAtt, wsDef);
		if (isHatred()) {
			// NOTE: hatred = reroll misses once: 1 - (the chance of MISSING twice)
			ret = 1.0 - Math.pow((1.0 - ret), 2);
		}
		if (isPoison()) {
			ret -= poisonProb;
		}
		debug ("to hit number: "+ret);
		return ret;
	}

	public double getToWoundNumber (int sAtt, int tDef) {
		double ret;
		switch (sAtt - tDef) {
			case -9:
			case -8:
			case -7:
			case -6:
			case -5:
			case -4:
				ret = 0.0;
				break;
			case -3:
			case -2:
				ret = (1.0/6.0);
				break;
			case -1:
				ret = (2.0/6.0);
				break;
			case 0:
				ret = (3.0/6.0);
				break;
			case 1:
				ret = (4.0/6.0);
				break;
			case 2:
			case 3:
			case 4:
			case 5:
			case 6:
			case 7:
			case 8:
			case 9:
				ret = (5.0/6.0);
				break;
			default:
				out ("oops, bad to wound number: "+(sAtt-tDef));
				ret = 0.0;
		}
		debug ("to wound number: "+ret);
		return ret;
	}

	public double getArmorNumber(int sAtt, int armorDef) {

		int saveMod = sAtt - 3;
		
		if (saveMod < 0) saveMod = 0;

		// NOTE: the armor number is expected to be from 1 (2+ armor save) 
		// to 6 (7+ armor save)
		int effectiveArmorDef = armorDef + saveMod - 1;

		if (effectiveArmorDef < 1) { effectiveArmorDef = 1; }  // 1 corresponds to a 2+ save

		if (effectiveArmorDef > 6) { effectiveArmorDef = 6; }  // 6 corresponds to a 7+ save

		double armorNumber = ((double) (effectiveArmorDef)) / 6.0;
		debug ("armor number: "+armorNumber);
		return armorNumber;
	}

	public double getWardSaveNumber() {
		debug("ward save:");
		return getArmorNumber(3, getWardSave());
	}
	public double round(double num) {
		long prob = Math.round(
			num * 10000
		);
		return ((double) prob) / 100;
	}

	// given WS/S/attacks of attacker & WS/T/armor of defender, compute expected wound number
	public static void main (String args[]) {
		if (args.length < 6) {
			out("usage:");
			out("AttackOdds <attacker ws> <attacker strength> <attacker attacks>");
			out("           <defender ws> <defender toughness> <defender armor>");
			return;
		}
		int wsAtt = Integer.parseInt(args[0]);
		int sAtt = Integer.parseInt(args[1]);
		int attAtt = Integer.parseInt(args[2]);
		int wsDef = Integer.parseInt(args[3]);
		int tDef = Integer.parseInt(args[4]);

		int armorDef = Integer.parseInt(args[5]);

		String[] tokens = new String[args.length-6];
		System.arraycopy(args, 6, tokens,0,args.length-6);

		AttackOdds att = new AttackOdds();
		att.setOptions(tokens);

		double toWoundNumber = att.getTotalToWoundNumber(wsAtt, sAtt, wsDef, tDef, armorDef);

		debug ("to wound number: "+toWoundNumber);

		double totalHits = 0.0;
		double totalMiss = 1.0;
		double expectedValue = 0.0;
		out ("num wound\tprob wound\tat most x\tmore than x");
		for (int i=0; i<= attAtt; i++) {
			double expectedWounds = att.getProbOfWound(toWoundNumber, attAtt, i);

			expectedValue += expectedWounds * (double) i;

			totalHits += expectedWounds;
			totalMiss -= expectedWounds;
			out (""+i+"\t\t"+att.round(expectedWounds)+"\t\t"+
				att.round(totalHits)+"\t\t"+
				att.round(totalMiss) );
		}

		out ("expected number of wounds: "+	expectedValue);
		double variance = 0.0;
		for (int i=0; i<= attAtt; i++) {
			double expectedWounds = att.getProbOfWound(toWoundNumber, attAtt, i);
			variance += Math.pow((((double) i) - expectedValue) * expectedWounds, 2);
		}
		out ("stdev: "+ Math.sqrt(variance));
	}
	public static void out (String s) {
		System.out.println(s);
	}

	public static void debug (String s) {
		System.out.println(s);
	}
}

