class Solution:
    def spiralOrder(self, matrix: List[List[int]]) -> List[int]:
        rows = len(matrix)
        cols = len(matrix[0])
        visited = set()
        spiral = []
        i, j = 0, 0
        spiral.append(matrix[0][0])
        visited.add((0, 0))
        #print((0, 0))
        while True:
            count = 0
            while j+1 < cols and (i, j+1) not in visited:
                j += 1
                spiral.append(matrix[i][j])
                visited.add((i, j))
                count += 1
                #print((i, j))
            while i+1 < rows and (i+1, j) not in visited:
                i += 1
                spiral.append(matrix[i][j])
                visited.add((i, j))
                count += 1
                #print((i, j))
            while j > 0 and (i, j-1) not in visited:
                j -= 1
                spiral.append(matrix[i][j])
                visited.add((i, j))
                count += 1
                #print((i, j))
            while i > 0 and (i-1, j) not in visited:
                i -= 1
                spiral.append(matrix[i][j])
                visited.add((i, j))
                count += 1
                #print((i, j))
            
            if count == 0:
                break
        return spiral