printf("ho_tari\n");

[지능] Vision AI 기반 컨베이어 벨트 객체 인식 딥러닝 모델 최적화 5일차 본문

두산 로보틱스 부트캠프 ROKEY/실무 프로젝트

[지능] Vision AI 기반 컨베이어 벨트 객체 인식 딥러닝 모델 최적화 5일차

호타리 2024. 12. 10. 14:20

2024.12.09

 

5일차 마지막 날에는 학습시킨 모델을 이용하여 카메라가 실시간으로 PCD를 인식하고 실시간으로 객체를 탐지해 바운딩 박스를 생성하고 모든 클래스를 탐지하면서 그 PCD가 정상품이라면 정상품으로 판단하고 불량품이면 불량품으로 판단하도록 하였다.

 

 

 

 

 

 

import time
import serial
import requests
import numpy as np
import os
import cv2
import json
import logging
import random
import sqlite3
from datetime import datetime
from requests.auth import HTTPBasicAuth
import threading
import queue  # 큐를 사용하기 위한 임포트

# -------------------------------
# 1. Logging Configuration
# -------------------------------
logging.basicConfig(
    level=logging.DEBUG,  # DEBUG 레벨로 변경하여 더 자세한 로그를 확인
    format='%(asctime)s [%(levelname)s] %(message)s',
    handlers=[
        logging.FileHandler("app.log", encoding='utf-8'),  # Log file (UTF-8 encoding)
        logging.StreamHandler()                              # Console output
    ]
)

# -------------------------------
# 2. API Configuration
# -------------------------------
URL = "https://suite-endpoint-api-apne2.superb-ai.com/endpoints/f928ebdf-89bd-4611-9c8a-dc9fe2da0049/inference"
ACCESS_KEY = "NbeRrNlMV99thyzJbJJ4k5NqIKm61TQU9iXRfYRx"
USERNAME = "kdt2024_1-26"

# -------------------------------
# 3. Serial Port Initialization
# -------------------------------
SERIAL_PORT = "/dev/ttyACM0"  # 필요에 따라 변경
BAUD_RATE = 9600

try:
    ser = serial.Serial(SERIAL_PORT, BAUD_RATE, timeout=1)
    logging.info(f"Connected to serial port {SERIAL_PORT}.")
except serial.SerialException as e:
    logging.error(f"Failed to connect to serial port: {e}")
    exit(1)

# -------------------------------
# 4. Directory Configuration
# -------------------------------
SAVE_DIR = "/home/rokey/vision-ai-inference-practice/save_dir2"
ANNOTATED_DIR = "/home/rokey/vision-ai-inference-practice/annotate_dir2"
NORMAL_DIR = "/home/rokey/vision-ai-inference-practice/annotate_dir2/normal"
DEFECTIVE_DIR = "/home/rokey/vision-ai-inference-practice/annotate_dir2/defective"
DB_PATH = "/home/rokey/vision-ai-inference-practice/products.db"

# Create directories if they don't exist
for directory in [SAVE_DIR, ANNOTATED_DIR, NORMAL_DIR, DEFECTIVE_DIR]:
    os.makedirs(directory, exist_ok=True)
    logging.debug(f"Ensured directory exists: {directory}")

# -------------------------------
# 5. Normal Product Judgment Criteria
# -------------------------------
NORMAL_CRITERIA = {
    "BOOTSEL": 1,
    "CHIPSET": 1,
    "OSCILLATOR": 1,
    "USB": 1,
    "HOLE": 4
}

# -------------------------------
# 6. Database Initialization
# -------------------------------
def init_db(db_path=DB_PATH):
    """Initialize the database and create tables if they don't exist."""
    try:
        conn = sqlite3.connect(db_path, check_same_thread=False)  # 스레드 간 연결 허용
        cursor = conn.cursor()
        # Create table for normal products
        cursor.execute("""
            CREATE TABLE IF NOT EXISTS normal_products (
                id INTEGER PRIMARY KEY AUTOINCREMENT,
                timestamp TEXT NOT NULL,
                image_path TEXT NOT NULL
            )
        """)
        # Create table for defective products
        cursor.execute("""
            CREATE TABLE IF NOT EXISTS defective_products (
                id INTEGER PRIMARY KEY AUTOINCREMENT,
                timestamp TEXT NOT NULL,
                image_path TEXT NOT NULL
            )
        """)
        conn.commit()
        logging.info(f"Database initialized successfully: {db_path}")
        return conn
    except sqlite3.Error as e:
        logging.error(f"Failed to initialize database: {e}")
        exit(1)

# -------------------------------
# 7. Function Definitions
# -------------------------------

def get_latest_frame(cam, num_frames=5):
    """Read multiple frames to get the latest frame from the camera."""
    frame = None
    for _ in range(num_frames):
        ret, frame = cam.read()
        if not ret:
            logging.error("Failed to read frame from camera.")
            return None
    logging.debug("Captured latest frame.")
    return frame

def crop_img(img):
    """Crop the image to the specified region and adjust brightness."""
    size_dict = {"x": 800, "y": 380, "width": 400, "height": 550}
    x = size_dict["x"]
    y = size_dict["y"]
    w = size_dict["width"]
    h = size_dict["height"]
    img_cropped = img[y : y + h, x : x + w]
    # Adjust brightness
    img_bright = cv2.convertScaleAbs(img_cropped, alpha=1, beta=5)
    logging.debug("Cropped and adjusted image brightness.")
    return img_bright

def save_img(img, save_dir, timestamp):
    """Save the image to the specified directory with a unique timestamp-based name."""
    image_filename = f"image_{timestamp}.jpg"
    image_path = os.path.join(save_dir, image_filename)
    try:
        cv2.imwrite(image_path, img)
        logging.info(f"Image saved: {image_path}")
        return image_path
    except Exception as e:
        logging.error(f"Failed to save image ({image_path}): {e}")
        return None

def inference_request(img: np.array, api_url: str):
    """Send the image to the API for inference.
    Args:
        img (numpy.array): Image to infer
        api_url (str): API endpoint URL
    Returns:
        dict: API response JSON or None
    """
    try:
        # Encode image as JPEG
        _, img_encoded = cv2.imencode(".jpg", img)
        image_data = img_encoded.tobytes()
        logging.debug("Encoded image as JPEG for API request.")

        # Send API request
        response = requests.post(
            url=api_url,
            auth=HTTPBasicAuth(USERNAME, ACCESS_KEY),
            headers={"Content-Type": "image/jpeg"},
            data=image_data
        )

        # Log response status code
        logging.info(f"API response status code: {response.status_code}")

        if response.status_code == 200:
            response_json = response.json()
            logging.info(f"API response: {json.dumps(response_json, indent=4, ensure_ascii=False)}")
            return response_json
        else:
            logging.error(f"Image upload failed. Status code: {response.status_code}")
            logging.error(f"Response content: {response.text}")
            return None
    except requests.exceptions.RequestException as e:
        logging.error(f"Error during request transmission: {e}")
        return None
    except Exception as e:
        logging.error(f"Unexpected error occurred: {e}")
        return None

def annotate_image(image, response_json, colors, normal_criteria):
    """Annotate the image based on the API response.
    Args:
        image (numpy.array): Original image
        response_json (dict): API response JSON
        colors (dict): Class-color mapping
        normal_criteria (dict): Normal judgment criteria
    Returns:
        tuple: Annotated image, Normality status (True/False)
    """
    if not response_json:
        logging.warning("No API response for annotation.")
        return image, False

    objects = response_json.get('objects', [])
    if not objects:
        logging.info("No objects detected.")
        return image, False

    class_counts = {}
    # Draw Bounding Boxes and Labels
    for obj in objects:
        class_name = obj['class'].strip().upper()
        score = obj.get('score', 0.0)
        box = obj.get('box', [0, 0, 0, 0])
        if len(box) != 4:
            logging.warning(f"Invalid box coordinates: {box}")
            continue
        x1, y1, x2, y2 = box

        # Count classes
        class_counts[class_name] = class_counts.get(class_name, 0) + 1

        # Draw Bounding Box
        color = colors.get(class_name, (0, 255, 0))  # Default color: Green
        thickness = 2
        cv2.rectangle(image, (x1, y1), (x2, y2), color, thickness)

        # Draw Label
        label = f"{class_name}: {score:.2f}"
        font = cv2.FONT_HERSHEY_SIMPLEX
        font_scale = 0.5
        font_thickness = 1
        text_size, _ = cv2.getTextSize(label, font, font_scale, font_thickness)
        text_w, text_h = text_size
        # Text background rectangle
        cv2.rectangle(image, (x1, y1 - text_h - 4), (x1 + text_w, y1), color, -1)
        # Put Text
        cv2.putText(image, label, (x1, y1 - 2), font, font_scale, (255, 255, 255), font_thickness, cv2.LINE_AA)

    # Sort class counts in descending order
    sorted_class_info = sorted(class_counts.items(), key=lambda item: item[1], reverse=True)
    class_info_lines = [f"{cls}: {count}" for cls, count in sorted_class_info]

    # Draw class information text
    font = cv2.FONT_HERSHEY_SIMPLEX
    font_scale = 0.7
    font_thickness = 2
    line_height = 20
    padding = 10

    # Calculate background rectangle size
    max_text_w = 0
    for line in class_info_lines:
        text_size, _ = cv2.getTextSize(line, font, font_scale, font_thickness)
        text_w, _ = text_size
        if text_w > max_text_w:
            max_text_w = text_w
    total_text_h = line_height * len(class_info_lines)

    # Draw background rectangle for class info
    cv2.rectangle(
        image,
        (padding - 5, padding - 5),
        (padding + max_text_w + 10, padding + total_text_h + 5),
        (0, 0, 0),  # Black
        -1
    )

    # Put each class info line
    for idx, line in enumerate(class_info_lines):
        text_position = (padding, padding + (idx + 1) * line_height)
        cv2.putText(
            image,
            line,
            text_position,
            font,
            font_scale,
            (255, 255, 255),  # White
            font_thickness,
            cv2.LINE_AA
        )

    # Determine if the product is normal
    is_normal = True
    for key, value in normal_criteria.items():
        detected = class_counts.get(key.upper(), 0)
        if detected != value:
            logging.warning(f"Mismatch detected - {key.upper()}: Expected {value}, Found {detected}")
            is_normal = False
            break

    # Set status text and color based on judgment
    if is_normal:
        status_text = "Normal"
        status_color = (0, 255, 0)  # Green
    else:
        status_text = "Poor"
        status_color = (0, 0, 255)  # Red

    # Draw status text
    status_font_scale = 1.0
    status_font_thickness = 2
    status_text_size, _ = cv2.getTextSize(status_text, font, status_font_scale, status_font_thickness)
    status_text_w, status_text_h = status_text_size

    # Calculate position for status text at the bottom-right corner
    image_height, image_width = image.shape[:2]
    status_padding = 10
    status_x = image_width - status_text_w - status_padding
    status_y = image_height - status_padding

    # Draw background rectangle for status text
    cv2.rectangle(
        image,
        (status_x - 5, status_y - status_text_h - 5),
        (status_x + status_text_w + 5, status_y + 5),
        status_color,
        -1
    )

    # Put status text (white text)
    cv2.putText(
        image,
        status_text,
        (status_x, status_y),
        font,
        status_font_scale,
        (255, 255, 255),  # White
        status_font_thickness,
        cv2.LINE_AA
    )

    return image, is_normal  # Return two values

# -------------------------------
# 8. Real-Time Detection Function
# -------------------------------
def real_time_detection(cam, colors, normal_criteria, api_url, image_queue, stop_event):
    """실시간 객체 탐지 및 주석 달기 함수."""
    while not stop_event.is_set():
        ret, frame = cam.read()
        if not ret:
            logging.error("실시간 객체 탐지: 프레임을 읽을 수 없습니다.")
            continue

        # 이미지 크롭 및 밝기 조정
        img_cropped = crop_img(frame)

        # API 요청
        response = inference_request(img_cropped, api_url)
        if response and 'objects' in response:
            # 새로운 클래스에 대한 색상 매핑 업데이트
            for obj in response['objects']:
                class_name = obj['class'].strip().upper()
                if class_name not in colors:
                    colors[class_name] = tuple(random.randint(0, 255) for _ in range(3))
        else:
            logging.warning("실시간 객체 탐지: API 응답에 'objects'가 없습니다.")

        # 이미지 주석 달기
        annotated_img, is_normal = annotate_image(img_cropped.copy(), response, colors, normal_criteria)

        # Annotated image를 큐에 넣기
        try:
            image_queue.put_nowait(annotated_img)
            logging.debug("Annotated image added to queue.")
        except queue.Full:
            logging.warning("Image queue is full. Skipping frame.")

        # 작은 지연을 줌으로써 CPU 사용량을 줄임
        time.sleep(0.1)

    logging.info("실시간 객체 탐지 스레드 종료.")

# -------------------------------
# 9. Main Function Definition
# -------------------------------
def main():
    # Initialize database
    conn = init_db()
    cursor = conn.cursor()

    # Initialize camera
    cam = cv2.VideoCapture(0)
    if not cam.isOpened():
        logging.error("Cannot open camera.")
        exit(-1)
    cam.set(cv2.CAP_PROP_FRAME_WIDTH, 1920)
    cam.set(cv2.CAP_PROP_FRAME_HEIGHT, 1080)
    logging.info("Camera initialized successfully.")

    # Optimize camera settings if needed
    # cam.set(cv2.CAP_PROP_AUTO_EXPOSURE, 0.25)  # Adjust based on camera model
    # cam.set(cv2.CAP_PROP_EXPOSURE, -4)         # Example value; adjust as needed

    # Initialize color mapping dictionary
    colors = {}

    # Create a queue for passing images from detection thread to main thread
    image_queue = queue.Queue(maxsize=10)

    # Create a threading event to signal thread termination
    stop_event = threading.Event()

    # Start real-time detection in a separate thread
    detection_thread = threading.Thread(
        target=real_time_detection,
        args=(cam, colors, NORMAL_CRITERIA, URL, image_queue, stop_event),
        daemon=True  # Daemonize thread to exit when main thread exits
    )
    detection_thread.start()
    logging.info("실시간 객체 탐지 스레드 시작.")

    try:
        while True:
            # Display any images from the image queue
            try:
                annotated_img = image_queue.get_nowait()
                cv2.imshow("실시간 객체 탐지", annotated_img)
                logging.debug("Displayed image from queue.")
            except queue.Empty:
                pass

            # Read data from serial port only if data is available
            if ser.in_waiting > 0:
                data = ser.read(1)  # Read 1 byte
                logging.debug(f"Read data from serial port: {data}")
                if data == b"0":
                    logging.info("Trigger signal received: Starting image capture.")

                    # Record start time
                    start_time = time.time()

                    # Add delay after trigger signal to stabilize object
                    time.sleep(0.5)  # Wait for 0.5 seconds

                    # Capture latest frame
                    img = get_latest_frame(cam, num_frames=5)
                    if img is None:
                        continue
                    img_cropped = crop_img(img)

                    # Generate unique timestamp-based filename
                    timestamp = datetime.now().strftime("%Y%m%d_%H%M%S_%f")
                    image_path = save_img(img_cropped, SAVE_DIR, timestamp)
                    if image_path is None:
                        continue

                    # Send image to API
                    logging.info(f"Sending image {timestamp} to API.")
                    response = inference_request(img_cropped, URL)
                    if response and 'objects' in response:
                        # Update color mapping for new classes
                        for obj in response['objects']:
                            class_name = obj['class'].strip().upper()
                            if class_name not in colors:
                                colors[class_name] = tuple(random.randint(0, 255) for _ in range(3))
                    else:
                        logging.warning(f"No 'objects' in API response. Image {timestamp} considered defective.")

                    # Annotate image
                    annotated_img, is_normal = annotate_image(img_cropped.copy(), response, colors, NORMAL_CRITERIA)

                    # Determine save path based on normality
                    if is_normal:
                        annotated_image_path = os.path.join(NORMAL_DIR, f"annotated_normal_{timestamp}.jpg")
                    else:
                        annotated_image_path = os.path.join(DEFECTIVE_DIR, f"annotated_defective_{timestamp}.jpg")

                    # Save annotated image
                    try:
                        cv2.imwrite(annotated_image_path, annotated_img)
                        logging.info(f"Annotated image saved: {annotated_image_path}")
                    except Exception as e:
                        logging.error(f"Failed to save annotated image ({annotated_image_path}): {e}")
                        continue

                    # Record entry in the database
                    try:
                        current_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
                        if is_normal:
                            cursor.execute("""
                                INSERT INTO normal_products (timestamp, image_path)
                                VALUES (?, ?)
                            """, (current_time, annotated_image_path))
                            logging.info(f"Recorded as Normal product: {current_time}, {annotated_image_path}")
                        else:
                            cursor.execute("""
                                INSERT INTO defective_products (timestamp, image_path)
                                VALUES (?, ?)
                            """, (current_time, annotated_image_path))
                            logging.info(f"Recorded as Defective product: {current_time}, {annotated_image_path}")
                        conn.commit()
                    except sqlite3.Error as e:
                        logging.error(f"Failed to record in database: {e}")

                    # Send response back via serial port
                    try:
                        ser.write(b"1")
                        logging.info(f"Sent response '1' for image {timestamp}.")
                    except serial.SerialException as e:
                        logging.error(f"Failed to send response via serial port: {e}")

                    # Record end time
                    end_time = time.time()

                    # Calculate elapsed time
                    elapsed_time = end_time - start_time
                    logging.info(f"Elapsed time for processing image {timestamp}: {elapsed_time:.2f} seconds.")

                    # Display image (optional)
                    cv2.imshow("Annotated Image", annotated_img)
                    if cv2.waitKey(1) & 0xFF == ord('q'):
                        logging.info("Quit signal received: Preparing to exit program.")
                        break

            # Handle GUI events and exit condition
            if cv2.waitKey(1) & 0xFF == ord('q'):
                logging.info("Quit signal received: Preparing to exit program.")
                break

    except KeyboardInterrupt:
        logging.info("KeyboardInterrupt received: Preparing to exit program.")
    except Exception as e:
        logging.error(f"Error in main loop: {e}")
    finally:
        # Signal the detection thread to stop
        stop_event.set()
        detection_thread.join()
        logging.info("실시간 객체 탐지 스레드 종료 대기.")

        # Close database connection
        try:
            conn.close()
            logging.info("Database connection closed.")
        except Exception as e:
            logging.error(f"Failed to close database connection: {e}")

        # Release resources
        try:
            cam.release()
            ser.close()
            cv2.destroyAllWindows()
            logging.info("Camera and serial port resources released.")
        except Exception as e:
            logging.error(f"Error during resource release: {e}")
        logging.info("Program terminated successfully.")

if __name__ == "__main__":
    main()

 

라즈베리 피코 검사 시스템 제안서_F_1.pdf
0.86MB

 

 

 

결과적으로 높은 정확도의 모델을 생성하진 못했지만 그 과정에서 라벨링을 통한 컴퓨터 비전 기술에 대해 제대로 배우고 경험할 수 있었다.