Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 51 additions & 0 deletions cacheflow/logger.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
# Adapted from https://github.com/skypilot-org/skypilot/blob/86dc0f6283a335e4aa37b3c10716f90999f48ab6/sky/sky_logging.py

import logging
import sys


_FORMAT = "%(levelname)s %(asctime)s %(filename)s:%(lineno)d] %(message)s"
_DATE_FORMAT = "%m-%d %H:%M:%S"


class NewLineFormatter(logging.Formatter):
"""Adds logging prefix to newlines to align multi-line messages."""

def __init__(self, fmt, datefmt=None):
logging.Formatter.__init__(self, fmt, datefmt)

def format(self, record):
msg = logging.Formatter.format(self, record)
if record.message != "":
parts = msg.split(record.message)
msg = msg.replace("\n", "\r\n" + parts[0])
return msg


_root_logger = logging.getLogger("cacheflow")
_default_handler = None


def _setup_logger():
_root_logger.setLevel(logging.DEBUG)
global _default_handler
if _default_handler is None:
_default_handler = logging.StreamHandler(sys.stdout)
_default_handler.flush = sys.stdout.flush # type: ignore
_default_handler.setLevel(logging.INFO)
_root_logger.addHandler(_default_handler)
fmt = NewLineFormatter(_FORMAT, datefmt=_DATE_FORMAT)
_default_handler.setFormatter(fmt)
# Setting this will avoid the message
# being propagated to the parent logger.
_root_logger.propagate = False


# The logger is initialized when the module is imported.
# This is thread-safe as the module is only imported once,
# guaranteed by the Python GIL.
_setup_logger()


def init_logger(name: str):
return logging.getLogger(name)
19 changes: 16 additions & 3 deletions cacheflow/master/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
except ImportError:
ray = None

from cacheflow.logger import init_logger
from cacheflow.master.scheduler import Scheduler
from cacheflow.master.simple_frontend import SimpleFrontend
from cacheflow.models import get_memory_analyzer
Expand All @@ -17,6 +18,9 @@
from cacheflow.utils import get_gpu_memory, get_cpu_memory


logger = init_logger(__name__)


class Server:
def __init__(
self,
Expand All @@ -42,6 +46,17 @@ def __init__(
collect_stats: bool = False,
do_memory_analysis: bool = False,
):
logger.info(
"Initializing a server with config: "
f"model={model!r}, "
f"dtype={dtype}, "
f"use_dummy_weights={use_dummy_weights}, "
f"cache_dir={cache_dir}, "
f"use_np_cache={use_np_cache}, "
f"tensor_parallel_size={tensor_parallel_size}, "
f"block_size={block_size}, "
f"seed={seed})"
)
self.num_nodes = num_nodes
self.num_devices_per_node = num_devices_per_node
self.world_size = pipeline_parallel_size * tensor_parallel_size
Expand All @@ -61,9 +76,7 @@ def __init__(
self.num_gpu_blocks = self.memory_analyzer.get_max_num_gpu_blocks(
max_num_batched_tokens=max_num_batched_tokens)
self.num_cpu_blocks = self.memory_analyzer.get_max_num_cpu_blocks(
swap_space=swap_space)
print(f'# GPU blocks: {self.num_gpu_blocks}, '
f'# CPU blocks: {self.num_cpu_blocks}')
swap_space_gib=swap_space)

# Create a controller for each pipeline stage.
self.controllers: List[Controller] = []
Expand Down
8 changes: 6 additions & 2 deletions cacheflow/master/simple_frontend.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,17 @@
import time
from typing import List, Optional, Set, Tuple
from typing import List, Optional, Tuple

from transformers import AutoTokenizer

from cacheflow.logger import init_logger
from cacheflow.sampling_params import SamplingParams
from cacheflow.sequence import Sequence, SequenceGroup
from cacheflow.utils import Counter


logger = init_logger(__name__)


class SimpleFrontend:

def __init__(
Expand Down Expand Up @@ -66,4 +70,4 @@ def print_response(
token_ids = seq.get_token_ids()
output = self.tokenizer.decode(token_ids, skip_special_tokens=True)
output = output.strip()
print(f'Seq {seq.seq_id}: {output!r}')
logger.info(f"Seq {seq.seq_id}: {output!r}")
18 changes: 11 additions & 7 deletions cacheflow/models/memory_analyzer.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,12 @@
import torch
from transformers import AutoConfig

from cacheflow.logger import init_logger
from cacheflow.models.utils import get_dtype_size


logger = init_logger(__name__)

_GiB = 1 << 30


Expand All @@ -23,20 +27,20 @@ def get_cache_block_size(self) -> int:

def get_max_num_cpu_blocks(
self,
swap_space: int,
swap_space_gib: int,
) -> int:
swap_space = swap_space * _GiB
swap_space = swap_space_gib * _GiB
cpu_memory = self.cpu_memory
if swap_space > 0.8 * cpu_memory:
raise ValueError(f'The swap space ({swap_space / _GiB:.2f} GiB) '
raise ValueError(f'The swap space ({swap_space_gib:.2f} GiB) '
'takes more than 80% of the available memory '
f'({cpu_memory / _GiB:.2f} GiB).'
'Please check the swap space size.')
if swap_space > 0.5 * cpu_memory:
print(f'WARNING: The swap space ({swap_space / _GiB:.2f} GiB) '
'takes more than 50% of the available memory '
f'({cpu_memory / _GiB:.2f} GiB).'
'This may slow the system performance.')
logger.info(f'WARNING: The swap space ({swap_space_gib:.2f} GiB) '
'takes more than 50% of the available memory '
f'({cpu_memory / _GiB:.2f} GiB).'
'This may slow the system performance.')
max_num_blocks = swap_space // self.get_cache_block_size()
return max_num_blocks

Expand Down
2 changes: 1 addition & 1 deletion simple_server.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
import argparse
from typing import List

from cacheflow.master.server import (
add_server_arguments, process_server_arguments,
init_local_server_and_frontend_with_arguments)
from cacheflow.sampling_params import SamplingParams


def main(args: argparse.Namespace):
server, frontend = init_local_server_and_frontend_with_arguments(args)
# Test the following inputs.
Expand Down