Recent Posts
printf("ho_tari\n");
[지능] Vision AI 기반 컨베이어 벨트 객체 인식 딥러닝 모델 최적화 5일차 본문
2024.12.09
5일차 마지막 날에는 학습시킨 모델을 이용하여 카메라가 실시간으로 PCD를 인식하고 실시간으로 객체를 탐지해 바운딩 박스를 생성하고 모든 클래스를 탐지하면서 그 PCD가 정상품이라면 정상품으로 판단하고 불량품이면 불량품으로 판단하도록 하였다.
import time
import serial
import requests
import numpy as np
import os
import cv2
import json
import logging
import random
import sqlite3
from datetime import datetime
from requests.auth import HTTPBasicAuth
import threading
import queue # 큐를 사용하기 위한 임포트
# -------------------------------
# 1. Logging Configuration
# -------------------------------
logging.basicConfig(
level=logging.DEBUG, # DEBUG 레벨로 변경하여 더 자세한 로그를 확인
format='%(asctime)s [%(levelname)s] %(message)s',
handlers=[
logging.FileHandler("app.log", encoding='utf-8'), # Log file (UTF-8 encoding)
logging.StreamHandler() # Console output
]
)
# -------------------------------
# 2. API Configuration
# -------------------------------
URL = "https://suite-endpoint-api-apne2.superb-ai.com/endpoints/f928ebdf-89bd-4611-9c8a-dc9fe2da0049/inference"
ACCESS_KEY = "NbeRrNlMV99thyzJbJJ4k5NqIKm61TQU9iXRfYRx"
USERNAME = "kdt2024_1-26"
# -------------------------------
# 3. Serial Port Initialization
# -------------------------------
SERIAL_PORT = "/dev/ttyACM0" # 필요에 따라 변경
BAUD_RATE = 9600
try:
ser = serial.Serial(SERIAL_PORT, BAUD_RATE, timeout=1)
logging.info(f"Connected to serial port {SERIAL_PORT}.")
except serial.SerialException as e:
logging.error(f"Failed to connect to serial port: {e}")
exit(1)
# -------------------------------
# 4. Directory Configuration
# -------------------------------
SAVE_DIR = "/home/rokey/vision-ai-inference-practice/save_dir2"
ANNOTATED_DIR = "/home/rokey/vision-ai-inference-practice/annotate_dir2"
NORMAL_DIR = "/home/rokey/vision-ai-inference-practice/annotate_dir2/normal"
DEFECTIVE_DIR = "/home/rokey/vision-ai-inference-practice/annotate_dir2/defective"
DB_PATH = "/home/rokey/vision-ai-inference-practice/products.db"
# Create directories if they don't exist
for directory in [SAVE_DIR, ANNOTATED_DIR, NORMAL_DIR, DEFECTIVE_DIR]:
os.makedirs(directory, exist_ok=True)
logging.debug(f"Ensured directory exists: {directory}")
# -------------------------------
# 5. Normal Product Judgment Criteria
# -------------------------------
NORMAL_CRITERIA = {
"BOOTSEL": 1,
"CHIPSET": 1,
"OSCILLATOR": 1,
"USB": 1,
"HOLE": 4
}
# -------------------------------
# 6. Database Initialization
# -------------------------------
def init_db(db_path=DB_PATH):
"""Initialize the database and create tables if they don't exist."""
try:
conn = sqlite3.connect(db_path, check_same_thread=False) # 스레드 간 연결 허용
cursor = conn.cursor()
# Create table for normal products
cursor.execute("""
CREATE TABLE IF NOT EXISTS normal_products (
id INTEGER PRIMARY KEY AUTOINCREMENT,
timestamp TEXT NOT NULL,
image_path TEXT NOT NULL
)
""")
# Create table for defective products
cursor.execute("""
CREATE TABLE IF NOT EXISTS defective_products (
id INTEGER PRIMARY KEY AUTOINCREMENT,
timestamp TEXT NOT NULL,
image_path TEXT NOT NULL
)
""")
conn.commit()
logging.info(f"Database initialized successfully: {db_path}")
return conn
except sqlite3.Error as e:
logging.error(f"Failed to initialize database: {e}")
exit(1)
# -------------------------------
# 7. Function Definitions
# -------------------------------
def get_latest_frame(cam, num_frames=5):
"""Read multiple frames to get the latest frame from the camera."""
frame = None
for _ in range(num_frames):
ret, frame = cam.read()
if not ret:
logging.error("Failed to read frame from camera.")
return None
logging.debug("Captured latest frame.")
return frame
def crop_img(img):
"""Crop the image to the specified region and adjust brightness."""
size_dict = {"x": 800, "y": 380, "width": 400, "height": 550}
x = size_dict["x"]
y = size_dict["y"]
w = size_dict["width"]
h = size_dict["height"]
img_cropped = img[y : y + h, x : x + w]
# Adjust brightness
img_bright = cv2.convertScaleAbs(img_cropped, alpha=1, beta=5)
logging.debug("Cropped and adjusted image brightness.")
return img_bright
def save_img(img, save_dir, timestamp):
"""Save the image to the specified directory with a unique timestamp-based name."""
image_filename = f"image_{timestamp}.jpg"
image_path = os.path.join(save_dir, image_filename)
try:
cv2.imwrite(image_path, img)
logging.info(f"Image saved: {image_path}")
return image_path
except Exception as e:
logging.error(f"Failed to save image ({image_path}): {e}")
return None
def inference_request(img: np.array, api_url: str):
"""Send the image to the API for inference.
Args:
img (numpy.array): Image to infer
api_url (str): API endpoint URL
Returns:
dict: API response JSON or None
"""
try:
# Encode image as JPEG
_, img_encoded = cv2.imencode(".jpg", img)
image_data = img_encoded.tobytes()
logging.debug("Encoded image as JPEG for API request.")
# Send API request
response = requests.post(
url=api_url,
auth=HTTPBasicAuth(USERNAME, ACCESS_KEY),
headers={"Content-Type": "image/jpeg"},
data=image_data
)
# Log response status code
logging.info(f"API response status code: {response.status_code}")
if response.status_code == 200:
response_json = response.json()
logging.info(f"API response: {json.dumps(response_json, indent=4, ensure_ascii=False)}")
return response_json
else:
logging.error(f"Image upload failed. Status code: {response.status_code}")
logging.error(f"Response content: {response.text}")
return None
except requests.exceptions.RequestException as e:
logging.error(f"Error during request transmission: {e}")
return None
except Exception as e:
logging.error(f"Unexpected error occurred: {e}")
return None
def annotate_image(image, response_json, colors, normal_criteria):
"""Annotate the image based on the API response.
Args:
image (numpy.array): Original image
response_json (dict): API response JSON
colors (dict): Class-color mapping
normal_criteria (dict): Normal judgment criteria
Returns:
tuple: Annotated image, Normality status (True/False)
"""
if not response_json:
logging.warning("No API response for annotation.")
return image, False
objects = response_json.get('objects', [])
if not objects:
logging.info("No objects detected.")
return image, False
class_counts = {}
# Draw Bounding Boxes and Labels
for obj in objects:
class_name = obj['class'].strip().upper()
score = obj.get('score', 0.0)
box = obj.get('box', [0, 0, 0, 0])
if len(box) != 4:
logging.warning(f"Invalid box coordinates: {box}")
continue
x1, y1, x2, y2 = box
# Count classes
class_counts[class_name] = class_counts.get(class_name, 0) + 1
# Draw Bounding Box
color = colors.get(class_name, (0, 255, 0)) # Default color: Green
thickness = 2
cv2.rectangle(image, (x1, y1), (x2, y2), color, thickness)
# Draw Label
label = f"{class_name}: {score:.2f}"
font = cv2.FONT_HERSHEY_SIMPLEX
font_scale = 0.5
font_thickness = 1
text_size, _ = cv2.getTextSize(label, font, font_scale, font_thickness)
text_w, text_h = text_size
# Text background rectangle
cv2.rectangle(image, (x1, y1 - text_h - 4), (x1 + text_w, y1), color, -1)
# Put Text
cv2.putText(image, label, (x1, y1 - 2), font, font_scale, (255, 255, 255), font_thickness, cv2.LINE_AA)
# Sort class counts in descending order
sorted_class_info = sorted(class_counts.items(), key=lambda item: item[1], reverse=True)
class_info_lines = [f"{cls}: {count}" for cls, count in sorted_class_info]
# Draw class information text
font = cv2.FONT_HERSHEY_SIMPLEX
font_scale = 0.7
font_thickness = 2
line_height = 20
padding = 10
# Calculate background rectangle size
max_text_w = 0
for line in class_info_lines:
text_size, _ = cv2.getTextSize(line, font, font_scale, font_thickness)
text_w, _ = text_size
if text_w > max_text_w:
max_text_w = text_w
total_text_h = line_height * len(class_info_lines)
# Draw background rectangle for class info
cv2.rectangle(
image,
(padding - 5, padding - 5),
(padding + max_text_w + 10, padding + total_text_h + 5),
(0, 0, 0), # Black
-1
)
# Put each class info line
for idx, line in enumerate(class_info_lines):
text_position = (padding, padding + (idx + 1) * line_height)
cv2.putText(
image,
line,
text_position,
font,
font_scale,
(255, 255, 255), # White
font_thickness,
cv2.LINE_AA
)
# Determine if the product is normal
is_normal = True
for key, value in normal_criteria.items():
detected = class_counts.get(key.upper(), 0)
if detected != value:
logging.warning(f"Mismatch detected - {key.upper()}: Expected {value}, Found {detected}")
is_normal = False
break
# Set status text and color based on judgment
if is_normal:
status_text = "Normal"
status_color = (0, 255, 0) # Green
else:
status_text = "Poor"
status_color = (0, 0, 255) # Red
# Draw status text
status_font_scale = 1.0
status_font_thickness = 2
status_text_size, _ = cv2.getTextSize(status_text, font, status_font_scale, status_font_thickness)
status_text_w, status_text_h = status_text_size
# Calculate position for status text at the bottom-right corner
image_height, image_width = image.shape[:2]
status_padding = 10
status_x = image_width - status_text_w - status_padding
status_y = image_height - status_padding
# Draw background rectangle for status text
cv2.rectangle(
image,
(status_x - 5, status_y - status_text_h - 5),
(status_x + status_text_w + 5, status_y + 5),
status_color,
-1
)
# Put status text (white text)
cv2.putText(
image,
status_text,
(status_x, status_y),
font,
status_font_scale,
(255, 255, 255), # White
status_font_thickness,
cv2.LINE_AA
)
return image, is_normal # Return two values
# -------------------------------
# 8. Real-Time Detection Function
# -------------------------------
def real_time_detection(cam, colors, normal_criteria, api_url, image_queue, stop_event):
"""실시간 객체 탐지 및 주석 달기 함수."""
while not stop_event.is_set():
ret, frame = cam.read()
if not ret:
logging.error("실시간 객체 탐지: 프레임을 읽을 수 없습니다.")
continue
# 이미지 크롭 및 밝기 조정
img_cropped = crop_img(frame)
# API 요청
response = inference_request(img_cropped, api_url)
if response and 'objects' in response:
# 새로운 클래스에 대한 색상 매핑 업데이트
for obj in response['objects']:
class_name = obj['class'].strip().upper()
if class_name not in colors:
colors[class_name] = tuple(random.randint(0, 255) for _ in range(3))
else:
logging.warning("실시간 객체 탐지: API 응답에 'objects'가 없습니다.")
# 이미지 주석 달기
annotated_img, is_normal = annotate_image(img_cropped.copy(), response, colors, normal_criteria)
# Annotated image를 큐에 넣기
try:
image_queue.put_nowait(annotated_img)
logging.debug("Annotated image added to queue.")
except queue.Full:
logging.warning("Image queue is full. Skipping frame.")
# 작은 지연을 줌으로써 CPU 사용량을 줄임
time.sleep(0.1)
logging.info("실시간 객체 탐지 스레드 종료.")
# -------------------------------
# 9. Main Function Definition
# -------------------------------
def main():
# Initialize database
conn = init_db()
cursor = conn.cursor()
# Initialize camera
cam = cv2.VideoCapture(0)
if not cam.isOpened():
logging.error("Cannot open camera.")
exit(-1)
cam.set(cv2.CAP_PROP_FRAME_WIDTH, 1920)
cam.set(cv2.CAP_PROP_FRAME_HEIGHT, 1080)
logging.info("Camera initialized successfully.")
# Optimize camera settings if needed
# cam.set(cv2.CAP_PROP_AUTO_EXPOSURE, 0.25) # Adjust based on camera model
# cam.set(cv2.CAP_PROP_EXPOSURE, -4) # Example value; adjust as needed
# Initialize color mapping dictionary
colors = {}
# Create a queue for passing images from detection thread to main thread
image_queue = queue.Queue(maxsize=10)
# Create a threading event to signal thread termination
stop_event = threading.Event()
# Start real-time detection in a separate thread
detection_thread = threading.Thread(
target=real_time_detection,
args=(cam, colors, NORMAL_CRITERIA, URL, image_queue, stop_event),
daemon=True # Daemonize thread to exit when main thread exits
)
detection_thread.start()
logging.info("실시간 객체 탐지 스레드 시작.")
try:
while True:
# Display any images from the image queue
try:
annotated_img = image_queue.get_nowait()
cv2.imshow("실시간 객체 탐지", annotated_img)
logging.debug("Displayed image from queue.")
except queue.Empty:
pass
# Read data from serial port only if data is available
if ser.in_waiting > 0:
data = ser.read(1) # Read 1 byte
logging.debug(f"Read data from serial port: {data}")
if data == b"0":
logging.info("Trigger signal received: Starting image capture.")
# Record start time
start_time = time.time()
# Add delay after trigger signal to stabilize object
time.sleep(0.5) # Wait for 0.5 seconds
# Capture latest frame
img = get_latest_frame(cam, num_frames=5)
if img is None:
continue
img_cropped = crop_img(img)
# Generate unique timestamp-based filename
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S_%f")
image_path = save_img(img_cropped, SAVE_DIR, timestamp)
if image_path is None:
continue
# Send image to API
logging.info(f"Sending image {timestamp} to API.")
response = inference_request(img_cropped, URL)
if response and 'objects' in response:
# Update color mapping for new classes
for obj in response['objects']:
class_name = obj['class'].strip().upper()
if class_name not in colors:
colors[class_name] = tuple(random.randint(0, 255) for _ in range(3))
else:
logging.warning(f"No 'objects' in API response. Image {timestamp} considered defective.")
# Annotate image
annotated_img, is_normal = annotate_image(img_cropped.copy(), response, colors, NORMAL_CRITERIA)
# Determine save path based on normality
if is_normal:
annotated_image_path = os.path.join(NORMAL_DIR, f"annotated_normal_{timestamp}.jpg")
else:
annotated_image_path = os.path.join(DEFECTIVE_DIR, f"annotated_defective_{timestamp}.jpg")
# Save annotated image
try:
cv2.imwrite(annotated_image_path, annotated_img)
logging.info(f"Annotated image saved: {annotated_image_path}")
except Exception as e:
logging.error(f"Failed to save annotated image ({annotated_image_path}): {e}")
continue
# Record entry in the database
try:
current_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
if is_normal:
cursor.execute("""
INSERT INTO normal_products (timestamp, image_path)
VALUES (?, ?)
""", (current_time, annotated_image_path))
logging.info(f"Recorded as Normal product: {current_time}, {annotated_image_path}")
else:
cursor.execute("""
INSERT INTO defective_products (timestamp, image_path)
VALUES (?, ?)
""", (current_time, annotated_image_path))
logging.info(f"Recorded as Defective product: {current_time}, {annotated_image_path}")
conn.commit()
except sqlite3.Error as e:
logging.error(f"Failed to record in database: {e}")
# Send response back via serial port
try:
ser.write(b"1")
logging.info(f"Sent response '1' for image {timestamp}.")
except serial.SerialException as e:
logging.error(f"Failed to send response via serial port: {e}")
# Record end time
end_time = time.time()
# Calculate elapsed time
elapsed_time = end_time - start_time
logging.info(f"Elapsed time for processing image {timestamp}: {elapsed_time:.2f} seconds.")
# Display image (optional)
cv2.imshow("Annotated Image", annotated_img)
if cv2.waitKey(1) & 0xFF == ord('q'):
logging.info("Quit signal received: Preparing to exit program.")
break
# Handle GUI events and exit condition
if cv2.waitKey(1) & 0xFF == ord('q'):
logging.info("Quit signal received: Preparing to exit program.")
break
except KeyboardInterrupt:
logging.info("KeyboardInterrupt received: Preparing to exit program.")
except Exception as e:
logging.error(f"Error in main loop: {e}")
finally:
# Signal the detection thread to stop
stop_event.set()
detection_thread.join()
logging.info("실시간 객체 탐지 스레드 종료 대기.")
# Close database connection
try:
conn.close()
logging.info("Database connection closed.")
except Exception as e:
logging.error(f"Failed to close database connection: {e}")
# Release resources
try:
cam.release()
ser.close()
cv2.destroyAllWindows()
logging.info("Camera and serial port resources released.")
except Exception as e:
logging.error(f"Error during resource release: {e}")
logging.info("Program terminated successfully.")
if __name__ == "__main__":
main()
결과적으로 높은 정확도의 모델을 생성하진 못했지만 그 과정에서 라벨링을 통한 컴퓨터 비전 기술에 대해 제대로 배우고 경험할 수 있었다.
'두산 로보틱스 부트캠프 ROKEY > 실무 프로젝트' 카테고리의 다른 글
[협동] DART 플랫폼(두산로보틱스)을 활용한 협동로봇 동작 운영 실습 2일차 (0) | 2024.12.12 |
---|---|
[협동] DART 플랫폼(두산로보틱스)을 활용한 협동로봇 동작 운영 실습 1일차 (0) | 2024.12.12 |
[지능] Vision AI 기반 컨베이어 벨트 객체 인식 딥러닝 모델 최적화 4일차 (0) | 2024.12.10 |
[지능] Vision AI 기반 컨베이어 벨트 객체 인식 딥러닝 모델 최적화 3일차 (0) | 2024.12.06 |
[지능] Vision AI 기반 컨베이어 벨트 객체 인식 딥러닝 모델 최적화 2일차 (0) | 2024.12.06 |