# Page Scanner Wizard, copyright 2023 Evil Mr Henry
# Distribution allowed under the GPL 3 or later.

import json
import tkinter as tk
import tkinter.filedialog as filedialog
import os.path
import os
from collections import defaultdict
from typing import Text, List, Optional

import regex as regex
from PIL import Image, ImageTk

# TODO: picture with "" filename breaks everything

ROTATE_NONE = "None"
ROTATE_LEFT = "Left"
ROTATE_RIGHT = "Right"
ROTATE_180 = "180 Degrees"

OPERATION_ROTATE = "Rotate"
OPERATION_PRECISE_ROTATE = "Precise_rotate"
OPERATION_RENAME = "Rename"
OPERATION_CROP = "Crop"
OPERATION_CROP_POINTS = "Crop points"
OPERATION_NEW_SERIES = "New series"
OPERATION_NOT_NEW_SERIES = "Not New series"

CROP_NONE = "None"
CROP_BASIC = "Crop"

page_number_regex = regex.compile(r"(\d+)")


class MainWindow(object):
    def __init__(self):
        self.operations = defaultdict(dict)
        self._current_directory = None
        self._current_index = None
        self.current_filename = None
        self.image_thumb = None
        self.original_image = None
        self.root = tk.Tk()
        self._filename_list = tk.Variable(self.root, value=[])
        self.rotation = tk.StringVar(self.root, ROTATE_NONE)
        self.crop_type = tk.StringVar(self.root, CROP_NONE)
        self.page_number = tk.StringVar(self.root, "m0001.jpg")
        self.page_number.trace("w", self.change_page_number)
        self.precise_rotate = tk.StringVar(self.root, "0")
        self.precise_rotate.trace("w", self.change_precise_rotate)
        self.page_number_increase = tk.StringVar(self.root, "2")
        self.default_rotation = ROTATE_NONE

        self.image_viewer = tk.Canvas(self.root,  background="#990099")
        self.image_viewer.grid(row=0, column=0, rowspan=2, sticky="NSEW")
        self.image_viewer.bind("<Configure>", self.look_at_image)

        self.controls_frame = tk.Frame(self.root)
        self.controls_frame.grid(row=0, column=1, rowspan=2)

        # rotation
        self.rotate_frame = tk.Frame(self.controls_frame)
        self.rotate_frame.grid(row=0, column=0)
        self.rotate_none = tk.Radiobutton(self.rotate_frame, text="No rotation", value=ROTATE_NONE,
                                          variable=self.rotation, command=self.set_rotation)
        self.rotate_none.grid(row=0, column=1)
        self.rotate_none.select()

        self.rotate_left = tk.Radiobutton(self.rotate_frame, text=ROTATE_LEFT, value=ROTATE_LEFT, variable=self.rotation,
                                          command=self.set_rotation)
        self.rotate_left.grid(row=1, column=1)

        self.rotate_right = tk.Radiobutton(self.rotate_frame, text=ROTATE_RIGHT, value=ROTATE_RIGHT, variable=self.rotation,
                                           command=self.set_rotation)
        self.rotate_right.grid(row=2, column=1)

        self.rotate_180 = tk.Radiobutton(self.rotate_frame, text=ROTATE_180, value=ROTATE_180, variable=self.rotation,
                                         command=self.set_rotation)
        self.rotate_180.grid(row=3, column=1)
        self.precise_rotate_entry = tk.Entry(self.rotate_frame, textvariable=self.precise_rotate, width=10,
                                          validate="key",
                                          validatecommand=(self.root.register(self.change_precise_rotate), "%P"))
        self.precise_rotate_entry.grid(row=0, column=2)

        self.rotate_save = tk.Button(self.rotate_frame, text="Rotate following pages", command=self.rotate_pages)
        self.rotate_save.grid(row=1, column=2)

        # rename
        self.rename_frame = tk.Frame(self.controls_frame)
        self.rename_frame.grid(row=2, column=0)

        self.page_number_frame = tk.Frame(self.rename_frame)
        self.page_number_frame.grid(row=0, column=0, columnspan=2)
        self.page_number_entry = tk.Entry(self.page_number_frame, textvariable=self.page_number, width=10,
                                          validate="key",
                                          validatecommand=(self.root.register(self.change_page_number), "%P"))
        self.page_number_entry.grid(row=3, column=2)
        self.page_number_label = tk.Label(self.page_number_frame, text="Page Number")
        self.page_number_label.grid(row=3, column=3)
        self.page_number_increase_entry = tk.Entry(self.rename_frame, width=2, textvariable=self.page_number_increase)
        self.page_number_increase_entry.grid(row=4, column=0)
        self.page_number_increase_label = tk.Label(self.rename_frame, text="+Page Number")
        self.page_number_increase_label.grid(row=4, column=1)

        self.page_number_renumber = tk.Button(self.rename_frame, text="Page insert",
                                              command=self.insert_page)
        self.page_number_renumber.grid(row=5, column=0, columnspan=2)

        self.page_number_renumber = tk.Button(self.rename_frame, text="Renumber following pages",
                                              command=self.renumber_pages)
        self.page_number_renumber.grid(row=6, column=0, columnspan=2)

        # crop
        self.crop_frame = tk.Frame(self.controls_frame)
        self.crop_frame.grid(row=3, column=0)
        self.crop_none = tk.Radiobutton(self.crop_frame, text="No crop", value=ROTATE_NONE,
                                        variable=self.crop_type, command=self.set_crop)
        self.crop_none.grid(row=0, column=1)
        self.crop_none.select()

        self.crop_basic = tk.Radiobutton(self.crop_frame, text="Crop", value=CROP_BASIC, variable=self.crop_type,
                                         command=self.set_crop)
        self.crop_basic.grid(row=1, column=1)

        self.crop_save = tk.Button(self.crop_frame, text="Crop following pages", command=self.crop_pages)
        self.crop_save.grid(row=1, column=2)

        self.other_controls_frame = tk.Frame(self.controls_frame)
        self.other_controls_frame.grid(row=4, column=0)
        # New series
        self.new_series_button = tk.Button(self.other_controls_frame, text="Set new series", command=self.set_new_series)
        self.new_series_button.grid(row=0, column=0, sticky="S")

        # output
        self.output_file_button = tk.Button(self.other_controls_frame, text="Export", command=self.export_files)
        self.output_file_button.grid(row=0, column=1, sticky="S")

        # file list
        self.image_browse = tk.Button(self.root, text="Browse", command=self.set_directory)
        self.image_browse.grid(row=0, column=2, columnspan=2, sticky="N")

        self.image_list = tk.Listbox(self.root, exportselection=False, selectmode=tk.BROWSE)
        self.image_list.grid(row=1, column=2, sticky="NS")
        self.image_list.bind('<<ListboxSelect>>', self.look_at_image)

        self.image_list_scrollbar = tk.Scrollbar(self.root, orient=tk.VERTICAL)
        self.image_list_scrollbar.grid(row=1, column=3, sticky="NS")

        self.image_list_scrollbar['command'] = self.image_list.yview
        self.image_list['yscrollcommand'] = self.image_list_scrollbar.set

        self.root.rowconfigure(1, weight=1)
        self.root.columnconfigure(0, weight=1)

        self.thumbnail = self.image_viewer.create_image(0, 0, anchor=tk.NW, image=None)

        self.crop_rectangle = self.image_viewer.create_rectangle(0,0,0,0)
        self.image_viewer.itemconfigure(self.crop_rectangle, state=tk.HIDDEN)
        self.crop_points = []
        for index in range(4):
            self.crop_points.append(self.image_viewer.create_oval(0,0,0,0, fill="#FFFFFF", activefill="#FF3333"))
            reference = self.crop_points[-1]
            self.image_viewer.itemconfigure(reference, state=tk.HIDDEN)
            self.image_viewer.tag_bind(reference, "<ButtonPress-1>",
                                       (lambda e, x=index: self.drag_start(e, x)))
            self.image_viewer.tag_bind(reference, "<ButtonRelease-1>", self.drag_stop)
            self.image_viewer.tag_bind(reference, "<B1-Motion>", self.drag)

        self.current_drag_item = None

        self.root.bind('<Up>', self.move)
        self.root.bind('<Down>', self.move)
        self.root.bind('<Prior>', self.move)
        self.root.bind('<Next>', self.move)

        self.image_list.bind('<Up>', self.move)
        self.image_list.bind('<Down>', self.move)
        self.image_list.bind('<Prior>', self.move)
        self.image_list.bind('<Next>', self.move)

    def move(self, event):
        """Binding for keys"""
        index = self.get_file_index()
        if index is None:
            return
        if event.keysym == 'Up':
            index = max(0, index - 1)
        elif event.keysym == 'Down':
            index = min(self.image_list.size() - 1, index + 1)
        elif event.keysym == 'Prior':
            index = max(0, index - 10)
        elif event.keysym == 'Next':
            index = min(self.image_list.size() - 1, index + 10)
        else:
            return
        self._move_listbox_selection(index)
        self.look_at_image()
        return "break"

    def _move_listbox_selection(self, index):
        """Change the selected element in the listbox, which is much harder than it needs to be."""
        self.image_list.selection_clear(0, tk.END)
        self.image_list.selection_set(index)
        self.image_list.activate(index)
        self.image_list.see(index)

    def drag_start(self, event, crop_point):
        """Called on mouse-down on crop handles"""
        self.current_drag_item = crop_point

    def drag(self, event):
        """Called when dragging a crop handle around"""
        # 0 3
        # 1 2
        modify_x = {0: 1, 1: 0, 2: 3, 3: 2}
        modify_y = {0: 3, 1: 2, 2: 1, 3: 0}
        x = min(max(0, event.x / self.image_thumb.width()), 1)
        y = min(max(0, event.y / self.image_thumb.height()), 1)
        self.operations[self.current_filename][OPERATION_CROP_POINTS][self.current_drag_item] = (
            x, y
        )
        x_index = modify_x[self.current_drag_item]
        self.operations[self.current_filename][OPERATION_CROP_POINTS][x_index] = (
            x, self.operations[self.current_filename][OPERATION_CROP_POINTS][x_index][1])
        y_index = modify_y[self.current_drag_item]
        self.operations[self.current_filename][OPERATION_CROP_POINTS][y_index] = (
            self.operations[self.current_filename][OPERATION_CROP_POINTS][y_index][0], y)

        self.show_crop_points()

    def drag_stop(self, event):
        """Called on mouse-up on a crop handle"""
        self.current_drag_item = None

    def save(self) -> None:
        """Save self.operations to json file."""
        filename = os.path.join(self._current_directory, ".page_scanner_wizard.json")
        with open(filename, mode="w") as fp:
            json.dump(self.operations, fp)

    def load(self) -> None:
        """Load self.operations from json file."""
        filename = os.path.join(self._current_directory, ".page_scanner_wizard.json")
        if not os.path.exists(filename):
            return
        with open(filename, mode="r") as fp:
            temp_operations = json.load(fp)  # indirection is so self.operations remains a defaultdict
            self.operations.update(temp_operations)

    def set_directory(self) -> None:
        """Fill in filenames after setting the directory"""
        file_name = filedialog.askopenfilename()
        if not file_name:
            return
        dir_name = os.path.dirname(file_name)
        self._current_directory = dir_name
        files = [filename for filename in os.listdir(dir_name) if filename[-4:].lower() in (".jpg", ".png")]
        files.sort()
        self._filename_list.set(files)
        self.renumber_pages(0)
        self.load()
        self.refresh_image_list()
        self.look_at_image()

    def get_file_index(self) -> int:
        """Get current index from the listbox."""
        index = self.image_list.curselection()
        if index != ():
            self._current_index = index[0]
        return self._current_index

    def get_filename(self, index):
        """Get filename from a given index"""
        filename = self._filename_list.get()[index]
        return os.path.join(self._current_directory, filename)

    def change_page_number(self, *_unused):
        """Change page number based on the page number textbox."""
        if not self._filename_list.get():
            return
        index = self.get_file_index()
        new_page_number = self.page_number.get()
        if self.image_list.get(index) == new_page_number:
            return
        filename = self.get_filename(index)
        self.operations[filename][OPERATION_RENAME] = new_page_number
        self.refresh_image_list(index)
        self.image_list.selection_set(index)
        self.image_list.activate(index)
        return True

    def set_new_series(self):
        """Switch value of "New series" of selected page."""
        if not self._filename_list.get():
            return
        index = self.get_file_index()
        filename = self.get_filename(index)
        current_value = self.operations[filename].get(OPERATION_NEW_SERIES, False)
        self.operations[filename][OPERATION_NEW_SERIES] = not current_value
        self.look_at_image()

    def export_files(self):
        """Start export process."""
        output = os.path.join(self._current_directory, "output")
        if not os.path.exists(output):
            os.mkdir(os.path.join(self._current_directory, "output"))
        self.root.after(1, self.export_file, list(self._filename_list.get()), 0, len(self._filename_list.get()))

    def export_file(self, file_list: List, current_index, total_file_count):
        """Export one file. Will call itself recursively until it runs out of files."""
        if not file_list:
            self.output_file_button.config(text="Export")
            return
        output = os.path.join(self._current_directory, "output")
        base_filename = file_list[0]
        file_list.pop(0)
        self.output_file_button.config(text="{}/{}".format(current_index, total_file_count))
        current_index += 1

        filename = os.path.join(self._current_directory, base_filename)
        if not self.operations[filename].get(OPERATION_RENAME):  # name files "" to skip them
            self.root.after(1, self.export_file, file_list, current_index, total_file_count)
            return
        output_filename = os.path.join(output, self.operations[filename].get(OPERATION_RENAME))
        if not output_filename.endswith((".jpg", ".png")):
            self.output_file_button.config(text="ERROR")
            print("unusable filename:", output_filename)
            return

        original_image = Image.open(filename)

        # rotate
        rotate_degrees = self.get_rotation_degrees(filename)
        original_image = original_image.rotate(rotate_degrees, expand=True)

        # crop
        if self.operations[filename].get(OPERATION_CROP) == CROP_BASIC:
            original_image = original_image.crop((
                self.operations[filename][OPERATION_CROP_POINTS][0][0] * original_image.width,
                self.operations[filename][OPERATION_CROP_POINTS][0][1] * original_image.height,
                self.operations[filename][OPERATION_CROP_POINTS][2][0] * original_image.width,
                self.operations[filename][OPERATION_CROP_POINTS][2][1] * original_image.height
            ))
        original_image.save(output_filename, quality=95)
        self.root.after(10, self.export_file, file_list, current_index, total_file_count)

    def rotate_pages(self):
        """Rotate all pages after current page."""
        start_index = self.get_file_index()
        for partial_filename in self._filename_list.get()[start_index:]:
            filename = os.path.join(self._current_directory, partial_filename)
            if self.operations[filename].get(OPERATION_NEW_SERIES) \
                    and partial_filename != self._filename_list.get()[start_index]:
                break
            self.operations[filename][OPERATION_ROTATE] = self.rotation.get()
            self.operations[filename][OPERATION_PRECISE_ROTATE] = self.precise_rotate.get()

    def get_next_page_number(self, page_number: Text) -> Text:
        """Given a page number, will return the next highest page number."""
        try:
            match = page_number_regex.findall(page_number)[0]
        except IndexError:
            return "?.jpg"
        number_width = len(match)
        if not self.page_number_increase.get().isdigit():
            self.page_number_increase.set("2")
        actual_number = str(int(match) + int(self.page_number_increase.get()))
        actual_number = actual_number.rjust(number_width, '0')
        return page_number_regex.sub(actual_number, page_number)

    def renumber_pages(self, start_index: Optional[int]=None) -> None:
        """Renumber pages starting at the given page/current page"""
        if start_index is None:
            start_index = self.get_file_index()
        current_page_number = self.page_number.get()
        for partial_filename in self._filename_list.get()[start_index:]:
            filename = os.path.join(self._current_directory, partial_filename)
            if self.operations[filename].get(OPERATION_NEW_SERIES) \
                    and partial_filename != self._filename_list.get()[start_index]:
                break
            self.operations[filename][OPERATION_RENAME] = current_page_number
            current_page_number = self.get_next_page_number(current_page_number)
        self.refresh_image_list()
        self.image_list.selection_set(start_index)
        self.image_list.activate(start_index)
        self.image_list.see(start_index)

    def insert_page(self):
        """For the Page Insert button. Will renumber based on the previous page."""
        current_index = self.get_file_index()
        current_filename = self.get_filename(current_index)
        current_rename = self.operations[current_filename][OPERATION_RENAME]
        if current_index == 0:
            return
        previous_index = current_index - 1
        previous_filename = self.get_filename(previous_index)
        self.page_number.set(self.operations[previous_filename][OPERATION_RENAME])
        self.renumber_pages(current_index)
        self.page_number.set(current_rename)

    def refresh_image_list(self, index=None):
        """Refresh listbox after renaming pages."""
        if index:
            self.image_list.delete(index)
            filename = self._filename_list.get()[index]
            current_page_number = self.operations.get(
                os.path.join(self._current_directory, filename), {}).get(OPERATION_RENAME, filename)
            self.image_list.insert(index, current_page_number)
            return
        self.image_list.delete(0, tk.END)
        new_filenames = [self.operations.get(os.path.join(self._current_directory, filename), {}).get(
            OPERATION_RENAME, filename) for filename in self._filename_list.get()]
        self.image_list.insert(tk.END, *new_filenames)

    def set_rotation(self, *_unused):
        """Called when changing the rotation of an image."""
        self.operations[self.current_filename][OPERATION_ROTATE] = self.rotation.get()
        self.look_at_image()

    @staticmethod
    def precise_rotate_correct(rotate_value: Text) -> float:
        """Used to make sure the precise rotation value is actually a number."""
        try:
            return float(rotate_value)
        except ValueError:
            return 0

    def change_precise_rotate(self, *unused):
        """Called when changing the precise rotation value."""
        self.operations[self.current_filename][OPERATION_PRECISE_ROTATE] = self.precise_rotate.get()
        self.look_at_image()

    def set_crop(self, *unused):
        """Called when changing the crop type in the GUI."""
        self.operations[self.current_filename][OPERATION_CROP] = self.crop_type.get()
        if self.operations[self.current_filename][OPERATION_CROP] == CROP_BASIC:
            if OPERATION_CROP_POINTS not in self.operations[self.current_filename]:
                self.operations[self.current_filename][OPERATION_CROP_POINTS] = [
                    (0,0), (0, 1), (1,1), (1, 0)
                ]
        self.look_at_image()

    def crop_pages(self, *unused):
        """Called when cropping subsequent pages."""
        start_index = self.get_file_index()
        for partial_filename in self._filename_list.get()[start_index:]:
            filename = os.path.join(self._current_directory, partial_filename)
            if self.operations[filename].get(OPERATION_NEW_SERIES) \
                    and partial_filename != self._filename_list.get()[start_index]:
                break
            self.operations[filename][OPERATION_CROP] = self.crop_type.get()
            if self.crop_type.get() == CROP_BASIC:
                self.operations[filename][OPERATION_CROP_POINTS] = self.operations[
                    self.current_filename][OPERATION_CROP_POINTS].copy()

    def get_rotation(self, filename: Text):
        """Returns the basic rotation for a given filename"""
        if OPERATION_ROTATE not in self.operations[filename]:
            self.operations[filename][OPERATION_ROTATE] = self.default_rotation
        return self.operations[filename][OPERATION_ROTATE]

    def get_precise_rotation(self, filename: Text):
        """Returns the precise rotation for a given filename"""
        if OPERATION_PRECISE_ROTATE not in self.operations[filename]:
            self.operations[filename][OPERATION_PRECISE_ROTATE] = 0
        return self.precise_rotate_correct(self.operations[filename][OPERATION_PRECISE_ROTATE])

    def get_rotation_degrees(self, filename: Text) -> float:
        """Returns the actual rotation a page needs in degrees."""
        rotate = self.get_rotation(filename)
        lookup = {
            ROTATE_NONE: 0,
            ROTATE_LEFT: 90,
            ROTATE_RIGHT: -90,
            ROTATE_180: 180
        }
        return lookup.get(rotate, 0) - self.get_precise_rotation(filename)

    def get_crop(self, filename):
        if OPERATION_CROP not in self.operations[filename]:
            self.operations[filename][OPERATION_CROP] = CROP_NONE
        return self.operations[filename][OPERATION_CROP]

    def look_at_image(self, *_unused):
        """Refresh the page."""
        if not self._filename_list.get():
            return
        index = self.get_file_index()
        if index is None:
            return
        self.current_filename = self.get_filename(index)
        self.rotation.set(self.operations[self.current_filename].get(OPERATION_ROTATE, ROTATE_NONE))
        self.precise_rotate.set(str(self.operations[self.current_filename].get(OPERATION_PRECISE_ROTATE, "")))

        self.page_number.set(self.operations[self.current_filename].get(
            OPERATION_RENAME, self.get_next_page_number(self.page_number.get())))

        self.original_image = Image.open(self.current_filename)
        temp_thumbnail = self.original_image.copy()
        rotate_degrees = self.get_rotation_degrees(self.current_filename)
        temp_thumbnail = temp_thumbnail.rotate(rotate_degrees, expand=True)
        temp_thumbnail.thumbnail((self.image_viewer.winfo_width(), self.image_viewer.winfo_height()))
        self.image_thumb = ImageTk.PhotoImage(image=temp_thumbnail)
        self.image_viewer.itemconfigure(self.thumbnail, image=self.image_thumb)

        self.show_crop_points()
        new_series_text = (OPERATION_NEW_SERIES if self.operations[self.current_filename].get(OPERATION_NEW_SERIES)
                           else OPERATION_NOT_NEW_SERIES)
        self.new_series_button.config(text=new_series_text)

        self.image_viewer.update()
        self.save()

    def show_crop_points(self):
        """Shows the GUI elements that shows the crop."""
        self.crop_type.set(self.operations[self.current_filename].get(OPERATION_CROP, CROP_NONE))
        if self.operations[self.current_filename].get(OPERATION_CROP) in (None, CROP_NONE):
            for index in range(4):
                self.image_viewer.itemconfigure(
                    self.crop_points[index],
                    state=tk.HIDDEN
                )
            self.image_viewer.itemconfigure(self.crop_rectangle, state=tk.HIDDEN)

        elif self.operations[self.current_filename][OPERATION_CROP] == CROP_BASIC:
            for index in range(4):
                self.image_viewer.coords(
                    self.crop_points[index],
                    self.operations[self.current_filename][OPERATION_CROP_POINTS][index][0]*self.image_thumb.width() - 10,
                    self.operations[self.current_filename][OPERATION_CROP_POINTS][index][1]*self.image_thumb.height() - 10,
                    self.operations[self.current_filename][OPERATION_CROP_POINTS][index][0]*self.image_thumb.width() + 10,
                    self.operations[self.current_filename][OPERATION_CROP_POINTS][index][1]*self.image_thumb.height() + 10
                )
                self.image_viewer.itemconfigure(self.crop_points[index], state=tk.NORMAL)
            self.image_viewer.coords(self.crop_rectangle,
                                     min(self.operations[self.current_filename][OPERATION_CROP_POINTS][index][0]
                                         for index in range(4)) * self.image_thumb.width(),
                                     min(self.operations[self.current_filename][OPERATION_CROP_POINTS][index][1]
                                         for index in range(4)) * self.image_thumb.height(),
                                     max(self.operations[self.current_filename][OPERATION_CROP_POINTS][index][0]
                                         for index in range(4)) * self.image_thumb.width(),
                                     max(self.operations[self.current_filename][OPERATION_CROP_POINTS][index][1]
                                         for index in range(4)) * self.image_thumb.height()
                                     )
            self.image_viewer.itemconfigure(self.crop_rectangle, state=tk.NORMAL)




main_window = MainWindow()
main_window.root.mainloop()
