# Page Scanner Wizard, copyright 2023 Evil Mr Henry
# Distribution allowed under the GPL v3 or later.

import json
import tkinter as tk
import tkinter.filedialog as filedialog
import os.path
import os
from collections import defaultdict
from typing import Text, List

import regex as regex
from PIL import Image, ImageTk

# TODO: picture with "" filename breaks everything

ROTATE_NONE = "None"
ROTATE_LEFT = "Left"
ROTATE_RIGHT = "Right"
ROTATE_180 = "180 Degrees"

OPERATION_ROTATE = "Rotate"
OPERATION_PRECISE_ROTATE = "Precise_rotate"
OPERATION_RENAME = "Rename"
OPERATION_CROP = "Crop"
OPERATION_CROP_POINTS = "Crop points"
OPERATION_NEW_SERIES = "New series"
OPERATION_NOT_NEW_SERIES = "Not New series"

CROP_NONE = "None"
CROP_BASIC = "Crop"

page_number_regex = regex.compile(r"(\d+)")


class MainWindow(object):
    def __init__(self):
        self.operations = defaultdict(dict)
        self._current_directory = None
        self._current_index = None
        self.current_filename = None
        self.image_thumb = None
        self.original_image = None
        self.root = tk.Tk()
        self._filename_list = tk.Variable(self.root, value=[])
        self.rotation = tk.StringVar(self.root, ROTATE_NONE)
        self.crop_type = tk.StringVar(self.root, CROP_NONE)
        self.page_number = tk.StringVar(self.root, "m0001.jpg")
        self.page_number.trace("w", self.change_page_number)
        self.precise_rotate = tk.StringVar(self.root, "0")
        self.precise_rotate.trace("w", self.change_precise_rotate)
        self.page_number_increase = tk.StringVar(self.root, "2")
        self.default_rotation = ROTATE_NONE

        self.image_viewer = tk.Canvas(self.root,  background="#990099")
        self.image_viewer.grid(row=0, column=0, rowspan=2, sticky="NSEW")
        self.image_viewer.bind("<Configure>", self.look_at_image)

        self.controls_frame = tk.Frame(self.root)
        self.controls_frame.grid(row=0, column=1, rowspan=2)

        # rotation
        self.rotate_frame = tk.Frame(self.controls_frame)
        self.rotate_frame.grid(row=0, column=0)
        self.rotate_none = tk.Radiobutton(self.rotate_frame, text="No rotation", value=ROTATE_NONE,
                                          variable=self.rotation, command=self.set_rotation)
        self.rotate_none.grid(row=0, column=1)
        self.rotate_none.select()

        self.rotate_left = tk.Radiobutton(self.rotate_frame, text=ROTATE_LEFT, value=ROTATE_LEFT, variable=self.rotation,
                                          command=self.set_rotation)
        self.rotate_left.grid(row=1, column=1)

        self.rotate_right = tk.Radiobutton(self.rotate_frame, text=ROTATE_RIGHT, value=ROTATE_RIGHT, variable=self.rotation,
                                           command=self.set_rotation)
        self.rotate_right.grid(row=2, column=1)

        self.rotate_180 = tk.Radiobutton(self.rotate_frame, text=ROTATE_180, value=ROTATE_180, variable=self.rotation,
                                         command=self.set_rotation)
        self.rotate_180.grid(row=3, column=1)
        self.precise_rotate_entry = tk.Entry(self.rotate_frame, textvariable=self.precise_rotate, width=10,
                                          validate="key",
                                          validatecommand=(self.root.register(self.change_precise_rotate), "%P"))
        self.precise_rotate_entry.grid(row=0, column=2)

        self.rotate_save = tk.Button(self.rotate_frame, text="Rotate following pages", command=self.rotate_pages)
        self.rotate_save.grid(row=1, column=2)

        # rename
        self.rename_frame = tk.Frame(self.controls_frame)
        self.rename_frame.grid(row=2, column=0)

        self.page_number_frame = tk.Frame(self.rename_frame)
        self.page_number_frame.grid(row=0, column=0, columnspan=2)
        self.page_number_entry = tk.Entry(self.page_number_frame, textvariable=self.page_number, width=10,
                                          validate="key",
                                          validatecommand=(self.root.register(self.change_page_number), "%P"))
        self.page_number_entry.grid(row=3, column=2)
        self.page_number_label = tk.Label(self.page_number_frame, text="Page Number")
        self.page_number_label.grid(row=3, column=3)
        self.page_number_increase_entry = tk.Entry(self.rename_frame, width=2, textvariable=self.page_number_increase)
        self.page_number_increase_entry.grid(row=4, column=0)
        self.page_number_increase_label = tk.Label(self.rename_frame, text="+Page Number")
        self.page_number_increase_label.grid(row=4, column=1)

        self.page_number_renumber = tk.Button(self.rename_frame, text="Renumber following pages",
                                              command=self.renumber_pages)
        self.page_number_renumber.grid(row=5, column=0, columnspan=2)

        # crop
        self.crop_frame = tk.Frame(self.controls_frame)
        self.crop_frame.grid(row=3, column=0)
        self.crop_none = tk.Radiobutton(self.crop_frame, text="No crop", value=ROTATE_NONE,
                                        variable=self.crop_type, command=self.set_crop)
        self.crop_none.grid(row=0, column=1)
        self.crop_none.select()

        self.crop_basic = tk.Radiobutton(self.crop_frame, text="Crop", value=CROP_BASIC, variable=self.crop_type,
                                         command=self.set_crop)
        self.crop_basic.grid(row=1, column=1)

        self.crop_save = tk.Button(self.crop_frame, text="Crop following pages", command=self.crop_pages)
        self.crop_save.grid(row=1, column=2)

        self.other_controls_frame = tk.Frame(self.controls_frame)
        self.other_controls_frame.grid(row=4, column=0)
        # New series
        self.new_series_button = tk.Button(self.other_controls_frame, text="Set new series", command=self.set_new_series)
        self.new_series_button.grid(row=0, column=0, sticky="S")

        # output
        self.output_file_button = tk.Button(self.other_controls_frame, text="Export", command=self.export_files)
        self.output_file_button.grid(row=0, column=1, sticky="S")

        # file list
        self.image_browse = tk.Button(self.root, text="Browse", command=self.set_directory)
        self.image_browse.grid(row=0, column=2, columnspan=2, sticky="N")

        self.image_list = tk.Listbox(self.root, exportselection=False, selectmode=tk.BROWSE)
        self.image_list.grid(row=1, column=2, sticky="NS")
        self.image_list.bind('<<ListboxSelect>>', self.look_at_image)

        self.image_list_scrollbar = tk.Scrollbar(self.root, orient=tk.VERTICAL)
        self.image_list_scrollbar.grid(row=1, column=3, sticky="NS")

        self.image_list_scrollbar['command'] = self.image_list.yview
        self.image_list['yscrollcommand'] = self.image_list_scrollbar.set

        self.root.rowconfigure(1, weight=1)
        self.root.columnconfigure(0, weight=1)

        self.thumbnail = self.image_viewer.create_image(0, 0, anchor=tk.NW, image=None)

        self.crop_rectangle = self.image_viewer.create_rectangle(0,0,0,0)
        self.image_viewer.itemconfigure(self.crop_rectangle, state=tk.HIDDEN)
        self.crop_points = []
        for index in range(4):
            self.crop_points.append(self.image_viewer.create_oval(0,0,0,0, fill="#FFFFFF", activefill="#FF3333"))
            reference = self.crop_points[-1]
            self.image_viewer.itemconfigure(reference, state=tk.HIDDEN)
            self.image_viewer.tag_bind(reference, "<ButtonPress-1>",
                                       (lambda e, x=index: self.drag_start(e, x)))
            self.image_viewer.tag_bind(reference, "<ButtonRelease-1>", self.drag_stop)
            self.image_viewer.tag_bind(reference, "<B1-Motion>", self.drag)

        self.current_drag_item = None

        self.root.bind('<Up>', self.move)
        self.root.bind('<Down>', self.move)
        self.root.bind('<Prior>', self.move)
        self.root.bind('<Next>', self.move)

        self.image_list.bind('<Up>', self.move)
        self.image_list.bind('<Down>', self.move)
        self.image_list.bind('<Prior>', self.move)
        self.image_list.bind('<Next>', self.move)

    def move(self, event):
        """Binding for keys"""
        index = self.get_file_index()
        if index is None:
            return
        if event.keysym == 'Up':
            index = max(0, index - 1)
        elif event.keysym == 'Down':
            index = min(self.image_list.size() - 1, index + 1)
        elif event.keysym == 'Prior':
            index = max(0, index - 10)
        elif event.keysym == 'Next':
            index = min(self.image_list.size() - 1, index + 10)
        else:
            return
        self._move_listbox_selection(index)
        self.look_at_image()
        return "break"

    def _move_listbox_selection(self, index):
        """Change the selected element in the listbox, which is much harder than it needs to be."""
        self.image_list.selection_clear(0, tk.END)
        self.image_list.selection_set(index)
        self.image_list.activate(index)
        self.image_list.see(index)

    def drag_start(self, event, crop_point):
        self.current_drag_item = crop_point

    def drag(self, event):
        """Called when dragging a crop handle around"""
        # 0 3
        # 1 2
        modify_x = {0: 1, 1: 0, 2: 3, 3: 2}
        modify_y = {0: 3, 1: 2, 2: 1, 3: 0}
        x = min(max(0, event.x / self.image_thumb.width()), 1)
        y = min(max(0, event.y / self.image_thumb.height()), 1)
        self.operations[self.current_filename][OPERATION_CROP_POINTS][self.current_drag_item] = (
            x, y
        )
        x_index = modify_x[self.current_drag_item]
        self.operations[self.current_filename][OPERATION_CROP_POINTS][x_index] = (
            x, self.operations[self.current_filename][OPERATION_CROP_POINTS][x_index][1])
        y_index = modify_y[self.current_drag_item]
        self.operations[self.current_filename][OPERATION_CROP_POINTS][y_index] = (
            self.operations[self.current_filename][OPERATION_CROP_POINTS][y_index][0], y)

        self.show_crop_points()

    def drag_stop(self, event):
        self.current_drag_item = None

    def save(self) -> None:
        """Save self.operations to json file."""
        filename = os.path.join(self._current_directory, ".page_scanner_wizard.json")
        with open(filename, mode="w") as fp:
            json.dump(self.operations, fp)

    def load(self) -> None:
        """Load self.operations from json file."""
        filename = os.path.join(self._current_directory, ".page_scanner_wizard.json")
        if not os.path.exists(filename):
            return
        with open(filename, mode="r") as fp:
            temp_operations = json.load(fp)  # indirection is so self.operations remains a defaultdict
            self.operations.update(temp_operations)

    def set_directory(self) -> None:
        """Fill in filenames after setting the directory"""
        file_name = filedialog.askopenfilename()
        if not file_name:
            return
        dir_name = os.path.dirname(file_name)
        self._current_directory = dir_name
        files = [filename for filename in os.listdir(dir_name) if filename[-4:].lower() in (".jpg", ".png")]
        files.sort()
        self._filename_list.set(files)
        self.renumber_pages(0)
        self.load()
        self.refresh_image_list()
        self.look_at_image()

    def get_file_index(self) -> int:
        index = self.image_list.curselection()
        if index != ():
            self._current_index = index[0]
        return self._current_index

    def get_filename(self, index):
        filename = self._filename_list.get()[index]
        return os.path.join(self._current_directory, filename)

    def change_page_number(self, *_unused):
        if not self._filename_list.get():
            return
        index = self.get_file_index()
        new_page_number = self.page_number.get()
        if self.image_list.get(index) == new_page_number:
            return
        filename = self.get_filename(index)
        self.operations[filename][OPERATION_RENAME] = new_page_number
        self.refresh_image_list(index)
        self.image_list.selection_set(index)
        self.image_list.activate(index)
        return True

    def set_new_series(self):
        if not self._filename_list.get():
            return
        index = self.get_file_index()
        filename = self.get_filename(index)
        current_value = self.operations[filename].get(OPERATION_NEW_SERIES, False)
        self.operations[filename][OPERATION_NEW_SERIES] = not current_value
        self.look_at_image()

    def export_files(self):
        output = os.path.join(self._current_directory, "output")
        if not os.path.exists(output):
            os.mkdir(os.path.join(self._current_directory, "output"))
        self.root.after(1, self.export_file, list(self._filename_list.get()), 0, len(self._filename_list.get()))

    def export_file(self, file_list: List, current_index, total_file_count):
        if not file_list:
            self.output_file_button.config(text="Export")
            return
        output = os.path.join(self._current_directory, "output")
        base_filename = file_list[0]
        file_list.pop(0)
        self.output_file_button.config(text="{}/{}".format(current_index, total_file_count))
        current_index += 1

        filename = os.path.join(self._current_directory, base_filename)
        if not self.operations[filename].get(OPERATION_RENAME):  # name files "" to skip them
            self.root.after(1, self.export_file, file_list, current_index, total_file_count)
            return

        original_image = Image.open(filename)

        # rotate
        rotate_degrees = self.get_rotation_degrees(filename)
        original_image = original_image.rotate(rotate_degrees, expand=True)

        # crop
        if self.operations[filename].get(OPERATION_CROP) == CROP_BASIC:
            original_image = original_image.crop((
                self.operations[filename][OPERATION_CROP_POINTS][0][0] * original_image.width,
                self.operations[filename][OPERATION_CROP_POINTS][0][1] * original_image.height,
                self.operations[filename][OPERATION_CROP_POINTS][2][0] * original_image.width,
                self.operations[filename][OPERATION_CROP_POINTS][2][1] * original_image.height
            ))
        output_filename = os.path.join(output, self.operations[filename].get(OPERATION_RENAME))
        original_image.save(output_filename, quality=95)
        self.root.after(10, self.export_file, file_list, current_index, total_file_count)

    def rotate_pages(self):
        start_index = self.get_file_index()
        for partial_filename in self._filename_list.get()[start_index:]:
            filename = os.path.join(self._current_directory, partial_filename)
            if self.operations[filename].get(OPERATION_NEW_SERIES) \
                    and partial_filename != self._filename_list.get()[start_index]:
                break
            self.operations[filename][OPERATION_ROTATE] = self.rotation.get()
            self.operations[filename][OPERATION_PRECISE_ROTATE] = self.precise_rotate.get()

    def get_next_page_number(self, page_number: Text) -> Text:
        try:
            match = page_number_regex.findall(page_number)[0]
        except IndexError:
            return "?.jpg"
        number_width = len(match)
        if not self.page_number_increase.get().isdigit():
            self.page_number_increase.set("2")
        actual_number = str(int(match) + int(self.page_number_increase.get()))
        actual_number = actual_number.rjust(number_width, '0')
        return page_number_regex.sub(actual_number, page_number)

    def renumber_pages(self, start_index=None):
        if start_index is None:
            start_index = self.get_file_index()
        current_page_number = self.page_number.get()
        for partial_filename in self._filename_list.get()[start_index:]:
            filename = os.path.join(self._current_directory, partial_filename)
            if self.operations[filename].get(OPERATION_NEW_SERIES) \
                    and partial_filename != self._filename_list.get()[start_index]:
                break
            self.operations[filename][OPERATION_RENAME] = current_page_number
            current_page_number = self.get_next_page_number(current_page_number)
        self.refresh_image_list()
        self.image_list.selection_set(start_index)
        self.image_list.activate(start_index)
        self.image_list.see(start_index)

    def refresh_image_list(self, index=None):
        if index:
            self.image_list.delete(index)
            filename = self._filename_list.get()[index]
            current_page_number = self.operations.get(
                os.path.join(self._current_directory, filename), {}).get(OPERATION_RENAME, filename)
            self.image_list.insert(index, current_page_number)
            return
        self.image_list.delete(0, tk.END)
        new_filenames = [self.operations.get(os.path.join(self._current_directory, filename), {}).get(
            OPERATION_RENAME, filename) for filename in self._filename_list.get()]
        self.image_list.insert(tk.END, *new_filenames)

    def set_rotation(self, *_unused):
        self.operations[self.current_filename][OPERATION_ROTATE] = self.rotation.get()
        self.look_at_image()

    @staticmethod
    def precise_rotate_correct(rotate_value):
        try:
            return float(rotate_value)
        except ValueError:
            return 0

    def change_precise_rotate(self, *unused):
        self.operations[self.current_filename][OPERATION_PRECISE_ROTATE] = self.precise_rotate.get()
        self.look_at_image()

    def set_crop(self, *unused):
        self.operations[self.current_filename][OPERATION_CROP] = self.crop_type.get()
        if self.operations[self.current_filename][OPERATION_CROP] == CROP_BASIC:
            if OPERATION_CROP_POINTS not in self.operations[self.current_filename]:
                self.operations[self.current_filename][OPERATION_CROP_POINTS] = [
                    (0,0), (0, 1), (1,1), (1, 0)
                ]
        self.look_at_image()

    def crop_pages(self, *unused):
        start_index = self.get_file_index()
        for partial_filename in self._filename_list.get()[start_index:]:
            filename = os.path.join(self._current_directory, partial_filename)
            if self.operations[filename].get(OPERATION_NEW_SERIES) \
                    and partial_filename != self._filename_list.get()[start_index]:
                break
            self.operations[filename][OPERATION_CROP] = self.crop_type.get()
            if self.crop_type.get() == CROP_BASIC:
                self.operations[filename][OPERATION_CROP_POINTS] = self.operations[
                    self.current_filename][OPERATION_CROP_POINTS].copy()

    def get_rotation(self, filename):
        if OPERATION_ROTATE not in self.operations[filename]:
            self.operations[filename][OPERATION_ROTATE] = self.default_rotation
        return self.operations[filename][OPERATION_ROTATE]

    def get_precise_rotation(self, filename):
        if OPERATION_PRECISE_ROTATE not in self.operations[filename]:
            self.operations[filename][OPERATION_PRECISE_ROTATE] = 0
        return self.precise_rotate_correct(self.operations[filename][OPERATION_PRECISE_ROTATE])

    def get_rotation_degrees(self, filename):
        rotate = self.get_rotation(filename)
        lookup = {
            ROTATE_NONE: 0,
            ROTATE_LEFT: 90,
            ROTATE_RIGHT: -90,
            ROTATE_180: 180
        }
        return lookup.get(rotate, 0) - self.get_precise_rotation(filename)

    def get_crop(self, filename):
        if OPERATION_CROP not in self.operations[filename]:
            self.operations[filename][OPERATION_CROP] = CROP_NONE
        return self.operations[filename][OPERATION_CROP]

    def look_at_image(self, *_unused):
        if not self._filename_list.get():
            return
        index = self.get_file_index()
        if index is None:
            return
        self.current_filename = self.get_filename(index)
        self.rotation.set(self.operations[self.current_filename].get(OPERATION_ROTATE, ROTATE_NONE))
        self.precise_rotate.set(str(self.operations[self.current_filename].get(OPERATION_PRECISE_ROTATE, "")))

        self.page_number.set(self.operations[self.current_filename].get(
            OPERATION_RENAME, self.get_next_page_number(self.page_number.get())))

        self.original_image = Image.open(self.current_filename)
        temp_thumbnail = self.original_image.copy()
        rotate_degrees = self.get_rotation_degrees(self.current_filename)
        temp_thumbnail = temp_thumbnail.rotate(rotate_degrees, expand=True)
        temp_thumbnail.thumbnail((self.image_viewer.winfo_width(), self.image_viewer.winfo_height()))
        self.image_thumb = ImageTk.PhotoImage(image=temp_thumbnail)
        self.image_viewer.itemconfigure(self.thumbnail, image=self.image_thumb)

        self.show_crop_points()
        new_series_text = (OPERATION_NEW_SERIES if self.operations[self.current_filename].get(OPERATION_NEW_SERIES)
                           else OPERATION_NOT_NEW_SERIES)
        self.new_series_button.config(text=new_series_text)

        self.image_viewer.update()
        self.save()

    def show_crop_points(self):
        self.crop_type.set(self.operations[self.current_filename].get(OPERATION_CROP, CROP_NONE))
        if self.operations[self.current_filename].get(OPERATION_CROP) in (None, CROP_NONE):
            for index in range(4):
                self.image_viewer.itemconfigure(
                    self.crop_points[index],
                    state=tk.HIDDEN
                )
            self.image_viewer.itemconfigure(self.crop_rectangle, state=tk.HIDDEN)

        elif self.operations[self.current_filename][OPERATION_CROP] == CROP_BASIC:
            for index in range(4):
                self.image_viewer.coords(
                    self.crop_points[index],
                    self.operations[self.current_filename][OPERATION_CROP_POINTS][index][0]*self.image_thumb.width() - 10,
                    self.operations[self.current_filename][OPERATION_CROP_POINTS][index][1]*self.image_thumb.height() - 10,
                    self.operations[self.current_filename][OPERATION_CROP_POINTS][index][0]*self.image_thumb.width() + 10,
                    self.operations[self.current_filename][OPERATION_CROP_POINTS][index][1]*self.image_thumb.height() + 10
                )
                self.image_viewer.itemconfigure(self.crop_points[index], state=tk.NORMAL)
            self.image_viewer.coords(self.crop_rectangle,
                                     min(self.operations[self.current_filename][OPERATION_CROP_POINTS][index][0]
                                         for index in range(4)) * self.image_thumb.width(),
                                     min(self.operations[self.current_filename][OPERATION_CROP_POINTS][index][1]
                                         for index in range(4)) * self.image_thumb.height(),
                                     max(self.operations[self.current_filename][OPERATION_CROP_POINTS][index][0]
                                         for index in range(4)) * self.image_thumb.width(),
                                     max(self.operations[self.current_filename][OPERATION_CROP_POINTS][index][1]
                                         for index in range(4)) * self.image_thumb.height()
                                     )
            self.image_viewer.itemconfigure(self.crop_rectangle, state=tk.NORMAL)




main_window = MainWindow()
main_window.root.mainloop()
