Skip to content

vllm.logprobs

LogprobsOnePosition module-attribute

LogprobsOnePosition = dict[int, Logprob]

PromptLogprobs module-attribute

PromptLogprobs = (
    FlatLogprobs | list[LogprobsOnePosition | None]
)

SampleLogprobs module-attribute

FlatLogprobs dataclass

Bases: MutableSequence[LogprobsOnePosition]

Flat logprobs of a request into multiple primitive type lists.

Compared to list[dict[int, Logprob]], this data structure reduced GC overhead significantly. As it flattened logprob information for all positions and ranks in to multiple primitive type lists (i.e. logprobs, token_ids, ranks per token_ids, decoded_tokens). So regardless of the sequence length and top_logprobs setup, FlatLogprobs would only introduce a constant amount of objects.

As each position might contains different amount of ranks, start_indices_per_position would be used to access the logprob ranges for different positions.

NOTE: To reduce the migration overhead and improve backward compatibility, we support the key Sequence APIs of list, so it could act as list[LogprobsOnePosition]

Source code in vllm/logprobs.py
@dataclass
class FlatLogprobs(MutableSequence[LogprobsOnePosition]):
    """
    Flat logprobs of a request into multiple primitive type lists.

    Compared to list[dict[int, Logprob]], this data structure reduced GC
    overhead significantly. As it flattened logprob information for
    all positions and ranks in to multiple primitive type lists (i.e.
    logprobs, token_ids, ranks per token_ids, decoded_tokens).
    So regardless of the sequence length and top_logprobs setup,
    FlatLogprobs would only introduce a constant amount of objects.

    As each position might contains different amount of ranks,
    start_indices_per_position would be used to access the logprob ranges
    for different positions.

    NOTE: To reduce the migration overhead and improve backward compatibility,
    we support the key Sequence APIs of list, so it could act as
    list[LogprobsOnePosition]
    """

    # Start / end indices to indicate the range of logprobs for each position.
    start_indices: list[int] = field(default_factory=list)
    end_indices: list[int] = field(default_factory=list)

    # Flatten Logprob information for (each position, rank).
    # For position <i>, the logprobs are ranged
    # from self.start_indices[i] to self.end_indices[i] (exclusive).
    token_ids: list[int] = field(default_factory=list)
    logprobs: list[float] = field(default_factory=list)
    ranks: list[int | None] = field(default_factory=list)
    decoded_tokens: list[str | None] = field(default_factory=list)

    def append(self, logprobs_one_position: LogprobsOnePosition | None) -> None:
        """Appends the container with logprobs for the next position"""
        self.start_indices.append(len(self.logprobs))
        if logprobs_one_position:
            for token_id, logprob in logprobs_one_position.items():
                self.token_ids.append(token_id)
                self.logprobs.append(logprob.logprob)
                self.ranks.append(logprob.rank)
                self.decoded_tokens.append(logprob.decoded_token)
        self.end_indices.append(len(self.logprobs))

    def append_fast(
        self,
        token_ids: list[int],
        logprobs: list[float],
        ranks: itertools.chain[int],
        decoded_tokens: Iterable[str | None],
    ) -> None:
        """
        Appends logprobs for the next position without creating
        the intermediate logprob dictionary.
        """
        self.start_indices.append(len(self.logprobs))
        for token_id, logprob, rank, decoded_token in zip(
            token_ids, logprobs, ranks, decoded_tokens
        ):
            self.token_ids.append(token_id)
            self.logprobs.append(logprob)
            self.ranks.append(rank)
            self.decoded_tokens.append(decoded_token)
        self.end_indices.append(len(self.logprobs))

    def extend(self, logprobs_multi_positions) -> None:
        """Extends the container with logprobs for the next multiple positions"""
        for logprobs_one_position in logprobs_multi_positions:
            self.append(logprobs_one_position)

    def __len__(self) -> int:
        """Gets number of positions stored in the container"""
        return len(self.start_indices)

    @overload
    def __getitem__(self, position: int) -> LogprobsOnePosition: ...

    @overload
    def __getitem__(self, s: slice, /) -> "FlatLogprobs": ...

    def __getitem__(self, index: int | slice):
        """Extracts logprobs of a given position or slice"""
        if isinstance(index, int):
            return {
                self.token_ids[i]: Logprob(
                    logprob=self.logprobs[i],
                    rank=self.ranks[i],
                    decoded_token=self.decoded_tokens[i],
                )
                for i in range(self.start_indices[index], self.end_indices[index])
            }
        elif isinstance(index, slice):
            min_index = self.start_indices[index][0]
            max_index = self.end_indices[index][-1]
            return FlatLogprobs(
                # Shift updated start_indices and end_indices to
                # be 0-indexed
                start_indices=[i - min_index for i in self.start_indices[index]],
                end_indices=[i - min_index for i in self.end_indices[index]],
                token_ids=self.token_ids[min_index:max_index],
                logprobs=self.logprobs[min_index:max_index],
                ranks=self.ranks[min_index:max_index],
                decoded_tokens=self.decoded_tokens[min_index:max_index],
            )
        else:
            raise TypeError(f"Invalid index type: {type(index)}")

    def __setitem__(self, item, value) -> None:
        raise TypeError("Cannot set logprobs in FlatLogprobs")

    def __delitem__(self, item) -> None:
        raise TypeError("Cannot delete logprobs from FlatLogprobs")

    def insert(self, item) -> None:
        raise TypeError("Cannot insert logprobs to FlatLogprobs")

    def __iter__(self) -> Iterator[LogprobsOnePosition]:
        """
        Iterates the container and yields LogprobsOnePosition for
        each position.
        """
        for i in range(0, len(self.start_indices)):
            yield self.__getitem__(i)

decoded_tokens class-attribute instance-attribute

decoded_tokens: list[str | None] = field(
    default_factory=list
)

end_indices class-attribute instance-attribute

end_indices: list[int] = field(default_factory=list)

logprobs class-attribute instance-attribute

logprobs: list[float] = field(default_factory=list)

ranks class-attribute instance-attribute

ranks: list[int | None] = field(default_factory=list)

start_indices class-attribute instance-attribute

start_indices: list[int] = field(default_factory=list)

token_ids class-attribute instance-attribute

token_ids: list[int] = field(default_factory=list)

__delitem__

__delitem__(item) -> None
Source code in vllm/logprobs.py
def __delitem__(self, item) -> None:
    raise TypeError("Cannot delete logprobs from FlatLogprobs")

__getitem__

__getitem__(position: int) -> LogprobsOnePosition
__getitem__(s: slice) -> FlatLogprobs
__getitem__(index: int | slice)

Extracts logprobs of a given position or slice

Source code in vllm/logprobs.py
def __getitem__(self, index: int | slice):
    """Extracts logprobs of a given position or slice"""
    if isinstance(index, int):
        return {
            self.token_ids[i]: Logprob(
                logprob=self.logprobs[i],
                rank=self.ranks[i],
                decoded_token=self.decoded_tokens[i],
            )
            for i in range(self.start_indices[index], self.end_indices[index])
        }
    elif isinstance(index, slice):
        min_index = self.start_indices[index][0]
        max_index = self.end_indices[index][-1]
        return FlatLogprobs(
            # Shift updated start_indices and end_indices to
            # be 0-indexed
            start_indices=[i - min_index for i in self.start_indices[index]],
            end_indices=[i - min_index for i in self.end_indices[index]],
            token_ids=self.token_ids[min_index:max_index],
            logprobs=self.logprobs[min_index:max_index],
            ranks=self.ranks[min_index:max_index],
            decoded_tokens=self.decoded_tokens[min_index:max_index],
        )
    else:
        raise TypeError(f"Invalid index type: {type(index)}")

__init__

__init__(
    start_indices: list[int] = list(),
    end_indices: list[int] = list(),
    token_ids: list[int] = list(),
    logprobs: list[float] = list(),
    ranks: list[int | None] = list(),
    decoded_tokens: list[str | None] = list(),
) -> None

__iter__

Iterates the container and yields LogprobsOnePosition for each position.

Source code in vllm/logprobs.py
def __iter__(self) -> Iterator[LogprobsOnePosition]:
    """
    Iterates the container and yields LogprobsOnePosition for
    each position.
    """
    for i in range(0, len(self.start_indices)):
        yield self.__getitem__(i)

__len__

__len__() -> int

Gets number of positions stored in the container

Source code in vllm/logprobs.py
def __len__(self) -> int:
    """Gets number of positions stored in the container"""
    return len(self.start_indices)

__setitem__

__setitem__(item, value) -> None
Source code in vllm/logprobs.py
def __setitem__(self, item, value) -> None:
    raise TypeError("Cannot set logprobs in FlatLogprobs")

append

append(
    logprobs_one_position: LogprobsOnePosition | None,
) -> None

Appends the container with logprobs for the next position

Source code in vllm/logprobs.py
def append(self, logprobs_one_position: LogprobsOnePosition | None) -> None:
    """Appends the container with logprobs for the next position"""
    self.start_indices.append(len(self.logprobs))
    if logprobs_one_position:
        for token_id, logprob in logprobs_one_position.items():
            self.token_ids.append(token_id)
            self.logprobs.append(logprob.logprob)
            self.ranks.append(logprob.rank)
            self.decoded_tokens.append(logprob.decoded_token)
    self.end_indices.append(len(self.logprobs))

append_fast

append_fast(
    token_ids: list[int],
    logprobs: list[float],
    ranks: chain[int],
    decoded_tokens: Iterable[str | None],
) -> None

Appends logprobs for the next position without creating the intermediate logprob dictionary.

Source code in vllm/logprobs.py
def append_fast(
    self,
    token_ids: list[int],
    logprobs: list[float],
    ranks: itertools.chain[int],
    decoded_tokens: Iterable[str | None],
) -> None:
    """
    Appends logprobs for the next position without creating
    the intermediate logprob dictionary.
    """
    self.start_indices.append(len(self.logprobs))
    for token_id, logprob, rank, decoded_token in zip(
        token_ids, logprobs, ranks, decoded_tokens
    ):
        self.token_ids.append(token_id)
        self.logprobs.append(logprob)
        self.ranks.append(rank)
        self.decoded_tokens.append(decoded_token)
    self.end_indices.append(len(self.logprobs))

extend

extend(logprobs_multi_positions) -> None

Extends the container with logprobs for the next multiple positions

Source code in vllm/logprobs.py
def extend(self, logprobs_multi_positions) -> None:
    """Extends the container with logprobs for the next multiple positions"""
    for logprobs_one_position in logprobs_multi_positions:
        self.append(logprobs_one_position)

insert

insert(item) -> None
Source code in vllm/logprobs.py
def insert(self, item) -> None:
    raise TypeError("Cannot insert logprobs to FlatLogprobs")

Logprob dataclass

Infos for supporting OpenAI compatible logprobs and token ranks.

Attributes:

Name Type Description
logprob float

The logprob of chosen token

rank int | None

The vocab rank of chosen token (>=1)

decoded_token str | None

The decoded chosen token index

Source code in vllm/logprobs.py
@dataclass
class Logprob:
    """Infos for supporting OpenAI compatible logprobs and token ranks.

    Attributes:
        logprob: The logprob of chosen token
        rank: The vocab rank of chosen token (>=1)
        decoded_token: The decoded chosen token index
    """

    logprob: float
    rank: int | None = None
    decoded_token: str | None = None

decoded_token class-attribute instance-attribute

decoded_token: str | None = None

logprob instance-attribute

logprob: float

rank class-attribute instance-attribute

rank: int | None = None

__init__

__init__(
    logprob: float,
    rank: int | None = None,
    decoded_token: str | None = None,
) -> None

append_logprobs_for_next_position

append_logprobs_for_next_position(
    request_logprobs: PromptLogprobs | SampleLogprobs,
    token_ids: list[int],
    logprobs: list[float],
    decoded_tokens: Iterable[str | None],
    rank: int,
    num_logprobs: int,
) -> None

Appends logprobs for the next position

Source code in vllm/logprobs.py
def append_logprobs_for_next_position(
    request_logprobs: PromptLogprobs | SampleLogprobs,
    token_ids: list[int],
    logprobs: list[float],
    decoded_tokens: Iterable[str | None],
    rank: int,
    num_logprobs: int,
) -> None:
    """Appends logprobs for the next position"""
    if num_logprobs == -1:
        num_logprobs = len(logprobs)
    # We do not need a special case for the sampled token
    # being in the topk, since inserting duplicated data
    # into a dictionary twice is the same as doing it once.
    topk_ranks = range(1, num_logprobs + 1)
    ranks = itertools.chain((rank,), topk_ranks)

    if isinstance(request_logprobs, FlatLogprobs):
        request_logprobs.append_fast(token_ids, logprobs, ranks, decoded_tokens)
    else:
        request_logprobs.append(
            {
                token_id: Logprob(
                    logprob=logprob,
                    rank=rank,
                    decoded_token=token,
                )
                for token_id, logprob, rank, token in zip(
                    token_ids, logprobs, ranks, decoded_tokens
                )
            }
        )

create_prompt_logprobs

create_prompt_logprobs() -> PromptLogprobs

Creates a container to store prompt logprobs for a request

Source code in vllm/logprobs.py
def create_prompt_logprobs() -> PromptLogprobs:
    """Creates a container to store prompt logprobs for a request"""
    logprobs = FlatLogprobs() if envs.VLLM_FLAT_LOGPROBS else []
    # NOTE: logprob of first prompt token is None.
    logprobs.append(None)
    return logprobs

create_sample_logprobs

create_sample_logprobs() -> SampleLogprobs

Creates a container to store decode logprobs for a request

Source code in vllm/logprobs.py
def create_sample_logprobs() -> SampleLogprobs:
    """Creates a container to store decode logprobs for a request"""
    return FlatLogprobs() if envs.VLLM_FLAT_LOGPROBS else []