use std::{
    collections::{BTreeSet, HashSet},
    fs::File,
    io::{BufRead, BufReader},
    path::Path,
};
fn get_around(map: &[Vec<char>], p: (usize, usize)) -> Vec<(char, usize, usize)> {
    let mut around_us = vec![];
    if p.0 > 0 {
        let (y, x) = (p.0 - 1, p.1);
        around_us.push((map[y][x], y, x));
    }
    if let Some(row) = map.get(p.0 + 1) {
        let (y, x) = (p.0 + 1, p.1);
        around_us.push((row[x], y, x));
    }
    if p.1 > 0 {
        let (y, x) = (p.0, p.1 - 1);
        around_us.push((map[y][x], y, x));
    }
    if let Some(f) = map[p.0].get(p.1 + 1) {
        let (y, x) = (p.0, p.1 + 1);
        around_us.push((*f, y, x));
    }
    around_us
}
#[derive(Debug, Eq, Hash, PartialEq)]
struct Region {
    c: char,
    plots: BTreeSet<Plot>,
}
impl Region {
    fn read_from_map(map: &[Vec<char>], c: &char, p: (usize, usize)) -> Self {
        let mut region = Region {
            c: *c,
            plots: BTreeSet::new(),
        };
        region.plots.insert(Plot {
            y: p.0,
            x: p.1,
            perimiter: (get_around(map, (p.0, p.1))
                .iter()
                .filter(|(f, _, _)| f == c)
                .collect::<Vec<_>>()
                .len() as i32
                - 4)
            .unsigned_abs() as usize,
        });
        region.read_recurse(map, c, p);
        region
    }
    fn read_recurse(&mut self, map: &[Vec<char>], c: &char, p: (usize, usize)) {
        for (f, y, x) in get_around(map, p) {
            if f == *c {
                let perimiter = (get_around(map, (y, x))
                    .iter()
                    .filter(|(f, _, _)| f == c)
                    .collect::<Vec<_>>()
                    .len() as i32
                    - 4)
                .unsigned_abs() as usize;
                if self.plots.insert(Plot { y, x, perimiter }) {
                    self.read_recurse(map, c, (y, x));
                }
            }
        }
    }
    fn get_fence_cost(&self, map: &[Vec<char>], max_pos: usize) -> usize {
        // Okay so, instead of perimiters * len, sides * len
        // form all possible rows and columns of plots
        // AAAA
        // BB, B, B, BB
        //     B, B
        // C, C, CC, C
        //    C,     C
        // D
        // EEE
        let mut rows_initial: Vec<Vec<Plot>> = vec![];
        for i in 0..max_pos {
            let mut a: Vec<Plot> = self.plots.iter().filter(|p| p.y == i).cloned().collect();
            if !a.is_empty() {
                a.sort_by_key(|p| p.x);
                rows_initial.push(a);
            }
        }
        let mut rows: Vec<Vec<Plot>> = vec![];
        for row in &rows_initial {
            let mut real_rows: Vec<Vec<Plot>> = vec![];
            let mut last = None;
            let mut current_real_row: Vec<Plot> = vec![];
            for p in row {
                if last.is_none() {
                    current_real_row.push(p.clone());
                    last = Some(p);
                    continue;
                }
                if p.x != last.unwrap().x + 1 {
                    real_rows.push(current_real_row.clone());
                    current_real_row.clear();
                }
                current_real_row.push(p.clone());
                last = Some(p);
            }
            real_rows.push(current_real_row.clone());
            rows.append(&mut real_rows);
        }
        let mut row_sides = 0;
        for row in rows {
            // top
            let mut sides = 0;
            let mut side_length = 0;
            for p in &row {
                let fence = if p.y > 0 {
                    map[p.y - 1][p.x] != self.c
                } else {
                    true
                };
                if fence {
                    side_length += 1;
                } else if side_length > 0 {
                    sides += 1;
                    side_length = 0;
                }
            }
            if side_length > 0 {
                sides += 1;
            }
            row_sides += sides;
            // bottom
            let mut sides = 0;
            let mut side_length = 0;
            for p in &row {
                let fence = if let Some(below) = map.get(p.y + 1) {
                    below[p.x] != self.c
                } else {
                    true
                };
                if fence {
                    side_length += 1;
                } else if side_length > 0 {
                    sides += 1;
                    side_length = 0;
                }
            }
            if side_length > 0 {
                sides += 1;
            }
            row_sides += sides;
        }
        let mut cols_initial: Vec<Vec<Plot>> = vec![];
        for i in 0..max_pos {
            let mut a: Vec<Plot> = self.plots.iter().filter(|p| p.x == i).cloned().collect();
            if !a.is_empty() {
                a.sort_by_key(|p| p.y);
                cols_initial.push(a);
            }
        }
        let mut cols: Vec<Vec<Plot>> = vec![];
        for col in &cols_initial {
            let mut real_cols: Vec<Vec<Plot>> = vec![];
            let mut last = None;
            let mut current_real_col: Vec<Plot> = vec![];
            for p in col {
                if last.is_none() {
                    current_real_col.push(p.clone());
                    last = Some(p);
                    continue;
                }
                if p.y != last.unwrap().y + 1 {
                    real_cols.push(current_real_col.clone());
                    current_real_col.clear();
                }
                current_real_col.push(p.clone());
                last = Some(p);
            }
            real_cols.push(current_real_col.clone());
            cols.append(&mut real_cols);
        }
        let mut col_sides = 0;
        for col in cols {
            // right
            let mut sides = 0;
            let mut side_length = 0;
            for p in &col {
                let fence = if p.x > 0 {
                    map[p.y][p.x - 1] != self.c
                } else {
                    true
                };
                if fence {
                    side_length += 1;
                } else if side_length > 0 {
                    sides += 1;
                    side_length = 0;
                }
            }
            if side_length > 0 {
                sides += 1;
            }
            col_sides += sides;
            // left
            let mut sides = 0;
            let mut side_length = 0;
            for p in &col {
                let fence = if let Some(left) = map[p.y].get(p.x + 1) {
                    *left != self.c
                } else {
                    true
                };
                if fence {
                    side_length += 1;
                } else if side_length > 0 {
                    sides += 1;
                    side_length = 0;
                }
            }
            if side_length > 0 {
                sides += 1;
            }
            col_sides += sides;
        }
        self.plots.len() * (row_sides + col_sides)
    }
}
#[derive(Clone, Debug, Eq, Hash, Ord, PartialEq, PartialOrd)]
struct Plot {
    y: usize,
    x: usize,
    perimiter: usize,
}
pub fn part_two(input: &Path) -> anyhow::Result<usize> {
    let reader = BufReader::new(File::open(input)?);
    let mut map: Vec<Vec<char>> = vec![];
    for line in reader.lines() {
        map.push(line?.chars().collect());
    }
    let mut regions: HashSet<Region> = HashSet::new();
    for (y, row) in map.iter().enumerate() {
        for (x, c) in row.iter().enumerate() {
            regions.insert(Region::read_from_map(&map, c, (y, x)));
        }
    }
    Ok(regions
        .iter()
        .map(|r| -> usize { r.get_fence_cost(&map, map.len()) })
        .sum())
}