import io
import tempfile
import subprocess
import requests
import base64
from flask import Flask, render_template_string, request, send_from_directory, session
from PIL import Image, ImageDraw, ImageFont
import os
import re
import argparse
app = Flask(__name__)
app.secret_key = 'TpfxYMxUJxJMtCCTraBSg1rbd6NLz38JTmIfpBLsotcI47EqXU'
FONT_PATH = "./fonts/ubuntu/Ubuntu-Regular.ttf"
FONT_PATH_BOLD = "./fonts/ubuntu/Ubuntu-Bold.ttf"
FONT_PATH_OBLIQUE = "./fonts/ubuntu/Ubuntu-Italic.ttf"
FONT_PATH_BOLDITALIC = "./fonts/ubuntu/Ubuntu-BoldItalic.ttf"
FONT_SIZE = 24
HEADER_SIZE_1 = 56
HEADER_SIZE_2 = 34
HEADER_SIZE_3 = 28
BANNER_FONT_SIZE = 300
IMAGE_WIDTH = 384
BULLET_CHAR = "• "
DITHERING_MODES = {
  "Sin dithering": "none",
  "Floyd-Steinberg": "floyd",
  "Bayer 2x2": "bayer2x2",
  "Bayer 4x4": "bayer4x4",
  "Bayer 8x8": "bayer8x8",
  "Bayer 16x16": "bayer16x16",
  "Atkinson": "atkinson",
  "Jarvis-Judice-Ninke": "jjn"
}
HTML_FORM = '''
  CatNote
  
  
  
    
    
      {% if img %}
        
Vista previa
        
      {% else %}
        
Su vista previa aparecerá aquí
      {% endif %}
    
      Referencia rápida de Markdown
      
        - Negrita: **texto**
- Cursiva: *texto*
- Negrita y cursiva: ***texto***
- Encabezado grande: # Título
- Encabezado mediano: ## Título
- Encabezado chico: ### Título
- Lista con viñetas: - Elemento
- Lista numerada: 1. Elemento
- Imágen: 
- Salto de línea: Deje una línea vacía
- Imagen subida: !(img)(usa la imagen cargada abajo)
 
   
  
'''
def remove_uploaded_img():
    path = session.pop('uploaded_img_path', None)
    if path and os.path.exists(path):
        try:
            os.remove(path)
        except Exception:
            pass
def bleh_image_from_url(url, dithering, mode):
    resp = requests.get(url, stream=True)
    resp.raise_for_status()
    bleh = subprocess.Popen(
        ["./bleh", "-o", "-", "-mode", f"{mode}", "-d", f"{dithering}", "-"], 
        stdin=subprocess.PIPE, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
    out, err = bleh.communicate(resp.content)
    if bleh.returncode != 0:
        raise RuntimeError(f"Driver failed: {err.decode()}")
    img = Image.open(io.BytesIO(out)).convert("L")
    # Optionally check width, pad/resize if needed
    if img.width != IMAGE_WIDTH:
        img = img.resize((IMAGE_WIDTH, img.height), Image.LANCZOS)
    return img
def bleh_image_from_bytes(image_bytes, dithering, mode):
    bleh = subprocess.Popen(
        ["./bleh", "-o", "-", "-mode", f"{mode}", "-d", f"{dithering}", "-"],
        stdin=subprocess.PIPE, stdout=subprocess.PIPE, stderr=subprocess.PIPE
    )
    out, err = bleh.communicate(image_bytes)
    if bleh.returncode != 0:
        raise RuntimeError(f"Driver failed: {err.decode()}")
    img = Image.open(io.BytesIO(out)).convert("L")
    if img.width != IMAGE_WIDTH:
        img = img.resize((IMAGE_WIDTH, img.height), Image.LANCZOS)
    return img
def rotate_image_bytes(image_bytes, rotation):
    if rotation == 0:
        return image_bytes
    img = Image.open(io.BytesIO(image_bytes))
    img = img.rotate(-rotation, expand=True)  # PIL rotates counterclockwise, negative for clockwise
    buf = io.BytesIO()
    img.save(buf, format="PNG")
    return buf.getvalue()
def parse_line(line):
  if re.match(r"!\(img\)", line) or re.match(r"!\[.*\]\(img\)", line):
      return ('userimage', None)
  image_match = re.match(r"^!\[(.*?)\]\((.+?)\)", line)
  if image_match:
      alt = image_match.group(1)
      url = image_match.group(2)
      return ('image', url, alt)
  if re.match(r'^\s*---\s*$', line):
      return ('hr', [])
  header_match = re.match(r"^(#{1,3}) +(.*)", line)
  if header_match:
      header_level = len(header_match.group(1))
      line = header_match.group(2)
      return ('header', header_level, parse_segments(line))
  bullet_match = re.match(r"^\s*([-*\u2022]) +(.*)", line)
  if bullet_match:
      return ('bullet', parse_segments(bullet_match.group(2)))
  ordered_match = re.match(r"^\s*(\d+)\. +(.*)", line)
  if ordered_match:
      return ('ordered', int(ordered_match.group(1)), parse_segments(ordered_match.group(2)))
  return ('text', parse_segments(line))
def parse_segments(line):
    # Handle escaped asterisks: replace them with a placeholder
    line = line.replace(r'\*', '\x07')
    # Apply formatting
    bi = re.findall(r"\*\*\*(.+?)\*\*\*", line)
    for x in bi:
        line = line.replace(f"***{x}***", f"\x01{x}\x02")
    b = re.findall(r"\*\*(.+?)\*\*", line)
    for x in b:
        line = line.replace(f"**{x}**", f"\x03{x}\x04")
    i = re.findall(r"\*(.+?)\*", line)
    for x in i:
        line = line.replace(f"*{x}*", f"\x05{x}\x06")
    # Split into styled segments
    segments = []
    i = 0
    while i < len(line):
        if line[i] == '\x01':
            i += 1
            start = i
            while i < len(line) and line[i] != '\x02': i += 1
            segments.append(('bolditalic', line[start:i]))
            i += 1
        elif line[i] == '\x03':
            i += 1
            start = i
            while i < len(line) and line[i] != '\x04': i += 1
            segments.append(('bold', line[start:i]))
            i += 1
        elif line[i] == '\x05':
            i += 1
            start = i
            while i < len(line) and line[i] != '\x06': i += 1
            segments.append(('italic', line[start:i]))
            i += 1
        else:
            start = i
            while i < len(line) and line[i] not in '\x01\x03\x05': i += 1
            if i > start:
                segments.append(('text', line[start:i]))
    # Restore literal asterisks
    segments = [(style, text.replace('\x07', '*')) for style, text in segments]
    return segments
def font_for_style(style, font, font_bold, font_italic, font_bolditalic):
    if style == 'bolditalic':
        return font_bolditalic or font_bold or font_italic or font
    elif style == 'bold':
        return font_bold
    elif style == 'italic':
        return font_italic
    else:
        return font
def wrap_segments(segments, font, font_bold, font_italic, font_bolditalic, max_width, start_x=0):
    line = []
    x = start_x
    for style, text in segments:
        words = re.split(r'(\s+)', text)
        for word in words:
            if word == '':
                continue
            f = font_for_style(style, font, font_bold, font_italic, font_bolditalic)
            w = f.getbbox(word)[2] - f.getbbox(word)[0]
            if x + w > max_width and line:
                yield line
                line = []
                x = start_x
            line.append((style, word))
            x += w
    if line:
        yield line
def render(md, dithering, printmode, uploaded_img_bytes=None, bannermode=False):
    font = ImageFont.truetype(FONT_PATH, FONT_SIZE)
    font_bold = ImageFont.truetype(FONT_PATH_BOLD, FONT_SIZE)
    font_italic = ImageFont.truetype(FONT_PATH_OBLIQUE, FONT_SIZE)
    font_banner = ImageFont.truetype(FONT_PATH_BOLD, BANNER_FONT_SIZE)
    try:
        font_bolditalic = ImageFont.truetype(FONT_PATH_BOLDITALIC, FONT_SIZE)
    except:
        font_bolditalic = None
    font_h1 = ImageFont.truetype(FONT_PATH_BOLD, HEADER_SIZE_1)
    font_h2 = ImageFont.truetype(FONT_PATH_BOLD, HEADER_SIZE_2)
    font_h3 = ImageFont.truetype(FONT_PATH_BOLD, HEADER_SIZE_3)
    if bannermode:
        # Remove line breaks for single-line banner
        md = md.replace('\r\n', ' ') # should change this if we're ever on windows
        # Only render as a single text line (ignore markdown except bold/italic)
        segments = parse_segments(md)
        # Calculate total width needed
        x = 0
        for style, text in segments:
            f = font_banner
            x += f.getbbox(text)[2] - f.getbbox(text)[0]
        width = max(x, 1)
        # Create image: width x 384 (height is printer width)
        image = Image.new("L", (width, IMAGE_WIDTH), 255)
        draw = ImageDraw.Draw(image)
        x = 0
        y = (IMAGE_WIDTH - BANNER_FONT_SIZE) // 2  # Center vertically
        for style, text in segments:
            f = font_banner
            draw.text((x, y), text, font=f, fill=0)
            x += f.getbbox(text)[2] - f.getbbox(text)[0]
        # Rotate so text is vertical
        image = image.rotate(270, expand=True)
        return image
    
    lines_out = []
    
    for src_line in md.splitlines():
        if src_line.strip() == '':
            lines_out.append(('blank', []))
            continue
        tag = parse_line(src_line)            
        if tag[0] == 'image':
            try:
                image = bleh_image_from_url(tag[1], dithering, printmode)
                lines_out.append(('image', image))
            except Exception as e:
                lines_out.append(('text', [('text', f"[Imagen inválida: {e}]")], font, FONT_SIZE))
        elif tag[0] == 'userimage':
          if uploaded_img_bytes:
            try:
                image = bleh_image_from_bytes(uploaded_img_bytes, dithering, printmode)
                lines_out.append(('image', image))
            except Exception as e:
                lines_out.append(('text', [('text', f"[Error al procesar imagen]")], font, FONT_SIZE))
                print(f"Image processing error: {e}")
          else:
            lines_out.append(('text', [('text', "[No se subió una imagen]")], font, FONT_SIZE))
        elif tag[0] == 'hr':
            lines_out.append(('hr',))
        elif tag[0] == 'header':
            header_level = tag[1]
            segments = tag[2]
            if header_level == 1:
                font_h = font_h1
                size_h = HEADER_SIZE_1
            elif header_level == 2:
                font_h = font_h2
                size_h = HEADER_SIZE_2
            else:
                font_h = font_h3
                size_h = HEADER_SIZE_3
            for wrapped in wrap_segments(segments, font_h, font_h, font_h, font_h, IMAGE_WIDTH):
                lines_out.append(('header', wrapped, font_h, size_h))
        elif tag[0] == 'bullet':
            segments = tag[1]
            bullet_font = font
            bullet_w = bullet_font.getbbox(BULLET_CHAR)[2] - bullet_font.getbbox(BULLET_CHAR)[0]
            wrapped_lines = list(wrap_segments(segments, font, font_bold, font_italic, font_bolditalic, IMAGE_WIDTH - bullet_w, start_x=0))
            for i, wrapped in enumerate(wrapped_lines):
                if i == 0:
                    lines_out.append(('bullet', wrapped, bullet_font, FONT_SIZE, True, bullet_w))
                else:
                    lines_out.append(('bullet', wrapped, bullet_font, FONT_SIZE, False, bullet_w))
        elif tag[0] == 'ordered':
            idx, segments = tag[1], tag[2]
            num_str = f"{idx}. "
            number_font = font
            num_w = number_font.getbbox(num_str)[2] - number_font.getbbox(num_str)[0]
            wrapped_lines = list(wrap_segments(segments, font, font_bold, font_italic, font_bolditalic, IMAGE_WIDTH - num_w, start_x=0))
            for i, wrapped in enumerate(wrapped_lines):
                if i == 0:
                    lines_out.append(('ordered', wrapped, number_font, FONT_SIZE, num_str, True, num_w))
                else:
                    lines_out.append(('ordered', wrapped, number_font, FONT_SIZE, num_str, False, num_w))
        else:  # normal text
            segments = tag[1]
            for wrapped in wrap_segments(segments, font, font_bold, font_italic, font_bolditalic, IMAGE_WIDTH):
                lines_out.append(('text', wrapped, font, FONT_SIZE))
    # Compute total height, including images
    height = 10  # Top margin
    for item in lines_out:
        if item[0] in ('header', 'text', 'bullet', 'ordered'):
            height += item[3]
        elif item[0] == 'hr':
            height += 10
        elif item[0] == 'blank':
            height += FONT_SIZE
        elif item[0] == 'image':
            img = item[1]
            height += img.height + 10  # add margin below image
    image = Image.new("L", (IMAGE_WIDTH, height), 255)
    draw = ImageDraw.Draw(image)
    y = 0
    for item in lines_out:
        if item[0] == 'blank':
            y += FONT_SIZE
        elif item[0] == 'hr':
            draw.line((0, y + 5, IMAGE_WIDTH, y + 5), fill=0, width=2)
            y += 10
        elif item[0] == 'header':
            segments, fnt, sz = item[1], item[2], item[3]
            x = 0
            for style, text in segments:
                draw.text((x, y), text, font=fnt, fill=0)
                x += fnt.getbbox(text)[2] - fnt.getbbox(text)[0]
            y += sz
        elif item[0] == 'bullet':
            segments, bullet_font, sz, show_bullet, bullet_w = item[1], item[2], item[3], item[4], item[5]
            x = 0
            if show_bullet:
                draw.text((x, y), BULLET_CHAR, font=bullet_font, fill=0)
                x += bullet_w
            else:
                x += bullet_w
            for style, text in segments:
                f = font_for_style(style, bullet_font, font_bold, font_italic, font_bolditalic)
                draw.text((x, y), text, font=f, fill=0)
                x += f.getbbox(text)[2] - f.getbbox(text)[0]
            y += sz
        elif item[0] == 'ordered':
            segments, number_font, sz, num_str, show_num, num_w = item[1], item[2], item[3], item[4], item[5], item[6]
            x = 0
            if show_num:
                draw.text((x, y), num_str, font=number_font, fill=0)
                x += num_w
            else:
                x += num_w
            for style, text in segments:
                f = font_for_style(style, number_font, font_bold, font_italic, font_bolditalic)
                draw.text((x, y), text, font=f, fill=0)
                x += f.getbbox(text)[2] - f.getbbox(text)[0]
            y += sz
        elif item[0] == 'text':
            segments, fnt, sz = item[1], item[2], item[3]
            x = 0
            for style, text in segments:
                f = font_for_style(style, font, font_bold, font_italic, font_bolditalic)
                draw.text((x, y), text, font=f, fill=0)
                x += f.getbbox(text)[2] - f.getbbox(text)[0]
            y += sz
        elif item[0] == 'image':
            img = item[1]
            # Center image horizontally if narrower
            img_x = (IMAGE_WIDTH - img.width) // 2 if img.width < IMAGE_WIDTH else 0
            image.paste(img, (img_x, y))
            y += img.height + 10  # vertical margin after image
    if bannermode:
        image = image.rotate(270, expand=True)
    return image
@app.route("/", methods=["GET", "POST"])
def index():
    img_data = None
    md = ""
    printed = False
    error = None
    if request.method == "POST":
        userimg = request.files.get("userimg")
        uploaded_img_bytes = None
        rotation = int(request.form.get("rotation", "0"))
        if userimg and userimg.filename:
            # Remove old temp file if present
            remove_uploaded_img()
            # Save new file to a temp location
            with tempfile.NamedTemporaryFile(delete=False, suffix=".catnote") as tmpf:
                userimg.save(tmpf)
                tmpf.flush()
                session['uploaded_img_path'] = tmpf.name
                tmpf.seek(0)
                uploaded_img_bytes = tmpf.read()
        elif 'uploaded_img_path' in session:
            path = session['uploaded_img_path']
            if os.path.exists(path):
                with open(path, "rb") as f:
                    uploaded_img_bytes = f.read()
            else:
                # If file was deleted or missing, remove from session
                session.pop('uploaded_img_path', None)
        else:
            remove_uploaded_img()
        if uploaded_img_bytes and rotation in (90, 180, 270):
            try:
                uploaded_img_bytes = rotate_image_bytes(uploaded_img_bytes, rotation)
            except Exception as e:
                error = f"Error rotating image: {e}"
                uploaded_img_bytes = None
        md = request.form["md"]
        dithering = request.form.get("dithering", "floyd")
        printmode = request.form.get("printmode", "1bpp")
        bannermode = bool(request.form.get("bannermode"))
        image = render(md, dithering, printmode, uploaded_img_bytes, bannermode=bannermode)
        session['dithering'] = dithering
        session['printmode'] = printmode
        session['rotation'] = rotation
        session['bannermode'] = bannermode
        buf = io.BytesIO()
        image.save(buf, format="PNG")
        buf.seek(0)
        img_data = base64.b64encode(buf.getvalue()).decode()
        intensity = request.form.get("intensity", "85")
        try:
            intensity = int(intensity)
            if intensity < 0 or intensity > 100:
                raise ValueError("Intensity must be between 0 and 100")
        except ValueError:
            error = "Intensidad debe ser un número entre 0 y 100"
            intensity = 85 
        session['intensity'] = intensity
        # If print button pressed, send to driver
        if "print" in request.form:
            try:
                with tempfile.NamedTemporaryFile(suffix=".png", delete=True) as tmpfile:
                    image.save(tmpfile, format="PNG")
                    tmpfile.flush()
                    # Run the bleh command
                    result = subprocess.run([
                        "./bleh",
                        "-mode", f"{printmode}",
                        "-intensity", f"{intensity}",
                        tmpfile.name
                    ], capture_output=True, text=True, timeout=90)
                    if result.returncode != 0:
                        error = f"Printer error: {result.stderr or result.stdout}"
                    else:
                        printed = True
            except Exception as e:
                error = f"Failed to print: {e}"
    return render_template_string(HTML_FORM, dithering_modes=DITHERING_MODES, img=img_data, default_md=md,
        printed=printed, error=error, current_dithering=session.get('dithering', 'floyd'), current_rotation=session.get('rotation', 0), current_printmode=session.get('printmode', '1bpp'), current_bannermode=session.get('bannermode', False))
@app.route('/manifest.json')
def manifest():
    return send_from_directory('static', 'manifest.json')
@app.route('/icon512_maskable.png')
def icon_maskable():
    return send_from_directory('static', 'icon512_maskable.png')
@app.route('/icon512_rounded.png')
def icon_rounded():
    return send_from_directory('static', 'icon512_rounded.png')
if __name__ == "__main__":
    parser = argparse.ArgumentParser(description='CatNote: Server for Markdown to MXW01 Cat Printer')
    parser.add_argument('-p', '--port', type=int, default=5000, help='Port to run the server on (default: 5000)')
    args = parser.parse_args()
    app.run(host='0.0.0.0', port=args.port)