diff --git a/include/scsi-lowlevel.h b/include/scsi-lowlevel.h index b924cf5..2aa24bb 100644 --- a/include/scsi-lowlevel.h +++ b/include/scsi-lowlevel.h @@ -270,9 +270,8 @@ struct scsi_iovector { struct scsi_iovec *iov; int niov; int nalloc; - - size_t reserved_1; - int reserved_2; + size_t offset; + int consumed; }; struct scsi_task { diff --git a/lib/socket.c b/lib/socket.c index 9844084..8812e41 100644 --- a/lib/socket.c +++ b/lib/socket.c @@ -408,82 +408,89 @@ iscsi_queue_length(struct iscsi_context *iscsi) return i; } -static ssize_t -iscsi_iovector_readv_writev(struct iscsi_context *iscsi, struct scsi_iovector *iovector, uint32_t pos, size_t count, int do_write) +ssize_t +iscsi_iovector_readv_writev(struct iscsi_context *iscsi, struct scsi_iovector *iovector, uint32_t pos, ssize_t max_read, int do_write) { - struct iovec *iovs; - struct iovec *first_iov; - struct iovec *last_iov; - int i, niov; - size_t skip_first, skip_last; - if (iovector->iov == NULL) { errno = EINVAL; return -1; } - niov = iovector->niov; - iovs = alloca(sizeof(struct iovec) * niov); - if (iovs == NULL) { - errno = ENOMEM; + if (pos < iovector->offset) { + /* start over in case we are going backwards */ + iovector->offset = 0; + iovector->consumed = 0; + } + + if (iovector->niov <= iovector->consumed) { + /* someone issued a read/write but did not provide enough user buffers for all the data. + * maybe someone tried to read just 512 bytes off a MMC device? + */ + errno = EINVAL; return -1; } - for(i = 0; i < niov; i++) { - iovs[i].iov_base = iovector->iov[i].iov_base; - iovs[i].iov_len = iovector->iov[i].iov_len; - } - first_iov = &iovs[0]; + /* iov is a pointer to the first iovec to pass */ + struct scsi_iovec *iov = &iovector->iov[iovector->consumed]; + pos -= iovector->offset; - /* Step past iovectors until we find the first iov to send */ - while (pos >= first_iov->iov_len) { - pos -= first_iov->iov_len; - first_iov++; - niov--; - if (niov <= 0) { - /* We ran out of iovectors. */ + /* forward until iov points to the first iov to pass */ + while (pos >= iov->iov_len) { + iovector->offset += iov->iov_len; + iovector->consumed++; + pos -= iov->iov_len; + if (iovector->niov <= iovector->consumed) { errno = EINVAL; return -1; } - } - /* How many bytes in the first iov to skip */ - skip_first = pos; - if (skip_first > 0) { - char *buf = first_iov->iov_base; - first_iov->iov_base = &buf[skip_first]; - first_iov->iov_len -= skip_first; + iov = &iovector->iov[iovector->consumed]; } + /* iov2 is a pointer to the last iovec to pass */ + struct scsi_iovec *iov2 = iov; - /* Find the last iovector to send */ - last_iov = first_iov; - while (count >last_iov->iov_len) { - count -= last_iov->iov_len; - last_iov++; - niov--; - if (niov <= 0) { - /* We ran out of iovectors. */ + int niov=1; /* number of iovectors to pass */ + uint32_t len2 = pos + max_read; /* adjust length of iov2 */ + + /* forward until iov2 points to the last iovec we pass later. it might + happen that we have a lot of iovectors but are limited by max_read */ + while (len2 > iov2->iov_len) { + if (iovector->niov <= iovector->consumed+niov-1) { errno = EINVAL; return -1; } - } - /* How many bytes in the last iov to skip */ - skip_last = last_iov->iov_len - count; - if (skip_last > 0) { - last_iov->iov_len -= skip_last; + niov++; + len2 -= iov2->iov_len; + iov2 = &iovector->iov[iovector->consumed+niov-1]; } + /* we might limit the length of the last iovec we pass to readv/writev + store its orignal length to restore it later */ + size_t _len2 = iov2->iov_len; - /* number of iovectors we will be using */ - niov = last_iov - first_iov + 1; + /* adjust base+len of start iovec and len of last iovec */ + iov2->iov_len = len2; + iov->iov_base = (void*) ((uintptr_t)iov->iov_base + pos); + iov->iov_len -= pos; + ssize_t n; if (do_write) { - count = writev(iscsi->fd, first_iov, niov); + n = writev(iscsi->fd, (struct iovec*) iov, niov); } else { - count = readv(iscsi->fd, first_iov, niov); + n = readv(iscsi->fd, (struct iovec*) iov, niov); } - return count; + /* restore original values */ + iov->iov_base = (void*) ((uintptr_t)iov->iov_base - pos); + iov->iov_len += pos; + iov2->iov_len = _len2; + + if (n > max_read) { + /* we read/write more bytes than expected, this MUST not happen */ + errno = EINVAL; + return -1; + } + return n; } static int